blob: 7b412dfbcfb181d5ce493bd61ec332ad454959f4 [file] [log] [blame]
#pragma once
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/assert.h"
#include "torch/csrc/autograd/function_hook.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/init_pass.h"
#include <memory>
#include <mutex>
#include <vector>
#include <iostream>
#include <cstdint>
#include <unordered_map>
namespace torch { namespace jit { namespace tracer {
using torch::autograd::Variable;
using variable_list = std::vector<std::shared_ptr<Variable>>;
// TracingState tracks the necessary state when we are tracing the execution of
// autograd code; most importantly, it holds a reference to the actual IR
// graph which we are recording the trace to.
//
// The liveness of a TracingState is expected to be a superset of the region
// of code being traced; in particular, Variables do not keep a TracingState
// live. Instead, they hold weak pointers to TracingState, to prevent leaks
// from arising when a variable that participated in a trace outlives the
// actual trace itself.
//
// Multi-threaded tracing is NOT yet supported: it is better to think of
// tracing as a thread local affair, where we enter and then exit
// a tracing region. This not only coincides with how Python code
// is run (GIL = single-threaded), but also how single Functions in
// an autograd closure are applied. Note that the execution of an
// *entire* autograd closure is multithreaded, in which case extra
// locking is necessary.
struct TracingState : public std::enable_shared_from_this<TracingState> {
TracingState()
: graph(new Graph())
, active(false) {}
std::shared_ptr<Graph> graph;
// void* is an unsafe TH. NON-OWNING, so it might get invalidated.
// TODO: Perhaps, turn this into an owning reference. The buffers
// are persistent, so this won't lead to a leak.
std::unordered_map<void*, Node*> buffer_map;
bool active;
std::mutex mutex;
variable_list inputs; // Used only for the duration of first stage
std::unique_lock<std::mutex> lock() { return std::unique_lock<std::mutex>(mutex); };
};
struct FunctionTracingState {
bool in_eval_subgraph = false;
};
// Should a function which takes 'vars' as inputs be traced?
// It sufficies for ONE variable to be tracing: any "untraced" variables
// are treated as constants.
inline bool isTracing(const variable_list& vars) {
for (auto& var : vars) {
if (!var) continue;
auto state = var->tracing_state.state.lock();
if (state && state->active)
return true;
}
return false;
}
// Retrieve the tracing state which a function applied with 'vars' should
// be recorded to. Precondition: isTracing(vars) == true. At the moment,
// we don't support mixing up variables from different traces; this code
// will need to be revisited if that ever becomes supported.
inline std::shared_ptr<TracingState> getTracingState(const variable_list& vars) {
std::shared_ptr<TracingState> state;
for (auto& var : vars) {
if (!var) continue;
auto var_state = var->tracing_state.state.lock();
if (var_state) {
if (!state) {
state = var_state;
}
JIT_ASSERT(state == var_state);
}
}
JIT_ASSERT(state);
return state;
}
// Having finished adding a new 'node' to the graph IR owned by TracingState 'state',
// 'setValueTrace' associates this node with an output variable, so that further operations
// involving this variable know which node in the IR to reference.
inline void setValueTrace(const std::shared_ptr<TracingState>& state, const std::shared_ptr<Variable>& var, Node *node) {
JIT_ASSERT(var->tracing_state.state.lock() == state || var->tracing_state.state.expired());
var->tracing_state.state = state;
var->tracing_state.trace = node;
}
// Given a variable 'var', return the 'node' which represents the instruction
// which computes the value of this variable in the IR. When 'mustExist' is
// false, we interpret untraced variables as constants that are just embedded
// in the graph. This is useful to handle code which does things like this
// (from torch.autograd.variable):
//
// def mm(self, matrix):
// output = Variable(self.data.new(self.data.size(0), matrix.data.size(1)))
// return Addmm.apply(output, self, matrix, 0, 1, True)
//
// Here, mm fakes up a dummy variable with uninitialized data to do an inplace
// update on, but subsequently ignores it because the alpha scaling factor is zero.
// This is one of the cases where a Variable can be created inside of a trace, and
// if we treat it as a constant, everything will work out.
inline Node* getValueTrace(const std::shared_ptr<TracingState>& state, const std::shared_ptr<Variable>& var, bool mustExist = false) {
//JIT_ASSERTM(var, "Not supported. NULL Variables will need to be removed from autograd");
if (!var) {
return state->graph->appendNode(state->graph->createUndefined());
}
auto var_state = var->tracing_state.state.lock();
if (var_state) {
JIT_ASSERT(var->tracing_state.state.lock() == state);
return var->tracing_state.trace;
}
if (mustExist) throw std::runtime_error("untraced variable");
Node *constant = state->graph->appendNode(state->graph->createConstant(var->data));
constant->inferTypeFrom(var->data);
setValueTrace(state, var, constant);
return constant;
}
inline Node* getBufferTrace(const std::unordered_map<void*, Node*>& buffer_map, at::Tensor buf) {
auto it = buffer_map.find(buf.unsafeGetTH(false));
if (it == buffer_map.end()) {
throw std::runtime_error("untraced buffer");
} else {
return it->second;
}
}
// Only one field may be non-null
struct TraceInput {
std::shared_ptr<Variable> variable;
at::Tensor buffer;
TraceInput(std::shared_ptr<Variable> variable) : variable(variable) {}
TraceInput(at::Tensor buffer) : buffer(buffer) {}
};
// Start tracing, treating 'inputs' as inputs to the trace, which can be
// varied on subsequent invocations of the trace. Any other variables
// will be treated as constants.
//
// NB: Why does this take an rvalue reference? We need to get a non-const
// reference to at::Tensor buffer to call unsafeGetTH, but you can't get this
// out of a const vector (silly std::vector...)
inline std::shared_ptr<TracingState> enter(std::vector<TraceInput>&& trace_inputs) {
auto state = std::make_shared<TracingState>();
// TODO: Figure out what's going on with batchnorm backwards...
variable_list inputs;
for (auto& trace_input : trace_inputs) {
if (trace_input.variable != nullptr) {
JIT_ASSERT(!trace_input.buffer.defined());
auto& input = trace_input.variable;
JIT_ASSERT(input->tracing_state.state.expired());
Node *input_node = state->graph->addInput();
setValueTrace(state, input, input_node);
input_node->inferTypeFrom(input->data);
inputs.push_back(input);
} else {
JIT_ASSERT(trace_input.buffer.defined());
// NON-owning reference. Pointers may be dead!
auto& buffer = trace_input.buffer;
Node* n = state->graph->addInput();
state->buffer_map.insert({buffer.unsafeGetTH(false), n});
n->inferTypeFrom(buffer);
}
}
state->active = true;
state->inputs = inputs;
return state;
}
namespace detail {
// Exit code shared between exit and TraceExitHook::run
inline void _exit(const std::shared_ptr<TracingState>& state, const variable_list& outputs) {
for (auto& output : outputs) {
state->graph->registerOutput(getValueTrace(state, output, true));
}
state->active = false;
}
// Marks a backwards subgraph that should be traced as the next stage.
// Mutates some of the outputs.
void traceBackward(const std::shared_ptr<TracingState>& state, const variable_list& inputs,
const variable_list& outputs);
} // namespace detail
// Exit a trace, treating 'outputs' as the outputs of the trace. These
// are the variables whose values will be computed upon subsequent
// invocations of the trace.
inline void exit(const variable_list& outputs) {
auto state = getTracingState(outputs);
detail::_exit(state, outputs);
detail::traceBackward(state, state->inputs, outputs);
state->inputs.clear();
}
// Marks part of the backward graph as non-traceable (i.e. one that should be replaced
// with an Eval in the trace).
void nontraceableBackwardSubgraph(const variable_list& inputs, const variable_list& outputs);
}}} // namespace torch::jit::tracer