blob: e52e06a34199c564186ae06009ece87f60667065 [file] [log] [blame]
#include "torch/csrc/autograd/python_engine.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/PtrWrapper.h"
#include "torch/csrc/THP.h"
#include "torch/csrc/autograd/engine.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/python_function.h"
#include "torch/csrc/utils/auto_gil.h"
#ifndef _WIN32
#include <pthread.h>
#endif
#include <unordered_set>
using namespace torch::autograd;
struct THPEngine {
PyObject_HEAD
};
static torch::autograd::python::PythonEngine engine;
static Engine& get_python_engine() {
return engine;
}
namespace torch { namespace autograd { namespace python {
void PythonEngine::thread_init(int device) {
// Create a PyThreadState, but release the GIL. This lets AutoGIL calls
// inside thread_main acquire the GIL without having to create a new
// PyThreadState each time.
AutoGIL gil;
AutoNoGIL no_gil;
Engine::thread_init(device);
}
void PythonEngine::thread_on_exception(FunctionTask& task, std::exception& e) {
auto python_err = dynamic_cast<python_error*>(&e);
if (python_err) {
python_err->persist();
}
Engine::thread_on_exception(task, e);
}
variable_list PythonEngine::execute(
const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
const edge_list& outputs) {
try {
return Engine::execute(roots, inputs, keep_graph, create_graph, outputs);
} catch (python_error& e) {
e.restore();
throw;
}
}
}}} // namespace torch::autograd::python
PyObject *THPEngineClass = nullptr;
static bool _reinitialize_engine = false;
static void _maybe_reinitialize_engine_after_fork() {
// This is "probably" thread-safe because the flag is set in a fork handler
// before any threads are created, and this function is only called with the
// GIL held. However, using fork + threads is playing with fire so this is
// more of a "best effort" thing. For example, if the fork occurs while the
// backwards threads hold a lock, we'll probably deadlock in the engine
// destructor.
if (_reinitialize_engine) {
engine.~PythonEngine();
new (&engine) torch::autograd::python::PythonEngine();
_reinitialize_engine = false;
}
}
// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
_maybe_reinitialize_engine_after_fork();
PyObject *tensors = nullptr;
PyObject *grad_tensors = nullptr;
unsigned char keep_graph = 0;
unsigned char create_graph = 0;
PyObject *inputs = nullptr;
unsigned char allow_unreachable = 0;
const char *accepted_kwargs[] = {
"tensors", "grad_tensors", "keep_graph", "create_graph", "inputs",
"allow_unreachable", nullptr
};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Ob", (char**)accepted_kwargs,
&tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable))
return nullptr;
THPUtils_assert(PyTuple_Check(tensors), "tensors argument is expected to "
"be a tuple, but got %s", THPUtils_typename(tensors));
THPUtils_assert(PyTuple_Check(grad_tensors), "grad_tensors argument is "
"expected to be a tuple, but got %s", THPUtils_typename(grad_tensors));
Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
THPUtils_assert(num_tensors == num_gradients, "got %ld tensors and %ld "
"gradients", num_tensors, num_gradients);
edge_list roots;
roots.reserve(num_tensors);
variable_list grads;
grads.reserve(num_tensors);
for (int i = 0; i < num_tensors; i++) {
PyObject *_tensor = PyTuple_GET_ITEM(tensors, i);
THPUtils_assert(THPVariable_Check(_tensor), "element %d of tensors "
"tuple is not a Tensor", i);
auto& variable = ((THPVariable*)_tensor)->cdata;
auto gradient_edge = variable.gradient_edge();
THPUtils_assert(gradient_edge.function,
"element %d of tensors does not require grad and does not have a grad_fn", i);
roots.push_back(std::move(gradient_edge));
PyObject *grad = PyTuple_GET_ITEM(grad_tensors, i);
if (THPVariable_Check(grad)) {
grads.push_back(((THPVariable*)grad)->cdata);
} else {
THPUtils_assert(grad == Py_None,
"element %d of gradients tuple is not a Tensor or None", i);
THPUtils_assert(!variable.requires_grad(),
"element %d of gradients tuple is None, but the corresponding Tensor requires grad");
}
}
std::vector<Edge> output_edges;
if (inputs != nullptr) {
int num_inputs = PyTuple_GET_SIZE(inputs);
output_edges.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
PyObject *input = PyTuple_GET_ITEM(inputs, i);
THPUtils_assert(THPVariable_Check(input),
"all inputs have to be Tensors, but got %s", THPUtils_typename(input));
THPVariable *input_var = (THPVariable*)input;
const auto output_nr = input_var->cdata.output_nr();
auto grad_fn = input_var->cdata.grad_fn();
if (!grad_fn) {
grad_fn = input_var->cdata.try_get_grad_accumulator();
}
THPUtils_assert(input_var->cdata.requires_grad(),
"One of the differentiated Tensors does not require grad");
if (!grad_fn) {
output_edges.emplace_back();
} else {
output_edges.emplace_back(grad_fn, output_nr);
}
}
}
variable_list outputs;
{
AutoNoGIL no_gil;
outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges);
}
if (inputs != nullptr) {
int num_inputs = PyTuple_GET_SIZE(inputs);
THPObjectPtr py_outputs {PyTuple_New(num_inputs)};
if (!py_outputs) return nullptr;
for (int i = 0; i < num_inputs; i++) {
THPUtils_assert(allow_unreachable || outputs[i].defined(), "One of the "
"differentiated Tensors appears to not have been used "
"in the graph. Set allow_unused=True if this is the "
"desired behavior.");
PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
}
return py_outputs.release();
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}
PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) {
HANDLE_TH_ERRORS
_maybe_reinitialize_engine_after_fork();
std::shared_ptr<PyObject> callback(_callback, [](PyObject *obj) { AutoGIL gil; Py_DECREF(obj); });
Py_INCREF(_callback);
engine.queue_callback([callback]() {
AutoGIL gil;
THPObjectPtr result {PyObject_CallFunctionObjArgs(callback.get(), nullptr)};
if (!result) throw python_error();
});
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THPEngine_is_checkpoint_valid(PyObject *self) {
HANDLE_TH_ERRORS
if(engine.is_checkpoint_valid()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
return type->tp_alloc(type, 0);
}
static struct PyMethodDef THPEngine_methods[] = {
{(char*)"run_backward", (PyCFunction)THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, nullptr},
{(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, nullptr},
{(char*)"is_checkpoint_valid", (PyCFunction)THPEngine_is_checkpoint_valid, METH_NOARGS, nullptr},
{nullptr}
};
PyTypeObject THPEngineType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._EngineBase", /* tp_name */
sizeof(THPEngine), /* tp_basicsize */
0, /* tp_itemsize */
0, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
nullptr, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
THPEngine_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
THPEngine_new /* tp_new */
};
static void child_atfork() {
_reinitialize_engine = true;
}
bool THPEngine_initModule(PyObject *module)
{
#ifndef _WIN32
if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
throw std::runtime_error("unable to set pthread_atfork handler");
}
#endif
if (PyType_Ready(&THPEngineType) < 0)
return false;
Py_INCREF(&THPEngineType);
PyModule_AddObject(module, "_ImperativeEngine", (PyObject *)&THPEngineType);
set_default_engine_stub(get_python_engine);
return true;
}