| #include "torch/csrc/autograd/engine.h" |
| |
| #include "torch/csrc/autograd/function.h" |
| #include "torch/csrc/autograd/functions/basic_ops.h" |
| #include "torch/csrc/autograd/grad_mode.h" |
| #include "torch/csrc/autograd/variable.h" |
| #include "torch/csrc/utils/auto_gpu.h" |
| |
| #include <atomic> |
| #include <condition_variable> |
| #include <cstdint> |
| #include <functional> |
| #include <iostream> |
| #include <memory> |
| #include <mutex> |
| #include <set> |
| #include <string> |
| #include <thread> |
| #include <unordered_set> |
| #include <typeinfo> |
| #include <sstream> |
| #include <queue> |
| #include <TH/TH.h> |
| |
| #ifdef WITH_CUDA |
| #include <cuda.h> |
| #include <THC/THC.h> |
| #endif |
| |
| namespace torch { namespace autograd { |
| |
| // NB: -1 indicates the CPU worker! |
| static constexpr int NO_DEVICE = -2; |
| |
| // Threads spawned by the engine are assigned a constant 'worker_device' |
| // specifying what device they process work for. This variable is initialized |
| // at thread creation time and is constant afterwards. This is used when |
| // handling reentrant backwards calls; see Note [Reentrant backwards] |
| static thread_local int worker_device = NO_DEVICE; |
| |
| // This variable is true if ALL invocations in the stack of re-entrant engine |
| // invocations are imperative backwards. This special variable is needed for the |
| // gradient checkpointing feature only. |
| static thread_local bool checkpoint_valid = true; |
| |
| // XXX: Changes to the way multithreading works in execute should be done with |
| // great care. Right now the implementation guarantees that a single function's |
| // apply will never be entered concurrently (even if multiple graphs are |
| // executed at the same time). Adding multiple threads per-device or removing |
| // engine thread affinity to the device can break this invariant, and we depend |
| // on it in a few places (e.g. AccumulateGrad function). |
| |
| struct FunctionTask { |
| GraphTask* base; |
| std::shared_ptr<Function> fn; |
| // This buffer serves as an implicit "addition" node for all of the |
| // gradients flowing here. Once all the dependencies are finished, we |
| // use the contents of this buffer to run the function. |
| InputBuffer inputs; |
| |
| FunctionTask(GraphTask* base, std::shared_ptr<Function> fn, InputBuffer inputs) |
| : base(base) |
| , fn(fn) |
| , inputs(std::move(inputs)) {} |
| }; |
| |
| // Returns true when t2 should be (weakly) BEFORE t1 in the queue. |
| struct CompareFunctionTaskTime { |
| bool operator()(FunctionTask const & t1, FunctionTask const & t2) { |
| return t1.fn->sequence_nr() < t2.fn->sequence_nr(); |
| } |
| }; |
| |
| struct ReadyQueue { |
| std::priority_queue<FunctionTask, std::vector<FunctionTask>, CompareFunctionTaskTime> heap; |
| std::condition_variable not_empty; |
| std::mutex mutex; |
| |
| void push(FunctionTask item); |
| FunctionTask pop(); |
| }; |
| |
| // Note [Reentrant backwards] |
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // To understand the reentrant backwards problem, we have to notice two |
| // aspects of how the autograd engine is implemented today: |
| // |
| // 1. When you call Engine::execute(), you want to block until |
| // differentiation finishes so that you can get the final result variables |
| // of the backwards pass. |
| // |
| // 2. The engine operates by having a single worker thread per work queue, |
| // and every work queue is pinned to a specific device where the |
| // operation is executed. |
| // |
| // The problem is, suppose that you call backward() inside of a worker |
| // thread. By property (1), we're supposed to block until the nested task |
| // finishes. However, by property (2), this worker thread is on the |
| // hook for processing the tasks assigned to it; we better not block, |
| // because then all of our backward executions (including the one we |
| // just started) will deadlock! |
| // |
| // Here's our cunning idea: instead of blocking, just get back to work |
| // on whatever task queue you should have been working on previously |
| // (this is saved via the thread local variable worker_device)! There are |
| // "simply" two things you have to arrange for: |
| // |
| // - We have to promptly kick ourselves out of the thread_main() loop |
| // when our graph_task complete, because we need to unblock the |
| // parent function tasks that started the reentrant execution in |
| // the first place. This is why thread_main() takes an optional |
| // graph_task as input. |
| // |
| // - When we finish a GraphTask, we have to make sure we wake up the worker |
| // thread so that it actually has a chance to exit the thread_main() |
| // loop. Thus the faffing about in thread_main() after |
| // evaluate_function() completes. |
| |
| |
| // GraphTask holds metadata needed for a single execution of backward() |
| struct GraphTask { |
| std::exception_ptr exception; |
| // Indicates if an error occurred while executing any task. When this is |
| // true, it signals all threads to stop executing. |
| std::atomic_bool has_error; |
| std::atomic<uint64_t> outstanding_tasks; |
| bool keep_graph; |
| bool grad_mode; |
| |
| std::mutex mutex; |
| // Notified when a task finishes executing. Check outstanding_tasks to see |
| // if all tasks are done. |
| std::condition_variable not_done; |
| std::unordered_map<Function*, InputBuffer> not_ready; |
| std::unordered_map<Function*, int> dependencies; |
| |
| struct ExecInfo { |
| struct Capture { |
| Capture(int input_idx, int output_idx) : input_idx(input_idx), output_idx(output_idx) {} |
| int input_idx; // within Function inputs |
| int output_idx; // within the output vector of a GraphTask |
| }; |
| |
| bool should_execute() const { |
| return needed || captures; |
| } |
| |
| bool needed = false; |
| std::unique_ptr<std::vector<Capture>> captures; |
| }; |
| // Exec info has a bit complicated semantics. If it's empty, it means the task is |
| // run in a "default" mode, which means that all next_edges we encounter should |
| // get executed. If it's not empty, only functions that have an entry and this entry |
| // has needed == True should be executed. |
| // exec_info.empty() means it's .backward(), otherwise it's .grad(). |
| std::unordered_map<Function*, ExecInfo> exec_info; |
| std::vector<Variable> captured_vars; |
| |
| void init_to_execute(Function& graph_root, const edge_list& captures); |
| |
| // The value of worker_device in the thread that created this task. |
| // See Note [Reentrant backwards] |
| int owner; |
| |
| bool can_checkpoint() { |
| return exec_info.empty(); |
| } |
| |
| GraphTask(bool keep_graph, bool grad_mode) |
| : exception() |
| , has_error(false) |
| , outstanding_tasks(0) |
| , keep_graph(keep_graph) |
| , grad_mode(grad_mode) |
| , mutex() |
| , not_done() |
| , not_ready() |
| , dependencies() |
| , owner(NO_DEVICE) {} |
| }; |
| |
| auto ReadyQueue::push(FunctionTask item) -> void { |
| { |
| std::lock_guard<std::mutex> lock(mutex); |
| ++item.base->outstanding_tasks; |
| heap.push(std::move(item)); |
| } |
| not_empty.notify_one(); |
| } |
| |
| auto ReadyQueue::pop() -> FunctionTask { |
| std::unique_lock<std::mutex> lock(mutex); |
| not_empty.wait(lock, [this]{ return !heap.empty(); }); |
| auto task = std::move(const_cast<FunctionTask&>(heap.top())); heap.pop(); |
| return task; |
| } |
| |
| Engine::Engine() : ready_queues() { |
| } |
| |
| // This Engine's ReadyQueues and their corresponding threads are leaked here |
| Engine::~Engine() = default; |
| |
| auto Engine::thread_init(int device) -> void { |
| THInferNumThreads(); |
| AutoGPU guard(device); |
| worker_device = device; |
| thread_main(nullptr); |
| } |
| |
| // NOTE: graph_tasks do not necessarily form a stack. Imagine this |
| // case: |
| // |
| // +----> Eval1 |
| // Root |
| // +----> Eval2 |
| // |
| // Once Root is executed, both Eval1 and Eval2 are added to the ready queue. |
| // Next, Eval1 is run and this causes the worker to enter thread_main again. |
| // Then, it pops the next task from the queue, but at this point it is Eval2. |
| // It enters thread_main once again, but now with graph_task of Eval2, which is |
| // completely unrelated to that of Eval1 (it's not a recursive call). |
| // It's all ok and is handled right now, but it should be accounted for |
| // in case this code is to be changed. |
| auto Engine::thread_main(GraphTask *graph_task) -> void { |
| auto queue = ready_queues[worker_device + 1]; |
| // Why the test on graph_task->outstanding_tasks? See |
| // Note [Reentrant backwards] |
| while (!graph_task || graph_task->outstanding_tasks > 0) { |
| FunctionTask task = queue->pop(); |
| if (task.fn && !task.base->has_error.load()) { |
| GradMode::set_enabled(task.base->grad_mode); |
| try { |
| evaluate_function(task); |
| } catch (std::exception& e) { |
| thread_on_exception(task, e); |
| } |
| } |
| // Notify downstream about the completion of tasks depending |
| // on both where the task was executed, and who owned the overall |
| // graph (in case of reentrant execution.) See Note [Reentrant backwards]. |
| auto base_owner = task.base->owner; |
| // Task from a non-worker thread. Easy case. |
| if (base_owner == NO_DEVICE) { |
| if (--task.base->outstanding_tasks == 0) { |
| std::lock_guard<std::mutex> lock(task.base->mutex); |
| task.base->not_done.notify_all(); |
| } |
| } else { |
| // If it's a task initiated from this thread, decrease the counter, but |
| // don't do anything - loop condition will do all checks for us next. |
| if (base_owner == worker_device) { |
| --task.base->outstanding_tasks; |
| // Otherwise send a dummy function task to the owning thread just to |
| // ensure that it's not sleeping. If it has work, it might see that |
| // graph_task->outstanding_tasks == 0 before it gets to the task, but |
| // it's a no-op anyway. |
| } else if (base_owner != worker_device) { |
| if (--task.base->outstanding_tasks == 0) { |
| // Synchronize outstanding_tasks with queue mutex |
| std::atomic_thread_fence(std::memory_order_release); |
| ready_queue(base_owner).push(FunctionTask(task.base, nullptr, InputBuffer(0))); |
| } |
| } |
| } |
| } |
| } |
| |
| auto Engine::thread_on_exception(FunctionTask& task, std::exception& e) -> void { |
| std::lock_guard<std::mutex> lock(task.base->mutex); |
| if (!task.base->has_error.load()) { |
| task.base->exception = std::current_exception(); |
| task.base->has_error = true; |
| } |
| } |
| |
| static variable_list call_pre_hooks(Function& fn, variable_list inputs) { |
| for (const auto& hook : fn.pre_hooks()) { |
| inputs = (*hook)(inputs); |
| } |
| return inputs; |
| } |
| |
| static variable_list call_post_hooks(Function& fn, variable_list outputs, variable_list inputs) { |
| for (const auto& hook : fn.post_hooks()) { |
| outputs = (*hook)(outputs, inputs); |
| } |
| return outputs; |
| } |
| |
| static bool is_compatible_type(const at::Type& expected, const at::Type& actual) { |
| // Types are compatible if they exactly match or if the gradient is a sparse |
| // version of the expected type. |
| return expected == actual || (actual.is_sparse() && |
| expected == actual.toBackend(toDense(actual.backend()))); |
| } |
| |
| template<typename F> |
| static void validate_outputs(const edge_list& edges, const variable_list& grads, const F& format_error) { |
| if (grads.size() != edges.size()) { |
| std::stringstream ss; |
| ss << "invalid number of gradients - expected "; |
| ss << edges.size() << ", but got " << grads.size(); |
| throw std::runtime_error(format_error(ss.str())); |
| } |
| for (size_t i = 0; i < grads.size(); i++) { |
| const auto& edge = edges[i]; |
| if (!edge.is_valid()) continue; |
| |
| const auto& metadata = edge.function->input_metadata(edge.input_nr); |
| const auto& output = grads[i]; |
| if (!output.defined()) { |
| // FIXME: TestJit.test_ge_optimized fails this assertion. |
| // std::stringstream ss; |
| // ss << "undefined gradient at index " << i; |
| // throw std::runtime_error(format_error(ss.str())); |
| continue; |
| } |
| if (!grads[i].sizes().equals(metadata.shape())) { |
| std::stringstream ss; |
| ss << "invalid gradient at index " << i << " - expected shape "; |
| ss << metadata.shape() << " but got " << grads[i].sizes(); |
| throw std::runtime_error(format_error(ss.str())); |
| } |
| if (!is_compatible_type(metadata.type(), grads[i].type())) { |
| std::stringstream ss; |
| ss << "invalid gradient at index " << i << " - expected type "; |
| ss << metadata.type() << " but got " << grads[i].type(); |
| throw std::runtime_error(format_error(ss.str())); |
| } |
| } |
| } |
| |
| static variable_list call_function(FunctionTask& task) { |
| bool prev_checkpoint_valid_state = checkpoint_valid; |
| checkpoint_valid = task.base->can_checkpoint() && prev_checkpoint_valid_state; |
| auto& fn = *task.fn; |
| auto inputs = call_pre_hooks(fn, InputBuffer::variables(std::move(task.inputs))); |
| |
| if(!task.base->keep_graph) { |
| fn.will_release_variables(); |
| } |
| auto outputs = fn(inputs); |
| validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) { |
| std::ostringstream ss; |
| ss << "Function " << fn.name() << " returned an " << msg; |
| return ss.str(); |
| }); |
| checkpoint_valid = prev_checkpoint_valid_state; |
| return call_post_hooks(fn, std::move(outputs), std::move(inputs)); |
| } |
| |
| auto Engine::evaluate_function(FunctionTask& task) -> void { |
| // If exec_info is not empty, we have to instrument the execution |
| auto & exec_info = task.base->exec_info; |
| if (!exec_info.empty()) { |
| auto & fn_info = exec_info.at(task.fn.get()); |
| if (auto *capture_vec = fn_info.captures.get()) { |
| std::lock_guard<std::mutex> lock(task.base->mutex); |
| for (auto capture : *capture_vec) { |
| task.base->captured_vars[capture.output_idx] = task.inputs[capture.input_idx]; |
| } |
| } |
| if (!fn_info.needed) return; |
| } |
| |
| auto outputs = call_function(task); |
| |
| auto& fn = *task.fn; |
| if (!task.base->keep_graph) { |
| fn.release_variables(); |
| } |
| |
| int num_outputs = outputs.size(); |
| if (num_outputs == 0) return; // Don't even acquire the mutex |
| std::lock_guard<std::mutex> lock(task.base->mutex); |
| for (int i = 0; i < num_outputs; ++i) { |
| auto& output = outputs[i]; |
| const auto& next = fn.next_edge(i); |
| |
| if (!next.is_valid()) continue; |
| |
| // Check if the next function is ready to be computed |
| bool is_ready = false; |
| auto& dependencies = task.base->dependencies; |
| auto it = dependencies.find(next.function.get()); |
| if (it == dependencies.end()) { |
| auto name = next.function->name(); |
| throw std::runtime_error(std::string("dependency not found for ") + name); |
| } else if (--it->second == 0) { |
| dependencies.erase(it); |
| is_ready = true; |
| } |
| |
| auto& not_ready = task.base->not_ready; |
| auto not_ready_it = not_ready.find(next.function.get()); |
| if (not_ready_it == not_ready.end()) { |
| // Skip functions that aren't supposed to be executed |
| if (!exec_info.empty()) { |
| auto it = exec_info.find(next.function.get()); |
| if (it == exec_info.end() || !it->second.should_execute()) { |
| continue; |
| } |
| } |
| // No buffers have been allocated for the function |
| InputBuffer input_buffer(next.function->num_inputs()); |
| input_buffer.add(next.input_nr, std::move(output)); |
| if (is_ready) { |
| auto& queue = ready_queue(input_buffer.device()); |
| queue.push(FunctionTask(task.base, next.function, std::move(input_buffer))); |
| } else { |
| not_ready.emplace(next.function.get(), std::move(input_buffer)); |
| } |
| } else { |
| // The function already has a buffer |
| auto &input_buffer = not_ready_it->second; |
| input_buffer.add(next.input_nr, std::move(output)); |
| if (is_ready) { |
| auto& queue = ready_queue(input_buffer.device()); |
| queue.push(FunctionTask(task.base, next.function, std::move(input_buffer))); |
| not_ready.erase(not_ready_it); |
| } |
| } |
| } |
| } |
| |
| /* Computes the number of dependencies for each function which requires grad */ |
| auto Engine::compute_dependencies(Function* root, GraphTask& task) -> void { |
| // Just to make sure that they will never be added to the queue again |
| std::unordered_set<Function*> seen; |
| std::vector<Function*> queue { root }; |
| |
| // Queue contains all nodes that will start propagating gradients. |
| // We no longer have to expand functions that don't require grad. |
| auto& dependencies = task.dependencies; |
| while (queue.size() > 0) { |
| auto fn = queue.back(); queue.pop_back(); |
| for (const auto& edge : fn->next_edges()) { |
| if (auto next_ptr = edge.function.get()) { |
| dependencies[next_ptr] += 1; |
| const bool was_inserted = seen.insert(next_ptr).second; |
| if (was_inserted) queue.push_back(next_ptr); |
| } |
| } |
| } |
| } |
| |
| struct ClearCallbacks { |
| ClearCallbacks(std::vector<std::function<void()>>& callbacks, |
| std::mutex &callbacks_lock) |
| : callbacks(callbacks) |
| , callbacks_lock(callbacks_lock) { clear(); } |
| ~ClearCallbacks() { clear(); } |
| |
| void clear() { |
| std::lock_guard<std::mutex> lock(callbacks_lock); |
| callbacks.clear(); |
| } |
| |
| std::vector<std::function<void()>>& callbacks; |
| std::mutex& callbacks_lock; |
| }; |
| |
| auto Engine::execute(const edge_list& input_roots, |
| const variable_list& inputs, |
| bool keep_graph, |
| bool create_graph, |
| const edge_list& outputs) -> variable_list { |
| std::call_once(start_threads_flag, &Engine::start_threads, this); |
| |
| validate_outputs(input_roots, inputs, [](const std::string& msg) { |
| return msg; |
| }); |
| |
| // Callbacks are only valid for the duration of this run and should always be cleared |
| ClearCallbacks _cb_guard(final_callbacks, post_callbacks_lock); |
| |
| GraphTask graph_task(keep_graph, create_graph); |
| std::unique_lock<std::mutex> lock(graph_task.mutex); |
| |
| // Now compute the dependencies for all executable functions and queue the root |
| auto graph_root = std::make_shared<GraphRoot>(input_roots, inputs); |
| compute_dependencies(graph_root.get(), graph_task); |
| if (!outputs.empty()) { |
| graph_task.init_to_execute(*graph_root, outputs); |
| } |
| ready_queue(-1).push(FunctionTask(&graph_task, std::move(graph_root), InputBuffer(0))); |
| |
| // Not a worker |
| if (worker_device == NO_DEVICE) { |
| // Wait for all tasks to complete |
| graph_task.not_done.wait(lock, [&graph_task]{ |
| return graph_task.outstanding_tasks.load() == 0; |
| }); |
| } else { |
| // Get back to work while we wait for our new graph_task to |
| // complete! |
| // See Note [Reentrant backwards] |
| graph_task.owner = worker_device; |
| lock.unlock(); |
| thread_main(&graph_task); |
| } |
| |
| // Check for an exception while running backwards |
| if (graph_task.has_error.load()) { |
| std::rethrow_exception(graph_task.exception); |
| } |
| |
| if (!graph_task.not_ready.empty()) { |
| throw std::runtime_error("could not compute gradients for some functions"); |
| } |
| |
| // Unlocking is necessary, because the callback can register |
| // more callbacks (or they can be registered from other threads |
| // while it's waiting. |
| std::unique_lock<std::mutex> cb_lock(post_callbacks_lock); |
| for (size_t i = 0; i < final_callbacks.size(); ++i) { |
| cb_lock.unlock(); |
| final_callbacks[i](); |
| cb_lock.lock(); |
| } |
| |
| return graph_task.captured_vars; |
| } |
| |
| // note that when python is present, this base engine will be overriden |
| // with a PythonEngine. Because this typically happens before get_default_engine |
| // is called, this base engine will never be created. |
| static Engine& get_base_engine() { |
| static Engine engine; |
| return engine; |
| } |
| |
| std::atomic<EngineStub> engine_stub(get_base_engine); |
| |
| void set_default_engine_stub(EngineStub stub) { |
| engine_stub.store(stub); |
| } |
| |
| |
| Engine& Engine::get_default_engine() { |
| return engine_stub.load()(); |
| } |
| |
| void Engine::queue_callback(std::function<void()> callback) { |
| std::lock_guard<std::mutex> lock(post_callbacks_lock); |
| final_callbacks.emplace_back(std::move(callback)); |
| } |
| |
| bool Engine::is_checkpoint_valid() { |
| return checkpoint_valid; |
| } |
| |
| auto Engine::ready_queue(int device) -> ReadyQueue& { |
| return *ready_queues.at(device + 1); |
| } |
| |
| auto Engine::start_threads() -> void { |
| int num_devices = 0; |
| #ifdef WITH_CUDA |
| // check for case of compiled with CUDA but no available devices |
| if (cudaGetDeviceCount(&num_devices) != cudaSuccess) { |
| cudaGetLastError(); |
| num_devices = 0; |
| } |
| #endif |
| // One for CPU, plus one for every GPU device |
| int num_threads = num_devices + 1; |
| ready_queues = std::vector<std::shared_ptr<ReadyQueue>>(num_threads); |
| for (auto& queue : ready_queues) |
| queue.reset(new ReadyQueue()); |
| for (int i = 0; i < num_threads; ++i) { |
| std::thread t(&Engine::thread_init, this, i - 1); |
| t.detach(); |
| } |
| } |
| |
| void GraphTask::init_to_execute(Function& graph_root, const edge_list& outputs) { |
| exec_info[&graph_root].needed = true; |
| |
| int output_idx = 0; |
| for (auto & output_edge : outputs) { |
| Function *output = output_edge.function.get(); |
| auto & info = exec_info[output]; |
| if (!info.captures) |
| info.captures.reset(new std::vector<ExecInfo::Capture>()); |
| info.captures->emplace_back(output_edge.input_nr, output_idx++); |
| } |
| captured_vars.resize(output_idx); |
| |
| // NB: this is an uglier version (recursion replaced with iteration) of the following code: |
| // is_needed = {} |
| // def compute_is_needed(fn): |
| // if fn not in is_needed: |
| // is_needed[fn] = any(compute_is_needed(next_edge) |
| // for next_edge in fn.next_edges) |
| // return is_needed[fn] |
| struct Frame { |
| Frame (Function *fn) : fn(fn), next_next_fn(0) {} |
| Function *fn; |
| size_t next_next_fn; |
| |
| Function* get_next_fn() { |
| const auto & next = fn->next_edges(); |
| auto num_next = next.size(); |
| while (next_next_fn < num_next) { |
| auto fn = next[next_next_fn++].function.get(); |
| if (fn) return fn; |
| } |
| return nullptr; |
| } |
| }; |
| std::vector<Frame> stack; |
| std::unordered_set<Function*> seen; |
| for (const auto & input : graph_root.next_edges()) { |
| if (seen.count(input.function.get()) > 0) continue; |
| stack.emplace_back(input.function.get()); |
| while (!stack.empty()) { |
| auto &frame = stack.back(); |
| if (Function *next_fn = frame.get_next_fn()) { |
| if (/* bool unseen = */ seen.emplace(next_fn).second) { |
| stack.emplace_back(next_fn); |
| continue; // recurse |
| } |
| } else { |
| // NB: if we were using real recursion we could have saved some lookups |
| // using a return value from recursive call. It would make this manually unrolled |
| // version a lot more complicated, so I skipped that. |
| const auto & next_edges = frame.fn->next_edges(); |
| const bool needed = std::any_of( |
| next_edges.begin(), next_edges.end(), [&](const Edge& edge) { |
| auto it = exec_info.find(edge.function.get()); |
| return it != exec_info.end() && it->second.should_execute(); |
| }); |
| exec_info[frame.fn].needed = needed; |
| stack.pop_back(); |
| } |
| } |
| } |
| } |
| |
| }} // namespace torch::autograd |