Revert "Do not use unsafe restriding for subclasses (#87610)"

This reverts commit 73379acaf3865379aed0a1bab1320616772152f3.

Reverted https://github.com/pytorch/pytorch/pull/87610 on behalf of https://github.com/mehtanirav due to [Internal breakages](https://www.internalfb.com/intern/sandcastle/job/36028797828925790/insights)
diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp
index fc51e9d..5eecbed 100644
--- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp
+++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp
@@ -928,11 +928,6 @@
   return at::scatter(self, dim, index_, source);  ;
 }
 
-// Note [Fix vmap slice_scatter]
-// registers a decomposition for `slice_scatter` that calls into `slice.src`
-// *_scatter operators have some special semantics though, that we can't easily
-// through a decomposition: slice_scatter's output needs to have the same
-// size, size, strides and storage_offset as the input.
 Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src,
                             int64_t dim, c10::optional<int64_t> start,
                             c10::optional<int64_t> end, int64_t step)
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 282eec8..2051cda 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -3,7 +3,6 @@
 #include <ATen/core/DimVector.h>
 #include <ATen/core/functional.h>
 #include <ATen/core/IListRef.h>
-#include <ATen/TensorSubclassLikeUtils.h>
 #include <ATen/AccumulateType.h>
 #include <ATen/Dispatch.h>
 #include <ATen/ExpandUtils.h>
@@ -1566,7 +1565,7 @@
     //
     // We need to do the checks here instead of in `native_functions.yaml`
     // to preserve backwards compatibility.
-    if (!self.is_xla() && !self.is_lazy() && !self.is_ipu() && !at::isTensorSubclassLike(self)) {
+    if (!self.is_xla() && !self.is_lazy() && !self.is_ipu()) {
       return self._reshape_alias_symint(shape, stride.value());
     } else {
       return self.view_symint(shape);
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index ff69ed9..26b64c5 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -3130,16 +3130,13 @@
     return fx_g
 
 class TestFunctionalize(TestCase):
-    def _check_functionalize_correctness(self, f, inpt, *, skip_vmap=False):
+    def _check_functionalize_correctness(self, f, inpt):
         inpt1 = inpt.clone()
         inpt2 = inpt.clone()
         inpt3 = inpt.clone()
 
         expected_outputs = f(inpt1)
-        if skip_vmap:
-            actual_outputs = functionalize(f)(inpt2)
-        else:
-            actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze()
+        actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze()
         # Right now the flavor of functionalize that also removes view ops
         # isn't being used with vmap
         # That's because {view}_copy ops don't have batching rules yet
@@ -3209,8 +3206,7 @@
             z2, z3 = z1.split(2)
             z2.add_(tmp)
             return x
-        # See Note [Fix vmap slice_scatter]
-        self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device), skip_vmap=True)
+        self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device))
 
     # Ensure functionalize works with List[Optional[Tensor]] arguments.
     # See the fix / discussion at https://github.com/pytorch/pytorch/pull/76085
diff --git a/test/test_functionalization.py b/test/test_functionalization.py
index e2cca26..bfb7967 100644
--- a/test/test_functionalization.py
+++ b/test/test_functionalization.py
@@ -156,11 +156,11 @@
     as_strided_copy_4 = torch.ops.aten.as_strided_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0)
     clone_1 = torch.ops.aten.clone.default(as_strided_copy_4, memory_format = torch.contiguous_format);  as_strided_copy_4 = None
     threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0);  clone_1 = relu = None
-    view_copy_2 = torch.ops.aten.view_copy.default(as_strided_copy_2, [16, 64, 128, 128])
-    detach_copy = torch.ops.aten.detach_copy.default(view_copy_2);  view_copy_2 = None
+    _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1])
+    detach_copy = torch.ops.aten.detach_copy.default(_reshape_alias_copy);  _reshape_alias_copy = None
     as_strided_scatter_1 = torch.ops.aten.as_strided_scatter.default(as_strided_copy_2, threshold_backward, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0);  as_strided_copy_2 = threshold_backward = None
