squeeze: allow squeezing multiple dimensions at once (#89017)

Ref #70924

This addresses part 1 of the issue, allowing `torch.squeeze` to be
passed a tuple of dimensions. e.g.
```python
x.squeeze(0).squeeze(0)
```
can now be written
```python
x.squeeze((0, 1))
```
(assuming x has at least 2 dimensions)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89017
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp
index 17f13f2..8a68503 100644
--- a/aten/src/ATen/FunctionalInverses.cpp
+++ b/aten/src/ATen/FunctionalInverses.cpp
@@ -3,6 +3,7 @@
 
 #include <ATen/ATen.h>
 #include <ATen/ExpandUtils.h>
+#include <ATen/WrapDimUtilsMulti.h>
 
 #include <utility>
 namespace at {
@@ -42,18 +43,26 @@
   return result;
 }
 
-Tensor unsqueeze_copy_to(const Tensor & self, int64_t dim, c10::SymIntArrayRef sizes, bool reapply_views) {
-  dim = at::maybe_wrap_dim(dim, sizes.size());
+Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIntArrayRef sizes, bool reapply_views) {
+  const auto ndim = sizes.size();
+  const auto mask = at::dim_list_to_bitset(dim, ndim);
   // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
   // unsqueezing in the backward.
-  if (sizes.size() > 0 && sizes[dim] == 1) {
-    if (reapply_views) {
-      return at::unsqueeze(self, dim);
-    } else {
-      return at::unsqueeze_copy(self, dim);
+  if (ndim == 0) {
+    return self;
+  }
+
+  Tensor result = self;
+  for (const auto d : c10::irange(ndim)) {
+    if (mask.test(d) && sizes[d] == 1) {
+      if (reapply_views) {
+        result = at::unsqueeze(result, d);
+      } else {
+        result = at::unsqueeze_copy(result, d);
+      }
     }
   }
-  return self;
+  return result;
 }
 
 // Note [Functionalization Pass: View Inverses].
@@ -215,6 +224,10 @@
     return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), reapply_views);
 }
 
+Tensor FunctionalInverses::squeeze_copy_dims_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, IntArrayRef dim) {
+    return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), reapply_views);
+}
+
 Tensor FunctionalInverses::t_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
     if (reapply_views) {
       return at::t(mutated_view);
diff --git a/aten/src/ATen/LegacyBatchingRegistrations.cpp b/aten/src/ATen/LegacyBatchingRegistrations.cpp
index 445c4a0..77c6410 100644
--- a/aten/src/ATen/LegacyBatchingRegistrations.cpp
+++ b/aten/src/ATen/LegacyBatchingRegistrations.cpp
@@ -296,6 +296,13 @@
   return self_physical.getPhysicalToLogicalMap().apply(result);
 }
 
+Tensor squeeze_dims_batching_rule(const Tensor& self, IntArrayRef dims) {
+  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+  auto dims_physical = self_physical.getPhysicalDims(dims);
+  auto result = self_physical.tensor().squeeze(dims_physical);
+  return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
 Tensor trace_batching_rule(const Tensor& self) {
   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
   // Batched Diagonal View
@@ -1116,6 +1123,7 @@
   m.impl("split_with_sizes", split_with_sizes_batching_rule);
   m.impl("squeeze", squeeze_batching_rule);
   m.impl("squeeze.dim", squeeze_dim_batching_rule);
+  m.impl("squeeze.dims", squeeze_dims_batching_rule);
   m.impl("t", native::t); // composite wrt autograd
   m.impl("trace", trace_batching_rule);
   m.impl("transpose.int", transpose_int_batching_rule);
diff --git a/aten/src/ATen/NamedTensorUtils.cpp b/aten/src/ATen/NamedTensorUtils.cpp
index 13d5ddb..9e0fc0f 100644
--- a/aten/src/ATen/NamedTensorUtils.cpp
+++ b/aten/src/ATen/NamedTensorUtils.cpp
@@ -241,6 +241,20 @@
   return outnames;
 }
 
+std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor, std::bitset<dim_bitset_size> dims) {
+  if (!tensor.has_names()) {
+    return {};
+  }
+  std::vector<Dimname> outnames;
+  auto tensor_names = tensor.names();
+  for (const auto d : c10::irange(tensor.dim())) {
+    if (!dims.test(d) || tensor.sym_sizes()[d] != 1) {
+      outnames.push_back(tensor_names[d]);
+    }
+  }
+  return outnames;
+}
+
 std::vector<Dimname> compute_diagonal_outnames(
     const Tensor& tensor,
     int64_t dim1,
diff --git a/aten/src/ATen/NamedTensorUtils.h b/aten/src/ATen/NamedTensorUtils.h
index c9ff27c..5051cf1 100644
--- a/aten/src/ATen/NamedTensorUtils.h
+++ b/aten/src/ATen/NamedTensorUtils.h
@@ -1,6 +1,7 @@
 #pragma once
 #include <ATen/NamedTensor.h>
 #include <ATen/TensorNames.h>
+#include <ATen/WrapDimUtilsMulti.h>
 
 #include <ATen/core/DimVector.h>
 #include <ATen/core/Tensor.h>
@@ -144,6 +145,9 @@
     const Tensor& other);
 
 TORCH_API std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor);
