Implement batching rules for some unary ops (#43059)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43059

This PR implements batching rules for some unary ops. In particular, it
implements the batching rules for the unary ops that take a single
tensor as input (and nothing else).

The batching rule for a unary op is:
(1) grab the physical tensor straight out of the BatchedTensor
(2) call the unary op
(3) rewrap the physical tensor in a BatchedTensor

Test Plan: - new tests `pytest test/test_vmap.py -v -k "Operators"`

Reviewed By: ezyang

Differential Revision: D23132277

Pulled By: zou3519

fbshipit-source-id: 24b9d7535338207531d767155cdefd2c373ada77
diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp
index 0248d3b..9e1554d 100644
--- a/aten/src/ATen/BatchingRegistrations.cpp
+++ b/aten/src/ATen/BatchingRegistrations.cpp
@@ -226,6 +226,14 @@
   return self_physical.newLogicalFromPhysical(result);
 }
 
+template <Tensor (*Op)(const Tensor&)>
+Tensor unary_pointwise_batching_rule(const Tensor& input) {
+  auto* input_batched = unsafeGetBatchedImpl(input);
+  auto output_physical = Op(input_batched->value());
+  auto old_bdims = input_batched->bdims();
+  return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
+}
+
 TORCH_LIBRARY_IMPL(_, Batched, m) {
   m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
 }
@@ -268,6 +276,40 @@
   m.impl("unsqueeze", unsqueeze_batching_rule);
   m.impl("view", view_batching_rule);
   m.impl("view_as", native::view_as); // composite wrt autograd
+
+  // unary pointwise, out-of-place, no additional arguments.
+#define UNARY_POINTWISE(op) m.impl(#op, unary_pointwise_batching_rule<at::op>);
+  UNARY_POINTWISE(abs);
+  UNARY_POINTWISE(acos);
+  UNARY_POINTWISE(asin);
+  UNARY_POINTWISE(atan);
+  UNARY_POINTWISE(ceil);
+  UNARY_POINTWISE(cos);
+  UNARY_POINTWISE(cosh);
+  UNARY_POINTWISE(digamma);
+  UNARY_POINTWISE(exp);
+  UNARY_POINTWISE(expm1);
+  UNARY_POINTWISE(floor);
+  UNARY_POINTWISE(frac);
+  UNARY_POINTWISE(lgamma);
+  UNARY_POINTWISE(log);
+  UNARY_POINTWISE(log10);
+  UNARY_POINTWISE(log1p);
+  UNARY_POINTWISE(log2);
+  UNARY_POINTWISE(neg);
+  UNARY_POINTWISE(reciprocal);
+  UNARY_POINTWISE(relu);
+  UNARY_POINTWISE(round);
+  UNARY_POINTWISE(rsqrt);
+  UNARY_POINTWISE(sigmoid);
+  UNARY_POINTWISE(sign);
+  UNARY_POINTWISE(sin);
+  UNARY_POINTWISE(sinh);
+  UNARY_POINTWISE(sqrt);
+  UNARY_POINTWISE(tan);
+  UNARY_POINTWISE(tanh);
+  UNARY_POINTWISE(trunc);
+#undef UNARY_POINTWISE
 }
 
 } // namespace at
diff --git a/test/test_vmap.py b/test/test_vmap.py
index 4831668..164905a 100644
--- a/test/test_vmap.py
+++ b/test/test_vmap.py
@@ -654,6 +654,79 @@
     def _vmap_view_test(self, *args, **kwargs):
         self._vmap_test(*args, **kwargs, check_view=True)
 
+    def _assert_doesnt_use_vmap_fallback(self, vmap_args, inputs):
+        regex = r'falling back to slow \(for loop and stack\) implementation'
+        with warnings.catch_warnings(record=True) as wa:
+            result = vmap(*vmap_args)(*inputs)
+            for captured_warning in wa:
+                self.assertNotRegex(str(captured_warning.message), regex)
+
+    def test_assert_doesnt_use_vmap_fallback(self):
+        with self.assertRaises(AssertionError):
+            # One day we'll implement a batching rule for torch.var_mean.
+            # When that happens, please change the example to use an
+            # operator that doesn't have a batching rule implemented.
+            self._assert_doesnt_use_vmap_fallback([torch.var_mean], [torch.rand(3)])
+
+    def test_unary_pointwise_ops(self):
+        def get_rand(size, device):
+            return [torch.rand(size, device=device)]
+
+        def get_randp1(size, device):
+            return [torch.rand(size, device=device) + 1]
+
+        def get_randn(size, device):
+            return [torch.randn(size, device=device)]
+
+        cases = [
+            (torch.abs, get_randn),
+            (torch.acos, get_rand),
+            (torch.asin, get_rand),
+            (torch.atan, get_rand),
+            (torch.ceil, get_randn),
+            (torch.cos, get_rand),
+            (torch.cosh, get_rand),
+            (torch.digamma, get_rand),
+            (torch.exp, get_randn),
+            (torch.expm1, get_randn),
+            (torch.floor, get_randn),
+            (torch.frac, get_randn),
+            (torch.lgamma, get_rand),
+            (torch.log, get_randp1),
+            (torch.log10, get_randp1),
+            (torch.log1p, get_randp1),
+            (torch.log2, get_randp1),
+            (torch.neg, get_randn),
+            (torch.reciprocal, get_randp1),
+            (torch.relu, get_randn),
+            (torch.round, get_randn),
+            (torch.rsqrt, get_randp1),
+            (torch.sigmoid, get_randn),
+            (torch.sign, get_randn),
+            (torch.sin, get_rand),
+            (torch.sinh, get_rand),
+            (torch.sqrt, get_rand),
+            (torch.tan, get_rand),
+            (torch.tanh, get_rand),
+            (torch.trunc, get_randn),
+        ]
+        test = self._vmap_test
+        B0, B1 = 7, 11
+        for op, getter in cases:
+            device = 'cpu'
+
+            self._assert_doesnt_use_vmap_fallback([op], getter([B0], device))
+
+            # Single vmap, various in_dims / out_dims
+            test(op, getter([B0, 3], device))
+            test(op, getter([2, 5, B0, 3], device), in_dims=2)
+            test(op, getter([2, 5, B0, 3], device), in_dims=2, out_dims=2)
+
+            # Doubly nested vmap
+            test(vmap(op), getter([B0, B1], device))
+            test(vmap(op), getter([B1, 2, 5, B0, 3], device), in_dims=2)
+            test(vmap(op, in_dims=2), getter([2, 5, B0, B1, 3], device), in_dims=2, out_dims=2)
+
     def test_chunk(self):
         test = self._vmap_view_test
         op = torch.chunk