-    view_copy_3 = torch.ops.aten.view_copy.default(as_strided_scatter_1, [16, 64, 128, 128]);  as_strided_scatter_1 = None
-    detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_3);  view_copy_3 = None
+    _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(as_strided_scatter_1, [16, 64, 128, 128], [1048576, 16384, 128, 1]);  as_strided_scatter_1 = None
+    detach_copy_1 = torch.ops.aten.detach_copy.default(_reshape_alias_copy_1);  _reshape_alias_copy_1 = None
     return detach_copy_1
     """)  # noqa: B950
 
@@ -713,40 +713,40 @@
     ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
     add = torch.ops.aten.add.Tensor(a_1, a_1);  a_1 = None
     view_copy = torch.ops.aten.view_copy.default(add, [8])
-    view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]);  view_copy = None
-    transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0)
+    _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(view_copy, [2, 4], [4, 1]);  view_copy = None
+    transpose_copy = torch.ops.aten.transpose_copy.int(_reshape_alias_copy, 1, 0)
     unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0);  transpose_copy = None
     squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy);  unsqueeze_copy = None
     split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2);  squeeze_copy = None
     getitem = split_copy[0]
     getitem_1 = split_copy[1];  split_copy = None
     add_1 = torch.ops.aten.add.Tensor(getitem, ones);  getitem = ones = None
-    select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0);  view_copy_1 = None
-    view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4])
-    view_copy_3 = torch.ops.aten.view_copy.default(add, [8]);  add = None
-    view_copy_4 = torch.ops.aten.view_copy.default(view_copy_3, [2, 4]);  view_copy_3 = None
-    transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_4, 1, 0);  view_copy_4 = None
+    select_copy = torch.ops.aten.select_copy.int(_reshape_alias_copy, 0, 0);  _reshape_alias_copy = None
+    _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(add_1, [4], [1])
+    view_copy_1 = torch.ops.aten.view_copy.default(add, [8]);  add = None
+    _reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]);  view_copy_1 = None
+    transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_2, 1, 0);  _reshape_alias_copy_2 = None
     unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0);  transpose_copy_1 = None
     squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1);  unsqueeze_copy_1 = None
     slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2);  squeeze_copy_1 = None
     unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0);  slice_scatter = None
     squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0);  unsqueeze_copy_2 = None
     transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0);  squeeze_copy_2 = None
-    view_copy_5 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]);  transpose_copy_2 = None
-    view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [4, 2]);  view_copy_5 = None
-    view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [8])
-    view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [2, 4]);  view_copy_7 = None
-    select_copy_1 = torch.ops.aten.select_copy.int(view_copy_8, 0, 0);  view_copy_8 = None
-    view_copy_9 = torch.ops.aten.view_copy.default(view_copy_6, [8]);  view_copy_6 = None
-    view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]);  view_copy_9 = None
-    transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_10, 1, 0);  view_copy_10 = None
+    _reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]);  transpose_copy_2 = None
+    view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_3, [4, 2]);  _reshape_alias_copy_3 = None
+    view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8])
+    _reshape_alias_copy_4 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]);  view_copy_3 = None
+    select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_4, 0, 0);  _reshape_alias_copy_4 = None
+    view_copy_4 = torch.ops.aten.view_copy.default(view_copy_2, [8]);  view_copy_2 = None
+    _reshape_alias_copy_5 = torch.ops.aten._reshape_alias_copy.default(view_copy_4, [2, 4], [4, 1]);  view_copy_4 = None
+    transpose_copy_3 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_5, 1, 0);  _reshape_alias_copy_5 = None
     unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0);  transpose_copy_3 = None
     squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3);  unsqueeze_copy_3 = None
     split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2);  squeeze_copy_3 = None
     getitem_2 = split_copy_1[0]
     getitem_3 = split_copy_1[1];  split_copy_1 = None
-    view_copy_11 = torch.ops.aten.view_copy.default(getitem_2, [4]);  getitem_2 = None
-    add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_11);  select_copy_1 = view_copy_11 = None
+    _reshape_alias_copy_6 = torch.ops.aten._reshape_alias_copy.default(getitem_2, [4], [1]);  getitem_2 = None
+    add_2 = torch.ops.aten.add.Tensor(select_copy_1, _reshape_alias_copy_6);  select_copy_1 = _reshape_alias_copy_6 = None
     return add_1
     """)  # noqa: B950
 
@@ -759,30 +759,30 @@
     ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
     add = torch.ops.aten.add.Tensor(a_1, a_1);  a_1 = None
     view = torch.ops.aten.view.default(add, [8])
-    view_1 = torch.ops.aten.view.default(view, [2, 4]);  view = None
-    transpose = torch.ops.aten.transpose.int(view_1, 1, 0)
+    _reshape_alias = torch.ops.aten._reshape_alias.default(view, [2, 4], [4, 1]);  view = None
+    transpose = torch.ops.aten.transpose.int(_reshape_alias, 1, 0)
     unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0);  transpose = None
     squeeze = torch.ops.aten.squeeze.default(unsqueeze);  unsqueeze = None
     split = torch.ops.aten.split.Tensor(squeeze, 2);  squeeze = None
     getitem = split[0]
     getitem_1 = split[1];  split = None
     add_1 = torch.ops.aten.add_.Tensor(getitem, ones);  ones = None
-    select = torch.ops.aten.select.int(view_1, 0, 0);  view_1 = None
+    select = torch.ops.aten.select.int(_reshape_alias, 0, 0);  _reshape_alias = None
     clone = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format)
     _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]);  clone = None
-    view_2 = torch.ops.aten.view.default(add, [8]);  add = None
-    view_3 = torch.ops.aten.view.default(view_2, [2, 4]);  view_2 = None
-    transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0);  view_3 = None
+    view_1 = torch.ops.aten.view.default(add, [8]);  add = None
+    _reshape_alias_1 = torch.ops.aten._reshape_alias.default(view_1, [2, 4], [4, 1]);  view_1 = None
+    transpose_1 = torch.ops.aten.transpose.int(_reshape_alias_1, 1, 0);  _reshape_alias_1 = None
     unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0);  transpose_1 = None
     squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1);  unsqueeze_1 = None
     unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0);  squeeze_1 = None
     squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0);  unsqueeze_2 = None
     transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0);  squeeze_2 = None
-    view_4 = torch.ops.aten.view.default(transpose_2, [8]);  transpose_2 = None
-    view_5 = torch.ops.aten.view.default(view_4, [4, 2]);  view_4 = None
-    view_6 = torch.ops.aten.view.default(view_5, [8]);  view_5 = None
-    view_7 = torch.ops.aten.view.default(view_6, [2, 4]);  view_6 = None
-    select_1 = torch.ops.aten.select.int(view_7, 0, 0);  view_7 = None
+    _reshape_alias_2 = torch.ops.aten._reshape_alias.default(transpose_2, [8], [1]);  transpose_2 = None
+    view_2 = torch.ops.aten.view.default(_reshape_alias_2, [4, 2]);  _reshape_alias_2 = None
+    view_3 = torch.ops.aten.view.default(view_2, [8]);  view_2 = None
+    _reshape_alias_3 = torch.ops.aten._reshape_alias.default(view_3, [2, 4], [4, 1]);  view_3 = None
+    select_1 = torch.ops.aten.select.int(_reshape_alias_3, 0, 0);  _reshape_alias_3 = None
     add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view);  select_1 = _unsafe_view = None
     return getitem
     """)