blob: 8b101f94e6d453b0c0e09114ec41735ab97a163e [file] [log] [blame]
#include "THCP.h"
static bool THCSPModule_loadClasses(PyObject *sparse_module)
{
if (!THCSPDoubleTensor_postInit(sparse_module)) return false;
if (!THCSPFloatTensor_postInit(sparse_module)) return false;
#ifdef CUDA_HALF_TENSOR
if (!THCSPHalfTensor_postInit(sparse_module)) return false;
#endif
if (!THCSPLongTensor_postInit(sparse_module)) return false;
if (!THCSPIntTensor_postInit(sparse_module)) return false;
if (!THCSPShortTensor_postInit(sparse_module)) return false;
if (!THCSPCharTensor_postInit(sparse_module)) return false;
if (!THCSPByteTensor_postInit(sparse_module)) return false;
return true;
}
static bool THCSPModule_assignStateless()
{
#define INIT_STATELESS(type) \
stateless = PyObject_Call((PyObject*)&TH_CONCAT_3(CudaSparse, type, TensorStatelessType), arg, NULL); \
if (!stateless) { \
THPUtils_setError("stateless method initialization error"); \
return false; \
} \
if (PyObject_SetAttrString(TH_CONCAT_3(THCSP,type,TensorClass), THP_STATELESS_ATTRIBUTE_NAME, stateless) == -1) { \
THPUtils_setError("stateless method initialization error (on assignment)");\
}
PyObject *arg = PyTuple_New(0);
PyObject *stateless;
INIT_STATELESS(Double);
INIT_STATELESS(Float);
#ifdef CUDA_HALF_TENSOR
INIT_STATELESS(Half);
#endif
INIT_STATELESS(Long);
INIT_STATELESS(Int);
INIT_STATELESS(Short);
INIT_STATELESS(Char);
INIT_STATELESS(Byte);
Py_DECREF(arg);
return true;
#undef INIT_STATELESS
}
////////////////////////////////////////////////////////////////////////////////
// Sparse Cuda module initialization
////////////////////////////////////////////////////////////////////////////////
bool THCSPModule_initCudaSparse(PyObject *module) {
#define ASSERT_TRUE(cond) if (!(cond)) { return false; }
ASSERT_TRUE(THCSPModule_loadClasses(module));
ASSERT_TRUE(THCSPModule_assignStateless());
return true;
#undef ASSERT_TRUE
}
PyObject * THCSPModule_initExtension(PyObject *self)
{
PyObject *module = PyImport_ImportModule("torch.cuda.sparse");
if (!module) {
THPUtils_setError("class loader couldn't access torch.cuda.sparse module");
return NULL;
}
if (!THCSPModule_initCudaSparse(module)) {
return NULL;
}
Py_RETURN_NONE;
}