| #include <torch/csrc/autograd/variable.h> |
| |
| #include <torch/csrc/autograd/autograd.h> |
| #include <torch/csrc/autograd/edge.h> |
| #include <torch/csrc/autograd/engine.h> |
| #include <torch/csrc/autograd/function.h> |
| #include <torch/csrc/autograd/InferenceMode.h> |
| #include <torch/csrc/autograd/functions/accumulate_grad.h> |
| #include <torch/csrc/autograd/functions/tensor.h> |
| #include <torch/csrc/autograd/generated/Functions.h> |
| #include <torch/csrc/autograd/utils/error_messages.h> |
| |
| #include <ATen/core/VariableHooksInterface.h> |
| |
| #include <ATen/ATen.h> |
| #include <ATen/MemoryOverlap.h> |
| #include <c10/util/Exception.h> |
| |
| #include <list> |
| #include <memory> |
| #include <mutex> |
| #include <stdexcept> |
| #include <string> |
| #include <vector> |
| #include <typeinfo> |
| #include <iostream> |
| |
| namespace torch { |
| namespace autograd { |
| |
| |
| DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl, |
| c10::optional<ViewInfo> backward_info, |
| c10::optional<ViewInfo> forward_info, |
| CreationMeta creation_meta) |
| : AutogradMeta(self_impl), |
| backward_info_(std::move(backward_info)), |
| forward_info_(std::move(forward_info)), |
| creation_meta_(creation_meta) { |
| is_view_ = true; |
| if (backward_info_.has_value()) { |
| self_impl->set_version_counter(impl::version_counter(backward_info_.value().base_)); |
| attr_version_ = self_impl->version_counter().current_version(); |
| } |
| } |
| |
| // Chain this view info with the new view op between base and tensor |
| ViewInfo ViewInfo::chain(const Variable & base, const Variable & tensor, |
| std::function<Variable(const Variable&)> view_func) const { |
| // Set `view_func` using the root base as input. |
| // `view_func` is used to recover views in backward when either as_strided is not supported |
| // or the view function changes the metadata which is not recorded by as_strided |
| // See Note [View + Inplace update on base tensor] and [View + Inplace update on view tensor] |
| // for more details how we use this function in backward. |
| if (view_func) { |
| // both current_view and it's parent have a view_func |
| if (view_fn_) { |
| // Copy parent view function to gain ownership |
| auto prev_fn = view_fn_; |
| view_func = [=](const at::Tensor& root_base) { |
| auto temp = prev_fn(root_base); |
| return view_func(temp); |
| }; |
| } else { |
| // current_view has a view_func and but it's parent doesn't have one |
| if (base.unsafeGetTensorImpl()->support_as_strided()) { |
| auto size = base.sizes().vec(); |
| auto stride = base.strides().vec(); |
| auto storage_offset = base.storage_offset(); |
| view_func = [=](const at::Tensor& root_base) { |
| auto temp = root_base.as_strided(size, stride, storage_offset); |
| return view_func(temp); |
| }; |
| } else { |
| // When base is a view but doesn't carry a view_fn in DifferentiableViewMeta, it's |
| // a view that doesn't support inplace update, e.g. unbind. |
| // In this case we should throw an error when inplace update happens in **forward**. |
| // One would naturally think the following function will be first called in backward pass. |
| // But the first call site is indeed in **forward** pass when we refresh `grad_fn` |
| // triggered by inplace update. |
| // Search Note [View + Inplace update for view tensor] to for the call site. |
| view_func = [=](const at::Tensor& root_base) { |
| TORCH_CHECK(false, "This view is the output of a function that returns multiple views." |
| "Such functions do not allow the output views to be modified inplace." |
| "You should replace the inplace operation by an out-of-place one"); |
| return root_base; |
| }; |
| } |
| } |
| } else if(view_fn_) { |
| // if current_view doesn't have a view_func but it's parent has one |
| // Copy parent view function to gain ownership |
| auto prev_view_fn = view_fn_; |
| auto size = tensor.sizes().vec(); |
| auto stride = tensor.strides().vec(); |
| auto storage_offset = tensor.storage_offset(); |
| view_func = [=](const at::Tensor& root_base) { |
| auto temp = prev_view_fn(root_base); |
| return temp.as_strided(size, stride, storage_offset); |
| }; |
| } |
| |
| return ViewInfo(base_, view_func); |
| } |
| |
| namespace { |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| at::Tensor singleton_undefined_tensor; |
| |
| struct ConcreteAutogradMetaFactory : public c10::impl::AutogradMetaFactory { |
| std::unique_ptr<c10::AutogradMetaInterface> make() const override { |
| return std::make_unique<AutogradMeta>(); |
| } |
| const at::Tensor& undefined_tensor() const override { |
| return singleton_undefined_tensor; |
| } |
| }; |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| ConcreteAutogradMetaFactory meta_factory; |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer(&meta_factory); |
| |
| } |
| |
| namespace impl { |
| |
| AutogradMeta* materialize_autograd_meta(const Variable& self) { |
| TORCH_CHECK(self.defined(), "cannot call materialize_autograd_meta() on undefined tensor"); |
| auto p = self.unsafeGetTensorImpl(); |
| if (!p->autograd_meta()) { |
| p->set_autograd_meta(std::make_unique<AutogradMeta>()); |
| } |
| return get_autograd_meta(self); |
| } |
| |
| void rebase_history(const Variable& self, Edge gradient_edge) { |
| TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr); |
| auto diff_view_meta = get_view_autograd_meta(self); |
| if (diff_view_meta && diff_view_meta->has_bw_view()) { |
| // See NOTE [ View + Inplace detection ] |
| auto creation_meta = diff_view_meta->get_creation_meta(); |
| if (creation_meta != CreationMeta::MULTI_OUTPUT_SAFE) { |
| // Do not use handle_view_on_rebase here as check_inplace should have been called before this |
| // and either throw an error or clear the warning |
| // Temporary error message as a full fix is too risky for now |
| // Should be an internal assert again |
| TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT); |
| TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0); |
| TORCH_INTERNAL_ASSERT(gradient_edge.function); |
| TORCH_CHECK( |
| gradient_edge.function->num_inputs() == 1, |
| "Functions which modify views in-place must return a single Variable"); |
| auto view_info = diff_view_meta->get_backward_view(); |
| diff_view_meta->output_nr_ = gradient_edge.input_nr; |
| auto copy_slices = std::make_shared<CopySlices>( |
| view_info.base_, at::TensorGeometry(self), view_info.view_fn_, std::move(gradient_edge.function)); |
| set_gradient_edge(view_info.base_, {std::move(copy_slices), 0}); |
| self.grad_fn(); // trigger an update to the view's grad_fn |
| return; |
| } |
| } |
| |
| set_gradient_edge(self, std::move(gradient_edge)); |
| } |
| |
| void create_cpp_hook(const Variable& self) { |
| auto &list = materialize_autograd_meta(self)->cpp_hooks_list_; |
| // NOLINTNEXTLINE(modernize-make-shared) |
| list.reset(new hooks_list()); |
| std::unique_ptr<FunctionPreHook> hook_ptr(new CppFunctionPreHook(list, self.output_nr())); |
| clear_hooks(self); |
| add_hook(self, std::make_shared<CppFunctionPreHook>(list, 0)); |
| // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
| auto fn = self.grad_fn(); |
| if (fn) { |
| fn->add_pre_hook(std::move(hook_ptr)); |
| } |
| } |
| |
| void set_grad_accumulator(const Variable& self, |
| std::weak_ptr<Node> grad_accumulator) { |
| materialize_autograd_meta(self)->grad_accumulator_ = std::move(grad_accumulator); |
| } |
| |
| std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) { |
| if (get_autograd_meta(self)) { |
| return get_autograd_meta(self)->grad_accumulator_.lock(); |
| } else { |
| return nullptr; |
| } |
| } |
| |
| std::shared_ptr<Node> grad_accumulator(const Variable& self) { |
| auto autograd_meta = get_autograd_meta(self); |
| if (!autograd_meta) { |
| return nullptr; |
| } |
| if (autograd_meta->grad_fn_) { |
| throw std::logic_error( |
| "grad_accumulator() should be only called on leaf Variables"); |
| } |
| if (!autograd_meta->requires_grad_) { |
| return nullptr; |
| } |
| |
| std::lock_guard<std::mutex> lock(autograd_meta->mutex_); |
| |
| auto result = autograd_meta->grad_accumulator_.lock(); |
| if (result) |
| return result; |
| |
| c10::raw::intrusive_ptr::incref(self.unsafeGetTensorImpl()); |
| auto intrusive_from_this = c10::intrusive_ptr<at::TensorImpl>::reclaim(self.unsafeGetTensorImpl()); |
| result = std::make_shared<AccumulateGrad>(Variable(std::move(intrusive_from_this))); |
| autograd_meta->grad_accumulator_ = result; |
| return result; |
| } |
| |
| Edge gradient_edge(const Variable& self) { |
| // If grad_fn is null (as is the case for a leaf node), we instead |
| // interpret the gradient function to be a gradient accumulator, which will |
| // accumulate its inputs into the grad property of the variable. These |
| // nodes get suppressed in some situations, see "suppress gradient |
| // accumulation" below. Note that only variables which have `requires_grad = |
| // True` can have gradient accumulators. |
| if (const auto& gradient = self.grad_fn()) { |
| return Edge(gradient, self.output_nr()); |
| } else { |
| return Edge(grad_accumulator(self), 0); |
| } |
| } |
| |
| void set_gradient_edge(const Variable& self, Edge edge) { |
| auto* meta = materialize_autograd_meta(self); |
| meta->grad_fn_ = std::move(edge.function); |
| meta->output_nr_ = edge.input_nr; |
| // For views, make sure this new grad_fn_ is not overwritten unless it is necessary |
| // in the VariableHooks::grad_fn below. |
| // This logic is only relevant for custom autograd Functions for which multiple |
| // operations can happen on a given Tensor before its gradient edge is set when |
| // exiting the custom Function. |
| auto diff_view_meta = get_view_autograd_meta(self); |
| if (diff_view_meta && diff_view_meta->has_bw_view()) { |
| diff_view_meta->set_attr_version(self._version()); |
| } |
| } |
| |
| Node* grad_fn_unsafe(const Variable& self) { |
| if (get_autograd_meta(self)) { |
| return get_autograd_meta(self)->grad_fn_.get(); |
| } else { |
| return nullptr; |
| } |
| } |
| |
| // Versions |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| void set_version_counter( |
| const Variable& self, |
| const c10::VariableVersion& version_counter) { |
| TORCH_CHECK(self.defined(), "cannot call set_version_counter() on undefined tensor"); |
| self.unsafeGetTensorImpl()->set_version_counter(version_counter); |
| } |
| |
| void bump_version(const Variable& self) { |
| TORCH_CHECK(self.defined(), "cannot call bump_version() on undefined tensor"); |
| self.unsafeGetTensorImpl()->bump_version(); |
| } |
| |
| const c10::VariableVersion& version_counter(const Variable& self) { |
| TORCH_CHECK(self.defined(), "cannot call version_counter() on undefined tensor"); |
| return self.unsafeGetTensorImpl()->version_counter(); |
| } |
| |
| // Hooks |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| void add_hook(const Variable& self, std::shared_ptr<FunctionPreHook> hook) { |
| materialize_autograd_meta(self)->hooks_.push_back(std::move(hook)); |
| } |
| |
| namespace { |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| std::vector<std::shared_ptr<FunctionPreHook>> empty_singleton; |
| } |
| |
| // TODO: Return an ArrayRef instead (and delete the singleton while you're at |
| // it |
| const std::vector<std::shared_ptr<FunctionPreHook>>& hooks(const Variable& self) |
| { |
| if (get_autograd_meta(self)) { |
| return get_autograd_meta(self)->hooks_; |
| } else { |
| return empty_singleton; |
| } |
| } |
| |
| void clear_hooks(const Variable& self) { |
| // This is a little goofy, but usually this should be a no oop |
| materialize_autograd_meta(self)->hooks_.clear(); |
| } |
| |
| void set_name(const Variable& self, const std::string& name) { |
| materialize_autograd_meta(self)->name_ = name; |
| } |
| |
| // Miscellaneous |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| void set_pyobj(const Variable& self, PyObject* pyobj) { |
| TORCH_CHECK(self.defined(), "cannot call set_pyobj() on undefined tensor"); |
| self.unsafeGetTensorImpl()->set_pyobj(pyobj); |
| } |
| |
| PyObject* pyobj(const Variable& self) { |
| TORCH_CHECK(self.defined(), "cannot call pyobj() on undefined tensor"); |
| return self.unsafeGetTensorImpl()->pyobj(); |
| } |
| |
| AutogradMeta* get_autograd_meta(const Variable& self) { |
| // NB: could return nullptr |
| TORCH_CHECK(self.defined(), "cannot call get_autograd_meta() on undefined tensor"); |
| return static_cast<AutogradMeta*>(self.unsafeGetTensorImpl()->autograd_meta()); |
| } |
| |
| DifferentiableViewMeta* get_view_autograd_meta(const Variable& self) { |
| // NB: return nullptr if self is not a view |
| AutogradMeta* meta = get_autograd_meta(self); |
| if (meta && meta->is_view_) { |
| return static_cast<DifferentiableViewMeta*>(meta); |
| } else { |
| return nullptr; |
| } |
| } |
| |
| } // namespace impl |
| |
| using at::Tensor; |
| |
| struct VariableHooks final : at::impl::VariableHooksInterface { |
| Tensor tensor_data(const Tensor&) const override; |
| Tensor variable_data(const Tensor&) const override; |
| const std::shared_ptr<torch::autograd::Node>& grad_fn(const Tensor&) const override; |
| unsigned _register_hook(const Tensor&, std::function<Tensor(const Tensor&)> hook) const override; |
| void remove_hook(const Tensor&, unsigned pos) const override; |
| bool is_view(const Tensor&) const override; |
| const Tensor& base(const Tensor&) const override; |
| const std::string& name(const Tensor&) const override; |
| bool is_leaf(const Tensor&) const override; |
| int64_t output_nr(const Tensor&) const override; |
| void set_data(const Tensor & self, const Tensor & new_data) const override; |
| Tensor data(const Tensor & self) const override; |
| int64_t _version(const Tensor & self) const override; |
| void retain_grad(const Tensor & self) const override; |
| void _backward(const Tensor& self, at::TensorList inputs, |
| const c10::optional<Tensor>& gradient, c10::optional<bool> keep_graph, |
| bool create_graph) const override; |
| void requires_grad_(const Tensor& self, bool _requires_grad) const override; |
| }; |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| VariableHooks variableHooks; |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks); |
| |
| Tensor VariableHooks::variable_data(const Tensor& self) const { |
| TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor"); |
| auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach( |
| /*version_counter=*/0, |
| /*allow_tensor_metadata_change=*/false); |
| self_impl_copy->set_autograd_meta(nullptr); |
| return at::Tensor(self_impl_copy); |
| } |
| |
| Tensor VariableHooks::tensor_data(const Tensor& self) const { |
| TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor"); |
| auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach( |
| /*version_counter=*/self.unsafeGetTensorImpl()->version_counter(), |
| /*allow_tensor_metadata_change=*/self.unsafeGetTensorImpl()->allow_tensor_metadata_change()); |
| return at::Tensor(self_impl_copy); |
| } |
| |
| bool VariableHooks::is_leaf(const Tensor & self) const { |
| if (impl::get_autograd_meta(self)) { |
| return impl::get_autograd_meta(self)->grad_fn_ == nullptr; |
| } else { |
| return true; |
| } |
| } |
| |
| int64_t VariableHooks::output_nr(const Tensor & self) const { |
| if (impl::get_autograd_meta(self)) { |
| return impl::get_autograd_meta(self)->output_nr_; |
| } else { |
| return 0; |
| } |
| } |
| |
| void VariableHooks::set_data(const Tensor & self, const Tensor & new_data) const { |
| // `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields |
| // from `new_data` to `var`. It requires that `new_data` and `var` have compatible |
| // tensor type. |
| TORCH_CHECK( |
| _has_compatible_shallow_copy_type(self, new_data), |
| "Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type."); |
| |
| // Resets gradient accumulator if metadata is out of date |
| AutogradMeta* autograd_meta = impl::get_autograd_meta(self); |
| if (autograd_meta) { |
| std::lock_guard<std::mutex> lock(autograd_meta->mutex_); |
| auto prior_accumulator = autograd_meta->grad_accumulator_.lock(); |
| if (prior_accumulator) { |
| const auto prior_device = prior_accumulator->input_metadata(0).device(); |
| const auto new_device = new_data.device(); |
| |
| if (!new_data.options().type_equal(self.options()) || prior_device != new_device) { |
| autograd_meta->grad_accumulator_.reset(); |
| } |
| } |
| } |
| |
| // Version counter is not shared when we replace a `Variable`'s tensor data |
| // by calling `set_data(...)`. The original version of the `Variable` is always preserved. |
| // See NOTE [ Version Counter Sharing ] for details. |
| // |
| // `var.set_data(new_data)` always ignores `var`'s `allow_tensor_metadata_change_`, because |
| // users need this API as an escape hatch for changing a tensor's metadata regardless of its |
| // `allow_tensor_metadata_change_` value, and the users are responsible for ensuring this is |
| // the behavior they want. |
| self.unsafeGetTensorImpl()->shallow_copy_from(new_data.getIntrusivePtr()); |
| } |
| |
| Tensor VariableHooks::data(const Tensor & self) const { |
| return self.variable_data(); |
| } |
| |
| int64_t VariableHooks::_version(const Tensor & self) const { |
| return self.unsafeGetTensorImpl()->version_counter().current_version(); |
| } |
| |
| void VariableHooks::retain_grad(const Tensor & self) const { |
| TORCH_CHECK(self.requires_grad(), "can't retain_grad on Tensor that has requires_grad=False"); |
| if (self.is_leaf()) { // no-op for leaves |
| return; |
| } |
| if (impl::get_autograd_meta(self)->retains_grad_) { |
| return; |
| } |
| c10::weak_intrusive_ptr<c10::TensorImpl> weak_self(self.getIntrusivePtr()); |
| |
| std::function<void(Tensor)> retain_grad_hook([weak_self](const Tensor& grad) { |
| if (weak_self.expired()) { |
| return; |
| } else { |
| auto var = weak_self.lock(); |
| if (!var->grad().defined()) { |
| if (grad.is_sparse()) { |
| var->mutable_grad() = grad.clone(); |
| } else { |
| var->mutable_grad() = grad.clone(at::MemoryFormat::Contiguous); |
| } |
| } else { |
| var->mutable_grad() = var->grad() + grad; |
| } |
| } |
| }); |
| |
| self.register_hook(retain_grad_hook); |
| impl::get_autograd_meta(self)->retains_grad_ = true; |
| } |
| |
| void VariableHooks::_backward( |
| const Tensor& self, |
| at::TensorList inputs, |
| const c10::optional<Tensor>& gradient, |
| c10::optional<bool> keep_graph, |
| bool create_graph) const { |
| // TODO torch::autograd::backward should take the c10::optional<Tensor> gradient directly |
| // instead of us having to unwrap it to Tensor _gradient here. |
| Tensor _gradient = gradient.has_value() ? *gradient : Tensor(); |
| std::vector<torch::autograd::Variable> input_vars(inputs.begin(), inputs.end()); |
| torch::autograd::backward({self}, {_gradient}, keep_graph, create_graph, input_vars); |
| } |
| |
| void VariableHooks::requires_grad_(const Tensor& self, bool _requires_grad) const { |
| if (!self.is_leaf() && !_requires_grad) { |
| throw std::runtime_error( |
| autograd::utils::requires_grad_leaf_error(_requires_grad) |
| ); |
| } |
| self.set_requires_grad(_requires_grad); |
| } |
| |
| // Backward View Variables |
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| |
| bool VariableHooks::is_view(const Tensor& self) const { |
| auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); |
| if (diff_view_meta) { |
| return diff_view_meta->has_bw_view(); |
| } else { |
| return false; |
| } |
| } |
| |
| const Tensor& VariableHooks::base(const Tensor& self) const { |
| auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); |
| if (diff_view_meta) { |
| TORCH_CHECK(diff_view_meta->has_bw_view(), "Can't get base of non-backward view Tensor"); |
| return diff_view_meta->get_backward_view().base_; |
| } else { |
| throw std::runtime_error("Can't get base of non-view Tensor"); |
| } |
| } |
| |
| namespace { |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| std::string singleton_string; |
| } |
| |
| const std::string& VariableHooks::name(const Tensor& self) const { |
| TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor"); |
| if (torch::autograd::impl::get_autograd_meta(self)) { |
| return torch::autograd::impl::get_autograd_meta(self)->name_; |
| } else { |
| return singleton_string; |
| } |
| } |
| |
| namespace { |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| std::shared_ptr<torch::autograd::Node> singleton_shared_ptr; |
| } |
| |
| const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tensor& self) const { |
| auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); |
| if (diff_view_meta && diff_view_meta->has_bw_view()) { |
| // See NOTE [ View + Inplace detection ] |
| if (diff_view_meta->get_creation_meta() != CreationMeta::MULTI_OUTPUT_SAFE) { |
| std::lock_guard<std::mutex> lock(diff_view_meta->mutex_); |
| auto view_info = diff_view_meta->get_backward_view(); |
| if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) { |
| return diff_view_meta->grad_fn_; |
| } |
| auto current_version = self._version(); |
| if (diff_view_meta->get_attr_version() != current_version) { |
| // This is an indirect rebase_history due to another view or the base being modified inplace |
| handle_view_on_rebase(diff_view_meta, /* indirect */ true); |
| TORCH_INTERNAL_ASSERT(diff_view_meta->output_nr_ == 0); |
| // Note [View + Inplace update for view tensor] |
| // An inplace update happened on Tensor `self` (which is a view). |
| // For example: |
| // view_1 = view_op_1(diff_view_meta->base_) |
| // view_2 = view_op_2(view_1) |
| // ... |
| // self = view_op_n(view_n-1) |
| // self = inplace_op(self) |
| // |
| // For CPU/CUDA backends, we employ one AsStridedBackward Node to represent the chain of |
| // view backward ops for effienciency. |
| // |
| // However in XLA backend we don't have full support of AsStridedBackward, we instead run a full |
| // forward pass with a tensor that requires gradient to get proper grad_fn setup, |
| // then save it to DifferentiableViewMeta for future use. |
| // This is fairly cheap for XLA lazy tensor approach (but would be really expensive for CPU/CUDA). |
| // XLA Tensor only run thorugh VariableType dispatch and lower the forward pass to a XLA HLO graph, |
| // then we take grad_fn and never materialize the tensor content. |
| // So we only construct the graph but not execute it, which is a fairly cheap operation to do. |
| // |
| // See Note [View + Inplace update for base tensor] for what we do to base tensor when |
| // an in-place operation happens. |
| // |
| // TODO: Potentially the following logic can be replaced by special logic in VariableType_x.cpp |
| // that would provide a way to recreate the grad_fn chain. |
| if (view_info.has_view_fn()) { |
| auto view_fn = view_info.view_fn(); |
| auto diff_view = view_fn(view_info.base_); |
| diff_view_meta->grad_fn_ = diff_view.grad_fn(); |
| } else { |
| auto fn = std::make_shared<torch::autograd::generated::AsStridedBackward>(); |
| fn->self_geometry = at::TensorGeometry(view_info.base_); |
| fn->size = self.sizes().vec(); |
| fn->stride = self.strides().vec(); |
| fn->storage_offset = self.storage_offset(); |
| fn->set_next_edges(torch::autograd::collect_next_edges(view_info.base_)); |
| fn->add_input_metadata( |
| view_info.base_.options(), |
| self.sizes(), // Note: sizes(), not base_.sizes(), is intentional |
| view_info.base_.device()); |
| diff_view_meta->grad_fn_ = std::move(fn); |
| } |
| diff_view_meta->set_attr_version(current_version); |
| } |
| return diff_view_meta->grad_fn_; |
| } |
| } |
| |
| if (torch::autograd::impl::get_autograd_meta(self)) { |
| return torch::autograd::impl::get_autograd_meta(self)->grad_fn_; |
| } else { |
| return singleton_shared_ptr; |
| } |
| } |
| |
| void VariableHooks::remove_hook(const Tensor& self, unsigned pos) const { |
| auto &list = torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list_; |
| TORCH_CHECK(list && pos < list->size() , "Invalid index, no hook at position ", pos); |
| // Hook will be ignored |
| (*list)[pos] = nullptr; |
| } |
| |
| unsigned VariableHooks::_register_hook(const Tensor& self, std::function<Tensor(const Tensor&)> hook) const { |
| TORCH_CHECK(self.requires_grad(), "cannot register a hook on a variable that " |
| "doesn't require gradient"); |
| // NB: materialize_autograd_meta unnecessary due to requires grad check |
| auto &list = torch::autograd::impl::get_autograd_meta(self)->cpp_hooks_list_; |
| if(!list) { |
| torch::autograd::impl::create_cpp_hook(self); |
| } |
| unsigned idx = list->size(); |
| list->push_back(hook); |
| return idx; |
| } |
| |
| void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect) { |
| /// See NOTE [ View + Inplace detection ] for justification of the logic below |
| auto creation_meta = diff_view_meta->get_creation_meta(); |
| if (creation_meta != CreationMeta::DEFAULT) { |
| auto grad_fn = diff_view_meta->grad_fn_.get(); |
| std::string msg; |
| std::string modified_obj; |
| // Create the header for the error message. |
| if (indirect) { |
| modified_obj = "its base or another view of its base has been"; |
| } else { |
| modified_obj = "is being"; |
| } |
| if (grad_fn) { |
| msg = c10::str("Output ", diff_view_meta->output_nr_, " of ", grad_fn->name(), " is a view and ", |
| modified_obj, " modified inplace."); |
| } else if (creation_meta == CreationMeta::INFERENCE_MODE) { |
| msg = c10::str("A view was created in inference mode and ", modified_obj, " modified inplace in normal mode."); |
| } else { |
| msg = c10::str("A view was created in no_grad mode and ", modified_obj, " modified inplace with grad mode enabled."); |
| } |
| |
| if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) { |
| TORCH_CHECK(false, msg, " This view is the output of a function that returns multiple views. Such functions do not" |
| " allow the output views to be modified inplace. You should replace the inplace operation by an" |
| " out-of-place one."); |
| } else { |
| if (creation_meta == CreationMeta::NO_GRAD_MODE) { |
| TORCH_INTERNAL_ASSERT(!grad_fn); |
| msg = c10::str(msg, " Given that this use case is ambiguous and error-prone, it is forbidden." |
| " You can clarify your code and remove this warning by moving both the view and the inplace either both" |
| " inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want" |
| " the inplace to be tracked)."); |
| } else if (creation_meta == CreationMeta::INFERENCE_MODE) { |
| TORCH_INTERNAL_ASSERT(!grad_fn); |
| msg = c10::str(msg, " Given that this use case is ambiguous and error-prone, it is forbidden." |
| " You can clarify your code by moving both the view and the inplace either both" |
| " inside the inference_mode block (if you don't want the inplace to be tracked) or both outside (if you want" |
| " the inplace to be tracked)."); |
| TORCH_CHECK(false, msg); |
| } else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) { |
| msg = c10::str(msg, " This view was created inside a custom Function (or because an input was returned as-is) and the" |
| " autograd logic to handle view+inplace would override the custom backward associated with the custom" |
| " Function, leading to incorrect gradients. This behavior is forbidden. You can remove this warning by" |
| " cloning the output of the custom Function."); |
| } else if (creation_meta == CreationMeta::MULTI_OUTPUT_SAFE) { |
| msg = c10::str(msg, " This view is an output of a function that " |
| "returns multiple views. Inplace operators on such " |
| "views is forbidden. You should replace the inplace " |
| "operation by an out-of-place one."); |
| } else { |
| TORCH_INTERNAL_ASSERT(false, "Invalid CreationMeta state"); |
| } |
| |
| if (creation_meta == CreationMeta::NO_GRAD_MODE) { |
| // TODO: remove this before 1.9 once all code is properly updated |
| TORCH_WARN(msg); |
| } else { |
| TORCH_CHECK(false, msg); |
| } |
| } |
| |
| // We warn only once per view |
| // Note that if a Tensor is modified inplace from two threads at the same time, this is not thread safe and can warn |
| // multiple time. This is ok as it should be a rare event. |
| diff_view_meta->set_creation_meta(CreationMeta::DEFAULT); |
| } |
| } |
| |
| |
| |
| }} // namespace torch::autograd |