Skip to content

Commit

Permalink
Refactor Python string utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
colesbury authored and apaszke committed Apr 28, 2017
1 parent 775481e commit 4c1cdb6
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 117 deletions.
53 changes: 18 additions & 35 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <libshm.h>
#include <TH/TH.h>

#include "torch/csrc/utils/python_strings.h"

#ifdef WITH_CUDNN
#include "cudnn/Module.h"
#endif
Expand Down Expand Up @@ -73,13 +75,9 @@ static PyObject * THPModule_initNames(PyObject *self, PyObject *arg)

THPObjectPtr module_name = PyObject_GetAttrString(obj, "__module__");
if (!module_name) return NULL;
#if PY_MAJOR_VERSION == 2
THPUtils_assert(PyString_Check(module_name.get()), "expected __module__ to be a string");
std::string name = PyString_AS_STRING(module_name.get());
#else
THPUtils_assert(PyUnicode_Check(module_name.get()), "expected __module__ to be a string");
std::string name = PyUnicode_AsUTF8(module_name.get());
#endif
THPUtils_assert(THPUtils_checkString(module_name.get()),
"expected __module__ to be a string");
std::string name = THPUtils_unpackString(module_name.get());
names.push_back(name + "." + type->tp_name);
type->tp_name = names.back().c_str();
}
Expand All @@ -89,15 +87,13 @@ static PyObject * THPModule_initNames(PyObject *self, PyObject *arg)
static bool THPModule_assignStateless(PyObject *self)
{
#define INIT_STATELESS(type) \
stateless = PyObject_Call((PyObject*)&TH_CONCAT_2(type, TensorStatelessType), arg, NULL); \
stateless = PyObject_CallFunctionObjArgs((PyObject*)&TH_CONCAT_2(type, TensorStatelessType), NULL); \
if (!stateless) { \
THPUtils_setError("stateless method initialization error"); \
return false; \
} \
if (PyObject_SetAttrString(TH_CONCAT_3(THP,type,TensorClass), THP_STATELESS_ATTRIBUTE_NAME, stateless) == -1) { \
THPUtils_setError("stateless method initialization error (on assignment)");\
return false; \
}
PyObject *arg = PyTuple_New(0);
PyObject *stateless;
INIT_STATELESS(Double);
INIT_STATELESS(Float);
Expand All @@ -107,23 +103,25 @@ static bool THPModule_assignStateless(PyObject *self)
INIT_STATELESS(Short);
INIT_STATELESS(Char);
INIT_STATELESS(Byte);
Py_DECREF(arg);
return true;
#undef INIT_STATELESS
}
//
// Callback for python part. Used for additional initialization of python classes
static PyObject * THPModule_initExtension(PyObject *self, PyObject *shm_manager_path)
{
if (!THPUtils_checkBytes(shm_manager_path)) {
HANDLE_TH_ERRORS
if (!THPUtils_checkString(shm_manager_path)) {
THPUtils_setError("initialization error - expected bytes/string object as shm_manager_path!");
return NULL;
}
libshm_init(THPUtils_bytesAsString(shm_manager_path));
std::string path = THPUtils_unpackString(shm_manager_path);
libshm_init(path.c_str());
if (!THPModule_loadClasses(self)) return NULL;
if (!THPModule_assignStateless(self)) return NULL;
if (!THPAutograd_initFunctions(self)) return NULL;
return PyBool_FromLong(true);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}

static PyObject * THPModule_getNumThreads(PyObject *module)
Expand Down Expand Up @@ -429,22 +427,6 @@ PyObject *THPModule_safeCall(PyObject *_unused, PyObject *args, PyObject *kwargs
return result;
}

static std::string parseString(PyObject *obj)
{
if (PyBytes_Check(obj)) {
return std::string(PyBytes_AS_STRING(obj));
#if PY_MAJOR_VERSION == 3
} else if (PyUnicode_Check(obj)) {
return std::string(PyUnicode_AsUTF8(obj));
#else
} else if (PyUnicode_Check(obj)) {
THPObjectPtr utf8 = PyUnicode_AsUTF8String(obj);
return std::string(PyBytes_AS_STRING(utf8.get()));
#endif
}
return "<invalid string>";
}

PyObject *THPModule_addDocStr(PyObject *_unused, PyObject *args)
{
// adds a __doc__ string to a function, similar to numpy's arr_add_docstring
Expand All @@ -455,8 +437,11 @@ PyObject *THPModule_addDocStr(PyObject *_unused, PyObject *args)
return NULL;
}

all_docs.push_back(parseString(doc_obj));
const char* doc_str = all_docs.back().c_str();
const char* doc_str = "<invalid string>";
if (THPUtils_checkString(doc_obj)) {
all_docs.push_back(THPUtils_unpackString(doc_obj));
doc_str = all_docs.back().c_str();
}

if (Py_TYPE(obj) == &PyCFunction_Type) {
PyCFunctionObject* f = (PyCFunctionObject *)obj;
Expand Down Expand Up @@ -499,7 +484,6 @@ extern PyObject * THCPModule_seedAll(PyObject *_unused);
extern PyObject * THCPModule_initialSeed(PyObject *_unused);
extern PyObject * THCPModule_cudaHostAllocator(PyObject *_unused);
extern PyObject * THCPModule_cudaSynchronize(PyObject *_unused);
extern PyObject * THCPModule_getLibPath(PyObject *_unused);
extern PyObject * THCPModule_cudaSleep(PyObject *_unused, PyObject *cycles);
extern PyObject * THCPModule_cudaLockMutex(PyObject *module);
extern PyObject * THCPModule_cudaUnlockMutex(PyObject *module);
Expand Down Expand Up @@ -532,7 +516,6 @@ static PyMethodDef TorchMethods[] = {
{"_cuda_initialSeed", (PyCFunction)THCPModule_initialSeed, METH_NOARGS, NULL},
{"_cuda_cudaHostAllocator", (PyCFunction)THCPModule_cudaHostAllocator, METH_NOARGS, NULL},
{"_cuda_synchronize", (PyCFunction)THCPModule_cudaSynchronize, METH_NOARGS, NULL},
{"_cuda_getLibPath", (PyCFunction)THCPModule_getLibPath, METH_NOARGS, NULL},
{"_cuda_sleep", (PyCFunction)THCPModule_cudaSleep, METH_O, NULL},
{"_cuda_sparse_init", (PyCFunction)THCSPModule_initExtension, METH_NOARGS, NULL},
{"_cuda_lock_mutex", (PyCFunction)THCPModule_cudaLockMutex, METH_NOARGS, NULL},
Expand Down
7 changes: 2 additions & 5 deletions torch/csrc/Size.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "Size.h"

#include <string>
#include "torch/csrc/utils/python_strings.h"
#include "THP.h"

PyObject* THPSizeClass = NULL;
Expand Down Expand Up @@ -47,11 +48,7 @@ static PyObject * THPSize_repr(THPSize *self)
repr += std::to_string(PyLong_AsLong(PyTuple_GET_ITEM(self, i)));
}
repr += "])";
#if PY_MAJOR_VERSION == 2
return PyString_FromString(repr.c_str());
#else
return PyUnicode_FromString(repr.c_str());
#endif
return THPUtils_packString(repr);
}

extern PyTypeObject THPSizeType;
Expand Down
12 changes: 3 additions & 9 deletions torch/csrc/autograd/python_hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/utils/auto_gil.h"
#include "torch/csrc/utils/object_ptr.h"
#include "torch/csrc/utils/python_strings.h"
#include "torch/csrc/Exceptions.h"
#include <THPP/THPP.h>

Expand Down Expand Up @@ -184,15 +185,8 @@ static void check_single_result(PyObject* _original, PyObject* _result, PyObject

static std::string hook_name(PyObject* hook) {
THPObjectPtr name = PyObject_GetAttrString(hook, "__name__");
#if PY_MAJOR_VERSION == 2
if (name && PyString_Check(name.get())) {
return std::string(PyString_AS_STRING(name.get()));
if (name && THPUtils_checkString(name.get())) {
return THPUtils_unpackString(name.get());
}
#else
if (name && PyUnicode_Check(name.get())) {
THPObjectPtr tmp = PyUnicode_AsASCIIString(name.get());
return std::string(PyBytes_AS_STRING(tmp.get()));
}
#endif
return "<unknown>";
}
13 changes: 0 additions & 13 deletions torch/csrc/cuda/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,19 +259,6 @@ PyObject * THCPModule_cudaUnlockMutex(PyObject *module)
Py_RETURN_NONE;
}

PyObject * THCPModule_getLibPath(PyObject *_unused)
{
#define _STR(x) #x
#define STR(x) _STR(x)
#if PY_MAJOR_VERSION == 2
return PyString_FromString(STR(CUDA_LIB_PATH));
#else
return PyUnicode_FromString(STR(CUDA_LIB_PATH));
#endif
#undef STR
#undef _STR
}

////////////////////////////////////////////////////////////////////////////////
// Cuda module initialization
////////////////////////////////////////////////////////////////////////////////
Expand Down
36 changes: 10 additions & 26 deletions torch/csrc/distributed/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,43 +42,27 @@ static bool THDPModule_loadClasses(PyObject *module_dict)
static std::unordered_map<PyObject*, THDReduceOp> obj2reduceop;
static std::unordered_map<PyObject*, THDGroup> obj2group;

static THPObjectPtr _ensureBytes(PyObject *obj)
{
#if PY_MAJOR_VERSION == 2
if (PyString_Check(obj)) {
#elif PY_MAJOR_VERSION == 3
if (PyBytes_Check(obj)) {
#endif
Py_INCREF(obj);
return obj;
}
if (PyUnicode_Check(obj)) {
return PyUnicode_AsASCIIString(obj);
}
return NULL;
}

PyObject* THDPModule_initProcessGroup(PyObject *_unused, PyObject *_backend)
PyObject* THDPModule_initProcessGroup(PyObject *_unused, PyObject *backend)
{
HANDLE_TH_ERRORS
THPObjectPtr backend_bytes = _ensureBytes(_backend);
THPUtils_assert(backend_bytes, "backend argument has to be a string/bytes "
"object, but got %s", THPUtils_typename(_backend));
char *backend_name = THPUtils_bytesAsString(backend_bytes.get());
THPUtils_assert(THPUtils_checkString(backend),
"backend argument has to be a string/bytes object, but got %s",
THPUtils_typename(backend));
std::string backend_name = THPUtils_unpackString(backend);
THDChannelType channel_type = name2channel_type.at(backend_name);
THPUtils_assert(THDProcessGroupInit(channel_type), "failed to initialize "
"distributed library (THD)");
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}

PyObject* THDPModule_initMasterWorker(PyObject *_unused, PyObject *_backend)
PyObject* THDPModule_initMasterWorker(PyObject *_unused, PyObject *backend)
{
HANDLE_TH_ERRORS
THPObjectPtr backend_bytes = _ensureBytes(_backend);
THPUtils_assert(backend_bytes, "backend argument has to be a string/bytes "
"object, but got %s", THPUtils_typename(_backend));
char *backend_name = THPUtils_bytesAsString(backend_bytes.get());
THPUtils_assert(THPUtils_checkString(backend),
"backend argument has to be a string/bytes object, but got %s",
THPUtils_typename(backend));
std::string backend_name = THPUtils_unpackString(backend);
THDChannelType channel_type = name2channel_type.at(backend_name);
THPUtils_assert(THDMasterWorkerInit(channel_type), "failed to initialize "
"distributed library (THD)");
Expand Down
18 changes: 10 additions & 8 deletions torch/csrc/generic/StorageSharing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ static PyObject * THPStorage_(shareFilename)(THPStorage *self)
ctx = (libshm_context*)storage->allocatorContext;
}

THPObjectPtr manager_handle = THPUtils_bytesFromString(ctx->manager_handle);
THPObjectPtr manager_handle = PyBytes_FromString(ctx->manager_handle);
if (!manager_handle) return NULL;
THPObjectPtr storage_handle =
THPUtils_bytesFromString(THMapAllocatorContext_filename(ctx->th_context));
PyBytes_FromString(THMapAllocatorContext_filename(ctx->th_context));
if (!storage_handle) return NULL;
THPObjectPtr size = PyLong_FromLong(storage->size);
if (!size) return NULL;
Expand All @@ -124,20 +124,21 @@ static PyObject * THPStorage_(shareFilename)(THPStorage *self)
static PyObject * THPStorage_(newSharedFilename)(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
THPUtils_assert(PyTuple_GET_SIZE(args) == 3, "tuple of 3 items expected");
PyObject *_manager_handle = PyTuple_GET_ITEM(args, 0);
PyObject *_object_handle = PyTuple_GET_ITEM(args, 1);
PyObject *_size = PyTuple_GET_ITEM(args, 2);
if (!THPUtils_checkBytes(_manager_handle) || !THPUtils_checkBytes(_object_handle) || !THPUtils_checkLong(_size)) {
if (!PyBytes_Check(_manager_handle) || !PyBytes_Check(_object_handle) || !THPUtils_checkLong(_size)) {
THPUtils_invalidArguments(args, NULL, "_new_shared in file system mode", 1,
"a handle (string/bytes) and storage size (int)");
return NULL;
}
const char *manager_handle = THPUtils_bytesAsString(_manager_handle);
const char *object_handle = THPUtils_bytesAsString(_object_handle);
const char *manager_handle = PyBytes_AS_STRING(_manager_handle);
const char *object_handle = PyBytes_AS_STRING(_object_handle);
long size = THPUtils_unpackLong(_size);

libshm_context *ctx = libshm_context_new(manager_handle, object_handle,
TH_ALLOCATOR_MAPPED_SHAREDMEM | TH_ALLOCATOR_MAPPED_NOCREATE);
int flags = TH_ALLOCATOR_MAPPED_SHAREDMEM |
TH_ALLOCATOR_MAPPED_NOCREATE;
libshm_context *ctx = libshm_context_new(manager_handle, object_handle, flags);
return THPStorage_(New)(THStorage_(newWithAllocator)(size,
&THManagedSharedAllocator, (void*)ctx));
END_HANDLE_TH_ERRORS
Expand Down Expand Up @@ -199,6 +200,7 @@ static PyObject * THPStorage_(shareFd)(THPStorage *self)
static PyObject * THPStorage_(newSharedFd)(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
THPUtils_assert(PyTuple_GET_SIZE(args) == 2, "tuple of 2 items expected");
PyObject *_tmp_fd = PyTuple_GET_ITEM(args, 0);
PyObject *_size = PyTuple_GET_ITEM(args, 1);
if (!THPUtils_checkLong(_tmp_fd) || !THPUtils_checkLong(_size)) {
Expand Down
12 changes: 2 additions & 10 deletions torch/csrc/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <algorithm>
#include <unordered_map>
#include "THP.h"
#include "torch/csrc/utils/python_strings.h"

#include "generic/utils.cpp"
#include <TH/THGenerateAllTypes.h>
Expand Down Expand Up @@ -457,15 +458,6 @@ std::vector<std::string> _tryMatchKwargs(const Option& option,
return unmatched;
}

std::string _parseDictKey(PyObject *key_str) {
#if PY_MAJOR_VERSION == 3
THPObjectPtr ascii = PyUnicode_AsASCIIString(key_str);
return std::string(PyBytes_AS_STRING(ascii.get()));
#else
return std::string(PyString_AS_STRING(key_str));
#endif
}

void THPUtils_invalidArguments(PyObject *given_args, PyObject *given_kwargs,
const char *function_name, size_t num_options, ...) {
std::vector<std::string> option_strings;
Expand Down Expand Up @@ -493,7 +485,7 @@ void THPUtils_invalidArguments(PyObject *given_args, PyObject *given_kwargs,
Py_ssize_t pos = 0;

while (PyDict_Next(given_kwargs, &pos, &key, &value)) {
kwargs.emplace(_parseDictKey(key), value);
kwargs.emplace(THPUtils_unpackString(key), value);
}
}

Expand Down
11 changes: 0 additions & 11 deletions torch/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,6 @@
(throw std::runtime_error("Could not unpack long"), 0))
#endif


#if PY_MAJOR_VERSION == 2
#define THPUtils_bytesFromString(c_string) PyString_FromString(c_string)
#define THPUtils_checkBytes(obj) PyString_Check(obj)
#define THPUtils_bytesAsString(obj) PyString_AS_STRING(obj)
#else
#define THPUtils_bytesFromString(c_string) PyBytes_FromString(c_string)
#define THPUtils_checkBytes(obj) PyBytes_Check(obj)
#define THPUtils_bytesAsString(obj) PyBytes_AS_STRING(obj)
#endif

#if PY_MAJOR_VERSION == 2
#define THPUtils_checkReal_FLOAT(object) \
(PyFloat_Check(object) || PyLong_Check(object) || PyInt_Check(object))
Expand Down
Loading

0 comments on commit 4c1cdb6

Please sign in to comment.