Fix `full` on symbolic value. (#108166)
Fix: #108067
This PR adds checks for `sympy.Expr` when extracting the dtype from a value inside the
`full` lowering.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108166
Approved by: https://github.com/lezcano
diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py
index 8dcb244..fb150ef 100644
--- a/test/inductor/test_torchinductor_dynamic_shapes.py
+++ b/test/inductor/test_torchinductor_dynamic_shapes.py
@@ -268,6 +268,15 @@
actual = cfn(3)
self.assertEqual(expect, actual)
+ def test_full(self, device):
+ def fn(a):
+ return torch.full((3,), a)
+
+ cfn = self.compile_fn(fn)
+ expect = fn(5)
+ actual = cfn(5)
+ self.assertEqual(expect, actual)
+
instantiate_device_type_tests(TestInductorDynamic, globals())
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 3caab47..2274d8b 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -131,6 +131,15 @@
return dtype
+def value_to_dtype(value: Any) -> torch.dtype:
+ if isinstance(value, sympy.Expr):
+ if value.is_integer: # type: ignore[attr-defined]
+ return torch.long
+ if value.is_real:
+ return torch.get_default_dtype()
+ return type_to_dtype(type(value))
+
+
def is_integer_type(x):
if isinstance(x, TensorBox):
return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
@@ -2452,7 +2461,7 @@
@register_lowering([torch.full, aten.full])
def full(size, fill_value, **kwargs):
dtype = kwargs.get("dtype")
- kwargs["dtype"] = dtype if dtype is not None else type_to_dtype(type(fill_value))
+ kwargs["dtype"] = dtype if dtype is not None else value_to_dtype(fill_value)
return tensor_constructor(fill_value)(size, **kwargs)