Revert "[Meta Tensor] fix meta inplace set storage (#123880)"

This reverts commit cccae9355191a807040fb40a65178c4d7fe3f084.

Reverted https://github.com/pytorch/pytorch/pull/123880 on behalf of https://github.com/izaitsevfb due to breaks cpu_inductor_torchbench (detectron2_fasterrcnn) ([comment](https://github.com/pytorch/pytorch/pull/123880#issuecomment-2083366385))
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index f547992..b7d8eeb 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -421,19 +421,9 @@
     // it.  TODO: Actually this might not quite be correct if we use special
     // pointers to track whether or not fake cuda tensors are pinned or not
     const auto itemsize = result.dtype().itemsize();
-    c10::SymInt new_size_bytes = at::detail::computeStorageNbytes(
+    c10::SymInt size_bytes = at::detail::computeStorageNbytes(
         size, stride, itemsize, std::move(storage_offset));
-    // TODO: When there are unbacked SymInts, we unconditionally skip the
-    // setter.  This is technically wrong, but we cannot conveniently test
-    // the real condition in many cases, because a lot of people are using
-    // set_ just to swizzle metadata on a tensor, they didn't actually want
-    // to see if they need to resize the storage.
-    //
-    // The old behavior was to unconditionally set_nbytes, but I think not
-    // setting it is more safe.
-    if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() && TORCH_GUARD_SIZE_OBLIVIOUS(new_size_bytes.sym_gt(storage.sym_nbytes()))) {
-      storage.set_nbytes(std::move(new_size_bytes));
-    }
+    storage.set_nbytes(std::move(size_bytes));
   }
   return result;
 }
diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py
index 8adcd04..8005d6e 100644
--- a/test/dynamo/test_subclasses.py
+++ b/test/dynamo/test_subclasses.py
@@ -3,8 +3,6 @@
 import itertools
 import unittest
 
-from functools import partial
-
 import torch
 
 import torch._dynamo.test_case
@@ -39,105 +37,6 @@
     return torch._dynamo.config.patch("traceable_tensor_subclasses", {c})
 
 
-def get_jagged_tensor(nested_size, offsets, requires_grad=True):
-    # Makes a jagged tensor with N constituent tensors with size
-    # as specified ((S0, S1, S2), D)
-    D = nested_size[1]
-    out = []
-    for s in nested_size[0]:
-        out.append(torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64))
-    return jagged_from_list(out, offsets)
-
-
-def get_view_test_cases():
-    # Test all cases with both an NT base and a dense base
-    # Subclass -> Subclass
-    # Dense -> Subclass
-
-    # NB: Don't close over loop variables, they will not get copied into the
-    # closure
-    #
-    # NB: These return functions so we don't generate tensors during test
-    # collection time
-
-    def mk_basic(base_is_nt):
-        # There are three cases to consider here based on the logic in
-        # meta_utils.py
-        #
-        # (1) basic case:
-        # view is not a leaf and has the same requires grad as its basic case
-        x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
-        x = x.clone() if base_is_nt else x
-        assert not x.is_leaf
-        return x.unsqueeze(-1)
-
-    def mk_leaf(base_is_nt, requires_grad_1, requires_grad_2):
-        x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=requires_grad_1)
-        x = x.clone() if base_is_nt else x
-        with torch.no_grad():
-            x_view = x.unsqueeze(-1)
-            # The issue is this doesn't quite work
-            x_view.requires_grad_(requires_grad_2)
-
-        return x_view
-
-    def mk_obscure(base_is_nt):
-        x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
-        x = x.clone() if base_is_nt else x
-        # intermediate leaf view
-        with torch.no_grad():
-            x_view = x.unsqueeze(-1)
-        x_view.requires_grad_(True)
-        x_view_view = x_view.unsqueeze(-1)
-        return x_view_view
-
-    for base_is_nt in [False, True]:
-        prefix = f"base_is_nt_{base_is_nt}"
-
-        yield partial(mk_basic, base_is_nt), f"{prefix}_basic"
-
-        # (2) leaf view case:
-        # the view has to be a leaf (w/ requires_grad True or requires_grad False)
-        # base w/ requires_grad True or requires_grad False
-        for requires_grad_1, requires_grad_2 in itertools.product(
-            [True, False], repeat=2
-        ):
-            yield partial(
-                mk_leaf, base_is_nt, requires_grad_1, requires_grad_2
-            ), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}"
-
-        # (3) obscure case:
-        # view is not a leaf (implies requires_grad True)
-        # base w/ requires_grad False)
-        yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
-
-    # Subclass -> Dense
-    yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[
-        0
-    ].clone(), "subclass_dense"
-
-    # Dense -> Subclass -> Dense -> Subclass
-    def mk_dense_subclass_dense_subclass():
-        values = torch.randn(10, 5)
-        offsets = torch.tensor([0, 3, 6, 10])
-        offsets2 = offsets.clone().detach()
-        return nested_view_from_values_offsets(
-            nested_view_from_values_offsets(values, offsets).values(), offsets
-        )
-
-    yield mk_dense_subclass_dense_subclass, "dense_subclass_dense_subclass"
-
-    def mk_subclass_dense_subclass_dense():
-        x = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
-        offsets2 = x.offsets().clone().detach()
-        nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
-
-    yield mk_subclass_dense_subclass_dense, "subclass_dense_subclass_dense"
-
-
-VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()}
-
-
 requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
 
 compile_full_eager = torch.compile(backend="eager", fullgraph=True)
