blob: 39895f3abc556c6899b62c7bdfe061a95a05c36f [file] [log] [blame]
#include <ATen/functorch/VmapInterpreter.h>
#include <ATen/functorch/DynamicLayer.h>
namespace at { namespace functorch {
void VmapInterpreterPtr::processImpl(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
DispatchKeySet exclude = keysToExcludeWhenEnteringDynamicLayer(TransformType::Vmap);
setup_dispatch_key_tls(exclude, DispatchKeySet(DispatchKey::FuncTorchVmapMode));
op.callBoxed(stack);
}
void VmapInterpreterPtr::sendToNextInterpreterImpl(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
// Re-dispatch
if (getDynamicLayerStack().size() == 0) {
sanityCheckStack(op, stack);
}
op.callBoxed(stack);
}
}} // namespace at::functorch