| #pragma once |
| #include <c10/core/impl/TorchDispatchModeTLS.h> |
| #include <torch/csrc/autograd/custom_function.h> |
| #include <torch/csrc/autograd/engine.h> |
| #include <torch/csrc/utils/python_stub.h> |
| #include <torch/csrc/utils/torch_dispatch_mode.h> |
| #include <iostream> |
| #include <typeindex> |
| #include <vector> |
| |
| // see [Note: Compiled Autograd] |
| |
| namespace torch::dynamo::autograd { |
| using namespace torch::autograd; |
| |
| struct SizeInput { |
| // Note: int value is still needed when dynamic to pass as an arg |
| enum DynType : uint8_t { STATIC = 0, DYNAMIC = 1 }; |
| SizeInput(DynType dt, int64_t v) : dyn_type(dt), value(v) {} |
| DynType dyn_type; |
| int64_t value; |
| }; |
| |
| struct CacheKeyBuffer { |
| CacheKeyBuffer(const uint8_t* key, uint16_t len) : data(new uint8_t[len]) { |
| std::memcpy(data.get(), key, len); |
| } |
| const uint8_t* get() const { |
| return data.get(); |
| } |
| |
| private: |
| std::unique_ptr<uint8_t[]> data; |
| }; |
| |
| struct CacheKey { |
| // Key to find the next node in the shadow graph. We use C++ RTTI for the |
| // type of the node (ntype), then a key generated with a visitor pattern. |
| CacheKey(const std::type_index& ntype, const uint8_t* key, uint16_t len) |
| : node_type(ntype), key_size(len), key(key) {} |
| |
| bool operator<(const CacheKey& other) const { |
| if (node_type != other.node_type) { |
| return node_type < other.node_type; |
| } |
| if (key_size != other.key_size) { |
| return key_size < other.key_size; |
| } |
| return std::memcmp(key, other.key, key_size) < 0; |
| } |
| |
| bool operator==(const CacheKey& other) const { |
| return node_type == other.node_type && key_size == other.key_size && |
| std::memcmp(key, other.key, key_size) == 0; |
| } |
| |
| size_t hash() const { |
| // don't bother hashing the key data, common case 1 cache entry per node |
| return std::hash<std::type_index>()(node_type) ^ key_size; |
| } |
| |
| std::type_index node_type; |
| uint16_t key_size; |
| const uint8_t* key; |
| }; |
| |
| struct NodeCall { |
| NodeCall(uint32_t id_, std::shared_ptr<Node> node_) |
| : id(id_), node(std::move(node_)) {} |
| |
| void mark_output(int input_nr, int output_idx) { |
| graph_output.emplace_back(std::make_pair(input_nr, output_idx)); |
| } |
| |
| uint32_t id; |
| std::shared_ptr<Node> node; |
| std::vector<std::pair<int, int>> tensor_pre_hooks; |
| std::vector<int> pre_hooks; |
| std::vector<int> post_hooks; |
| std::vector<std::pair<int, int>> graph_output; |
| bool needed = true; |
| }; |
| |
| struct NodeCalls : public std::unordered_map<Node*, NodeCall> { |
| NodeCall& lookup(const std::shared_ptr<Node>& function) { |
| auto it = find(function.get()); |
| if (it == end()) { |
| it = emplace(function.get(), NodeCall(_next_id++, function)).first; |
| } |
| return it->second; |
| } |
| |
| private: |
| uint32_t _next_id = 0; |
| }; |
| |
| struct TensorArg { |
| // Represents a de-duplicated tensor that will be passed into the graph |
| TensorArg(uint32_t i = 0) : id(i) {} |
| uint32_t index() const { |
| TORCH_INTERNAL_ASSERT(defined()); |
| return id - 1; |
| } |
| bool defined() const { |
| return id != 0; |
| } |
| uint32_t id; |
| at::Tensor proxy_tensor; |
| }; |
| |
| struct TensorArgs { |
| // Manages a collection of TensorArgs and mappings from Tensors/SavedVariables |
| // to them. This also allows us to unpack SavedVariable exactly once and |
| // store the unpacked Tensor. |
| |
| TensorArg& lookup(const at::Tensor& tensor, bool create = false) { |
| if (!tensor.defined()) { |
| return _undefined; |
| } |
| auto impl = tensor.unsafeGetTensorImpl(); |
| auto it = _args.find(impl); |
| if (it == _args.end()) { |
| TORCH_INTERNAL_ASSERT(create && inputs.size() == _next_id - 1); |
| it = _args.emplace(impl, TensorArg(_next_id++)).first; |
| inputs.emplace_back(tensor); |
| } |
| return it->second; |
| } |
| |
| TensorArg& lookup(const SavedVariable& sv) { |
| auto it = _saved_variables.find(&sv); |
| TORCH_INTERNAL_ASSERT(it != _saved_variables.end()); |
| return *it->second; |
| } |
| |
| TensorArg& add(const at::Tensor& tensor) { |
| return lookup(tensor, true); |
| } |
| |
| TensorArg& add(const SavedVariable& sv, const std::shared_ptr<Node>& node) { |
| // TODO(jansel): Here we unpack the SavedVariable exactly once. This might |
| // fire SavedTensor hooks. In the future we should try to put saved tensor |
| // hooks into the graph. |
| at::Tensor tensor = sv.unpack(node); |
| TensorArg& arg = add(tensor); |
| _saved_variables.emplace(&sv, &arg); |
| return arg; |
| } |
| |
| // the concrete tensors that will get passed into the graph as inputs |
| std::vector<at::Tensor> inputs; |
| |
| private: |
| std::unordered_map<const c10::TensorImpl*, TensorArg> _args; |
| // Every TensorArg from this is actually owned by _args (or _undefined) and |
| // that's why we have an un-owned pointer here. |
| std::unordered_map<const SavedVariable*, TensorArg*> _saved_variables; |
| TensorArg _undefined; |
| uint32_t _next_id = 1; // id=0 used by _undefined |
| }; |
| |
| struct AutogradCompilerCall { |
| void add_size_input(const c10::SymInt& s) { |
| all_size_inputs.emplace_back(SizeInput(default_dyn_type, s.expect_int())); |
| } |
| |
| int emplace_hook(c10::SafePyObject&& fn) { |
| hooks.emplace_back(std::move(fn)); |
| return hooks.size() - 1; |
| } |
| |
| TensorArgs tensor_args; |
| std::vector<SizeInput> all_size_inputs; |
| std::vector<int64_t> dyn_size_inputs; |
| std::vector<c10::SafePyObject> hooks; |
| NodeCalls node_calls; |
| SizeInput::DynType default_dyn_type = SizeInput::STATIC; |
| }; |
| |
| class CompiledNodeArgs { |
| // CompiledNodeArgs builds a representation of the constant values found |
| // across all the nodes in the compiled graph, via 'collect' overloads. The |
| // collected constants are specialized on by concatenation into a cache key. |
| // Tensor, symint arguments (which are lifted to become graph inputs rather |
| // than specialized on) are forwarded to the compiler and not included in the |
| // key. |
| public: |
| void collect(const TensorArg& t) { |
| collect_size(t.id); |
| if (t.defined()) { |
| const at::Tensor& tensor = _compiler.tensor_args.inputs[t.index()]; |
| // including these in the cache key means dynamo-level tensor guards can |
| // be skipped |
| collect(tensor.device()); |
| collect(tensor.dtype()); |
| collect(tensor.requires_grad()); |
| } |
| } |
| |
| void collect(const at::Tensor& t) { |
| collect(_compiler.tensor_args.add(t)); |
| } |
| void collect(const SavedVariable& t) { |
| collect(_compiler.tensor_args.add(t, _node_call.node)); |
| } |
| void collect(const c10::SymInt& t) { |
| _compiler.add_size_input(t); |
| } |
| template <typename T> |
| void collect(const std::vector<T>& t) { |
| collect_size(t.size()); |
| for (const T& i : t) { |
| collect(i); |
| } |
| } |
| template <typename T> |
| void collect(const c10::ArrayRef<T>& t) { |
| collect_size(t.size()); |
| for (const T& i : t) { |
| collect(i); |
| } |
| } |
| template <typename T> |
| void collect(const c10::OptionalArray<T>& t) { |
| collect(t.list); |
| } |
| template <typename T> |
| void collect(const c10::optional<T>& t) { |
| if (cond(t.has_value())) { |
| collect(*t); |
| } |
| } |
| template <typename A, typename B> |
| void collect(const std::pair<A, B>& t) { |
| collect(t.first); |
| collect(t.second); |
| } |
| void collect(const c10::Scalar& t) { |
| auto type = t.type(); |
| specialize_on_bytes(type); |
| if (type == c10::ScalarType::Double) { |
| collect(t.toDouble()); |
| } else if (type == c10::ScalarType::Long) { |
| collect(t.toLong()); |
| } else if (type == c10::ScalarType::Bool) { |
| collect(t.toBool()); |
| } else if (type == c10::ScalarType::ComplexDouble) { |
| auto c = t.toComplexDouble(); |
| collect(c.real()); |
| collect(c.imag()); |
| } else { |
| TORCH_INTERNAL_ASSERT(false); |
| } |
| } |
| void collect(const c10::TensorOptions& t) { |
| collect(t.device()); |
| collect(t.dtype()); |
| collect(t.layout()); |
| collect(t.requires_grad()); |
| collect(t.pinned_memory()); |
| collect(t.memory_format_opt()); |
| } |
| void collect(const at::TensorGeometry& t) { |
| collect(t.sym_sizes()); |
| collect(t.sym_strides()); |
| collect(t.sym_storage_offset()); |
| } |
| void collect(const torch::autograd::TypeAndSize& t) { |
| collect(t.sym_sizes); |
| collect(t.options); |
| } |
| void collect(const c10::Device& t) { |
| collect(t.type()); |
| collect(t.index()); |
| } |
| void collect(const std::string& t) { |
| collect_size(t.size()); |
| for (char c : t) { |
| collect(c); |
| } |
| } |
| void collect(const caffe2::TypeMeta& t) { |
| specialize_on_bytes(t.id()); |
| } |
| void collect(const std::shared_ptr<Node>& t) { |
| // Note: this is only capturing the ID of the node not everything |
| // contained inside it. This is used for tracking connections between |
| // nodes and the actual details of the node itself must be handled by |
| // a seperate call to `node->compiled_args()`. |
| if (cond((bool)t)) { |
| collect(_compiler.node_calls.lookup(t)); |
| } |
| } |
| void collect(const NodeCall& t) { |
| collect_size(t.id); |
| collect(t.graph_output); |
| collect_hooks_from(t.node.get()); |
| } |
| void collect(const Edge& t) { |
| if (cond(t.is_valid())) { |
| collect_size(_compiler.node_calls.lookup(t.function).id); |
| collect_size(t.input_nr); |
| collect(t.function->input_metadata(t.input_nr)); // for validate_outputs |
| } |
| } |
| void collect(const InputMetadata& t) { |
| TORCH_CHECK(!t.is_nested_tensor(), "NestedTensor not implemented"); |
| collect(t.options()); |
| collect(t.is_tensor_subclass()); |
| collect(t.shape_as_dim_vector()); |
| } |
| void collect(const VariableInfo& t) { |
| collect(t.layout); |
| collect(t.device); |
| collect(t.scalar_type); |
| collect(t.size); |
| collect(t.requires_grad); |
| collect(t.is_empty); |
| } |
| bool cond(bool cond) { |
| collect(cond); |
| return cond; |
| } |
| |
| #define COLLECT_AS_BYTES(T) \ |
| void collect(T t) { \ |
| specialize_on_bytes(t); \ |
| } |
| COLLECT_AS_BYTES(c10::ScalarType); |
| COLLECT_AS_BYTES(c10::DeviceType); |
| COLLECT_AS_BYTES(c10::Layout); |
| COLLECT_AS_BYTES(c10::MemoryFormat); |
| COLLECT_AS_BYTES(int8_t); |
| COLLECT_AS_BYTES(int16_t); |
| COLLECT_AS_BYTES(int32_t); |
| COLLECT_AS_BYTES(int64_t); |
| COLLECT_AS_BYTES(uint8_t); |
| COLLECT_AS_BYTES(uint16_t); |
| COLLECT_AS_BYTES(uint32_t); |
| COLLECT_AS_BYTES(uint64_t); |
| COLLECT_AS_BYTES(bool); |
| COLLECT_AS_BYTES(float); |
| COLLECT_AS_BYTES(double); |
| #undef COLLECT_AS_BYTES |
| |
| void collect_hooks_from(Node* fn) { |
| TORCH_CHECK( |
| fn->retains_grad_hooks().empty(), |
| "retains_grad_hooks not implemented for compiled autograd"); |
| for (auto& i : fn->tensor_pre_hooks()) { |
| i->compiled_args(*this); |
| } |
| for (auto& i : fn->pre_hooks()) { |
| i->compiled_args(*this); |
| } |
| for (auto& i : fn->post_hooks()) { |
| i->compiled_args(*this); |
| } |
| collect_size(_node_call.tensor_pre_hooks.size()); |
| collect_size(_node_call.pre_hooks.size()); |
| collect_size(_node_call.post_hooks.size()); |
| for (const auto& h : _node_call.tensor_pre_hooks) { |
| collect_size(h.second); // index |
| } |
| } |
| |
| CacheKey key() const { |
| Node* node = _node_call.node.get(); |
| return CacheKey( |
| typeid(*node), _specialization_key, _specialization_key_size); |
| } |
| |
| void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) { |
| auto fn_id = _compiler.emplace_hook(std::move(obj)); |
| collect_size(fn_id); |
| _node_call.tensor_pre_hooks.emplace_back(std::make_pair(fn_id, index)); |
| } |
| |
| void add_pre_hook(c10::SafePyObject&& obj) { |
| auto fn_id = _compiler.emplace_hook(std::move(obj)); |
| collect_size(fn_id); |
| _node_call.pre_hooks.emplace_back(fn_id); |
| } |
| |
| void add_post_hook(c10::SafePyObject&& obj) { |
| auto fn_id = _compiler.emplace_hook(std::move(obj)); |
| collect_size(fn_id); |
| _node_call.post_hooks.emplace_back(fn_id); |
| } |
| |
| void collect_size(size_t s) { |
| // we expect sizes to be small, so try to cram them into a single byte |
| constexpr uint8_t encode_as_u64 = std::numeric_limits<uint8_t>::max(); |
| constexpr uint8_t encode_as_u32 = encode_as_u64 - 1; |
| constexpr uint8_t encode_as_u16 = encode_as_u64 - 2; |
| if (C10_UNLIKELY(s >= encode_as_u16)) { |
| // first write a byte indicating the path we followed, then the data |
| if (s <= std::numeric_limits<uint16_t>::max()) { |
| // 3 bytes |
| specialize_on_bytes(encode_as_u16); |
| specialize_on_bytes(static_cast<uint16_t>(s)); |
| } else if (s <= std::numeric_limits<uint32_t>::max()) { |
| // 5 bytes |
| specialize_on_bytes(encode_as_u32); |
| specialize_on_bytes(static_cast<uint32_t>(s)); |
| } else { |
| // 9 bytes |
| specialize_on_bytes(encode_as_u64); |
| specialize_on_bytes(s); |
| } |
| } else { |
| // happy case, 1 byte |
| specialize_on_bytes(static_cast<uint8_t>(s)); |
| } |
| } |
| |
| SizeInput::DynType set_default_dyn_type(SizeInput::DynType default_dyn_type) { |
| return std::exchange(_compiler.default_dyn_type, default_dyn_type); |
| } |
| |
| CompiledNodeArgs(AutogradCompilerCall& compiler, NodeCall& node_call) |
| : _compiler(compiler), |
| _node_call(node_call), |
| _specialization_key_size(0), |
| _specialization_key_storage(1024), |
| _specialization_key( |
| (uint8_t*)std::malloc(_specialization_key_storage)) {} |
| ~CompiledNodeArgs() { |
| std::free(_specialization_key); |
| } |
| CompiledNodeArgs(const CompiledNodeArgs&) = delete; |
| |
| private: |
| template <typename T> |
| void specialize_on_bytes(const T& t) { |
| while (C10_UNLIKELY( |
| _specialization_key_size + sizeof(T) > _specialization_key_storage)) { |
| _specialization_key_storage *= 2; |
| _specialization_key = (uint8_t*)std::realloc( |
| _specialization_key, _specialization_key_storage); |
| } |
| std::memcpy(_specialization_key + _specialization_key_size, &t, sizeof(T)); |
| _specialization_key_size += sizeof(T); |
| } |
| |
| AutogradCompilerCall& _compiler; |
| NodeCall& _node_call; |
| size_t _specialization_key_size; |
| size_t _specialization_key_storage; |
| uint8_t* _specialization_key; |
| }; |
| |
| struct TraceState { |
| TraceState( |
| const std::vector<c10::optional<c10::SymInt>>& ss, |
| size_t num_outputs) |
| : sym_sizes_index(0), sym_sizes(ss), outputs(num_outputs) {} |
| |
| void debug_asserts() { |
| TORCH_INTERNAL_ASSERT(sym_sizes_index == sym_sizes.size()); |
| } |
| c10::optional<c10::SymInt> next_sym_size() { |
| TORCH_INTERNAL_ASSERT(sym_sizes_index < sym_sizes.size()); |
| return sym_sizes[sym_sizes_index++]; |
| } |
| |
| size_t sym_sizes_index; |
| std::vector<c10::optional<c10::SymInt>> sym_sizes; |
| variable_list outputs; |
| std::vector<size_t> output_grad_targets; |
| }; |
| |
| #define SWAP_SAVED_VARIABLES_SAVE(mapping, var, move) \ |
| bool inserted = mapping.emplace(&var, move).second; \ |
| TORCH_INTERNAL_ASSERT(inserted, "duplicate before()"); |
| |
| #define SWAP_SAVED_VARIABLES_RESTORE(mapping, var) \ |
| auto it = mapping.find(&var); \ |
| TORCH_INTERNAL_ASSERT(it != mapping.end(), "duplicate after()"); \ |
| var = std::move(it->second); \ |
| mapping.erase(it); |
| |
| class SwapSavedVariables { |
| // SwapSavedVariables is used during the tracing/compilation phase after a |
| // cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes, |
| // allows tracing to happen, then swaps them back afterwards. |
| public: |
| void before(at::Tensor& t) { |
| TensorArg& arg = compiler.tensor_args.lookup(t); |
| SWAP_SAVED_VARIABLES_SAVE(stashed_tensors, t, std::move(t)); |
| if (arg.defined()) { |
| TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined()); |
| t = arg.proxy_tensor; |
| } |
| } |
| void after(at::Tensor& t) { |
| SWAP_SAVED_VARIABLES_RESTORE(stashed_tensors, t); |
| } |
| |
| void before(SavedVariable& t) { |
| TensorArg& arg = compiler.tensor_args.lookup(t); |
| SWAP_SAVED_VARIABLES_SAVE(stashed_variables, t, std::move(t)); |
| if (arg.defined()) { |
| TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined()); |
| t = SavedVariable(arg.proxy_tensor, false); |
| } |
| } |
| void after(SavedVariable& t) { |
| SWAP_SAVED_VARIABLES_RESTORE(stashed_variables, t); |
| } |
| |
| void before(c10::SymInt& t) { |
| SWAP_SAVED_VARIABLES_SAVE(stashed_symints, t, t); |
| auto opt_value = state.next_sym_size(); |
| if (opt_value.has_value()) { |
| t = *opt_value; // dynamic shape |
| } |
| } |
| void after(c10::SymInt& t) { |
| SWAP_SAVED_VARIABLES_RESTORE(stashed_symints, t); |
| } |
| |
| void before(Edge& t) { |
| if (t.is_valid()) { |
| // need for symints used by validate_outputs |
| before(t.function->mutable_input_metadata(t.input_nr)); |
| } |
| } |
| void after(Edge& t) { |
| if (t.is_valid()) { |
| after(t.function->mutable_input_metadata(t.input_nr)); |
| } |
| } |
| void before(InputMetadata& t) { |
| before(t.mutable_shape_as_dim_vector()); |
| } |
| void after(InputMetadata& t) { |
| after(t.mutable_shape_as_dim_vector()); |
| } |
| void before(at::TensorGeometry& t) { |
| before(t.mutable_sizes()); |
| before(t.mutable_strides()); |
| before(t.mutable_storage_offset()); |
| t.recompute(); |
| } |
| void after(at::TensorGeometry& t) { |
| after(t.mutable_sizes()); |
| after(t.mutable_strides()); |
| after(t.mutable_storage_offset()); |
| t.recompute(); |
| } |
| void before(torch::autograd::TypeAndSize& t) { |
| before(t.sym_sizes); |
| before(t.options); |
| } |
| void after(torch::autograd::TypeAndSize& t) { |
| after(t.sym_sizes); |
| after(t.options); |
| } |
| void before(VariableInfo& t) { |
| before(t.size); |
| } |
| void after(VariableInfo& t) { |
| after(t.size); |
| } |
| |
| template <typename T> |
| void before(std::vector<T>& t) { |
| for (T& i : t) { |
| before(i); |
| } |
| } |
| template <typename T> |
| void after(std::vector<T>& t) { |
| for (T& i : t) { |
| after(i); |
| } |
| } |
| template <typename T, unsigned N> |
| void before(c10::SmallVector<T, N>& t) { |
| for (T& i : t) { |
| before(i); |
| } |
| } |
| template <typename T, unsigned N> |
| void after(c10::SmallVector<T, N>& t) { |
| for (T& i : t) { |
| after(i); |
| } |
| } |
| |
| template <typename T> |
| void before(c10::OptionalArray<T>& t) { |
| before(t.list); |
| } |
| template <typename T> |
| void after(c10::OptionalArray<T>& t) { |
| after(t.list); |
| } |
| |
| template <typename T> |
| void before(c10::optional<T>& t) { |
| if (t.has_value()) { |
| before(*t); |
| } |
| } |
| template <typename T> |
| void after(c10::optional<T>& t) { |
| if (t.has_value()) { |
| after(*t); |
| } |
| } |
| |
| #define NO_OP_VISIT(T) \ |
| void before(const T&) {} \ |
| void after(const T&) {} |
| NO_OP_VISIT(caffe2::TypeMeta); |
| NO_OP_VISIT(c10::Device); |
| NO_OP_VISIT(c10::DeviceType); |
| NO_OP_VISIT(c10::Layout); |
| NO_OP_VISIT(c10::MemoryFormat); |
| NO_OP_VISIT(c10::ScalarType); |
| NO_OP_VISIT(c10::Scalar); |
| NO_OP_VISIT(c10::TensorOptions); |
| NO_OP_VISIT(std::string); |
| NO_OP_VISIT(int64_t); |
| NO_OP_VISIT(bool); |
| NO_OP_VISIT(double); |
| #undef NO_OP_VISIT |
| |
| // record the need to run `dst.mutable_grad() = src` after the graph |
| // dst is a real tensor, src is a fake tensor |
| void assign_mutable_grad(const at::Tensor& dst, const at::Tensor& src) { |
| const TensorArg& arg = compiler.tensor_args.lookup(dst); |
| TORCH_INTERNAL_ASSERT(arg.defined()); |
| TORCH_INTERNAL_ASSERT( |
| state.outputs.size() == state.output_grad_targets.size()); |
| state.outputs.emplace_back(src); |
| state.output_grad_targets.emplace_back(arg.index()); |
| } |
| |
| SwapSavedVariables(AutogradCompilerCall& c, TraceState& s) |
| : compiler(c), state(s) {} |
| |
| private: |
| AutogradCompilerCall& compiler; |
| TraceState& state; |
| |
| // These mappings are used to save the prior values when we overwrite things |
| // in before(). In after(), we use these to cleanup after ourselves. |
| std::unordered_map<const SavedVariable*, SavedVariable> stashed_variables; |
| std::unordered_map<const at::Tensor*, at::Tensor> stashed_tensors; |
| std::unordered_map<const c10::SymInt*, c10::SymInt> stashed_symints; |
| }; |
| |
| } // namespace torch::dynamo::autograd |
| |
| template <> |
| struct std::hash<torch::dynamo::autograd::CacheKey> { |
| size_t operator()(const torch::dynamo::autograd::CacheKey& k) const { |
| return k.hash(); |
| } |
| }; |