Batching rule for torch.squeeze(tensor) (#47632)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47632

This one is fun because we have to be careful not to squeeze out any of
the batch dims (it is the dims of the per-example tensor that are being squeezed).

Test Plan: - new tests

Reviewed By: anjali411

Differential Revision: D24859022

Pulled By: zou3519

fbshipit-source-id: 8adbd80963081efb683f62ea074a286a10da288f
diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp
index d6de7a6..c30ddb6 100644
--- a/aten/src/ATen/BatchingRegistrations.cpp
+++ b/aten/src/ATen/BatchingRegistrations.cpp
@@ -215,6 +215,27 @@
   return self_physical.newLogicalFromPhysical(result);
 }
 
+Tensor squeeze_batching_rule(const Tensor& self) {
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto physical_sizes = self_physical.tensor().sizes();
+
+  // Don't squeeze the batch dims!
+  VmapDimVector squeezed_sizes;
+  int64_t num_batch_dims = self_physical.numBatchDims();
+  squeezed_sizes.insert(
+      squeezed_sizes.end(),
+      physical_sizes.begin(),
+      physical_sizes.begin() + num_batch_dims);
+  for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) {
+    if (*it != 1) {
+      squeezed_sizes.push_back(*it);
+    }
+  }
+
+  auto result = self_physical.tensor().view(squeezed_sizes);
+  return self_physical.newLogicalFromPhysical(result);
+}
+
 Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) {
   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
   auto dim_physical = self_physical.getPhysicalDim(dim);
@@ -950,6 +971,7 @@
   m.impl("slice.Tensor", slice_batching_rule);
   m.impl("split.Tensor", split_batching_rule);
   m.impl("split_with_sizes", split_with_sizes_batching_rule);
+  m.impl("squeeze", squeeze_batching_rule);
   m.impl("squeeze.dim", squeeze_dim_batching_rule);
   m.impl("t", native::t); // composite wrt autograd
   m.impl("transpose.int", transpose_int_batching_rule);
diff --git a/test/test_vmap.py b/test/test_vmap.py
index 9f57258..5f67498 100644
--- a/test/test_vmap.py
+++ b/test/test_vmap.py
@@ -1705,6 +1705,18 @@
         test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
              (torch.rand(3, 5, B0, B1, B2),), in_dims=2)
 
+    def test_squeeze(self):
+        test = self._vmap_view_test
+        op = torch.squeeze
+        B0, B1 = 1, 11
+        test(op, (torch.rand(B0),))
+        test(op, (torch.rand(B0, 3, 5),))
+        test(op, (torch.rand(1, B0, 5),), in_dims=1)
+        test(op, (torch.rand(B0, 0, 1, 5, 1),))
+        test(op, (torch.rand(B0, 1, 1, 1, 1),))
+        test(vmap(op), (torch.rand(B0, B1, 1),))
+        test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2)
+
     def test_sum_dim(self):
         test = self._vmap_test
         B0, B1 = 5, 7