blob: 0b71d435c32c149ed31d2fdef1dd453fb7db1a06 [file] [log] [blame]
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/FunctionalInverses.h>
#include <ATen/TensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/IListRef.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/_to_copy.h>
#endif
namespace at {
void FunctionalTensorWrapper::set_constructor_metadata() {
TORCH_INTERNAL_ASSERT(value_.defined());
// Note: "level" is a concept that we don't know how to compute in core.
// For now I'm retroactively setting this in functorch,
// but once Open Multiple Dispatch lands we should be able to calculate this in core.
level_ = -1;
// mirror all of the generic tensor metadata onto the wrapper
copy_generic_tensor_metadata(value_.getIntrusivePtr().get(), this);
refresh_numel();
refresh_contiguous();
storage_access_should_throw_ = false;
// In general, the sizes/stride metadata on a tensor can change as it is mutated,
// and these changes need to be reflected in the metadata of the wrapper.
set_allow_tensor_metadata_change(true);
key_set_ = c10::DispatchKeySet(c10::DispatchKey::Functionalize) | value_.key_set();
// All of the keys corresponding to functorch transforms should not be copied over.
// Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect
// to participate in the functorch transforms.
key_set_ = key_set_ - c10::functorch_transforms_ks - c10::python_ks;
// We override a bunch of _custom(), so make sure they get called
// TODO: metadata copying may not actually be necessary then
set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
set_custom_device(true);
}
FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value)
: c10::TensorImpl(
c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(value)),
c10::DispatchKeySet(DispatchKey::Functionalize) | value.key_set(),
value.dtype()
),
value_(value)
{
set_constructor_metadata();
}
void FunctionalTensorWrapper::freeze_storage() const {
functional_storage_impl()->freeze();
}
// Note [Functionalization: Alias Removal]
// When someone calls a view() op during the functionalization pass, e.g. 'b = a.view(...)',
// we link `b` and `a` to a shared Alias object to preserve the aliasing relationship.
//
// How do we do that?
//
// Every FunctionalTensorWrapper contains a dummy FunctionalStorageImpl, which subclasses from c10::StorageImpl.
// It doesn't contain any data (similar to MetaTensor storage), but it contains an Alias object that knows about the base tensor.
// When a tensor is created through a view operation, both the new and old tensor point to the same FunctionalStorageImpl.
//
// As mutations are applied to any of the views, we also queue each mutation up on the Alias object, so we can replay them.
// When the user requests a tensor that's had a view taken, we check if it's up to date.
// If it's not up to date, we first replay all of the queued up mutations onto the alias, and then re-apply the current view
// on top of the newly updated alias.
//
// Why do we queue up and lazily run mutations on the alias, instead of updating the alias eagerly?
// This behavior was taken from pytorch/xla, which the alias-removal logic was inspired from.
// One benefit of the laziness is that we save work in the cases where a user has multiple views and mutates one of them,
// but never uses the other views later in the program (in which case we'll never update the alias).
// It also has downsides though: repeatedly applying mutations to the same view without syncing
// will silently use up more and more memory as more mutations are queued up.
//
// Corresponding diagram:
//
// b = a.view(...)
//
// a b
// | | If the user asks for b and it’s out of date,
// \/ \/ We regenerate b by replaying it’s views from the alias.
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - .
// | FunctionalTensorWrapper | | FunctionalTensorWrapper |
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - .
// | value | storage | | storage | Value |
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - .
// | \ / |
// | \ / |
// | . - - - - - - - - - - - - . |
// | | FunctionalStorageImpl | |
// | . - - - - - - - - - - - - . |
// | | Alias | |
// | . - - - - - - - - - - - - . |
// | / mutations to a or b |
// | / are queued onto Alias |
// | / |
// \/ / \/
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
// | TensorImpl | | TensorImpl |
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
// | value | storage | | storage | Value |
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
// | |
// | |
// | |
// | In this picture the two tensor views their own storages, |
// | have their own storages, but backends like functorch |
// \/ are allowed to re-alias underneath the pass \/
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
// | underyling_storage | | underyling_storage |
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
//
// This constructor is only used by view ops.
// - view_value: The output tensor that we need to wrap.
// - base: The "base" of the view that `view_value` was generated from.
// See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic.
FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, functionalization::ViewMeta meta)
: c10::TensorImpl(
c10::DispatchKeySet(DispatchKey::Functionalize),
view_value.dtype(),
view_value.device()
),
value_(view_value)
{
set_constructor_metadata();
// Copy the original tensor's ViewMeta vector and push the current one.
if (!base->view_metas_.empty()) {
view_metas_ = base->view_metas_; // copy
}
view_metas_.push_back(meta);
storage_ = base->storage_; // alias this tensor's storage with the base tensor's
}
functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const {
return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl());
}
void FunctionalTensorWrapper::commit_update() {
auto storage_impl = functional_storage_impl();
storage_impl->add_update(value_, view_metas_);
// As an optimization, we used to mark the tensor here as "up-to-date",
// That way, code like:
// x = torch.ones(1'000'000)
// x[0].add_(1)
// doesn't result in an unnecessary materialization of the base.
// This optimization results in the slice temporarily haven't incorrect
// stride/storage_offset though, and DCE should handle that optimization anyway.
// generation_ = storage_impl->generation();
}
bool FunctionalTensorWrapper::is_up_to_date() const {
auto alias_generation = functional_storage_impl()->generation();
return generation_ == alias_generation;
}
// See Note [Functionalization Pass - Inplace View Ops]
void FunctionalTensorWrapper::mutate_view_meta(at::functionalization::ViewMeta meta) {
view_metas_.push_back(meta);
// Note [Functionalization Pass - Inplace View Ops]
// So, these ops are special - they're mutation AND view ops. They get special codegen.
// An example is transpose_, e.g. `a.transpose_()`
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
value_ = meta.forward_fn(value_, meta.out_index);
}
// Note [Functionalization: Mutation Removal]
// Mutation removal is used to take a program like this:
//
// a.add_(b)
//
// and replace it with a slightly different program that has the same semantics:
//
// tmp = a.add(b)
// a.replace_(tmp)
//
// Where the replace_() call is implemented directly in the functionalization pass, so it is transparent to the backend.
// This is useful for backends that aren't able to handle certain types of mutations, like functorch.
//
// Why do we need to wrap every tensor in a FunctionalTensorWrapper? Consider this program:
//
// Before:
// tensor.add_(batched_tensor)
//
// After:
// tmp = tensor.add(batched_tensor)
// tensor.replace_(tmp)
//
// In the above, tmp is a batched tensor (because adding a normal tensor to a batched tensor does broadcasting and creates a batched tensor).
// But we can't just replace the underlying memory backing `tensor` with `tmp` - a batched tensor takes up more space!
// Instead, every input, intermediate and output of the program is wrapped in a FunctionalTensorImpl, which wraps the underlying tensor.
void FunctionalTensorWrapper::replace_(const Tensor& other) {
// TODO: going to need to change this if we want nested functionalize() transforms.
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other));
value_ = other;
// out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor.
// We need to propagate that metadata mutation to the wrapper (new size).
set_sizes_and_strides(value_.sym_sizes(), value_.sym_strides(), value_.sym_storage_offset());
if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
// .to() should not re-entrantly go through functionalization.
at::AutoDispatchSkipFunctionalize guard;
// and we want _to_copy() to show up in the graph, not the composite .to() operator
// (this can happen if autograd has already run by the time we enter this code)
value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout()));
}
}
void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) {
// Note [resize_() in functionalization pass]
// resize_() is a special operator in functionalization because it can reallocate its underlying storage.
// This function is only ever called in the case that resize_() needs to reallocate its storage to a larger size.
//
// However, functionalization currently bans the following code:
// a = torch.ones(2)
// b = a.view(2)
// b.resize_(4) # b is a view tensor, that we are trying to increase the storage size of
//
// Why is this code difficult to handle?
// The functionalization pass currently keeps aliases in sync by making the following assumptions:
// - The “base” tensor always refers to “all of the data”
// - Whenever you have b = view_op(a), “b” should always refer to a subset of “a”s memory.
//
// The code above breaks that assumption b.resize_(4) actually needs to update "a"
// to tell it that it is now actually some slice of a pre-existing larger storage.
// We're also no longer re-generate "b" fully from "a" anymore, since "a" refers to a slice of "b"'s data.
//
// This is probably fixable in theory, but:
// - the fix would likey complicated the functionalization logic quite a bit.
// - the primary use case for resize_() today is resizing zero-sized tensors in out= variants of operators
// - resize_() also can give you weird results today if you try to resize_() a weirdly strided tensor.
//
// Given all of the above, for now we're just banning the above usage.
TORCH_CHECK(storage().use_count() == 1, "Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass");
TORCH_CHECK(view_metas_.empty(), "Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass");
// If this tensor is not a view (and has no outstanding views taken out on it),
// Then it's safe to throw out the old storage and replace it with the new, larger one.
storage_ = c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(other));
value_ = other;
generation_ = 0;
// And update the metadata on the wrapper to reflect the new sizes and strides
set_sizes_and_strides(value_.sizes(), value_.strides());
refresh_numel();
// (Technically we should be guaranteed that the tensor was already contiguous,
// since it's guaranteed not to have been a view. Doesnt hurt to run though)
refresh_contiguous();
}
void FunctionalTensorWrapper::sync_() {
if (is_up_to_date()) {
return;
}
apply_updates();
regenerate_from_base();
}
void FunctionalTensorWrapper::regenerate_from_base() {
at::AutoDispatchSkipFunctionalize guard;
auto storage_impl = functional_storage_impl();
auto t = storage_impl->base();
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
// Reapply views to get the viewed tensor from the base in alias_
for (auto& view_meta: view_metas_) {
t = view_meta.forward_fn(t, view_meta.out_index);
}
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
replace_(t);
generation_ = storage_impl->generation();
}
bool FunctionalTensorWrapper::apply_updates() {
// Apply all updates on alias_
auto storage_impl = functional_storage_impl();
return storage_impl->apply_updates();
}
const char* FunctionalTensorWrapper::tensorimpl_type_name() const {
return "FunctionalTensorWrapper";
}
template <typename VariableVersion>
c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach_core(
VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
if (key_set_.has(DispatchKey::Python) &&
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
auto r = pyobj_slot_.load_pyobj_interpreter()->detach(this);
if (r) {
r->set_version_counter(std::forward<VariableVersion>(version_counter));
r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
return r;
}
}
auto impl = c10::make_intrusive<FunctionalTensorWrapper>(value_);
copy_tensor_metadata(
/*src_impl=*/this,
/*dest_impl=*/impl.get(),
/*version_counter=*/std::forward<VariableVersion>(version_counter),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->level_ = level_;
impl->generation_ = generation_;
impl->view_metas_ = view_metas_;
impl->refresh_numel();
impl->refresh_contiguous();
return impl;
}
c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const {
return shallow_copy_and_detach_core(
version_counter, allow_tensor_metadata_change);
}
c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
return shallow_copy_and_detach_core(
std::move(version_counter), allow_tensor_metadata_change);
}
c10::Device FunctionalTensorWrapper::device_custom() const {
return value_.unsafeGetTensorImpl()->device();
}
at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const {
return value_.unsafeGetTensorImpl()->sizes();
}
at::IntArrayRef FunctionalTensorWrapper::strides_custom() const {
return value_.unsafeGetTensorImpl()->strides();
}
int64_t FunctionalTensorWrapper::dim_custom() const {
return value_.unsafeGetTensorImpl()->dim();
}
int64_t FunctionalTensorWrapper::numel_custom() const {
return value_.unsafeGetTensorImpl()->numel();
}
bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const {
return value_.unsafeGetTensorImpl()->is_contiguous(memory_format);
}
c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const {
return value_.unsafeGetTensorImpl()->sym_sizes();
}
c10::SymIntArrayRef FunctionalTensorWrapper::sym_strides_custom() const {
return value_.unsafeGetTensorImpl()->sym_strides();
}
c10::SymInt FunctionalTensorWrapper::sym_size_custom(int64_t d) const {
return value_.unsafeGetTensorImpl()->sym_size(d);
}
c10::SymInt FunctionalTensorWrapper::sym_storage_offset_custom() const {
return value_.unsafeGetTensorImpl()->sym_storage_offset();
}
namespace functionalization {
namespace impl {
Tensor to_functional_tensor(const Tensor& tensor) {
// Note [Wrapped Numbers <> Functionalization]
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
return tensor;
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isFunctionalTensor(tensor));
return at::detail::make_tensor<FunctionalTensorWrapper>(tensor);
}
c10::optional<Tensor> to_functional_tensor(const c10::optional<Tensor>& tensor) {
if (tensor.has_value()) {
return c10::make_optional<Tensor>(to_functional_tensor(*tensor));
}
return c10::nullopt;
}
c10::List<c10::optional<Tensor>> to_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) {
c10::List<c10::optional<Tensor>> outputs;
outputs.reserve(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs.push_back(to_functional_tensor(t_list[i]));
}
return outputs;
}
std::vector<Tensor> to_functional_tensor(ITensorListRef t_list) {
std::vector<Tensor> outputs;
outputs.reserve(t_list.size());
for (const auto& tensor : t_list) {
outputs.push_back(to_functional_tensor(tensor));
}
return outputs;
}
Tensor from_functional_tensor(const Tensor& tensor, bool assert_functional) {
// Note [Wrapped Numbers <> Functionalization]
if (!tensor.defined() || tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
return tensor;
}
if (isFunctionalTensor(tensor)) {
auto impl = unsafeGetFunctionalWrapper(tensor);
return impl->value();
} else {
// If the current tensor is not functional, then raise an error
// if assert_functional is true. Otherwise, return the input.
TORCH_INTERNAL_ASSERT(!assert_functional)
return tensor;
}
}
c10::optional<Tensor> from_functional_tensor(const c10::optional<Tensor>& t, bool assert_functional) {
if (t.has_value()) {
return c10::make_optional<Tensor>(from_functional_tensor(*t, assert_functional));
}
return c10::nullopt;
}
std::vector<Tensor> from_functional_tensor(ITensorListRef t_list) {
std::vector<Tensor> outputs;
outputs.reserve(t_list.size());
for (const auto& tensor : t_list) {
// from_functional_tensor(Tensor) has asserts to make sure you don't accidentally call
// it on a non-functional input,
// but from_functional_tensor(TensorList) can recieve a list containing both
// functional and non-functional tensors.
// Example of when that can happen: torch.cat(function_input_tensor, global_state_tensor).
// When that happens, we're okay with only unwrapping the functional tensors.
outputs.push_back(from_functional_tensor(tensor, /*assert_functional=*/false));
}
return outputs;
}
c10::List<c10::optional<Tensor>> from_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) {
c10::List<c10::optional<Tensor>> outputs;
outputs.reserve(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs.push_back(from_functional_tensor(t_list[i], /*assert_functional=*/false));
}
return outputs;
}
void sync(const Tensor& t) {
if (t.unsafeGetTensorImpl()->is_wrapped_number()) {
// Note [Wrapped Numbers <> Functionalization]
// Unfortunately, we can't easily guarantee that wrapped numbers (scalar-tensors)
// get wrapped up in a FunctionalTensorWrapper object, since they skip the dispatcher.
// That shouldn't matter, since I don't think we're allowed to assign to wrapped numbers anyway.
return;
}
// Not every tensor that hits a functionalization kernel is necessarily a functional tensor.
// For example, xla_tensor.copy_(cpu_tensor) needs to hit the functionalization kernel
// to sync xla_tensor, but not cpu_tensor.
if (!at::functionalization::impl::isFunctionalTensor(t)) {
return;
}
auto functional_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
functional_impl->sync_();
}
void sync(const c10::optional<Tensor>& t) {
if (t.has_value()) {
sync(*t);
}
}
void sync(ITensorListRef t_list) {
for (const auto& t : t_list) {
sync(t);
}
}
void sync(const c10::List<c10::optional<Tensor>> t_list) {
for (const auto i : c10::irange(t_list.size())) {
sync(t_list[i]);
}
}
void replace_(const Tensor& functional_tensor, const Tensor& other) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
unsafeGetFunctionalWrapper(functional_tensor)->replace_(other);
}
void replace_(const ITensorListRef functional_tensor, ITensorListRef other) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size());
auto functional_tensor_it = functional_tensor.begin();
auto other_it = other.begin();
for (const auto i : c10::irange(functional_tensor.size())) {
(void)i; // Suppress unused variable warning
replace_(*functional_tensor_it++, *other_it++);
}
}
void commit_update(const Tensor& functional_tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
unsafeGetFunctionalWrapper(functional_tensor)->commit_update();
}
void commit_update(ITensorListRef functional_tensor) {
for (const auto& t : functional_tensor) {
commit_update(t);
}
}
bool isFunctionalTensor(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
}
bool isFunctionalTensor(const c10::optional<Tensor>& t) {
if (t.has_value()) {
return isFunctionalTensor(*t);
} else {
return false;
}
}
bool isFunctionalTensor(const c10::List<c10::optional<Tensor>>& t_list) {
if (t_list.empty()) return false;
auto functional_count = 0;
for (const auto i : c10::irange(t_list.size())) {
if (!t_list[i].has_value() || !t_list[i]->defined()) continue;
if (isFunctionalTensor(t_list[i])) {
++functional_count;
}
}
return functional_count > 0;
}
template <typename T>
bool isFunctionalTensorIListRef(c10::IListRef<T> list) {
if (list.size() == 0) return false;
auto functional_count = 0;
for (const auto& tensor : list) {
if (!tensor.defined()) continue;
if (isFunctionalTensor(tensor)) {
++functional_count;
}
}
return functional_count > 0;
}
bool isFunctionalTensor(ITensorListRef list) {
return isFunctionalTensorIListRef(list);
}
void freeze_functional_tensor(const Tensor& tensor) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(tensor));
auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
functional_base_impl->freeze_storage();
}
Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap));
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base));
auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base);
if (out_idx != 0) {
// Note [out_idx in ViewMeta]
// When a view op outputs multiple tensors, each output needs its own separate ViewMeta.
// Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function.
meta = meta.to_out_idx(out_idx);
}
return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta);
}
std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta) {
std::vector<Tensor> outputs(view_to_wrap.size());
int64_t i = 0;
for (const auto& tensor : view_to_wrap) {
outputs[i] = create_functional_tensor_with_view_meta(tensor, base, meta, i);
i++;
}
return outputs;
}
void mutate_view_meta(const at::Tensor& self, functionalization::ViewMeta meta) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
self_impl->mutate_view_meta(std::move(meta));
}
// Note [Propagating strides in the functionalization pass]
// In order to properly compute stride information, the functionalization pass
// calls each {view} reference implementations with meta tensors.
// The output meta tensor's stride info serves as a reference for what the correct strides should be.
void set_sizes_strides_offset(const Tensor& out, const Tensor& reference_out) {
out.unsafeGetTensorImpl()->set_sizes_and_strides(reference_out.sym_sizes(), reference_out.sym_strides(), reference_out.sym_storage_offset());
}
void set_sizes_strides_offset(const std::vector<Tensor>& outs, const std::vector<Tensor>& reference_outs) {
TORCH_INTERNAL_ASSERT(outs.size() == reference_outs.size());
for (const auto i : c10::irange(reference_outs.size())) {
set_sizes_strides_offset(outs[i], reference_outs[i]);
}
}
thread_local bool _functionalizationReapplyViews;
bool getFunctionalizationReapplyViewsTLS() {
return _functionalizationReapplyViews;
}
void setFunctionalizationReapplyViewsTLS(bool reapply_views) {
_functionalizationReapplyViews = reapply_views;
}
} // namespace impl
// Given an **out-of-place** op that might internally call view/inplace ops,
// This function will "functionalize" it.
// That is, it will call the operator, but removing any intermediate views/mutations
// that are performed inside of it.
// This is useful for LTC/XLA, which would like to re-use some of our composite kernels
// from pytorch core but not have to worry about the view ops that they might call.
// e.g. at::block_diag
void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
const auto arguments_begin = stack->size() - num_arguments;
auto arguments = torch::jit::last(stack, num_arguments);
// Wrap all tensor-like inputs into FunctionalTensorWrappers.
// When we re-invoke the dispatcher, this will automatically enable the functionalization pass.
for (uint64_t idx = 0; idx < num_arguments; ++idx) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
const auto& t = ivalue.toTensor();
if (t.defined()) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t),
"The composite op functionalization fallback expects its inputs all not to be functional tensors");
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
(*stack)[arguments_begin + idx] = t_new;
}
} else if (ivalue.isTensorList()) {
auto tensors = ivalue.toTensorList();
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensors),
"The composite op functionalization fallback expects its inputs all not to be functional tensors");
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
(*stack)[arguments_begin + idx] = t_new;
} else if (ivalue.isOptionalTensorList()) {
auto opt_tensors = ivalue.toOptionalTensorList();
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(opt_tensors),
"The composite op functionalization fallback expects its inputs all not to be functional tensors");
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
(*stack)[arguments_begin + idx] = t_new;
}
}
{
// Today when you call at::empty(device=lazy), the lazy backend decides whether or not to wrap
// the output in a functional tensor based on TLS.
// In this code, we're re-entrantly entering functionalization in the same call-stack,
// so we need to manually fix up TLS as if it hadn't already been called.
auto curr_tls = c10::impl::tls_local_dispatch_key_set();
auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet();
tls_reenable_functionalize.set_included(curr_tls.included_);
tls_reenable_functionalize.set_excluded(curr_tls.excluded_.remove(c10::DispatchKey::Functionalize));
c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize);
// So, we should probably provide a way to directly call a kernel registered to
// the `CompositeExplicitAutograd` key.
// We can't do that today, so this should be a reasonably good proxy
// (It won't work in cases where an op has both a CompositeExplicitAutograd kernel
// AND a dedicated meta kernel, but that probably shouldn't ever happen).
op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::Meta), stack);
}
const auto num_returns = schema.returns().size();
const auto returns_begin = stack->size() - num_returns;
auto returns = torch::jit::last(stack, num_returns);
for (const auto idx : c10::irange(num_returns)) {
const auto& ivalue = returns[idx];
if (ivalue.isTensor()) {
const auto& t = ivalue.toTensor();
if (!t.defined()) continue;
at::functionalization::impl::sync(t);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isTensorList()) {
auto tensors = ivalue.toTensorList();
at::functionalization::impl::sync(tensors);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isOptionalTensorList()) {
auto opt_tensors = ivalue.toOptionalTensorList();
at::functionalization::impl::sync(opt_tensors);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
(*stack)[returns_begin + idx] = t_new;
}
}
}
} // namespace functionalization
} // namespace at