blob: 75c74eeafcfffabc38ac126646b057d91c786828 [file] [log] [blame]
#pragma once
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/anomaly_mode.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/saved_variable.h>
#include <torch/csrc/autograd/input_metadata.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/utils/python_stub.h>
#include <torch/csrc/utils/variadic.h>
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
#include <algorithm>
#include <cstdint>
#include <initializer_list>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace torch { namespace autograd {
struct Edge;
struct FunctionPostHook;
struct FunctionPreHook;
using tensor_list = std::vector<at::Tensor>;
using variable_list = std::vector<Variable>;
using edge_list = std::vector<Edge>;
using saved_variable_list = std::vector<SavedVariable>;
using IndexRange = std::pair<size_t, size_t>;
// Custom deleter to prevent stack overflows.
TORCH_API void deleteNode(Node* function);
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Node
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// A `Node` is an abstract class that represents an operation taking zero
/// or more input `Variable`s and producing zero or more output `Variable`s. All
/// functions in PyTorch's autograd machinery derive from this class and
/// override its `apply` method. Instances of such subclasses will then be
/// invokeable via the call operator.
///
/// Nodes in the Autograd Graph
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// When viewing the autograd system as a graph, `Node`s are the vertices or
/// nodes, connected to each other via (directed) `Edge`s, which themselves are
/// represented via (`Node`, input_nr) pairs. `Variable`s are the outputs to
/// and inputs of `Node`s, and travel between these edges during execution
/// of the graph. When two or more `Edge`s (from different sources) point at the
/// same input to a `Node`, the values produced along all of these edges are
/// implicitly summed prior to being forwarded to the target `Node`.
///
/// Hierarchy
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Subclasses usually represent differentiable functions as well as their
/// gradient operators. Note, however, that due to the very general definition
/// of a `Node` taking *zero* or more inputs and producing *zero* or more
/// outputs, uses of `Node`s are flexible and extend beyond purely
/// mathematical operations. For example, the `AccumulateGrad` function is a
/// *sink*: it takes one input, but produces no outputs, instead accumulating
/// the input as a side effect. At the other extreme, the `GraphRoot` function
/// receives no inputs from other functions, but produces multiple outputs.
///
/// Interface
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// The most important method on `Node` is the call operator, which takes in
/// a list of variables and produces a list of variables. The precise size of
/// these lists can be determined with `num_inputs()` and `num_outputs()`.
/// `Node`s are stitched together via their `next_edge` interface, which let
/// you manipulate the set of outgoing edges of a `Node`. You can add an
/// edge with `add_next_edge()`, retrieve an edge with `next_edge(index)` and
/// iterate over them via the `next_edges()` method. Other methods exist for
/// integration with the JIT and other parts of PyTorch. Every `Node` has a
/// *sequence number* that increases monotonically in the order of `Node`
/// construction. It can be retrieved via the `sequence_nr()` method. Note that
/// this sequence number is *thread local*. This means that when `Node`s
/// `A`, `B` and `C` are created consecutively in the same thread, their
/// sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B`
/// are created in one thread and `C` is created in a new thread, there are *no
/// guarantees* w.r.t. the ordering of `C` relative to `A` or `B`.
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
struct TORCH_API Node : std::enable_shared_from_this<Node> {
public:
/// Construct a new `Node` with the given `next_edges`. `sequence_nr` is
/// a (currently THE) hint to prioritization in the backward() pass, with
/// higher sequence numbers prioritized before lower sequence numbers.
explicit Node(
uint64_t sequence_nr,
edge_list&& next_edges = edge_list())
: sequence_nr_(sequence_nr),
next_edges_(std::move(next_edges)) {
if (AnomalyMode::is_enabled()) {
metadata()->store_stack();
}
}
explicit Node(edge_list&& next_edges = edge_list())
: Node(get_next_sequence_nr()++, std::move(next_edges)) {}
/// Nodes are neither copyable nor moveable.
Node(const Node& other) = delete;
Node(Node&& other) = delete;
Node& operator=(const Node& other) = delete;
Node& operator=(Node&& other) = delete;
virtual ~Node() = default;
/// Evaluates the function on the given inputs and returns the result of the
/// function call.
variable_list operator()(variable_list&& inputs) {
RECORD_FUNCTION(
this, std::vector<c10::IValue>(inputs.begin(), inputs.end()));
// In the first iteration of named tensors, autograd ignores names and
// operates on unnamed tensors. In the long term, autograd should
// probably operate with names.
at::NoNamesGuard no_names_guard;
return apply(std::move(inputs));
}
// Graph Connectivity API
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the
// forward function.
// Marker for expected undefined input
struct undefined_input {};
/// Adds the type and shape metadata for a new input. Returns the index of
/// of the new input.
uint32_t add_input_metadata(
const at::TensorOptions& options
, at::IntArrayRef shape
, at::Device device) noexcept {
uint32_t input_nr = input_metadata_.size();
input_metadata_.emplace_back(options, shape, device);
return input_nr;
}
uint32_t add_input_metadata(const at::Tensor& t) noexcept {
uint32_t input_nr = input_metadata_.size();
input_metadata_.emplace_back(t);
return input_nr;
}
/// Adds a placeholder for an input that will not be used.
uint32_t add_input_metadata(undefined_input u) noexcept {
uint32_t input_nr = input_metadata_.size();
input_metadata_.emplace_back();
return input_nr;
}
uint32_t num_inputs() const noexcept {
return input_metadata_.size();
}
const InputMetadata& input_metadata(size_t index) const {
return input_metadata_[index];
}
/**
* Note: Function Streams
* A function's stream (for a given device type) is the stream of the first
* element of its input buffer on a device of that type.
*
* If all elements are on the same device they MUST share a stream. If
* elements are on different devices (across multiple GPUs, for example)
* they may have different streams.
*/
c10::optional<c10::Stream> stream(const c10::DeviceType device_type) {
for (const auto& metadata : input_metadata_) {
if (metadata.device().type() == device_type) return metadata.stream();
}
return c10::nullopt;
}
void clear_input_metadata() {
input_metadata_.clear();
}
// Outputs ("Next Edges")
const Edge& next_edge(size_t index) const noexcept {
return next_edges_[index];
}
void set_next_edge(size_t index, Edge edge) {
next_edges_[index] = std::move(edge);
}
void add_next_edge(Edge edge) {
next_edges_.push_back(std::move(edge));
}
void set_next_edges(edge_list&& next_edges) {
next_edges_ = std::move(next_edges);
}
const edge_list& next_edges() const noexcept {
return next_edges_;
}
edge_list& next_edges() noexcept {
return next_edges_;
}
uint32_t num_outputs() const noexcept {
return next_edges_.size();
}
// Miscellaneous Methods
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// The sequence number of this `Node`.
uint64_t sequence_nr() const noexcept {
return sequence_nr_;
}
/// Returns the name of the dynamic type of the function, for debugging.
virtual std::string name() const;
/// Returns true if the particular output edge is active, and that particular
/// output of this function should be computed.
bool should_compute_output(size_t output_edge_index) const {
TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range");
return next_edges_[output_edge_index].is_valid();
}
/// Returns true if any of the output edges in any of the ranges are active.
bool should_compute_output(std::initializer_list<IndexRange> idxs) const {
return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) {
for (auto i = range.first; i < range.second; i++) {
if (should_compute_output(i))
return true;
}
return false;
});
}
/// Returns the `PyObject` stored for this `Node` (for Python
/// interaction).
PyObject* pyobj() const noexcept {
return pyobj_;
}
/// Sets the `PyObject` stored for this `Node` (for Python interaction).
void set_pyobj(PyObject* pyobj) noexcept {
pyobj_ = pyobj;
}
/// Returns the anomaly metadata stored for this `Node`.
/// If none exist, creates a new empty one.
AnomalyMetadata* metadata() noexcept;
// Hook API
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
uintptr_t add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
post_hooks_.push_back(std::move(post_hook));
// Use the raw pointer as the unique key to identify this hook. This key
// can then be used in del_post_hook(key) to remove this hook.
return reinterpret_cast<std::uintptr_t>(post_hooks_.back().get());
}
const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() const
noexcept {
return post_hooks_;
}
// delete a post hook matching the key
bool del_post_hook(const uintptr_t& key) {
for (auto it = post_hooks_.begin(); it != post_hooks_.end(); ++it) {
if (key == reinterpret_cast<std::uintptr_t>(it->get())) {
post_hooks_.erase(it);
return true;
}
}
return false;
}
std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept {
return post_hooks_;
}
void add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
pre_hooks_.push_back(std::move(pre_hook));
}
const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() const
noexcept {
return pre_hooks_;
}
std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() noexcept {
return pre_hooks_;
}
// Customization Points for Subclasses
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Releases saved variables if the operation won't be reused.
virtual void release_variables() {}
/// Called before an apply if `release_variables()` is going to be called.
/// Allows larger ops like `InterpreterAutogradFunction` to incrementally
/// release variables as they run.
virtual void will_release_variables() {}
/// Returns true if this function is traceable. An op is traceable if all
/// operations happening within `apply()` are performed on autograd
/// `Variables` (i.e. apply mostly instantiates and applies other functions).
virtual bool is_traceable() {
return false;
}
/// A `Node` is said to pass state transparently to backward, if the
/// state consists only of (Saved)Variables and only non-variable objects
/// that parameterize the operation in some way that defines the graph
/// structure AND the backward function is traceable. In particular,
/// parametrization MUST NOT depend on the data of any `Variable`.
/// TODO: it might be possible to handle cases where backward is
/// non-traceable but state passing could be considered transparent. This
/// will probably depend on saved_variable_list being mutable.
/// NOTE: this value matters only if is_traceable() returns false.
virtual bool passes_state_transparently() {
return false;
}
static uint64_t peek_at_next_sequence_nr();
protected:
static uint64_t& get_next_sequence_nr();
/// Performs the `Node`'s actual operation.
virtual variable_list apply(variable_list&& inputs) = 0;
/// Calls `apply()`, but instruments it with tracing machinery.
variable_list traced_apply(variable_list inputs);
// Since `Node`s are neither copyable nor moveable, we can have const
// fields.
const uint64_t sequence_nr_;
edge_list next_edges_;
PyObject* pyobj_ = nullptr; // weak reference
std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr;
std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
at::SmallVector<InputMetadata, 2> input_metadata_;
};
/// See Node::is_traceable() for definition.
struct TraceableFunction : public Node {
using Node::Node;
bool is_traceable() final {
return true;
}
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Associated Free Nodes
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
namespace detail {
// Implementation of `collect_next_edges` (see below).
struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
edge_list next_edges;
using IterArgs<MakeNextFunctionList>::operator();
void operator()(const Variable& variable) {
if (variable.defined()) {
next_edges.push_back(impl::gradient_edge(variable));
} else {
next_edges.emplace_back();
}
}
};
} // namespace detail
/// Create an `Edge` between the given `variable` and the `function`, which is
/// assumed to be the gradient function of this variable (i.e. the function
/// through which this variable is backpropagated during the backward pass).
/// This sets the `grad_fn` property of the `variable`. This function assumes
/// that the `Variable` is a new input to the gradient function and its
/// `input_nr` thus equal to `function->num_inputs()`. Additionally, it
/// increments the `Node`'s number of inputs by one. Approximately
/// equivalent to `variable.set_gradient_edge(function,
/// function->add_input_metadata(variable.dispatch_type(), variable.sizes()))`.
/// If you don't want the `Node`'s `num_inputs` to be incremented, use
/// `set_gradient_edge` directly.
inline void create_gradient_edge(
Variable& variable,
std::shared_ptr<Node> function) {
// Copy before move.
const auto input_nr = function->add_input_metadata(variable);
impl::set_gradient_edge(variable, {std::move(function), input_nr});
}
/// Return true if any of the variables in the list require a gradient.
inline bool any_variable_requires_grad(const variable_list& variables) {
return std::any_of(
variables.begin(), variables.end(), [](const Variable& variable) {
return variable.defined() && variable.requires_grad();
});
}
/// Return the next edges of all the given variables, or tuples of variables.
template <typename... Variables>
edge_list collect_next_edges(Variables&&... variables) {
if (!GradMode::is_enabled())
return {};
detail::MakeNextFunctionList make;
make.apply(std::forward<Variables>(variables)...);
return std::move(make.next_edges);
}
}} // namespace torch::autograd