| #pragma once |
| |
| #include <torch/csrc/autograd/anomaly_mode.h> |
| #include <torch/csrc/autograd/edge.h> |
| #include <torch/csrc/autograd/grad_mode.h> |
| #include <torch/csrc/autograd/graph_task.h> |
| #include <torch/csrc/autograd/input_metadata.h> |
| #include <torch/csrc/autograd/saved_variable.h> |
| #include <torch/csrc/autograd/variable.h> |
| #include <torch/csrc/utils/python_stub.h> |
| #include <torch/csrc/utils/variadic.h> |
| |
| #include <ATen/SequenceNumber.h> |
| #include <ATen/core/Tensor.h> |
| #include <ATen/record_function.h> |
| #include <c10/util/Exception.h> |
| #include <c10/util/irange.h> |
| |
| #include <algorithm> |
| #include <cstdint> |
| #include <initializer_list> |
| #include <memory> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| C10_CLANG_DIAGNOSTIC_PUSH() |
| #if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") |
| C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") |
| #endif |
| |
| namespace torch { |
| namespace autograd { |
| |
| struct Edge; |
| struct FunctionPostHook; |
| struct FunctionPreHook; |
| |
| using tensor_list = std::vector<at::Tensor>; |
| using variable_list = std::vector<Variable>; |
| using edge_list = std::vector<Edge>; |
| using saved_variable_list = std::vector<SavedVariable>; |
| using IndexRange = std::pair<size_t, size_t>; |
| |
| // Custom deleter to prevent stack overflows. |
| TORCH_API void deleteNode(Node* function); |
| |
| // Guard that sets and restores the evaluating node |
| class NodeGuard { |
| public: |
| explicit NodeGuard(std::shared_ptr<Node> node); |
| ~NodeGuard(); |
| |
| private: |
| std::shared_ptr<Node> last_evaluating_node_; |
| }; |
| |
| // Return the Node currently being evaluated (if any) |
| // This is only set during the backward pass while a Node is being |
| // executed. |
| TORCH_API std::shared_ptr<Node> get_current_node(); |
| |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // Node |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // A `Node` is an abstract class that represents an operation taking zero |
| // or more input `Variable`s and producing zero or more output `Variable`s. All |
| // functions in PyTorch's autograd machinery derive from this class and |
| // override its `apply` method. Instances of such subclasses will then be |
| // invokeable via the call operator. |
| // |
| // Nodes in the Autograd Graph |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // When viewing the autograd system as a graph, `Node`s are the vertices or |
| // nodes, connected to each other via (directed) `Edge`s, which themselves are |
| // represented via (`Node`, input_nr) pairs. `Variable`s are the outputs to |
| // and inputs of `Node`s, and travel between these edges during execution |
| // of the graph. When two or more `Edge`s (from different sources) point at the |
| // same input to a `Node`, the values produced along all of these edges are |
| // implicitly summed prior to being forwarded to the target `Node`. |
| // |
| // Hierarchy |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // Subclasses usually represent differentiable functions as well as their |
| // gradient operators. Note, however, that due to the very general definition |
| // of a `Node` taking *zero* or more inputs and producing *zero* or more |
| // outputs, uses of `Node`s are flexible and extend beyond purely |
| // mathematical operations. For example, the `AccumulateGrad` function is a |
| // *sink*: it takes one input, but produces no outputs, instead accumulating |
| // the input as a side effect. At the other extreme, the `GraphRoot` function |
| // receives no inputs from other functions, but produces multiple outputs. |
| // |
| // Interface |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // The most important method on `Node` is the call operator, which takes in |
| // a list of variables and produces a list of variables. The precise size of |
| // these lists can be determined with `num_inputs()` and `num_outputs()`. |
| // `Node`s are stitched together via their `next_edge` interface, which let |
| // you manipulate the set of outgoing edges of a `Node`. You can add an |
| // edge with `add_next_edge()`, retrieve an edge with `next_edge(index)` and |
| // iterate over them via the `next_edges()` method. Other methods exist for |
| // integration with the JIT and other parts of PyTorch. Every `Node` has a |
| // *sequence number* that increases monotonically in the order of `Node` |
| // construction. It can be retrieved via the `sequence_nr()` method. Note that |
| // this sequence number is *thread local*. This means that when `Node`s |
| // `A`, `B` and `C` are created consecutively in the same thread, their |
| // sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B` |
| // are created in one thread and `C` is created in a new thread, there are *no |
| // guarantees* w.r.t. the ordering of `C` relative to `A` or `B`. |
| // See NOTE [ Sequence Number] for more details on the usages of sequence |
| // number. |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| struct TORCH_API Node : std::enable_shared_from_this<Node> { |
| public: |
| /// Construct a new `Node` with the given `next_edges` |
| explicit Node(uint64_t sequence_nr, edge_list&& next_edges = edge_list()) |
| : sequence_nr_(sequence_nr), next_edges_(std::move(next_edges)) { |
| for (const Edge& edge : next_edges_) { |
| update_topological_nr(edge); |
| } |
| |
| if (AnomalyMode::is_enabled()) { |
| metadata()->store_stack(); |
| |
| // If anomaly mode is enabled and graph is constructed, then assign the |
| // currently evaluating node as the parent of this node. |
| // A parent is a Node where this Node is created. |
| // We are tracking the parents to track multiple backward operations. |
| assign_parent(); |
| } |
| |
| // Store the thread_id of the forward operator. |
| // See NOTE [ Sequence Numbers ] |
| thread_id_ = at::RecordFunction::currentThreadId(); |
| } |
| |
| explicit Node(edge_list&& next_edges = edge_list()) |
| : Node( |
| /*sequence_nr=*/at::sequence_number::get_and_increment(), |
| std::move(next_edges)) {} |
| |
| /// Nodes are neither copyable nor moveable. |
| Node(const Node& other) = delete; |
| Node(Node&& other) = delete; |
| Node& operator=(const Node& other) = delete; |
| Node& operator=(Node&& other) = delete; |
| virtual ~Node() = default; |
| |
| std::shared_ptr<Node> getptr() { |
| return shared_from_this(); |
| } |
| /// Evaluates the function on the given inputs and returns the result of the |
| /// function call. |
| variable_list operator()(variable_list&& inputs) { |
| // In the first iteration of named tensors, autograd ignores names and |
| // operates on unnamed tensors. In the long term, autograd should |
| // probably operate with names. |
| at::NoNamesGuard no_names_guard; |
| |
| #ifdef USE_ROCM |
| // Keep track of backward pass for rocblas. |
| at::ROCmBackwardPassGuard in_backward; |
| #endif |
| |
| auto step_callbacks = |
| at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION); |
| if (C10_UNLIKELY(step_callbacks.has_value())) { |
| at::RecordFunction guard(std::move(*step_callbacks)); |
| // Using sequence number and thread id to correlate with |
| // the forward pass function |
| guard.setForwardThreadId(thread_id_); |
| if (guard.needsInputs()) { |
| std::vector<c10::IValue> inputs_vec(inputs.begin(), inputs.end()); |
| guard.before( |
| name(), |
| c10::ArrayRef<const c10::IValue>( |
| inputs_vec.data(), inputs_vec.size()), |
| sequence_nr()); |
| } else { |
| guard.before(name(), sequence_nr()); |
| } |
| return apply(std::move(inputs)); |
| } else { |
| return apply(std::move(inputs)); |
| } |
| } |
| |
| // Graph Connectivity API |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| // Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the |
| // forward function. |
| |
| // Marker for expected undefined input |
| struct undefined_input {}; |
| |
| /// Adds the type and shape metadata for a new input. Returns the index of |
| /// of the new input. |
| uint32_t add_input_metadata( |
| const at::TensorOptions& options, |
| c10::SymIntArrayRef shape, |
| bool is_tensor_subclass) noexcept { |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| uint32_t input_nr = input_metadata_.size(); |
| auto meta_shape = MetadataShape{c10::in_place_type<SymIntSmallVec>, shape}; |
| input_metadata_.emplace_back(options, meta_shape, is_tensor_subclass); |
| return input_nr; |
| } |
| |
| uint32_t add_input_metadata(const at::Tensor& t) noexcept { |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| uint32_t input_nr = input_metadata_.size(); |
| input_metadata_.emplace_back(t); |
| return input_nr; |
| } |
| |
| /// Adds a placeholder for an input that will not be used. |
| uint32_t add_input_metadata(undefined_input u) noexcept { |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| uint32_t input_nr = input_metadata_.size(); |
| input_metadata_.emplace_back(); |
| return input_nr; |
| } |
| |
| uint32_t num_inputs() const noexcept { |
| return input_metadata_.size(); |
| } |
| |
| const InputMetadata& input_metadata(size_t index) const { |
| return input_metadata_[index]; |
| } |
| |
| /** |
| * Note: Function Streams |
| * A function's stream (for a given device type) is the stream of the first |
| * element of its input buffer on a device of that type. |
| * |
| * If all elements are on the same device they MUST share a stream. If |
| * elements are on different devices (across multiple GPUs, for example) |
| * they may have different streams. |
| */ |
| c10::optional<c10::Stream> stream(const c10::DeviceType device_type) { |
| for (const auto& metadata : input_metadata_) { |
| if (metadata.device().type() == device_type) |
| return metadata.stream(); |
| } |
| |
| return c10::nullopt; |
| } |
| |
| void clear_input_metadata() { |
| input_metadata_.clear(); |
| } |
| |
| // Outputs ("Next Edges") |
| |
| void update_topological_nr(const Edge& edge) { |
| TORCH_INTERNAL_ASSERT( |
| !has_parent_, |
| "Cannot update a node's topological_nr after it already has a parent." |
| " If we allow this, we can no longer guarantee that a parent's" |
| " topo_nr is always greater than those of all its children") |
| Node* node = edge.function.get(); |
| if (node) { |
| auto topo_nr = node->topological_nr(); |
| if (topological_nr_ <= topo_nr) { |
| topological_nr_ = topo_nr + 1; |
| } |
| } |
| } |
| |
| void set_next_edge(size_t index, Edge edge) { |
| update_topological_nr(edge); |
| next_edges_[index] = std::move(edge); |
| } |
| |
| void add_next_edge(Edge edge) { |
| update_topological_nr(edge); |
| next_edges_.push_back(std::move(edge)); |
| } |
| |
| void set_next_edges(edge_list&& next_edges) { |
| next_edges_ = std::move(next_edges); |
| for (const auto& next_edge : next_edges_) { |
| update_topological_nr(next_edge); |
| } |
| } |
| |
| const Edge& next_edge(size_t index) const noexcept { |
| return next_edges_[index]; |
| } |
| |
| const edge_list& next_edges() const noexcept { |
| return next_edges_; |
| } |
| |
| edge_list& next_edges() noexcept { |
| return next_edges_; |
| } |
| |
| uint32_t num_outputs() const noexcept { |
| return next_edges_.size(); |
| } |
| |
| // Miscellaneous Methods |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| /// NOTE [ Sequence Number] |
| /// |
| /// The sequence_nr has two main usages in autograd: |
| /// |
| /// 1) Helps determine the node's execution priority in the engine. |
| /// All else being equal, nodes with higher priority numbers are executed |
| /// first. Thus, nodes corresponding to ops executed later are the first to |
| /// be executed in the backward pass. One caveat is that we prioritize |
| /// AccumulateGrad nodes by explicitly setting its sequence_nr to be |
| /// UINT64_MAX. |
| /// 2) The sequence number of this `Node` is paired with with thread_id it was |
| /// created in |
| /// as a unique identifier by the profiler to annotate recorded events. |
| /// The purpose of this is to help users (and possibly programs) |
| /// interpreting the profiler's output to correlate backward nodes with its |
| /// forward ops. We need both sequence_nr and thread_id to identify a node |
| /// because sequence_nr is thread_local, i.e., starts counting up from zero |
| /// in a new thread |
| uint64_t sequence_nr() const noexcept { |
| return sequence_nr_; |
| } |
| |
| // NOTE [ Topological Number ] |
| // |
| // topological_nr is used to prune branches in the DAG during autograd |
| // discovery as maintaining topological_nr helps us check in O(1) if there |
| // does NOT exist a directed path between two nodes. |
| // |
| // The topological order number of this `Node` representing the length of the |
| // longest possible path from this Node to any leaf node. If you are leaf |
| // node, aka AccumulateGrad, this will be zero. This value has the property |
| // that For every pair of nodes X, Y in G, existence of a directed path from X |
| // to Y implies topo_nr(X) > topo_nr(Y). The converse is not true, however, so |
| // we cannot prove existence of a path from X to Y, only non-existence. |
| // |
| // One assumption we make when using topo_nr is that once a node |
| // has been used, i.e., has a parent node, its own topo_nr does not change |
| // we have added some checks with the `has_parent_` field to enforce this. |
| // |
| // What NOT to do: |
| // |
| // 1) 2 -> 1 -> 0 In this diagram we label nodes with their |
| // topo_nr. |
| // 2 -> 1 -> 0 We have two simple graphs that can each |
| // arise from |
| // `t.exp().exp()`, for example. |
| // 2) 2 -> 1 -> 0 |
| // / |
| // 2 -> 1 -> 0 We add 2 as a next edge to 1 even though 1 |
| // already |
| // has a parent. |
| // 3) 2 -> 1 -> 0 |
| // / |
| // 2 -> 3 -> 0 2 < 3, yet there exists a path from 2 to 3! |
| // |
| uint64_t topological_nr() const noexcept { |
| has_parent_ = true; |
| return topological_nr_; |
| } |
| |
| // assigning a node as a parent to this node |
| void assign_parent(); |
| |
| /// Id of the thread that created Node |
| uint64_t thread_id() const noexcept { |
| return thread_id_; |
| } |
| |
| /// Returns the name of the dynamic type of the function, for debugging. |
| virtual std::string name() const; |
| |
| /// The difference between functions `should_compute_output` and |
| /// `task_should_compute_output`: |
| /// - `should_compute_output` should only be used during graph construction |
| /// and takes into account only requires_grad information |
| /// - `task_should_compute_output` should only be called during the backward |
| /// pass (unless called directly through grad_fn) and takes into account the |
| /// current graph task. Specifically, the autograd engine trims unnecessary |
| /// edges when `inputs` are specified, and during backward untrimmed nodes |
| /// left on the graph can/should check `task_should_compute_output` to see if |
| /// any outgoing edges have been trimmed by the engine. If that is the case, |
| /// gradient computation wrt those edges can be omitted. |
| /// |
| /// Returns true if the particular output edge is active, and that particular |
| /// output of this function should be computed. |
| bool should_compute_output(size_t output_edge_index) const { |
| TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range"); |
| return next_edges_[output_edge_index].is_valid(); |
| } |
| |
| /// Returns true if any of the output edges in any of the ranges are active. |
| bool should_compute_output(std::initializer_list<IndexRange> idxs) const { |
| return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { |
| for (const auto i : c10::irange(range.first, range.second)) { |
| if (should_compute_output(i)) |
| return true; |
| } |
| return false; |
| }); |
| } |
| |
| /// Same as the above `should_compute_output` function but will also |
| /// check whether this edge is needed within the current graph task. |
| bool task_should_compute_output(size_t output_edge_index) const { |
| TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range"); |
| const auto& next = next_edges_[output_edge_index]; |
| if (next.is_valid()) { |
| const auto exec_info = get_current_graph_task_exec_info(); |
| if (exec_info && !exec_info->empty()) { |
| auto it = exec_info->find(next.function.get()); |
| if (it == exec_info->end() || !it->second.should_execute()) { |
| return false; // this edge is not needed for the current graph_task |
| } |
| } |
| return true; |
| } |
| return false; |
| } |
| |
| /// Returns true if any of the output edges in any of the ranges are active |
| /// and should be computed in the current graph task. |
| bool task_should_compute_output( |
| std::initializer_list<IndexRange> idxs) const { |
| return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { |
| for (const auto i : c10::irange(range.first, range.second)) { |
| if (task_should_compute_output(i)) |
| return true; |
| } |
| return false; |
| }); |
| } |
| |
| /// Returns the `PyObject` stored for this `Node` (for Python |
| /// interaction). |
| PyObject* pyobj() const noexcept { |
| return pyobj_; |
| } |
| |
| /// Sets the `PyObject` stored for this `Node` (for Python interaction). |
| void set_pyobj(PyObject* pyobj) noexcept { |
| pyobj_ = pyobj; |
| } |
| |
| /// Returns the anomaly metadata stored for this `Node`. |
| /// If none exist, creates a new empty one. |
| AnomalyMetadata* metadata() noexcept; |
| |
| // Hook API |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| uintptr_t add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) { |
| post_hooks_.push_back(std::move(post_hook)); |
| // Use the raw pointer as the unique key to identify this hook. This key |
| // can then be used in del_post_hook(key) to remove this hook. |
| return reinterpret_cast<std::uintptr_t>(post_hooks_.back().get()); |
| } |
| |
| const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() |
| const noexcept { |
| return post_hooks_; |
| } |
| |
| // delete a post hook matching the key |
| bool del_post_hook(const uintptr_t& key) { |
| for (auto it = post_hooks_.begin(); it != post_hooks_.end(); ++it) { |
| if (key == reinterpret_cast<std::uintptr_t>(it->get())) { |
| post_hooks_.erase(it); |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept { |
| return post_hooks_; |
| } |
| |
| void add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) { |
| pre_hooks_.push_back(std::move(pre_hook)); |
| } |
| |
| void add_tensor_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) { |
| tensor_pre_hooks_.push_back(std::move(pre_hook)); |
| } |
| |
| void add_retains_grad_hook( |
| std::unique_ptr<FunctionPreHook>&& pre_hook, |
| int output_idx) { |
| retains_grad_hooks_[output_idx] = std::move(pre_hook); |
| } |
| |
| std::unique_ptr<FunctionPreHook> pop_retains_grad_hook(int output_idx) { |
| auto ret = std::move(retains_grad_hooks_[output_idx]); |
| retains_grad_hooks_.erase(output_idx); |
| return ret; |
| } |
| |
| const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() |
| const noexcept { |
| return pre_hooks_; |
| } |
| |
| std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() noexcept { |
| return pre_hooks_; |
| } |
| |
| virtual std::vector<std::unique_ptr<FunctionPreHook>>& |
| tensor_pre_hooks() noexcept { |
| return tensor_pre_hooks_; |
| } |
| |
| std::unordered_map<int, std::unique_ptr<FunctionPreHook>>& |
| retains_grad_hooks() noexcept { |
| return retains_grad_hooks_; |
| } |
| |
| // Customization Points for Subclasses |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| /// Releases saved variables if the operation won't be reused. |
| virtual void release_variables() {} |
| |
| /// Called before an apply if `release_variables()` is going to be called. |
| /// Allows larger ops like `InterpreterAutogradFunction` to incrementally |
| /// release variables as they run. |
| virtual void will_release_variables() {} |
| |
| /// Returns true if this function is traceable. An op is traceable if all |
| /// operations happening within `apply()` are performed on autograd |
| /// `Variables` (i.e. apply mostly instantiates and applies other functions). |
| virtual bool is_traceable() { |
| return false; |
| } |
| |
| /// A `Node` is said to pass state transparently to backward, if the |
| /// state consists only of (Saved)Variables and only non-variable objects |
| /// that parameterize the operation in some way that defines the graph |
| /// structure AND the backward function is traceable. In particular, |
| /// parametrization MUST NOT depend on the data of any `Variable`. |
| /// TODO: it might be possible to handle cases where backward is |
| /// non-traceable but state passing could be considered transparent. This |
| /// will probably depend on saved_variable_list being mutable. |
| /// NOTE: this value matters only if is_traceable() returns false. |
| virtual bool passes_state_transparently() { |
| return false; |
| } |
| |
| protected: |
| /// Performs the `Node`'s actual operation. |
| virtual variable_list apply(variable_list&& inputs) = 0; |
| |
| /// Calls `apply()`, but instruments it with tracing machinery. |
| variable_list traced_apply(variable_list inputs); |
| |
| // Sequence number used to correlate backward nodes with forward ops in the |
| // profiler and provide determinisim in the engine. |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| const uint64_t sequence_nr_; |
| |
| // See NOTE [ Topological Number ] |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| uint64_t topological_nr_ = 0; |
| |
| // Tracks whether this node has been added as the next_edge of another node |
| // via set_next_edge(s), which always calls topological_nr() of all its |
| // children See NOTE [ Topological Number ] for why we need this. |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| mutable bool has_parent_ = false; |
| |
| // Id of the thread that created the instance |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| uint64_t thread_id_ = 0; |
| |
| // Note [Thread Safety on Autograd Node] |
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // Autograd Engine let the owning thread which calls Engine::execute to drive |
| // the GraphTask execution, there might be cases that part of the GraphTask is |
| // shared across different `backward()` or `grad()` calls, i.e. fork new |
| // threads in the middle of the forward and call `backward()` separately from |
| // different threads. We need to protect the thread safety on NodeTask to |
| // prevent data racing on shared variables read/write. |
| // |
| // NB: This is only needed for Autograd Nodes that runs on CPU, technically |
| // "CUDA", "XLA" nodes don't need locking because device threads are always |
| // single threaded. |
| // |
| // Here we add a thread mutex to help protect the Node's thread safety, so |
| // that different threads cannot race the shared data when executing the same |
| // NodeTask from multiple CPU threads. It IS the user/developer responsibility |
| // to take advantage of this mutex to protect the thread safety of their |
| // autograd Node. The general strategy of thread safety on autograd Node: |
| // |
| // 1. User should lock the mutex during Node::release_variables() if the Node |
| // needs |
| // to release the variables on the fly, this serve the purpose that when we |
| // release saved_variables from one thread, no other threads can release |
| // the saved variables concurrently. call the Node::apply(), |
| // 2. User should lock the mutex during Node::apply(), this is to ensure Node |
| // that |
| // writing to the shared variable are not racing across threads (i.e. |
| // AccumulateGrad and custom C++ Autograd Node if writing to shared |
| // variables ) |
| // 3. item 2 and item 3 should work together so that when we release saved |
| // variables |
| // from one thread, no other threads can call Node::apply(), this ensures |
| // the variable references from other threads aren't dangling. |
| // 4. if the Node don't release any variables and no shared data read/write in |
| // the Node |
| // i.e. purely functional, user don't need to lock the mutex |
| // |
| // This way we could protect the thread safety on Autograd Node, but we could |
| // still not protect the thread safety on Node pre/post C++ hooks (python |
| // hooks are automatically thread safe), we rely on the user to write thread |
| // safe C++ hooks if they want the hook to be correctly applied in |
| // multithreading environment. |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| std::mutex mutex_; |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| edge_list next_edges_; |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| PyObject* pyobj_ = nullptr; // weak reference |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr; |
| |
| // NOTE [Hooks ordering] |
| // We have 3 separate fields for pre hooks registered to the autograd nodes |
| // because the conditions under which they execute are different, and we |
| // want more fine-grained control over the order in which different types |
| // of hooks are executed. |
| // - pre_hooks are only executed when the node itself is executed |
| // - tensor_pre_hook is executed as long as the engine traverses over it |
| // even if that node won't be executed. |
| // - retains_grad_hook are like tensor_pre_hooks except they are always |
| // ordered after all other tensor pre hooks |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_; |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_; |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| std::unordered_map<int, std::unique_ptr<FunctionPreHook>> retains_grad_hooks_; |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_; |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| at::SmallVector<InputMetadata, 2> input_metadata_; |
| }; |
| |
| /// See Node::is_traceable() for definition. |
| struct TraceableFunction : public Node { |
| using Node::Node; |
| bool is_traceable() final { |
| return true; |
| } |
| }; |
| |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // Associated Free Nodes |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| namespace detail { |
| // Implementation of `collect_next_edges` (see below). |
| struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> { |
| edge_list next_edges; |
| using IterArgs<MakeNextFunctionList>::operator(); |
| void operator()(const Variable& variable) { |
| // NOLINTNEXTLINE(bugprone-branch-clone) |
| if (variable.defined()) { |
| next_edges.push_back(impl::gradient_edge(variable)); |
| } else { |
| next_edges.emplace_back(); |
| } |
| } |
| void operator()(const Variable* variable) { |
| // NOLINTNEXTLINE(bugprone-branch-clone) |
| if (variable->defined()) { |
| next_edges.push_back(impl::gradient_edge(*variable)); |
| } else { |
| next_edges.emplace_back(); |
| } |
| } |
| void operator()(const c10::optional<Variable>& variable) { |
| // NOLINTNEXTLINE(bugprone-branch-clone) |
| if (variable.has_value() && variable->defined()) { |
| next_edges.push_back(impl::gradient_edge(*variable)); |
| } else { |
| next_edges.emplace_back(); |
| } |
| } |
| }; |
| } // namespace detail |
| |
| /// Create an `Edge` between the given `variable` and the `function`, which is |
| /// assumed to be the gradient function of this variable (i.e. the function |
| /// through which this variable is backpropagated during the backward pass). |
| /// This sets the `grad_fn` property of the `variable`. This function assumes |
| /// that the `Variable` is a new input to the gradient function and its |
| /// `input_nr` thus equal to `function->num_inputs()`. Additionally, it |
| /// increments the `Node`'s number of inputs by one. Approximately |
| /// equivalent to `variable.set_gradient_edge(function, |
| /// function->add_input_metadata(variable.dispatch_type(), variable.sizes()))`. |
| /// If you don't want the `Node`'s `num_inputs` to be incremented, use |
| /// `set_gradient_edge` directly. |
| inline void create_gradient_edge( |
| Variable& variable, |
| std::shared_ptr<Node> function) { |
| // Copy before move. |
| const auto input_nr = function->add_input_metadata(variable); |
| impl::set_gradient_edge(variable, {std::move(function), input_nr}); |
| } |
| |
| /// Return true if any of the variables in the list require a gradient. |
| inline bool any_variable_requires_grad(const variable_list& variables) { |
| return std::any_of( |
| variables.begin(), variables.end(), [](const Variable& variable) { |
| return variable.defined() && variable.requires_grad(); |
| }); |
| } |
| |
| /// Return the next edges of all the given variables, or tuples of variables. |
| template <typename... Variables> |
| edge_list collect_next_edges(Variables&&... variables) { |
| detail::MakeNextFunctionList make; |
| make.apply(std::forward<Variables>(variables)...); |
| return std::move(make.next_edges); |
| } |
| } // namespace autograd |
| } // namespace torch |
| |
| C10_CLANG_DIAGNOSTIC_POP() |