blob: bfef62489ad122e1da46425c475c730060953a5e [file] [log] [blame]
#include "torch/csrc/utils/pybind.h"
#include "batch_normalization.h"
#include "convolution.h"
#include "accumulate_grad.h"
#include "basic_ops.h"
#include "tensor.h"
#include "special.h"
#include "jit_closure.h"
#include "torch/csrc/THP.h"
#include "torch/csrc/autograd/python_cpp_function.h"
#include "torch/csrc/jit/python_tracer.h"
#include "torch/csrc/utils/tuple_parser.h"
#include "torch/csrc/DynamicTypes.h"
namespace pybind11 { namespace detail {
// handle Python <-> torch::autograd::Function conversions
template <> struct type_caster<std::shared_ptr<torch::autograd::Function>> {
public:
PYBIND11_TYPE_CASTER(std::shared_ptr<torch::autograd::Function>, _("std::shared_ptr<torch::autograd::Function>"));
bool load(handle src, bool) {
if (!THPFunction_Check(src.ptr())) return false;
value = THPFunction_asFunction((THPFunction*)src.ptr());
return true;
}
static handle cast(std::shared_ptr<torch::autograd::Function> src, return_value_policy /* policy */, handle /* parent */) {
auto fn = functionToPyObject(src);
return handle(fn);
}
};
}} // namespace pybind11::detail
using namespace torch::autograd;
using torch::TupleParser;
struct BatchNormCtor {
BatchNormForward* operator()(PyObject* args) {
BatchNormParams params;
TupleParser parser(args, 6);
parser.parse(params.running_mean, "running_mean");
parser.parse(params.running_var, "running_var");
parser.parse(params.training, "training");
parser.parse(params.momentum, "momentum");
parser.parse(params.eps, "eps");
parser.parse(params.cudnn_enabled, "cudnn_enabled");
return new BatchNormForward(std::move(params));
}
};
struct ConvCtor {
ConvForward* operator()(PyObject* args) {
ConvParams params;
TupleParser parser(args, 8);
parser.parse(params.stride, "stride");
parser.parse(params.padding, "padding");
parser.parse(params.dilation, "dilation");
parser.parse(params.transposed, "transposed");
parser.parse(params.output_padding, "output_padding");
parser.parse(params.groups, "groups");
parser.parse(params.benchmark, "benchmark");
parser.parse(params.cudnn_enabled, "cudnn_enabled");
return new ConvForward(std::move(params));
}
};
struct DelayedErrorCtor {
DelayedError* operator()(PyObject* args) {
std::string msg;
TupleParser parser(args, 1);
parser.parse(msg, "msg");
return new DelayedError(msg);
}
};
struct NoCtor {
Function* operator()(PyObject* args) {
throw std::runtime_error("Cannot construct");
}
};
template<typename C, typename T>
static void addClass(PyObject* module, PyTypeObject& type, const char* name,
PyGetSetDef* function_properties=NULL, PyMethodDef* function_methods=NULL)
{
createForwardFunctionPyTypeObject<T>(type, name, function_properties, function_methods);
Py_INCREF(&type);
PyModule_AddObject(module, name, (PyObject*)&type);
registerCppFunction(typeid(C), &type);
}
template<typename T, typename ValueT, typename ParamsT, ValueT ParamsT::*ptr,
typename ConvertArgT, PyObject* (*Convert)(ConvertArgT)>
PyObject* getTupleAttr(PyObject* obj, void* _unused)
{
HANDLE_TH_ERRORS
THPCppFunction* self = (THPCppFunction*)obj;
auto& arr = ((T*)(self->cdata.get()))->*ptr;
auto num_elems = arr.size();
THPObjectPtr py_tuple(PyTuple_New(num_elems));
if (!py_tuple) return NULL;
for (size_t i = 0; i < num_elems; ++i) {
PyTuple_SET_ITEM(py_tuple.get(), i, Convert(arr[i]));
}
return py_tuple.release();
END_HANDLE_TH_ERRORS
}
template<typename T, typename ValueT, typename ParamsT, ValueT ParamsT::*ptr,
typename ConvertArgT, PyObject* (*Convert)(ConvertArgT)>
PyObject* getValueAttr(PyObject* obj, void* _unused)
{
HANDLE_TH_ERRORS
THPCppFunction* self = (THPCppFunction*)obj;
auto& val = ((T*)(self->cdata.get()))->*ptr;
return Convert(val);
END_HANDLE_TH_ERRORS
}
template<typename T, typename ParamsT, at::Tensor ParamsT::*ptr>
PyObject* getTensorAttr(PyObject* obj, void* _unused)
{
HANDLE_TH_ERRORS
THPCppFunction* self = (THPCppFunction*)obj;
auto& val = ((T*)(self->cdata.get()))->*ptr;
THPObjectPtr py_tensor;
if (!val.defined()) {
Py_INCREF(Py_None);
py_tensor = Py_None;
} else {
py_tensor = torch::createPyObject(val);
}
return py_tensor.release();
END_HANDLE_TH_ERRORS
}
static struct PyGetSetDef batch_norm_forward_properties[] = {
THP_FUNCTION_DEFAULT_PROPERTIES,
{(char*)"running_mean", (getter)getTensorAttr<BatchNormForward, BatchNormParams,
&BatchNormParams::running_mean>, NULL, NULL, NULL},
{(char*)"running_var", (getter)getTensorAttr<BatchNormForward, BatchNormParams,
&BatchNormParams::running_var>, NULL, NULL, NULL},
{(char*)"training", (getter)getValueAttr<BatchNormForward, bool, BatchNormParams,
&BatchNormParams::training, long, PyBool_FromLong>, NULL, NULL, NULL},
{(char*)"momentum", (getter)getValueAttr<BatchNormForward, double, BatchNormParams,
&BatchNormParams::momentum, double, PyFloat_FromDouble>, NULL, NULL, NULL},
{(char*)"eps", (getter)getValueAttr<BatchNormForward, double, BatchNormParams,
&BatchNormParams::eps, double, PyFloat_FromDouble>, NULL, NULL, NULL},
{(char*)"cudnn_enabled", (getter)getValueAttr<BatchNormForward, bool, BatchNormParams,
&BatchNormParams::cudnn_enabled, long, PyBool_FromLong>, NULL, NULL, NULL},
{NULL}
};
static struct PyGetSetDef batch_norm_backward_properties[] = {
THP_FUNCTION_DEFAULT_PROPERTIES,
{(char*)"running_mean", (getter)getTensorAttr<BatchNormBackward, BatchNormParams,
&BatchNormParams::running_mean>, NULL, NULL, NULL},
{(char*)"running_var", (getter)getTensorAttr<BatchNormBackward, BatchNormParams,
&BatchNormParams::running_var>, NULL, NULL, NULL},
{(char*)"training", (getter)getValueAttr<BatchNormBackward, bool, BatchNormParams,
&BatchNormParams::training, long, PyBool_FromLong>, NULL, NULL, NULL},
{(char*)"momentum", (getter)getValueAttr<BatchNormBackward, double, BatchNormParams,
&BatchNormParams::momentum, double, PyFloat_FromDouble>, NULL, NULL, NULL},
{(char*)"eps", (getter)getValueAttr<BatchNormBackward, double, BatchNormParams,
&BatchNormParams::eps, double, PyFloat_FromDouble>, NULL, NULL, NULL},
{(char*)"cudnn_enabled", (getter)getValueAttr<BatchNormBackward, bool, BatchNormParams,
&BatchNormParams::cudnn_enabled, long, PyBool_FromLong>, NULL, NULL, NULL},
{NULL}
};
static struct PyGetSetDef batch_norm_backward_backward_properties[] = {
THP_FUNCTION_DEFAULT_PROPERTIES,
{(char*)"running_mean", (getter)getTensorAttr<BatchNormBackwardBackward, BatchNormParams,
&BatchNormParams::running_mean>, NULL, NULL, NULL},
{(char*)"running_var", (getter)getTensorAttr<BatchNormBackwardBackward, BatchNormParams,
&BatchNormParams::running_var>, NULL, NULL, NULL},
{(char*)"training", (getter)getValueAttr<BatchNormBackwardBackward, bool, BatchNormParams,
&BatchNormParams::training, long, PyBool_FromLong>, NULL, NULL, NULL},
{(char*)"momentum", (getter)getValueAttr<BatchNormBackwardBackward, double, BatchNormParams,
&BatchNormParams::momentum, double, PyFloat_FromDouble>, NULL, NULL, NULL},
{(char*)"eps", (getter)getValueAttr<BatchNormBackwardBackward, double, BatchNormParams,
&BatchNormParams::eps, double, PyFloat_FromDouble>, NULL, NULL, NULL},
{(char*)"cudnn_enabled", (getter)getValueAttr<BatchNormBackwardBackward, bool, BatchNormParams,
&BatchNormParams::cudnn_enabled, long, PyBool_FromLong>, NULL, NULL, NULL},
{NULL}
};
static struct PyGetSetDef conv_forward_properties[] = {
THP_FUNCTION_DEFAULT_PROPERTIES,
{(char*)"stride", (getter)getTupleAttr<ConvForward, std::vector<int>, ConvParams,
&ConvParams::stride, long, PyInt_FromLong>, NULL, NULL, NULL},
{(char*)"padding", (getter)getTupleAttr<ConvForward, std::vector<int>, ConvParams,
&ConvParams::padding, long, PyInt_FromLong>, NULL, NULL, NULL},
{(char*)"dilation", (getter)getTupleAttr<ConvForward, std::vector<int>, ConvParams,
&ConvParams::dilation, long, PyInt_FromLong>, NULL, NULL, NULL},
{(char*)"transposed", (getter)getValueAttr<ConvForward, bool, ConvParams,
&ConvParams::transposed, long, PyBool_FromLong>, NULL, NULL, NULL},
{(char*)"output_padding", (getter)getTupleAttr<ConvForward, std::vector<int>, ConvParams,
&ConvParams::output_padding, long, PyInt_FromLong>, NULL, NULL, NULL},
{(char*)"groups", (getter)getValueAttr<ConvForward, int, ConvParams,
&ConvParams::groups, long, PyInt_FromLong>, NULL, NULL, NULL},
{NULL}
};
static struct PyGetSetDef conv_backward_properties[] = {
THP_FUNCTION_DEFAULT_PROPERTIES,
{(char*)"stride", (getter)getTupleAttr<ConvBackward, std::vector<int>, ConvParams,
&ConvParams::stride, long, PyInt_FromLong>, NULL, NULL, NULL},
{(char*)"padding", (getter)getTupleAttr<ConvBackward, std::vector<int>, ConvParams,
&ConvParams::padding, long, PyInt_FromLong>, NULL, NULL, NULL},
{(char*)"dilation", (getter)getTupleAttr<ConvBackward, std::vector<int>, ConvParams,
&ConvParams::dilation, long, PyInt_FromLong>, NULL, NULL, NULL},
{(char*)"transposed", (getter)getValueAttr<ConvBackward, bool, ConvParams,
&ConvParams::transposed, long, PyBool_FromLong>, NULL, NULL, NULL},
{(char*)"output_padding", (getter)getTupleAttr<ConvBackward, std::vector<int>, ConvParams,
&ConvParams::output_padding, long, PyInt_FromLong>, NULL, NULL, NULL},
{(char*)"groups", (getter)getValueAttr<ConvBackward, int, ConvParams,
&ConvParams::groups, long, PyInt_FromLong>, NULL, NULL, NULL},
{NULL}
};
static struct PyGetSetDef conv_backward_backward_properties[] = {
THP_FUNCTION_DEFAULT_PROPERTIES,
{(char*)"stride", (getter)getTupleAttr<ConvBackwardBackward, std::vector<int>, ConvParams,
&ConvParams::stride, long, PyInt_FromLong>, NULL, NULL, NULL},
{(char*)"padding", (getter)getTupleAttr<ConvBackwardBackward, std::vector<int>, ConvParams,
&ConvParams::padding, long, PyInt_FromLong>, NULL, NULL, NULL},
{(char*)"dilation", (getter)getTupleAttr<ConvBackwardBackward, std::vector<int>, ConvParams,
&ConvParams::dilation, long, PyInt_FromLong>, NULL, NULL, NULL},
{(char*)"transposed", (getter)getValueAttr<ConvBackwardBackward, bool, ConvParams,
&ConvParams::transposed, long, PyBool_FromLong>, NULL, NULL, NULL},
{(char*)"output_padding", (getter)getTupleAttr<ConvBackwardBackward, std::vector<int>, ConvParams,
&ConvParams::output_padding, long, PyInt_FromLong>, NULL, NULL, NULL},
{(char*)"groups", (getter)getValueAttr<ConvBackwardBackward, int, ConvParams,
&ConvParams::groups, long, PyInt_FromLong>, NULL, NULL, NULL},
{NULL}
};
static PyObject* accumulateGradVar(PyObject *_self, void* _unused)
{
THPCppFunction* self = (THPCppFunction*)_self;
auto grad_acc = (AccumulateGrad*)self->cdata.get();
auto var = grad_acc->variable.lock();
if (!var) Py_RETURN_NONE;
return THPVariable_Wrap(var);
}
static struct PyGetSetDef accumulate_grad_properties[] = {
THP_FUNCTION_DEFAULT_PROPERTIES,
{(char*)"variable", accumulateGradVar, NULL, NULL, NULL},
{NULL}
};
bool THPAutograd_initFunctions(PyObject* _unused)
{
THPObjectPtr module(PyModule_New("torch._C._functions"));
if (!module) return false;
static PyTypeObject BatchNormClass, BatchNormBackwardClass, BatchNormBackwardBackwardClass;
addClass<BatchNormForward, BatchNormCtor>(module, BatchNormClass, "BatchNorm", batch_norm_forward_properties);
addClass<BatchNormBackward, NoCtor>(module, BatchNormBackwardClass, "BatchNormBackward", batch_norm_backward_properties);
addClass<BatchNormBackwardBackward, NoCtor>(module, BatchNormBackwardBackwardClass, "BatchNormBackwardBackward", batch_norm_backward_backward_properties);
static PyTypeObject ConvClass, ConvBackwardClass, ConvBackwardBackwardClass;
addClass<ConvForward, ConvCtor>(module, ConvClass, "ConvNd", conv_forward_properties);
addClass<ConvBackward, NoCtor>(module, ConvBackwardClass, "ConvNdBackward", conv_backward_properties);
addClass<ConvBackwardBackward, NoCtor>(module, ConvBackwardBackwardClass, "ConvNdBackwardBackward", conv_backward_backward_properties);
static PyTypeObject AccumulateGradClass;
addClass<AccumulateGrad, NoCtor>(module, AccumulateGradClass, "AccumulateGrad", accumulate_grad_properties);
static PyTypeObject AddClass, AddBackwardClass;
addClass<Add, NoCtor>(module, AddClass, "Add");
addClass<AddBackward, NoCtor>(module, AddBackwardClass, "AddBackward");
static PyTypeObject ErrorClass;
addClass<Error, NoCtor>(module, ErrorClass, "Error");
static PyTypeObject DelayedErrorClass;
addClass<DelayedError, DelayedErrorCtor>(module, DelayedErrorClass, "DelayedError");
static PyTypeObject CloneClass;
addClass<Clone, NoCtor>(module, CloneClass, "Clone");
static PyTypeObject ContiguousClass;
addClass<Contiguous, NoCtor>(module, ContiguousClass, "Contiguous");
static PyTypeObject IdentityClass;
addClass<Identity, NoCtor>(module, IdentityClass, "Identity");
static PyTypeObject TransposeClass;
addClass<Transpose, NoCtor>(module, TransposeClass, "Transpose");
static PyTypeObject ViewClass;
addClass<View, NoCtor>(module, ViewClass, "View");
static PyTypeObject ExpandClass;
addClass<Expand, NoCtor>(module, ExpandClass, "Expand");
static PyTypeObject NarrowClass;
addClass<Narrow, NoCtor>(module, NarrowClass, "Narrow");
static PyTypeObject CatClass;
addClass<Cat, NoCtor>(module, CatClass, "Cat");
static PyTypeObject EvalClass;
addClass<Eval, NoCtor>(module, EvalClass, "Eval");
static PyTypeObject AutogradClosureClass;
addClass<AutogradClosure, NoCtor>(module, AutogradClosureClass, "AutogradClosure");
THPObjectPtr parent(PyImport_ImportModule("torch._C"));
if (!parent) return false;
PyModule_AddObject(parent.get(), "_functions", module.release());
return true;
}
namespace torch { namespace autograd {
void initAutogradClosureBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<AutogradClosureFactory,std::shared_ptr<AutogradClosureFactory>>(m, "AutogradClosureFactory")
.def("__call__", &AutogradClosureFactory::construct)
;
m.def("_jit_createAutogradClosure", [](jit::tracer::TracingState* tracing_state) {
return std::make_shared<AutogradClosureFactory>(tracing_state);
});
}
}}