Implement batching rules for some unary ops (#43059)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43059
This PR implements batching rules for some unary ops. In particular, it
implements the batching rules for the unary ops that take a single
tensor as input (and nothing else).
The batching rule for a unary op is:
(1) grab the physical tensor straight out of the BatchedTensor
(2) call the unary op
(3) rewrap the physical tensor in a BatchedTensor
Test Plan: - new tests `pytest test/test_vmap.py -v -k "Operators"`
Reviewed By: ezyang
Differential Revision: D23132277
Pulled By: zou3519
fbshipit-source-id: 24b9d7535338207531d767155cdefd2c373ada77
diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp
index 0248d3b..9e1554d 100644
--- a/aten/src/ATen/BatchingRegistrations.cpp
+++ b/aten/src/ATen/BatchingRegistrations.cpp
@@ -226,6 +226,14 @@
return self_physical.newLogicalFromPhysical(result);
}
+template <Tensor (*Op)(const Tensor&)>
+Tensor unary_pointwise_batching_rule(const Tensor& input) {
+ auto* input_batched = unsafeGetBatchedImpl(input);
+ auto output_physical = Op(input_batched->value());
+ auto old_bdims = input_batched->bdims();
+ return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
+}
+
TORCH_LIBRARY_IMPL(_, Batched, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
}
@@ -268,6 +276,40 @@
m.impl("unsqueeze", unsqueeze_batching_rule);
m.impl("view", view_batching_rule);
m.impl("view_as", native::view_as); // composite wrt autograd
+
+ // unary pointwise, out-of-place, no additional arguments.
+#define UNARY_POINTWISE(op) m.impl(#op, unary_pointwise_batching_rule<at::op>);
+ UNARY_POINTWISE(abs);
+ UNARY_POINTWISE(acos);
+ UNARY_POINTWISE(asin);
+ UNARY_POINTWISE(atan);
+ UNARY_POINTWISE(ceil);
+ UNARY_POINTWISE(cos);
+ UNARY_POINTWISE(cosh);
+ UNARY_POINTWISE(digamma);
+ UNARY_POINTWISE(exp);
+ UNARY_POINTWISE(expm1);
+ UNARY_POINTWISE(floor);
+ UNARY_POINTWISE(frac);
+ UNARY_POINTWISE(lgamma);
+ UNARY_POINTWISE(log);
+ UNARY_POINTWISE(log10);
+ UNARY_POINTWISE(log1p);
+ UNARY_POINTWISE(log2);
+ UNARY_POINTWISE(neg);
+ UNARY_POINTWISE(reciprocal);
+ UNARY_POINTWISE(relu);
+ UNARY_POINTWISE(round);
+ UNARY_POINTWISE(rsqrt);
+ UNARY_POINTWISE(sigmoid);
+ UNARY_POINTWISE(sign);
+ UNARY_POINTWISE(sin);
+ UNARY_POINTWISE(sinh);
+ UNARY_POINTWISE(sqrt);
+ UNARY_POINTWISE(tan);
+ UNARY_POINTWISE(tanh);
+ UNARY_POINTWISE(trunc);
+#undef UNARY_POINTWISE
}
} // namespace at
diff --git a/test/test_vmap.py b/test/test_vmap.py
index 4831668..164905a 100644
--- a/test/test_vmap.py
+++ b/test/test_vmap.py
@@ -654,6 +654,79 @@
def _vmap_view_test(self, *args, **kwargs):
self._vmap_test(*args, **kwargs, check_view=True)
+ def _assert_doesnt_use_vmap_fallback(self, vmap_args, inputs):
+ regex = r'falling back to slow \(for loop and stack\) implementation'
+ with warnings.catch_warnings(record=True) as wa:
+ result = vmap(*vmap_args)(*inputs)
+ for captured_warning in wa:
+ self.assertNotRegex(str(captured_warning.message), regex)
+
+ def test_assert_doesnt_use_vmap_fallback(self):
+ with self.assertRaises(AssertionError):
+ # One day we'll implement a batching rule for torch.var_mean.
+ # When that happens, please change the example to use an
+ # operator that doesn't have a batching rule implemented.
+ self._assert_doesnt_use_vmap_fallback([torch.var_mean], [torch.rand(3)])
+
+ def test_unary_pointwise_ops(self):
+ def get_rand(size, device):
+ return [torch.rand(size, device=device)]
+
+ def get_randp1(size, device):
+ return [torch.rand(size, device=device) + 1]
+
+ def get_randn(size, device):
+ return [torch.randn(size, device=device)]
+
+ cases = [
+ (torch.abs, get_randn),
+ (torch.acos, get_rand),
+ (torch.asin, get_rand),
+ (torch.atan, get_rand),
+ (torch.ceil, get_randn),
+ (torch.cos, get_rand),
+ (torch.cosh, get_rand),
+ (torch.digamma, get_rand),
+ (torch.exp, get_randn),
+ (torch.expm1, get_randn),
+ (torch.floor, get_randn),
+ (torch.frac, get_randn),
+ (torch.lgamma, get_rand),
+ (torch.log, get_randp1),
+ (torch.log10, get_randp1),
+ (torch.log1p, get_randp1),
+ (torch.log2, get_randp1),
+ (torch.neg, get_randn),
+ (torch.reciprocal, get_randp1),
+ (torch.relu, get_randn),
+ (torch.round, get_randn),
+ (torch.rsqrt, get_randp1),
+ (torch.sigmoid, get_randn),
+ (torch.sign, get_randn),
+ (torch.sin, get_rand),
+ (torch.sinh, get_rand),
+ (torch.sqrt, get_rand),
+ (torch.tan, get_rand),
+ (torch.tanh, get_rand),
+ (torch.trunc, get_randn),
+ ]
+ test = self._vmap_test
+ B0, B1 = 7, 11
+ for op, getter in cases:
+ device = 'cpu'
+
+ self._assert_doesnt_use_vmap_fallback([op], getter([B0], device))
+
+ # Single vmap, various in_dims / out_dims
+ test(op, getter([B0, 3], device))
+ test(op, getter([2, 5, B0, 3], device), in_dims=2)
+ test(op, getter([2, 5, B0, 3], device), in_dims=2, out_dims=2)
+
+ # Doubly nested vmap
+ test(vmap(op), getter([B0, B1], device))
+ test(vmap(op), getter([B1, 2, 5, B0, 3], device), in_dims=2)
+ test(vmap(op, in_dims=2), getter([2, 5, B0, B1, 3], device), in_dims=2, out_dims=2)
+
def test_chunk(self):
test = self._vmap_view_test
op = torch.chunk