blob: 49f53980c1a33f780efa9087068f166ea5cce640 [file] [log] [blame]
#include "THP.h"
PyObject* sparse_tensor_classes;
////////////////////////////////////////////////////////////////////////////////
// SPARSE MODULE INITIALIZATION
////////////////////////////////////////////////////////////////////////////////
static bool THSPModule_loadClasses(PyObject *sparse_module)
{
if (!THSPDoubleTensor_postInit(sparse_module)) return false;
if (!THSPFloatTensor_postInit(sparse_module)) return false;
if (!THSPLongTensor_postInit(sparse_module)) return false;
if (!THSPIntTensor_postInit(sparse_module)) return false;
if (!THSPShortTensor_postInit(sparse_module)) return false;
if (!THSPCharTensor_postInit(sparse_module)) return false;
if (!THSPByteTensor_postInit(sparse_module)) return false;
return true;
}
static bool THSPModule_assignStateless()
{
#define INIT_STATELESS(type) \
stateless = PyObject_Call((PyObject*)&TH_CONCAT_3(Sparse, type, TensorStatelessType), arg, NULL); \
if (!stateless) { \
THPUtils_setError("stateless method initialization error"); \
return false; \
} \
if (PyObject_SetAttrString(TH_CONCAT_3(THSP,type,TensorClass), THP_STATELESS_ATTRIBUTE_NAME, stateless) == -1) { \
THPUtils_setError("stateless method initialization error (on assignment)");\
}
PyObject *arg = PyTuple_New(0);
PyObject *stateless;
INIT_STATELESS(Double);
INIT_STATELESS(Float);
INIT_STATELESS(Long);
INIT_STATELESS(Int);
INIT_STATELESS(Short);
INIT_STATELESS(Char);
INIT_STATELESS(Byte);
Py_DECREF(arg);
return true;
#undef INIT_STATELESS
}
// Callback for python part. Used for additional initialization of python classes
PyObject *THSPModule_initExtension(PyObject *self)
{
PyObject *module = PyImport_ImportModule("torch.sparse");
if (!module) return NULL;
if (!THSPModule_loadClasses(module)) return NULL;
if (!THSPModule_assignStateless()) return NULL;
Py_RETURN_NONE;
}
////////////////////////////////////////////////////////////////////////////////
// Sparse Stateless Functions
////////////////////////////////////////////////////////////////////////////////
bool THPModule_isSparseTensor(PyObject *obj)
{
int result = PySet_Contains(sparse_tensor_classes, (PyObject*)Py_TYPE(obj));
if (result == -1)
throw std::logic_error("FATAL: sparse_tensor_classes isn't a set!");
return result;
}
#define IMPLEMENT_SPARSE_STATELESS(name) \
static PyObject * TH_CONCAT_2(THSPModule_, name)(PyObject *_unused, PyObject *args, PyObject *kwargs) \
{ \
PyObject *tensor = THSPFloatTensorClass; \
PyObject *key, *value; \
Py_ssize_t pos = 0; \
for (int i = 0; i < PyTuple_Size(args); i++) { \
PyObject *item = PyTuple_GET_ITEM(args, i); \
if (THPModule_isTensor(item) || THPVariable_Check(item)) { \
tensor = item; \
goto dispatch; \
} \
} \
if (kwargs) { \
while (PyDict_Next(kwargs, &pos, &key, &value)) { \
if (THPModule_isTensor(value) || THPVariable_Check(value)) { \
tensor = value; \
goto dispatch; \
} \
} \
} \
\
dispatch: \
return THPUtils_dispatchStateless(tensor, #name, args, kwargs); \
}
IMPLEMENT_SPARSE_STATELESS(spmm);
IMPLEMENT_SPARSE_STATELESS(sspmm);
IMPLEMENT_SPARSE_STATELESS(sspaddmm);
IMPLEMENT_SPARSE_STATELESS(hspmm);
#undef IMPLEMENT_SPARSE_STATELESS