blob: fda33f6afef60450998ef604e86d467589bea55e [file] [log] [blame]
#pragma once
// Function is an abstract class that represents a single operation from one or
// more variables to one more or variables.
//
// Subclasses may represent "forward" or "backward" operations (i.e functions
// and their derivatives). Some functions may be used as both.
#include <Python.h>
#include "torch/csrc/autograd/saved_variable.h"
#include "torch/csrc/utils/auto_unique_ptr.h"
#include "torch/csrc/autograd/function_hook.h"
#include "torch/csrc/autograd/profiler.h"
#include "torch/csrc/jit/tracer.h"
#include <ATen/ATen.h>
#include <memory>
#include <vector>
namespace torch { namespace autograd {
struct Function;
struct Variable;
using tensor_list = std::vector<at::Tensor>;
using variable_list = std::vector<Variable>;
using edge_type = std::pair<std::shared_ptr<Function>, int>;
using function_list = std::vector<edge_type>;
using saved_variable_list = std::vector<SavedVariable>;
struct edge_hasher {
std::size_t operator()(const edge_type& edge) const {
#define HASH_IDX(idx) std::hash<std::tuple_element<idx, edge_type>::type>()(std::get<idx>(edge))
// TODO: that's probably a bad hash function, but whatever
return HASH_IDX(0) ^ HASH_IDX(1);
}
};
// State used to create "backward" functions
struct FunctionFlags {
// Roughly speaking, is_executable corresponds to requires_grad.
// See http://pytorch.org/docs/notes/autograd.html for more details:
// both is_executable and is_volatile specify whether or not backwards
// gradient computation will be performed for a function, but they differ in
// their precedence.
bool is_executable = false;
bool is_volatile = false;
// What functions take the output of this function as input.
// There is one function per output of this function.
function_list next_functions;
};
struct Function : std::enable_shared_from_this<Function> {
Function()
: num_inputs(0)
, next_functions()
, is_executable(false)
, is_stochastic(false)
, pre_hooks()
, post_hooks()
, pyobj(nullptr)
{}
Function(FunctionFlags&& flags)
: num_inputs(0)
, next_functions(std::move(flags.next_functions))
, is_executable(flags.is_executable)
, is_stochastic(false)
, pre_hooks()
, post_hooks()
, pyobj(nullptr)
{}
Function(const Function& other) = delete;
Function(Function&& other) = delete;
virtual ~Function() {}
// Implements the operation
// NOTE: Don't call this function directly. Use apply_fn or operator() instead.
virtual variable_list apply(const variable_list& inputs) = 0;
variable_list tracedApply(variable_list inputs);
variable_list operator()(const variable_list& inputs) {
profiler::RecordFunction rec(this);
if (jit::tracer::isTracing(inputs)) {
return tracedApply(inputs);
}
return apply(inputs);
}
// PyFunctions are not managed by shared_ptrs by default, but are bound to the
// lifetime of their Python object instead.
virtual std::shared_ptr<Function> getSharedPtr() {
return shared_from_this();
};
// Computes is_executable, is_volatile, and next_functions from a list
// of input variables
static FunctionFlags flags(const variable_list& inputs);
// Releases saved variables if the operation won't be reused
virtual inline void releaseVariables() {}
// Function name for debugging
virtual std::string name();
inline bool should_compute_output(int i) const {
auto& fn = next_functions[i].first;
return fn && fn->is_executable;
}
inline void set_flags(FunctionFlags&& flags) {
is_executable = flags.is_executable;
next_functions = std::move(flags.next_functions);
}
// An op is traceable if all operations happening within apply() are performed
// on autograd Variables (i.e. apply mostly instantiates and applies other functions).
virtual inline bool is_traceable() { return false; };
// An op is said to pass state transparently to backward, if the state consists
// only of (Saved)Variables and only non-variable objects that parametrize the
// operation in some way that defines the graph structure AND the backward function
// is traceable. In particular, parametrization MUST NOT depend on the data
// of any Variable.
// TODO: it might be possible to handle cases where backward is non-traceable
// but state passing could be considered transparent. This will probably depend
// on saved_variable_list being mutable.
// NOTE: this value matters only if is_traceable() returns false.
virtual inline bool passes_state_transparently() { return false; };
// Let's the JIT find inputs to apply that are not present explicitly in arguments.
// Required only for functions that are not traceable, don't pass state to
// backward transparently, and are not backwards closures of functions that don't
// pass the state transparently. Which means that hopefully they will hardly ever
// need to be implemented :)
virtual inline std::unique_ptr<saved_variable_list> saved_variables() { return nullptr; }
static void setUpContextEdge(jit::Node* this_node, int ctx_output_nr,
const variable_list& inputs, const variable_list& outputs);
int num_inputs;
function_list next_functions;
bool is_executable;
bool is_stochastic;
std::vector<std::shared_ptr<FunctionPreHook>> pre_hooks;
std::vector<std::shared_ptr<FunctionPostHook>> post_hooks;
PyObject *pyobj; // weak reference
auto_unique_ptr<jit::tracer::FunctionTracingState> tracing_state;
};
// Actually what is a ForwardFunction here applies to all functions that are
// applied only in forward OR are backward closures that don't save any Variables.
// I chose this name, because the second situation is quite rare.
template<bool transparent_state = false>
struct ForwardFunction : public Function {
using Function::Function;
virtual inline std::unique_ptr<saved_variable_list> saved_variables() final {
return std::unique_ptr<saved_variable_list>(new saved_variable_list());
}
virtual inline bool is_traceable() final { return false; };
virtual inline bool passes_state_transparently() final { return transparent_state; };
};
// See Function::is_traceable() for definition.
struct TraceableFunction : public Function {
using Function::Function;
virtual inline bool is_traceable() final { return true; };
};
template<typename T>
struct apply_fn {
template<typename... Args>
apply_fn(Args&& ...args)
: fn_(std::make_shared<T>(std::forward<Args>(args)...)) {}
Variable operator()(const variable_list& inputs) {
return (*fn_)(inputs)[0];
}
template<typename... Args>
Variable operator()(Args&& ...inputs) {
return (*fn_)(variable_list{inputs...})[0];
}
std::shared_ptr<T> fn_;
};
}} // namespace torch::autograd