[functorch] Refactor DynamicLayer so the logic for each transform is separate (pytorch/functorch#756)
See functorch/csrc/Interpreter.h for the rundown of what happened.
diff --git a/functorch/functorch/csrc/ADInterpreters.cpp b/functorch/functorch/csrc/ADInterpreters.cpp
new file mode 100644
index 0000000..310df42
--- /dev/null
+++ b/functorch/functorch/csrc/ADInterpreters.cpp
@@ -0,0 +1,192 @@
+#include <functorch/csrc/ADInterpreters.h>
+#include <functorch/csrc/DynamicLayer.h>
+#include <functorch/csrc/TensorWrapper.h>
+
+namespace at { namespace functorch {
+
+static void checkForInvalidMutationOnCaptures(
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack,
+ int64_t cur_level) {
+ if (!isInplaceOp(op.schema())) {
+ return;
+ }
+ auto args = torch::jit::last(stack, op.schema().arguments().size());
+ auto mutated_arg = unwrapIfDead(args[0].toTensor());
+ auto* wrapper = maybeGetTensorWrapper(mutated_arg);
+ if (wrapper && wrapper->level().has_value() && wrapper->level().value() == cur_level) {
+ return;
+ }
+ TORCH_CHECK(false,
+ "During a grad (vjp, jvp, grad, etc) transform, the function provided ",
+ "attempted to call in-place operation (", op.schema().operator_name(), ") ",
+ "that would mutate a captured Tensor. This is not supported; please rewrite ",
+ "the function being transformed to explicitly accept the mutated Tensor(s) ",
+ "as inputs.");
+}
+
+static Tensor materializeGradWrappers(const Tensor& tensor, int64_t current_level) {
+ if (!tensor.defined()) {
+ return tensor;
+ }
+ auto* wrapper = maybeGetTensorWrapper(tensor);
+ if (!wrapper) {
+ return makeTensorWrapper(tensor, current_level);
+ }
+ TORCH_INTERNAL_ASSERT(wrapper->level().value() <= current_level, "escaped?");
+ if (wrapper->level().value() == current_level) {
+ TORCH_INTERNAL_ASSERT(tensor.defined());
+ return tensor;
+ }
+ return makeTensorWrapper(tensor, current_level);
+}
+
+static void autogradBasedTransformProcess(
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack,
+ int64_t current_level,
+ TransformType transform_type) {
+ // if is a grad transform, and the operation is in-place, and the mutated
+ // argument is not currently wrapped in a TensorWrapper, then we need to
+ // error out otherwise the result is silently incorrect
+ checkForInvalidMutationOnCaptures(op, stack, current_level);
+
+ // materialize live GradWrappers
+ auto maybeTransformGradWrappers = [&](const Tensor& tensor) {
+ return materializeGradWrappers(tensor, current_level);
+ };
+ auto num_args = op.schema().arguments().size();
+ foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), maybeTransformGradWrappers);
+
+ auto exclude = keysToExcludeWhenEnteringDynamicLayer(transform_type);
+ setup_dispatch_key_tls(exclude, {});
+ op.callBoxed(stack);
+}
+
+static void autogradBasedTransformSendToNext(
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack,
+ int64_t current_level,
+ TransformType transform_type,
+ optional<bool> prev_grad_mode,
+ optional<bool> prev_fwd_grad_mode) {
+ if (transform_type == TransformType::Grad) {
+ TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
+ }
+ if (transform_type == TransformType::Jvp) {
+ TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
+ }
+ auto unwrap = [&](const Tensor& tensor) {
+ if (!tensor.defined()) {
+ return tensor;
+ }
+ auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor);
+ if (!maybe_tensor_wrapper) {
+ return tensor;
+ }
+ auto tensor_wrapper_level = maybe_tensor_wrapper->level().value();
+ TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= current_level);
+ if (tensor_wrapper_level == current_level) {
+ return maybe_tensor_wrapper->value();
+ }
+ return tensor;
+ };
+ auto wrap = [&](const Tensor& tensor) {
+ if (!tensor.defined()) {
+ return tensor;
+ }
+ // if (c10::show_dispatch_trace_enabled()) {
+ // std::cout << "wrap " << current_level << std::endl;
+ // }
+ return makeTensorWrapper(tensor, current_level);
+ };
+
+ // TODO: we only need to do the following (marked with !) on in-place functions
+ // that modify sizes or strides. There aren't many of them.
+ // If autograd dispatch key:
+ // 1. (!) Put a copy of all of the args onto the stack
+ // 2. Unwrap all the args in the copy set
+ // 3. Call the operator
+ // 4. Wrap the output
+ // 5. (!) refreshMetadata for all the args in the original set
+ // 6. (!) Pop those args off.
+
+ // Step 1 & 2
+ auto args_size = op.schema().arguments().size();
+ // Step 1
+ auto front = stack->size() - args_size;
+ for (const auto arg_idx : c10::irange(0, args_size)) {
+ stack->push_back((*stack)[front + arg_idx]);
+ }
+ // Step 2
+ foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap);
+
+ // See NOTE [grad and vjp interaction with no_grad]
+ optional<c10::AutoGradMode> grad_guard;
+ if (transform_type == TransformType::Grad && prev_grad_mode.has_value() && *prev_grad_mode == false) {
+ grad_guard.emplace(*prev_grad_mode);
+ }
+ optional<c10::AutoFwGradMode> fw_grad_guard;
+ if (transform_type == TransformType::Jvp &&
+ prev_fwd_grad_mode.has_value() && prev_fwd_grad_mode.value() == false) {
+ fw_grad_guard.emplace(*prev_fwd_grad_mode);
+ }
+
+ // Re-dispatch
+ if (getDynamicLayerStack().size() == 0) {
+ sanityCheckStack(op, stack);
+ }
+ op.callBoxed(stack);
+
+ // Step 4, 5, 6
+ auto ret_size = op.schema().returns().size();
+ // Step 4
+ foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), wrap);
+
+ // Step 5
+ auto args_front = stack->size() - args_size - ret_size;
+ for (const auto arg_idx : c10::irange(0, args_size)) {
+ auto& ivalue = (*stack)[args_front + arg_idx];
+ if (!ivalue.isTensor()) {
+ continue;
+ }
+ auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor());
+ if (!maybe_tensor_wrapper) {
+ continue;
+ }
+ maybe_tensor_wrapper->refreshMetadata();
+ }
+
+ // Step 6
+ stack->erase(stack->end() - (args_size + ret_size), stack->end() - ret_size);
+}
+
+void GradInterpreterPtr::processImpl(
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack) {
+ autogradBasedTransformProcess(op, stack, level(), TransformType::Grad);
+}
+
+void GradInterpreterPtr::sendToNextInterpreterImpl(
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack) {
+ autogradBasedTransformSendToNext(
+ op, stack, level(),
+ TransformType::Grad, prevGradMode(), nullopt);
+}
+
+void JvpInterpreterPtr::processImpl(
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack) {
+ autogradBasedTransformProcess(op, stack, level(), TransformType::Jvp);
+}
+
+void JvpInterpreterPtr::sendToNextInterpreterImpl(
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack) {
+ autogradBasedTransformSendToNext(
+ op, stack, level(),
+ TransformType::Jvp, nullopt, prevFwdGradMode());
+}
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/ADInterpreters.h b/functorch/functorch/csrc/ADInterpreters.h
new file mode 100644
index 0000000..6f79afc
--- /dev/null
+++ b/functorch/functorch/csrc/ADInterpreters.h
@@ -0,0 +1,32 @@
+#pragma once
+#include <functorch/csrc/Interpreter.h>
+
+namespace at { namespace functorch {
+
+struct GradInterpreterPtr {
+ explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); }
+ TransformType key() const { return base_->key(); }
+ int64_t level() const { return base_->level(); }
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+ bool prevGradMode() const {
+ return c10::get<GradInterpreterMeta>(base_->meta()).prevGradMode_;
+ }
+ private:
+ const Interpreter* base_;
+};
+
+struct JvpInterpreterPtr {
+ explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); }
+ TransformType key() const { return base_->key(); }
+ int64_t level() const { return base_->level(); }
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+ bool prevFwdGradMode() const {
+ return c10::get<JvpInterpreterMeta>(base_->meta()).prevFwdGradMode_;
+ }
+ private:
+ const Interpreter* base_;
+};
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/DynamicLayer.cpp b/functorch/functorch/csrc/DynamicLayer.cpp
index cab56e6..6d96d52 100644
--- a/functorch/functorch/csrc/DynamicLayer.cpp
+++ b/functorch/functorch/csrc/DynamicLayer.cpp
@@ -19,40 +19,6 @@
namespace at {
namespace functorch {
-std::ostream& operator<<(std::ostream& os, const TransformType& t) {
- switch (t) {
- case TransformType::Torch:
- os << "Torch";
- break;
- case TransformType::Vmap:
- os << "Vmap";
- break;
- case TransformType::Grad:
- os << "Grad";
- break;
- case TransformType::Jvp:
- os << "Jvp";
- break;
- case TransformType::Functionalize:
- os << "Functionalize";
- break;
- default:
- TORCH_INTERNAL_ASSERT(false);
- }
- return os;
-}
-
-constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
- kDynamicLayerFrontModeKey,
- kDynamicLayerBackModeKey,
- kGradWrapperKey,
- DispatchKey::Functionalize,
- // DispatchKey::Batched,
- kBatchedKey,
- DispatchKey::PythonTLSSnapshot,
- DispatchKey::ADInplaceOrView
-}) | autograd_dispatch_keyset;
-
void setDynamicLayerFrontBackKeysIncluded(bool included) {
c10::impl::tls_set_dispatch_key_included(kDynamicLayerFrontModeKey, included);
c10::impl::tls_set_dispatch_key_included(kDynamicLayerBackModeKey, included);
@@ -66,14 +32,6 @@
optional<bool> prev_grad_mode,
optional<bool> prev_fwd_grad_mode,
optional<bool> functionalize_add_back_views)
- :
- transform_type_(transform_type),
- layerId_(layerId),
- batchSize_(batchSize),
- randomness_(randomness),
- prevGradMode_(prev_grad_mode),
- prevFwdGradMode_(prev_fwd_grad_mode),
- functionalizeAddBackViews_(functionalize_add_back_views)
{
if (transform_type == TransformType::Grad) {
TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
@@ -81,55 +39,42 @@
if (transform_type == TransformType::Jvp) {
TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
}
+ switch (transform_type) {
+ case TransformType::Vmap:
+ interpreter_ = Interpreter::Vmap(layerId, batchSize.value(), randomness.value());
+ break;
+ case TransformType::Grad:
+ interpreter_ = Interpreter::Grad(layerId, prev_grad_mode.value());
+ break;
+ case TransformType::Jvp:
+ interpreter_ = Interpreter::Jvp(layerId, prev_fwd_grad_mode.value());
+ break;
+ case TransformType::Functionalize:
+ interpreter_ = Interpreter::Functionalize(layerId, functionalize_add_back_views.value());
+ break;
+ default:
+ TORCH_INTERNAL_ASSERT(false);
+ }
}
TransformType DynamicLayer::key() const {
- return transform_type_;
+ return interpreter_.key();
}
int64_t DynamicLayer::layerId() const {
- return layerId_;
+ return interpreter_.level();
}
int64_t DynamicLayer::batchSize() const {
- TORCH_INTERNAL_ASSERT(batchSize_);
- return *batchSize_;
+ return VmapInterpreterPtr(&interpreter_).batchSize();
}
RandomnessType DynamicLayer::randomness() const {
- TORCH_INTERNAL_ASSERT(randomness_);
- return *randomness_;
-}
-
-optional<bool> DynamicLayer::prevGradMode() const {
- return prevGradMode_;
-}
-
-optional<bool> DynamicLayer::prevFwdGradMode() const {
- return prevFwdGradMode_;
-}
-
-void DynamicLayer::saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
- TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
- savedLocalDispatchKeySet_ = std::move(keyset);
-}
-
-void DynamicLayer::clearSavedLocalDispatchKeySet() {
- TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
- savedLocalDispatchKeySet_ = c10::nullopt;
-}
-
-c10::impl::LocalDispatchKeySet DynamicLayer::getSavedLocalDispatchKeySet() const {
- TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
- return *savedLocalDispatchKeySet_;
+ return VmapInterpreterPtr(&interpreter_).randomness();
}
constexpr DispatchKeySet kFrontBackKeys({kDynamicLayerBackModeKey, kDynamicLayerFrontModeKey});
-optional<bool> DynamicLayer::functionalizeAddBackViews() const {
- return functionalizeAddBackViews_;
-}
-
using DynmetaData = std::unordered_map<int64_t, std::shared_ptr<bool>>;
DynmetaData kDynMetaDataSingleton;
@@ -200,14 +145,14 @@
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
auto& layer = dynamicLayerStack.back();
auto tmp = c10::impl::tls_local_dispatch_key_set();
- layer.saveLocalDispatchKeySet(tmp);
+ layer.interpreter().saveLocalDispatchKeySet(tmp);
}
~SaveLocalDispatchKeySet() {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
auto& layer = dynamicLayerStack.back();
- auto tmp = layer.getSavedLocalDispatchKeySet();
- layer.clearSavedLocalDispatchKeySet();
+ auto tmp = layer.interpreter().getSavedLocalDispatchKeySet();
+ layer.interpreter().clearSavedLocalDispatchKeySet();
c10::impl::_force_tls_local_dispatch_key_set(tmp);
}
SaveLocalDispatchKeySet(const SaveLocalDispatchKeySet&) = delete;
@@ -216,7 +161,7 @@
static c10::impl::ForceDispatchKeyGuard
restoreLocalDispatchKeySetRAII(const DynamicLayer& layer) {
- auto tmp = layer.getSavedLocalDispatchKeySet();
+ auto tmp = layer.interpreter().getSavedLocalDispatchKeySet();
return c10::impl::ForceDispatchKeyGuard(tmp);
}
@@ -316,27 +261,7 @@
return result;
}
-static Tensor materializeGradWrappers(const Tensor& tensor, const std::vector<DynamicLayer>& dynlayerStack) {
- if (!tensor.defined()) {
- return tensor;
- }
- if (dynlayerStack.back().key() != TransformType::Grad && dynlayerStack.back().key() != TransformType::Jvp) {
- return tensor;
- }
- auto cur_level = dynlayerStack.back().layerId();
- auto* wrapper = maybeGetTensorWrapper(tensor);
- if (!wrapper) {
- return makeTensorWrapper(tensor, cur_level);
- }
- TORCH_INTERNAL_ASSERT(wrapper->level().value() <= cur_level, "escaped?");
- if (wrapper->level().value() == cur_level) {
- TORCH_INTERNAL_ASSERT(tensor.defined());
- return tensor;
- }
- return makeTensorWrapper(tensor, cur_level);
-}
-
-static Tensor unwrapIfDead(const Tensor& tensor) {
+Tensor unwrapIfDead(const Tensor& tensor) {
auto* wrapped = maybeGetTensorWrapper(tensor);
if (!wrapped) {
return tensor;
@@ -405,14 +330,6 @@
return os;
}
-void sanityCheckNotFunctional(const c10::OperatorHandle& op, torch::jit::Stack* stack, size_t num_args) {
- foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
- [](const Tensor& tensor) {
- TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensor));
- return tensor;
- });
-}
-
static bool allTensors(
ArrayRef<IValue> args,
std::function<bool(const Tensor&)> pred) {
@@ -445,26 +362,6 @@
return true;
}
-static bool anyTensors(
- ArrayRef<IValue> args,
- std::function<bool(const Tensor&)> pred) {
- // Demorgan's law
- return !allTensors(args, [&](const Tensor& self) { return !pred(self); });
-}
-
-static void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
- auto num_args = op.schema().arguments().size();
- foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
- [](const Tensor& tensor) {
-
- auto* wrapper = maybeGetTensorWrapper(tensor);
- TORCH_INTERNAL_ASSERT(wrapper == nullptr);
- auto* batched = maybeGetBatchedImpl(tensor);
- TORCH_INTERNAL_ASSERT(batched == nullptr);
- return tensor;
- });
-}
-
bool isInplaceOp(const FunctionSchema& schema) {
if (!schema.is_mutable() || schema.returns().size() != 1) {
return false;
@@ -486,45 +383,6 @@
return return_alias_info && return_alias_info->isWrite();
}
-static void checkForInvalidMutationOnCaptures(
- const c10::OperatorHandle& op,
- torch::jit::Stack* stack,
- const std::vector<DynamicLayer>& dynamicLayerStack) {
- if (dynamicLayerStack.back().key() != TransformType::Grad && dynamicLayerStack.back().key() != TransformType::Jvp) {
- return;
- }
- if (!isInplaceOp(op.schema())) {
- return;
- }
- auto args = torch::jit::last(stack, op.schema().arguments().size());
- auto mutated_arg = unwrapIfDead(args[0].toTensor());
- auto cur_level = dynamicLayerStack.back().layerId();
- auto* wrapper = maybeGetTensorWrapper(mutated_arg);
- if (wrapper && wrapper->level().has_value() && wrapper->level().value() == cur_level) {
- return;
- }
- TORCH_CHECK(false,
- "During a grad (vjp, jvp, grad, etc) transform, the function provided ",
- "attempted to call in-place operation (", op.schema().operator_name(), ") ",
- "that would mutate a captured Tensor. This is not supported; please rewrite ",
- "the function being transformed to explicitly accept the mutated Tensor(s) ",
- "as inputs.");
-}
-
-static DispatchKeySet keysForEnteringDynamicLayer(TransformType key) {
- if (key == TransformType::Vmap) {
- // NB: Does not include kVmapModeKey. We may modulate the key when
- // constructing the DynamicLayer, but we don't control it when entering/exiting
- // the DynamicLayer.
- return DispatchKeySet({kBatchedKey});
- } else if (key == TransformType::Grad || key == TransformType::Jvp) {
- return autograd_dispatch_keyset.add(DispatchKey::ADInplaceOrView);
- } else if (key == TransformType::Functionalize) {
- return DispatchKeySet(DispatchKey::Functionalize);
- } else {
- TORCH_INTERNAL_ASSERT(false, "Unsupported key: ", key);
- }
-}
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
static void dump_local_tls() {
@@ -534,30 +392,15 @@
}
#endif
-static DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key) {
- DispatchKeySet exclude = all_dynlayer_keyset;
- exclude = exclude.remove(kDynamicLayerBackModeKey);
- exclude = exclude - keysForEnteringDynamicLayer(key);
- return exclude;
-}
-static bool isFunctionalTensorAtCurrentLevel(const Tensor& tensor) {
- auto& dynamicLayerStack = dynamicLayerStackAccessor();
- auto layer = dynamicLayerStack.back();
- auto level = layer.layerId();
+struct WithoutTop {
+ WithoutTop();
+ ~WithoutTop();
+ DynamicLayer layer_;
+};
- if (!at::functionalization::impl::isFunctionalTensor(tensor)) {
- return false;
- }
- const auto* functional = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
- auto functional_level = functional->level();
- return functional_level == level;
-}
-
-static void setup_dispatch_key_tls(DispatchKeySet exclude, DispatchKeySet include) {
- auto local_keyset = c10::impl::tls_local_dispatch_key_set();
- local_keyset.excluded_ = local_keyset.excluded_ | exclude;
- local_keyset.included_ = local_keyset.included_ | include;
- c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
+WithoutTop::WithoutTop(): layer_(popDynamicLayer()) {}
+WithoutTop::~WithoutTop() {
+ pushDynamicLayer(std::move(layer_));
}
void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
@@ -581,220 +424,20 @@
foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), unwrapIfDead);
auto& layer = dynamicLayerStack.back();
- DispatchKeySet exclude = keysToExcludeWhenEnteringDynamicLayer(layer.key());
-
- switch (layer.key()) {
- case TransformType::Grad:
- case TransformType::Jvp:
- {
- // if is a grad transform, and the operation is in-place, and the mutated
- // argument is not currently wrapped in a TensorWrapper, then we need to
- // error out otherwise the result is silently incorrect
- checkForInvalidMutationOnCaptures(op, stack, dynamicLayerStack);
-
- // materialize live GradWrappers
- auto maybeTransformGradWrappers = [](const Tensor& tensor) {
- return materializeGradWrappers(tensor, getDynamicLayerStack());
- };
- foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), maybeTransformGradWrappers);
-
- setup_dispatch_key_tls(exclude, {});
- op.callBoxed(stack);
- break;
- }
- case TransformType::Vmap:
- {
- setup_dispatch_key_tls(exclude, DispatchKeySet(kVmapModeKey));
- op.callBoxed(stack);
- break;
- }
- case TransformType::Functionalize:
- {
- // We always want to call the functionalization kernels if functionalize() is on the layer stack.
- // It's the responsibility of the functionalization kernel to no-op and redispatch
- // if none of the input tensors are functional.
- setup_dispatch_key_tls(exclude, DispatchKeySet(DispatchKey::Functionalize));
- auto functionalization_add_back_views = layer.functionalizeAddBackViews().has_value() && *(layer.functionalizeAddBackViews());
- // We have some side-car TLS that we can set to toggle the functionaliation behavior.
- // If set, then we functionalization will only remove mutations, instead of
- // removing both mutations AND view operators.
- at::functionalization::impl::FunctionalizationReapplyViewsGuard functional_guard(functionalization_add_back_views);
-
- op.callBoxed(stack);
-
- auto ret_size = op.schema().returns().size();
- foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(),
- [&](const Tensor& tensor) {
- if (at::functionalization::impl::isFunctionalTensor(tensor)) {
- auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
- // Functorch is responsible for setting the level on the wrapper, since we don't
- // have that info available in core (for now).
- // We could just "propagate" the level from the input tensors inside of the functionalize kernels,
- // but unfortunately we can't do that for factory operators.
- wrapper->set_level(layer.layerId());
- }
- return tensor;
- }
- );
- break;
- }
- default:
- TORCH_INTERNAL_ASSERT(false);
- }
+ layer.interpreter().process(op, stack);
}
-struct WithoutTop {
- WithoutTop(): layer_(popDynamicLayer()) {
- }
- ~WithoutTop() {
- pushDynamicLayer(std::move(layer_));
- }
-
- DynamicLayer layer_;
-};
+static c10::impl::ForceDispatchKeyGuard
+restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) {
+ return c10::impl::ForceDispatchKeyGuard(key_set);
+}
void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
- auto cur_level = getDynamicLayerStack().back().layerId();
- auto cur_key = getDynamicLayerStack().back().key();
+ auto& layer = dynamicLayerStackAccessor().back();
+ auto restore_guard = restoreLocalDispatchKeySetRAII(layer.interpreter().getSavedLocalDispatchKeySet());
+ WithoutTop guard;
- switch (cur_key) {
- case TransformType::Grad:
- case TransformType::Jvp:
- {
- optional<bool> prev_grad_mode = getDynamicLayerStack().back().prevGradMode();
- optional<bool> prev_fwd_grad_mode = getDynamicLayerStack().back().prevFwdGradMode();
- if (cur_key == TransformType::Grad) {
- TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
- }
- if (cur_key == TransformType::Jvp) {
- TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
- }
- auto unwrap = [&](const Tensor& tensor) {
- if (!tensor.defined()) {
- return tensor;
- }
- auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor);
- if (!maybe_tensor_wrapper) {
- return tensor;
- }
- auto tensor_wrapper_level = maybe_tensor_wrapper->level().value();
- TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= cur_level);
- if (tensor_wrapper_level == cur_level) {
- return maybe_tensor_wrapper->value();
- }
- return tensor;
- };
- auto wrap = [&](const Tensor& tensor) {
- if (!tensor.defined()) {
- return tensor;
- }
- // if (c10::show_dispatch_trace_enabled()) {
- // std::cout << "wrap " << cur_level << std::endl;
- // }
- return makeTensorWrapper(tensor, cur_level);
- };
-
- // TODO: we only need to do the following (marked with !) on in-place functions
- // that modify sizes or strides. There aren't many of them.
- // If autograd dispatch key:
- // 1. (!) Put a copy of all of the args onto the stack
- // 2. Unwrap all the args in the copy set
- // 3. Call the operator
- // 4. Wrap the output
- // 5. (!) refreshMetadata for all the args in the original set
- // 6. (!) Pop those args off.
-
- // Step 1 & 2
- auto args_size = op.schema().arguments().size();
- // Step 1
- auto front = stack->size() - args_size;
- for (const auto arg_idx : c10::irange(0, args_size)) {
- stack->push_back((*stack)[front + arg_idx]);
- }
- // Step 2
- foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap);
-
- auto restore_guard = restoreLocalDispatchKeySetRAII(getDynamicLayerStack().back());
- WithoutTop guard;
-
- // See NOTE [grad and vjp interaction with no_grad]
- optional<c10::AutoGradMode> grad_guard;
- if (cur_key == TransformType::Grad && prev_grad_mode.has_value() && *prev_grad_mode == false) {
- grad_guard.emplace(*prev_grad_mode);
- }
- optional<c10::AutoFwGradMode> fw_grad_guard;
- if (cur_key == TransformType::Jvp &&
- prev_fwd_grad_mode.has_value() && prev_fwd_grad_mode.value() == false) {
- fw_grad_guard.emplace(*prev_fwd_grad_mode);
- }
-
- // Re-dispatch
- if (dynamicLayerStackAccessor().size() == 0) {
- sanityCheckStack(op, stack);
- }
- op.callBoxed(stack);
-
- // Step 4, 5, 6
- auto ret_size = op.schema().returns().size();
- // Step 4
- foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), wrap);
-
- // Step 5
- auto args_front = stack->size() - args_size - ret_size;
- for (const auto arg_idx : c10::irange(0, args_size)) {
- auto& ivalue = (*stack)[args_front + arg_idx];
- if (!ivalue.isTensor()) {
- continue;
- }
- auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor());
- if (!maybe_tensor_wrapper) {
- continue;
- }
- maybe_tensor_wrapper->refreshMetadata();
- }
-
- // Step 6
- stack->erase(stack->end() - (args_size + ret_size), stack->end() - ret_size);
- break;
- }
- case TransformType::Vmap:
- {
- auto restore_guard = restoreLocalDispatchKeySetRAII(getDynamicLayerStack().back());
- WithoutTop guard;
-
- // Re-dispatch
- if (dynamicLayerStackAccessor().size() == 0) {
- sanityCheckStack(op, stack);
- }
- op.callBoxed(stack);
-
- break;
- }
- case TransformType::Functionalize:
- {
- // For now, we don't support nested functionalization calls.
- // This check just enforces that - after the functionalize kernel runs
- // and we hit the BackModeFallback, we'll have unwrapped our FunctionalTensors
- // so we can check that the unwrapped thing is not another (nested) FunctionalTensor.
- auto args_size = op.schema().arguments().size();
- sanityCheckNotFunctional(op, stack, args_size);
-
- auto restore_guard = restoreLocalDispatchKeySetRAII(getDynamicLayerStack().back());
- WithoutTop guard;
-
- // Re-dispatch
- if (dynamicLayerStackAccessor().size() == 0) {
- sanityCheckStack(op, stack);
- }
- op.callBoxed(stack);
-
- auto ret_size = op.schema().returns().size();
- sanityCheckNotFunctional(op, stack, ret_size);
- break;
- }
- default:
- TORCH_INTERNAL_ASSERT(false);
- }
+ layer.interpreter().sendToNextInterpreter(op, stack);
}
TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
diff --git a/functorch/functorch/csrc/DynamicLayer.h b/functorch/functorch/csrc/DynamicLayer.h
index 8239ff4..cf84311 100644
--- a/functorch/functorch/csrc/DynamicLayer.h
+++ b/functorch/functorch/csrc/DynamicLayer.h
@@ -9,9 +9,14 @@
#include <c10/core/DispatchKey.h>
#include <ATen/core/function_schema.h>
#include <c10/util/Optional.h>
+#include <c10/util/variant.h>
#include <unordered_map>
#include <mutex>
#include <c10/core/impl/LocalDispatchKeySet.h>
+#include <functorch/csrc/Interpreter.h>
+#include <functorch/csrc/VmapInterpreter.h>
+#include <functorch/csrc/ADInterpreters.h>
+#include <functorch/csrc/FunctionalizeInterpreter.h>
// Forward declared bc I am lazy
namespace c10 { struct AutogradMetaInterface; }
@@ -19,23 +24,9 @@
namespace at {
namespace functorch {
-enum RandomnessType {
- Error, // always errors when calling a random function
- Same, // randomness appears the same across batches
- Different, // randomness appears different across batches
- END
-};
-
-enum class TransformType {
- Torch, // Unused
- Vmap,
- Grad, // reverse-mode AD, aka vjp
- Jvp, // forward-mode AD
- Functionalize,
-};
-
-std::ostream& operator<<(std::ostream& os, const TransformType& t);
-
+// TODO: we can excise DynamicLayer in favor of Interpreter,
+// But I am going to leave it for now as a compatiblity shim to avoid
+// needing to refactor a lot of callsites...
struct FUNCTORCH_API DynamicLayer {
explicit DynamicLayer(
TransformType transform_type,
@@ -49,36 +40,15 @@
TransformType key() const;
int64_t layerId() const;
+ const Interpreter& interpreter() const { return interpreter_; }
+ Interpreter& interpreter() { return interpreter_; }
+
// Only valid for vmap
int64_t batchSize() const;
RandomnessType randomness() const;
- // only valid for grad-based transforms
- optional<bool> prevGradMode() const;
-
- // only valid for jvp transform
- optional<bool> prevFwdGradMode() const;
-
- void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset);
- void clearSavedLocalDispatchKeySet();
- c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const;
-
- // only valid for functionalization
- optional<bool> functionalizeAddBackViews() const;
private:
- TransformType transform_type_;
- int64_t layerId_;
-
- // Honestly these should be a union or some extendable metadata class.
- // Not doing that for now because I don't think we'll use this mechanism for very long.
- optional<int64_t> batchSize_;
- optional<RandomnessType> randomness_;
- optional<bool> prevGradMode_;
- optional<bool> prevFwdGradMode_;
-
- optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
-
- optional<bool> functionalizeAddBackViews_;
+ Interpreter interpreter_;
};
FUNCTORCH_API int64_t initAndPushDynamicLayer(
@@ -109,15 +79,12 @@
// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
bool isInplaceOp(const c10::FunctionSchema& schema);
-// Applies the following for-loop:
-// for i in range(begin, end):
-// args[i] = func(args[i])
-void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
- std::function<Tensor(const Tensor&)> func);
+Tensor unwrapIfDead(const Tensor& tensor);
// Pretty printers
std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
+
}
} // namespace at
diff --git a/functorch/functorch/csrc/FunctionalizeInterpreter.cpp b/functorch/functorch/csrc/FunctionalizeInterpreter.cpp
new file mode 100644
index 0000000..4242305
--- /dev/null
+++ b/functorch/functorch/csrc/FunctionalizeInterpreter.cpp
@@ -0,0 +1,68 @@
+#include <functorch/csrc/FunctionalizeInterpreter.h>
+#include <functorch/csrc/DynamicLayer.h>
+#include <ATen/FunctionalTensorWrapper.h>
+
+namespace at { namespace functorch {
+
+static void sanityCheckNotFunctional(const c10::OperatorHandle& op, torch::jit::Stack* stack, size_t num_args) {
+ foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
+ [](const Tensor& tensor) {
+ TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensor));
+ return tensor;
+ });
+}
+
+void FunctionalizeInterpreterPtr::processImpl(
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack) {
+ DispatchKeySet exclude = keysToExcludeWhenEnteringDynamicLayer(TransformType::Functionalize);
+
+ // We always want to call the functionalization kernels if functionalize() is on the layer stack.
+ // It's the responsibility of the functionalization kernel to no-op and redispatch
+ // if none of the input tensors are functional.
+ setup_dispatch_key_tls(exclude, DispatchKeySet(DispatchKey::Functionalize));
+ auto functionalization_add_back_views = functionalizeAddBackViews();
+ // We have some side-car TLS that we can set to toggle the functionaliation behavior.
+ // If set, then we functionalization will only remove mutations, instead of
+ // removing both mutations AND view operators.
+ at::functionalization::impl::FunctionalizationReapplyViewsGuard functional_guard(functionalization_add_back_views);
+
+ op.callBoxed(stack);
+
+ auto ret_size = op.schema().returns().size();
+ foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(),
+ [&](const Tensor& tensor) {
+ if (at::functionalization::impl::isFunctionalTensor(tensor)) {
+ auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
+ // Functorch is responsible for setting the level on the wrapper, since we don't
+ // have that info available in core (for now).
+ // We could just "propagate" the level from the input tensors inside of the functionalize kernels,
+ // but unfortunately we can't do that for factory operators.
+ wrapper->set_level(level());
+ }
+ return tensor;
+ }
+ );
+}
+
+void FunctionalizeInterpreterPtr::sendToNextInterpreterImpl(
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack) {
+ // For now, we don't support nested functionalization calls.
+ // This check just enforces that - after the functionalize kernel runs
+ // and we hit the BackModeFallback, we'll have unwrapped our FunctionalTensors
+ // so we can check that the unwrapped thing is not another (nested) FunctionalTensor.
+ auto args_size = op.schema().arguments().size();
+ sanityCheckNotFunctional(op, stack, args_size);
+
+ // Re-dispatch
+ if (getDynamicLayerStack().size() == 0) {
+ sanityCheckStack(op, stack);
+ }
+ op.callBoxed(stack);
+
+ auto ret_size = op.schema().returns().size();
+ sanityCheckNotFunctional(op, stack, ret_size);
+}
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/FunctionalizeInterpreter.h b/functorch/functorch/csrc/FunctionalizeInterpreter.h
new file mode 100644
index 0000000..5475b38
--- /dev/null
+++ b/functorch/functorch/csrc/FunctionalizeInterpreter.h
@@ -0,0 +1,19 @@
+#pragma once
+#include <functorch/csrc/Interpreter.h>
+
+namespace at { namespace functorch {
+
+struct FunctionalizeInterpreterPtr {
+ explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); }
+ TransformType key() const { return base_->key(); }
+ int64_t level() const { return base_->level(); }
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+ bool functionalizeAddBackViews() const {
+ return c10::get<FunctionalizeInterpreterMeta>(base_->meta()).functionalizeAddBackViews_;
+ }
+ private:
+ const Interpreter* base_;
+};
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/Interpreter.cpp b/functorch/functorch/csrc/Interpreter.cpp
new file mode 100644
index 0000000..35c1150
--- /dev/null
+++ b/functorch/functorch/csrc/Interpreter.cpp
@@ -0,0 +1,108 @@
+#include <functorch/csrc/Interpreter.h>
+#include <functorch/csrc/BatchedTensorImpl.h>
+#include <functorch/csrc/TensorWrapper.h>
+#include <functorch/csrc/VmapInterpreter.h>
+#include <functorch/csrc/FunctionalizeInterpreter.h>
+#include <functorch/csrc/ADInterpreters.h>
+
+namespace at { namespace functorch {
+
+constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
+ kDynamicLayerFrontModeKey,
+ kDynamicLayerBackModeKey,
+ kGradWrapperKey,
+ DispatchKey::Functionalize,
+ // DispatchKey::Batched,
+ kBatchedKey,
+ DispatchKey::PythonTLSSnapshot,
+ DispatchKey::ADInplaceOrView
+}) | autograd_dispatch_keyset;
+
+static DispatchKeySet keysForEnteringDynamicLayer(TransformType key) {
+ if (key == TransformType::Vmap) {
+ // NB: Does not include kVmapModeKey. We may modulate the key when
+ // constructing the DynamicLayer, but we don't control it when entering/exiting
+ // the DynamicLayer.
+ return DispatchKeySet({kBatchedKey});
+ } else if (key == TransformType::Grad || key == TransformType::Jvp) {
+ return autograd_dispatch_keyset.add(DispatchKey::ADInplaceOrView);
+ } else if (key == TransformType::Functionalize) {
+ return DispatchKeySet(DispatchKey::Functionalize);
+ } else {
+ TORCH_INTERNAL_ASSERT(false, "Unsupported key: ", key);
+ }
+}
+
+DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key) {
+ DispatchKeySet exclude = all_dynlayer_keyset;
+ exclude = exclude.remove(kDynamicLayerBackModeKey);
+ exclude = exclude - keysForEnteringDynamicLayer(key);
+ return exclude;
+}
+
+void setup_dispatch_key_tls(DispatchKeySet exclude, DispatchKeySet include) {
+ auto local_keyset = c10::impl::tls_local_dispatch_key_set();
+ local_keyset.excluded_ = local_keyset.excluded_ | exclude;
+ local_keyset.included_ = local_keyset.included_ | include;
+ c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
+}
+
+std::ostream& operator<<(std::ostream& os, const TransformType& t) {
+ switch (t) {
+ case TransformType::Torch:
+ os << "Torch";
+ break;
+ case TransformType::Vmap:
+ os << "Vmap";
+ break;
+ case TransformType::Grad:
+ os << "Grad";
+ break;
+ case TransformType::Jvp:
+ os << "Jvp";
+ break;
+ case TransformType::Functionalize:
+ os << "Functionalize";
+ break;
+ default:
+ TORCH_INTERNAL_ASSERT(false);
+ }
+ return os;
+}
+
+void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+ auto num_args = op.schema().arguments().size();
+ foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
+ [](const Tensor& tensor) {
+
+ auto* wrapper = maybeGetTensorWrapper(tensor);
+ TORCH_INTERNAL_ASSERT(wrapper == nullptr);
+ auto* batched = maybeGetBatchedImpl(tensor);
+ TORCH_INTERNAL_ASSERT(batched == nullptr);
+ return tensor;
+ });
+}
+
+#define INTERPRETER_DISPATCH(type, method) \
+ switch (key()) { \
+ case TransformType::Vmap: \
+ return VmapInterpreterPtr(this). method; \
+ case TransformType::Grad: \
+ return GradInterpreterPtr(this). method; \
+ case TransformType::Jvp: \
+ return JvpInterpreterPtr(this). method; \
+ case TransformType::Functionalize: \
+ return FunctionalizeInterpreterPtr(this). method; \
+ default: \
+ TORCH_INTERNAL_ASSERT(false, "Unrecognized transform"); \
+ }
+
+void Interpreter::process(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+ INTERPRETER_DISPATCH(key_, SINGLE_ARG(processImpl(op, stack)));
+}
+
+void Interpreter::sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+ INTERPRETER_DISPATCH(key_, SINGLE_ARG(sendToNextInterpreterImpl(op, stack)));
+}
+
+}}
diff --git a/functorch/functorch/csrc/Interpreter.h b/functorch/functorch/csrc/Interpreter.h
new file mode 100644
index 0000000..2a1a426
--- /dev/null
+++ b/functorch/functorch/csrc/Interpreter.h
@@ -0,0 +1,187 @@
+#pragma once
+
+// variant.h doesn't clean up after itself...
+#include <c10/util/variant.h>
+#undef DECLTYPE_AUTO
+
+#include <functorch/csrc/Macros.h>
+#include <functorch/csrc/Constants.h>
+#include <ATen/core/dispatch/Dispatcher.h>
+#include <c10/core/impl/LocalDispatchKeySet.h>
+#include <c10/util/Optional.h>
+
+namespace at { namespace functorch {
+
+// NOTE: [functorch interpreter stack]
+//
+// functorch's dispatching system uses a stack of interpreters.
+// Historically we've referred to this as the "DynamicLayerStack".
+//
+// An interpreter is something that reads in the code it is passed
+// and then executes it. We have a different interpreter per-transform:
+// the "VmapInterpreter" is responsible for reading in operators (like aten::mv)
+// and executing the batched version of it (the batching rule for aten::mv).
+//
+// Concretely, each interpreter is responsible for two things:
+//
+// 1) process(ophandle, stack)
+// Given an operator handle and a stack of arguments, the interpreter is
+// responsible for figuring out how to execute the operation under the semantics
+// of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call
+// the batching rule.
+//
+// The batching rules are stored as kernels on the FuncTorchBatched key, so the way
+// VmapInterpreter calls the batching rule is roughly: (A) exclude all
+// dispatch keys aside from the Batched key, (B) redispatch so we get to the
+// Batched key.
+//
+// 2) sendToNextInterpreter(ophandle, stack)
+// The VmapInterpreter, when it sees aten::mv, will process it into a call to
+// aten::mm. It then needs to send the call to aten::mm to the next interpreter
+// in the interpreter stack.
+//
+// The VmapInterpreter just does this via a call to ophandle.callBoxed(stack)
+// and most Interpreters will implement it this way.
+
+enum RandomnessType {
+ Error, // always errors when calling a random function
+ Same, // randomness appears the same across batches
+ Different, // randomness appears different across batches
+ END
+};
+
+enum class TransformType {
+ Torch, // Unused
+ Vmap,
+ Grad, // reverse-mode AD, aka vjp
+ Jvp, // forward-mode AD
+ Functionalize,
+};
+
+std::ostream& operator<<(std::ostream& os, const TransformType& t);
+
+// NOTE: [Interpreter "subclassing" design]
+//
+// How are various Interpreters for different transforms (vmap, grad, ...)
+// implemented?
+//
+// Accessing interpreters is in the hot-path of functorch so we have a constraint
+// that this code must be as fast as possible.
+//
+// As a result, we stay away from virtual methods and this causes our code
+// to look a little funny.
+//
+// `Interpreter` is the struct for Interpreters. It holds ALL of the
+// relevant information (what type of interpreter it is and the metadata).
+// Metadata for each interpreter is represented as a Union (c10::variant)
+// of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...).
+//
+// Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this
+// if you want to access the metadata fields (like batchSize and randomness).
+//
+// Each type of interpreter (e.g. Vmap) has a convenience struct
+// (e.g. VmapInterpreterPtr) associated with it.
+//
+// Construct the convenience struct with VmapInterpreterPtr(Interpreter*),
+// and then one can access methods on VmapInterpreterPtr like so:
+// >>> VmapInterpreterPtr(&interpreter).batchSize()
+//
+// Finally, Interpreter::process switches on the type of the interpreter
+// and calls one of {Transform}Intepreter::processImpl under the hood.
+// Same for Interpreter::sendToNextInterpreter :)
+
+struct VmapInterpreterMeta {
+ explicit VmapInterpreterMeta(int64_t batchSize, RandomnessType randomness) :
+ batchSize_(batchSize), randomness_(randomness) {}
+ int64_t batchSize_;
+ RandomnessType randomness_;
+};
+
+struct GradInterpreterMeta {
+ explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
+ bool prevGradMode_;
+};
+
+struct JvpInterpreterMeta {
+ explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
+ bool prevFwdGradMode_;
+};
+
+struct FunctionalizeInterpreterMeta {
+ explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
+ functionalizeAddBackViews_(functionalizeAddBackViews) {}
+ bool functionalizeAddBackViews_;
+};
+
+typedef c10::variant<
+ int64_t,
+ GradInterpreterMeta,
+ JvpInterpreterMeta,
+ VmapInterpreterMeta,
+ FunctionalizeInterpreterMeta
+> InterpreterMeta;
+
+
+struct Interpreter {
+ // factory functions
+ static Interpreter Vmap(int64_t level, int64_t batchSize, RandomnessType randomness) {
+ return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(batchSize, randomness));
+ }
+ static Interpreter Grad(int64_t level, bool prevGradMode) {
+ return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode));
+ }
+ static Interpreter Jvp(int64_t level, bool prevFwdGradMode) {
+ return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode));
+ }
+ static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) {
+ return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews));
+ }
+
+ // methods
+ TransformType key() const { return type_; }
+ int64_t level() const { return level_; }
+ const InterpreterMeta& meta() const { return meta_; }
+
+ void process(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+ void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+
+ void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
+ TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
+ savedLocalDispatchKeySet_ = std::move(keyset);
+ }
+ void clearSavedLocalDispatchKeySet() {
+ TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
+ savedLocalDispatchKeySet_ = c10::nullopt;
+ }
+ c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
+ TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
+ return *savedLocalDispatchKeySet_;
+ }
+
+ // Please don't use this
+ explicit Interpreter() = default;
+
+ private:
+ explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
+ type_(type), level_(level), meta_(meta) {}
+
+ // fields
+ TransformType type_;
+ int64_t level_;
+ optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
+ InterpreterMeta meta_;
+};
+
+// Applies the following for-loop:
+// for i in range(begin, end):
+// args[i] = func(args[i])
+void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
+ std::function<Tensor(const Tensor&)> func);
+
+DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key);
+
+void setup_dispatch_key_tls(DispatchKeySet exclude, DispatchKeySet include);
+
+void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/VmapInterpreter.cpp b/functorch/functorch/csrc/VmapInterpreter.cpp
new file mode 100644
index 0000000..a8f0283
--- /dev/null
+++ b/functorch/functorch/csrc/VmapInterpreter.cpp
@@ -0,0 +1,24 @@
+#include <functorch/csrc/VmapInterpreter.h>
+#include <functorch/csrc/DynamicLayer.h>
+
+namespace at { namespace functorch {
+
+void VmapInterpreterPtr::processImpl(
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack) {
+ DispatchKeySet exclude = keysToExcludeWhenEnteringDynamicLayer(TransformType::Vmap);
+ setup_dispatch_key_tls(exclude, DispatchKeySet(kVmapModeKey));
+ op.callBoxed(stack);
+}
+
+void VmapInterpreterPtr::sendToNextInterpreterImpl(
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack) {
+ // Re-dispatch
+ if (getDynamicLayerStack().size() == 0) {
+ sanityCheckStack(op, stack);
+ }
+ op.callBoxed(stack);
+}
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/VmapInterpreter.h b/functorch/functorch/csrc/VmapInterpreter.h
new file mode 100644
index 0000000..084cea9
--- /dev/null
+++ b/functorch/functorch/csrc/VmapInterpreter.h
@@ -0,0 +1,22 @@
+#pragma once
+#include <functorch/csrc/Interpreter.h>
+
+namespace at { namespace functorch {
+
+struct VmapInterpreterPtr {
+ explicit VmapInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Vmap); }
+ TransformType key() const { return base_->key(); }
+ int64_t level() const { return base_->level(); }
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+ int64_t batchSize() const {
+ return c10::get<VmapInterpreterMeta>(base_->meta()).batchSize_;
+ }
+ RandomnessType randomness() const {
+ return c10::get<VmapInterpreterMeta>(base_->meta()).randomness_;
+ }
+ private:
+ const Interpreter* base_;
+};
+
+}} // namespace at::functorch