[functorch] Refactor DynamicLayer so the logic for each transform is separate (pytorch/functorch#756)

See functorch/csrc/Interpreter.h for the rundown of what happened.
diff --git a/functorch/functorch/csrc/ADInterpreters.cpp b/functorch/functorch/csrc/ADInterpreters.cpp
new file mode 100644
index 0000000..310df42
--- /dev/null
+++ b/functorch/functorch/csrc/ADInterpreters.cpp
@@ -0,0 +1,192 @@
+#include <functorch/csrc/ADInterpreters.h>
+#include <functorch/csrc/DynamicLayer.h>
+#include <functorch/csrc/TensorWrapper.h>
+
+namespace at { namespace functorch {
+
+static void checkForInvalidMutationOnCaptures(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack,
+    int64_t cur_level) {
+  if (!isInplaceOp(op.schema())) {
+    return;
+  }
+  auto args = torch::jit::last(stack, op.schema().arguments().size());
+  auto mutated_arg = unwrapIfDead(args[0].toTensor());
+  auto* wrapper = maybeGetTensorWrapper(mutated_arg);
+  if (wrapper && wrapper->level().has_value() && wrapper->level().value() == cur_level) {
+    return;
+  }
+  TORCH_CHECK(false,
+      "During a grad (vjp, jvp, grad, etc) transform, the function provided ",
+      "attempted to call in-place operation (", op.schema().operator_name(), ") ",
+      "that would mutate a captured Tensor. This is not supported; please rewrite ",
+      "the function being transformed to explicitly accept the mutated Tensor(s) ",
+      "as inputs.");
+}
+
+static Tensor materializeGradWrappers(const Tensor& tensor, int64_t current_level) {
+  if (!tensor.defined()) {
+    return tensor;
+  }
+  auto* wrapper = maybeGetTensorWrapper(tensor);
+  if (!wrapper) {
+    return makeTensorWrapper(tensor, current_level);
+  }
+  TORCH_INTERNAL_ASSERT(wrapper->level().value() <= current_level, "escaped?");
+  if (wrapper->level().value() == current_level) {
+    TORCH_INTERNAL_ASSERT(tensor.defined());
+    return tensor;
+  }
+  return makeTensorWrapper(tensor, current_level);
+}
+
+static void autogradBasedTransformProcess(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack,
+    int64_t current_level,
+    TransformType transform_type) {
+  // if is a grad transform, and the operation is in-place, and the mutated
+  // argument is not currently wrapped in a TensorWrapper, then we need to
+  // error out otherwise the result is silently incorrect
+  checkForInvalidMutationOnCaptures(op, stack, current_level);
+
+  // materialize live GradWrappers
+  auto maybeTransformGradWrappers = [&](const Tensor& tensor) {
+    return materializeGradWrappers(tensor, current_level);
+  };
+  auto num_args = op.schema().arguments().size();
+  foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), maybeTransformGradWrappers);
+
+  auto exclude = keysToExcludeWhenEnteringDynamicLayer(transform_type);
+  setup_dispatch_key_tls(exclude, {});
+  op.callBoxed(stack);
+}
+
+static void autogradBasedTransformSendToNext(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack,
+    int64_t current_level,
+    TransformType transform_type,
+    optional<bool> prev_grad_mode,
+    optional<bool> prev_fwd_grad_mode) {
+  if (transform_type == TransformType::Grad) {
+    TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
+  }
+  if (transform_type == TransformType::Jvp) {
+    TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
+  }
+  auto unwrap = [&](const Tensor& tensor) {
+    if (!tensor.defined()) {
+      return tensor;
+    }
+    auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor);
+    if (!maybe_tensor_wrapper) {
+      return tensor;
+    }
+    auto tensor_wrapper_level = maybe_tensor_wrapper->level().value();
+    TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= current_level);
+    if (tensor_wrapper_level == current_level) {
+      return maybe_tensor_wrapper->value();
+    }
+    return tensor;
+  };
+  auto wrap = [&](const Tensor& tensor) {
+    if (!tensor.defined()) {
+      return tensor;
+    }
+    // if (c10::show_dispatch_trace_enabled()) {
+    //   std::cout << "wrap " << current_level << std::endl;
+    // }
+    return makeTensorWrapper(tensor, current_level);
+  };
+
+  // TODO: we only need to do the following (marked with !) on in-place functions
+  // that modify sizes or strides. There aren't many of them.
+  // If autograd dispatch key:
+  // 1. (!) Put a copy of all of the args onto the stack
+  // 2. Unwrap all the args in the copy set
+  // 3. Call the operator
+  // 4. Wrap the output
+  // 5. (!) refreshMetadata for all the args in the original set
+  // 6. (!) Pop those args off.
+
+  // Step 1 & 2
+  auto args_size = op.schema().arguments().size();
+  // Step 1
+  auto front = stack->size() - args_size;
+  for (const auto arg_idx : c10::irange(0, args_size)) {
+    stack->push_back((*stack)[front + arg_idx]);
+  }
+  // Step 2
+  foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap);
+
+  // See NOTE [grad and vjp interaction with no_grad]
+  optional<c10::AutoGradMode> grad_guard;
+  if (transform_type == TransformType::Grad && prev_grad_mode.has_value() && *prev_grad_mode == false) {
+    grad_guard.emplace(*prev_grad_mode);
+  }
+  optional<c10::AutoFwGradMode> fw_grad_guard;
+  if (transform_type == TransformType::Jvp &&
+      prev_fwd_grad_mode.has_value() && prev_fwd_grad_mode.value() == false) {
+    fw_grad_guard.emplace(*prev_fwd_grad_mode);
+  }
+
+  // Re-dispatch
+  if (getDynamicLayerStack().size() == 0) {
+    sanityCheckStack(op, stack);
+  }
+  op.callBoxed(stack);
+
+  // Step 4, 5, 6
+  auto ret_size = op.schema().returns().size();
+  // Step 4
+  foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), wrap);
+
+  // Step 5
+  auto args_front = stack->size() - args_size - ret_size;
+  for (const auto arg_idx : c10::irange(0, args_size)) {
+    auto& ivalue = (*stack)[args_front + arg_idx];
+    if (!ivalue.isTensor()) {
+      continue;
+    }
+    auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor());
+    if (!maybe_tensor_wrapper) {
+      continue;
+    }
+    maybe_tensor_wrapper->refreshMetadata();
+  }
+
+  // Step 6
+  stack->erase(stack->end() - (args_size + ret_size), stack->end() - ret_size);
+}
+
+void GradInterpreterPtr::processImpl(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack) {
+  autogradBasedTransformProcess(op, stack, level(), TransformType::Grad);
+}
+
+void GradInterpreterPtr::sendToNextInterpreterImpl(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack) {
+  autogradBasedTransformSendToNext(
+      op, stack, level(),
+      TransformType::Grad, prevGradMode(), nullopt);
+}
+
+void JvpInterpreterPtr::processImpl(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack) {
+  autogradBasedTransformProcess(op, stack, level(), TransformType::Jvp);
+}
+
+void JvpInterpreterPtr::sendToNextInterpreterImpl(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack) {
+  autogradBasedTransformSendToNext(
+      op, stack, level(),
+      TransformType::Jvp, nullopt, prevFwdGradMode());
+}
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/ADInterpreters.h b/functorch/functorch/csrc/ADInterpreters.h
new file mode 100644
index 0000000..6f79afc
--- /dev/null
+++ b/functorch/functorch/csrc/ADInterpreters.h
@@ -0,0 +1,32 @@
+#pragma once
+#include <functorch/csrc/Interpreter.h>
+
+namespace at { namespace functorch {
+
+struct GradInterpreterPtr {
+  explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); }
+  TransformType key() const { return base_->key(); }
+  int64_t level() const { return base_->level(); }
+  void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  bool prevGradMode() const {
+    return c10::get<GradInterpreterMeta>(base_->meta()).prevGradMode_;
+  }
+ private:
+  const Interpreter* base_;
+};
+
+struct JvpInterpreterPtr {
+  explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); }
+  TransformType key() const { return base_->key(); }
+  int64_t level() const { return base_->level(); }
+  void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  bool prevFwdGradMode() const {
+    return c10::get<JvpInterpreterMeta>(base_->meta()).prevFwdGradMode_;
+  }
+ private:
+  const Interpreter* base_;
+};
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/DynamicLayer.cpp b/functorch/functorch/csrc/DynamicLayer.cpp
index cab56e6..6d96d52 100644
--- a/functorch/functorch/csrc/DynamicLayer.cpp
+++ b/functorch/functorch/csrc/DynamicLayer.cpp
@@ -19,40 +19,6 @@
 namespace at {
 namespace functorch {
 
-std::ostream& operator<<(std::ostream& os, const TransformType& t) {
-  switch (t) {
-    case TransformType::Torch:
-      os << "Torch";
-      break;
-    case TransformType::Vmap:
-      os << "Vmap";
-      break;
-    case TransformType::Grad:
-      os << "Grad";
-      break;
-    case TransformType::Jvp:
-      os << "Jvp";
-      break;
-    case TransformType::Functionalize:
-      os << "Functionalize";
-      break;
-    default:
-      TORCH_INTERNAL_ASSERT(false);
-  }
-  return os;
-}
-
-constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
-  kDynamicLayerFrontModeKey,
-  kDynamicLayerBackModeKey,
-  kGradWrapperKey,
-  DispatchKey::Functionalize,
-  // DispatchKey::Batched,
-  kBatchedKey,
-  DispatchKey::PythonTLSSnapshot,
-  DispatchKey::ADInplaceOrView
-}) | autograd_dispatch_keyset;
-
 void setDynamicLayerFrontBackKeysIncluded(bool included) {
   c10::impl::tls_set_dispatch_key_included(kDynamicLayerFrontModeKey, included);
   c10::impl::tls_set_dispatch_key_included(kDynamicLayerBackModeKey, included);
@@ -66,14 +32,6 @@
     optional<bool> prev_grad_mode,
     optional<bool> prev_fwd_grad_mode,
     optional<bool> functionalize_add_back_views)
