[inductor] Add constant_to_device for ir.Constant (#108087)
Fixes error with:
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 ./benchmarks/dynamo/torchbench.py --inference --performance --no-skip --inductor --freezing --only pyhpc_turbulent_kinetic_energy
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108087
Approved by: https://github.com/eellison
ghstack dependencies: #108096
diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py
index 4b49d31..20695c2 100644
--- a/test/inductor/test_cuda_repro.py
+++ b/test/inductor/test_cuda_repro.py
@@ -347,6 +347,23 @@
actual = torch.compile(forward, fullgraph=True)(x)
self.assertEqual(actual, correct)
+ def test_full_copy(self):
+ def forward(x):
+ full_10 = torch.ops.aten.full.default(
+ [204, 204, 28],
+ 0,
+ dtype=torch.float64,
+ layout=torch.strided,
+ device="cuda",
+ pin_memory=False,
+ )
+ return x + full_10.to("cpu")
+
+ o = torch.randn([204, 204, 28], dtype=torch.float64)
+ correct = forward(o)
+ actual = torch.compile(forward, fullgraph=True)(o)
+ self.assertEqual(actual, correct)
+
def test_autotune_inplace_kernel(self):
"""
This UT tests autotune on an inplace kernel. The autotune should not contaminate
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index 6bab0bd..fd62a28 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -1971,6 +1971,9 @@
def realize(self):
pass
+ def constant_to_device(self, device):
+ return Constant(self.value, self.dtype, device)
+
@dataclasses.dataclass
class IndexingConstant(BaseConstant):
@@ -1984,6 +1987,9 @@
return loader
+ def constant_to_device(self, device):
+ return IndexingConstant(self.index, self.dtype, device)
+
@dataclasses.dataclass
class Layout(IRNode):