Properly unwrap_storage tensors sent to DynamicScalar (#117444)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117444
Approved by: https://github.com/Skylion007
diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py
index df8c30b..c825b5c 100644
--- a/test/inductor/test_torchinductor_dynamic_shapes.py
+++ b/test/inductor/test_torchinductor_dynamic_shapes.py
@@ -418,6 +418,16 @@
cfn = self.compile_fn(fn)
self.assertEqual(fn(a), cfn(a))
+ @torch._dynamo.config.patch(capture_scalar_outputs=True)
+ def test_item_materialize(self, device):
+ def fn(x):
+ return x.sum(dim=0).view(4).tolist()
+
+ cfn = torch.compile(fullgraph=True)(fn)
+
+ a = torch.ones(3, 4, dtype=torch.int64, device=device)
+ self.assertEqual(cfn(a), fn(a))
+
def test_abs(self, device):
def fn(x, y):
y0, y1 = y.shape
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index ad5d2e1..fad4cec 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -4366,7 +4366,7 @@
# TODO: handle bools carefully
def __init__(self, sym, data):
- super().__init__(None, NoneLayout(torch.device("cpu")), [data]) # type: ignore[arg-type]
+ super().__init__(None, NoneLayout(torch.device("cpu")), self.unwrap_storage([data])) # type: ignore[arg-type]
if isinstance(sym, sympy.Symbol):
self.sym = sym
self.is_bool = False