Revert "Allow only one -1 in nested view/reshape (#85691)"
This reverts commit 4c4e5f6106b69960833d7766799fd4f246aa7cd7.
Reverted https://github.com/pytorch/pytorch/pull/85691 on behalf of https://github.com/atalman due to Causes github first merge conflict
diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp
index 98d6f7f..84d62e9 100644
--- a/aten/src/ATen/native/nested/NestedTensorMath.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp
@@ -997,8 +997,6 @@
& stride = strides[itensor];
// compute reshaped size
std::vector<int64_t> size_reshaped_vector(proposed_shape.begin() + 1, proposed_shape.end());
- // only allow one pre-existing dimension to have proposed shape == -1
- int64_t infer_index_old = -1;
// some negative sizes remain to be infered
if (ndims_underlying < ndims_underlying_reshaped) {
int64_t numel = 1, numel_reshaped = 1;
@@ -1007,9 +1005,7 @@
int64_t& size_reshaped = size_reshaped_vector[idim];
TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped);
if (size_reshaped == -1) {
- TORCH_CHECK(infer_index_old == -1, "only one dimension can be inferred");
size_reshaped = size[idim];
- infer_index_old = idim;
}
numel *= size[idim];
numel_reshaped *= size_reshaped;
@@ -1091,7 +1087,7 @@
// Note [Special size rule for nested tensor]
// Instead of infering size, -1 means "inherit the old size", so:
// * negative size is legal for a ragged dimension
-// * however, we only allow one -1
+// * multiple sizes can be -1
// In principle we could still infer a dimension,
// we are designing a better semantics to include both inheritance and inference
Tensor view_nested(const Tensor& self, IntArrayRef proposed_shape) {
@@ -1105,10 +1101,19 @@
ntensors > 0,
"empty nested tensor cannot be reshaped");
// basic information after reshaping
- int64_t ntensors_reshaped = proposed_shape[0];
+ int64_t ntensors_reshaped;
+ if (proposed_shape[0] >= 0) {
+ ntensors_reshaped = proposed_shape[0];
+ }
+ else if (proposed_shape[0] == -1) {
+ ntensors_reshaped = ntensors;
+ }
+ else {
+ AT_ERROR("invalid shape dimension ", proposed_shape[0]);
+ }
TORCH_CHECK(
ntensors == ntensors_reshaped,
- "view: For now nested view cannot change or infer the implicit batch dimension");
+ "for now view cannot change the implicit batch dimension");
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
strides = NestedTensor_get_strides(self_ptr);
// reshaping underlying tensor dimensions does not change offset
@@ -1192,10 +1197,19 @@
ntensors > 0,
"empty nested tensor cannot be reshaped");
// basic information after reshaping
- int64_t ntensors_reshaped = proposed_shape[0];
+ int64_t ntensors_reshaped{0};
+ if (proposed_shape[0] >= 0) {
+ ntensors_reshaped = proposed_shape[0];
+ }
+ else if (proposed_shape[0] == -1) {
+ ntensors_reshaped = ntensors;
+ }
+ else {
+ AT_ERROR("invalid shape dimension ", proposed_shape[0]);
+ }
TORCH_CHECK(
ntensors == ntensors_reshaped,
- "reshape: For now nested reshape cannot change or infer the implicit batch dimension");
+ "for now reshape cannot change the implicit batch dimension");
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
strides = NestedTensor_get_strides(self_ptr);
// reshaping underlying tensor dimensions does not change offset
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 4221231..9de07c9 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -1178,11 +1178,11 @@
"empty nested tensor cannot be reshaped",
lambda: nt_empty.view(-1)
)
- # error case: -1 for batch size
+ # error case: invalid proposed shape for underlying tensors
self.assertRaisesRegex(
RuntimeError,
- r"view: For now nested view cannot change or infer the implicit batch dimension",
- lambda: nt.view(-1, 2, 3)
+ r"invalid shape dimension -2",
+ lambda: nt.view(-2, 2, 3)
)
self.assertRaisesRegex(
RuntimeError,
@@ -1194,10 +1194,9 @@
x1 = torch.randn((3, 20), device=device, dtype=dtype)
nt = torch.nested_tensor([x0, x1])
pt = torch.nested.to_padded_tensor(nt, 0.0)
- # error case, trying to reshape batch dim to a legit shape
self.assertRaisesRegex(
RuntimeError,
- r"For now nested view cannot change or infer the implicit batch dimension",
+ r"for now view cannot change the implicit batch dimension",
lambda: nt.transpose(-1, -2).view(40, -1)
)
# inherit only the ragged dimension
@@ -1207,15 +1206,10 @@
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
pt1 = pt.view(2, -1, 5, 4)
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
-
- # more than one -1 (even for "old" dims), should fail
- # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
- # but we ban "inherit old behavior" for >1 dimension
- self.assertRaisesRegex(
- RuntimeError,
- r"only one dimension can be inferred",
- lambda: nt1.view(2, -1, -1, 2, 2)
- )
+ # also inherit regular dimension
+ nt2 = nt1.view(2, -1, -1, 2, 2)
+ pt2 = pt1.view(2, -1, 5, 2, 2)
+ self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2)
@dtypes(torch.float, torch.float16, torch.double)
def test_view_inference_mode_interaction(self, device, dtype):
@@ -1252,11 +1246,11 @@
"empty nested tensor cannot be reshaped",
lambda: nt_empty.reshape(-1)
)
- # error case: -1 for batch size
+ # error case: invalid proposed shape for underlying tensors
self.assertRaisesRegex(
RuntimeError,
- r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
- lambda: nt.reshape(-1, 2, 3)
+ r"invalid shape dimension -2",
+ lambda: nt.reshape(-2, 2, 3)
)
self.assertRaisesRegex(
RuntimeError,
@@ -1266,12 +1260,11 @@
# normal case
x0 = torch.randn((2, 20), device=device, dtype=dtype)
x1 = torch.randn((3, 20), device=device, dtype=dtype)
- nt = torch.nested_tensor([x0, x1]) # (2, (2, 3), 20)
+ nt = torch.nested_tensor([x0, x1])
pt = torch.nested.to_padded_tensor(nt, 0.0)
- # error case, trying to reshape batch dim to a legit shape
self.assertRaisesRegex(
RuntimeError,
- r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
+ r"for now reshape cannot change the implicit batch dimension",
lambda: nt.transpose(-1, -2).reshape(40, -1)
)
# inherit only the ragged dimension
@@ -1281,15 +1274,10 @@
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
pt1 = pt.reshape(2, -1, 5, 4)
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
-
- # more than one -1 (even for "old" dims), should fail
- # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
- # but we ban "inherit old behavior" for >1 dimension
- self.assertRaisesRegex(
- RuntimeError,
- r"only one dimension can be inferred",
- lambda: nt1.reshape(2, -1, -1, 2, 2)
- )
+ # also inherit regular dimension
+ nt2 = nt1.reshape(2, -1, -1, 2, 2)
+ pt2 = pt1.reshape(2, -1, 5, 2, 2)
+ self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2)
@parametrize("input_dim", [3, 4])
def test_scaled_dot_product_attention(self, device, input_dim):