@@ -1308,7 +1207,15 @@
 
 class TestNestedTensor(torch._dynamo.test_case.TestCase):
     def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
-        return get_jagged_tensor(nested_size, offsets, requires_grad)
+        # Makes a jagged tensor with N constituent tensors with size
+        # as specified ((S0, S1, S2), D)
+        D = nested_size[1]
+        out = []
+        for s in nested_size[0]:
+            out.append(
+                torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64)
+            )
+        return jagged_from_list(out, offsets)
 
     def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True):
         # Makes a jagged tensor with N constituent tensors with size
@@ -1462,9 +1369,62 @@
 
         torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)
 
-    def _input_view_test(self, nt_view_name):
-        nt_view = VIEW_TEST_CASES[nt_view_name]()
+    def _get_views(self):
+        # Test all cases with both an NT base and a dense base
+        # Subclass -> Subclass
+        # Dense -> Subclass
+        for base_is_nt in [False, True]:
+            # There are three cases to consider here based on the logic in
+            # meta_utils.py
+            #
+            # (1) basic case:
+            # view is not a leaf and has the same requires grad as its basic case
+            x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
+            x = x.clone() if base_is_nt else x
+            self.assertEqual(x.is_leaf, False)
+            yield x.unsqueeze(-1)
 
+            # (2) leaf view case:
+            # the view has to be a leaf (w/ requires_grad True or requires_grad False)
+            # base w/ requires_grad True or requires_grad False
+            for requires_grad_1, requires_grad_2 in itertools.product(
+                [True, False], repeat=2
+            ):
+                x, _ = self._get_jagged_tensor(
+                    ((2, 3, 4), 3), None, requires_grad=requires_grad_1
+                )
+                x = x.clone() if base_is_nt else x
+                with torch.no_grad():
+                    x_view = x.unsqueeze(-1)
+                    # The issue is this doesn't quite work
+                    x_view.requires_grad_(requires_grad_2)
+                yield x_view
+
+            # (3) obscure case:
+            # view is not a leaf (implies requires_grad True)
+            # base w/ requires_grad False)
+            x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
+            x = x.clone() if base_is_nt else x
+            # intermediate leaf view
+            with torch.no_grad():
+                x_view = x.unsqueeze(-1)
+            x_view.requires_grad_(True)
+            x_view_view = x_view.unsqueeze(-1)
+            yield x_view_view
+
+        # Subclass -> Dense
+        x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
+        yield x.values()
+
+        # Dense -> Subclass -> Dense -> Subclass
+        values = torch.randn(10, 5)
+        offsets = torch.tensor([0, 3, 6, 10])
+        offsets2 = offsets.clone().detach()
+        yield nested_view_from_values_offsets(
+            nested_view_from_values_offsets(values, offsets).values(), offsets
+        )
+
+    def _input_view_test(self, nt_view):
         def fn(x):
             return x.sin()
 
