| #include "python_nn_functions.h" |
| |
| // ${generated_comment} |
| |
| |
| #include "torch/csrc/Device.h" |
| #include "torch/csrc/DynamicTypes.h" |
| #include "torch/csrc/Exceptions.h" |
| #include "torch/csrc/autograd/python_variable.h" |
| #include "torch/csrc/autograd/utils/wrap_outputs.h" |
| #include "torch/csrc/autograd/utils/python_arg_parsing.h" |
| #include "torch/csrc/utils/python_arg_parser.h" |
| #include "torch/csrc/utils/structseq.h" |
| |
| #include "python_nn_functions_dispatch.h" |
| |
| using at::Tensor; |
| using at::Scalar; |
| using at::MemoryFormat; |
| using namespace torch::autograd::utils; |
| |
| namespace torch { namespace autograd { |
| |
| static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObject* kwargs) |
| { |
| HANDLE_TH_ERRORS |
| auto parsed = parse_to_conversion(args, kwargs, /*allow_copy*/ false); // we don't want copy for nn.Module.to |
| auto& device = std::get<0>(parsed); |
| auto& scalarType = std::get<1>(parsed); |
| auto non_blocking = std::get<2>(parsed); |
| auto tuple = THPObjectPtr{PyTuple_New(3)}; |
| if (!tuple) throw python_error(); |
| if (device) { |
| PyTuple_SET_ITEM(tuple.get(), 0, THPDevice_New(*device)); |
| } else { |
| Py_INCREF(Py_None); |
| PyTuple_SET_ITEM(tuple.get(), 0, Py_None); |
| } |
| if (scalarType) { |
| PyTuple_SET_ITEM(tuple.get(), 1, torch::autograd::utils::wrap(torch::getDtype(*scalarType))); |
| } else { |
| Py_INCREF(Py_None); |
| PyTuple_SET_ITEM(tuple.get(), 1, Py_None); |
| } |
| PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking)); |
| return tuple.release(); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| ${py_methods} |
| |
| static PyMethodDef nn_functions[] = { |
| {"_parse_to", (PyCFunction)(void(*)(void))THPVariable__parse_to, METH_VARARGS | METH_KEYWORDS, nullptr}, |
| ${py_method_defs} |
| {NULL} |
| }; |
| |
| void initNNFunctions(PyObject* module) { |
| #if PY_MAJOR_VERSION == 2 |
| PyObject* nn = Py_InitModule("torch._C._nn", nn_functions); |
| Py_XINCREF(nn); // Py_InitModule returns "borrowed" reference |
| #else |
| static struct PyModuleDef def = { |
| PyModuleDef_HEAD_INIT, |
| "torch._C._nn", |
| NULL, |
| -1, |
| nn_functions |
| }; |
| PyObject* nn = PyModule_Create(&def); |
| #endif |
| if (!nn) { |
| throw python_error(); |
| } |
| // steals a reference to nn |
| if (PyModule_AddObject(module, "_nn", nn) != 0) { |
| throw python_error(); |
| } |
| } |
| |
| }} // namespace torch::autograd |