-  :
-    transform_type_(transform_type),
-    layerId_(layerId),
-    batchSize_(batchSize),
-    randomness_(randomness),
-    prevGradMode_(prev_grad_mode),
-    prevFwdGradMode_(prev_fwd_grad_mode),
-    functionalizeAddBackViews_(functionalize_add_back_views)
 {
   if (transform_type == TransformType::Grad) {
     TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
@@ -81,55 +39,42 @@
   if (transform_type == TransformType::Jvp) {
     TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
   }
+  switch (transform_type) {
+    case TransformType::Vmap:
+      interpreter_ = Interpreter::Vmap(layerId, batchSize.value(), randomness.value());
+      break;
+    case TransformType::Grad:
+      interpreter_ = Interpreter::Grad(layerId, prev_grad_mode.value());
+      break;
+    case TransformType::Jvp:
+      interpreter_ = Interpreter::Jvp(layerId, prev_fwd_grad_mode.value());
+      break;
+    case TransformType::Functionalize:
+      interpreter_ = Interpreter::Functionalize(layerId, functionalize_add_back_views.value());
+      break;
+    default:
+      TORCH_INTERNAL_ASSERT(false);
+  }
 }
 
 TransformType DynamicLayer::key() const {
-  return transform_type_;
+  return interpreter_.key();
 }
 
 int64_t DynamicLayer::layerId() const {
-  return layerId_;
+  return interpreter_.level();
 }
 
 int64_t DynamicLayer::batchSize() const {
-  TORCH_INTERNAL_ASSERT(batchSize_);
-  return *batchSize_;
+  return VmapInterpreterPtr(&interpreter_).batchSize();
 }
 
 RandomnessType DynamicLayer::randomness() const {
-  TORCH_INTERNAL_ASSERT(randomness_);
-  return *randomness_;
-}
-
-optional<bool> DynamicLayer::prevGradMode() const {
-  return prevGradMode_;
-}
-
-optional<bool> DynamicLayer::prevFwdGradMode() const {
-  return prevFwdGradMode_;
-}
-
-void DynamicLayer::saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
-  TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
-  savedLocalDispatchKeySet_ = std::move(keyset);
-}
-
-void DynamicLayer::clearSavedLocalDispatchKeySet() {
-  TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
-  savedLocalDispatchKeySet_ = c10::nullopt;
-}
-
-c10::impl::LocalDispatchKeySet DynamicLayer::getSavedLocalDispatchKeySet() const {
-  TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
-  return *savedLocalDispatchKeySet_;
+  return VmapInterpreterPtr(&interpreter_).randomness();
 }
 
 constexpr DispatchKeySet kFrontBackKeys({kDynamicLayerBackModeKey, kDynamicLayerFrontModeKey});
 
