blob: d74b6ef7860198f19dea482a69f70fc534727a9f [file] [log] [blame]
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/InferenceMode.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/functions/tensor.h>
#include <torch/csrc/autograd/generated/Functions.h>
#include <torch/csrc/autograd/utils/error_messages.h>
#include <ATen/core/VariableHooksInterface.h>
#include <ATen/ATen.h>
#include <ATen/MemoryOverlap.h>
#include <c10/util/Exception.h>
#include <list>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <string>
#include <vector>
#include <typeinfo>
#include <iostream>
namespace torch {
namespace autograd {
DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl,
c10::optional<ViewInfo> backward_info,
c10::optional<ViewInfo> forward_info,
CreationMeta creation_meta)
: AutogradMeta(self_impl),
backward_info_(std::move(backward_info)),
forward_info_(std::move(forward_info)),
creation_meta_(creation_meta) {
is_view_ = true;
if (backward_info_.has_value()) {
self_impl->set_version_counter(impl::version_counter(backward_info_.value().base_));
attr_version_ = self_impl->version_counter().current_version();
}
}
// Chain this view info with the new view op between base and tensor
ViewInfo ViewInfo::chain(const Variable & base, const Variable & tensor,
std::function<Variable(const Variable&)> view_func) const {
// Set `view_func` using the root base as input.
// `view_func` is used to recover views in backward when either as_strided is not supported
// or the view function changes the metadata which is not recorded by as_strided
// See Note [View + Inplace update on base tensor] and [View + Inplace update on view tensor]
// for more details how we use this function in backward.
if (view_func) {
// both current_view and it's parent have a view_func
if (view_fn_) {
// Copy parent view function to gain ownership
auto prev_fn = view_fn_;
view_func = [=](const at::Tensor& root_base) {
auto temp = prev_fn(root_base);
return view_func(temp);
};
} else {
// current_view has a view_func and but it's parent doesn't have one
if (base.unsafeGetTensorImpl()->support_as_strided()) {
auto size = base.sizes().vec();
auto stride = base.strides().vec();
auto storage_offset = base.storage_offset();
view_func = [=](const at::Tensor& root_base) {
auto temp = root_base.as_strided(size, stride, storage_offset);
return view_func(temp);
};
} else {
// When base is a view but doesn't carry a view_fn in DifferentiableViewMeta, it's
// a view that doesn't support inplace update, e.g. unbind.
// In this case we should throw an error when inplace update happens in **forward**.
// One would naturally think the following function will be first called in backward pass.
// But the first call site is indeed in **forward** pass when we refresh `grad_fn`
// triggered by inplace update.
// Search Note [View + Inplace update for view tensor] to for the call site.
view_func = [=](const at::Tensor& root_base) {
TORCH_CHECK(false, "This view is the output of a function that returns multiple views."
"Such functions do not allow the output views to be modified inplace."
"You should replace the inplace operation by an out-of-place one");
return root_base;
};
}
}
} else if(view_fn_) {
// if current_view doesn't have a view_func but it's parent has one
// Copy parent view function to gain ownership
auto prev_view_fn = view_fn_;
auto size = tensor.sizes().vec();
auto stride = tensor.strides().vec();
auto storage_offset = tensor.storage_offset();
view_func = [=](const at::Tensor& root_base) {
auto temp = prev_view_fn(root_base);
return temp.as_strided(size, stride, storage_offset);
};
}
return ViewInfo(base_, view_func);
}
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
at::Tensor singleton_undefined_tensor;
struct ConcreteAutogradMetaFactory : public c10::impl::AutogradMetaFactory {
std::unique_ptr<c10::AutogradMetaInterface> make() const override {
return std::make_unique<AutogradMeta>();
}
const at::Tensor& undefined_tensor() const override {
return singleton_undefined_tensor;
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
ConcreteAutogradMetaFactory meta_factory;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer(&meta_factory);
}
namespace impl {
AutogradMeta* materialize_autograd_meta(const Variable& self) {
TORCH_CHECK(self.defined(), "cannot call materialize_autograd_meta() on undefined tensor");
auto p = self.unsafeGetTensorImpl();
if (!p->autograd_meta()) {
p->set_autograd_meta(std::make_unique<AutogradMeta>());
}
return get_autograd_meta(self);
}
void rebase_history(const Variable& self, Edge gradient_edge) {
TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr);
auto diff_view_meta = get_view_autograd_meta(self);
if (diff_view_meta && diff_view_meta->has_bw_view()) {
// See NOTE [ View + Inplace detection ]
auto creation_meta = diff_view_meta->get_creation_meta();
if (creation_meta != CreationMeta::MULTI_OUTPUT_SAFE) {
// Do not use handle_view_on_rebase here as check_inplace should have been called before this
// and either throw an error or clear the warning
// Temporary error message as a full fix is too risky for now
// Should be an internal assert again
TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT);
TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0);
TORCH_INTERNAL_ASSERT(gradient_edge.function);
TORCH_CHECK(
gradient_edge.function->num_inputs() == 1,
"Functions which modify views in-place must return a single Variable");
auto view_info = diff_view_meta->get_backward_view();
diff_view_meta->output_nr_ = gradient_edge.input_nr;
auto copy_slices = std::make_shared<CopySlices>(
view_info.base_, at::TensorGeometry(self), view_info.view_fn_, std::move(gradient_edge.function));
set_gradient_edge(view_info.base_, {std::move(copy_slices), 0});
self.grad_fn(); // trigger an update to the view's grad_fn
return;
}
}
set_gradient_edge(self, std::move(gradient_edge));
}
void create_cpp_hook(const Variable& self) {
auto &list = materialize_autograd_meta(self)->cpp_hooks_list_;
// NOLINTNEXTLINE(modernize-make-shared)
list.reset(new hooks_list());
std::unique_ptr<FunctionPreHook> hook_ptr(new CppFunctionPreHook(list, self.output_nr()));
clear_hooks(self);
add_hook(self, std::make_shared<CppFunctionPreHook>(list, 0));
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto fn = self.grad_fn();
if (fn) {
fn->add_pre_hook(std::move(hook_ptr));
}
}
void set_grad_accumulator(const Variable& self,
std::weak_ptr<Node> grad_accumulator) {
materialize_autograd_meta(self)->grad_accumulator_ = std::move(grad_accumulator);
}
std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
if (get_autograd_meta(self)) {
return get_autograd_meta(self)->grad_accumulator_.lock();
} else {
return nullptr;
}
}
std::shared_ptr<Node> grad_accumulator(const Variable& self) {
auto autograd_meta = get_autograd_meta(self);
if (!autograd_meta) {
return nullptr;
}
if (autograd_meta->grad_fn_) {
throw std::logic_error(
"grad_accumulator() should be only called on leaf Variables");
}
if (!autograd_meta->requires_grad_) {
return nullptr;
}
std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
auto result = autograd_meta->grad_accumulator_.lock();
if (result)
return result;
c10::raw::intrusive_ptr::incref(self.unsafeGetTensorImpl());
auto intrusive_from_this = c10::intrusive_ptr<at::TensorImpl>::reclaim(self.unsafeGetTensorImpl());
result = std::make_shared<AccumulateGrad>(Variable(std::move(intrusive_from_this)));
autograd_meta->grad_accumulator_ = result;
return result;
}
Edge gradient_edge(const Variable& self) {
// If grad_fn is null (as is the case for a leaf node), we instead
// interpret the gradient function to be a gradient accumulator, which will
// accumulate its inputs into the grad property of the variable. These
// nodes get suppressed in some situations, see "suppress gradient
// accumulation" below. Note that only variables which have `requires_grad =
// True` can have gradient accumulators.
if (const auto& gradient = self.grad_fn()) {
return Edge(gradient, self.output_nr());
} else {
return Edge(grad_accumulator(self), 0);
}
}
void set_gradient_edge(const Variable& self, Edge edge) {
auto* meta = materialize_autograd_meta(self);
meta->grad_fn_ = std::move(edge.function);
meta->output_nr_ = edge.input_nr;
// For views, make sure this new grad_fn_ is not overwritten unless it is necessary
// in the VariableHooks::grad_fn below.
// This logic is only relevant for custom autograd Functions for which multiple
// operations can happen on a given Tensor before its gradient edge is set when
// exiting the custom Function.
auto diff_view_meta = get_view_autograd_meta(self);
if (diff_view_meta && diff_view_meta->has_bw_view()) {
diff_view_meta->set_attr_version(self._version());
}
}
Node* grad_fn_unsafe(const Variable& self) {
if (get_autograd_meta(self)) {
return get_autograd_meta(self)->grad_fn_.get();
} else {
return nullptr;
}
}
// Versions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
void set_version_counter(
const Variable& self,
const c10::VariableVersion& version_counter) {
TORCH_CHECK(self.defined(), "cannot call set_version_counter() on undefined tensor");
self.unsafeGetTensorImpl()->set_version_counter(version_counter);
}
void bump_version(const Variable& self) {
TORCH_CHECK(self.defined(), "cannot call bump_version() on undefined tensor");
self.unsafeGetTensorImpl()->bump_version();
}
const c10::VariableVersion& version_counter(const Variable& self) {
TORCH_CHECK(self.defined(), "cannot call version_counter() on undefined tensor");
return self.unsafeGetTensorImpl()->version_counter();
}
// Hooks
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
void add_hook(const Variable& self, std::shared_ptr<FunctionPreHook> hook) {
materialize_autograd_meta(self)->hooks_.push_back(std::move(hook));
}
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
std::vector<std::shared_ptr<FunctionPreHook>> empty_singleton;
}
// TODO: Return an ArrayRef instead (and delete the singleton while you're at
// it
const std::vector<std::shared_ptr<FunctionPreHook>>& hooks(const Variable& self)
{
if (get_autograd_meta(self)) {
return get_autograd_meta(self)->hooks_;
} else {
return empty_singleton;
}
}
void clear_hooks(const Variable& self) {
// This is a little goofy, but usually this should be a no oop
materialize_autograd_meta(self)->hooks_.clear();
}
void set_name(const Variable& self, const std::string& name) {
materialize_autograd_meta(self)->name_ = name;
}
// Miscellaneous
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
void set_pyobj(const Variable& self, PyObject* pyobj) {
TORCH_CHECK(self.defined(), "cannot call set_pyobj() on undefined tensor");
self.unsafeGetTensorImpl()->set_pyobj(pyobj);
}
PyObject* pyobj(const Variable& self) {
TORCH_CHECK(self.defined(), "cannot call pyobj() on undefined tensor");
return self.unsafeGetTensorImpl()->pyobj();
}
AutogradMeta* get_autograd_meta(const Variable& self) {
// NB: could return nullptr
TORCH_CHECK(self.defined(), "cannot call get_autograd_meta() on undefined tensor");
return static_cast<AutogradMeta*>(self.unsafeGetTensorImpl()->autograd_meta());
}
DifferentiableViewMeta* get_view_autograd_meta(const Variable& self) {
// NB: return nullptr if self is not a view
AutogradMeta* meta = get_autograd_meta(self);
if (meta && meta->is_view_) {
return static_cast<DifferentiableViewMeta*>(meta);
} else {
return nullptr;
}
}
} // namespace impl
using at::Tensor;
struct VariableHooks final : at::impl::VariableHooksInterface {
Tensor tensor_data(const Tensor&) const override;
Tensor variable_data(const Tensor&) const override;
const std::shared_ptr<torch::autograd::Node>& grad_fn(const Tensor&) const override;
unsigned _register_hook(const Tensor&, std::function<Tensor(const Tensor&)> hook) const override;
void remove_hook(const Tensor&, unsigned pos) const override;
bool is_view(const Tensor&) const override;
const Tensor& base(const Tensor&) const override;
const std::string& name(const Tensor&) const override;
bool is_leaf(const Tensor&) const override;
int64_t output_nr(const Tensor&) const override;
void set_data(const Tensor & self, const Tensor & new_data) const override;
Tensor data(const Tensor & self) const override;
int64_t _version(const Tensor & self) const override;
void retain_grad(const Tensor & self) const override;
void _backward(const Tensor& self, at::TensorList inputs,
const c10::optional<Tensor>& gradient, c10::optional<bool> keep_graph,
bool create_graph) const override;
void requires_grad_(const Tensor& self, bool _requires_grad) const override;
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
VariableHooks variableHooks;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks);
Tensor VariableHooks::variable_data(const Tensor& self) const {
TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor");
auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/false);
self_impl_copy->set_autograd_meta(nullptr);
return at::Tensor(self_impl_copy);
}
Tensor VariableHooks::tensor_data(const Tensor& self) const {
TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor");
auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
/*version_counter=*/self.unsafeGetTensorImpl()->version_counter(),
/*allow_tensor_metadata_change=*/self.unsafeGetTensorImpl()->allow_tensor_metadata_change());
return at::Tensor(self_impl_copy);
}
bool VariableHooks::is_leaf(const Tensor & self) const {
if (impl::get_autograd_meta(self)) {
return impl::get_autograd_meta(self)->grad_fn_ == nullptr;
} else {
return true;
}
}
int64_t VariableHooks::output_nr(const Tensor & self) const {
if (impl::get_autograd_meta(self)) {
return impl::get_autograd_meta(self)->output_nr_;
} else {
return 0;
}
}
void VariableHooks::set_data(const Tensor & self, const Tensor & new_data) const {
// `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields
// from `new_data` to `var`. It requires that `new_data` and `var` have compatible
// tensor type.
TORCH_CHECK(
_has_compatible_shallow_copy_type(self, new_data),
"Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type.");
// Resets gradient accumulator if metadata is out of date
AutogradMeta* autograd_meta = impl::get_autograd_meta(self);
if (autograd_meta) {
std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
auto prior_accumulator = autograd_meta->grad_accumulator_.lock();
if (prior_accumulator) {
const auto prior_device = prior_accumulator->input_metadata(0).device();
const auto new_device = new_data.device();
if (!new_data.options().type_equal(self.options()) || prior_device != new_device) {
autograd_meta->grad_accumulator_.reset();
}
}
}
// Version counter is not shared when we replace a `Variable`'s tensor data
// by calling `set_data(...)`. The original version of the `Variable` is always preserved.
// See NOTE [ Version Counter Sharing ] for details.
//
// `var.set_data(new_data)` always ignores `var`'s `allow_tensor_metadata_change_`, because
// users need this API as an escape hatch for changing a tensor's metadata regardless of its
// `allow_tensor_metadata_change_` value, and the users are responsible for ensuring this is
// the behavior they want.
self.unsafeGetTensorImpl()->shallow_copy_from(new_data.getIntrusivePtr());
}
Tensor VariableHooks::data(const Tensor & self) const {
return self.variable_data();
}
int64_t VariableHooks::_version(const Tensor & self) const {
return self.unsafeGetTensorImpl()->version_counter().current_version();
}
void VariableHooks::retain_grad(const Tensor & self) const {
TORCH_CHECK(self.requires_grad(), "can't retain_grad on Tensor that has requires_grad=False");
if (self.is_leaf()) { // no-op for leaves
return;
}
if (impl::get_autograd_meta(self)->retains_grad_) {
return;
}
c10::weak_intrusive_ptr<c10::TensorImpl> weak_self(self.getIntrusivePtr());
std::function<void(Tensor)> retain_grad_hook([weak_self](const Tensor& grad) {
if (weak_self.expired()) {
return;
} else {
auto var = weak_self.lock();
if (!var->grad().defined()) {
if (grad.is_sparse()) {
var->mutable_grad() = grad.clone();
} else {
var->mutable_grad() = grad.clone(at::MemoryFormat::Contiguous);
}
} else {
var->mutable_grad() = var->grad() + grad;
}
}
});
self.register_hook(retain_grad_hook);
impl::get_autograd_meta(self)->retains_grad_ = true;
}
void VariableHooks::_backward(
const Tensor& self,
at::TensorList inputs,
const c10::optional<Tensor>& gradient,
c10::optional<bool> keep_graph,
bool create_graph) const {
// TODO torch::autograd::backward should take the c10::optional<Tensor> gradient directly
// instead of us having to unwrap it to Tensor _gradient here.
Tensor _gradient = gradient.has_value() ? *gradient : Tensor();
std::vector<torch::autograd::Variable> input_vars(inputs.begin(), inputs.end());
torch::autograd::backward({self}, {_gradient}, keep_graph, create_graph, input_vars);
}
void VariableHooks::requires_grad_(const Tensor& self, bool _requires_grad) const {
if (!self.is_leaf() && !_requires_grad) {
throw std::runtime_error(
autograd::utils::requires_grad_leaf_error(_requires_grad)
);
}
self.set_requires_grad(_requires_grad);
}
// Backward View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
bool VariableHooks::is_view(const Tensor& self) const {
auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
if (diff_view_meta) {
return diff_view_meta->has_bw_view();
} else {
return false;
}
}
const Tensor& VariableHooks::base(const Tensor& self) const {
auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
if (diff_view_meta) {
TORCH_CHECK(diff_view_meta->has_bw_view(), "Can't get base of non-backward view Tensor");
return diff_view_meta->get_backward_view().base_;
} else {
throw std::runtime_error("Can't get base of non-view Tensor");
}
}
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
std::string singleton_string;
}
const std::string& VariableHooks::name(const Tensor& self) const {
TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor");
if (torch::autograd::impl::get_autograd_meta(self)) {
return torch::autograd::impl::get_autograd_meta(self)->name_;
} else {
return singleton_string;
}
}
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
std::shared_ptr<torch::autograd::Node> singleton_shared_ptr;
}
const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tensor& self) const {
auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
if (diff_view_meta && diff_view_meta->has_bw_view()) {
// See NOTE [ View + Inplace detection ]
if (diff_view_meta->get_creation_meta() != CreationMeta::MULTI_OUTPUT_SAFE) {
std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
auto view_info = diff_view_meta->get_backward_view();
if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) {
return diff_view_meta->grad_fn_;
}
auto current_version = self._version();
if (diff_view_meta->get_attr_version() != current_version) {
// This is an indirect rebase_history due to another view or the base being modified inplace
handle_view_on_rebase(diff_view_meta, /* indirect */ true);
TORCH_INTERNAL_ASSERT(diff_view_meta->output_nr_ == 0);
// Note [View + Inplace update for view tensor]
// An inplace update happened on Tensor `self` (which is a view).
// For example:
// view_1 = view_op_1(diff_view_meta->base_)
// view_2 = view_op_2(view_1)
// ...
// self = view_op_n(view_n-1)
// self = inplace_op(self)
//
// For CPU/CUDA backends, we employ one AsStridedBackward Node to represent the chain of
// view backward ops for effienciency.
//
// However in XLA backend we don't have full support of AsStridedBackward, we instead run a full
// forward pass with a tensor that requires gradient to get proper grad_fn setup,
// then save it to DifferentiableViewMeta for future use.
// This is fairly cheap for XLA lazy tensor approach (but would be really expensive for CPU/CUDA).
// XLA Tensor only run thorugh VariableType dispatch and lower the forward pass to a XLA HLO graph,
// then we take grad_fn and never materialize the tensor content.
// So we only construct the graph but not execute it, which is a fairly cheap operation to do.
//
// See Note [View + Inplace update for base tensor] for what we do to base tensor when
// an in-place operation happens.
//
// TODO: Potentially the following logic can be replaced by special logic in VariableType_x.cpp
// that would provide a way to recreate the grad_fn chain.
if (view_info.has_view_fn()) {
auto view_fn = view_info.view_fn();
auto diff_view = view_fn(view_info.base_);
diff_view_meta->grad_fn_ = diff_view.grad_fn();
} else {
auto fn = std::make_shared<torch::autograd::generated::AsStridedBackward>();
fn->self_geometry = at::TensorGeometry(view_info.base_);
fn->size = self.sizes().vec();
fn->stride = self.strides().vec();
fn->storage_offset = self.storage_offset();
fn->set_next_edges(torch::autograd::collect_next_edges(view_info.base_));
fn->add_input_metadata(
view_info.base_.options(),
self.sizes(), // Note: sizes(), not base_.sizes(), is intentional
view_info.base_.device());
diff_view_meta->grad_fn_ = std::move(fn);
}
diff_view_meta->set_attr_version(current_version);
}
return diff_view_meta->grad_fn_;
}
}
if (torch::autograd::impl::get_autograd_meta(self)) {
return torch::autograd::impl::get_autograd_meta(self)->grad_fn_;
} else {
return singleton_shared_ptr;
}
}
void VariableHooks::remove_hook(const Tensor& self, unsigned pos) const {
auto &list = torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list_;
TORCH_CHECK(list && pos < list->size() , "Invalid index, no hook at position ", pos);
// Hook will be ignored
(*list)[pos] = nullptr;
}
unsigned VariableHooks::_register_hook(const Tensor& self, std::function<Tensor(const Tensor&)> hook) const {
TORCH_CHECK(self.requires_grad(), "cannot register a hook on a variable that "
"doesn't require gradient");
// NB: materialize_autograd_meta unnecessary due to requires grad check
auto &list = torch::autograd::impl::get_autograd_meta(self)->cpp_hooks_list_;
if(!list) {
torch::autograd::impl::create_cpp_hook(self);
}
unsigned idx = list->size();
list->push_back(hook);
return idx;
}
void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect) {
/// See NOTE [ View + Inplace detection ] for justification of the logic below
auto creation_meta = diff_view_meta->get_creation_meta();
if (creation_meta != CreationMeta::DEFAULT) {
auto grad_fn = diff_view_meta->grad_fn_.get();
std::string msg;
std::string modified_obj;
// Create the header for the error message.
if (indirect) {
modified_obj = "its base or another view of its base has been";
} else {
modified_obj = "is being";
}
if (grad_fn) {
msg = c10::str("Output ", diff_view_meta->output_nr_, " of ", grad_fn->name(), " is a view and ",
modified_obj, " modified inplace.");
} else if (creation_meta == CreationMeta::INFERENCE_MODE) {
msg = c10::str("A view was created in inference mode and ", modified_obj, " modified inplace in normal mode.");
} else {
msg = c10::str("A view was created in no_grad mode and ", modified_obj, " modified inplace with grad mode enabled.");
}
if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) {
TORCH_CHECK(false, msg, " This view is the output of a function that returns multiple views. Such functions do not"
" allow the output views to be modified inplace. You should replace the inplace operation by an"
" out-of-place one.");
} else {
if (creation_meta == CreationMeta::NO_GRAD_MODE) {
TORCH_INTERNAL_ASSERT(!grad_fn);
msg = c10::str(msg, " Given that this use case is ambiguous and error-prone, it is forbidden."
" You can clarify your code and remove this warning by moving both the view and the inplace either both"
" inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want"
" the inplace to be tracked).");
} else if (creation_meta == CreationMeta::INFERENCE_MODE) {
TORCH_INTERNAL_ASSERT(!grad_fn);
msg = c10::str(msg, " Given that this use case is ambiguous and error-prone, it is forbidden."
" You can clarify your code by moving both the view and the inplace either both"
" inside the inference_mode block (if you don't want the inplace to be tracked) or both outside (if you want"
" the inplace to be tracked).");
TORCH_CHECK(false, msg);
} else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) {
msg = c10::str(msg, " This view was created inside a custom Function (or because an input was returned as-is) and the"
" autograd logic to handle view+inplace would override the custom backward associated with the custom"
" Function, leading to incorrect gradients. This behavior is forbidden. You can remove this warning by"
" cloning the output of the custom Function.");
} else if (creation_meta == CreationMeta::MULTI_OUTPUT_SAFE) {
msg = c10::str(msg, " This view is an output of a function that "
"returns multiple views. Inplace operators on such "
"views is forbidden. You should replace the inplace "
"operation by an out-of-place one.");
} else {
TORCH_INTERNAL_ASSERT(false, "Invalid CreationMeta state");
}
if (creation_meta == CreationMeta::NO_GRAD_MODE) {
// TODO: remove this before 1.9 once all code is properly updated
TORCH_WARN(msg);
} else {
TORCH_CHECK(false, msg);
}
}
// We warn only once per view
// Note that if a Tensor is modified inplace from two threads at the same time, this is not thread safe and can warn
// multiple time. This is ok as it should be a rare event.
diff_view_meta->set_creation_meta(CreationMeta::DEFAULT);
}
}
}} // namespace torch::autograd