Batching rule for torch.squeeze(tensor) (#47632)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47632
This one is fun because we have to be careful not to squeeze out any of
the batch dims (it is the dims of the per-example tensor that are being squeezed).
Test Plan: - new tests
Reviewed By: anjali411
Differential Revision: D24859022
Pulled By: zou3519
fbshipit-source-id: 8adbd80963081efb683f62ea074a286a10da288f
diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp
index d6de7a6..c30ddb6 100644
--- a/aten/src/ATen/BatchingRegistrations.cpp
+++ b/aten/src/ATen/BatchingRegistrations.cpp
@@ -215,6 +215,27 @@
return self_physical.newLogicalFromPhysical(result);
}
+Tensor squeeze_batching_rule(const Tensor& self) {
+ auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+ auto physical_sizes = self_physical.tensor().sizes();
+
+ // Don't squeeze the batch dims!
+ VmapDimVector squeezed_sizes;
+ int64_t num_batch_dims = self_physical.numBatchDims();
+ squeezed_sizes.insert(
+ squeezed_sizes.end(),
+ physical_sizes.begin(),
+ physical_sizes.begin() + num_batch_dims);
+ for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) {
+ if (*it != 1) {
+ squeezed_sizes.push_back(*it);
+ }
+ }
+
+ auto result = self_physical.tensor().view(squeezed_sizes);
+ return self_physical.newLogicalFromPhysical(result);
+}
+
Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
@@ -950,6 +971,7 @@
m.impl("slice.Tensor", slice_batching_rule);
m.impl("split.Tensor", split_batching_rule);
m.impl("split_with_sizes", split_with_sizes_batching_rule);
+ m.impl("squeeze", squeeze_batching_rule);
m.impl("squeeze.dim", squeeze_dim_batching_rule);
m.impl("t", native::t); // composite wrt autograd
m.impl("transpose.int", transpose_int_batching_rule);
diff --git a/test/test_vmap.py b/test/test_vmap.py
index 9f57258..5f67498 100644
--- a/test/test_vmap.py
+++ b/test/test_vmap.py
@@ -1705,6 +1705,18 @@
test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
(torch.rand(3, 5, B0, B1, B2),), in_dims=2)
+ def test_squeeze(self):
+ test = self._vmap_view_test
+ op = torch.squeeze
+ B0, B1 = 1, 11
+ test(op, (torch.rand(B0),))
+ test(op, (torch.rand(B0, 3, 5),))
+ test(op, (torch.rand(1, B0, 5),), in_dims=1)
+ test(op, (torch.rand(B0, 0, 1, 5, 1),))
+ test(op, (torch.rand(B0, 1, 1, 1, 1),))
+ test(vmap(op), (torch.rand(B0, B1, 1),))
+ test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2)
+
def test_sum_dim(self):
test = self._vmap_test
B0, B1 = 5, 7