-optional<bool> DynamicLayer::functionalizeAddBackViews() const {
-  return functionalizeAddBackViews_;
-}
-
 using DynmetaData = std::unordered_map<int64_t, std::shared_ptr<bool>>;
 DynmetaData kDynMetaDataSingleton;
 
@@ -200,14 +145,14 @@
     TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
     auto& layer = dynamicLayerStack.back();
     auto tmp = c10::impl::tls_local_dispatch_key_set();
-    layer.saveLocalDispatchKeySet(tmp);
+    layer.interpreter().saveLocalDispatchKeySet(tmp);
   }
   ~SaveLocalDispatchKeySet() {
     auto& dynamicLayerStack = dynamicLayerStackAccessor();
     TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
     auto& layer = dynamicLayerStack.back();
-    auto tmp = layer.getSavedLocalDispatchKeySet();
-    layer.clearSavedLocalDispatchKeySet();
+    auto tmp = layer.interpreter().getSavedLocalDispatchKeySet();
+    layer.interpreter().clearSavedLocalDispatchKeySet();
     c10::impl::_force_tls_local_dispatch_key_set(tmp);
   }
   SaveLocalDispatchKeySet(const SaveLocalDispatchKeySet&) = delete;
@@ -216,7 +161,7 @@
 
 static c10::impl::ForceDispatchKeyGuard
 restoreLocalDispatchKeySetRAII(const DynamicLayer& layer) {
-  auto tmp = layer.getSavedLocalDispatchKeySet();
+  auto tmp = layer.interpreter().getSavedLocalDispatchKeySet();
   return c10::impl::ForceDispatchKeyGuard(tmp);
 }
 
@@ -316,27 +261,7 @@
   return result;
 }
 
-static Tensor materializeGradWrappers(const Tensor& tensor, const std::vector<DynamicLayer>& dynlayerStack) {
-  if (!tensor.defined()) {
-    return tensor;
-  }
-  if (dynlayerStack.back().key() != TransformType::Grad && dynlayerStack.back().key() != TransformType::Jvp) {
-    return tensor;
-  }
-  auto cur_level = dynlayerStack.back().layerId();
-  auto* wrapper = maybeGetTensorWrapper(tensor);
-  if (!wrapper) {
-    return makeTensorWrapper(tensor, cur_level);
-  }
-  TORCH_INTERNAL_ASSERT(wrapper->level().value() <= cur_level, "escaped?");
-  if (wrapper->level().value() == cur_level) {
-    TORCH_INTERNAL_ASSERT(tensor.defined());
-    return tensor;
-  }
-  return makeTensorWrapper(tensor, cur_level);
-}
-
-static Tensor unwrapIfDead(const Tensor& tensor) {
+Tensor unwrapIfDead(const Tensor& tensor) {
   auto* wrapped = maybeGetTensorWrapper(tensor);
   if (!wrapped) {
     return tensor;
@@ -405,14 +330,6 @@
   return os;
 }
 
-void sanityCheckNotFunctional(const c10::OperatorHandle& op, torch::jit::Stack* stack, size_t num_args) {
-  foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
-      [](const Tensor& tensor) {
-        TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensor));
-        return tensor;
-      });
-}
-
 static bool allTensors(
     ArrayRef<IValue> args,
     std::function<bool(const Tensor&)> pred) {
@@ -445,26 +362,6 @@
   return true;
 }
 