+TORCH_API std::vector<Dimname> compute_squeeze_outnames(
+    const Tensor& tensor,
+    std::bitset<dim_bitset_size> dims);
 
 std::vector<Dimname> compute_diagonal_outnames(
     const Tensor& tensor,
diff --git a/aten/src/ATen/functorch/BatchRulesViews.cpp b/aten/src/ATen/functorch/BatchRulesViews.cpp
index 1cedd7d..c171fcc 100644
--- a/aten/src/ATen/functorch/BatchRulesViews.cpp
+++ b/aten/src/ATen/functorch/BatchRulesViews.cpp
@@ -239,34 +239,41 @@
   return std::make_tuple(result, c10::optional<int64_t>(new_batch_idx));
 }
 
-std::tuple<Tensor, optional<int64_t>> squeeze_dim_batch_rule(const Tensor& self, optional<int64_t> bdim, int64_t dim) {
+std::tuple<Tensor, optional<int64_t>> squeeze_dims_batch_rule(
+    const Tensor& self, optional<int64_t> bdim, IntArrayRef dims) {
   TORCH_INTERNAL_ASSERT(bdim.has_value());
   // Special case for scalar arrays to replicate PyTorch behavior.
-  if (self.dim() == 1) {
-    TORCH_CHECK(dim == 0, "Dimension is out of range (expected to be in range of [-1, 0], but got ", dim);
+  auto ndim = self.dim();
+  if (ndim == 1) {
+    TORCH_CHECK(
+        dims.size() == 0 || (dims.size() == 1 && dims[0] == 0),
+        "Dimension is out of range (expected to be in range of [-1, 0], but got ", dims);
     return std::make_tuple(self.alias(), bdim);
   }
 
-  // Calculate the proper offset if dim is negative.
-  auto actual_dim = dim;
-  if (dim < 0) {
-    actual_dim = self.dim() + dim - 1;
-  }
-  if (actual_dim < bdim) {
-    // Since dimension to be squeezed is before the batch dimension pass as-is.
-    auto original_size = self.dim();
-    auto result = self.squeeze(actual_dim);
-    auto updated_batch_idx = *bdim;
-    if (result.dim() != original_size) {
-      // A column before batch dimension has been dropped so adjust accordingly.
-      --updated_batch_idx;
+  // Adjust any dimensions higher than the batch dimension
+  DimVector adjusted_dims(dims.begin(), dims.end());
+  int64_t updated_batch_idx = *bdim;
+  for (auto &d : adjusted_dims) {
+    auto actual_dim = c10::maybe_wrap_dim(d, ndim - 1);
+    if (actual_dim < *bdim) {
+      d = actual_dim;
+      if (self.sym_size(actual_dim) == 1) {
+        // A column before batch dimension will be dropped so adjust accordingly.
+        --updated_batch_idx;
+      }
+    } else {
+      // Since dimension to be squeezed is after the batch dimension adjust by one to account
+      // for the original batch dimension. In this case batch dimension won't move.
+      d = actual_dim + 1;
     }
-    return std::make_tuple(result, optional<int64_t>(updated_batch_idx));
-  } else {
-    // Since dimension to be squeezed is after the batch dimension adjust by one to account
-    // for the original batch dimension. In this case batch dimension won't move.
-    return std::make_tuple(self.squeeze(actual_dim + 1), bdim);
   }
+  return std::make_tuple(self.squeeze(adjusted_dims), optional<int64_t>(updated_batch_idx));
+}
+
+std::tuple<Tensor, optional<int64_t>> squeeze_dim_batch_rule(
+    const Tensor& self, optional<int64_t> bdim, int64_t dim) {
+  return squeeze_dims_batch_rule(self, bdim, {dim});
 }
 
 std::tuple<std::vector<Tensor>, optional<int64_t>> chunk_batching_rule(const Tensor& self, optional<int64_t> self_bdim, int64_t chunks, int64_t dim) {
@@ -548,6 +555,7 @@
   VMAP_SUPPORT2(select, int, select_batching_rule);
   VMAP_SUPPORT(squeeze, squeeze_batch_rule);
   VMAP_SUPPORT2(squeeze, dim, squeeze_dim_batch_rule);
+  VMAP_SUPPORT2(squeeze, dims, squeeze_dims_batch_rule);
   VMAP_SUPPORT(_reshape_alias, _reshape_alias_batch_rule);
   VMAP_SUPPORT(roll, roll_batch_rule);
   VMAP_SUPPORT(permute, permute_batching_rule);
diff --git a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp
index c50f124..d9f6ed2 100644
--- a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp
+++ b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp
@@ -144,40 +144,52 @@
   return result;
 }
 
-Tensor& squeeze_dim__batching_rule(Tensor& self, int64_t dim) {
+Tensor& squeeze_dims__batching_rule(Tensor& self, IntArrayRef dims) {
   if (!participatesInCurrentLevel(self)) {
     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
-    return self.squeeze_(dim);
+    return self.squeeze_(dims);
   }
   auto* batched = maybeGetBatchedImpl(self);
   const auto bdim = batched->bdim();
   auto logical_dim = self.dim();
 
-  // If logically a scalar tensor, then Tensor.squeeze_(dim) is a no-op
   if (logical_dim == 0) {
+    TORCH_CHECK(
+        dims.size() == 0 || (dims.size() == 1 && dims[0] == 0),
+        "Dimension is out of range (expected to be in range of [-1, 0], but got ", dims);
     return self;
   }
 
-  dim = maybe_wrap_dim(dim, logical_dim);
-  if (dim >= bdim) {
-    dim = dim + 1;
-    batched->value().squeeze_(dim);
-    batched->refreshTensorMetadata();
-    return self;
+  // Adjust any dimensions higher than the batch dimension
+  DimVector adjusted_dims(dims.begin(), dims.end());
+  int64_t updated_batch_idx = bdim;
+  for (auto &d : adjusted_dims) {
+    auto actual_dim = c10::maybe_wrap_dim(d, logical_dim);
+    if (actual_dim < bdim) {
+      d = actual_dim;
+      if (batched->value().sym_size(actual_dim) == 1) {
+        // A column before batch dimension will be dropped so adjust accordingly.
+        --updated_batch_idx;
+      }
+    } else {
+      // Since dimension to be squeezed is after the batch dimension adjust by one to account
+      // for the original batch dimension. In this case batch dimension won't move.
+      d = actual_dim + 1;
+    }
   }
 
-  // Tensor.squeeze_(0) is a no-op if dim 0 has a size other than 1
-  if (batched->value().size(dim) != 1) {
-    return self;
+  batched->value().squeeze_(adjusted_dims);
+  if (updated_batch_idx != bdim) {
+    batched->unsafe_set_bdim(updated_batch_idx);
   }
-
-  // dim < bdim, so we need to adjust bdim
-  batched->value().squeeze_(dim);
-  batched->unsafe_set_bdim(bdim - 1);
   batched->refreshTensorMetadata();
   return self;
 }
 
+Tensor& squeeze_dim__batching_rule(Tensor& self, int64_t dim) {
+  return squeeze_dims__batching_rule(self, {dim});
+}
+
 Tensor& squeeze__batching_rule(Tensor& self) {
   if (!participatesInCurrentLevel(self)) {
     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
@@ -816,6 +828,7 @@
   // still legacy b/c needs special inplace rules
   m.impl("squeeze_", squeeze__batching_rule);
   m.impl("squeeze_.dim", squeeze_dim__batching_rule);
+  m.impl("squeeze_.dims", squeeze_dims__batching_rule);
   m.impl("unsqueeze_", unsqueeze__batching_rule);
   m.impl("transpose_", transpose__batching_rule);
 
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 05a90e3..cc33bda 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -90,6 +90,7 @@
 #include <ATen/ops/slice.h>
 #include <ATen/ops/special_logsumexp_native.h>
 #include <ATen/ops/sqrt.h>
+#include <ATen/ops/squeeze.h>
 #include <ATen/ops/stack.h>
 #include <ATen/ops/std.h>
 #include <ATen/ops/std_mean.h>
@@ -1381,23 +1382,11 @@
   return at::nansum(self, dim, keepdim, opt_dtype).div(factor);
 }
 
-static Tensor squeeze_multiple(const Tensor& self, IntArrayRef dims) {
-  int ndims = self.sizes().size();
-  auto dims_to_squeeze = at::dim_list_to_bitset(dims, ndims);
-  Tensor result = self;
-  for (int i = ndims - 1; i >= 0; --i) {
-    if (dims_to_squeeze[i]) {
-      result = result.squeeze(i);
-    }
-  }
-  return result;
-}
-
 static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {
   // can't take max of empty tensor
   if (self.numel() != 0) {
     auto maxes = at::amax(self, dims, true);
-    auto maxes_squeezed = (keepdim ? maxes : squeeze_multiple(maxes, dims));
+    auto maxes_squeezed = (keepdim ? maxes : at::squeeze(maxes, dims));
     maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
     at::sum_out(result, (self - maxes).exp_(), dims, keepdim);
     result.log_().add_(maxes_squeezed);
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 7634796..2127968 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -162,6 +162,7 @@
 #include <ATen/ops/split_with_sizes_native.h>
 #include <ATen/ops/squeeze_copy_native.h>
 #include <ATen/ops/squeeze_native.h>
+#include <ATen/ops/squeeze.h>
 #include <ATen/ops/stack_native.h>
 #include <ATen/ops/sub.h>
 #include <ATen/ops/sum.h>
@@ -3095,6 +3096,22 @@
   return std::make_tuple(std::move(sizes), std::move(strides));
 }
 
+std::tuple<SymDimVector, SymDimVector>
+inferSqueezeGeometry(const Tensor &tensor, std::bitset<dim_bitset_size> dim_mask) {
+  const auto ndim = tensor.dim();
+  const auto sym_sizes = tensor.sym_sizes();
+  const auto sym_strides = tensor.sym_strides();
+
+  SymDimVector out_sizes, out_strides;
+  for (const auto d: c10::irange(ndim)) {
+    if (!dim_mask.test(d) || sym_sizes[d] != 1) {
+      out_sizes.push_back(sym_sizes[d]);
+      out_strides.push_back(sym_strides[d]);
+    }
+  }
+  return std::make_tuple(std::move(out_sizes), std::move(out_strides));
+}
+
 namespace {
 // Named type instead of a pair/tuple so that we can be sure to
 // construct the vectors in place and get NRVO.
@@ -3117,18 +3134,21 @@
 }
 
 // dim is present if squeezing a single dimension and absent if squeezing all dimensions
-Tensor squeeze_qtensor(const Tensor& self, c10::optional<int64_t> dim) {
+Tensor squeeze_qtensor(const Tensor& self, c10::OptionalIntArrayRef dims) {
   auto quantizer = get_qtensorimpl(self)->quantizer();
   SymDimVector sizes;
   SymDimVector strides;
-  std::tie(sizes, strides) = dim.has_value() ? inferSqueezeGeometry(self, dim.value()) : inferSqueezeGeometry(self);
+  const auto ndim = self.dim();
+  auto mask = dims.has_value()
+      ? dim_list_to_bitset(dims, self.dim())
+      : std::bitset<dim_bitset_size>((1ull << self.dim()) - 1);
+  std::tie(sizes, strides) = inferSqueezeGeometry(self, mask);
   if (quantizer->qscheme() == QScheme::PER_CHANNEL_AFFINE) {
     const auto* per_channel_quantizer = static_cast<at::PerChannelAffineQuantizer*>(quantizer.get());
     auto axis = per_channel_quantizer->axis();
     int64_t shift = 0;
-    integer_range<int64_t> dims = dim.has_value() ? integer_range<int64_t>{dim.value(), dim.value() + 1} : c10::irange(0, self.dim());
-    for (const auto d : dims) {
-      if (self.sizes()[d] == 1) {
+    for (const auto d : c10::irange(ndim)) {
+      if (mask.test(d) && self.sizes()[d] == 1) {
         TORCH_CHECK(axis != d, "Squeeze is only possible on non-axis dimension for Per-Channel Quantized Tensors.");
         if (d < axis) {
           ++shift;
@@ -3144,13 +3164,8 @@
   // TODO: quantized Tensor support for SymInt needs to be added but basic building blocs
   // are missing for now.
   auto result = make_qtensor(self, C10_AS_INTARRAYREF_SLOW(sizes), C10_AS_INTARRAYREF_SLOW(strides), std::move(quantizer));
-  if (dim.has_value()) {
-    namedinference::propagate_names_except(result, self, {dim.value()});
-  } else {
-    auto maybe_outnames = namedinference::compute_squeeze_outnames(self);
-    namedinference::propagate_names_if_nonempty(result, maybe_outnames);
-  }
-
+  auto maybe_outnames = namedinference::compute_squeeze_outnames(self, mask);
+  namedinference::propagate_names_if_nonempty(result, maybe_outnames);
   return result;
 }
 
@@ -3163,10 +3178,7 @@
 }
 
 Tensor squeeze_quantized(const Tensor& self) {
-  at::Tensor result = squeeze_qtensor(self, c10::nullopt);
-  auto maybe_outnames = namedinference::compute_squeeze_outnames(self);
-  namedinference::propagate_names_if_nonempty(result, maybe_outnames);
-  return result;
+  return squeeze_qtensor(self, c10::nullopt);
 }
 
 Tensor squeeze(const Tensor& self, int64_t dim) {
@@ -3182,8 +3194,19 @@
 }
 
 Tensor squeeze_quantized(const Tensor& self, int64_t dim) {
-  int64_t dims = self.dim();
-  dim = maybe_wrap_dim(dim, dims);
+  return squeeze_qtensor(self, dim);
+}
+
+Tensor squeeze(const Tensor& self, IntArrayRef dims) {
+  auto mask = dim_list_to_bitset(dims, self.dim());
+  auto g = inferSqueezeGeometry(self, mask);
+  at::Tensor result = self.as_strided_symint(std::get<0>(g), std::get<1>(g));
+  auto maybe_outnames = namedinference::compute_squeeze_outnames(self, mask);
+  namedinference::propagate_names_if_nonempty(result, maybe_outnames);
+  return result;
+}
+
+Tensor squeeze_quantized(const Tensor& self, IntArrayRef dim) {
   return squeeze_qtensor(self, dim);
 }
 
@@ -3206,6 +3229,13 @@
   return self;
 }
 
+Tensor & squeeze_(Tensor &self, IntArrayRef dims) {
+  auto mask = dim_list_to_bitset(dims, self.dim());
+  auto g = inferSqueezeGeometry(self, mask);
+  self.as_strided__symint(std::get<0>(g), std::get<1>(g));
+  return self;
+}
+
 // NOTE [ Unsafe View ]
 // _unsafe_view() differs from view() in that the returned tensor isn't treated
 // as a view for the purposes of automatic differentiation. (It's not listed in
@@ -4044,6 +4074,13 @@
 }
 
 
+at::Tensor& squeeze_copy_dims_out(const at::Tensor & self, IntArrayRef dims, at::Tensor & out) {
+  auto tmp = self.squeeze(dims);
+  out.copy_(tmp);
+  return out;
+}
+
+
 at::Tensor& t_copy_out(const at::Tensor & self, at::Tensor & out) {
   auto tmp = self.t();
   out.copy_(tmp);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 5f3e1ab..d52c8df 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5212,6 +5212,17 @@
   device_check: NoCheck
   device_guard: False
 
+
+- func: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: squeeze
+    QuantizedCPU, QuantizedCUDA: squeeze_quantized
+    NestedTensorCPU, NestedTensorCUDA: squeeze_dim_nested
+  tags: canonical
+
 - func: squeeze_(Tensor(a!) self) -> Tensor(a!)
   variants: method
   device_check: NoCheck
@@ -5228,6 +5239,14 @@
   dispatch:
     CompositeExplicitAutograd: squeeze_
 
+- func: squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  tags: inplace_view
+  dispatch:
+    CompositeExplicitAutograd: squeeze_
+
 - func: squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!)
   variants: method
   device_check: NoCheck
@@ -13600,6 +13619,12 @@
     CompositeExplicitAutogradNonFunctional: squeeze_copy_dim
   tags: view_copy
 
+- func: squeeze_copy.dims(Tensor self, int[] dim) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: squeeze_copy_dims
+  tags: view_copy
+
 - func: t_copy(Tensor self) -> Tensor
   variants: function
   dispatch:
@@ -13812,6 +13837,12 @@
     CompositeExplicitAutograd: squeeze_copy_dim_out
 
 
+- func: squeeze_copy.dims_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: squeeze_copy_dims_out
+
+
 - func: t_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
   variants: function
   dispatch:
diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp
index 5842c3b..287e861 100644
--- a/aten/src/ATen/native/nested/NestedTensorMath.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp
@@ -9,6 +9,7 @@
 #include <ATen/TensorIndexing.h>
 #include <ATen/TensorOperators.h>
 #include <ATen/TensorUtils.h>
+#include <ATen/WrapDimUtilsMulti.h>
 #include <ATen/core/Tensor.h>
 #include <ATen/native/layer_norm.h>
 #include <ATen/native/nested/NestedTensorUtils.h>
@@ -682,37 +683,53 @@
   return self;
 }
 
-Tensor squeeze_dim_nested(const Tensor& self, int64_t dim) {
+Tensor squeeze_dim_nested(const Tensor& self, IntArrayRef dims) {
   auto self_ptr = get_nested_tensor_impl(self);
   int64_t ndim = self_ptr->dim();
-  int64_t wrapped_dim = at::maybe_wrap_dim(dim, ndim);
-  TORCH_CHECK(wrapped_dim > 0,
+  auto mask = at::dim_list_to_bitset(dims, ndim);
+  TORCH_CHECK(!mask.test(0),
   "squeeze(): For nested tensors, squeezing dimension 0 is not supported at the moment ",
   "if you need this feature, please open an issue on github describing your use case.");
   const Tensor& sizemat = self_ptr->get_nested_size_tensor();
   const Tensor& stridemat = self_ptr->get_nested_stride_tensor();
   // if tensor.size(dim) != 1 torch.squeeze will return the result, we do the same here
-  c10::optional<int64_t> size_dim = self_ptr->opt_size(dim);
-  if (!(size_dim.has_value() && size_dim.value() == 1)) {
+  for (const auto d : c10::irange(ndim)) {
+    if (mask.test(d)) {
+      c10::optional<int64_t> size_dim = self_ptr->opt_size(d);
+      if (!(size_dim.has_value() && *size_dim == 1)) {
+        mask.reset(d);
+      }
+    }
+  }
+
+  if (!mask.any()) {
     // detach to avoid triggering throw_error_if_base_and_tensor_are_same
     return self.detach();
   }
   // if ndim == 2 and we pass the above if statement we should have a
   // nested tensor of singleton tensors
-  TORCH_CHECK(ndim != 2,
+  TORCH_CHECK(ndim > static_cast<int64_t>(1 + dims.size()),
   "squeeze(): For nested tensors, squeezing a nested tensor of singleton tensors is not ",
   "supported at the moment, if you need this feature, please open an issue on github",
   "describing your use case.");
-  auto column_indices = sizemat.new_empty(ndim - 2);
+  const auto new_ndim = ndim - mask.count();
+  auto column_indices = sizemat.new_empty(new_ndim - 1);
   int64_t* column_indices_ptr = column_indices.data_ptr<int64_t>();
-  std::iota(column_indices_ptr, column_indices_ptr + wrapped_dim - 1, 0);
-  std::iota(column_indices_ptr + wrapped_dim - 1, column_indices_ptr + ndim - 2, wrapped_dim);
+  for (const auto d : c10::irange(1, ndim)) {
+    if (!mask.test(d)) {
+      *column_indices_ptr++ = d - 1;
+    }
+  }
   auto sizemat_squeezed = at::index_select(sizemat, 1, column_indices);
   auto stridemat_squeezed = at::index_select(stridemat, 1, column_indices);
   return create_nested_view_tensor(
       self, sizemat_squeezed, stridemat_squeezed, std::vector<int64_t>(self_ptr->get_storage_offsets()));
 }
 
+Tensor squeeze_dim_nested(const Tensor& self, int64_t dim) {
+  return squeeze_dim_nested(self, IntArrayRef{dim});
+}
+
 Tensor unsqueeze_nested(const Tensor& self, int64_t dim) {
   auto self_ptr = get_nested_tensor_impl(self);
   int64_t ndim = self_ptr->dim();
diff --git a/aten/src/ATen/native/ts_native_functions.yaml b/aten/src/ATen/native/ts_native_functions.yaml
index f4c3ee8..85ac57e 100644
--- a/aten/src/ATen/native/ts_native_functions.yaml
+++ b/aten/src/ATen/native/ts_native_functions.yaml
@@ -157,6 +157,7 @@
   #- unbind_copy.int
   - squeeze_copy
   - squeeze_copy.dim
+  - squeeze_copy.dims
   - t_copy
   - transpose_copy.int
   - unsqueeze_copy
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index 3c59bfc..67f6a44 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -3322,7 +3322,7 @@
 
     def vmap_outplace_test(self, func, args, kwargs, in_dims, check_shape_only=False,
                            postprocess_fn=None):
-        for loop_out, vmap_out in compute_quantities_for_vmap_test(func, args, kwargs, in_dims):
+        for vmap_out, loop_out in compute_quantities_for_vmap_test(func, args, kwargs, in_dims):
             if postprocess_fn is not None:
                 loop_out = postprocess_fn(loop_out)
                 vmap_out = postprocess_fn(vmap_out)
@@ -3343,7 +3343,7 @@
                         func, args, kwargs, in_dims, compute_loop_out=False, clone_inputs=True):
                     pass
             return
-        for loop_out, vmap_out in compute_quantities_for_vmap_test(
+        for vmap_out, loop_out in compute_quantities_for_vmap_test(
                 func, args, kwargs, in_dims, clone_inputs=True):
             if postprocess_fn is not None:
                 loop_out = postprocess_fn(loop_out)
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 10a0080..a69b8b4 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1506,6 +1506,14 @@
     AutogradNestedTensor:
       self: grad.unsqueeze(dim)
 
+- name: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)
+  dispatch:
+    Default:
+      self: unsqueeze_to(grad, dim, self.sym_sizes())
+      result: auto_linear
+    AutogradNestedTensor:
+      self: unsqueeze_multiple(grad, dim, self.dim())
+
 - name: squeeze_(Tensor(a!) self) -> Tensor(a!)
   self: unsqueeze_to(grad, self.sym_sizes())
   result: auto_linear
@@ -1514,6 +1522,10 @@
   self: unsqueeze_to(grad, dim, self.sym_sizes())
   result: auto_linear
 
+- name: squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!)
+  self: unsqueeze_to(grad, dim, self.sym_sizes())
+  result: auto_linear
+
 - name: std.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor
   self: std_backward(result, grad, self, dim, correction, keepdim)
   # pointwise (variance) + sum + sqrt
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index e6dd438..c3aa8bf 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -12,6 +12,7 @@
 import torch.utils._pytree as pytree
 from torch import _prims_common
 from torch._prims_common import (
+    canonicalize_dims,
     dtype_to_type,
     elementwise_dtypes,
     ELEMENTWISE_TYPE_PROMOTION_KIND,
@@ -488,16 +489,17 @@
     assert isinstance(x, TensorBox)
     if dim is None:
         return TensorBox(SqueezeView.create(x.data))
-    offset = len(x.get_size()) == 0
-    dim = _validate_dim(x, dim, offset)
-    new_shape = list(x.get_size())
-    if len(new_shape) > 0:
-        removed = new_shape.pop(dim)
-        if V.graph.sizevars.maybe_guard_equals(removed, 1):
-            return view(x, new_shape)
 
+    dim = canonicalize_dims(len(x.get_size()), dim)
+    dims = set((dim,) if not isinstance(dim, tuple) else dim)
+
+    new_shape = [
+        s
+        for d, s in enumerate(x.get_size())
+        if not (d in dims and V.graph.sizevars.maybe_guard_equals(s, 1))
+    ]
     # squeeze does nothing if the size isn't 1
-    return x
+    return view(x, new_shape) if new_shape != x.get_size() else x
 
 
 @register_lowering([aten.squeeze_])
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index 52aa4b8..f0c07dd 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -1706,13 +1706,6 @@
     return a.as_strided(new_shape, new_strides, a.storage_offset())
 
 
-def _squeeze_aten(a: Tensor, dimensions: Sequence) -> Tensor:
-    for idx in reversed(sorted(dimensions)):
-        a = torch.squeeze(a, dim=idx)
-
-    return a
-
-
 _squeeze_doc = """
   Creates a view of the tensor with the specified dimensions removed.
 
@@ -1722,7 +1715,7 @@
 squeeze = _make_prim(
     schema="squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)",
     meta=_squeeze_meta,
-    impl_aten=_squeeze_aten,
+    impl_aten=torch.squeeze,
     return_type=RETURN_TYPE.VIEW,
     doc=_squeeze_doc,
 )
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index e792e35..e581de0 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -3513,21 +3513,22 @@
 
 
 @register_decomposition(aten.squeeze)
-def squeeze(a: TensorLikeType, dim: Optional[int] = None) -> TensorLikeType:
-    if dim is not None:
-        dim = utils.canonicalize_dim(a.ndim, dim)
-        # Short-circuits if the tensor has no dimensions
-        if len(a.shape) == 0:
-            assert dim == 0
-            return prims.view_of(a)
+def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
+    if dim is None:
+        dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1)
+        return prims.squeeze(a, dims) if dims else prims.view_of(a)
 
-        # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
-        if a.shape[dim] != 1:
-            return prims.view_of(a)
-        return prims.squeeze(a, (dim,))
+    ndim = a.ndim
+    dim = utils.canonicalize_dims(ndim, dim)
+    dims = (dim,) if isinstance(dim, Dim) else dim
+    # Short-circuits if the tensor has no dimensions
+    if ndim == 0:
+        assert len(dims) == 0 or dims == (0,)
+        return prims.view_of(a)
 
-    dims = tuple(idx for idx in range(len(a.shape)) if a.shape[idx] == 1)
-    return prims.squeeze(a, dims)
+    # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
+    dims = tuple(d for d in dims if a.shape[d] == 1)
+    return prims.squeeze(a, dims) if dims else prims.view_of(a)
 
 
 # Note: does not work with TensorMetas because of data-dependent control-flow
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 42292d6..ee0254e 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -10640,14 +10640,14 @@
     r"""
 squeeze(input, dim=None) -> Tensor
 
-Returns a tensor with all the dimensions of :attr:`input` of size `1` removed.
+Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed.
 
 For example, if `input` is of shape:
-:math:`(A \times 1 \times B \times C \times 1 \times D)` then the `out` tensor
+:math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()`
 will be of shape: :math:`(A \times B \times C \times D)`.
 
 When :attr:`dim` is given, a squeeze operation is done only in the given
-dimension. If `input` is of shape: :math:`(A \times 1 \times B)`,
+dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`,
 ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)``
 will squeeze the tensor to the shape :math:`(A \times B)`.
 
@@ -10656,12 +10656,15 @@
 
 .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)`
           will also remove the batch dimension, which can lead to unexpected
-          errors.
+          errors. Consider specifying only the dims you wish to be squeezed.
 
 Args:
     {input}
-    dim (int, optional): if given, the input will be squeezed only in
-           this dimension
+    dim (int or tuple of ints, optional): if given, the input will be squeezed
+           only in the specified dimensions.
+
+        .. versionchanged:: 2.0
+           :attr:`dim` now accepts tuples of dimensions.
 
 Example::
 
@@ -10677,6 +10680,8 @@
     >>> y = torch.squeeze(x, 1)
     >>> y.size()
     torch.Size([2, 2, 1, 2])
+    >>> y = torch.squeeze(x, (1, 2, 3))
+    torch.Size([2, 2, 2])
 """.format(
         **common_args
     ),
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index c874eb8..1135a82 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -863,15 +863,25 @@
 
 Tensor unsqueeze_to(
     const Tensor& self,
+    IntArrayRef dims,
+    c10::SymIntArrayRef sym_sizes) {
+  const auto ndim = sym_sizes.size();
+  auto mask = at::dim_list_to_bitset(dims, ndim);
+
+  Tensor result = self;
+  for (const auto d : c10::irange(ndim)) {
+    if (mask.test(d) && sym_sizes[d] == 1) {
+      result = result.unsqueeze(d);
+    }
+  }
+  return result;
+}
+
+Tensor unsqueeze_to(
+    const Tensor& self,
     int64_t dim,
     c10::SymIntArrayRef sym_sizes) {
-  dim = at::maybe_wrap_dim(dim, sym_sizes.size());
-  // in NumPy it's not an error to unsqueeze a scalar, but we still need to
-  // avoided unsqueezing in the backward.
-  if (sym_sizes.size() > 0 && sym_sizes[dim] == 1) {
-    return self.unsqueeze(dim);
-  }
-  return self;
+  return unsqueeze_to(self, IntArrayRef{dim}, sym_sizes);
 }
 
 std::vector<Tensor> cat_tensors_backward(
diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h
index ddeb08c..1279f0a 100644
--- a/torch/csrc/autograd/FunctionsManual.h
+++ b/torch/csrc/autograd/FunctionsManual.h
@@ -220,6 +220,10 @@
     const at::Tensor& self,
     int64_t dim,
     c10::SymIntArrayRef sym_sizes);
+at::Tensor unsqueeze_to(
+    const at::Tensor& self,
+    IntArrayRef dim,
+    c10::SymIntArrayRef sym_sizes);
 std::vector<at::Tensor> cat_tensors_backward(
     const at::Tensor& grad,
     const std::vector<std::vector<c10::SymInt>>& sizes,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index c99eb43..32f11f9 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -13,7 +13,7 @@
 import numpy as np
 from torch._six import inf, nan
 
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any, Dict, List, Tuple, Union, Sequence
 from torch.testing import make_tensor
 from torch.testing._internal.common_dtype import (
     _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
@@ -4771,6 +4771,7 @@
     shapes_and_args = (
         ((S, 1, S, 1), ()),
         ((1, 1, 1, 1), ()),
+        ((1, 1, 1, 1), (0,)),
         ((S, 1, S, 1), (1,)),
         ((S, 1, S, 1), (-1,)),
         ((S, 1, S, 1), (2,)),
@@ -4785,6 +4786,37 @@
         yield SampleInput(tensor, args=args)
 
 
+def sample_inputs_squeeze_multiple(op_info, device, dtype, requires_grad, **kwargs):
+    shapes_and_args = (
+        ((1, 1, 1, 1), ()),
+        ((S, 1, S, 1), (1,)),
+        ((S, 1, S, 1), (-1,)),
+        ((S, 1, S, 1), (1, 3)),
+        ((S, 1, S, 1), (1, 2,)),
+        ((), (0,)),
+    )
+
+    for shape, dims in shapes_and_args:
+        tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
+                             requires_grad=requires_grad)
+
+        yield SampleInput(tensor, dims)
+
+
+def _squeeze_ref(x, axis=None):
+    # NumPy doesn't allow squeezing scalars
+    if x.ndim == 0:
+        return x
+
+    if isinstance(axis, Sequence):
+        # Numpy doesn't allow specifying non-singular dimensions
+        axis = tuple(a for a in axis if x.shape[a] == 1)
+
+    if isinstance(axis, int) and x.shape[axis] != 1:
+        return x
+
+    return np.squeeze(x, axis)
+
 def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs):
     assert mode in ('constant', 'reflect', 'replicate', 'circular')
     if mode in ['reflect', 'replicate']:
@@ -15654,6 +15686,7 @@
                       DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
                   )),
     OpInfo('squeeze',
+           ref=_squeeze_ref,
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
            supports_out=False,
            assert_autodiffed=True,
@@ -15667,6 +15700,21 @@
            # https://github.com/pytorch/pytorch/issues/66357
            check_batched_forward_grad=False,
            sample_inputs_func=sample_inputs_squeeze),
+    OpInfo('squeeze',
+           ref=_squeeze_ref,
+           variant_test_name="multiple",
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           assert_autodiffed=True,
+           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
+           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           # https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_squeeze_multiple),
     UnaryUfuncInfo(
         'fill',
         ref=_fill_np,
@@ -19016,6 +19064,11 @@
         torch_opinfo_name="squeeze",
     ),
     PythonRefInfo(
+        "_refs.squeeze",
+        torch_opinfo_name="squeeze",
+        torch_opinfo_variant_name="multiple",
+    ),
+    PythonRefInfo(
         "_refs.tensor_split",
         torch_opinfo_name="tensor_split",
         skips=(