blob: 4aec2b35384f9f33e50ad8dbb875eb4f3ff3dab2 [file] [log] [blame]
#pragma once
// Engine implements backpropagation from output variables and their gradients
// to "root" variables (variables created by the user with requires_grad=True).
#include <Python.h>
#include <deque>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include <functional>
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/input_buffer.h"
namespace torch { namespace autograd {
struct ReadyQueue;
struct FunctionTask;
struct GraphTask;
// A single instance of this struct should be created through the whole process lifetime.
// The worker thread creation logic and Engine's destructor rely on this.
struct Engine {
Engine();
virtual ~Engine();
using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>, InputBuffer>>;
using dependencies_type = std::unordered_map<Function*, int>;
using pre_callback_type = std::function<bool (Function*, variable_list&)>;
using pre_callback_map = std::unordered_multimap<Function*, pre_callback_type>;
using post_callback_type = std::function<bool (Function*, variable_list&, variable_list&)>;
using post_callback_map = std::unordered_multimap<Function*, post_callback_type>;
// Given a list of (Function, input number) pairs computes the value of the graph
// by following next_function references.
virtual void execute(
const function_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
const pre_callback_map& pre_callbacks = pre_callback_map(),
const post_callback_map& post_callbacks = post_callback_map());
void queue_callback(std::function<void()> callback);
protected:
void compute_dependencies(Function* root, GraphTask& task);
void evaluate_function(FunctionTask& task);
ReadyQueue& ready_queue(int device);
void start_threads();
virtual void thread_init(int device);
virtual void thread_main(GraphTask *task);
virtual void thread_on_exception(FunctionTask& task, std::exception& e);
std::once_flag start_threads_flag;
std::vector<std::shared_ptr<ReadyQueue>> ready_queues;
std::vector<std::function<void()>> final_callbacks;
std::mutex post_callbacks_lock;
};
}} // namespace torch::autograd