blob: 7ee6b91b1b4ebe251fe56dbe5a4e4438edcdc473 [file] [log] [blame]
#include "python_nn_functions.h"
// ${generated_comment}
#include "torch/csrc/Device.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/structseq.h"
#include "python_nn_functions_dispatch.h"
using at::Tensor;
using at::Scalar;
using at::MemoryFormat;
using namespace torch::autograd::utils;
namespace torch { namespace autograd {
static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
auto parsed = parse_to_conversion(args, kwargs, /*allow_copy*/ false); // we don't want copy for nn.Module.to
auto& device = std::get<0>(parsed);
auto& scalarType = std::get<1>(parsed);
auto non_blocking = std::get<2>(parsed);
auto tuple = THPObjectPtr{PyTuple_New(3)};
if (!tuple) throw python_error();
if (device) {
PyTuple_SET_ITEM(tuple.get(), 0, THPDevice_New(*device));
} else {
Py_INCREF(Py_None);
PyTuple_SET_ITEM(tuple.get(), 0, Py_None);
}
if (scalarType) {
PyTuple_SET_ITEM(tuple.get(), 1, torch::autograd::utils::wrap(torch::getDtype(*scalarType)));
} else {
Py_INCREF(Py_None);
PyTuple_SET_ITEM(tuple.get(), 1, Py_None);
}
PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking));
return tuple.release();
END_HANDLE_TH_ERRORS
}
${py_methods}
static PyMethodDef nn_functions[] = {
{"_parse_to", (PyCFunction)(void(*)(void))THPVariable__parse_to, METH_VARARGS | METH_KEYWORDS, nullptr},
${py_method_defs}
{NULL}
};
void initNNFunctions(PyObject* module) {
#if PY_MAJOR_VERSION == 2
PyObject* nn = Py_InitModule("torch._C._nn", nn_functions);
Py_XINCREF(nn); // Py_InitModule returns "borrowed" reference
#else
static struct PyModuleDef def = {
PyModuleDef_HEAD_INIT,
"torch._C._nn",
NULL,
-1,
nn_functions
};
PyObject* nn = PyModule_Create(&def);
#endif
if (!nn) {
throw python_error();
}
// steals a reference to nn
if (PyModule_AddObject(module, "_nn", nn) != 0) {
throw python_error();
}
}
}} // namespace torch::autograd