-static bool anyTensors(
-    ArrayRef<IValue> args,
-    std::function<bool(const Tensor&)> pred) {
-  // Demorgan's law
-  return !allTensors(args, [&](const Tensor& self) { return !pred(self); });
-}
-
-static void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
-  auto num_args = op.schema().arguments().size();
-  foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
-      [](const Tensor& tensor) {
-
-        auto* wrapper = maybeGetTensorWrapper(tensor);
-        TORCH_INTERNAL_ASSERT(wrapper == nullptr);
-        auto* batched = maybeGetBatchedImpl(tensor);
-        TORCH_INTERNAL_ASSERT(batched == nullptr);
-        return tensor;
-      });
-}
-
 bool isInplaceOp(const FunctionSchema& schema) {
   if (!schema.is_mutable() || schema.returns().size() != 1) {
     return false;
@@ -486,45 +383,6 @@
   return return_alias_info && return_alias_info->isWrite();
 }
 
-static void checkForInvalidMutationOnCaptures(
-    const c10::OperatorHandle& op,
-    torch::jit::Stack* stack,
-    const std::vector<DynamicLayer>& dynamicLayerStack) {
-  if (dynamicLayerStack.back().key() != TransformType::Grad && dynamicLayerStack.back().key() != TransformType::Jvp) {
-    return;
-  }
-  if (!isInplaceOp(op.schema())) {
-    return;
-  }
-  auto args = torch::jit::last(stack, op.schema().arguments().size());
-  auto mutated_arg = unwrapIfDead(args[0].toTensor());
-  auto cur_level = dynamicLayerStack.back().layerId();
-  auto* wrapper = maybeGetTensorWrapper(mutated_arg);
-  if (wrapper && wrapper->level().has_value() && wrapper->level().value() == cur_level) {
-    return;
-  }
-  TORCH_CHECK(false,
-      "During a grad (vjp, jvp, grad, etc) transform, the function provided ",
-      "attempted to call in-place operation (", op.schema().operator_name(), ") ",
-      "that would mutate a captured Tensor. This is not supported; please rewrite ",
-      "the function being transformed to explicitly accept the mutated Tensor(s) ",
-      "as inputs.");
-}
-
-static DispatchKeySet keysForEnteringDynamicLayer(TransformType key) {
-  if (key == TransformType::Vmap) {
-    // NB: Does not include kVmapModeKey. We may modulate the key when
-    // constructing the DynamicLayer, but we don't control it when entering/exiting
-    // the DynamicLayer.
-    return DispatchKeySet({kBatchedKey});
-  } else if (key == TransformType::Grad || key == TransformType::Jvp) {
-    return autograd_dispatch_keyset.add(DispatchKey::ADInplaceOrView);
-  } else if (key == TransformType::Functionalize) {
-    return DispatchKeySet(DispatchKey::Functionalize);
-  } else {
-    TORCH_INTERNAL_ASSERT(false, "Unsupported key: ", key);
-  }
-}
 
 #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
 static void dump_local_tls() {
@@ -534,30 +392,15 @@
 }
 #endif
 
-static DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key) {
-  DispatchKeySet exclude = all_dynlayer_keyset;
-  exclude = exclude.remove(kDynamicLayerBackModeKey);
-  exclude = exclude - keysForEnteringDynamicLayer(key);
-  return exclude;
-}
-static bool isFunctionalTensorAtCurrentLevel(const Tensor& tensor) {
-  auto& dynamicLayerStack = dynamicLayerStackAccessor();
-  auto layer = dynamicLayerStack.back();
-  auto level = layer.layerId();
+struct WithoutTop {
+  WithoutTop();
+  ~WithoutTop();
+  DynamicLayer layer_;
+};
 
-  if (!at::functionalization::impl::isFunctionalTensor(tensor)) {
-    return false;
-  }
-  const auto* functional = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
-  auto functional_level = functional->level();
-  return functional_level == level;
-}
-
-static void setup_dispatch_key_tls(DispatchKeySet exclude, DispatchKeySet include) {
-  auto local_keyset = c10::impl::tls_local_dispatch_key_set();
-  local_keyset.excluded_ = local_keyset.excluded_ | exclude;
-  local_keyset.included_ = local_keyset.included_ | include;
-  c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
+WithoutTop::WithoutTop(): layer_(popDynamicLayer()) {}
+WithoutTop::~WithoutTop() {
+  pushDynamicLayer(std::move(layer_));
 }
 
 void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
@@ -581,220 +424,20 @@
   foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), unwrapIfDead);
 
   auto& layer = dynamicLayerStack.back();
-  DispatchKeySet exclude = keysToExcludeWhenEnteringDynamicLayer(layer.key());
-
-  switch (layer.key()) {
-    case TransformType::Grad:
-    case TransformType::Jvp:
-      {
-        // if is a grad transform, and the operation is in-place, and the mutated
-        // argument is not currently wrapped in a TensorWrapper, then we need to
-        // error out otherwise the result is silently incorrect
-        checkForInvalidMutationOnCaptures(op, stack, dynamicLayerStack);
-
-        // materialize live GradWrappers
-        auto maybeTransformGradWrappers = [](const Tensor& tensor) {
-          return materializeGradWrappers(tensor, getDynamicLayerStack());
-        };
-        foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), maybeTransformGradWrappers);
-
-        setup_dispatch_key_tls(exclude, {});
-        op.callBoxed(stack);
-        break;
-      }
-    case TransformType::Vmap:
-      {
-        setup_dispatch_key_tls(exclude, DispatchKeySet(kVmapModeKey));
-        op.callBoxed(stack);
-        break;
-      }
-    case TransformType::Functionalize:
-      {
-        // We always want to call the functionalization kernels if functionalize() is on the layer stack.
-        // It's the responsibility of the functionalization kernel to no-op and redispatch
-        // if none of the input tensors are functional.
-        setup_dispatch_key_tls(exclude, DispatchKeySet(DispatchKey::Functionalize));
-        auto functionalization_add_back_views = layer.functionalizeAddBackViews().has_value() && *(layer.functionalizeAddBackViews());
-        // We have some side-car TLS that we can set to toggle the functionaliation behavior.
-        // If set, then we functionalization will only remove mutations, instead of
-        // removing both mutations AND view operators.
-        at::functionalization::impl::FunctionalizationReapplyViewsGuard functional_guard(functionalization_add_back_views);
-
-        op.callBoxed(stack);
-
-        auto ret_size = op.schema().returns().size();
-        foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(),
-          [&](const Tensor& tensor) {
-            if (at::functionalization::impl::isFunctionalTensor(tensor)) {
-              auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
-              // Functorch is responsible for setting the level on the wrapper, since we don't
-              // have that info available in core (for now).
-              // We could just "propagate" the level from the input tensors inside of the functionalize kernels,
-              // but unfortunately we can't do that for factory operators.
-              wrapper->set_level(layer.layerId());
-            }
-            return tensor;
-          }
-        );
-        break;
-      }
-    default:
-      TORCH_INTERNAL_ASSERT(false);
-  }
+  layer.interpreter().process(op, stack);
 }
 
-struct WithoutTop {
-  WithoutTop(): layer_(popDynamicLayer()) {
-  }
-  ~WithoutTop() {
-    pushDynamicLayer(std::move(layer_));
-  }
-
-  DynamicLayer layer_;
-};
+static c10::impl::ForceDispatchKeyGuard
+restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) {
+  return c10::impl::ForceDispatchKeyGuard(key_set);
+}
 
 void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
-  auto cur_level = getDynamicLayerStack().back().layerId();
-  auto cur_key = getDynamicLayerStack().back().key();
+  auto& layer = dynamicLayerStackAccessor().back();
+  auto restore_guard = restoreLocalDispatchKeySetRAII(layer.interpreter().getSavedLocalDispatchKeySet());
+  WithoutTop guard;
 
