| |
| #include <ATen/FunctionalTensorWrapper.h> |
| |
| #include <ATen/FunctionalInverses.h> |
| #include <ATen/TensorUtils.h> |
| #include <ATen/WrapDimUtils.h> |
| #include <ATen/core/LegacyTypeDispatch.h> |
| #include <c10/util/Exception.h> |
| |
| #include <c10/util/irange.h> |
| |
| namespace at { |
| |
| void FunctionalTensorWrapper::set_constructor_metadata() { |
| TORCH_INTERNAL_ASSERT(value_.defined()); |
| // Note: "level" is a concept that we don't know how to compute in core. |
| // For now I'm retroactively setting this in functorch, |
| // but once Open Multiple Dispatch lands we should be able to calculate this in core. |
| level_ = -1; |
| // shallow_copy_from overwrites the storage and dispatch keyset... |
| auto functional_storage = storage_; |
| shallow_copy_from(value_.getIntrusivePtr()); |
| storage_ = functional_storage; |
| storage_access_should_throw_ = false; |
| key_set_ = c10::DispatchKeySet(c10::DispatchKey::Functionalize) | value_.key_set(); |
| } |
| |
| FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value) |
| : c10::TensorImpl( |
| c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(value)), |
| c10::DispatchKeySet(DispatchKey::Functionalize) | value.key_set(), |
| value.dtype() |
| ), |
| value_(value) |
| { |
| set_constructor_metadata(); |
| } |
| |
| // Note [Functionalization: Alias Removal] |
| // When someone calls a view() op during the functionalization pass, e.g. 'b = a.view(...)', |
| // we link `b` and `a` to a shared Alias object to preserve the aliasing relationship. |
| // |
| // How do we do that? |
| // |
| // Every FunctionalTensorWrapper contains a dummy FunctionalStorageImpl, which subclasses from c10::StorageImpl. |
| // It doesn't contain any data (similar to MetaTensor storage), but it contains an Alias object that knows about the base tensor. |
| // When a tensor is created through a view operation, both the new and old tensor point to the same FunctionalStorageImpl. |
| // |
| // As mutations are applied to any of the views, we also queue each mutation up on the Alias object, so we can replay them. |
| // When the user requests a tensor that's had a view taken, we check if it's up to date. |
| // If it's not up to date, we first replay all of the queued up mutations onto the alias, and then re-apply the current view |
| // on top of the newly updated alias. |
| // |
| // Why do we queue up and lazily run mutations on the alias, instead of updating the alias eagerly? |
| // This behavior was taken from pytorch/xla, which the alias-removal logic was inspired from. |
| // One benefit of the laziness is that we save work in the cases where a user has multiple views and mutates one of them, |
| // but never uses the other views later in the program (in which case we'll never update the alias). |
| // It also has downsides though: repeatedly applying mutations to the same view without syncing |
| // will silently use up more and more memory as more mutations are queued up. |
| // |
| // Corresponding diagram: |
| // |
| // b = a.view(...) |
| // |
| // a b |
| // | | If the user asks for b and it’s out of date, |
| // \/ \/ We regenerate b by replaying it’s views from the alias. |
| // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - . |
| // | FunctionalTensorWrapper | | FunctionalTensorWrapper | |
| // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - . |
| // | value | storage | | storage | Value | |
| // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - . |
| // | \ / | |
| // | \ / | |
| // | . - - - - - - - - - - - - . | |
| // | | FunctionalStorageImpl | | |
| // | . - - - - - - - - - - - - . | |
| // | | Alias | | |
| // | . - - - - - - - - - - - - . | |
| // | / mutations to a or b | |
| // | / are queued onto Alias | |
| // | / | |
| // \/ / \/ |
| // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . |
| // | TensorImpl | | TensorImpl | |
| // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . |
| // | value | storage | | storage | Value | |
| // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . |
| // | | |
| // | | |
| // | | |
| // | In this picture the two tensor views their own storages, | |
| // | have their own storages, but backends like functorch | |
| // \/ are allowed to re-alias underneath the pass \/ |
| // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . |
| // | underyling_storage | | underyling_storage | |
| // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . |
| // |
| // This constructor is only used by view ops. |
| // - view_value: The output tensor that we need to wrap. |
| // - base: The "base" of the view that `view_value` was generated from. |
| // See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic. |
| FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, functionalization::ViewMeta meta) |
| : c10::TensorImpl( |
| c10::DispatchKeySet(DispatchKey::Functionalize), |
| view_value.dtype(), |
| view_value.device() |
| ), |
| value_(view_value) |
| { |
| set_constructor_metadata(); |
| // Copy the original tensor's ViewMeta vector and push the current one. |
| if (base->view_metas_.size() > 0) { |
| view_metas_ = base->view_metas_; // copy |
| } |
| view_metas_.push_back(meta); |
| storage_ = base->storage_; // alias this tensor's storage with the base tensor's |
| } |
| |
| functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const { |
| return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl()); |
| } |
| |
| void FunctionalTensorWrapper::commit_update() { |
| auto storage_impl = functional_storage_impl(); |
| storage_impl->add_update(value_, view_metas_); |
| // Invariant: commit_update() is called during an inplace operation. |
| // Tensor inputs to the operation are synced before runnig the op, |
| // so the current tensor must be up-to-date with its alias at this point. |
| generation_ = storage_impl->generation(); |
| } |
| |
| bool FunctionalTensorWrapper::is_up_to_date() const { |
| auto alias_generation = functional_storage_impl()->generation(); |
| return generation_ == alias_generation; |
| } |
| |
| // See Note [Functionalization Pass - Inplace View Ops] |
| void FunctionalTensorWrapper::mutate_view_meta(at::functionalization::ViewMeta meta) { |
| view_metas_.push_back(meta); |
| // Note [Functionalization Pass - Inplace View Ops] |
| // So, these ops are special - they're mutation AND view ops. They get special codegen. |
| // An example is transpose_, e.g. `a.transpose_()` |
| // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas. |
| // We also need to force a sync (even if a is already up to date), because a's underlying tensor hasn't actually |
| // been updated to reflect the new view yet. |
| regenerate_from_base(); |
| } |
| |
| // Note [Functionalization: Mutation Removal] |
| // Mutation removal is used to take a program like this: |
| // |
| // a.add_(b) |
| // |
| // and replace it with a slightly different program that has the same semantics: |
| // |
| // tmp = a.add(b) |
| // a.replace_(tmp) |
| // |
| // Where the replace_() call is implemented directly in the functionalization pass, so it is transparent to the backend. |
| // This is useful for backends that aren't able to handle certain types of mutations, like functorch. |
| // |
| // Why do we need to wrap every tensor in a FunctionalTensorWrapper? Consider this program: |
| // |
| // Before: |
| // tensor.add_(batched_tensor) |
| // |
| // After: |
| // tmp = tensor.add(batched_tensor) |
| // tensor.replace_(tmp) |
| // |
| // In the above, tmp is a batched tensor (because adding a normal tensor to a batched tensor does broadcasting and creates a batched tensor). |
| // But we can't just replace the underlying memory backing `tensor` with `tmp` - a batched tensor takes up more space! |
| // Instead, every input, intermediate and output of the program is wrapped in a FunctionalTensorImpl, which wraps the underlying tensor. |
| void FunctionalTensorWrapper::replace_(const Tensor& other) { |
| // TODO: going to need to change this if we want nested functionalize() transforms. |
| TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other)); |
| value_ = other; |
| // out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor. |
| // We need to propagate that metadata mutation to the wrapper (new size). |
| set_sizes_and_strides(value_.sizes(), value_.strides()); |
| } |
| |
| |
| void FunctionalTensorWrapper::sync_() { |
| if (is_up_to_date()) { |
| return; |
| } |
| apply_updates(); |
| regenerate_from_base(); |
| } |
| |
| void FunctionalTensorWrapper::regenerate_from_base() { |
| at::AutoDispatchSkipFunctionalize guard; |
| auto storage_impl = functional_storage_impl(); |
| auto t = storage_impl->base(); |
| TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); |
| // Reapply views to get the viewed tensor from the base in alias_ |
| for (auto& view_meta: view_metas_) { |
| t = view_meta.forward_fn(t, view_meta.out_index); |
| } |
| TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); |
| replace_(t); |
| generation_ = storage_impl->generation(); |
| } |
| |
| void FunctionalTensorWrapper::apply_updates() { |
| // Apply all updates on alias_ |
| auto storage_impl = functional_storage_impl(); |
| storage_impl->apply_updates(); |
| } |
| |
| const char* FunctionalTensorWrapper::tensorimpl_type_name() const { |
| return "FunctionalTensorWrapper"; |
| } |
| |
| namespace functionalization { |
| namespace impl { |
| |
| Tensor to_functional_tensor(const Tensor& tensor) { |
| // Note [Wrapped Numbers <> Functionalization] |
| if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) { |
| return tensor; |
| } |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isFunctionalTensor(tensor)); |
| return at::detail::make_tensor<FunctionalTensorWrapper>(tensor); |
| } |
| c10::optional<Tensor> to_functional_tensor(const c10::optional<Tensor>& tensor) { |
| if (tensor.has_value()) { |
| return c10::make_optional<Tensor>(to_functional_tensor(*tensor)); |
| } |
| return c10::nullopt; |
| } |
| c10::List<Tensor> to_functional_tensor(const c10::List<Tensor>& t_list) { |
| c10::List<Tensor> outputs; |
| outputs.reserve(t_list.size()); |
| for (const auto i : c10::irange(t_list.size())) { |
| outputs.push_back(to_functional_tensor(t_list[i])); |
| } |
| return outputs; |
| } |
| c10::List<c10::optional<Tensor>> to_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) { |
| c10::List<c10::optional<Tensor>> outputs; |
| outputs.reserve(t_list.size()); |
| for (const auto i : c10::irange(t_list.size())) { |
| outputs.push_back(to_functional_tensor(t_list[i])); |
| } |
| return outputs; |
| } |
| std::vector<Tensor> to_functional_tensor(const std::vector<Tensor>& t_list) { |
| std::vector<Tensor> outputs(t_list.size()); |
| for (const auto i : c10::irange(t_list.size())) { |
| outputs[i] = to_functional_tensor(t_list[i]); |
| } |
| return outputs; |
| } |
| TensorList to_functional_tensor(const TensorList& t_list) { |
| std::vector<Tensor> outputs(t_list.size()); |
| for (const auto i : c10::irange(t_list.size())) { |
| outputs[i] = to_functional_tensor(t_list[i]); |
| } |
| return outputs; |
| } |
| |
| Tensor from_functional_tensor(const Tensor& tensor) { |
| // Note [Wrapped Numbers <> Functionalization] |
| if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) { |
| return tensor; |
| } |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(tensor)); |
| auto impl = unsafeGetFunctionalWrapper(tensor); |
| return impl->value(); |
| } |
| c10::optional<Tensor> from_functional_tensor(const c10::optional<Tensor>& t) { |
| if (t.has_value()) { |
| return c10::make_optional<Tensor>(from_functional_tensor(*t)); |
| } |
| return c10::nullopt; |
| } |
| c10::List<Tensor> from_functional_tensor(const c10::List<Tensor>& t_list) { |
| c10::List<Tensor> outputs; |
| outputs.reserve(t_list.size()); |
| for (const auto i : c10::irange(t_list.size())) { |
| outputs.push_back(from_functional_tensor(t_list[i])); |
| } |
| return outputs; |
| } |
| c10::List<c10::optional<Tensor>> from_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) { |
| c10::List<c10::optional<Tensor>> outputs; |
| outputs.reserve(t_list.size()); |
| for (const auto i : c10::irange(t_list.size())) { |
| outputs.push_back(from_functional_tensor(t_list[i])); |
| } |
| return outputs; |
| } |
| TensorList from_functional_tensor(const TensorList& t_list) { |
| std::vector<Tensor> outputs(t_list.size()); |
| for (const auto i : c10::irange(t_list.size())) { |
| outputs.push_back(from_functional_tensor(t_list[i])); |
| } |
| return outputs; |
| } |
| |
| void sync(const Tensor& t) { |
| if (t.unsafeGetTensorImpl()->is_wrapped_number()) { |
| // Note [Wrapped Numbers <> Functionalization] |
| // Unfortunately, we can't easily guarantee that wrapped numbers (scalar-tensors) |
| // get wrapped up in a FunctionalTensorWrapper object, since they skip the dispatcher. |
| // That shouldn't matter, since I don't think we're allowed to assign to wrapped numbers anyway. |
| return; |
| } |
| // Not every tensor that hits a functionalization kernel is necessarily a functional tensor. |
| // For example, xla_tensor.copy_(cpu_tensor) needs to hit the functionalization kernel |
| // to sync xla_tensor, but not cpu_tensor. |
| if (!at::functionalization::impl::isFunctionalTensor(t)) { |
| return; |
| } |
| auto functional_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); |
| functional_impl->sync_(); |
| } |
| void sync(const c10::optional<Tensor>& t) { |
| if (t.has_value()) { |
| sync(*t); |
| } |
| } |
| void sync(const c10::List<Tensor> t_list) { |
| for (const auto i : c10::irange(t_list.size())) { |
| sync(t_list[i]); |
| } |
| } |
| void sync(const at::TensorList t_list) { |
| for (auto t: t_list) { |
| sync(t); |
| } |
| } |
| void sync(const c10::List<c10::optional<Tensor>> t_list) { |
| for (const auto i : c10::irange(t_list.size())) { |
| sync(t_list[i]); |
| } |
| } |
| |
| bool isFunctionalTensor(const at::Tensor& tensor) { |
| return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize); |
| } |
| |
| bool isFunctionalTensor(const c10::optional<Tensor>& t) { |
| if (t.has_value()) { |
| return isFunctionalTensor(*t); |
| } else { |
| return false; |
| } |
| } |
| |
| bool isFunctionalTensor(const c10::List<Tensor>& t_list) { |
| if (t_list.size() == 0) return false; |
| bool any_functional = isFunctionalTensor(t_list[0]); |
| for (const auto i : c10::irange(1, t_list.size())) { |
| auto curr_functional = isFunctionalTensor(t_list[i]); |
| TORCH_INTERNAL_ASSERT( |
| curr_functional == any_functional, |
| "Functionalization encountered a list of tensors where some are functional", |
| "and some are not, which is not currently unsupported."); |
| } |
| return any_functional; |
| } |
| |
| bool isFunctionalTensor(const c10::List<c10::optional<Tensor>>& t_list) { |
| if (t_list.size() == 0) return false; |
| bool any_functional = isFunctionalTensor(t_list[0]); |
| for (const auto i : c10::irange(1, t_list.size())) { |
| auto curr_functional = isFunctionalTensor(t_list[i]); |
| TORCH_INTERNAL_ASSERT( |
| curr_functional == any_functional, |
| "Functionalization encountered a list of tensors where some are functional", |
| "and some are not, which is not currently unsupported."); |
| } |
| return any_functional; |
| } |
| |
| bool isFunctionalTensor(const c10::ArrayRef<Tensor> t_list) { |
| if (t_list.size() == 0) return false; |
| bool any_functional = isFunctionalTensor(t_list[0]); |
| for (const auto i : c10::irange(1, t_list.size())) { |
| auto curr_functional = isFunctionalTensor(t_list[i]); |
| TORCH_INTERNAL_ASSERT( |
| curr_functional == any_functional, |
| "Functionalization encountered a list of tensors where some are functional", |
| "and some are not, which is not currently unsupported."); |
| } |
| return any_functional; |
| } |
| |
| Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { |
| TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap)); |
| TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); |
| auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base); |
| if (out_idx != 0) { |
| // Note [out_idx in ViewMeta] |
| // When a view op outputs multiple tensors, each output needs its own separate ViewMeta. |
| // Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function. |
| meta = meta.to_out_idx(out_idx); |
| } |
| return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta); |
| } |
| |
| std::vector<Tensor> create_functional_tensor_with_view_meta(const c10::List<at::Tensor>& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta) { |
| std::vector<Tensor> outputs(view_to_wrap.size()); |
| for (const auto i : c10::irange(view_to_wrap.size())) { |
| outputs[i] = create_functional_tensor_with_view_meta(view_to_wrap[i], base, meta, i); |
| } |
| return outputs; |
| } |
| |
| std::vector<Tensor> create_functional_tensor_with_view_meta(const std::vector<at::Tensor>& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta) { |
| std::vector<Tensor> outputs(view_to_wrap.size()); |
| for (const auto i : c10::irange(view_to_wrap.size())) { |
| outputs[i] = create_functional_tensor_with_view_meta(view_to_wrap[i], base, meta, i); |
| } |
| return outputs; |
| } |
| |
| void mutate_view_meta(const at::Tensor& self, functionalization::ViewMeta meta) { |
| TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); |
| auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self); |
| self_impl->mutate_view_meta(meta); |
| } |
| |
| // Note [Propagating strides in the functionalization pass] |
| // In order to properly compute stride information, the functionalization pass |
| // calls each {view} reference implementations with meta tensors. |
| // The output meta tensor's stride info serves as a reference for what the correct strides should be. |
| void set_sizes_strides_offset(const Tensor& out, const Tensor& reference_out) { |
| out.unsafeGetTensorImpl()->set_sizes_and_strides(reference_out.sizes(), reference_out.strides()); |
| out.unsafeGetTensorImpl()->set_storage_offset(reference_out.storage_offset()); |
| } |
| |
| void set_sizes_strides_offset(const std::vector<Tensor>& outs, const std::vector<Tensor>& reference_outs) { |
| TORCH_INTERNAL_ASSERT(outs.size() == reference_outs.size()); |
| for (const auto i : c10::irange(reference_outs.size())) { |
| set_sizes_strides_offset(outs[i], reference_outs[i]); |
| } |
| } |
| |
| thread_local bool _functionalizationReapplyViews; |
| |
| bool getFunctionalizationReapplyViewsTLS() { |
| return _functionalizationReapplyViews; |
| } |
| void setFunctionalizationReapplyViewsTLS(bool reapply_views) { |
| _functionalizationReapplyViews = reapply_views; |
| } |
| |
| } // namespace impl |
| } // namespace functionalization |
| } // namespace at |