squeeze: allow squeezing multiple dimensions at once (#89017)
Ref #70924
This addresses part 1 of the issue, allowing `torch.squeeze` to be
passed a tuple of dimensions. e.g.
```python
x.squeeze(0).squeeze(0)
```
can now be written
```python
x.squeeze((0, 1))
```
(assuming x has at least 2 dimensions)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89017
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp
index 17f13f2..8a68503 100644
--- a/aten/src/ATen/FunctionalInverses.cpp
+++ b/aten/src/ATen/FunctionalInverses.cpp
@@ -3,6 +3,7 @@
#include <ATen/ATen.h>
#include <ATen/ExpandUtils.h>
+#include <ATen/WrapDimUtilsMulti.h>
#include <utility>
namespace at {
@@ -42,18 +43,26 @@
return result;
}
-Tensor unsqueeze_copy_to(const Tensor & self, int64_t dim, c10::SymIntArrayRef sizes, bool reapply_views) {
- dim = at::maybe_wrap_dim(dim, sizes.size());
+Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIntArrayRef sizes, bool reapply_views) {
+ const auto ndim = sizes.size();
+ const auto mask = at::dim_list_to_bitset(dim, ndim);
// in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
// unsqueezing in the backward.
- if (sizes.size() > 0 && sizes[dim] == 1) {
- if (reapply_views) {
- return at::unsqueeze(self, dim);
- } else {
- return at::unsqueeze_copy(self, dim);
+ if (ndim == 0) {
+ return self;
+ }
+
+ Tensor result = self;
+ for (const auto d : c10::irange(ndim)) {
+ if (mask.test(d) && sizes[d] == 1) {
+ if (reapply_views) {
+ result = at::unsqueeze(result, d);
+ } else {
+ result = at::unsqueeze_copy(result, d);
+ }
}
}
- return self;
+ return result;
}
// Note [Functionalization Pass: View Inverses].
@@ -215,6 +224,10 @@
return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), reapply_views);
}
+Tensor FunctionalInverses::squeeze_copy_dims_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, IntArrayRef dim) {
+ return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), reapply_views);
+}
+
Tensor FunctionalInverses::t_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
if (reapply_views) {
return at::t(mutated_view);
diff --git a/aten/src/ATen/LegacyBatchingRegistrations.cpp b/aten/src/ATen/LegacyBatchingRegistrations.cpp
index 445c4a0..77c6410 100644
--- a/aten/src/ATen/LegacyBatchingRegistrations.cpp
+++ b/aten/src/ATen/LegacyBatchingRegistrations.cpp
@@ -296,6 +296,13 @@
return self_physical.getPhysicalToLogicalMap().apply(result);
}
+Tensor squeeze_dims_batching_rule(const Tensor& self, IntArrayRef dims) {
+ auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+ auto dims_physical = self_physical.getPhysicalDims(dims);
+ auto result = self_physical.tensor().squeeze(dims_physical);
+ return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
Tensor trace_batching_rule(const Tensor& self) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
// Batched Diagonal View
@@ -1116,6 +1123,7 @@
m.impl("split_with_sizes", split_with_sizes_batching_rule);
m.impl("squeeze", squeeze_batching_rule);
m.impl("squeeze.dim", squeeze_dim_batching_rule);
+ m.impl("squeeze.dims", squeeze_dims_batching_rule);
m.impl("t", native::t); // composite wrt autograd
m.impl("trace", trace_batching_rule);
m.impl("transpose.int", transpose_int_batching_rule);
diff --git a/aten/src/ATen/NamedTensorUtils.cpp b/aten/src/ATen/NamedTensorUtils.cpp
index 13d5ddb..9e0fc0f 100644
--- a/aten/src/ATen/NamedTensorUtils.cpp
+++ b/aten/src/ATen/NamedTensorUtils.cpp
@@ -241,6 +241,20 @@
return outnames;
}
+std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor, std::bitset<dim_bitset_size> dims) {
+ if (!tensor.has_names()) {
+ return {};
+ }
+ std::vector<Dimname> outnames;
+ auto tensor_names = tensor.names();
+ for (const auto d : c10::irange(tensor.dim())) {
+ if (!dims.test(d) || tensor.sym_sizes()[d] != 1) {
+ outnames.push_back(tensor_names[d]);
+ }
+ }
+ return outnames;
+}
+
std::vector<Dimname> compute_diagonal_outnames(
const Tensor& tensor,
int64_t dim1,
diff --git a/aten/src/ATen/NamedTensorUtils.h b/aten/src/ATen/NamedTensorUtils.h
index c9ff27c..5051cf1 100644
--- a/aten/src/ATen/NamedTensorUtils.h
+++ b/aten/src/ATen/NamedTensorUtils.h
@@ -1,6 +1,7 @@
#pragma once
#include <ATen/NamedTensor.h>
#include <ATen/TensorNames.h>
+#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/core/DimVector.h>
#include <ATen/core/Tensor.h>
@@ -144,6 +145,9 @@
const Tensor& other);
TORCH_API std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor);
+TORCH_API std::vector<Dimname> compute_squeeze_outnames(
+ const Tensor& tensor,
+ std::bitset<dim_bitset_size> dims);
std::vector<Dimname> compute_diagonal_outnames(
const Tensor& tensor,
diff --git a/aten/src/ATen/functorch/BatchRulesViews.cpp b/aten/src/ATen/functorch/BatchRulesViews.cpp
index 1cedd7d..c171fcc 100644
--- a/aten/src/ATen/functorch/BatchRulesViews.cpp
+++ b/aten/src/ATen/functorch/BatchRulesViews.cpp
@@ -239,34 +239,41 @@
return std::make_tuple(result, c10::optional<int64_t>(new_batch_idx));
}
-std::tuple<Tensor, optional<int64_t>> squeeze_dim_batch_rule(const Tensor& self, optional<int64_t> bdim, int64_t dim) {
+std::tuple<Tensor, optional<int64_t>> squeeze_dims_batch_rule(
+ const Tensor& self, optional<int64_t> bdim, IntArrayRef dims) {
TORCH_INTERNAL_ASSERT(bdim.has_value());
// Special case for scalar arrays to replicate PyTorch behavior.
- if (self.dim() == 1) {
- TORCH_CHECK(dim == 0, "Dimension is out of range (expected to be in range of [-1, 0], but got ", dim);
+ auto ndim = self.dim();
+ if (ndim == 1) {
+ TORCH_CHECK(
+ dims.size() == 0 || (dims.size() == 1 && dims[0] == 0),
+ "Dimension is out of range (expected to be in range of [-1, 0], but got ", dims);
return std::make_tuple(self.alias(), bdim);
}
- // Calculate the proper offset if dim is negative.
- auto actual_dim = dim;
- if (dim < 0) {
- actual_dim = self.dim() + dim - 1;
- }
- if (actual_dim < bdim) {
- // Since dimension to be squeezed is before the batch dimension pass as-is.
- auto original_size = self.dim();
- auto result = self.squeeze(actual_dim);
- auto updated_batch_idx = *bdim;
- if (result.dim() != original_size) {
- // A column before batch dimension has been dropped so adjust accordingly.
- --updated_batch_idx;
+ // Adjust any dimensions higher than the batch dimension
+ DimVector adjusted_dims(dims.begin(), dims.end());
+ int64_t updated_batch_idx = *bdim;
+ for (auto &d : adjusted_dims) {
+ auto actual_dim = c10::maybe_wrap_dim(d, ndim - 1);
+ if (actual_dim < *bdim) {
+ d = actual_dim;
+ if (self.sym_size(actual_dim) == 1) {
+ // A column before batch dimension will be dropped so adjust accordingly.
+ --updated_batch_idx;
+ }
+ } else {
+ // Since dimension to be squeezed is after the batch dimension adjust by one to account
+ // for the original batch dimension. In this case batch dimension won't move.
+ d = actual_dim + 1;
}
- return std::make_tuple(result, optional<int64_t>(updated_batch_idx));
- } else {
- // Since dimension to be squeezed is after the batch dimension adjust by one to account
- // for the original batch dimension. In this case batch dimension won't move.
- return std::make_tuple(self.squeeze(actual_dim + 1), bdim);
}
+ return std::make_tuple(self.squeeze(adjusted_dims), optional<int64_t>(updated_batch_idx));
+}
+
+std::tuple<Tensor, optional<int64_t>> squeeze_dim_batch_rule(
+ const Tensor& self, optional<int64_t> bdim, int64_t dim) {
+ return squeeze_dims_batch_rule(self, bdim, {dim});
}
std::tuple<std::vector<Tensor>, optional<int64_t>> chunk_batching_rule(const Tensor& self, optional<int64_t> self_bdim, int64_t chunks, int64_t dim) {
@@ -548,6 +555,7 @@
VMAP_SUPPORT2(select, int, select_batching_rule);
VMAP_SUPPORT(squeeze, squeeze_batch_rule);
VMAP_SUPPORT2(squeeze, dim, squeeze_dim_batch_rule);
+ VMAP_SUPPORT2(squeeze, dims, squeeze_dims_batch_rule);
VMAP_SUPPORT(_reshape_alias, _reshape_alias_batch_rule);
VMAP_SUPPORT(roll, roll_batch_rule);
VMAP_SUPPORT(permute, permute_batching_rule);
diff --git a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp
index c50f124..d9f6ed2 100644
--- a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp
+++ b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp
@@ -144,40 +144,52 @@
return result;
}
-Tensor& squeeze_dim__batching_rule(Tensor& self, int64_t dim) {
+Tensor& squeeze_dims__batching_rule(Tensor& self, IntArrayRef dims) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
- return self.squeeze_(dim);
+ return self.squeeze_(dims);
}
auto* batched = maybeGetBatchedImpl(self);
const auto bdim = batched->bdim();
auto logical_dim = self.dim();
- // If logically a scalar tensor, then Tensor.squeeze_(dim) is a no-op
if (logical_dim == 0) {
+ TORCH_CHECK(
+ dims.size() == 0 || (dims.size() == 1 && dims[0] == 0),
+ "Dimension is out of range (expected to be in range of [-1, 0], but got ", dims);
return self;
}
- dim = maybe_wrap_dim(dim, logical_dim);
- if (dim >= bdim) {
- dim = dim + 1;
- batched->value().squeeze_(dim);
- batched->refreshTensorMetadata();
- return self;
+ // Adjust any dimensions higher than the batch dimension
+ DimVector adjusted_dims(dims.begin(), dims.end());
+ int64_t updated_batch_idx = bdim;
+ for (auto &d : adjusted_dims) {
+ auto actual_dim = c10::maybe_wrap_dim(d, logical_dim);
+ if (actual_dim < bdim) {
+ d = actual_dim;
+ if (batched->value().sym_size(actual_dim) == 1) {
+ // A column before batch dimension will be dropped so adjust accordingly.
+ --updated_batch_idx;
+ }
+ } else {
+ // Since dimension to be squeezed is after the batch dimension adjust by one to account
+ // for the original batch dimension. In this case batch dimension won't move.
+ d = actual_dim + 1;
+ }
}
- // Tensor.squeeze_(0) is a no-op if dim 0 has a size other than 1
- if (batched->value().size(dim) != 1) {
- return self;
+ batched->value().squeeze_(adjusted_dims);
+ if (updated_batch_idx != bdim) {
+ batched->unsafe_set_bdim(updated_batch_idx);
}
-
- // dim < bdim, so we need to adjust bdim
- batched->value().squeeze_(dim);
- batched->unsafe_set_bdim(bdim - 1);
batched->refreshTensorMetadata();
return self;
}
+Tensor& squeeze_dim__batching_rule(Tensor& self, int64_t dim) {
+ return squeeze_dims__batching_rule(self, {dim});
+}
+
Tensor& squeeze__batching_rule(Tensor& self) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
@@ -816,6 +828,7 @@
// still legacy b/c needs special inplace rules
m.impl("squeeze_", squeeze__batching_rule);
m.impl("squeeze_.dim", squeeze_dim__batching_rule);
+ m.impl("squeeze_.dims", squeeze_dims__batching_rule);
m.impl("unsqueeze_", unsqueeze__batching_rule);
m.impl("transpose_", transpose__batching_rule);
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 05a90e3..cc33bda 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -90,6 +90,7 @@
#include <ATen/ops/slice.h>
#include <ATen/ops/special_logsumexp_native.h>
#include <ATen/ops/sqrt.h>
+#include <ATen/ops/squeeze.h>
#include <ATen/ops/stack.h>
#include <ATen/ops/std.h>
#include <ATen/ops/std_mean.h>
@@ -1381,23 +1382,11 @@
return at::nansum(self, dim, keepdim, opt_dtype).div(factor);
}
-static Tensor squeeze_multiple(const Tensor& self, IntArrayRef dims) {
- int ndims = self.sizes().size();
- auto dims_to_squeeze = at::dim_list_to_bitset(dims, ndims);
- Tensor result = self;
- for (int i = ndims - 1; i >= 0; --i) {
- if (dims_to_squeeze[i]) {
- result = result.squeeze(i);
- }
- }
- return result;
-}
-
static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {
// can't take max of empty tensor
if (self.numel() != 0) {
auto maxes = at::amax(self, dims, true);
- auto maxes_squeezed = (keepdim ? maxes : squeeze_multiple(maxes, dims));
+ auto maxes_squeezed = (keepdim ? maxes : at::squeeze(maxes, dims));
maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
at::sum_out(result, (self - maxes).exp_(), dims, keepdim);
result.log_().add_(maxes_squeezed);
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 7634796..2127968 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -162,6 +162,7 @@
#include <ATen/ops/split_with_sizes_native.h>
#include <ATen/ops/squeeze_copy_native.h>
#include <ATen/ops/squeeze_native.h>
+#include <ATen/ops/squeeze.h>
#include <ATen/ops/stack_native.h>
#include <ATen/ops/sub.h>
#include <ATen/ops/sum.h>
@@ -3095,6 +3096,22 @@
return std::make_tuple(std::move(sizes), std::move(strides));
}
+std::tuple<SymDimVector, SymDimVector>
+inferSqueezeGeometry(const Tensor &tensor, std::bitset<dim_bitset_size> dim_mask) {
+ const auto ndim = tensor.dim();
+ const auto sym_sizes = tensor.sym_sizes();
+ const auto sym_strides = tensor.sym_strides();
+
+ SymDimVector out_sizes, out_strides;
+ for (const auto d: c10::irange(ndim)) {
+ if (!dim_mask.test(d) || sym_sizes[d] != 1) {
+ out_sizes.push_back(sym_sizes[d]);
+ out_strides.push_back(sym_strides[d]);
+ }
+ }
+ return std::make_tuple(std::move(out_sizes), std::move(out_strides));
+}
+
namespace {
// Named type instead of a pair/tuple so that we can be sure to
// construct the vectors in place and get NRVO.
@@ -3117,18 +3134,21 @@
}
// dim is present if squeezing a single dimension and absent if squeezing all dimensions
-Tensor squeeze_qtensor(const Tensor& self, c10::optional<int64_t> dim) {
+Tensor squeeze_qtensor(const Tensor& self, c10::OptionalIntArrayRef dims) {
auto quantizer = get_qtensorimpl(self)->quantizer();
SymDimVector sizes;
SymDimVector strides;
- std::tie(sizes, strides) = dim.has_value() ? inferSqueezeGeometry(self, dim.value()) : inferSqueezeGeometry(self);
+ const auto ndim = self.dim();
+ auto mask = dims.has_value()
+ ? dim_list_to_bitset(dims, self.dim())
+ : std::bitset<dim_bitset_size>((1ull << self.dim()) - 1);
+ std::tie(sizes, strides) = inferSqueezeGeometry(self, mask);
if (quantizer->qscheme() == QScheme::PER_CHANNEL_AFFINE) {
const auto* per_channel_quantizer = static_cast<at::PerChannelAffineQuantizer*>(quantizer.get());
auto axis = per_channel_quantizer->axis();
int64_t shift = 0;
- integer_range<int64_t> dims = dim.has_value() ? integer_range<int64_t>{dim.value(), dim.value() + 1} : c10::irange(0, self.dim());
- for (const auto d : dims) {
- if (self.sizes()[d] == 1) {
+ for (const auto d : c10::irange(ndim)) {
+ if (mask.test(d) && self.sizes()[d] == 1) {
TORCH_CHECK(axis != d, "Squeeze is only possible on non-axis dimension for Per-Channel Quantized Tensors.");
if (d < axis) {
++shift;
@@ -3144,13 +3164,8 @@
// TODO: quantized Tensor support for SymInt needs to be added but basic building blocs
// are missing for now.
auto result = make_qtensor(self, C10_AS_INTARRAYREF_SLOW(sizes), C10_AS_INTARRAYREF_SLOW(strides), std::move(quantizer));
- if (dim.has_value()) {
- namedinference::propagate_names_except(result, self, {dim.value()});
- } else {
- auto maybe_outnames = namedinference::compute_squeeze_outnames(self);
- namedinference::propagate_names_if_nonempty(result, maybe_outnames);
- }
-
+ auto maybe_outnames = namedinference::compute_squeeze_outnames(self, mask);
+ namedinference::propagate_names_if_nonempty(result, maybe_outnames);
return result;
}
@@ -3163,10 +3178,7 @@
}
Tensor squeeze_quantized(const Tensor& self) {
- at::Tensor result = squeeze_qtensor(self, c10::nullopt);
- auto maybe_outnames = namedinference::compute_squeeze_outnames(self);
- namedinference::propagate_names_if_nonempty(result, maybe_outnames);
- return result;
+ return squeeze_qtensor(self, c10::nullopt);
}
Tensor squeeze(const Tensor& self, int64_t dim) {
@@ -3182,8 +3194,19 @@
}
Tensor squeeze_quantized(const Tensor& self, int64_t dim) {
- int64_t dims = self.dim();
- dim = maybe_wrap_dim(dim, dims);
+ return squeeze_qtensor(self, dim);
+}
+
+Tensor squeeze(const Tensor& self, IntArrayRef dims) {
+ auto mask = dim_list_to_bitset(dims, self.dim());
+ auto g = inferSqueezeGeometry(self, mask);
+ at::Tensor result = self.as_strided_symint(std::get<0>(g), std::get<1>(g));
+ auto maybe_outnames = namedinference::compute_squeeze_outnames(self, mask);
+ namedinference::propagate_names_if_nonempty(result, maybe_outnames);
+ return result;
+}
+
+Tensor squeeze_quantized(const Tensor& self, IntArrayRef dim) {
return squeeze_qtensor(self, dim);
}
@@ -3206,6 +3229,13 @@
return self;
}
+Tensor & squeeze_(Tensor &self, IntArrayRef dims) {
+ auto mask = dim_list_to_bitset(dims, self.dim());
+ auto g = inferSqueezeGeometry(self, mask);
+ self.as_strided__symint(std::get<0>(g), std::get<1>(g));
+ return self;
+}
+
// NOTE [ Unsafe View ]
// _unsafe_view() differs from view() in that the returned tensor isn't treated
// as a view for the purposes of automatic differentiation. (It's not listed in
@@ -4044,6 +4074,13 @@
}
+at::Tensor& squeeze_copy_dims_out(const at::Tensor & self, IntArrayRef dims, at::Tensor & out) {
+ auto tmp = self.squeeze(dims);
+ out.copy_(tmp);
+ return out;
+}
+
+
at::Tensor& t_copy_out(const at::Tensor & self, at::Tensor & out) {
auto tmp = self.t();
out.copy_(tmp);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 5f3e1ab..d52c8df 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5212,6 +5212,17 @@
device_check: NoCheck
device_guard: False
+
+- func: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)
+ variants: function, method
+ device_check: NoCheck
+ device_guard: False
+ dispatch:
+ CompositeExplicitAutograd: squeeze
+ QuantizedCPU, QuantizedCUDA: squeeze_quantized
+ NestedTensorCPU, NestedTensorCUDA: squeeze_dim_nested
+ tags: canonical
+
- func: squeeze_(Tensor(a!) self) -> Tensor(a!)
variants: method
device_check: NoCheck
@@ -5228,6 +5239,14 @@
dispatch:
CompositeExplicitAutograd: squeeze_
+- func: squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!)
+ variants: method
+ device_check: NoCheck
+ device_guard: False
+ tags: inplace_view
+ dispatch:
+ CompositeExplicitAutograd: squeeze_
+
- func: squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!)
variants: method
device_check: NoCheck
@@ -13600,6 +13619,12 @@
CompositeExplicitAutogradNonFunctional: squeeze_copy_dim
tags: view_copy
+- func: squeeze_copy.dims(Tensor self, int[] dim) -> Tensor
+ variants: function
+ dispatch:
+ CompositeExplicitAutogradNonFunctional: squeeze_copy_dims
+ tags: view_copy
+
- func: t_copy(Tensor self) -> Tensor
variants: function
dispatch:
@@ -13812,6 +13837,12 @@
CompositeExplicitAutograd: squeeze_copy_dim_out
+- func: squeeze_copy.dims_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)
+ variants: function
+ dispatch:
+ CompositeExplicitAutograd: squeeze_copy_dims_out
+
+
- func: t_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp
index 5842c3b..287e861 100644
--- a/aten/src/ATen/native/nested/NestedTensorMath.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp
@@ -9,6 +9,7 @@
#include <ATen/TensorIndexing.h>
#include <ATen/TensorOperators.h>
#include <ATen/TensorUtils.h>
+#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/layer_norm.h>
#include <ATen/native/nested/NestedTensorUtils.h>
@@ -682,37 +683,53 @@
return self;
}
-Tensor squeeze_dim_nested(const Tensor& self, int64_t dim) {
+Tensor squeeze_dim_nested(const Tensor& self, IntArrayRef dims) {
auto self_ptr = get_nested_tensor_impl(self);
int64_t ndim = self_ptr->dim();
- int64_t wrapped_dim = at::maybe_wrap_dim(dim, ndim);
- TORCH_CHECK(wrapped_dim > 0,
+ auto mask = at::dim_list_to_bitset(dims, ndim);
+ TORCH_CHECK(!mask.test(0),
"squeeze(): For nested tensors, squeezing dimension 0 is not supported at the moment ",
"if you need this feature, please open an issue on github describing your use case.");
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
const Tensor& stridemat = self_ptr->get_nested_stride_tensor();
// if tensor.size(dim) != 1 torch.squeeze will return the result, we do the same here
- c10::optional<int64_t> size_dim = self_ptr->opt_size(dim);
- if (!(size_dim.has_value() && size_dim.value() == 1)) {
+ for (const auto d : c10::irange(ndim)) {
+ if (mask.test(d)) {
+ c10::optional<int64_t> size_dim = self_ptr->opt_size(d);
+ if (!(size_dim.has_value() && *size_dim == 1)) {
+ mask.reset(d);
+ }
+ }
+ }
+
+ if (!mask.any()) {
// detach to avoid triggering throw_error_if_base_and_tensor_are_same
return self.detach();
}
// if ndim == 2 and we pass the above if statement we should have a
// nested tensor of singleton tensors
- TORCH_CHECK(ndim != 2,
+ TORCH_CHECK(ndim > static_cast<int64_t>(1 + dims.size()),
"squeeze(): For nested tensors, squeezing a nested tensor of singleton tensors is not ",
"supported at the moment, if you need this feature, please open an issue on github",
"describing your use case.");
- auto column_indices = sizemat.new_empty(ndim - 2);
+ const auto new_ndim = ndim - mask.count();
+ auto column_indices = sizemat.new_empty(new_ndim - 1);
int64_t* column_indices_ptr = column_indices.data_ptr<int64_t>();
- std::iota(column_indices_ptr, column_indices_ptr + wrapped_dim - 1, 0);
- std::iota(column_indices_ptr + wrapped_dim - 1, column_indices_ptr + ndim - 2, wrapped_dim);
+ for (const auto d : c10::irange(1, ndim)) {
+ if (!mask.test(d)) {
+ *column_indices_ptr++ = d - 1;
+ }
+ }
auto sizemat_squeezed = at::index_select(sizemat, 1, column_indices);
auto stridemat_squeezed = at::index_select(stridemat, 1, column_indices);
return create_nested_view_tensor(
self, sizemat_squeezed, stridemat_squeezed, std::vector<int64_t>(self_ptr->get_storage_offsets()));
}
+Tensor squeeze_dim_nested(const Tensor& self, int64_t dim) {
+ return squeeze_dim_nested(self, IntArrayRef{dim});
+}
+
Tensor unsqueeze_nested(const Tensor& self, int64_t dim) {
auto self_ptr = get_nested_tensor_impl(self);
int64_t ndim = self_ptr->dim();
diff --git a/aten/src/ATen/native/ts_native_functions.yaml b/aten/src/ATen/native/ts_native_functions.yaml
index f4c3ee8..85ac57e 100644
--- a/aten/src/ATen/native/ts_native_functions.yaml
+++ b/aten/src/ATen/native/ts_native_functions.yaml
@@ -157,6 +157,7 @@
#- unbind_copy.int
- squeeze_copy
- squeeze_copy.dim
+ - squeeze_copy.dims
- t_copy
- transpose_copy.int
- unsqueeze_copy
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index 3c59bfc..67f6a44 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -3322,7 +3322,7 @@
def vmap_outplace_test(self, func, args, kwargs, in_dims, check_shape_only=False,
postprocess_fn=None):
- for loop_out, vmap_out in compute_quantities_for_vmap_test(func, args, kwargs, in_dims):
+ for vmap_out, loop_out in compute_quantities_for_vmap_test(func, args, kwargs, in_dims):
if postprocess_fn is not None:
loop_out = postprocess_fn(loop_out)
vmap_out = postprocess_fn(vmap_out)
@@ -3343,7 +3343,7 @@
func, args, kwargs, in_dims, compute_loop_out=False, clone_inputs=True):
pass
return
- for loop_out, vmap_out in compute_quantities_for_vmap_test(
+ for vmap_out, loop_out in compute_quantities_for_vmap_test(
func, args, kwargs, in_dims, clone_inputs=True):
if postprocess_fn is not None:
loop_out = postprocess_fn(loop_out)
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 10a0080..a69b8b4 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1506,6 +1506,14 @@
AutogradNestedTensor:
self: grad.unsqueeze(dim)
+- name: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)
+ dispatch:
+ Default:
+ self: unsqueeze_to(grad, dim, self.sym_sizes())
+ result: auto_linear
+ AutogradNestedTensor:
+ self: unsqueeze_multiple(grad, dim, self.dim())
+
- name: squeeze_(Tensor(a!) self) -> Tensor(a!)
self: unsqueeze_to(grad, self.sym_sizes())
result: auto_linear
@@ -1514,6 +1522,10 @@
self: unsqueeze_to(grad, dim, self.sym_sizes())
result: auto_linear
+- name: squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!)
+ self: unsqueeze_to(grad, dim, self.sym_sizes())
+ result: auto_linear
+
- name: std.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor
self: std_backward(result, grad, self, dim, correction, keepdim)
# pointwise (variance) + sum + sqrt
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index e6dd438..c3aa8bf 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -12,6 +12,7 @@
import torch.utils._pytree as pytree
from torch import _prims_common
from torch._prims_common import (
+ canonicalize_dims,
dtype_to_type,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
@@ -488,16 +489,17 @@
assert isinstance(x, TensorBox)
if dim is None:
return TensorBox(SqueezeView.create(x.data))
- offset = len(x.get_size()) == 0
- dim = _validate_dim(x, dim, offset)
- new_shape = list(x.get_size())
- if len(new_shape) > 0:
- removed = new_shape.pop(dim)
- if V.graph.sizevars.maybe_guard_equals(removed, 1):
- return view(x, new_shape)
+ dim = canonicalize_dims(len(x.get_size()), dim)
+ dims = set((dim,) if not isinstance(dim, tuple) else dim)
+
+ new_shape = [
+ s
+ for d, s in enumerate(x.get_size())
+ if not (d in dims and V.graph.sizevars.maybe_guard_equals(s, 1))
+ ]
# squeeze does nothing if the size isn't 1
- return x
+ return view(x, new_shape) if new_shape != x.get_size() else x
@register_lowering([aten.squeeze_])
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index 52aa4b8..f0c07dd 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -1706,13 +1706,6 @@
return a.as_strided(new_shape, new_strides, a.storage_offset())
-def _squeeze_aten(a: Tensor, dimensions: Sequence) -> Tensor:
- for idx in reversed(sorted(dimensions)):
- a = torch.squeeze(a, dim=idx)
-
- return a
-
-
_squeeze_doc = """
Creates a view of the tensor with the specified dimensions removed.
@@ -1722,7 +1715,7 @@
squeeze = _make_prim(
schema="squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)",
meta=_squeeze_meta,
- impl_aten=_squeeze_aten,
+ impl_aten=torch.squeeze,
return_type=RETURN_TYPE.VIEW,
doc=_squeeze_doc,
)
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index e792e35..e581de0 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -3513,21 +3513,22 @@
@register_decomposition(aten.squeeze)
-def squeeze(a: TensorLikeType, dim: Optional[int] = None) -> TensorLikeType:
- if dim is not None:
- dim = utils.canonicalize_dim(a.ndim, dim)
- # Short-circuits if the tensor has no dimensions
- if len(a.shape) == 0:
- assert dim == 0
- return prims.view_of(a)
+def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
+ if dim is None:
+ dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1)
+ return prims.squeeze(a, dims) if dims else prims.view_of(a)
- # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
- if a.shape[dim] != 1:
- return prims.view_of(a)
- return prims.squeeze(a, (dim,))
+ ndim = a.ndim
+ dim = utils.canonicalize_dims(ndim, dim)
+ dims = (dim,) if isinstance(dim, Dim) else dim
+ # Short-circuits if the tensor has no dimensions
+ if ndim == 0:
+ assert len(dims) == 0 or dims == (0,)
+ return prims.view_of(a)
- dims = tuple(idx for idx in range(len(a.shape)) if a.shape[idx] == 1)
- return prims.squeeze(a, dims)
+ # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
+ dims = tuple(d for d in dims if a.shape[d] == 1)
+ return prims.squeeze(a, dims) if dims else prims.view_of(a)
# Note: does not work with TensorMetas because of data-dependent control-flow
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 42292d6..ee0254e 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -10640,14 +10640,14 @@
r"""
squeeze(input, dim=None) -> Tensor
-Returns a tensor with all the dimensions of :attr:`input` of size `1` removed.
+Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed.
For example, if `input` is of shape:
-:math:`(A \times 1 \times B \times C \times 1 \times D)` then the `out` tensor
+:math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()`
will be of shape: :math:`(A \times B \times C \times D)`.
When :attr:`dim` is given, a squeeze operation is done only in the given
-dimension. If `input` is of shape: :math:`(A \times 1 \times B)`,
+dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`,
``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)``
will squeeze the tensor to the shape :math:`(A \times B)`.
@@ -10656,12 +10656,15 @@
.. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)`
will also remove the batch dimension, which can lead to unexpected
- errors.
+ errors. Consider specifying only the dims you wish to be squeezed.
Args:
{input}
- dim (int, optional): if given, the input will be squeezed only in
- this dimension
+ dim (int or tuple of ints, optional): if given, the input will be squeezed
+ only in the specified dimensions.
+
+ .. versionchanged:: 2.0
+ :attr:`dim` now accepts tuples of dimensions.
Example::
@@ -10677,6 +10680,8 @@
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])
+ >>> y = torch.squeeze(x, (1, 2, 3))
+ torch.Size([2, 2, 2])
""".format(
**common_args
),
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index c874eb8..1135a82 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -863,15 +863,25 @@
Tensor unsqueeze_to(
const Tensor& self,
+ IntArrayRef dims,
+ c10::SymIntArrayRef sym_sizes) {
+ const auto ndim = sym_sizes.size();
+ auto mask = at::dim_list_to_bitset(dims, ndim);
+
+ Tensor result = self;
+ for (const auto d : c10::irange(ndim)) {
+ if (mask.test(d) && sym_sizes[d] == 1) {
+ result = result.unsqueeze(d);
+ }
+ }
+ return result;
+}
+
+Tensor unsqueeze_to(
+ const Tensor& self,
int64_t dim,
c10::SymIntArrayRef sym_sizes) {
- dim = at::maybe_wrap_dim(dim, sym_sizes.size());
- // in NumPy it's not an error to unsqueeze a scalar, but we still need to
- // avoided unsqueezing in the backward.
- if (sym_sizes.size() > 0 && sym_sizes[dim] == 1) {
- return self.unsqueeze(dim);
- }
- return self;
+ return unsqueeze_to(self, IntArrayRef{dim}, sym_sizes);
}
std::vector<Tensor> cat_tensors_backward(
diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h
index ddeb08c..1279f0a 100644
--- a/torch/csrc/autograd/FunctionsManual.h
+++ b/torch/csrc/autograd/FunctionsManual.h
@@ -220,6 +220,10 @@
const at::Tensor& self,
int64_t dim,
c10::SymIntArrayRef sym_sizes);
+at::Tensor unsqueeze_to(
+ const at::Tensor& self,
+ IntArrayRef dim,
+ c10::SymIntArrayRef sym_sizes);
std::vector<at::Tensor> cat_tensors_backward(
const at::Tensor& grad,
const std::vector<std::vector<c10::SymInt>>& sizes,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index c99eb43..32f11f9 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -13,7 +13,7 @@
import numpy as np
from torch._six import inf, nan
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any, Dict, List, Tuple, Union, Sequence
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
@@ -4771,6 +4771,7 @@
shapes_and_args = (
((S, 1, S, 1), ()),
((1, 1, 1, 1), ()),
+ ((1, 1, 1, 1), (0,)),
((S, 1, S, 1), (1,)),
((S, 1, S, 1), (-1,)),
((S, 1, S, 1), (2,)),
@@ -4785,6 +4786,37 @@
yield SampleInput(tensor, args=args)
+def sample_inputs_squeeze_multiple(op_info, device, dtype, requires_grad, **kwargs):
+ shapes_and_args = (
+ ((1, 1, 1, 1), ()),
+ ((S, 1, S, 1), (1,)),
+ ((S, 1, S, 1), (-1,)),
+ ((S, 1, S, 1), (1, 3)),
+ ((S, 1, S, 1), (1, 2,)),
+ ((), (0,)),
+ )
+
+ for shape, dims in shapes_and_args:
+ tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
+ requires_grad=requires_grad)
+
+ yield SampleInput(tensor, dims)
+
+
+def _squeeze_ref(x, axis=None):
+ # NumPy doesn't allow squeezing scalars
+ if x.ndim == 0:
+ return x
+
+ if isinstance(axis, Sequence):
+ # Numpy doesn't allow specifying non-singular dimensions
+ axis = tuple(a for a in axis if x.shape[a] == 1)
+
+ if isinstance(axis, int) and x.shape[axis] != 1:
+ return x
+
+ return np.squeeze(x, axis)
+
def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs):
assert mode in ('constant', 'reflect', 'replicate', 'circular')
if mode in ['reflect', 'replicate']:
@@ -15654,6 +15686,7 @@
DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
)),
OpInfo('squeeze',
+ ref=_squeeze_ref,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=False,
assert_autodiffed=True,
@@ -15667,6 +15700,21 @@
# https://github.com/pytorch/pytorch/issues/66357
check_batched_forward_grad=False,
sample_inputs_func=sample_inputs_squeeze),
+ OpInfo('squeeze',
+ ref=_squeeze_ref,
+ variant_test_name="multiple",
+ dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+ supports_out=False,
+ assert_autodiffed=True,
+ autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
+ autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
+ supports_forward_ad=True,
+ supports_fwgrad_bwgrad=True,
+ # vmap does not support inplace views
+ check_inplace_batched_forward_grad=False,
+ # https://github.com/pytorch/pytorch/issues/66357
+ check_batched_forward_grad=False,
+ sample_inputs_func=sample_inputs_squeeze_multiple),
UnaryUfuncInfo(
'fill',
ref=_fill_np,
@@ -19016,6 +19064,11 @@
torch_opinfo_name="squeeze",
),
PythonRefInfo(
+ "_refs.squeeze",
+ torch_opinfo_name="squeeze",
+ torch_opinfo_variant_name="multiple",
+ ),
+ PythonRefInfo(
"_refs.tensor_split",
torch_opinfo_name="tensor_split",
skips=(