[pytorch] Disable new autograd fallback for mobile builds (#105750)
Summary:
To save on binary size, some of the mobile configs don't include the
autograd kernels for built-in operators (VariableTypeEverything.cpp).
For the mobile build:
- we don't care about having a nice autograd fallback that warns if
an operator has incorrect autograd support. If you're running
a custom operator on mobile then it's already too late for us to warn
or error on it.
- for perf reasons, we do not want mobile to go through autograd_fallbac
for all operators (the boxing/unboxing adds overhead).
As a result, on mobile we set the fallback to the fallthrough.
Test Plan: existing tests and benchmarks
Differential Revision: D47674272
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105750
Approved by: https://github.com/soulitzer
diff --git a/aten/src/ATen/core/VariableFallbackKernel.cpp b/aten/src/ATen/core/VariableFallbackKernel.cpp
index 88dda52..b801eb2 100644
--- a/aten/src/ATen/core/VariableFallbackKernel.cpp
+++ b/aten/src/ATen/core/VariableFallbackKernel.cpp
@@ -33,7 +33,22 @@
c10::DispatchKeySet dispatch_keys,
torch::jit::Stack* stack);
+#ifdef C10_MOBILE
+// NOTE [mobile/edge builds and the autograd fallback]
+// To save on binary size, some of the mobile configs don't include the
+// autograd kernels for built-in operators (VariableTypeEverything.cpp).
+// For the mobile build:
+// - we don't care about having a nice autograd fallback that warns if
+// an operator has incorrect autograd support. If you're running
+// a custom operator on mobile then it's already too late for us to warn
+// or error on it.
+// - for perf reasons, we do not want mobile to go through autograd_fallback
+// for all operators (the boxing/unboxing adds overhead).
+// As a result, on mobile we set the fallback to the fallthrough.
+#define AUTOGRAD_FALLBACK torch::CppFunction::makeFallthrough()
+#else
#define AUTOGRAD_FALLBACK torch::CppFunction::makeFromBoxedFunction<&autograd_fallback>()
+#endif
TORCH_LIBRARY_IMPL(_, AutogradOther, m) {
m.fallback(AUTOGRAD_FALLBACK);