-  switch (cur_key) {
-    case TransformType::Grad:
-    case TransformType::Jvp:
-      {
-        optional<bool> prev_grad_mode = getDynamicLayerStack().back().prevGradMode();
-        optional<bool> prev_fwd_grad_mode = getDynamicLayerStack().back().prevFwdGradMode();
-        if (cur_key == TransformType::Grad) {
-          TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
-        }
-        if (cur_key == TransformType::Jvp) {
-          TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
-        }
-        auto unwrap = [&](const Tensor& tensor) {
-          if (!tensor.defined()) {
-            return tensor;
-          }
-          auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor);
-          if (!maybe_tensor_wrapper) {
-            return tensor;
-          }
-          auto tensor_wrapper_level = maybe_tensor_wrapper->level().value();
-          TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= cur_level);
-          if (tensor_wrapper_level == cur_level) {
-            return maybe_tensor_wrapper->value();
-          }
-          return tensor;
-        };
-        auto wrap = [&](const Tensor& tensor) {
-          if (!tensor.defined()) {
-            return tensor;
-          }
-          // if (c10::show_dispatch_trace_enabled()) {
-          //   std::cout << "wrap " << cur_level << std::endl;
-          // }
-          return makeTensorWrapper(tensor, cur_level);
-        };
-
-        // TODO: we only need to do the following (marked with !) on in-place functions
-        // that modify sizes or strides. There aren't many of them.
-        // If autograd dispatch key:
-        // 1. (!) Put a copy of all of the args onto the stack
-        // 2. Unwrap all the args in the copy set
-        // 3. Call the operator
-        // 4. Wrap the output
-        // 5. (!) refreshMetadata for all the args in the original set
-        // 6. (!) Pop those args off.
-
-        // Step 1 & 2
-        auto args_size = op.schema().arguments().size();
-        // Step 1
-        auto front = stack->size() - args_size;
-        for (const auto arg_idx : c10::irange(0, args_size)) {
-          stack->push_back((*stack)[front + arg_idx]);
-        }
-        // Step 2
-        foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap);
-
-        auto restore_guard = restoreLocalDispatchKeySetRAII(getDynamicLayerStack().back());
-        WithoutTop guard;
-
-        // See NOTE [grad and vjp interaction with no_grad]
-        optional<c10::AutoGradMode> grad_guard;
-        if (cur_key == TransformType::Grad && prev_grad_mode.has_value() && *prev_grad_mode == false) {
-          grad_guard.emplace(*prev_grad_mode);
-        }
-        optional<c10::AutoFwGradMode> fw_grad_guard;
-        if (cur_key == TransformType::Jvp &&
-            prev_fwd_grad_mode.has_value() && prev_fwd_grad_mode.value() == false) {
-          fw_grad_guard.emplace(*prev_fwd_grad_mode);
-        }
-
-        // Re-dispatch
-        if (dynamicLayerStackAccessor().size() == 0) {
-          sanityCheckStack(op, stack);
-        }
-        op.callBoxed(stack);
-
-        // Step 4, 5, 6
-        auto ret_size = op.schema().returns().size();
-        // Step 4
-        foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), wrap);
-
-        // Step 5
-        auto args_front = stack->size() - args_size - ret_size;
-        for (const auto arg_idx : c10::irange(0, args_size)) {
-          auto& ivalue = (*stack)[args_front + arg_idx];
-          if (!ivalue.isTensor()) {
-            continue;
-          }
-          auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor());
-          if (!maybe_tensor_wrapper) {
-            continue;
-          }
-          maybe_tensor_wrapper->refreshMetadata();
-        }
-
-        // Step 6
-        stack->erase(stack->end() - (args_size + ret_size), stack->end() - ret_size);
-        break;
-      }
-    case TransformType::Vmap:
-      {
-        auto restore_guard = restoreLocalDispatchKeySetRAII(getDynamicLayerStack().back());
-        WithoutTop guard;
-
-        // Re-dispatch
-        if (dynamicLayerStackAccessor().size() == 0) {
-          sanityCheckStack(op, stack);
-        }
-        op.callBoxed(stack);
-
-        break;
-      }
-    case TransformType::Functionalize:
-      {
-        // For now, we don't support nested functionalization calls.
-        // This check just enforces that - after the functionalize kernel runs
-        // and we hit the BackModeFallback, we'll have unwrapped our FunctionalTensors
-        // so we can check that the unwrapped thing is not another (nested) FunctionalTensor.
-        auto args_size = op.schema().arguments().size();
-        sanityCheckNotFunctional(op, stack, args_size);
-
-        auto restore_guard = restoreLocalDispatchKeySetRAII(getDynamicLayerStack().back());
-        WithoutTop guard;
-
-        // Re-dispatch
-        if (dynamicLayerStackAccessor().size() == 0) {
-          sanityCheckStack(op, stack);
-        }
-        op.callBoxed(stack);
-
-        auto ret_size = op.schema().returns().size();
-        sanityCheckNotFunctional(op, stack, ret_size);
-        break;
-      }
-    default:
-      TORCH_INTERNAL_ASSERT(false);
-  }
+  layer.interpreter().sendToNextInterpreter(op, stack);
 }
 
 TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
diff --git a/functorch/functorch/csrc/DynamicLayer.h b/functorch/functorch/csrc/DynamicLayer.h
index 8239ff4..cf84311 100644
--- a/functorch/functorch/csrc/DynamicLayer.h
+++ b/functorch/functorch/csrc/DynamicLayer.h
@@ -9,9 +9,14 @@
 #include <c10/core/DispatchKey.h>
 #include <ATen/core/function_schema.h>
 #include <c10/util/Optional.h>
+#include <c10/util/variant.h>
 #include <unordered_map>
 #include <mutex>
 #include <c10/core/impl/LocalDispatchKeySet.h>
+#include <functorch/csrc/Interpreter.h>
+#include <functorch/csrc/VmapInterpreter.h>
+#include <functorch/csrc/ADInterpreters.h>
+#include <functorch/csrc/FunctionalizeInterpreter.h>
 
 // Forward declared bc I am lazy
 namespace c10 { struct AutogradMetaInterface; }
