TestVmapOperators: add structured tests that batching rules get invoked (#43731)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43731

After this PR, for each test in TestVmapOperators, TestVmapOperators
tests that the test never invokes the slow vmap fallback path. The
rationale behind this change is that TestVmapOperators is used for
testing batching rules and we want confidence that the batching rules
actually get invoked.

We set this up using a similar mechanism to the CUDA memory leak check:
(https://github.com/pytorch/pytorch/blob/bff741a8497887c8ee22ffa9f0208565072a74dc/torch/testing/_internal/common_utils.py#L506-L511)

This PR also implements the batching rule for `to.dtype_layout`; the new
testing caught that we were testing vmap on `to.dtype_layout` but it
didn't actually have a batching rule implemented!

Test Plan: - New tests in `pytest test/test_vmap.py -v` that test the mechanism.

Reviewed By: ezyang

Differential Revision: D23380729

Pulled By: zou3519

fbshipit-source-id: 6a4b97a7fa7b4e1c5be6ad80d6761e0d5b97bb8c
diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp
index 0b339f0..36aef83 100644
--- a/aten/src/ATen/BatchingRegistrations.cpp
+++ b/aten/src/ATen/BatchingRegistrations.cpp
@@ -314,6 +314,29 @@
   return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
 }
 
+// I am quite sad that we need to register operators with exploded TensorOptions,
+// even though the native:: implementations can use TensorOptions&.
+// This also makes it hard to metaprogram: i.e., we can't use
+// unary_pointwise_batching_rule<..., at::to> because at::to takes TensorOptions& (!!)
+Tensor to_dtype_layout_batching_rule(
+    const Tensor& self,
+    optional<ScalarType> dtype,
+    optional<Layout> layout,
+    optional<Device> device,
+    optional<bool> pin_memory,
+    bool non_blocking, bool copy,
+    optional<MemoryFormat> memory_format) {
+  auto options = TensorOptions()
+    .dtype(dtype)
+    .layout(layout)
+    .device(device)
+    .pinned_memory(pin_memory);
+  auto* input_batched = unsafeGetBatchedImpl(self);
+  auto output_physical = input_batched->value().to(options, non_blocking, copy, memory_format);
+  auto old_bdims = input_batched->bdims();
+  return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
+}
+
 TORCH_LIBRARY_IMPL(_, Batched, m) {
   m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
 }
@@ -399,6 +422,7 @@
   TO_BATCHING_RULE("to.device", Device, ScalarType, bool, bool, optional<MemoryFormat>)
   TO_BATCHING_RULE("to.dtype", ScalarType, bool, bool, optional<MemoryFormat>)
   TO_BATCHING_RULE("to.other", const Tensor&, bool, bool, optional<MemoryFormat>)
+  m.impl("to.dtype_layout", to_dtype_layout_batching_rule);
 #undef TO_BATCHING_RULE
 
   using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
diff --git a/test/test_vmap.py b/test/test_vmap.py
index 0cb1831..8902ee6 100644
--- a/test/test_vmap.py
+++ b/test/test_vmap.py
@@ -5,6 +5,7 @@
 import warnings
 from torch.testing._internal.common_device_type import instantiate_device_type_tests
 from torch.testing._internal.common_utils import TEST_WITH_ROCM
+import types
 
 class TestVmapAPI(TestCase):
     def test_non_tensor_output_raises(self):
@@ -675,33 +676,95 @@
     result_as_tuple = (result,) if op_has_single_return else result
     self.assertTrue(result[0].requires_grad)
 
-class TestVmapOperators(TestCase):
+def should_allow_vmap_fallback_usage(fn):
+    return getattr(fn, '_allow_vmap_fallback_usage', False)
+
+def allowVmapFallbackUsage(fn):
+    fn._allow_vmap_fallback_usage = True
+    return fn
+
+# All tests of TestVmapBase check that the slow vmap fallback is never invoked.
+# This is so that we can incrementally add batching rules for operators to
+# replace the slow vmap fallback path for said operators. To skip this check,
+# please use the allowVmapFallbackUsage decorator.
+#
+# NB: Don't add tests to TestVmapBase directly, unless you want them to run
+# on every subclass of TestVmapBase. Add them to e.g. TestVmapOperators.
+#
+# NB: TestVmapBase is a nested class. This prevents test runners from picking
+# it up and running it.
+class Namespace:
+    class TestVmapBase(TestCase):
+        def __init__(self, method_name='runTest'):
+            super().__init__(method_name)
+
+            test_method = getattr(self, method_name, None)
+            if test_method is None:
+                return
+
+            if not should_allow_vmap_fallback_usage(test_method):
+                setattr(self, method_name,
+                        self._wrap_method_with_vmap_fallback_check(test_method))
+
+        def _wrap_method_with_vmap_fallback_check(self, method):
+            msg = (
+                'Expected the test to not invoke the vmap fallback path, i.e., '
+                'all of the operators being tested in this test should have batching '
+                'rules implemented. If you are intentionally testing something to '
+                'do with the fallback path, use allowVmapFallbackUsage. Otherwise, '
+                'please make sure that batching rules are implemented for the '
+                'operator(s) being tested.'
+            )
+
+            @functools.wraps(method)
+            def wrapper(self, *args, **kwargs):
+                regex = r'falling back to slow \(for loop and stack\) implementation'
+                with warnings.catch_warnings(record=True) as wa:
+                    warnings.simplefilter('always')
+                    method(*args, **kwargs)
+                    for captured_warning in wa:
+                        self.assertNotRegex(str(captured_warning.message), regex, msg)
+            return types.MethodType(wrapper, self)
+
+        @allowVmapFallbackUsage
+        def test_vmap_fallback_check_ok(self):
+            # One day we'll implement a batching rule for torch.var_mean.
+            # When that happens, please change the example to use an
+            # operator that doesn't have a batching rule implemented.
+            op_using_fallback = torch.var_mean
+            vmap(op_using_fallback)(torch.rand(3))
+
+        def test_vmap_fallback_check(self):
+            @self._wrap_method_with_vmap_fallback_check
+            def no_fallback(self):
+                pass
+
+            # One day we'll implement a batching rule for torch.var_mean.
+            # When that happens, please change the example to use an
+            # operator that doesn't have a batching rule implemented.
+            op_using_fallback = torch.var_mean
+
+            @self._wrap_method_with_vmap_fallback_check
+            def uses_fallback(self):
+                vmap(op_using_fallback)(torch.rand(3))
+
+            no_fallback(self)
+
+            with self.assertRaises(AssertionError):
+                uses_fallback(self)
+
+
+class TestVmapOperators(Namespace.TestVmapBase):
     def _vmap_test(self, *args, **kwargs):
         return _vmap_test(self, *args, **kwargs)
 
     def _vmap_view_test(self, *args, **kwargs):
         self._vmap_test(*args, **kwargs, check_view=True)
 
-    def _assert_doesnt_use_vmap_fallback(self, vmap_args, inputs):
-        regex = r'falling back to slow \(for loop and stack\) implementation'
-        with warnings.catch_warnings(record=True) as wa:
-            result = vmap(*vmap_args)(*inputs)
-            for captured_warning in wa:
-                self.assertNotRegex(str(captured_warning.message), regex)
-
-    def test_assert_doesnt_use_vmap_fallback(self):
-        with self.assertRaises(AssertionError):
-            # One day we'll implement a batching rule for torch.var_mean.
-            # When that happens, please change the example to use an
-            # operator that doesn't have a batching rule implemented.
-            self._assert_doesnt_use_vmap_fallback([torch.var_mean], [torch.rand(3)])
-
     def _test_unary(self, op, getter, device):
         test = self._vmap_test
         B0, B1 = 7, 11
 
-        self._assert_doesnt_use_vmap_fallback([op], [getter([B0], device)])
-
         # Single vmap, various in_dims / out_dims
         test(op, [getter([B0, 3], device)])
         test(op, [getter([2, 5, B0, 3], device)], in_dims=2)
@@ -775,9 +838,6 @@
             device = 'cpu'
             B0, B1 = 7, 11
 
-            self._assert_doesnt_use_vmap_fallback(
-                [op], (getter([B0], device), getter([B0], device)))
-
             # Single vmap: op(Tensor, Tensor)
             test(op, (getter([B0, 3], device), getter([B0, 3], device)))
             test(op, (getter([B0], device), getter([B0, 2, 3], device)))
@@ -1175,7 +1235,7 @@
     return tuple(arg for arg in as_tuple(args)
                  if isinstance(arg, torch.Tensor) and arg.requires_grad)
 
-class TestVmapBatchedGradient(TestCase):
+class TestVmapBatchedGradient(Namespace.TestVmapBase):
     def _vmap_test(self, *args, **kwargs):
         return _vmap_test(self, *args, **kwargs)
 
@@ -1233,18 +1293,9 @@
                         check_propagates_grad=False)
 
     def test_sigmoid(self, device):
-        # Maybe we can make the "check that the slow fallback was not invoked"
-        # into a context manager, because it's used a lot. I'll leave that for
-        # future work.
-        regex = r'falling back to slow \(for loop and stack\) implementation'
-        with warnings.catch_warnings(record=True) as wa:
-            warnings.simplefilter('always')
-            x = torch.randn(2, 3, requires_grad=True, device=device)
-            self._batched_grad_test(Tensor.sigmoid, (x,), {})
-            self._batched_grad_grad_test(Tensor.sigmoid, (x,), {})
-
-            for captured_warning in wa:
-                self.assertNotRegex(str(captured_warning.message), regex)
+        x = torch.randn(2, 3, requires_grad=True, device=device)
+        self._batched_grad_test(Tensor.sigmoid, (x,), {})
+        self._batched_grad_grad_test(Tensor.sigmoid, (x,), {})
 
 instantiate_device_type_tests(
     TestVmapBatchedGradient,