blob: bba674420c2f38b06a2c382ea5d5298110ce7663 [file] [log] [blame]
#include "torch/csrc/autograd/python_engine.h"
#include "torch/csrc/autograd/engine.h"
#include "torch/csrc/autograd/python_function.h"
#include "torch/csrc/THP.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/PtrWrapper.h"
#include "torch/csrc/utils/auto_gil.h"
#include <unordered_set>
using namespace torch::autograd;
struct THPEngine {
PyObject_HEAD
};
static torch::autograd::python::PythonEngine engine;
namespace torch { namespace autograd { namespace python {
void PythonEngine::thread_main(std::shared_ptr<ReadyQueue> queue, int device) {
// Create a PyThreadState, but release the GIL. This lets AutoGIL calls
// inside thread_main acquire the GIL without having to create a new
// PyThreadState each time.
AutoGIL gil;
AutoNoGIL no_gil;
Engine::thread_main(queue, device);
}
void PythonEngine::thread_on_exception(FunctionTask& task, std::exception& e) {
auto python_err = dynamic_cast<python_error*>(&e);
if (python_err) {
python_err->persist();
}
Engine::thread_on_exception(task, e);
}
PythonEngine& PythonEngine::getDefaultEngine() {
return engine;
}
}}} // namespace torch::autograd::python
PyObject *THPEngineClass = NULL;
struct CallbackContext {
std::string error;
THPObjectPtr outputs;
// Used to determine which callback arguments should be used to
// fill outputs.
// Function -> ([grad_nr, outputs_idx], is_leaf)
std::unordered_map<
std::shared_ptr<Function>,
std::pair<std::vector<std::pair<int, int>>, bool>> output_map;
};
void compute_partial_exec_callbacks(const function_list& roots,
const CallbackContext& ctx,
Engine::pre_callback_map& map) {
// This callback is used to suppress the computation of a node
// if it is not necessary.
static Engine::pre_callback_type abort_callback(
[](Function* fn, variable_list &vars) { return false; });
std::vector<Function*> queue;
std::unordered_set<Function*> seen; // for the initial DFS
std::unordered_set<Function*> needed; // functions to compute
std::unordered_map<Function*, std::vector<Function*>> rev_graph;
// Reverse the next_fn edges
queue.reserve(roots.size());
for (auto& root : roots) {
auto ptr = root.first.get();
bool unseen;
std::tie(std::ignore, unseen) = seen.insert(ptr);
if (unseen) queue.emplace_back(ptr);
}
while (!queue.empty()) {
auto fn = queue.back(); queue.pop_back();
for (auto& next_fn_pair : fn->next_functions) {
auto next_fn = next_fn_pair.first.get();
if (!next_fn) continue;
rev_graph[next_fn].push_back(fn);
if (seen.insert(next_fn).second) {
queue.push_back(next_fn);
}
}
}
auto all_functions = std::move(seen); // this is cheap and improves readability
// Find all functions we need to compute
queue.clear();
for (auto input_info: ctx.output_map) {
auto input = input_info.first.get();
auto& rev_edges = rev_graph[input];
if (rev_edges.size() == 0) throw std::runtime_error("differentiated input is unreachable");
queue.emplace_back(input);
needed.insert(input);
}
while (!queue.empty()) {
auto fn = queue.back(); queue.pop_back();
for (auto rev_next_fn : rev_graph[fn]) {
if (needed.insert(rev_next_fn).second) {
queue.push_back(rev_next_fn);
}
}
}
// Prevent expansion for functions in {all_vertices} \ {needed}
for (auto fn : all_functions) {
if (needed.count(fn) > 0) continue;
map.emplace(fn, abort_callback);
}
}
// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
PyObject *variables = NULL;
PyObject *grad_variables = NULL;
unsigned char keep_graph = 0;
PyObject *inputs = NULL;
unsigned char only_inputs = 0;
const char *accepted_kwargs[] = {"variables", "grad_variables",
"keep_graph", "inputs", "only_inputs", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOb|Ob", (char**)accepted_kwargs,
&variables, &grad_variables, &keep_graph, &inputs, &only_inputs))
return NULL;
THPUtils_assert(PyTuple_Check(variables), "variables argument is expected to "
"be a tuple, but got %s", THPUtils_typename(variables));
THPUtils_assert(PyTuple_Check(grad_variables), "variables argument is "
"expected to be a tuple, but got %s", THPUtils_typename(grad_variables));
Py_ssize_t num_variables = PyTuple_GET_SIZE(variables);
Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_variables);
THPUtils_assert(num_variables == num_gradients, "got %ld variables and %ld "
"gradients", num_variables, num_gradients);
function_list roots(num_variables);
variable_list grads(num_variables);
for (int i = 0; i < num_variables; i++) {
PyObject *_variable = PyTuple_GET_ITEM(variables, i);
THPUtils_assert(THPVariable_Check(_variable), "element %d of variables "
"tuple is not a Variable", i);
auto& variable = ((THPVariable*)_variable)->cdata;
THPUtils_assert(!variable->is_volatile,
"element %d of variables tuple is volatile", i);
// If grad_fn is NULL (as is the case for a leaf node), we instead
// interpret the gradient function to be a grad accumulator,
// which will accumulate its inputs into the grad property of the
// variable. These nodes get suppressed in some situations,
// see "suppress grad accumulation" below.
auto grad_fn = variable->grad_fn ? variable->grad_fn : variable->get_grad_accumulator();
int output_nr = variable->grad_fn ? variable->output_nr : 0;
roots[i] = std::make_pair<>(std::move(grad_fn), output_nr);
PyObject *grad = PyTuple_GET_ITEM(grad_variables, i);
if (THPVariable_Check(grad)) {
grads[i] = ((THPVariable*)grad)->cdata;
} else {
THPUtils_assert(grad == Py_None,
"element %d of gradients tuple is not a Variable or None", i);
THPUtils_assert(!variable->requires_grad,
"element %d of gradients tuple is None, but the corresponding Variable requires grad");
}
}
Engine::pre_callback_map callbacks;
CallbackContext ctx;
if (inputs != NULL) {
THPUtils_assert(PyTuple_Check(inputs), "inputs argument has to be a tuple");
int num_inputs = PyTuple_GET_SIZE(inputs);
ctx.outputs = PyTuple_New(num_inputs);
// First, find all relevant functions and fill ctx.output_map
for (int i = 0; i < num_inputs; ++i) {
PyObject *input = PyTuple_GET_ITEM(inputs, i);
THPUtils_assert(THPVariable_Check(input),
"all inputs have to be Variables, but got %s", THPUtils_typename(input));
THPVariable *input_var = (THPVariable*)input;
auto grad_fn = input_var->cdata->grad_fn;
int output_nr = input_var->cdata->output_nr;
bool is_leaf = !grad_fn;
if (is_leaf) {
grad_fn = input_var->cdata->grad_accumulator.lock();
}
THPUtils_assert(grad_fn, "One of the differentiated Variables appears to not have "
"been used in the graph");
auto& fn_info = ctx.output_map[grad_fn];
fn_info.first.emplace_back(output_nr, i);
fn_info.second = is_leaf;
}
// Register callbacks that will gather the outputs
for (auto& entry : ctx.output_map) {
auto& fn_info = entry.second;
callbacks.emplace(entry.first.get(), [&ctx, &fn_info](Function* _unused, variable_list& grads) {
auto& saved_outputs = fn_info.first;
bool is_leaf = fn_info.second;
AutoGIL gil;
for (auto& saved_out : saved_outputs) {
PyTuple_SET_ITEM(ctx.outputs.get(), saved_out.second,
THPVariable_Wrap(grads[saved_out.first]));
}
// Suppress grad accumulation.
// If the variable is a leaf, the next function to execute
// is a grad_accumulator. But when inputs != NULL, we should
// NOT accumulate, so terminate execution.
return !is_leaf;
});
}
// Disable execution for all unneeded functions
if (only_inputs) {
compute_partial_exec_callbacks(roots, ctx, callbacks);
}
}
try {
AutoNoGIL no_gil;
engine.execute(roots, grads, keep_graph, callbacks);
} catch (python_error &e) {
e.restore();
return nullptr;
}
if (ctx.outputs) {
return ctx.outputs.release();
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}
PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) {
std::shared_ptr<PyObject> callback(_callback, [](PyObject *obj) { AutoGIL gil; Py_DECREF(obj); });
Py_INCREF(_callback);
engine.queue_callback([callback]() {
AutoGIL gil;
THPObjectPtr result {PyObject_CallFunctionObjArgs(callback.get(), NULL)};
if (!result) throw python_error();
});
Py_RETURN_NONE;
}
PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
return type->tp_alloc(type, 0);
}
static struct PyMethodDef THPEngine_methods[] = {
{(char*)"run_backward", (PyCFunction)THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, NULL},
{(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, NULL},
{NULL}
};
PyTypeObject THPEngineType = {
PyVarObject_HEAD_INIT(NULL, 0)
"torch._C._EngineBase", /* tp_name */
sizeof(THPEngine), /* tp_basicsize */
0, /* tp_itemsize */
0, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
NULL, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
THPEngine_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
THPEngine_new /* tp_new */
};
bool THPEngine_initModule(PyObject *module)
{
if (PyType_Ready(&THPEngineType) < 0)
return false;
Py_INCREF(&THPEngineType);
PyModule_AddObject(module, "_ImperativeEngine", (PyObject *)&THPEngineType);
return true;
}