[inductor] Allow sympy expressions to participate in type promotion (#115676)
In the test example we have `add(i64[10], sympy.Expr)` where
`sympy.Expr` is not considered a promoting arg so isn't factored into
the type promotion. However, in eager it would promote to float32.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115676
Approved by: https://github.com/lezcano
ghstack dependencies: #115677, #115699, #115700
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index 91868e8..cb0781e 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -2708,7 +2708,8 @@
m = M().eval()
with torch.no_grad():
metrics.reset()
- self.common(m, (idx, x))
+ # FIXME: returns the wrong dtype
+ self.common(m, (idx, x), exact_dtype=False)
assert metrics.generated_cpp_vec_kernel_count == 1
# we are doing direct load/store, make sure we do not generate
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 9e7e1c4..bcd977f 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -281,6 +281,22 @@
has_lowp_args = False
original_lowp_dtype = torch.half
+ if reference_in_float and exact_dtype:
+ # Store expected dtypes so we can check actual result gives the correct types
+ torch.manual_seed(0)
+ try:
+ eager_result = model(*ref_inputs, **ref_kwargs)
+ except RuntimeError:
+ # Eager model may fail if the dtype is not supported
+ eager_result = None
+
+ ref_inputs = [clone_preserve_strides(x) for x in example_inputs]
+ expect_dtypes = [
+ x.dtype if isinstance(x, torch.Tensor) else None
+ for x in pytree.tree_leaves(eager_result)
+ ]
+ del eager_result
+
if reference_in_float:
# check_lowp is ignored here, it's kept just to be able to call `common` with extra arg
def upcast_fn(x):
@@ -349,6 +365,13 @@
for x, y in zip(actual_flat, correct_flat)
)
+ if reference_in_float and exact_dtype:
+ for expect_dtype, actual_result in zip(expect_dtypes, actual_flat):
+ if expect_dtype is not None:
+ assert (
+ actual_result.dtype == expect_dtype
+ ), f"dtype mismatch, expected {expect_dtype} but got {actual_result.dtype}"
+
if reference_in_float:
correct_flat = reference_to_expect(actual_flat, correct_flat)
correct = tree_unflatten(correct_flat, correct_spec)
@@ -1628,7 +1651,8 @@
a // b,
)
- self.common(fn, (1024, 100))
+ # FIXME: returns the wrong dtype
+ self.common(fn, (1024, 100), exact_dtype=False)
def test_div9(self):
def fn(x):
@@ -4601,6 +4625,14 @@
re.search(pattern, code), msg="Found bad index_expr in code:\n" + code
)
+ def test_float_index_expression_type_promotion(self):
+ # Test that float indexing expressions participate in type promotion
+ def fn(x):
+ return x + 1.0 / x.size(0)
+
+ x = torch.arange(10)
+ self.common(fn, (x,))
+
def test_sort(self):
def fn(a):
return torch.sort(a)
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 3f90de4..85c4ccd 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -162,7 +162,7 @@
def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND):
def construct_input(inp):
- if isinstance(inp, (Number, sympy.Symbol)):
+ if isinstance(inp, (Number, sympy.Expr)):
return inp
else:
assert hasattr(inp, "get_dtype")
@@ -199,7 +199,9 @@
else:
# FIXME that's a crude approximation for promoting args
promoting_args = [
- a for a in args if isinstance(a, Number) or hasattr(a, "get_dtype")
+ a
+ for a in args
+ if isinstance(a, (Number, sympy.Expr)) or hasattr(a, "get_dtype")
]
dtype = get_promoted_dtype(
*promoting_args, type_promotion_kind=type_promotion_kind
diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py
index 71a5d5f..db86730 100644
--- a/torch/_prims_common/__init__.py
+++ b/torch/_prims_common/__init__.py
@@ -1303,7 +1303,7 @@
return type(x)
-def symbol_type(x: sympy.Symbol) -> Type:
+def expr_type(x: sympy.Expr) -> Type:
if x.is_integer: # type: ignore[attr-defined]
return int
else:
@@ -1411,14 +1411,14 @@
import sympy
for x in args:
- if not isinstance(x, (Number, TensorLike, sympy.Symbol)):
+ if not isinstance(x, (Number, TensorLike, sympy.Expr)):
msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!"
raise ValueError(msg)
if isinstance(x, Number):
highest_type = get_higher_type(highest_type, number_type(x))
- elif isinstance(x, sympy.Symbol):
- highest_type = get_higher_type(highest_type, symbol_type(x))
+ elif isinstance(x, sympy.Expr):
+ highest_type = get_higher_type(highest_type, expr_type(x))
else:
# x is a TensorLike
highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype))