Revert "Do not use unsafe restriding for subclasses (#87610)"
This reverts commit 73379acaf3865379aed0a1bab1320616772152f3.
Reverted https://github.com/pytorch/pytorch/pull/87610 on behalf of https://github.com/mehtanirav due to [Internal breakages](https://www.internalfb.com/intern/sandcastle/job/36028797828925790/insights)
diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp
index fc51e9d..5eecbed 100644
--- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp
+++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp
@@ -928,11 +928,6 @@
return at::scatter(self, dim, index_, source); ;
}
-// Note [Fix vmap slice_scatter]
-// registers a decomposition for `slice_scatter` that calls into `slice.src`
-// *_scatter operators have some special semantics though, that we can't easily
-// through a decomposition: slice_scatter's output needs to have the same
-// size, size, strides and storage_offset as the input.
Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src,
int64_t dim, c10::optional<int64_t> start,
c10::optional<int64_t> end, int64_t step)
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 282eec8..2051cda 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -3,7 +3,6 @@
#include <ATen/core/DimVector.h>
#include <ATen/core/functional.h>
#include <ATen/core/IListRef.h>
-#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
@@ -1566,7 +1565,7 @@
//
// We need to do the checks here instead of in `native_functions.yaml`
// to preserve backwards compatibility.
- if (!self.is_xla() && !self.is_lazy() && !self.is_ipu() && !at::isTensorSubclassLike(self)) {
+ if (!self.is_xla() && !self.is_lazy() && !self.is_ipu()) {
return self._reshape_alias_symint(shape, stride.value());
} else {
return self.view_symint(shape);
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index ff69ed9..26b64c5 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -3130,16 +3130,13 @@
return fx_g
class TestFunctionalize(TestCase):
- def _check_functionalize_correctness(self, f, inpt, *, skip_vmap=False):
+ def _check_functionalize_correctness(self, f, inpt):
inpt1 = inpt.clone()
inpt2 = inpt.clone()
inpt3 = inpt.clone()
expected_outputs = f(inpt1)
- if skip_vmap:
- actual_outputs = functionalize(f)(inpt2)
- else:
- actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze()
+ actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze()
# Right now the flavor of functionalize that also removes view ops
# isn't being used with vmap
# That's because {view}_copy ops don't have batching rules yet
@@ -3209,8 +3206,7 @@
z2, z3 = z1.split(2)
z2.add_(tmp)
return x
- # See Note [Fix vmap slice_scatter]
- self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device), skip_vmap=True)
+ self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device))
# Ensure functionalize works with List[Optional[Tensor]] arguments.
# See the fix / discussion at https://github.com/pytorch/pytorch/pull/76085
diff --git a/test/test_functionalization.py b/test/test_functionalization.py
index e2cca26..bfb7967 100644
--- a/test/test_functionalization.py
+++ b/test/test_functionalization.py
@@ -156,11 +156,11 @@
as_strided_copy_4 = torch.ops.aten.as_strided_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0)
clone_1 = torch.ops.aten.clone.default(as_strided_copy_4, memory_format = torch.contiguous_format); as_strided_copy_4 = None
threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0); clone_1 = relu = None
- view_copy_2 = torch.ops.aten.view_copy.default(as_strided_copy_2, [16, 64, 128, 128])
- detach_copy = torch.ops.aten.detach_copy.default(view_copy_2); view_copy_2 = None
+ _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1])
+ detach_copy = torch.ops.aten.detach_copy.default(_reshape_alias_copy); _reshape_alias_copy = None
as_strided_scatter_1 = torch.ops.aten.as_strided_scatter.default(as_strided_copy_2, threshold_backward, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0); as_strided_copy_2 = threshold_backward = None
- view_copy_3 = torch.ops.aten.view_copy.default(as_strided_scatter_1, [16, 64, 128, 128]); as_strided_scatter_1 = None
- detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_3); view_copy_3 = None
+ _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(as_strided_scatter_1, [16, 64, 128, 128], [1048576, 16384, 128, 1]); as_strided_scatter_1 = None
+ detach_copy_1 = torch.ops.aten.detach_copy.default(_reshape_alias_copy_1); _reshape_alias_copy_1 = None
return detach_copy_1
""") # noqa: B950
@@ -713,40 +713,40 @@
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view_copy = torch.ops.aten.view_copy.default(add, [8])
- view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]); view_copy = None
- transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0)
+ _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(view_copy, [2, 4], [4, 1]); view_copy = None
+ transpose_copy = torch.ops.aten.transpose_copy.int(_reshape_alias_copy, 1, 0)
unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0); transpose_copy = None
squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None
split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None
getitem = split_copy[0]
getitem_1 = split_copy[1]; split_copy = None
add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
- select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = None
- view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4])
- view_copy_3 = torch.ops.aten.view_copy.default(add, [8]); add = None
- view_copy_4 = torch.ops.aten.view_copy.default(view_copy_3, [2, 4]); view_copy_3 = None
- transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_4, 1, 0); view_copy_4 = None
+ select_copy = torch.ops.aten.select_copy.int(_reshape_alias_copy, 0, 0); _reshape_alias_copy = None
+ _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(add_1, [4], [1])
+ view_copy_1 = torch.ops.aten.view_copy.default(add, [8]); add = None
+ _reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]); view_copy_1 = None
+ transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_2, 1, 0); _reshape_alias_copy_2 = None
unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None
squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None
slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = None
unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None
squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None
transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None
- view_copy_5 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]); transpose_copy_2 = None
- view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [4, 2]); view_copy_5 = None
- view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [8])
- view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [2, 4]); view_copy_7 = None
- select_copy_1 = torch.ops.aten.select_copy.int(view_copy_8, 0, 0); view_copy_8 = None
- view_copy_9 = torch.ops.aten.view_copy.default(view_copy_6, [8]); view_copy_6 = None
- view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]); view_copy_9 = None
- transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_10, 1, 0); view_copy_10 = None
+ _reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]); transpose_copy_2 = None
+ view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_3, [4, 2]); _reshape_alias_copy_3 = None
+ view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8])
+ _reshape_alias_copy_4 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]); view_copy_3 = None
+ select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_4, 0, 0); _reshape_alias_copy_4 = None
+ view_copy_4 = torch.ops.aten.view_copy.default(view_copy_2, [8]); view_copy_2 = None
+ _reshape_alias_copy_5 = torch.ops.aten._reshape_alias_copy.default(view_copy_4, [2, 4], [4, 1]); view_copy_4 = None
+ transpose_copy_3 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_5, 1, 0); _reshape_alias_copy_5 = None
unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None
squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None
split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None
getitem_2 = split_copy_1[0]
getitem_3 = split_copy_1[1]; split_copy_1 = None
- view_copy_11 = torch.ops.aten.view_copy.default(getitem_2, [4]); getitem_2 = None
- add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_11); select_copy_1 = view_copy_11 = None
+ _reshape_alias_copy_6 = torch.ops.aten._reshape_alias_copy.default(getitem_2, [4], [1]); getitem_2 = None
+ add_2 = torch.ops.aten.add.Tensor(select_copy_1, _reshape_alias_copy_6); select_copy_1 = _reshape_alias_copy_6 = None
return add_1
""") # noqa: B950
@@ -759,30 +759,30 @@
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view = torch.ops.aten.view.default(add, [8])
- view_1 = torch.ops.aten.view.default(view, [2, 4]); view = None
- transpose = torch.ops.aten.transpose.int(view_1, 1, 0)
+ _reshape_alias = torch.ops.aten._reshape_alias.default(view, [2, 4], [4, 1]); view = None
+ transpose = torch.ops.aten.transpose.int(_reshape_alias, 1, 0)
unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0); transpose = None
squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None
split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None
getitem = split[0]
getitem_1 = split[1]; split = None
add_1 = torch.ops.aten.add_.Tensor(getitem, ones); ones = None
- select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None
+ select = torch.ops.aten.select.int(_reshape_alias, 0, 0); _reshape_alias = None
clone = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format)
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
- view_2 = torch.ops.aten.view.default(add, [8]); add = None
- view_3 = torch.ops.aten.view.default(view_2, [2, 4]); view_2 = None
- transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0); view_3 = None
+ view_1 = torch.ops.aten.view.default(add, [8]); add = None
+ _reshape_alias_1 = torch.ops.aten._reshape_alias.default(view_1, [2, 4], [4, 1]); view_1 = None
+ transpose_1 = torch.ops.aten.transpose.int(_reshape_alias_1, 1, 0); _reshape_alias_1 = None
unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0); transpose_1 = None
squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1); unsqueeze_1 = None
unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0); squeeze_1 = None
squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0); unsqueeze_2 = None
transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None
- view_4 = torch.ops.aten.view.default(transpose_2, [8]); transpose_2 = None
- view_5 = torch.ops.aten.view.default(view_4, [4, 2]); view_4 = None
- view_6 = torch.ops.aten.view.default(view_5, [8]); view_5 = None
- view_7 = torch.ops.aten.view.default(view_6, [2, 4]); view_6 = None
- select_1 = torch.ops.aten.select.int(view_7, 0, 0); view_7 = None
+ _reshape_alias_2 = torch.ops.aten._reshape_alias.default(transpose_2, [8], [1]); transpose_2 = None
+ view_2 = torch.ops.aten.view.default(_reshape_alias_2, [4, 2]); _reshape_alias_2 = None
+ view_3 = torch.ops.aten.view.default(view_2, [8]); view_2 = None
+ _reshape_alias_3 = torch.ops.aten._reshape_alias.default(view_3, [2, 4], [4, 1]); view_3 = None
+ select_1 = torch.ops.aten.select.int(_reshape_alias_3, 0, 0); _reshape_alias_3 = None
add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None
return getitem
""")