@@ -1490,15 +1450,8 @@
 
             # varies based on the type of view
             guard_str = "\n".join(guards)
-            if (
-                isinstance(nt_view._base, NestedTensor)
-                or nt_view_name == "subclass_dense"
-            ):
+            if isinstance(nt_view._base, NestedTensor):
                 self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
-            elif nt_view_name.startswith("base_is_nt_False_"):
-                # TODO: this is a "do I need to resize storage" guard,
-                # probably don't actually want to see this
-                self.assertExpectedInline(guard_str, """8*s1*s3 <= 8*s0*s1""")
             else:
                 self.assertExpectedInline(guard_str, """""")
             return gm
@@ -1507,12 +1460,9 @@
         compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True)
         out = compile_fn(nt_view)
 
-    @parametrize(
-        "nt_view_name",
-        [k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"],
-    )
-    def test_inputs_to_compiled_fn_are_views(self, nt_view_name):
-        self._input_view_test(nt_view_name)
+    def test_inputs_to_compiled_fn_are_views(self):
+        for nt_view in self._get_views():
+            self._input_view_test(nt_view)
 
     def test_subclass_gives_static_shapes_when_dynamic_false(self):
         def check_graph(gm, *args):
@@ -1540,10 +1490,10 @@
     # are cached onto fake offsets to solve this problem.
     @unittest.expectedFailure
     def test_subclass_dense_subclass_dense_view(self):
-        self._input_view_test("subclass_dense_subclass_dense")
-
-
-instantiate_parametrized_tests(TestNestedTensor)
+        x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
+        offsets2 = x.offsets().clone().detach()
+        nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
+        self._input_view_test(nt_view)
 
 
 if __name__ == "__main__":
diff --git a/test/test_meta.py b/test/test_meta.py
index 93d7bb8..af1a5fb 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -286,14 +286,6 @@
         m = MetaConverter()(y)
         self.assertMetadataMatches(m, y)
 
-    def test_inplace_set_storage(self):
-        x = torch.tensor([0, 1], dtype=torch.int64)
-        storage = x.untyped_storage()
-        ssize = storage.size()
-        meta = torch.empty((), dtype=torch.int64)
-        meta.set_(storage, 0, (), ())
-        self.assertEqual(storage.size(), ssize)
-
     @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
     def test_weakref(self):
         x = torch.randn(4, 4, 4)
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index fdea8a3..d291605 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -8,7 +8,6 @@
 from dataclasses import dataclass
 from typing import (
     Any,
-    Callable,
     cast,
     Dict,
     List,
@@ -1214,20 +1213,16 @@
         if metadata.is_neg:
             torch._C._set_neg(empty, True)
 
-        maybe_suppress: Callable[[], Any] = contextlib.nullcontext
-        if self.shape_env is not None:
-            maybe_suppress = self.shape_env.suppress_guards
-
         if func.is_view:
             # For view ops, the storage should be the same as the tensor input.
             storage = args[cast(int, entry.view_idx)].untyped_storage()
-            with in_kernel_invocation_manager(self), maybe_suppress():
+            with in_kernel_invocation_manager(self):
                 empty.set_(
                     storage, metadata.storage_offset, metadata.shape, metadata.stride
                 )
         elif metadata.storage_offset != 0:
             storage = empty.untyped_storage()
-            with in_kernel_invocation_manager(self), maybe_suppress():
+            with in_kernel_invocation_manager(self):
                 empty.set_(
                     storage, metadata.storage_offset, metadata.shape, metadata.stride
                 )