[inductor] Fix recompiles bug for torch.full (#123811)
Fixes #123810
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123811
Approved by: https://github.com/peterbell10
diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py
index c3b0147..b7f8047 100644
--- a/test/inductor/test_torchinductor_dynamic_shapes.py
+++ b/test/inductor/test_torchinductor_dynamic_shapes.py
@@ -578,7 +578,7 @@
actual = cfn(3)
self.assertEqual(expect, actual)
- def test_full(self, device):
+ def test_full_symbolic_value(self, device):
def fn(a):
return torch.full((3,), a), torch.full((3,), torch.sym_float(a))
@@ -587,6 +587,25 @@
actual = cfn(5)
self.assertEqual(expect, actual)
+ def test_full_recompiles(self, device):
+ def fn(x):
+ _, L = x.shape
+ return torch.full((L, L), torch.finfo(torch.float16).min, device=device)
+
+ cfn = self.compile_fn(fn)
+
+ import functools
+
+ input_fn = functools.partial(torch.randint, 10, 1000, device=device)
+
+ cfn(input_fn((2, 3)))
+ cfn(input_fn((2, 4))) # expect don't recompile here
+
+ # check compiled times of frame 0
+ from torch._dynamo.convert_frame import FRAME_COMPILE_COUNTER
+
+ self.assertEqual(FRAME_COMPILE_COUNTER[0], 1)
+
@parametrize(
"op",
[
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index 27b1289..00640f6 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -134,7 +134,7 @@
dtype = kwargs.get("dtype")
if dtype is None:
kwargs["dtype"] = type_to_dtype(type(fill_value))
- return aten.full(size, fill_value, **kwargs)
+ return torch.full(size, fill_value, **kwargs)
return NotImplemented