blob: 7491185bae419e508bddec399f393b62ab3bcf38 [file] [log] [blame]
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/Storage.cpp"
#else
PyObject *THPStorageClass = NULL;
PyObject * THPStorage_(New)(THStorage *ptr)
{
PyTypeObject *type = (PyTypeObject *)THPStorageClass;
PyObject *obj = type->tp_alloc(type, 0);
if (obj) {
((THPStorage *)obj)->cdata = ptr;
} else {
THStorage_(free)(LIBRARY_STATE ptr);
}
return obj;
}
static void THPStorage_(dealloc)(THPStorage* self)
{
THStorage_(free)(LIBRARY_STATE self->cdata);
Py_TYPE(self)->tp_free((PyObject*)self);
}
static THStorage* THPStorage_(newWithAllocator)(int64_t size, THAllocator* allocator)
{
#if defined(THC_GENERIC_FILE) || defined(THD_GENERIC_FILE)
THPUtils_setError(THPStorageStr " does not support custom allocators");
return NULL;
#else
return THStorage_(newWithAllocator)(LIBRARY_STATE size, allocator, NULL);
#endif
}
static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
Py_ssize_t num_args = args ? PyTuple_Size(args) : 0;
THPStoragePtr self((THPStorage *)type->tp_alloc(type, 0));
THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
THAllocator* allocator = NULL;
// Internally we allow constructing with a keywoard only argument cdata
if (kwargs != NULL) {
PyObject *allocator_ptr = PyDict_GetItemString(kwargs, "allocator");
if (allocator_ptr) {
THPUtils_assert(THPUtils_checkLong(allocator_ptr), "invalid allocator");
allocator = (THAllocator*) PyLong_AsVoidPtr(allocator_ptr);
PyDict_DelItemString(kwargs, "allocator");
}
Py_ssize_t num_kwargs = PyDict_Size(kwargs);
if (num_args == 0) {
PyObject *cdata_ptr = PyDict_GetItemString(kwargs, "cdata");
if (num_kwargs == 1 && cdata_ptr && THPUtils_checkLong(cdata_ptr)) {
THStorage *ptr = (THStorage*)PyLong_AsVoidPtr(cdata_ptr);
self->cdata = ptr;
return (PyObject*)self.release();
}
}
THPUtils_assert(num_kwargs == 0, THPStorageStr "(): invalid keyword arguments");
}
// torch.Storage()
if (num_args == 0) {
if (allocator) {
self->cdata = THPStorage_(newWithAllocator)(0, allocator);
} else {
self->cdata = THStorage_(new)(LIBRARY_STATE_NOARGS);
}
return (PyObject*)self.release();
}
PyObject *first_arg = PyTuple_GET_ITEM(args, 0);
// torch.Storage(size)
if (num_args == 1 && THPUtils_checkLong(first_arg)) {
int64_t size = THPUtils_unpackLong(first_arg);
if (allocator) {
self->cdata = THPStorage_(newWithAllocator)(size, allocator);
} else {
self->cdata = THStorage_(newWithSize)(LIBRARY_STATE size);
}
return (PyObject*)self.release();
}
// torch.Storage(view_source, [offset, [size]])
if (num_args < 4 && THPStorage_(Check)(first_arg)) {
#ifdef THD_GENERIC_FILE
THPUtils_setError("distributed storages don't support storage views");
return NULL;
#else
THPStorage *storage_arg = (THPStorage *)first_arg;
int64_t numel = storage_arg->cdata->size;
int64_t offset = 0;
if (num_args >= 2) {
PyObject *second_arg = PyTuple_GET_ITEM(args, 1);
if (!THPUtils_checkLong(second_arg))
goto invalid_arguments;
offset = THPUtils_unpackLong(second_arg);
}
int64_t size = numel - offset;
if (num_args >= 3) {
PyObject *third_arg = PyTuple_GET_ITEM(args, 2);
if (!THPUtils_checkLong(third_arg))
goto invalid_arguments;
size = THPUtils_unpackLong(third_arg);
}
THPUtils_assert(offset >= 0 && offset <= numel, "specified an offset of "
"%" PRId64 ", but the viewed storage has only %" PRId64 " element(s)", offset, numel);
THPUtils_assert(size >= 1 && size <= numel - offset, "specified a size of "
"%" PRId64 ", but the viewed storage has only %" PRId64 " element(s) after offset %" PRId64,
size, numel - offset, offset);
real *data_ptr = storage_arg->cdata->data + offset;
THStoragePtr storage(THStorage_(newWithData)(LIBRARY_STATE data_ptr, size));
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW;
storage->view = storage_arg->cdata;
THStorage_(retain)(LIBRARY_STATE storage_arg->cdata);
self->cdata = storage.release();
return (PyObject*)self.release();
#endif
}
// torch.Storage(sequence)
if (num_args == 1 && PySequence_Check(first_arg)) {
#ifdef THD_GENERIC_FILE
THPUtils_setError("distributed storages don't support construction from a sequence");
#else
Py_ssize_t length = PySequence_Length(first_arg);
THPUtils_assert(length >= 0, "couldn't obtain the length of %s",
THPUtils_typename(first_arg));
self->cdata = THStorage_(newWithSize)(LIBRARY_STATE length);
THPObjectPtr item;
try {
for (Py_ssize_t i = 0; i < length; i++) {
item = PySequence_GetItem(first_arg, i);
real value = THPUtils_(unpackReal)(item.get());
#if !defined(THC_GENERIC_FILE)
self->cdata->data[i] = value;
#else
// TODO: this might be slow - consider batched updates?
THCStorage_(set)(LIBRARY_STATE self->cdata, i, value);
#endif
}
} catch (std::runtime_error &e) {
THPUtils_setError("tried to construct a storage from a sequence (%s), "
"but one of the items was of type %s instead of %s",
THPUtils_typename(first_arg),
THPUtils_typename(item.get()),
THPUtils_typeTraits<real>::python_type_str);
return NULL;
}
return (PyObject*)self.release();
#endif
}
#ifndef THD_GENERIC_FILE
invalid_arguments:
#endif
THPUtils_invalidArguments(args, kwargs, THPStorageStr " constructor", 6,
"no arguments",
"(int size)",
"(Sequence data)",
"(" THPStorageStr " view_source)",
"(" THPStorageStr " view_source, int offset)",
"(" THPStorageStr " view_source, int offset, int size)");
return NULL;
END_HANDLE_TH_ERRORS
}
static Py_ssize_t THPStorage_(length)(THPStorage *self)
{
HANDLE_TH_ERRORS
return THStorage_(size)(LIBRARY_STATE self->cdata);
END_HANDLE_TH_ERRORS_RET(-1)
}
static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
{
HANDLE_TH_ERRORS
/* Integer index */
if (THPUtils_checkLong(index)) {
int64_t nindex = THPUtils_unpackLong(index);
if (nindex < 0)
nindex += THStorage_(size)(LIBRARY_STATE self->cdata);
if (nindex < 0 || nindex >= self->cdata->size) {
PyErr_Format(PyExc_IndexError, "index %" PRId64 " out of range for storage of "
"size %" PRId64, (int64_t) nindex, (int64_t) self->cdata->size);
return NULL;
}
real value = THStorage_(get)(LIBRARY_STATE self->cdata, nindex);
return THPUtils_(newReal)(value);
/* Slice index */
} else if (PySlice_Check(index)) {
#ifdef THD_GENERIC_FILE
THPUtils_setError("distributed storages don't support slicing");
return NULL;
#else
Py_ssize_t start, stop, slicelength, step;
int64_t len = THStorage_(size)(LIBRARY_STATE self->cdata);
if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength))
return NULL;
if (step != 1) {
THPUtils_setError("Trying to slice with a step of %" PRId64 ", but only a step of "
"1 is supported", (int64_t)step);
return NULL;
}
real *data = THStorage_(data)(LIBRARY_STATE self->cdata);
THStoragePtr new_storage(THStorage_(newWithData)(LIBRARY_STATE data + start, slicelength));
new_storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW;
new_storage->view = self->cdata;
THStorage_(retain)(LIBRARY_STATE self->cdata);
PyObject *_ret = THPStorage_(New)(new_storage);
new_storage.release();
return _ret;
#endif
}
PyErr_Format(PyExc_TypeError, "can't index a " THPStorageStr " with %s",
THPUtils_typename(index));
return NULL;
END_HANDLE_TH_ERRORS
}
static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
{
HANDLE_TH_ERRORS
if (!THPUtils_(checkReal)(value)) {
THPUtils_setError("can only set storage content with a %s, but got "
"%s instead", THPUtils_typeTraits<real>::python_type_str,
THPUtils_typename(value));
return -1;
}
real rvalue = THPUtils_(unpackReal)(value);
if (THPUtils_checkLong(index)) {
int64_t nindex = THPUtils_unpackLong(index);
THStorage_(set)(LIBRARY_STATE self->cdata, nindex, rvalue);
return 0;
} else if (PySlice_Check(index)) {
Py_ssize_t start, stop, slicelength, step;
int64_t len = THStorage_(size)(LIBRARY_STATE self->cdata);
if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength))
return -1;
if (step != 1) {
THPUtils_setError("Trying to slice with a step of %" PRId64 ", but only a step of "
"1 is supported", (int64_t)step);
return 0;
}
// TODO: check the bounds only once
// TODO: fill?
for (;start < stop; start++)
THStorage_(set)(LIBRARY_STATE self->cdata, start, rvalue);
return 0;
}
THPUtils_setError("can't index a " THPStorageStr " with %s",
THPUtils_typename(index));
return -1;
END_HANDLE_TH_ERRORS_RET(-1)
}
static PyMappingMethods THPStorage_(mappingmethods) = {
(lenfunc)THPStorage_(length),
(binaryfunc)THPStorage_(get),
(objobjargproc)THPStorage_(set)
};
// TODO: implement equality
PyTypeObject THPStorageType = {
PyVarObject_HEAD_INIT(NULL, 0)
"torch._C." THPStorageBaseStr, /* tp_name */
sizeof(THPStorage), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)THPStorage_(dealloc), /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
&THPStorage_(mappingmethods), /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
NULL, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* will be assigned in init */ /* tp_methods */
0, /* will be assigned in init */ /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
THPStorage_(pynew), /* tp_new */
};
static struct PyMemberDef THPStorage_(members)[] = {
{(char*)"_cdata", T_ULONGLONG, offsetof(THPStorage, cdata), READONLY, NULL},
{NULL}
};
extern THPCopyList THStorage_(copy_functions);
THPCopyList THStorage_(copy_functions);
void THPStorage_(initCopyMethods)()
{
#ifndef THD_GENERIC_FILE
auto& h = THStorage_(copy_functions);
// copy from CPU types
THPInsertStorageCopyFunction(h, &THStorage_(copyByte));
THPInsertStorageCopyFunction(h, &THStorage_(copyChar));
THPInsertStorageCopyFunction(h, &THStorage_(copyShort));
THPInsertStorageCopyFunction(h, &THStorage_(copyInt));
THPInsertStorageCopyFunction(h, &THStorage_(copyLong));
THPInsertStorageCopyFunction(h, &THStorage_(copyHalf));
THPInsertStorageCopyFunction(h, &THStorage_(copyFloat));
THPInsertStorageCopyFunction(h, &THStorage_(copyDouble));
#ifdef THC_GENERIC_FILE
// copy from GPU types
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaByte));
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaChar));
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaShort));
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaInt));
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaLong));
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaFloat));
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaDouble));
#ifdef CUDA_HALF_TENSOR
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaHalf));
#endif
// add CPU <- GPU copies to base type
#define THCpuStorage_(name) TH_CONCAT_4(TH, Real, Storage_, name)
extern THPCopyList THCpuStorage_(copy_functions);
auto& b = THCpuStorage_(copy_functions);
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaByte));
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaChar));
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaShort));
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaInt));
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaLong));
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaFloat));
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaDouble));
#ifdef CUDA_HALF_TENSOR
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaHalf));
#endif
#undef THCpuStorage_
#endif
#endif // !defined(THD_GENERIC_FILE)
}
#include "StorageMethods.cpp"
#ifndef THD_GENERIC_FILE
#include "StorageSharing.cpp"
#endif
bool THPStorage_(init)(PyObject *module)
{
static std::vector<PyMethodDef> methods;
THPUtils_addPyMethodDefs(methods, THPStorage_(methods));
#ifndef THD_GENERIC_FILE
THPUtils_addPyMethodDefs(methods, THPStorage_(sharingMethods));
#endif
THPStorageType.tp_methods = methods.data();
THPStorageType.tp_members = THPStorage_(members);
if (PyType_Ready(&THPStorageType) < 0)
return false;
Py_INCREF(&THPStorageType);
PyModule_AddObject(module, THPStorageBaseStr, (PyObject *)&THPStorageType);
THPStorage_(initCopyMethods)();
return true;
}
void THPStorage_(postInit)(PyObject *module)
{
THPStorageClass = PyObject_GetAttrString(module,(char*)TH_CONCAT_STRING_2(Real,Storage));
if (!THPStorageClass) throw python_error();
bool is_cuda = false;
#ifdef THC_GENERIC_FILE
is_cuda = true;
#endif
const char *type_name = TH_CONCAT_STRING_2(Real,);
torch::registerStoragePyTypeObject((PyTypeObject*)THPStorageClass, type_name, is_cuda, false);
}
#endif