blob: 51a3df7b87f2c3343a841aa5c8507385a7ff4710 [file] [log] [blame]
#include <Python.h>
#include <sys/types.h>
#ifndef _MSC_VER
#include <sys/socket.h>
#endif
#include <stdbool.h>
#include <unordered_map>
#include <libshm.h>
#include <TH/TH.h>
#include <ATen/ATen.h>
#include <ATen/dlpack.h>
#include <ATen/DLConvertor.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/autograd/generated/python_nn_functions.h"
#include "torch/csrc/utils/python_strings.h"
#include "torch/csrc/utils/tensor_numpy.h"
#include "torch/csrc/jit/python_tracer.h"
#include "torch/csrc/jit/init.h"
#include "torch/csrc/jit/python_ir.h"
#ifdef WITH_CUDNN
#include "cudnn.h"
#endif
#define WITH_NUMPY_IMPORT_ARRAY
#include "THP.h"
#include "ModuleSparse.cpp"
#include "DataLoader.cpp"
namespace py = pybind11;
PyObject* module;
PyObject* tensor_classes;
PyObject *THPDefaultTensorClass = NULL;
THPGenerator *THPDefaultGenerator = NULL;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
static bool THPModule_loadClasses(PyObject *self)
{
#define ASSERT_NOT_NULL(ptr) if (!(ptr)) { THPUtils_setError("couldn't load classes"); return false; }
PyObject *torch_module = PyImport_ImportModule("torch");
if (!torch_module) {
THPUtils_setError("class loader couldn't access torch module");
return false;
}
ASSERT_NOT_NULL(tensor_classes = PyObject_GetAttrString(torch_module, "_tensor_classes"));
if (!THPDoubleTensor_postInit(torch_module)) return false;
if (!THPFloatTensor_postInit(torch_module)) return false;
if (!THPHalfTensor_postInit(torch_module)) return false;
if (!THPLongTensor_postInit(torch_module)) return false;
if (!THPIntTensor_postInit(torch_module)) return false;
if (!THPShortTensor_postInit(torch_module)) return false;
if (!THPCharTensor_postInit(torch_module)) return false;
if (!THPByteTensor_postInit(torch_module)) return false;
THPDoubleStorage_postInit(torch_module);
THPFloatStorage_postInit(torch_module);
THPHalfStorage_postInit(torch_module);
THPLongStorage_postInit(torch_module);
THPIntStorage_postInit(torch_module);
THPShortStorage_postInit(torch_module);
THPCharStorage_postInit(torch_module);
THPByteStorage_postInit(torch_module);
return true;
#undef ASSERT_NOT_NULL
}
static PyObject * THPModule_initNames(PyObject *self, PyObject *arg)
{
static std::vector<std::string> names;
THPObjectPtr types(PySequence_Fast(arg, "expected a sequence"));
if (!types) return NULL;
int num_classes = PySequence_Fast_GET_SIZE(types.get());
names.reserve(names.size() + num_classes);
for (int i = 0; i < num_classes; i++) {
PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i);
THPUtils_assert(PyType_Check(obj), "expected a PyTypeObject");
PyTypeObject* type = (PyTypeObject*)obj;
THPObjectPtr module_name(PyObject_GetAttrString(obj, "__module__"));
if (!module_name) return NULL;
THPUtils_assert(THPUtils_checkString(module_name.get()),
"expected __module__ to be a string");
std::string name = THPUtils_unpackString(module_name.get());
names.push_back(name + "." + type->tp_name);
type->tp_name = names.back().c_str();
}
Py_RETURN_NONE;
}
static bool THPModule_assignStateless(PyObject *self)
{
#define INIT_STATELESS(type) \
stateless = PyObject_CallFunctionObjArgs((PyObject*)&TH_CONCAT_2(type, TensorStatelessType), NULL); \
if (!stateless) { \
return false; \
} \
if (PyObject_SetAttrString(TH_CONCAT_3(THP,type,TensorClass), THP_STATELESS_ATTRIBUTE_NAME, stateless) == -1) { \
return false; \
}
PyObject *stateless;
INIT_STATELESS(Double);
INIT_STATELESS(Float);
INIT_STATELESS(Half);
INIT_STATELESS(Long);
INIT_STATELESS(Int);
INIT_STATELESS(Short);
INIT_STATELESS(Char);
INIT_STATELESS(Byte);
return true;
#undef INIT_STATELESS
}
//
// Callback for python part. Used for additional initialization of python classes
static PyObject * THPModule_initExtension(PyObject *self, PyObject *shm_manager_path)
{
HANDLE_TH_ERRORS
if (!THPUtils_checkString(shm_manager_path)) {
THPUtils_setError("initialization error - expected bytes/string object as shm_manager_path!");
return NULL;
}
std::string path = THPUtils_unpackString(shm_manager_path);
libshm_init(path.c_str());
if (!THPModule_loadClasses(self)) return NULL;
if (!THPModule_assignStateless(self)) return NULL;
if (!THPAutograd_initFunctions(self)) return NULL;
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject * THPModule_getNumThreads(PyObject *module)
{
return PyLong_FromLong(THGetNumThreads());
}
static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg)
{
THPUtils_assert(THPUtils_checkLong(arg), "set_num_threads expects an int, "
"but got %s", THPUtils_typename(arg));
THSetNumThreads((int)THPUtils_unpackLong(arg));
Py_RETURN_NONE;
}
bool THPModule_isTensor(PyObject *obj)
{
int result = PySet_Contains(tensor_classes, (PyObject*)Py_TYPE(obj));
if (result == -1)
throw std::logic_error("FATAL: tensor_classes isn't a set!");
return result;
}
PyObject * THPModule_setDefaultTensorType(PyObject *_unused, PyObject *type)
{
THPDefaultTensorClass = type;
Py_RETURN_NONE;
}
PyObject * THPModule_fromNumpy(PyObject *_unused, PyObject *array)
{
HANDLE_TH_ERRORS
return torch::createPyObject(torch::utils::tensor_from_numpy(array));
END_HANDLE_TH_ERRORS
}
/**
* STATELESS FUNCTIONS
**/
static PyObject * findTensor(PyObject *args, PyObject *kwargs) {
for (Py_ssize_t i = 0; i < PyTuple_Size(args); i++) {
PyObject *item = PyTuple_GET_ITEM(args, i);
if (THPModule_isTensor(item) || THPVariable_Check(item)) {
return item;
}
}
if (kwargs) {
Py_ssize_t pos = 0;
PyObject *key, *value;
while (PyDict_Next(kwargs, &pos, &key, &value)) {
if (THPModule_isTensor(value) || THPVariable_Check(value)) {
return value;
}
}
}
return THPDefaultTensorClass;
}
static PyObject * swapFirstTwoItems(PyObject *args) {
// Returns a tuple with the first two items swapped
auto size = PyTuple_GET_SIZE(args);
auto r = THPObjectPtr{PyTuple_New(size)};
if (!r) return nullptr;
for (Py_ssize_t i = 0; i < size; i++) {
PyObject* obj = PyTuple_GET_ITEM(args, (i <= 1 ? 1 - i : i));
Py_INCREF(obj);
PyTuple_SET_ITEM(r.get(), i, obj);
}
return r.release();
}
static PyObject * dispatchStateless(PyObject *args, PyObject *kwargs, const char *name) {
PyObject *tensor = findTensor(args, kwargs);
return THPUtils_dispatchStateless(tensor, name, args, kwargs);
}
static PyObject * dispatchStatelessSwap(PyObject *args, PyObject *kwargs, const char *name) {
PyObject *tensor = findTensor(args, kwargs);
if (THPVariable_Check(tensor) && PyTuple_GET_SIZE(args) >= 2 && tensor == PyTuple_GET_ITEM(args, 1)) {
// Unlike tensors, the stateless methods on Variables are dispatched in a different manner.
// On Variables, the `self` argument must be at the first argument when dispatching.
// For stateless methods which has more than one arguments and the `self` comes second,
// (e.g., `polygamma(n, x)`, etc.), the `self` argument needs to be swapped to the
// first position before dispatching.
auto newArgs = THPObjectPtr{swapFirstTwoItems(args)};
return THPUtils_dispatchStateless(tensor, name, newArgs.get(), kwargs);
} else {
return THPUtils_dispatchStateless(tensor, name, args, kwargs);
}
}
#define IMPLEMENT_STATELESS(name) \
static PyObject * TH_CONCAT_2(THPModule_, name)(PyObject *_unused, PyObject *args, PyObject *kwargs) \
{ \
return dispatchStateless(args, kwargs, #name); \
}
#define IMPLEMENT_STATELESS_SWAP(name) \
static PyObject * TH_CONCAT_2(THPModule_, name)(PyObject *_unused, PyObject *args, PyObject *kwargs) \
{ \
return dispatchStatelessSwap(args, kwargs, #name); \
}
// This handles the deprecated torch.addxx signatures. For example,
// torch.addmm(1, var, 2, a, b) -> var.addmm(1, 2, a, b)
#define IMPLEMENT_STATELESS_ADDXX IMPLEMENT_STATELESS_SWAP
IMPLEMENT_STATELESS(sigmoid)
IMPLEMENT_STATELESS(log)
IMPLEMENT_STATELESS(log1p)
IMPLEMENT_STATELESS(lgamma)
IMPLEMENT_STATELESS(digamma)
IMPLEMENT_STATELESS(erf)
IMPLEMENT_STATELESS(erfinv)
IMPLEMENT_STATELESS(exp)
IMPLEMENT_STATELESS(expm1)
IMPLEMENT_STATELESS(cos)
IMPLEMENT_STATELESS(acos)
IMPLEMENT_STATELESS(cosh)
IMPLEMENT_STATELESS(sin)
IMPLEMENT_STATELESS(asin)
IMPLEMENT_STATELESS(sinh)
IMPLEMENT_STATELESS(tan)
IMPLEMENT_STATELESS(atan)
IMPLEMENT_STATELESS(tanh)
IMPLEMENT_STATELESS(sqrt)
IMPLEMENT_STATELESS(rsqrt)
IMPLEMENT_STATELESS(ceil)
IMPLEMENT_STATELESS(floor)
IMPLEMENT_STATELESS(round)
IMPLEMENT_STATELESS(abs)
IMPLEMENT_STATELESS(trunc)
IMPLEMENT_STATELESS(frac)
IMPLEMENT_STATELESS(mean)
IMPLEMENT_STATELESS(std)
IMPLEMENT_STATELESS(var)
IMPLEMENT_STATELESS(norm)
IMPLEMENT_STATELESS(reciprocal)
IMPLEMENT_STATELESS(neg)
IMPLEMENT_STATELESS(add)
IMPLEMENT_STATELESS(mul)
IMPLEMENT_STATELESS(div)
IMPLEMENT_STATELESS(fmod)
IMPLEMENT_STATELESS(min)
IMPLEMENT_STATELESS(max)
IMPLEMENT_STATELESS(dot)
IMPLEMENT_STATELESS(sum)
IMPLEMENT_STATELESS(prod)
IMPLEMENT_STATELESS(remainder)
IMPLEMENT_STATELESS(cumsum)
IMPLEMENT_STATELESS(cumprod)
IMPLEMENT_STATELESS(clamp)
IMPLEMENT_STATELESS(equal)
IMPLEMENT_STATELESS(eye)
IMPLEMENT_STATELESS(diag)
IMPLEMENT_STATELESS(numel)
IMPLEMENT_STATELESS(sign)
IMPLEMENT_STATELESS(trace)
IMPLEMENT_STATELESS(tril)
IMPLEMENT_STATELESS(triu)
IMPLEMENT_STATELESS(zero)
IMPLEMENT_STATELESS(kthvalue)
IMPLEMENT_STATELESS(mode)
IMPLEMENT_STATELESS(median)
IMPLEMENT_STATELESS(cross)
IMPLEMENT_STATELESS(sort)
IMPLEMENT_STATELESS(topk)
IMPLEMENT_STATELESS(t)
IMPLEMENT_STATELESS(transpose)
IMPLEMENT_STATELESS(squeeze)
IMPLEMENT_STATELESS(unsqueeze)
IMPLEMENT_STATELESS(renorm)
IMPLEMENT_STATELESS(dist)
IMPLEMENT_STATELESS(linspace)
IMPLEMENT_STATELESS(logspace)
IMPLEMENT_STATELESS(histc)
IMPLEMENT_STATELESS(atan2)
IMPLEMENT_STATELESS(pow)
IMPLEMENT_STATELESS(lerp)
IMPLEMENT_STATELESS(zeros)
IMPLEMENT_STATELESS(zeros_like)
IMPLEMENT_STATELESS(ones)
IMPLEMENT_STATELESS(ones_like)
IMPLEMENT_STATELESS(index_select)
IMPLEMENT_STATELESS(take)
IMPLEMENT_STATELESS(ger)
IMPLEMENT_STATELESS(mv)
IMPLEMENT_STATELESS(mm)
IMPLEMENT_STATELESS(bmm)
// TODO: this doesn't implement options that return numbers!
IMPLEMENT_STATELESS(multinomial)
IMPLEMENT_STATELESS(normal)
IMPLEMENT_STATELESS(standard_gamma)
IMPLEMENT_STATELESS(dirichlet_grad)
IMPLEMENT_STATELESS(bernoulli)
IMPLEMENT_STATELESS(range)
IMPLEMENT_STATELESS(arange)
IMPLEMENT_STATELESS(gather)
IMPLEMENT_STATELESS(rand)
IMPLEMENT_STATELESS(randn)
IMPLEMENT_STATELESS(masked_select)
IMPLEMENT_STATELESS(gesv)
IMPLEMENT_STATELESS(gels)
IMPLEMENT_STATELESS(trtrs)
IMPLEMENT_STATELESS(symeig)
IMPLEMENT_STATELESS(eig)
IMPLEMENT_STATELESS(svd)
IMPLEMENT_STATELESS(inverse)
IMPLEMENT_STATELESS(potrf)
IMPLEMENT_STATELESS(potrs)
IMPLEMENT_STATELESS(potri)
IMPLEMENT_STATELESS(pstrf)
IMPLEMENT_STATELESS(qr)
IMPLEMENT_STATELESS(geqrf)
IMPLEMENT_STATELESS(orgqr)
IMPLEMENT_STATELESS(ormqr)
IMPLEMENT_STATELESS(btrifact)
IMPLEMENT_STATELESS(btrifact_with_info)
IMPLEMENT_STATELESS(btrisolve)
IMPLEMENT_STATELESS(gt)
IMPLEMENT_STATELESS(lt)
IMPLEMENT_STATELESS(ge)
IMPLEMENT_STATELESS(le)
IMPLEMENT_STATELESS(eq)
IMPLEMENT_STATELESS(ne)
// For torch.polygamma(n, x), the `self` argument comes second, the
// first two arguments needs to be swapped before dispatch.
IMPLEMENT_STATELESS_SWAP(polygamma)
IMPLEMENT_STATELESS_ADDXX(addmm)
IMPLEMENT_STATELESS_ADDXX(addmv)
IMPLEMENT_STATELESS_ADDXX(addr)
IMPLEMENT_STATELESS_ADDXX(addbmm)
IMPLEMENT_STATELESS_ADDXX(baddbmm)
IMPLEMENT_STATELESS_ADDXX(addcmul)
IMPLEMENT_STATELESS_ADDXX(addcdiv)
#undef IMPLEMENT_STATELESS
#undef IMPLEMENT_STATELESS_SWAP
#undef IMPLEMENT_STATELESS_ADDXX
// In nonzero, the first argument might be a LongTensor that will be used
// for indices output, so we should pick a function based on second
// tensor's type.
static PyObject * THPModule_nonzero(PyObject *_unused, PyObject *args, PyObject *kwargs)
{
PyObject *tensor = THPDefaultTensorClass;
if (PyTuple_Size(args) == 1)
tensor = PyTuple_GET_ITEM(args, 0);
else if (PyTuple_Size(args) == 2)
tensor = PyTuple_GET_ITEM(args, 1);
return THPUtils_dispatchStateless(tensor, "nonzero", args, kwargs);
}
static PyObject * THPModule_randperm(PyObject *_unused, PyObject *args, PyObject *kwargs)
{
PyObject *tensor = THPLongTensorClass;
PyObject *out;
if (kwargs && (out = PyDict_GetItemString(kwargs, "out")))
tensor = out;
return THPUtils_dispatchStateless(tensor, "randperm", args, kwargs);
}
static PyObject * THPModule_cat(PyObject *_unused, PyObject *args, PyObject *kwargs)
{
PyObject *tensor = THPDefaultTensorClass;
THPObjectPtr iterator;
THPObjectPtr item;
PyObject *first_arg=nullptr;
if (args && PyTuple_GET_SIZE(args) > 0) {
first_arg = PyTuple_GET_ITEM(args, 0);
} else if (kwargs && PyTuple_GET_ITEM(args, 0)) {
first_arg = PyDict_GetItemString(kwargs, "seq");
}
if (first_arg) {
if (THPModule_isTensor(first_arg)) {
tensor = first_arg;
} else if (PySequence_Check(first_arg)) {
item = PySequence_GetItem(first_arg, 0);
if (item && (THPModule_isTensor(item) || THPVariable_Check(item))) {
tensor = item;
}
}
PyErr_Clear();
}
return THPUtils_dispatchStateless(tensor, "cat", args, kwargs);
}
PyObject *THPModule_safeCall(PyObject *_unused, PyObject *args, PyObject *kwargs)
{
PyObject *result = NULL;
PyObject *args_slice = NULL;
PyThreadState *thread_state = PyThreadState_Get();
Py_ssize_t num_args = args ? PyTuple_Size(args) : 0;
THPUtils_assert(num_args > 0, "expected at least one argument");
try {
args_slice = PyTuple_GetSlice(args, 1, num_args);
result = PyObject_Call(PyTuple_GET_ITEM(args, 0), args_slice, kwargs);
} catch (std::exception &e) {
PyEval_RestoreThread(thread_state);
Py_DECREF(args_slice);
PyErr_SetString(THPException_FatalError, e.what());
Py_LeaveRecursiveCall();
}
Py_DECREF(args_slice);
return result;
}
PyObject *THPModule_addDocStr(PyObject *_unused, PyObject *args)
{
// adds a __doc__ string to a function, similar to numpy's arr_add_docstring
static std::vector<std::string> all_docs;
PyObject *obj;
PyObject *doc_obj;
if (!PyArg_ParseTuple(args, "OO", &obj, &doc_obj)) {
return NULL;
}
const char* doc_str = "<invalid string>";
if (THPUtils_checkString(doc_obj)) {
all_docs.push_back(THPUtils_unpackString(doc_obj));
doc_str = all_docs.back().c_str();
}
if (Py_TYPE(obj) == &PyCFunction_Type) {
PyCFunctionObject* f = (PyCFunctionObject *)obj;
if (f->m_ml->ml_doc) {
return PyErr_Format(PyExc_RuntimeError,
"function '%s' already has a docstring", f->m_ml->ml_name);
}
f->m_ml->ml_doc = doc_str;
} else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor") == 0) {
PyMethodDescrObject* m = (PyMethodDescrObject *)obj;
if (m->d_method->ml_doc) {
return PyErr_Format(PyExc_RuntimeError,
"method '%s' already has a docstring", m->d_method->ml_name);
}
m->d_method->ml_doc = doc_str;
} else {
return PyErr_Format(PyExc_TypeError,
"don't know how to add docstring to type '%s'", Py_TYPE(obj)->tp_name);
}
Py_INCREF(obj);
return obj;
}
PyObject *THPModule_inferSize(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
Py_ssize_t num_args = args ? (Py_ssize_t) PyTuple_Size(args) : 0;
THPUtils_assert(num_args == 2, "expected exactly 2 arguments");
PyObject *arg1 = PyTuple_GET_ITEM(args, 0);
THPUtils_assert(THPSize_Check(arg1), "expected a torch.Size as argument 1");
PyObject *arg2 = PyTuple_GET_ITEM(args, 1);
THPUtils_assert(THPSize_Check(arg2), "expected a torch.Size as argument 2");
THLongStoragePtr size1_guard = THPUtils_unpackSize(arg1);
THLongStorage *size1 = size1_guard.get();
THLongStoragePtr size2_guard = THPUtils_unpackSize(arg2);
THLongStorage *size2 = size2_guard.get();
THLongStoragePtr sizes_guard(THLongStorage_new());
THLongStorage *sizes = sizes_guard.get();
char error_buffer[1024];
int ret = THLongStorage_inferSize2(sizes, size1->data, size1->size, size2->data, size2->size, error_buffer, 1024);
THPUtils_assert(ret == 0, error_buffer);
return THPSize_New(sizes->size, sizes->data);
END_HANDLE_TH_ERRORS
}
static PyObject *THPModule_setBackcompatBroadcastWarn(PyObject *module, PyObject *arg) {
THPUtils_assert(PyBool_Check(arg), "set_backcompat_broadcast_warn expects a bool, "
"but got %s", THPUtils_typename(arg));
setBackCompatBroadcastWarn(arg == Py_True);
Py_RETURN_NONE;
}
static PyObject *THPModule_getBackcompatBroadcastWarn(PyObject *module)
{
if (getBackCompatBroadcastWarn()) Py_RETURN_TRUE;
else Py_RETURN_FALSE;
}
static PyObject *THPModule_setBackcompatKeepdimWarn(PyObject *module, PyObject *arg) {
THPUtils_assert(PyBool_Check(arg), "set_backcompat_keepdim_warn expects a bool, "
"but got %s", THPUtils_typename(arg));
setBackCompatKeepdimWarn(arg == Py_True);
Py_RETURN_NONE;
}
static PyObject *THPModule_getBackcompatKeepdimWarn(PyObject *module)
{
if (getBackCompatKeepdimWarn()) Py_RETURN_TRUE;
else Py_RETURN_FALSE;
}
PyObject *THPModule_hasDistributed(PyObject *_unused)
{
#ifdef WITH_DISTRIBUTED
Py_RETURN_TRUE;
#else
Py_RETURN_FALSE;
#endif
}
PyObject *THPModule_toDLPack(PyObject *_unused, PyObject *data)
{
THPUtils_assert(THPModule_isTensor(data), "data must be a Tensor");
auto atTensor = torch::createTensor(data);
DLManagedTensor* dlMTensor = at::toDLPack(atTensor);
return PyCapsule_New(dlMTensor, "dltensor", NULL);
}
PyObject *THPModule_fromDLPack(PyObject *_unused, PyObject *data)
{
DLManagedTensor * dlMTensor = (DLManagedTensor *)PyCapsule_GetPointer(data, "dltensor");
THPUtils_assert(dlMTensor, "from_dlpack received an invalid capsule. "
"Note that DLTensor capsules can be consumed only once, "
"so you might have already constructed a tensor from it once.")
// atensor steals the ownership of the underlying storage. It also passes a
// destructor function that will be called when the underlying storage goes
// out of scope. When the destructor is called, the dlMTensor is destructed too.
at::Tensor atensor = at::fromDLPack(dlMTensor);
// It is possible that the call to at::fromDLPack is the very first
// call to create a Tensor in PyTorch. If so, then _lazy_init has
// not been called, and the attempt to call createPyObject will fail
// because cuda ATen types have not been registered in Python yet.
// so if we have a cuda tensor, then we need to make sure
// we have called _lazy_init here
if(atensor.is_cuda()) {
py::module::import("torch.cuda").attr("init")();
}
// Make sure this capsule will never be used again.
PyCapsule_SetName(data, "used_dltensor");
return torch::createPyObject(atensor);
}
PyObject *THPModule_setUserEnabledCuDNN(PyObject *_unused, PyObject *arg)
{
THPUtils_assert(PyBool_Check(arg), "set_enabled_cudnn expects a bool, "
"but got %s", THPUtils_typename(arg));
at::globalContext().setUserEnabledCuDNN(arg == Py_True);
Py_RETURN_NONE;
}
PyObject *THPModule_userEnabledCuDNN(PyObject *_unused)
{
if (at::globalContext().userEnabledCuDNN()) Py_RETURN_TRUE;
else Py_RETURN_FALSE;
}
PyObject *THPModule_setDeterministicCuDNN(PyObject *_unused, PyObject *arg)
{
THPUtils_assert(PyBool_Check(arg), "set_deterministic_cudnn expects a bool, "
"but got %s", THPUtils_typename(arg));
at::globalContext().setDeterministicCuDNN(arg == Py_True);
Py_RETURN_NONE;
}
PyObject *THPModule_deterministicCuDNN(PyObject *_unused)
{
if (at::globalContext().deterministicCuDNN()) Py_RETURN_TRUE;
else Py_RETURN_FALSE;
}
PyObject *THPModule_setBenchmarkCuDNN(PyObject *_unused, PyObject *arg)
{
THPUtils_assert(PyBool_Check(arg), "set_benchmark_cudnn expects a bool, "
"but got %s", THPUtils_typename(arg));
at::globalContext().setBenchmarkCuDNN(arg == Py_True);
Py_RETURN_NONE;
}
PyObject *THPModule_benchmarkCuDNN(PyObject *_unused)
{
if (at::globalContext().benchmarkCuDNN()) Py_RETURN_TRUE;
else Py_RETURN_FALSE;
}
#ifdef WITH_CUDA
extern PyObject * THCSPModule_initExtension(PyObject *self);
#endif
static PyMethodDef TorchMethods[] = {
{"_initExtension", (PyCFunction)THPModule_initExtension, METH_O, NULL},
{"_autograd_init", (PyCFunction)THPAutograd_initExtension, METH_NOARGS, NULL},
{"_add_docstr", (PyCFunction)THPModule_addDocStr, METH_VARARGS, NULL},
{"_sparse_init", (PyCFunction)THSPModule_initExtension, METH_NOARGS, NULL},
{"_init_names", (PyCFunction)THPModule_initNames, METH_O, NULL},
{"_has_distributed",(PyCFunction)THPModule_hasDistributed, METH_NOARGS, NULL},
#ifdef WITH_CUDA
{"_cuda_sparse_init", (PyCFunction)THCSPModule_initExtension, METH_NOARGS, NULL},
#endif
{"_safe_call", (PyCFunction)THPModule_safeCall, METH_VARARGS | METH_KEYWORDS, NULL},
{"_set_default_tensor_type", (PyCFunction)THPModule_setDefaultTensorType, METH_O, NULL},
{"_infer_size", (PyCFunction)THPModule_inferSize, METH_VARARGS, NULL},
{"_set_backcompat_broadcast_warn", (PyCFunction)THPModule_setBackcompatBroadcastWarn, METH_O, NULL},
{"_get_backcompat_broadcast_warn", (PyCFunction)THPModule_getBackcompatBroadcastWarn, METH_NOARGS, NULL},
{"_set_backcompat_keepdim_warn", (PyCFunction)THPModule_setBackcompatKeepdimWarn, METH_O, NULL},
{"_get_backcompat_keepdim_warn", (PyCFunction)THPModule_getBackcompatKeepdimWarn, METH_NOARGS, NULL},
{"get_num_threads", (PyCFunction)THPModule_getNumThreads, METH_NOARGS, NULL},
{"set_num_threads", (PyCFunction)THPModule_setNumThreads, METH_O, NULL},
{"_get_cudnn_enabled", (PyCFunction)THPModule_userEnabledCuDNN, METH_NOARGS, NULL},
{"_set_cudnn_enabled", (PyCFunction)THPModule_setUserEnabledCuDNN, METH_O, NULL},
{"_get_cudnn_benchmark", (PyCFunction)THPModule_benchmarkCuDNN, METH_NOARGS, NULL},
{"_set_cudnn_benchmark", (PyCFunction)THPModule_setBenchmarkCuDNN, METH_O, NULL},
{"_get_cudnn_deterministic", (PyCFunction)THPModule_deterministicCuDNN, METH_NOARGS, NULL},
{"_set_cudnn_deterministic", (PyCFunction)THPModule_setDeterministicCuDNN, METH_O, NULL},
{"from_numpy", (PyCFunction)THPModule_fromNumpy, METH_O, NULL},
{"_to_dlpack", (PyCFunction)THPModule_toDLPack, METH_O, NULL},
{"_from_dlpack", (PyCFunction)THPModule_fromDLPack, METH_O, NULL},
{"sigmoid", (PyCFunction)THPModule_sigmoid, METH_VARARGS | METH_KEYWORDS, NULL},
{"log", (PyCFunction)THPModule_log, METH_VARARGS | METH_KEYWORDS, NULL},
{"log1p", (PyCFunction)THPModule_log1p, METH_VARARGS | METH_KEYWORDS, NULL},
{"lgamma", (PyCFunction)THPModule_lgamma, METH_VARARGS | METH_KEYWORDS, NULL},
{"digamma", (PyCFunction)THPModule_digamma, METH_VARARGS | METH_KEYWORDS, NULL},
{"polygamma", (PyCFunction)THPModule_polygamma, METH_VARARGS | METH_KEYWORDS, NULL},
{"erf", (PyCFunction)THPModule_erf, METH_VARARGS | METH_KEYWORDS, NULL},
{"erfinv", (PyCFunction)THPModule_erfinv, METH_VARARGS | METH_KEYWORDS, NULL},
{"exp", (PyCFunction)THPModule_exp, METH_VARARGS | METH_KEYWORDS, NULL},
{"expm1", (PyCFunction)THPModule_expm1, METH_VARARGS | METH_KEYWORDS, NULL},
{"cos", (PyCFunction)THPModule_cos, METH_VARARGS | METH_KEYWORDS, NULL},
{"acos", (PyCFunction)THPModule_acos, METH_VARARGS | METH_KEYWORDS, NULL},
{"cosh", (PyCFunction)THPModule_cosh, METH_VARARGS | METH_KEYWORDS, NULL},
{"sin", (PyCFunction)THPModule_sin, METH_VARARGS | METH_KEYWORDS, NULL},
{"asin", (PyCFunction)THPModule_asin, METH_VARARGS | METH_KEYWORDS, NULL},
{"sinh", (PyCFunction)THPModule_sinh, METH_VARARGS | METH_KEYWORDS, NULL},
{"tan", (PyCFunction)THPModule_tan, METH_VARARGS | METH_KEYWORDS, NULL},
{"atan", (PyCFunction)THPModule_atan, METH_VARARGS | METH_KEYWORDS, NULL},
{"tanh", (PyCFunction)THPModule_tanh, METH_VARARGS | METH_KEYWORDS, NULL},
{"sqrt", (PyCFunction)THPModule_sqrt, METH_VARARGS | METH_KEYWORDS, NULL},
{"rsqrt", (PyCFunction)THPModule_rsqrt, METH_VARARGS | METH_KEYWORDS, NULL},
{"ceil", (PyCFunction)THPModule_ceil, METH_VARARGS | METH_KEYWORDS, NULL},
{"floor", (PyCFunction)THPModule_floor, METH_VARARGS | METH_KEYWORDS, NULL},
{"round", (PyCFunction)THPModule_round, METH_VARARGS | METH_KEYWORDS, NULL},
{"abs", (PyCFunction)THPModule_abs, METH_VARARGS | METH_KEYWORDS, NULL},
{"trunc", (PyCFunction)THPModule_trunc, METH_VARARGS | METH_KEYWORDS, NULL},
{"frac", (PyCFunction)THPModule_frac, METH_VARARGS | METH_KEYWORDS, NULL},
{"mean", (PyCFunction)THPModule_mean, METH_VARARGS | METH_KEYWORDS, NULL},
{"std", (PyCFunction)THPModule_std, METH_VARARGS | METH_KEYWORDS, NULL},
{"var", (PyCFunction)THPModule_var, METH_VARARGS | METH_KEYWORDS, NULL},
{"norm", (PyCFunction)THPModule_norm, METH_VARARGS | METH_KEYWORDS, NULL},
{"reciprocal", (PyCFunction)THPModule_reciprocal, METH_VARARGS | METH_KEYWORDS, NULL},
{"neg", (PyCFunction)THPModule_neg, METH_VARARGS | METH_KEYWORDS, NULL},
{"add", (PyCFunction)THPModule_add, METH_VARARGS | METH_KEYWORDS, NULL},
{"mul", (PyCFunction)THPModule_mul, METH_VARARGS | METH_KEYWORDS, NULL},
{"div", (PyCFunction)THPModule_div, METH_VARARGS | METH_KEYWORDS, NULL},
{"fmod", (PyCFunction)THPModule_fmod, METH_VARARGS | METH_KEYWORDS, NULL},
{"min", (PyCFunction)THPModule_min, METH_VARARGS | METH_KEYWORDS, NULL},
{"max", (PyCFunction)THPModule_max, METH_VARARGS | METH_KEYWORDS, NULL},
{"dot", (PyCFunction)THPModule_dot, METH_VARARGS | METH_KEYWORDS, NULL},
{"sum", (PyCFunction)THPModule_sum, METH_VARARGS | METH_KEYWORDS, NULL},
{"prod", (PyCFunction)THPModule_prod, METH_VARARGS | METH_KEYWORDS, NULL},
{"remainder", (PyCFunction)THPModule_remainder, METH_VARARGS | METH_KEYWORDS, NULL},
{"cumsum", (PyCFunction)THPModule_cumsum, METH_VARARGS | METH_KEYWORDS, NULL},
{"cumprod", (PyCFunction)THPModule_cumprod, METH_VARARGS | METH_KEYWORDS, NULL},
{"clamp", (PyCFunction)THPModule_clamp, METH_VARARGS | METH_KEYWORDS, NULL},
{"equal", (PyCFunction)THPModule_equal, METH_VARARGS | METH_KEYWORDS, NULL},
{"eye", (PyCFunction)THPModule_eye, METH_VARARGS | METH_KEYWORDS, NULL},
{"diag", (PyCFunction)THPModule_diag, METH_VARARGS | METH_KEYWORDS, NULL},
{"numel", (PyCFunction)THPModule_numel, METH_VARARGS | METH_KEYWORDS, NULL},
{"sign", (PyCFunction)THPModule_sign, METH_VARARGS | METH_KEYWORDS, NULL},
{"trace", (PyCFunction)THPModule_trace, METH_VARARGS | METH_KEYWORDS, NULL},
{"tril", (PyCFunction)THPModule_tril, METH_VARARGS | METH_KEYWORDS, NULL},
{"triu", (PyCFunction)THPModule_triu, METH_VARARGS | METH_KEYWORDS, NULL},
{"zero", (PyCFunction)THPModule_zero, METH_VARARGS | METH_KEYWORDS, NULL},
{"gt", (PyCFunction)THPModule_gt, METH_VARARGS | METH_KEYWORDS, NULL},
{"lt", (PyCFunction)THPModule_lt, METH_VARARGS | METH_KEYWORDS, NULL},
{"ge", (PyCFunction)THPModule_ge, METH_VARARGS | METH_KEYWORDS, NULL},
{"le", (PyCFunction)THPModule_le, METH_VARARGS | METH_KEYWORDS, NULL},
{"eq", (PyCFunction)THPModule_eq, METH_VARARGS | METH_KEYWORDS, NULL},
{"ne", (PyCFunction)THPModule_ne, METH_VARARGS | METH_KEYWORDS, NULL},
{"kthvalue", (PyCFunction)THPModule_kthvalue, METH_VARARGS | METH_KEYWORDS, NULL},
{"mode", (PyCFunction)THPModule_mode, METH_VARARGS | METH_KEYWORDS, NULL},
{"median", (PyCFunction)THPModule_median, METH_VARARGS | METH_KEYWORDS, NULL},
{"cross", (PyCFunction)THPModule_cross, METH_VARARGS | METH_KEYWORDS, NULL},
{"sort", (PyCFunction)THPModule_sort, METH_VARARGS | METH_KEYWORDS, NULL},
{"topk", (PyCFunction)THPModule_topk, METH_VARARGS | METH_KEYWORDS, NULL},
{"t", (PyCFunction)THPModule_t, METH_VARARGS | METH_KEYWORDS, NULL},
{"transpose", (PyCFunction)THPModule_transpose, METH_VARARGS | METH_KEYWORDS, NULL},
{"squeeze", (PyCFunction)THPModule_squeeze, METH_VARARGS | METH_KEYWORDS, NULL},
{"unsqueeze", (PyCFunction)THPModule_unsqueeze, METH_VARARGS | METH_KEYWORDS, NULL},
{"nonzero", (PyCFunction)THPModule_nonzero, METH_VARARGS | METH_KEYWORDS, NULL},
{"renorm", (PyCFunction)THPModule_renorm, METH_VARARGS | METH_KEYWORDS, NULL},
{"dist", (PyCFunction)THPModule_dist, METH_VARARGS | METH_KEYWORDS, NULL},
{"linspace", (PyCFunction)THPModule_linspace, METH_VARARGS | METH_KEYWORDS, NULL},
{"logspace", (PyCFunction)THPModule_logspace, METH_VARARGS | METH_KEYWORDS, NULL},
{"histc", (PyCFunction)THPModule_histc, METH_VARARGS | METH_KEYWORDS, NULL},
{"atan2", (PyCFunction)THPModule_atan2, METH_VARARGS | METH_KEYWORDS, NULL},
{"pow", (PyCFunction)THPModule_pow, METH_VARARGS | METH_KEYWORDS, NULL},
{"lerp", (PyCFunction)THPModule_lerp, METH_VARARGS | METH_KEYWORDS, NULL},
{"zeros", (PyCFunction)THPModule_zeros, METH_VARARGS | METH_KEYWORDS, NULL},
{"zeros_like", (PyCFunction)THPModule_zeros_like, METH_VARARGS | METH_KEYWORDS, NULL},
{"ones", (PyCFunction)THPModule_ones, METH_VARARGS | METH_KEYWORDS, NULL},
{"ones_like", (PyCFunction)THPModule_ones_like, METH_VARARGS | METH_KEYWORDS, NULL},
{"index_select", (PyCFunction)THPModule_index_select, METH_VARARGS | METH_KEYWORDS, NULL},
{"take", (PyCFunction)THPModule_take, METH_VARARGS | METH_KEYWORDS, NULL},
{"addmm", (PyCFunction)THPModule_addmm, METH_VARARGS | METH_KEYWORDS, NULL},
{"addmv", (PyCFunction)THPModule_addmv, METH_VARARGS | METH_KEYWORDS, NULL},
{"addr", (PyCFunction)THPModule_addr, METH_VARARGS | METH_KEYWORDS, NULL},
{"ger", (PyCFunction)THPModule_ger, METH_VARARGS | METH_KEYWORDS, NULL},
{"mv", (PyCFunction)THPModule_mv, METH_VARARGS | METH_KEYWORDS, NULL},
{"addbmm", (PyCFunction)THPModule_addbmm, METH_VARARGS | METH_KEYWORDS, NULL},
{"baddbmm", (PyCFunction)THPModule_baddbmm, METH_VARARGS | METH_KEYWORDS, NULL},
{"addcmul", (PyCFunction)THPModule_addcmul, METH_VARARGS | METH_KEYWORDS, NULL},
{"addcdiv", (PyCFunction)THPModule_addcdiv, METH_VARARGS | METH_KEYWORDS, NULL},
{"mm", (PyCFunction)THPModule_mm, METH_VARARGS | METH_KEYWORDS, NULL},
{"bmm", (PyCFunction)THPModule_bmm, METH_VARARGS | METH_KEYWORDS, NULL},
{"multinomial", (PyCFunction)THPModule_multinomial, METH_VARARGS | METH_KEYWORDS, NULL},
{"normal", (PyCFunction)THPModule_normal, METH_VARARGS | METH_KEYWORDS, NULL},
{"_standard_gamma", (PyCFunction)THPModule_standard_gamma, METH_VARARGS | METH_KEYWORDS, NULL},
{"_dirichlet_grad", (PyCFunction)THPModule_dirichlet_grad, METH_VARARGS | METH_KEYWORDS, NULL},
{"bernoulli", (PyCFunction)THPModule_bernoulli, METH_VARARGS | METH_KEYWORDS, NULL},
{"rand", (PyCFunction)THPModule_rand, METH_VARARGS | METH_KEYWORDS, NULL},
{"randn", (PyCFunction)THPModule_randn, METH_VARARGS | METH_KEYWORDS, NULL},
{"randperm", (PyCFunction)THPModule_randperm, METH_VARARGS | METH_KEYWORDS, NULL},
{"range", (PyCFunction)THPModule_range, METH_VARARGS | METH_KEYWORDS, NULL},
{"arange", (PyCFunction)THPModule_arange, METH_VARARGS | METH_KEYWORDS, NULL},
{"gather", (PyCFunction)THPModule_gather, METH_VARARGS | METH_KEYWORDS, NULL},
{"cat", (PyCFunction)THPModule_cat, METH_VARARGS | METH_KEYWORDS, NULL},
{"masked_select", (PyCFunction)THPModule_masked_select, METH_VARARGS | METH_KEYWORDS, NULL},
{"gesv", (PyCFunction)THPModule_gesv, METH_VARARGS | METH_KEYWORDS, NULL},
{"gels", (PyCFunction)THPModule_gels, METH_VARARGS | METH_KEYWORDS, NULL},
{"trtrs", (PyCFunction)THPModule_trtrs, METH_VARARGS | METH_KEYWORDS, NULL},
{"symeig", (PyCFunction)THPModule_symeig, METH_VARARGS | METH_KEYWORDS, NULL},
{"eig", (PyCFunction)THPModule_eig, METH_VARARGS | METH_KEYWORDS, NULL},
{"svd", (PyCFunction)THPModule_svd, METH_VARARGS | METH_KEYWORDS, NULL},
{"inverse", (PyCFunction)THPModule_inverse, METH_VARARGS | METH_KEYWORDS, NULL},
{"potrf", (PyCFunction)THPModule_potrf, METH_VARARGS | METH_KEYWORDS, NULL},
{"potrs", (PyCFunction)THPModule_potrs, METH_VARARGS | METH_KEYWORDS, NULL},
{"potri", (PyCFunction)THPModule_potri, METH_VARARGS | METH_KEYWORDS, NULL},
{"pstrf", (PyCFunction)THPModule_pstrf, METH_VARARGS | METH_KEYWORDS, NULL},
{"qr", (PyCFunction)THPModule_qr, METH_VARARGS | METH_KEYWORDS, NULL},
{"geqrf", (PyCFunction)THPModule_geqrf, METH_VARARGS | METH_KEYWORDS, NULL},
{"orgqr", (PyCFunction)THPModule_orgqr, METH_VARARGS | METH_KEYWORDS, NULL},
{"ormqr", (PyCFunction)THPModule_ormqr, METH_VARARGS | METH_KEYWORDS, NULL},
{"btrifact", (PyCFunction)THPModule_btrifact, METH_VARARGS | METH_KEYWORDS, NULL},
{"btrifact_with_info", (PyCFunction)THPModule_btrifact_with_info, METH_VARARGS | METH_KEYWORDS, NULL},
{"btrisolve", (PyCFunction)THPModule_btrisolve, METH_VARARGS | METH_KEYWORDS, NULL},
// Sparse functions
{"smm", (PyCFunction)THSPModule_sspmm, METH_VARARGS | METH_KEYWORDS, NULL},
{"saddmm", (PyCFunction)THSPModule_sspaddmm, METH_VARARGS | METH_KEYWORDS, NULL},
{"dsmm", (PyCFunction)THSPModule_spmm, METH_VARARGS | METH_KEYWORDS, NULL},
{"hsmm", (PyCFunction)THSPModule_hspmm, METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL, 0, NULL}
};
bool THCPDoubleStorage_init(PyObject *module);
bool THCPFloatStorage_init(PyObject *module);
bool THCPHalfStorage_init(PyObject *module);
bool THCPLongStorage_init(PyObject *module);
bool THCPIntStorage_init(PyObject *module);
bool THCPShortStorage_init(PyObject *module);
bool THCPCharStorage_init(PyObject *module);
bool THCPByteStorage_init(PyObject *module);
bool THCPDoubleTensor_init(PyObject *module);
bool THCPFloatTensor_init(PyObject *module);
bool THCPHalfTensor_init(PyObject *module);
bool THCPLongTensor_init(PyObject *module);
bool THCPIntTensor_init(PyObject *module);
bool THCPShortTensor_init(PyObject *module);
bool THCPCharTensor_init(PyObject *module);
bool THCPByteTensor_init(PyObject *module);
bool THCPStream_init(PyObject *module);
#ifdef WITH_CUDA
PyMethodDef* THCPModule_methods();
#endif
bool THCSPDoubleTensor_init(PyObject *module);
bool THCSPFloatTensor_init(PyObject *module);
bool THCSPHalfTensor_init(PyObject *module);
bool THCSPLongTensor_init(PyObject *module);
bool THCSPIntTensor_init(PyObject *module);
bool THCSPShortTensor_init(PyObject *module);
bool THCSPCharTensor_init(PyObject *module);
bool THCSPByteTensor_init(PyObject *module);
bool THDPDoubleStorage_init(PyObject *module);
bool THDPFloatStorage_init(PyObject *module);
//bool THDPHalfStorage_init(PyObject *module);
bool THDPLongStorage_init(PyObject *module);
bool THDPIntStorage_init(PyObject *module);
bool THDPShortStorage_init(PyObject *module);
bool THDPCharStorage_init(PyObject *module);
bool THDPByteStorage_init(PyObject *module);
bool THDPDoubleTensor_init(PyObject *module);
bool THDPFloatTensor_init(PyObject *module);
//bool THDPHalfTensor_init(PyObject *module);
bool THDPLongTensor_init(PyObject *module);
bool THDPIntTensor_init(PyObject *module);
bool THDPShortTensor_init(PyObject *module);
bool THDPCharTensor_init(PyObject *module);
bool THDPByteTensor_init(PyObject *module);
static std::vector<PyMethodDef> methods;
#ifdef WITH_DISTRIBUTED
PyMethodDef* THDPModule_methods();
#endif
// TODO: Refactor this in some less manual way
#ifdef WITH_CUDNN
static PyObject * THCUDNN_cudnn_version(PyObject *self, PyObject *args)
{
return PyLong_FromLong(CUDNN_VERSION);
}
static PyMethodDef _THCUDNN_methods[] = {
{"_cudnn_version", (PyCFunction)THCUDNN_cudnn_version, METH_VARARGS, NULL},
{NULL}
};
PyMethodDef* THCUDNN_methods() {
return _THCUDNN_methods;
}
#endif
static PyObject* initModule() {
HANDLE_TH_ERRORS
THInferNumThreads();
#define ASSERT_TRUE(cmd) if (!(cmd)) return NULL
THPUtils_addPyMethodDefs(methods, TorchMethods);
THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
#ifdef WITH_CUDA
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
#endif
#ifdef WITH_CUDNN
THPUtils_addPyMethodDefs(methods, THCUDNN_methods());
#endif
#ifdef WITH_DISTRIBUTED
THPUtils_addPyMethodDefs(methods, THDPModule_methods());
#endif
#if PY_MAJOR_VERSION == 2
ASSERT_TRUE(module = Py_InitModule("torch._C", methods.data()));
#else
static struct PyModuleDef torchmodule = {
PyModuleDef_HEAD_INIT,
"torch._C",
NULL,
-1,
methods.data()
};
ASSERT_TRUE(module = PyModule_Create(&torchmodule));
#endif
ASSERT_TRUE(THPWrapper_init(module));
ASSERT_TRUE(THPGenerator_init(module));
ASSERT_TRUE(THPException_init(module));
ASSERT_TRUE(THPSize_init(module));
ASSERT_TRUE(THPVariable_initModule(module));
ASSERT_TRUE(THPFunction_initModule(module));
ASSERT_TRUE(THPEngine_initModule(module));
torch::autograd::initAutogradClosureBindings(module);
torch::jit::initJITBindings(module);
torch::autograd::initNNFunctions(module);
ASSERT_TRUE(THPDoubleStorage_init(module));
ASSERT_TRUE(THPFloatStorage_init(module));
ASSERT_TRUE(THPHalfStorage_init(module));
ASSERT_TRUE(THPLongStorage_init(module));
ASSERT_TRUE(THPIntStorage_init(module));
ASSERT_TRUE(THPShortStorage_init(module));
ASSERT_TRUE(THPCharStorage_init(module));
ASSERT_TRUE(THPByteStorage_init(module));
ASSERT_TRUE(THPDoubleTensor_init(module));
ASSERT_TRUE(THPFloatTensor_init(module));
ASSERT_TRUE(THPHalfTensor_init(module));
ASSERT_TRUE(THPLongTensor_init(module));
ASSERT_TRUE(THPIntTensor_init(module));
ASSERT_TRUE(THPShortTensor_init(module));
ASSERT_TRUE(THPCharTensor_init(module));
ASSERT_TRUE(THPByteTensor_init(module));
ASSERT_TRUE(THSPDoubleTensor_init(module));
ASSERT_TRUE(THSPFloatTensor_init(module));
ASSERT_TRUE(THSPLongTensor_init(module));
ASSERT_TRUE(THSPIntTensor_init(module));
ASSERT_TRUE(THSPShortTensor_init(module));
ASSERT_TRUE(THSPCharTensor_init(module));
ASSERT_TRUE(THSPByteTensor_init(module));
#ifdef WITH_CUDA
// This will only initialise base classes and attach them to library namespace
// They won't be ready for real usage until importing cuda module, that will
// complete the process (but it defines Python classes before calling back into
// C, so these lines have to execute first)..
ASSERT_TRUE(THCPDoubleStorage_init(module));
ASSERT_TRUE(THCPFloatStorage_init(module));
ASSERT_TRUE(THCPHalfStorage_init(module));
ASSERT_TRUE(THCPLongStorage_init(module));
ASSERT_TRUE(THCPIntStorage_init(module));
ASSERT_TRUE(THCPShortStorage_init(module));
ASSERT_TRUE(THCPCharStorage_init(module));
ASSERT_TRUE(THCPByteStorage_init(module));
ASSERT_TRUE(THCPDoubleTensor_init(module));
ASSERT_TRUE(THCPFloatTensor_init(module));
ASSERT_TRUE(THCPHalfTensor_init(module));
ASSERT_TRUE(THCPLongTensor_init(module));
ASSERT_TRUE(THCPIntTensor_init(module));
ASSERT_TRUE(THCPShortTensor_init(module));
ASSERT_TRUE(THCPCharTensor_init(module));
ASSERT_TRUE(THCPByteTensor_init(module));
ASSERT_TRUE(THCPStream_init(module));
ASSERT_TRUE(THCSPDoubleTensor_init(module));
ASSERT_TRUE(THCSPFloatTensor_init(module));
ASSERT_TRUE(THCSPHalfTensor_init(module));
ASSERT_TRUE(THCSPLongTensor_init(module));
ASSERT_TRUE(THCSPIntTensor_init(module));
ASSERT_TRUE(THCSPShortTensor_init(module));
ASSERT_TRUE(THCSPCharTensor_init(module));
ASSERT_TRUE(THCSPByteTensor_init(module));
#endif
#ifdef WITH_CUDNN
PyObject *has_cudnn = Py_True;
#else
PyObject *has_cudnn = Py_False;
#endif
Py_INCREF(has_cudnn);
ASSERT_TRUE(PyModule_AddObject(module, "has_cudnn", has_cudnn) == 0);
#ifdef WITH_DISTRIBUTED_MW
// See comment on CUDA objects
ASSERT_TRUE(THDPDoubleStorage_init(module));
ASSERT_TRUE(THDPFloatStorage_init(module));
//ASSERT_TRUE(THDPHalfStorage_init(module));
ASSERT_TRUE(THDPLongStorage_init(module));
ASSERT_TRUE(THDPIntStorage_init(module));
ASSERT_TRUE(THDPShortStorage_init(module));
ASSERT_TRUE(THDPCharStorage_init(module));
ASSERT_TRUE(THDPByteStorage_init(module));
ASSERT_TRUE(THDPDoubleTensor_init(module));
ASSERT_TRUE(THDPFloatTensor_init(module));
//ASSERT_TRUE(THDPHalfTensor_init(module));
ASSERT_TRUE(THDPLongTensor_init(module));
ASSERT_TRUE(THDPIntTensor_init(module));
ASSERT_TRUE(THDPShortTensor_init(module));
ASSERT_TRUE(THDPCharTensor_init(module));
ASSERT_TRUE(THDPByteTensor_init(module));
#endif
// force ATen to initialize because it handles
// setting up TH Errors so that they throw C++ exceptions
at::init();
auto& defaultGenerator = at::globalContext().defaultGenerator(at::kCPU);
THPDefaultGenerator = (THPGenerator*)THPGenerator_NewWithGenerator(
defaultGenerator);
ASSERT_TRUE(PyModule_AddObject(module, "default_generator", (PyObject*)THPDefaultGenerator) == 0);
#ifdef WITH_NUMPY
if (_import_array() < 0) return NULL;
#endif
return module;
END_HANDLE_TH_ERRORS
}
#if PY_MAJOR_VERSION == 2
PyMODINIT_FUNC init_C()
#else
PyMODINIT_FUNC PyInit__C()
#endif
{
#if PY_MAJOR_VERSION == 2
initModule();
#else
return initModule();
#endif
}