@@ -19,23 +24,9 @@
 namespace at {
 namespace functorch {
 
-enum RandomnessType {
-    Error,      // always errors when calling a random function
-    Same,       // randomness appears the same across batches
-    Different,  // randomness appears different across batches
-    END
-};
-
-enum class TransformType {
-  Torch,  // Unused
-  Vmap,
-  Grad,  // reverse-mode AD, aka vjp
-  Jvp,  // forward-mode AD
-  Functionalize,
-};
-
-std::ostream& operator<<(std::ostream& os, const TransformType& t);
-
+// TODO: we can excise DynamicLayer in favor of Interpreter,
+// But I am going to leave it for now as a compatiblity shim to avoid
+// needing to refactor a lot of callsites...
 struct FUNCTORCH_API DynamicLayer {
   explicit DynamicLayer(
       TransformType transform_type,
@@ -49,36 +40,15 @@
   TransformType key() const;
   int64_t layerId() const;
 
+  const Interpreter& interpreter() const { return interpreter_; }
+  Interpreter& interpreter() { return interpreter_; }
+
   // Only valid for vmap
   int64_t batchSize() const;
   RandomnessType randomness() const;
 
-  // only valid for grad-based transforms
-  optional<bool> prevGradMode() const;
-
-  // only valid for jvp transform
-  optional<bool> prevFwdGradMode() const;
-
-  void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset);
-  void clearSavedLocalDispatchKeySet();
-  c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const;
-
-  // only valid for functionalization
-  optional<bool> functionalizeAddBackViews() const;
  private:
-  TransformType transform_type_;
-  int64_t layerId_;
-
-  // Honestly these should be a union or some extendable metadata class.
-  // Not doing that for now because I don't think we'll use this mechanism for very long.
-  optional<int64_t> batchSize_;
-  optional<RandomnessType> randomness_;
-  optional<bool> prevGradMode_;
-  optional<bool> prevFwdGradMode_;
-
-  optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
-
-  optional<bool> functionalizeAddBackViews_;
+  Interpreter interpreter_;
 };
 
 FUNCTORCH_API int64_t initAndPushDynamicLayer(
@@ -109,15 +79,12 @@
 // add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
 bool isInplaceOp(const c10::FunctionSchema& schema);
 
-// Applies the following for-loop:
-// for i in range(begin, end):
-//   args[i] = func(args[i])
-void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
-    std::function<Tensor(const Tensor&)> func);
+Tensor unwrapIfDead(const Tensor& tensor);
 
 // Pretty printers
 std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
 std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
 
+
 }
 } // namespace at
