blob: 4c025631f2fbc59873a3ca5ac108f69e9dcb5703 [file] [log] [blame]
#include "torch/csrc/autograd/engine.h"
#include <atomic>
#include <condition_variable>
#include <cstdint>
#include <iostream>
#include <mutex>
#include <set>
#include <string>
#include <THPP/THPP.h>
#include <thread>
#include <unordered_set>
#include <typeinfo>
#include <sstream>
#ifdef WITH_CUDA
#include <cuda.h>
#include <THC/THC.h>
#endif
using thpp::Tensor;
namespace torch { namespace autograd {
struct FunctionTask {
BackwardTask* base;
std::shared_ptr<Function> fn;
GradBuffer grad;
FunctionTask(BackwardTask* base, std::shared_ptr<Function> fn, GradBuffer grad)
: base(base)
, fn(fn)
, grad(std::move(grad)) {}
};
struct ReadyQueue {
std::deque<FunctionTask> queue;
std::condition_variable not_empty;
std::mutex mutex;
void push_front(FunctionTask item);
FunctionTask pop_back();
};
struct BackwardTask {
std::exception_ptr exception;
std::atomic_bool has_error;
std::atomic<uint64_t> outstanding_tasks;
bool retain_variables;
bool node_requires_grad;
std::mutex mutex;
std::condition_variable not_done;
std::unordered_map<Function*, GradBuffer> not_ready;
std::unordered_map<Function*, int> dependencies;
BackwardTask(bool retain_variables)
: exception()
, has_error(false)
, outstanding_tasks(0)
, retain_variables(retain_variables)
, node_requires_grad(false)
, mutex()
, not_done()
, not_ready()
, dependencies() {}
};
auto ReadyQueue::push_front(FunctionTask item) -> void {
{
std::lock_guard<std::mutex> lock(mutex);
++item.base->outstanding_tasks;
queue.push_front(std::move(item));
}
not_empty.notify_one();
}
auto ReadyQueue::pop_back() -> FunctionTask {
std::unique_lock<std::mutex> lock(mutex);
not_empty.wait(lock, [this]{ return !queue.empty(); });
auto task = std::move(queue.back()); queue.pop_back();
return task;
}
Engine::Engine() : ready_queues() {
}
Engine::~Engine() = default;
auto Engine::thread_main(ReadyQueue& queue) -> void {
while (1) {
FunctionTask task = queue.pop_back();
if (!task.base->has_error.load()) {
try {
evaluate_function(task);
} catch (std::exception& e) {
thread_on_exception(task, e);
}
}
if (--task.base->outstanding_tasks == 0) {
std::lock_guard<std::mutex> lock(task.base->mutex);
task.base->not_done.notify_all();
}
}
}
auto Engine::thread_on_exception(FunctionTask& task, std::exception& e) -> void {
std::lock_guard<std::mutex> lock(task.base->mutex);
if (!task.base->has_error.load()) {
task.base->exception = std::current_exception();
task.base->has_error = true;
}
}
static variable_list call_pre_hooks(Function& fn, variable_list grad_output) {
for (auto& hook : fn.pre_hooks) {
grad_output = (*hook)(grad_output);
}
return grad_output;
}
static variable_list call_post_hooks(Function& fn, variable_list grad_input, variable_list grad_output) {
for (auto& hook : fn.post_hooks) {
grad_input = (*hook)(grad_input, grad_output);
}
return grad_input;
}
static variable_list call_function(FunctionTask& task) {
auto grad_output = call_pre_hooks(*task.fn, GradBuffer::variables(std::move(task.grad)));
auto grad_input = task.fn->apply(grad_output);
return call_post_hooks(*task.fn, std::move(grad_input), std::move(grad_output));
}
auto Engine::evaluate_function(FunctionTask& task) -> void {
auto grad_inputs = call_function(task);
auto& fn = *task.fn;
if (!task.base->retain_variables) {
fn.releaseVariables();
}
if (grad_inputs.size() != fn.previous_functions.size()) {
std::stringstream ss;
ss << "Function '" << fn.name() << "' returned an invalid number of gradients - expected ";
ss << fn.previous_functions.size() << ", but got " << grad_inputs.size();
throw std::runtime_error(ss.str());
}
int size = grad_inputs.size();
for (int i = 0; i < size; ++i) {
auto& grad_input = grad_inputs[i];
auto& prev_fn = fn.previous_functions[i].first;
int output_nr = fn.previous_functions[i].second;
// null inputs have no previous_function and we skip them here
if (!prev_fn) {
continue;
}
// Stochastic functions are placed in the ready queue by
// compute_dependencies, so we can skip them here.
if (prev_fn->is_stochastic || !prev_fn->requires_grad) {
continue;
}
std::lock_guard<std::mutex> lock(task.base->mutex);
if (auto var = dynamic_cast<Variable*>(prev_fn.get())) {
if (!grad_input) {
// NOTE: grad_input can be NULL if the function returns None for a
// non_differentiable input. We may need to track additional information
// at the function level to determine if a NULL grad_input is an error.
std::stringstream ss;
ss << "Function '" << fn.name() << "' missing gradient at " << i;
throw std::runtime_error(ss.str());
}
var->backward(grad_input);
continue;
}
// Check if the function is ready for backward
bool is_ready = false;
auto& dependencies = task.base->dependencies;
auto it = dependencies.find(prev_fn.get());
if (it == dependencies.end()) {
auto name = prev_fn->name();
throw std::runtime_error(std::string("dependency not found for ") + name);
} else if (--it->second == 0) {
dependencies.erase(it);
is_ready = true;
}
auto& not_ready = task.base->not_ready;
auto not_ready_it = not_ready.find(prev_fn.get());
if (not_ready_it == not_ready.end()) {
// No buffers have been allocated for the function
GradBuffer prev_buffer(prev_fn->num_outputs);
prev_buffer.addGrad(output_nr, std::move(grad_input));
if (is_ready) {
auto& queue = ready_queue(prev_buffer.device());
queue.push_front(FunctionTask(task.base, prev_fn, std::move(prev_buffer)));
} else {
not_ready.emplace(prev_fn.get(), std::move(prev_buffer));
}
} else {
// The function already has a buffer
auto &prev_buffer = not_ready_it->second;
prev_buffer.addGrad(output_nr, std::move(grad_input));
if (is_ready) {
auto& queue = ready_queue(prev_buffer.device());
queue.push_front(FunctionTask(task.base, prev_fn, std::move(prev_buffer)));
not_ready.erase(not_ready_it);
}
}
}
}
/** Finds all stochastic functions and appends them to the queue */
auto Engine::find_stochastic_functions(function_queue& queue, BackwardTask& task) -> void {
std::unordered_set<Function*> seen;
function_queue search_queue(queue);
while (search_queue.size() > 0) {
auto fn = search_queue.back(); search_queue.pop_back();
for (auto& prev_fn_pair : fn->previous_functions) {
auto& prev_fn = prev_fn_pair.first;
Function* prev_ptr = prev_fn.get();
if (!prev_ptr) continue;
if (prev_ptr->is_stochastic && prev_ptr->requires_grad && seen.count(prev_ptr) == 0) {
ready_queue(-1).push_front(FunctionTask(&task, prev_fn, GradBuffer(0)));
queue.push_back(prev_ptr);
task.node_requires_grad = true;
}
if (seen.count(prev_ptr) == 0) {
seen.insert(prev_ptr);
search_queue.push_back(prev_ptr);
}
}
}
}
/** Computes the number of dependencies for each function which requires grad */
auto Engine::compute_dependencies(function_queue queue, BackwardTask& task) -> void {
// Just to make sure that they will never be added to the queue again
std::unordered_set<Function*> seen(queue.begin(), queue.end());
// Queue contains all nodes that will start propagating gradients.
// We no longer have to expand functions that don't require grad.
auto& dependencies = task.dependencies;
while (queue.size() > 0) {
auto fn = std::move(queue.back()); queue.pop_back();
// This is needed only to filter out backward roots that don't require grad
if (!fn->requires_grad) continue;
for (auto& prev_fn_pair : fn->previous_functions) {
Function* prev_ptr = prev_fn_pair.first.get();
if (!prev_ptr) continue;
if (dynamic_cast<Variable*>(prev_ptr)) continue;
if (!prev_ptr->requires_grad) continue;
if (prev_ptr->is_stochastic) continue; // Stochastic nodes were in the queue already
dependencies[prev_ptr] += 1;
if (seen.count(prev_ptr) == 0) {
seen.insert(prev_ptr);
queue.push_back(prev_ptr);
}
}
}
}
auto Engine::find_creators(const variable_list& variables,
tensor_list& grad_variables,
BackwardTask& task) -> function_queue {
function_queue creators;
std::unordered_map<std::shared_ptr<Function>, std::unique_ptr<GradBuffer>> creator_grad;
int size = variables.size();
for (int i = 0; i < size; ++i) {
auto& var = variables[i];
auto& grad = grad_variables[i];
if (!var->creator) {
// If someone calls .backward() on a leaf, it's simple...
if (var->requires_grad) {
var->backward(std::make_shared<Variable>(std::move(grad), false, true));
task.node_requires_grad = true;
}
} else {
auto& creator = var->creator;
auto& buf = creator_grad[creator];
if (creator->requires_grad) {
if (!buf) buf.reset(new GradBuffer(creator->num_outputs));
buf->addGrad(var->output_nr, Variable::of(std::move(grad)));
}
}
}
for (auto& entry: creator_grad) {
const auto& creator = entry.first;
creators.push_back(creator.get());
if (creator->requires_grad) {
// NOTE: buf is null if creator doesn't require gradient
auto& buf = entry.second;
auto& queue = ready_queue(buf->device());
queue.push_front(FunctionTask(&task, creator, std::move(*buf)));
task.node_requires_grad = true;
}
}
return creators;
}
auto Engine::backward(const variable_list& variables,
tensor_list& grad_variables,
bool retain_variables) -> void {
static std::once_flag once_flag;
std::call_once(once_flag, &Engine::start_threads, this);
BackwardTask backward_task(retain_variables);
std::unique_lock<std::mutex> lock(backward_task.mutex);
// Find the unique creators and backprop into variables which don't have creators.
auto creators = find_creators(variables, grad_variables, backward_task);
// Search the graph and find all stochastic functions. Append them to the queue.
find_stochastic_functions(creators, backward_task);
if (!backward_task.node_requires_grad) {
throw std::runtime_error(
"there are no graph nodes that require computing gradients");
}
// Now compute the dependencies for each function which requires grad
compute_dependencies(std::move(creators), backward_task);
// wait for all tasks to complete
backward_task.not_done.wait(lock, [&backward_task]{
return backward_task.outstanding_tasks.load() == 0;
});
// check for an exception while running backwards
if (backward_task.has_error.load()) {
std::rethrow_exception(backward_task.exception);
}
if (!backward_task.not_ready.empty()) {
throw std::runtime_error("could not compute gradients for some functions");
}
}
auto Engine::ready_queue(int device) -> ReadyQueue& {
return *ready_queues.at(device + 1);
}
auto Engine::start_threads() -> void {
int num_devices = 0;
#ifdef WITH_CUDA
THCudaCheck(cudaGetDeviceCount(&num_devices));
#endif
ready_queues = std::vector<std::unique_ptr<ReadyQueue>>(num_devices + 1);
for (auto& queue : ready_queues) {
queue.reset(new ReadyQueue());
std::thread t(&Engine::thread_main, this, std::ref(*queue));
t.detach();
}
}
}} // namespace torch::autograd