blob: 53a42e86f9a612f9ae66a8a034f81ee0268f32e2 [file] [log] [blame]
#include "torch/csrc/jit/autodiff.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/symbolic_variable.h"
#include "torch/csrc/utils/functional.h"
#include "torch/csrc/utils/auto_gpu.h"
namespace torch { namespace jit {
using value_map = std::unordered_map<Value*, Value*>;
using value_set = std::unordered_set<Value*>;
// Creates a node for a + b and puts it after the given node.
// If node is a null pointer, appends that node at the end of the node list.
Value* addValues(Value *a, Value *b, Node *node = nullptr) {
Graph *graph = a->node()->owningGraph();
Node *add_node = graph->create(kadd, {a, b})
->t_(kalpha, at::Scalar(1).toTensor());
if (node) {
add_node->insertAfter(node);
} else {
graph->appendNode(add_node);
}
Value *add_output = add_node->output();
add_output->setType(a->typeOption());
return add_output;
}
std::unordered_set<Symbol> differentiable_kinds = {
kadd, ksub, kmul,
};
bool isDifferentiable(Node * n) {
return differentiable_kinds.count(n->kind()) > 0;
}
static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_values) {
const auto build_sym_grad = [node](const std::vector<SymbolicVariable>& grads) -> std::vector<SymbolicVariable> {
auto inputs = node->inputs();
switch(node->kind()) {
case kadd:
return {grads[0], grads[0]};
case ksub:
return {grads[0], -grads[0]};
case kmul:
return {grads[0] * inputs[1], grads[0] * inputs[0]};
}
throw std::runtime_error(std::string("don't support differentiation of `") +
node->kind().toString() + "`");
};
auto sym_grads = build_sym_grad(fmap<SymbolicVariable>(grad_values));
return fmap(sym_grads, [](const SymbolicVariable &v) { return v.value(); });
}
static value_set findAllRequiresGradNodes(
Graph& graph, const std::vector<bool>& input_requires_grad) {
JIT_ASSERT(graph.inputs().size() == input_requires_grad.size());
std::unordered_set<Value*> requires_grad_set;
const auto requires_grad = [&](Value *v) { return requires_grad_set.count(v) > 0; };
auto inputs = graph.inputs();
for (std::size_t i = 0, num_inputs = inputs.size(); i < num_inputs; ++i) {
if (!input_requires_grad[i]) continue;
requires_grad_set.emplace(inputs[i]);
}
for (Node * node : graph.nodes()) {
if (std::none_of(node->inputs().begin(), node->inputs().end(), requires_grad)) continue;
for (Value * output : node->outputs())
requires_grad_set.emplace(output);
}
return requires_grad_set;
}
static Value* allocZerosLike(Value *v) {
static const Symbol constant_sym = "constant"_sym;
static const Symbol is_zero_sym = "is_zero"_sym;
static const Symbol value_sym = "value"_sym;
JIT_EXPECTM(v->hasType(), "can't allocate zero gradient for a value without a type");
Graph *graph = v->owningGraph();
auto type = v->type()->expect<TensorType>();
AutoGPU gpu_guard(type->device());
auto & at_type = type->device() == -1 ? at::CPU(type->scalarType()) : at::CUDA(type->scalarType());
auto zeros = at_type.zeros({1}).expand(type->sizes());
Node *constant = graph->create(constant_sym)
->t_(value_sym, zeros)
->i_(is_zero_sym, 1);
graph->appendNode(constant);
return constant->output();
}
struct ReverseDetails {
ReverseDetails(value_map&& grad_map, value_set&& requires_grad_set)
: grad_map(std::move(grad_map))
, requires_grad_set(std::move(requires_grad_set)) {}
value_map grad_map;
value_set requires_grad_set;
};
// Before:
// - graph has only stage 0
// - grad_desc doesn't have any fields initialized
// After:
// - graph has stage 0 and stage 1 that computes its vjp
// - grad_desc has df_input_vjps and df_output_vjps set
// (but df_input_vjps will be modified later as well)
static ReverseDetails addReverseInline(Graph& graph, Gradient& grad_desc,
const std::vector<bool>& input_requires_grad) {
JIT_ASSERT(graph.stage() == 0);
graph.advanceStage();
auto requires_grad_set = findAllRequiresGradNodes(graph, input_requires_grad);
const auto requires_grad = [&](Value *v) { return requires_grad_set.count(v) > 0; };
value_map grad_map; // x -> dx mapping
const auto get_grad = [&](Value* v) -> Value* {
auto it = grad_map.find(v);
if (it == grad_map.end()) {
std::tie(it, std::ignore) = grad_map.emplace(v, allocZerosLike(v));
}
return it->second;
};
const auto set_grad = [&](Value *x, Value *dx) {
if (Value * prev_grad = grad_map[x]) {
Value * new_grad = addValues(prev_grad, dx);
grad_map[x] = new_grad;
} else {
grad_map[x] = dx;
}
};
auto outputs = graph.outputs();
for (std::size_t i = 0, num_outputs = outputs.size(); i < num_outputs; ++i) {
Value * output = outputs[i];
if (!requires_grad(output)) continue;
Value * output_grad = graph.addInput()->setType(output->typeOption());
set_grad(output, output_grad);
grad_desc.df_input_vjps.push_back(i);
}
for (auto it = graph.rbegin(), end = graph.rend(); it != end; ++it) {
Node *node = *it;
auto inputs = node->inputs();
if (std::none_of(inputs.begin(), inputs.end(), requires_grad)) continue;
value_list grad_inputs = gradientForNode(node, fmap(node->outputs(), get_grad));
JIT_ASSERT(grad_inputs.size() == node->inputs().size());
for (std::size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
set_grad(inputs[i], grad_inputs[i]);
}
}
auto inputs = graph.inputs();
for (std::size_t i = 0, num_inputs = inputs.size(); i < num_inputs; ++i) {
Value * input = inputs[i];
if (input->stage() > 0) break;
if (!requires_grad(input)) continue;
graph.registerOutput(get_grad(input));
grad_desc.df_output_vjps.push_back(i);
}
return ReverseDetails(std::move(grad_map), std::move(requires_grad_set));
}
// This function will take the graph and return a new one that:
// - contains all nodes of graph that have given stage
// - there will be an input corresponding to each input of the inputs array
// - values corresponding to outputs will be returned from the new graph
// It requires that values contained in inputs are sufficient to be able to
// compute all values in a given stage. An exception will be thrown if this is
// not the case.
static std::shared_ptr<Graph> splitOffStage(
Graph& graph,
std::size_t stage,
ArrayRef<Value*> inputs,
ArrayRef<Value*> outputs) {
auto graph_clone = std::make_shared<Graph>();
value_map val_map; // values in graph -> values in graph_clone
const auto lookup_val = [&](Value *v) { return val_map.at(v); };
for (Value *input : inputs)
val_map[input] = graph_clone->addInput()->setType(input->typeOption());
for (Node *node : graph.nodes()) {
if (node->stage() != stage) continue;
Node *node_clone = graph_clone->createClone(node, lookup_val);
for (std::size_t i = 0, num_outputs = node_clone->outputs().size(); i < num_outputs; ++i)
val_map[node->outputs()[i]] = node_clone->outputs()[i];
graph_clone->appendNode(node_clone);
}
for (Value *output : outputs) {
JIT_ASSERT(output->stage() == stage);
graph_clone->registerOutput(val_map.at(output));
}
return graph_clone;
}
// Takes a graph returned from `addReverseInline` and splits it into two graphs
// (one for each stage). All intermediates needed in the second stage are added to
// outputs of the first graph, and taken as inputs in the second one. For a more
// detailed description see Note [Gradient graphs] in autodiff.h.
// This function also initializes the fields in grad_desc that were undefined after
// `addReverseInline` (and modifies `df_input_vjps`).
static void lambdaLiftReverse(Graph& graph,
ReverseDetails& rev_info,
Gradient& grad_desc) {
static const auto is_stage_0 = [](Value *v) { return v->stage() == 0; };
static const auto is_stage_1 = [](Value *v) { return v->stage() == 1; };
// NOTE: in the comments inside this function first stage is stage 0
JIT_ASSERT(graph.stage() == 1);
// --------------------------------------------------------------------------
// 1. Find values of stage 0 that need to be captured.
// --------------------------------------------------------------------------
// First, we need to find all values that are produced in the first stage,
// and used in the second one. They will need to be added as inputs of the reverse
// graph, and some of them may also need to be appended as outputs of the primal graph.
value_set reverse_captures_set;
value_list reverse_captures; // Invariant: topo sorted
auto check_uses = [&](Value *v) {
for (auto use : v->uses()) {
if (use.user->stage() != 1) continue;
if (/* bool unseen = */ reverse_captures_set.emplace(v).second) {
reverse_captures.push_back(v);
}
}
};
for (Value * input : graph.inputs()) {
if (input->stage() != 0) break;
check_uses(input);
}
for (Node * node : graph.nodes()) {
if (node->stage() != 0) break;
for (Value * output : node->outputs())
check_uses(output);
}
// --------------------------------------------------------------------------
// 2. Prepare input/outputs lists for both graphs.
// --------------------------------------------------------------------------
// It's simple to construct primal_inputs/reverse_outputs,
// but primal_outputs/reverse_inputs are much more subtle.
// Here's a summary of how they are supposed to look like:
//
// Primal outputs:
// [original outputs], [temporaries]
//
// Reverse inputs:
// [captured primal values, in topological order],
// [output vjps (aka grad_outputs)], [temporary vjps]
// -- Simple cases -----------------------------------------------------------
value_list primal_inputs = filter(graph.inputs(), is_stage_0);
value_list reverse_outputs = filter(graph.outputs(), is_stage_1);
// -- Construct primal_outputs, df_input_captures, f_real_outputs ----
value_list primal_outputs = filter(graph.outputs(), is_stage_0);
grad_desc.f_real_outputs = primal_outputs.size();
std::unordered_map<Value*, std::size_t> orig_primal_outputs_idx;
std::unordered_map<Value*, std::size_t> orig_primal_inputs_idx;
// NOTE: we use emplace to avoid replacing an existing index if an output is repeated
for (std::size_t i = 0, num_outputs = primal_outputs.size(); i < num_outputs; ++i)
orig_primal_outputs_idx.emplace(primal_outputs[i], i);
for (std::size_t i = 0, num_inputs = primal_inputs.size(); i < num_inputs; ++i)
orig_primal_inputs_idx[primal_inputs[i]] = i;
// NB: reverse_captures are already deduplicated, and in topo order
for (Value * capture_val : reverse_captures) {
// If it's already an output we don't have to add anything,
// but register the fact that it needs to be captured.
if (orig_primal_outputs_idx.count(capture_val) > 0) {
grad_desc.df_input_captures.emplace_back(Capture::Kind::Output,
orig_primal_outputs_idx[capture_val]);
// If it's an input, we could add it as an output but in fact it's
// more efficient to use a special kind of capture.
} else if (orig_primal_inputs_idx.count(capture_val) > 0) {
grad_desc.df_input_captures.emplace_back(Capture::Kind::Input,
orig_primal_inputs_idx.at(capture_val));
// Otherwise it's just a regular intermediate value that we need to add as an output
} else {
primal_outputs.emplace_back(capture_val);
grad_desc.df_input_captures.emplace_back(Capture::Kind::Output,
primal_outputs.size() - 1);
}
}
// -- Add VJPs for temporaries, adjust df_input_vjps -------------------------
// NB [possible optimization]: use the newly added vjp input as soon as the first
// vjp for that value is generated, to reduce the lifespan of this input
// (currently we add it to the final vjp after all adds).
JIT_ASSERT(graph.stage() == 1); // We will be adding inputs to stage 1
for (std::size_t i = grad_desc.f_real_outputs; i < primal_outputs.size(); ++i) {
Value * tmp = primal_outputs.at(i);
// Add VJP inputs only for intermediates that actually required grad.
if (rev_info.requires_grad_set.count(tmp) == 0) continue;
Value * tmp_vjp_in = graph.addInput()->setType(tmp->typeOption());
Value * tmp_vjp_prev = rev_info.grad_map.at(tmp);
// This is quite weird because we can't first make a sum and then replace all uses
// of tmp_vjp_prev (that would replace its use in the sum too!), so we create an
// incorrect sum that doesn't use prev vjp, replace uses, and fix the sum.
Value * new_vjp = addValues(tmp_vjp_in, tmp_vjp_in, tmp_vjp_prev->node());
tmp_vjp_prev->replaceAllUsesWith(new_vjp);
new_vjp->node()->replaceInput(1, tmp_vjp_prev);
grad_desc.df_input_vjps.emplace_back(i);
}
// -- Construct reverse_inputs -----------------------------------------------
// Quick reference:
// [captured primal values, in topological order], 1st loop below
// [output vjps (aka grad_outputs)], [temporary vjps] 2nd loop below
value_list reverse_inputs;
for (Capture capture : grad_desc.df_input_captures) {
auto & source = capture.kind == Capture::Kind::Input ? primal_inputs : primal_outputs;
reverse_inputs.push_back(source[capture.offset]);
}
// These are the vjps computed by differentiate + the code above
for (Value * reverse_vjp : filter(graph.inputs(), is_stage_1))
reverse_inputs.push_back(reverse_vjp);
// Finally, we can split the graph into two parts.
grad_desc.f = splitOffStage(graph, 0, primal_inputs, primal_outputs);
grad_desc.df = splitOffStage(graph, 1, reverse_inputs, reverse_outputs);
}
Gradient differentiate(std::shared_ptr<Graph>& _graph, const std::vector<bool>& requires_grad) {
// Take ownership of the graph
std::shared_ptr<Graph> graph;
JIT_ASSERTM(_graph.use_count() == 1,
"differentiate will mutate and destroy the graph, so it requires "
"graph.use_count() == 1");
std::swap(_graph, graph);
// XXX: Take care when handling outputs - they can be duplicated!
Gradient grad_desc;
// Fills in df_input_vjps and df_output_vjps
auto rev_info = addReverseInline(*graph, grad_desc, requires_grad);
// addReverseInline has to call gradientForNode if *any* of the outputs
// require grad, but it will emit vjps for *all* outputs. Use DCE to remove
// unnecessary nodes.
EliminateDeadCode(graph);
// Fills in f, df, f_real_outputs, df_input_captures,
// modifies df_input_vjps (new vjps are added for temporaries)
lambdaLiftReverse(*graph, rev_info, grad_desc);
return grad_desc;
}
}}