diff --git a/functorch/functorch/csrc/FunctionalizeInterpreter.cpp b/functorch/functorch/csrc/FunctionalizeInterpreter.cpp
new file mode 100644
index 0000000..4242305
--- /dev/null
+++ b/functorch/functorch/csrc/FunctionalizeInterpreter.cpp
@@ -0,0 +1,68 @@
+#include <functorch/csrc/FunctionalizeInterpreter.h>
+#include <functorch/csrc/DynamicLayer.h>
+#include <ATen/FunctionalTensorWrapper.h>
+
+namespace at { namespace functorch {
+
+static void sanityCheckNotFunctional(const c10::OperatorHandle& op, torch::jit::Stack* stack, size_t num_args) {
+  foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
+      [](const Tensor& tensor) {
+        TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensor));
+        return tensor;
+      });
+}
+
+void FunctionalizeInterpreterPtr::processImpl(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack) {
+  DispatchKeySet exclude = keysToExcludeWhenEnteringDynamicLayer(TransformType::Functionalize);
+
+  // We always want to call the functionalization kernels if functionalize() is on the layer stack.
+  // It's the responsibility of the functionalization kernel to no-op and redispatch
+  // if none of the input tensors are functional.
+  setup_dispatch_key_tls(exclude, DispatchKeySet(DispatchKey::Functionalize));
+  auto functionalization_add_back_views = functionalizeAddBackViews();
+  // We have some side-car TLS that we can set to toggle the functionaliation behavior.
+  // If set, then we functionalization will only remove mutations, instead of
+  // removing both mutations AND view operators.
+  at::functionalization::impl::FunctionalizationReapplyViewsGuard functional_guard(functionalization_add_back_views);
+
+  op.callBoxed(stack);
+
+  auto ret_size = op.schema().returns().size();
+  foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(),
+    [&](const Tensor& tensor) {
+      if (at::functionalization::impl::isFunctionalTensor(tensor)) {
+        auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
+        // Functorch is responsible for setting the level on the wrapper, since we don't
+        // have that info available in core (for now).
+        // We could just "propagate" the level from the input tensors inside of the functionalize kernels,
+        // but unfortunately we can't do that for factory operators.
+        wrapper->set_level(level());
+      }
+      return tensor;
+    }
+  );
+}
+
+void FunctionalizeInterpreterPtr::sendToNextInterpreterImpl(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack) {
+  // For now, we don't support nested functionalization calls.
+  // This check just enforces that - after the functionalize kernel runs
+  // and we hit the BackModeFallback, we'll have unwrapped our FunctionalTensors
+  // so we can check that the unwrapped thing is not another (nested) FunctionalTensor.
+  auto args_size = op.schema().arguments().size();
+  sanityCheckNotFunctional(op, stack, args_size);
+
+  // Re-dispatch
+  if (getDynamicLayerStack().size() == 0) {
+    sanityCheckStack(op, stack);
+  }
+  op.callBoxed(stack);
+
+  auto ret_size = op.schema().returns().size();
+  sanityCheckNotFunctional(op, stack, ret_size);
+}
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/FunctionalizeInterpreter.h b/functorch/functorch/csrc/FunctionalizeInterpreter.h
new file mode 100644
index 0000000..5475b38
--- /dev/null
+++ b/functorch/functorch/csrc/FunctionalizeInterpreter.h
@@ -0,0 +1,19 @@
+#pragma once
+#include <functorch/csrc/Interpreter.h>
+
+namespace at { namespace functorch {
+
+struct FunctionalizeInterpreterPtr {
+  explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); }
+  TransformType key() const { return base_->key(); }
+  int64_t level() const { return base_->level(); }
+  void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  bool functionalizeAddBackViews() const {
+    return c10::get<FunctionalizeInterpreterMeta>(base_->meta()).functionalizeAddBackViews_;
+  }
+ private:
+  const Interpreter* base_;
+};
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/Interpreter.cpp b/functorch/functorch/csrc/Interpreter.cpp
new file mode 100644
index 0000000..35c1150
--- /dev/null
+++ b/functorch/functorch/csrc/Interpreter.cpp
@@ -0,0 +1,108 @@
+#include <functorch/csrc/Interpreter.h>
+#include <functorch/csrc/BatchedTensorImpl.h>
+#include <functorch/csrc/TensorWrapper.h>
+#include <functorch/csrc/VmapInterpreter.h>
+#include <functorch/csrc/FunctionalizeInterpreter.h>
+#include <functorch/csrc/ADInterpreters.h>
+
+namespace at { namespace functorch {
+
+constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
+  kDynamicLayerFrontModeKey,
+  kDynamicLayerBackModeKey,
+  kGradWrapperKey,
+  DispatchKey::Functionalize,
+  // DispatchKey::Batched,
+  kBatchedKey,
+  DispatchKey::PythonTLSSnapshot,
+  DispatchKey::ADInplaceOrView
+}) | autograd_dispatch_keyset;
+
+static DispatchKeySet keysForEnteringDynamicLayer(TransformType key) {
+  if (key == TransformType::Vmap) {
+    // NB: Does not include kVmapModeKey. We may modulate the key when
+    // constructing the DynamicLayer, but we don't control it when entering/exiting
+    // the DynamicLayer.
+    return DispatchKeySet({kBatchedKey});
+  } else if (key == TransformType::Grad || key == TransformType::Jvp) {
+    return autograd_dispatch_keyset.add(DispatchKey::ADInplaceOrView);
+  } else if (key == TransformType::Functionalize) {
+    return DispatchKeySet(DispatchKey::Functionalize);
+  } else {
+    TORCH_INTERNAL_ASSERT(false, "Unsupported key: ", key);
+  }
+}
+
+DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key) {
+  DispatchKeySet exclude = all_dynlayer_keyset;
+  exclude = exclude.remove(kDynamicLayerBackModeKey);
+  exclude = exclude - keysForEnteringDynamicLayer(key);
+  return exclude;
+}
+
+void setup_dispatch_key_tls(DispatchKeySet exclude, DispatchKeySet include) {
+  auto local_keyset = c10::impl::tls_local_dispatch_key_set();
+  local_keyset.excluded_ = local_keyset.excluded_ | exclude;
+  local_keyset.included_ = local_keyset.included_ | include;
+  c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
+}
+
+std::ostream& operator<<(std::ostream& os, const TransformType& t) {
+  switch (t) {
+    case TransformType::Torch:
+      os << "Torch";
+      break;
+    case TransformType::Vmap:
+      os << "Vmap";
+      break;
+    case TransformType::Grad:
+      os << "Grad";
+      break;
+    case TransformType::Jvp:
+      os << "Jvp";
+      break;
+    case TransformType::Functionalize:
+      os << "Functionalize";
+      break;
+    default:
+      TORCH_INTERNAL_ASSERT(false);
+  }
+  return os;
+}
+
+void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  auto num_args = op.schema().arguments().size();
+  foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
+      [](const Tensor& tensor) {
+
+        auto* wrapper = maybeGetTensorWrapper(tensor);
+        TORCH_INTERNAL_ASSERT(wrapper == nullptr);
+        auto* batched = maybeGetBatchedImpl(tensor);
+        TORCH_INTERNAL_ASSERT(batched == nullptr);
+        return tensor;
+      });
+}
+
+#define INTERPRETER_DISPATCH(type, method) \
+  switch (key()) { \
+    case TransformType::Vmap: \
+      return VmapInterpreterPtr(this). method; \
+    case TransformType::Grad: \
+      return GradInterpreterPtr(this). method; \
+    case TransformType::Jvp: \
+      return JvpInterpreterPtr(this). method; \
+    case TransformType::Functionalize: \
+      return FunctionalizeInterpreterPtr(this). method; \
+    default: \
+      TORCH_INTERNAL_ASSERT(false, "Unrecognized transform"); \
+  }
+
+void Interpreter::process(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  INTERPRETER_DISPATCH(key_, SINGLE_ARG(processImpl(op, stack)));
+}
+
+void Interpreter::sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  INTERPRETER_DISPATCH(key_, SINGLE_ARG(sendToNextInterpreterImpl(op, stack)));
+}
+
+}}
diff --git a/functorch/functorch/csrc/Interpreter.h b/functorch/functorch/csrc/Interpreter.h
new file mode 100644
index 0000000..2a1a426
--- /dev/null
+++ b/functorch/functorch/csrc/Interpreter.h
@@ -0,0 +1,187 @@
+#pragma once
+
+// variant.h doesn't clean up after itself...
+#include <c10/util/variant.h>
+#undef DECLTYPE_AUTO
+
+#include <functorch/csrc/Macros.h>
+#include <functorch/csrc/Constants.h>
+#include <ATen/core/dispatch/Dispatcher.h>
+#include <c10/core/impl/LocalDispatchKeySet.h>
+#include <c10/util/Optional.h>
+
+namespace at { namespace functorch {
+
+// NOTE: [functorch interpreter stack]
+//
+// functorch's dispatching system uses a stack of interpreters.
+// Historically we've referred to this as the "DynamicLayerStack".
+//
+// An interpreter is something that reads in the code it is passed
+// and then executes it. We have a different interpreter per-transform:
+// the "VmapInterpreter" is responsible for reading in operators (like aten::mv)
+// and executing the batched version of it (the batching rule for aten::mv).
+//
+// Concretely, each interpreter is responsible for two things:
+//
+// 1) process(ophandle, stack)
+// Given an operator handle and a stack of arguments, the interpreter is
+// responsible for figuring out how to execute the operation under the semantics
+// of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call
+// the batching rule.
+//
+// The batching rules are stored as kernels on the FuncTorchBatched key, so the way
+// VmapInterpreter calls the batching rule is roughly: (A) exclude all
+// dispatch keys aside from the Batched key, (B) redispatch so we get to the
+// Batched key.
+//
+// 2) sendToNextInterpreter(ophandle, stack)
+// The VmapInterpreter, when it sees aten::mv, will process it into a call to
+// aten::mm. It then needs to send the call to aten::mm to the next interpreter
+// in the interpreter stack.
+//
+// The VmapInterpreter just does this via a call to ophandle.callBoxed(stack)
+// and most Interpreters will implement it this way.
+
+enum RandomnessType {
+    Error,      // always errors when calling a random function
+    Same,       // randomness appears the same across batches
+    Different,  // randomness appears different across batches
+    END
+};
+
+enum class TransformType {
+  Torch,  // Unused
+  Vmap,
+  Grad,  // reverse-mode AD, aka vjp
+  Jvp,  // forward-mode AD
+  Functionalize,
+};
+
+std::ostream& operator<<(std::ostream& os, const TransformType& t);
+
+// NOTE: [Interpreter "subclassing" design]
+//
+// How are various Interpreters for different transforms (vmap, grad, ...)
+// implemented?
+//
+// Accessing interpreters is in the hot-path of functorch so we have a constraint
+// that this code must be as fast as possible.
+//
+// As a result, we stay away from virtual methods and this causes our code
+// to look a little funny.
+//
+// `Interpreter` is the struct for Interpreters. It holds ALL of the
+// relevant information (what type of interpreter it is and the metadata).
+// Metadata for each interpreter is represented as a Union (c10::variant)
+// of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...).
+//
+// Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this
+// if you want to access the metadata fields (like batchSize and randomness).
+//
+// Each type of interpreter (e.g. Vmap) has a convenience struct
+// (e.g. VmapInterpreterPtr) associated with it.
+//
+// Construct the convenience struct with VmapInterpreterPtr(Interpreter*),
+// and then one can access methods on VmapInterpreterPtr like so:
+// >>> VmapInterpreterPtr(&interpreter).batchSize()
+//
+// Finally, Interpreter::process switches on the type of the interpreter
+// and calls one of {Transform}Intepreter::processImpl under the hood.
+// Same for Interpreter::sendToNextInterpreter :)
+
+struct VmapInterpreterMeta {
+  explicit VmapInterpreterMeta(int64_t batchSize, RandomnessType randomness) :
+    batchSize_(batchSize), randomness_(randomness) {}
+  int64_t batchSize_;
+  RandomnessType randomness_;
+};
+
+struct GradInterpreterMeta {
+  explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
+  bool prevGradMode_;
+};
+
+struct JvpInterpreterMeta {
+  explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
+  bool prevFwdGradMode_;
+};
+
+struct FunctionalizeInterpreterMeta {
+  explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
+    functionalizeAddBackViews_(functionalizeAddBackViews) {}
+  bool functionalizeAddBackViews_;
+};
+
+typedef c10::variant<
+  int64_t,
+  GradInterpreterMeta,
+  JvpInterpreterMeta,
+  VmapInterpreterMeta,
+  FunctionalizeInterpreterMeta
+> InterpreterMeta;
+
+
+struct Interpreter {
+  // factory functions
+  static Interpreter Vmap(int64_t level, int64_t batchSize, RandomnessType randomness) {
+    return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(batchSize, randomness));
+  }
+  static Interpreter Grad(int64_t level, bool prevGradMode) {
+    return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode));
+  }
+  static Interpreter Jvp(int64_t level, bool prevFwdGradMode) {
+    return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode));
+  }
+  static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) {
+    return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews));
+  }
+
+  // methods
+  TransformType key() const { return type_; }
+  int64_t level() const { return level_; }
+  const InterpreterMeta& meta() const { return meta_; }
+
+  void process(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+
+  void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
+    TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
+    savedLocalDispatchKeySet_ = std::move(keyset);
+  }
+  void clearSavedLocalDispatchKeySet() {
+    TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
+    savedLocalDispatchKeySet_ = c10::nullopt;
+  }
+  c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
+    TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
+    return *savedLocalDispatchKeySet_;
+  }
+
+  // Please don't use this
+  explicit Interpreter() = default;
+
+ private:
+  explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
+    type_(type), level_(level), meta_(meta) {}
+
+  // fields
+  TransformType type_;
+  int64_t level_;
+  optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
+  InterpreterMeta meta_;
+};
+
+// Applies the following for-loop:
+// for i in range(begin, end):
+//   args[i] = func(args[i])
+void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
+    std::function<Tensor(const Tensor&)> func);
+
+DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key);
+
+void setup_dispatch_key_tls(DispatchKeySet exclude, DispatchKeySet include);
+
+void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/VmapInterpreter.cpp b/functorch/functorch/csrc/VmapInterpreter.cpp
new file mode 100644
index 0000000..a8f0283
--- /dev/null
+++ b/functorch/functorch/csrc/VmapInterpreter.cpp
@@ -0,0 +1,24 @@
+#include <functorch/csrc/VmapInterpreter.h>
+#include <functorch/csrc/DynamicLayer.h>
+
+namespace at { namespace functorch {
+
+void VmapInterpreterPtr::processImpl(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack) {
+  DispatchKeySet exclude = keysToExcludeWhenEnteringDynamicLayer(TransformType::Vmap);
+  setup_dispatch_key_tls(exclude, DispatchKeySet(kVmapModeKey));
+  op.callBoxed(stack);
+}
+
+void VmapInterpreterPtr::sendToNextInterpreterImpl(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack) {
+  // Re-dispatch
+  if (getDynamicLayerStack().size() == 0) {
+    sanityCheckStack(op, stack);
+  }
+  op.callBoxed(stack);
+}
+
+}} // namespace at::functorch
diff --git a/functorch/functorch/csrc/VmapInterpreter.h b/functorch/functorch/csrc/VmapInterpreter.h
new file mode 100644
index 0000000..084cea9
--- /dev/null
+++ b/functorch/functorch/csrc/VmapInterpreter.h
@@ -0,0 +1,22 @@
+#pragma once
+#include <functorch/csrc/Interpreter.h>
+
+namespace at { namespace functorch {
+
+struct VmapInterpreterPtr {
+  explicit VmapInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Vmap); }
+  TransformType key() const { return base_->key(); }
+  int64_t level() const { return base_->level(); }
+  void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  int64_t batchSize() const {
+    return c10::get<VmapInterpreterMeta>(base_->meta()).batchSize_;
+  }
+  RandomnessType randomness() const {
+    return c10::get<VmapInterpreterMeta>(base_->meta()).randomness_;
+  }
+ private:
+  const Interpreter* base_;
+};
+
+}} // namespace at::functorch