| // Copyright (c) Facebook, Inc. and its affiliates. |
| // All rights reserved. |
| // |
| // This source code is licensed under the BSD-style license found in the |
| // LICENSE file in the root directory of this source tree. |
| |
| #include <torch/extension.h> |
| #include <ATen/WrapDimUtils.h> |
| #include <ATen/FunctionalTensorWrapper.h> |
| |
| #include <ATen/functorch/TensorWrapper.h> |
| #include <ATen/functorch/DynamicLayer.h> |
| #include <ATen/functorch/BatchedTensorImpl.h> |
| #include <ATen/functorch/LegacyVmapTransforms.h> |
| #include <ATen/functorch/BatchedFallback.h> |
| #include <ATen/functorch/BatchRulesHelper.h> |
| #include <ATen/functorch/PlumbingHelper.h> |
| #include <torch/csrc/functorch/CompileCache.h> |
| #include <c10/core/AutogradState.h> |
| #include <functorch/csrc/dim/dim.h> |
| |
| // This file contains functorch's Python bindings. |
| |
| namespace at { |
| namespace functorch { |
| |
| static bool has_level(const Tensor& self, int64_t level) { |
| const auto* batched = maybeGetBatchedImpl(self); |
| if (!batched) { |
| return false; |
| } |
| return batched->level() >= level; |
| } |
| |
| Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) { |
| return addBatchDim(self, batch_dim, level); |
| } |
| |
| Tensor _wrap_functional_tensor(const Tensor& self, int64_t level) { |
| auto t = at::functionalization::impl::to_functional_tensor(self); |
| at::functionalization::impl::unsafeGetFunctionalWrapper(t)->set_level(level); |
| return t; |
| } |
| |
| void _assert_wrapped_functional(const Tensor& unwrapped, const Tensor& wrapped) { |
| TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(wrapped)); |
| TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(unwrapped)); |
| auto wrapped_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(wrapped); |
| auto& wrapped_inner = wrapped_impl->value(); |
| TORCH_INTERNAL_ASSERT(unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl()) |
| } |
| |
| void _propagate_functional_input_mutation(const Tensor& unwrapped, const Tensor& wrapped) { |
| TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(wrapped)); |
| TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(unwrapped)); |
| auto wrapped_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(wrapped); |
| // Ensure that the input is up to date by committing any pending updates to the alias. |
| wrapped_impl->sync_(); |
| auto& wrapped_inner = wrapped_impl->value(); |
| // It would probably be more reasonable to check that the two tensors are aliased, |
| // but we can't do that unless we give BatchedTensorImpl a notion of storage. |
| if (unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl()) { |
| } else { |
| if (unwrapped.nbytes() != wrapped_inner.nbytes()) { |
| // Functions might resize zero-sized inputs, which we need to reflect ehre. |
| unwrapped.resize_(wrapped_inner.sizes()); |
| } |
| // If the input tensor's metadata was mutated, then use as_strided_() |
| // to propagate the metadata change. |
| if (unwrapped.sizes() != wrapped_inner.sizes()) { |
| unwrapped.as_strided_(wrapped_inner.sizes(), wrapped_inner.strides()); |
| } |
| unwrapped.copy_(wrapped_inner); |
| } |
| } |
| |
| |
| static std::pair<Tensor,int64_t> remove_existing_batch_dim( |
| const BatchedTensorImpl* batched, int64_t level) { |
| |
| TORCH_INTERNAL_ASSERT(batched->level() == level); |
| return std::make_pair(batched->value(), batched->bdim()); |
| } |
| |
| // Poor man's version of np.moveaxis. Moves the dimension at `dst` to `src` |
| // while preserving the order of other existing dimensions. |
| // We should probably add np.moveaxis (it is more general) to PyTorch. (#36048) |
| // When we do, replace the following with it. |
| static Tensor _movedim(const Tensor& self, int64_t src, int64_t dst) { |
| auto logical_dim = self.dim(); |
| src = maybe_wrap_dim(src, logical_dim); |
| dst = maybe_wrap_dim(dst, logical_dim); |
| if (src == dst) { |
| return self; |
| } |
| VmapDimVector permutation; |
| permutation.reserve(logical_dim); |
| for (int64_t dim = 0; dim < logical_dim; dim++) { |
| if (dim == src) { |
| continue; |
| } |
| permutation.push_back(dim); |
| } |
| permutation.insert(permutation.begin() + dst, src); |
| return self.permute(permutation); |
| } |
| |
| // Removes the batch dim with level `level` from `self`. If this causes the |
| // last batch dim to be removed from a BatchedTensor, then this returns a |
| // regular Tensor. |
| // |
| // If the `level` of the batch dim to remove does not exist in `self`, then we |
| // add the batch dim in. This can happen if `self` didn't interact with a tensor |
| // inside the vmap level, for example, |
| // self = torch.randn(3) |
| // y = torch.randn(5) |
| // out = vmap(lambda x: vmap(lambda y: x)(y))(self) |
| // assert out.shape == (3, 5) |
| // Inside the inner vmap, `x` is a BatchedTensor with a single batch dimension |
| // corresponding to the *outer* vmap level and it doesn't have any dimensions that |
| // correspond to the inner vmap level so we need to create one for the user. |
| // |
| // `out_dim` controls where we should put the batch dimension in the output tensor. |
| Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, int64_t out_dim) { |
| if (!has_level(self, level)) { |
| auto self_sizes = self.sizes(); |
| VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end()); |
| expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size); |
| auto result = self.expand(expanded_sizes); |
| return result; |
| } |
| |
| // Must be batched if has_level(self, /*any_level*/) |
| const auto* batched = maybeGetBatchedImpl(self); |
| TORCH_INTERNAL_ASSERT(batched != nullptr); |
| |
| Tensor self_without_bdim; |
| int64_t newly_exposed_logical_dim; |
| std::tie(self_without_bdim, newly_exposed_logical_dim) = remove_existing_batch_dim(batched, level); |
| auto result = _movedim(self_without_bdim, newly_exposed_logical_dim, out_dim); |
| return result; |
| } |
| |
| Tensor _unwrap_functional_tensor(const Tensor& self, bool add_back_views) { |
| // We only ever call that after popping out of a functionalize() call, in which case the current tensors |
| // should always be wrapped in a FunctionalTensorWrapper. |
| TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); |
| auto functional = at::functionalization::impl::unsafeGetFunctionalWrapper(self); |
| |
| // when regenerating the (potentially mutated) input tensors, the functionalization pass |
| // regenerates them through a series of view_copy() op calls. |
| // Functorch wants to turn those back into view ops though. |
| // Ensure that the input is up to date by committing any pending updates to the alias. |
| at::functionalization::impl::FunctionalizationReapplyViewsGuard guard(add_back_views); |
| bool any_updates = functional->apply_updates(); |
| if (any_updates) { |
| functional->regenerate_from_base(); |
| } |
| return functional->value(); |
| } |
| |
| Tensor _wrap_for_grad(const Tensor& self, int64_t level) { |
| // NB: different behavior inside?? |
| // return self; |
| // TORCH_INTERNAL_ASSERT(!maybeGetTensorWrapper(self)); |
| // TORCH_INTERNAL_ASSERT(self.has_storage()); |
| return makeTensorWrapper(self, level); |
| } |
| |
| Tensor _unwrap_for_grad(const Tensor& self, int64_t level) { |
| auto* result = maybeGetTensorWrapper(self); |
| if (!result) { |
| return self; |
| } |
| TORCH_INTERNAL_ASSERT(result->level().has_value()); |
| if (result->level() == level) { |
| return result->value(); |
| } |
| return self; |
| } |
| |
| int64_t dlevel(const Tensor& tensor) { |
| auto* wrapped = maybeGetTensorWrapper(tensor); |
| if (!wrapped) { |
| return 0; |
| } |
| if (!wrapped->is_alive()) { |
| return -1; |
| } |
| return wrapped->level().value(); |
| } |
| |
| bool dump_tensor(const Tensor& self) { |
| dumpTensorCout(self); |
| return true; |
| } |
| |
| RandomnessType get_randomness_enum(const std::string& randomness) { |
| if (randomness == "error") { |
| return RandomnessType::Error; |
| } else if (randomness == "same") { |
| return RandomnessType::Same; |
| } else if (randomness == "different") { |
| return RandomnessType::Different; |
| } else { |
| TORCH_CHECK(false, "randomness argument must be error, same, or different."); |
| } |
| } |
| |
| void set_fwd_grad_enabled(bool enabled) { |
| AutogradState::get_tls_state().set_fw_grad_mode(enabled); |
| } |
| |
| bool get_fwd_grad_enabled() { |
| return AutogradState::get_tls_state().get_fw_grad_mode(); |
| } |
| |
| int64_t _grad_increment_nesting() { |
| // See NOTE [grad and vjp interaction with no_grad] |
| bool prev_grad_mode = c10::GradMode::is_enabled(); |
| return initAndPushDynamicLayer(TransformType::Grad, nullopt, nullopt, prev_grad_mode); |
| } |
| |
| int64_t _grad_decrement_nesting() { |
| auto layer = popDynamicLayerAndDeleteMetadata(); |
| TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Grad); |
| return layer.layerId(); |
| } |
| |
| int64_t _jvp_increment_nesting() { |
| // See NOTE [grad and vjp interaction with no_grad] |
| bool prev_fwd_grad_mode = get_fwd_grad_enabled(); |
| return initAndPushDynamicLayer(TransformType::Jvp, nullopt, nullopt, nullopt, prev_fwd_grad_mode); |
| } |
| |
| int64_t _jvp_decrement_nesting() { |
| auto layer = popDynamicLayerAndDeleteMetadata(); |
| TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Jvp); |
| return layer.layerId(); |
| } |
| |
| int64_t _vmap_increment_nesting(int64_t batch_size, const std::string& randomness) { |
| return initAndPushDynamicLayer(TransformType::Vmap, batch_size, get_randomness_enum(randomness)); |
| } |
| |
| int64_t _vmap_decrement_nesting() { |
| auto layer = popDynamicLayerAndDeleteMetadata(); |
| TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Vmap); |
| return layer.layerId(); |
| } |
| |
| int64_t _func_increment_nesting(bool reapply_views) { |
| return initAndPushDynamicLayer(TransformType::Functionalize, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt, /*functionalize_add_back_views=*/reapply_views); |
| } |
| |
| int64_t _func_decrement_nesting() { |
| auto layer = popDynamicLayerAndDeleteMetadata(); |
| TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Functionalize); |
| return layer.layerId(); |
| } |
| |
| static bool is_batchedtensor(const Tensor& tensor) { |
| auto* batched = maybeGetBatchedImpl(tensor); |
| return batched != nullptr; |
| } |
| |
| static bool is_gradtrackingtensor(const Tensor& tensor) { |
| auto* wrapped = maybeGetTensorWrapper(tensor); |
| return wrapped != nullptr; |
| } |
| |
| static bool is_functionaltensor(const Tensor& tensor) { |
| return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize); |
| } |
| |
| static Tensor get_unwrapped(const Tensor& tensor) { |
| auto* batched = maybeGetBatchedImpl(tensor); |
| if (batched) { |
| return batched->value(); |
| } |
| auto* wrapped = maybeGetTensorWrapper(tensor); |
| if (wrapped) { |
| return wrapped->value(); |
| } |
| auto* functional = dynamic_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl()); |
| if (functional) { |
| return functional->value(); |
| } |
| TORCH_CHECK(false, "No wrappers present!"); |
| } |
| |
| static int64_t maybe_get_level(const Tensor& tensor) { |
| auto* batched = maybeGetBatchedImpl(tensor); |
| if (batched) { |
| return batched->level(); |
| } |
| auto* wrapped = maybeGetTensorWrapper(tensor); |
| if (wrapped) { |
| if (wrapped->level()) { |
| return *wrapped->level(); |
| } |
| // TODO: this is a weird special case... |
| return -2; |
| } |
| auto* functional = dynamic_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl()); |
| if (functional) { |
| return functional->level(); |
| } |
| return -1; |
| } |
| |
| static int64_t maybe_get_bdim(const Tensor& tensor) { |
| auto* batched = maybeGetBatchedImpl(tensor); |
| if (batched) { |
| return batched->bdim(); |
| } |
| return -1; |
| } |
| |
| static int64_t currentLevel() { |
| auto maybe_layer = maybeCurrentDynamicLayer(); |
| TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); |
| int64_t current_level = maybe_layer->layerId(); |
| return current_level; |
| } |
| |
| static void tls_set_vmap_excluded(bool excluded) { |
| c10::impl::tls_set_dispatch_key_excluded(DispatchKey::FuncTorchBatched, excluded); |
| } |
| |
| static bool tls_set_is_included() { |
| return c10::impl::tls_is_dispatch_key_included(DispatchKey::FuncTorchDynamicLayerFrontMode); |
| } |
| |
| static void _set_dynamic_layer_keys_included(bool value) { |
| return setDynamicLayerFrontBackKeysIncluded(value); |
| } |
| |
| static void dump_dls() { |
| std::cout << getDynamicLayerStack() << std::endl; |
| } |
| |
| static void dump_local_tls() { |
| auto tls = c10::impl::tls_local_dispatch_key_set(); |
| std::cout << "[Local Include] " << tls.included_ << std::endl; |
| std::cout << "[Local Exclude] " << tls.excluded_ << std::endl; |
| } |
| |
| } // namespace functorch |
| } |
| |
| |
| namespace at { namespace functorch { |
| |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| m.def("_add_batch_dim", &at::functorch::_add_batch_dim, "add batch dim"); |
| m.def("_remove_batch_dim", &at::functorch::_remove_batch_dim, "remove batch dim"); |
| m.def("_wrap_functional_tensor", &at::functorch::_wrap_functional_tensor, "add functional tensor"); |
| m.def("_assert_wrapped_functional", &at::functorch::_assert_wrapped_functional, "assert wrapped functional"); |
| m.def("_propagate_functional_input_mutation", &at::functorch::_propagate_functional_input_mutation, "propagate functional input mutations"); |
| m.def("_unwrap_functional_tensor", &at::functorch::_unwrap_functional_tensor, "remove functional tensor"); |
| m.def("_vmap_increment_nesting", &at::functorch::_vmap_increment_nesting, "remove batch dim"); |
| m.def("_vmap_decrement_nesting", &at::functorch::_vmap_decrement_nesting, "remove batch dim"); |
| m.def("_func_increment_nesting", &at::functorch::_func_increment_nesting, "functionalization start"); |
| m.def("_func_decrement_nesting", &at::functorch::_func_decrement_nesting, "functionalization end"); |
| m.def("_grad_increment_nesting", &at::functorch::_grad_increment_nesting, "remove batch dim"); |
| m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim"); |
| m.def("_jvp_increment_nesting", &at::functorch::_jvp_increment_nesting); |
| m.def("_jvp_decrement_nesting", &at::functorch::_jvp_decrement_nesting); |
| m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "wrap as gradtrackingtensor"); |
| m.def("_unwrap_for_grad", &at::functorch::_unwrap_for_grad, "unwrap from gradtrackingtensor"); |
| m.def("_set_vmap_fallback_warning_enabled", &at::functorch::setVmapFallbackWarningEnabled, "Set vmap fallback warnings"); |
| m.def("_set_vmap_fallback_enabled", &at::functorch::setVmapFallbackEnabled); |
| m.def("_is_vmap_fallback_enabled", &at::functorch::isVmapFallbackEnabled); |
| m.def("set_inplace_requires_grad_allowed", &at::functorch::setInplaceRequiresGradAllowed); |
| m.def("get_inplace_requires_grad_allowed", &at::functorch::getInplaceRequiresGradAllowed); |
| m.def("dlevel", &at::functorch::dlevel, "dlevel"); |
| m.def("dump_tensor", &at::functorch::dump_tensor, "dump_tensor"); |
| m.def("reshape_dim_into", &at::functorch::reshape_dim_into); |
| m.def("reshape_dim_outof", &at::functorch::reshape_dim_outof); |
| m.def("are_transforms_active", &at::functorch::areTransformsActive); |
| // various debugging things. Maybe we should offer these as first-class APIs |
| // on Tensors? |
| m.def("is_batchedtensor", &at::functorch::is_batchedtensor); |
| m.def("is_gradtrackingtensor", &at::functorch::is_gradtrackingtensor); |
| m.def("is_functionaltensor", &at::functorch::is_functionaltensor); |
| m.def("get_unwrapped", &at::functorch::get_unwrapped); |
| m.def("maybe_get_level", &at::functorch::maybe_get_level); |
| m.def("maybe_get_bdim", &at::functorch::maybe_get_bdim); |
| m.def("current_level", &at::functorch::currentLevel); |
| m.def("tls_set_vmap_excluded", &at::functorch::tls_set_vmap_excluded); |
| m.def("tls_set_is_included", &at::functorch::tls_set_is_included); |
| m.def("_set_dynamic_layer_keys_included", &at::functorch::_set_dynamic_layer_keys_included); |
| m.def("dump_dls", &at::functorch::dump_dls); |
| m.def("dump_local_tls", &at::functorch::dump_local_tls); |
| m.def("set_fwd_grad_enabled", &at::functorch::set_fwd_grad_enabled); |
| m.def("get_fwd_grad_enabled", &at::functorch::get_fwd_grad_enabled); |
| |
| torch::functorch::initCompileCacheBindings(m.ptr()); |
| |
| // initialize first-class dims and install it as a submodule on _C |
| auto dim = Dim_init(); |
| if (!dim) { |
| throw py::error_already_set(); |
| } |
| py::setattr(m, "dim", py::reinterpret_steal<py::object>(dim)); |
| } |
| |
| }} |