Revert "[Meta Tensor] fix meta inplace set storage (#123880)"
This reverts commit cccae9355191a807040fb40a65178c4d7fe3f084.
Reverted https://github.com/pytorch/pytorch/pull/123880 on behalf of https://github.com/izaitsevfb due to breaks cpu_inductor_torchbench (detectron2_fasterrcnn) ([comment](https://github.com/pytorch/pytorch/pull/123880#issuecomment-2083366385))
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index f547992..b7d8eeb 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -421,19 +421,9 @@
// it. TODO: Actually this might not quite be correct if we use special
// pointers to track whether or not fake cuda tensors are pinned or not
const auto itemsize = result.dtype().itemsize();
- c10::SymInt new_size_bytes = at::detail::computeStorageNbytes(
+ c10::SymInt size_bytes = at::detail::computeStorageNbytes(
size, stride, itemsize, std::move(storage_offset));
- // TODO: When there are unbacked SymInts, we unconditionally skip the
- // setter. This is technically wrong, but we cannot conveniently test
- // the real condition in many cases, because a lot of people are using
- // set_ just to swizzle metadata on a tensor, they didn't actually want
- // to see if they need to resize the storage.
- //
- // The old behavior was to unconditionally set_nbytes, but I think not
- // setting it is more safe.
- if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() && TORCH_GUARD_SIZE_OBLIVIOUS(new_size_bytes.sym_gt(storage.sym_nbytes()))) {
- storage.set_nbytes(std::move(new_size_bytes));
- }
+ storage.set_nbytes(std::move(size_bytes));
}
return result;
}
diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py
index 8adcd04..8005d6e 100644
--- a/test/dynamo/test_subclasses.py
+++ b/test/dynamo/test_subclasses.py
@@ -3,8 +3,6 @@
import itertools
import unittest
-from functools import partial
-
import torch
import torch._dynamo.test_case
@@ -39,105 +37,6 @@
return torch._dynamo.config.patch("traceable_tensor_subclasses", {c})
-def get_jagged_tensor(nested_size, offsets, requires_grad=True):
- # Makes a jagged tensor with N constituent tensors with size
- # as specified ((S0, S1, S2), D)
- D = nested_size[1]
- out = []
- for s in nested_size[0]:
- out.append(torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64))
- return jagged_from_list(out, offsets)
-
-
-def get_view_test_cases():
- # Test all cases with both an NT base and a dense base
- # Subclass -> Subclass
- # Dense -> Subclass
-
- # NB: Don't close over loop variables, they will not get copied into the
- # closure
- #
- # NB: These return functions so we don't generate tensors during test
- # collection time
-
- def mk_basic(base_is_nt):
- # There are three cases to consider here based on the logic in
- # meta_utils.py
- #
- # (1) basic case:
- # view is not a leaf and has the same requires grad as its basic case
- x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
- x = x.clone() if base_is_nt else x
- assert not x.is_leaf
- return x.unsqueeze(-1)
-
- def mk_leaf(base_is_nt, requires_grad_1, requires_grad_2):
- x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=requires_grad_1)
- x = x.clone() if base_is_nt else x
- with torch.no_grad():
- x_view = x.unsqueeze(-1)
- # The issue is this doesn't quite work
- x_view.requires_grad_(requires_grad_2)
-
- return x_view
-
- def mk_obscure(base_is_nt):
- x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
- x = x.clone() if base_is_nt else x
- # intermediate leaf view
- with torch.no_grad():
- x_view = x.unsqueeze(-1)
- x_view.requires_grad_(True)
- x_view_view = x_view.unsqueeze(-1)
- return x_view_view
-
- for base_is_nt in [False, True]:
- prefix = f"base_is_nt_{base_is_nt}"
-
- yield partial(mk_basic, base_is_nt), f"{prefix}_basic"
-
- # (2) leaf view case:
- # the view has to be a leaf (w/ requires_grad True or requires_grad False)
- # base w/ requires_grad True or requires_grad False
- for requires_grad_1, requires_grad_2 in itertools.product(
- [True, False], repeat=2
- ):
- yield partial(
- mk_leaf, base_is_nt, requires_grad_1, requires_grad_2
- ), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}"
-
- # (3) obscure case:
- # view is not a leaf (implies requires_grad True)
- # base w/ requires_grad False)
- yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
-
- # Subclass -> Dense
- yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[
- 0
- ].clone(), "subclass_dense"
-
- # Dense -> Subclass -> Dense -> Subclass
- def mk_dense_subclass_dense_subclass():
- values = torch.randn(10, 5)
- offsets = torch.tensor([0, 3, 6, 10])
- offsets2 = offsets.clone().detach()
- return nested_view_from_values_offsets(
- nested_view_from_values_offsets(values, offsets).values(), offsets
- )
-
- yield mk_dense_subclass_dense_subclass, "dense_subclass_dense_subclass"
-
- def mk_subclass_dense_subclass_dense():
- x = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
- offsets2 = x.offsets().clone().detach()
- nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
-
- yield mk_subclass_dense_subclass_dense, "subclass_dense_subclass_dense"
-
-
-VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()}
-
-
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
compile_full_eager = torch.compile(backend="eager", fullgraph=True)
@@ -1308,7 +1207,15 @@
class TestNestedTensor(torch._dynamo.test_case.TestCase):
def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
- return get_jagged_tensor(nested_size, offsets, requires_grad)
+ # Makes a jagged tensor with N constituent tensors with size
+ # as specified ((S0, S1, S2), D)
+ D = nested_size[1]
+ out = []
+ for s in nested_size[0]:
+ out.append(
+ torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64)
+ )
+ return jagged_from_list(out, offsets)
def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True):
# Makes a jagged tensor with N constituent tensors with size
@@ -1462,9 +1369,62 @@
torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)
- def _input_view_test(self, nt_view_name):
- nt_view = VIEW_TEST_CASES[nt_view_name]()
+ def _get_views(self):
+ # Test all cases with both an NT base and a dense base
+ # Subclass -> Subclass
+ # Dense -> Subclass
+ for base_is_nt in [False, True]:
+ # There are three cases to consider here based on the logic in
+ # meta_utils.py
+ #
+ # (1) basic case:
+ # view is not a leaf and has the same requires grad as its basic case
+ x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
+ x = x.clone() if base_is_nt else x
+ self.assertEqual(x.is_leaf, False)
+ yield x.unsqueeze(-1)
+ # (2) leaf view case:
+ # the view has to be a leaf (w/ requires_grad True or requires_grad False)
+ # base w/ requires_grad True or requires_grad False
+ for requires_grad_1, requires_grad_2 in itertools.product(
+ [True, False], repeat=2
+ ):
+ x, _ = self._get_jagged_tensor(
+ ((2, 3, 4), 3), None, requires_grad=requires_grad_1
+ )
+ x = x.clone() if base_is_nt else x
+ with torch.no_grad():
+ x_view = x.unsqueeze(-1)
+ # The issue is this doesn't quite work
+ x_view.requires_grad_(requires_grad_2)
+ yield x_view
+
+ # (3) obscure case:
+ # view is not a leaf (implies requires_grad True)
+ # base w/ requires_grad False)
+ x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
+ x = x.clone() if base_is_nt else x
+ # intermediate leaf view
+ with torch.no_grad():
+ x_view = x.unsqueeze(-1)
+ x_view.requires_grad_(True)
+ x_view_view = x_view.unsqueeze(-1)
+ yield x_view_view
+
+ # Subclass -> Dense
+ x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
+ yield x.values()
+
+ # Dense -> Subclass -> Dense -> Subclass
+ values = torch.randn(10, 5)
+ offsets = torch.tensor([0, 3, 6, 10])
+ offsets2 = offsets.clone().detach()
+ yield nested_view_from_values_offsets(
+ nested_view_from_values_offsets(values, offsets).values(), offsets
+ )
+
+ def _input_view_test(self, nt_view):
def fn(x):
return x.sin()
@@ -1490,15 +1450,8 @@
# varies based on the type of view
guard_str = "\n".join(guards)
- if (
- isinstance(nt_view._base, NestedTensor)
- or nt_view_name == "subclass_dense"
- ):
+ if isinstance(nt_view._base, NestedTensor):
self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
- elif nt_view_name.startswith("base_is_nt_False_"):
- # TODO: this is a "do I need to resize storage" guard,
- # probably don't actually want to see this
- self.assertExpectedInline(guard_str, """8*s1*s3 <= 8*s0*s1""")
else:
self.assertExpectedInline(guard_str, """""")
return gm
@@ -1507,12 +1460,9 @@
compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True)
out = compile_fn(nt_view)
- @parametrize(
- "nt_view_name",
- [k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"],
- )
- def test_inputs_to_compiled_fn_are_views(self, nt_view_name):
- self._input_view_test(nt_view_name)
+ def test_inputs_to_compiled_fn_are_views(self):
+ for nt_view in self._get_views():
+ self._input_view_test(nt_view)
def test_subclass_gives_static_shapes_when_dynamic_false(self):
def check_graph(gm, *args):
@@ -1540,10 +1490,10 @@
# are cached onto fake offsets to solve this problem.
@unittest.expectedFailure
def test_subclass_dense_subclass_dense_view(self):
- self._input_view_test("subclass_dense_subclass_dense")
-
-
-instantiate_parametrized_tests(TestNestedTensor)
+ x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
+ offsets2 = x.offsets().clone().detach()
+ nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
+ self._input_view_test(nt_view)
if __name__ == "__main__":
diff --git a/test/test_meta.py b/test/test_meta.py
index 93d7bb8..af1a5fb 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -286,14 +286,6 @@
m = MetaConverter()(y)
self.assertMetadataMatches(m, y)
- def test_inplace_set_storage(self):
- x = torch.tensor([0, 1], dtype=torch.int64)
- storage = x.untyped_storage()
- ssize = storage.size()
- meta = torch.empty((), dtype=torch.int64)
- meta.set_(storage, 0, (), ())
- self.assertEqual(storage.size(), ssize)
-
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
def test_weakref(self):
x = torch.randn(4, 4, 4)
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index fdea8a3..d291605 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -8,7 +8,6 @@
from dataclasses import dataclass
from typing import (
Any,
- Callable,
cast,
Dict,
List,
@@ -1214,20 +1213,16 @@
if metadata.is_neg:
torch._C._set_neg(empty, True)
- maybe_suppress: Callable[[], Any] = contextlib.nullcontext
- if self.shape_env is not None:
- maybe_suppress = self.shape_env.suppress_guards
-
if func.is_view:
# For view ops, the storage should be the same as the tensor input.
storage = args[cast(int, entry.view_idx)].untyped_storage()
- with in_kernel_invocation_manager(self), maybe_suppress():
+ with in_kernel_invocation_manager(self):
empty.set_(
storage, metadata.storage_offset, metadata.shape, metadata.stride
)
elif metadata.storage_offset != 0:
storage = empty.untyped_storage()
- with in_kernel_invocation_manager(self), maybe_suppress():
+ with in_kernel_invocation_manager(self):
empty.set_(
storage, metadata.storage_offset, metadata.shape, metadata.stride
)