[inductor] Allow sympy expressions to participate in type promotion (#115676)

In the test example we have `add(i64[10], sympy.Expr)` where
`sympy.Expr` is not considered a promoting arg so isn't factored into
the type promotion. However, in eager it would promote to float32.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115676
Approved by: https://github.com/lezcano
ghstack dependencies: #115677, #115699, #115700
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index 91868e8..cb0781e 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -2708,7 +2708,8 @@
         m = M().eval()
         with torch.no_grad():
             metrics.reset()
-            self.common(m, (idx, x))
+            # FIXME: returns the wrong dtype
+            self.common(m, (idx, x), exact_dtype=False)
             assert metrics.generated_cpp_vec_kernel_count == 1
 
         # we are doing direct load/store, make sure we do not generate
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 9e7e1c4..bcd977f 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -281,6 +281,22 @@
     has_lowp_args = False
     original_lowp_dtype = torch.half
 
+    if reference_in_float and exact_dtype:
+        # Store expected dtypes so we can check actual result gives the correct types
+        torch.manual_seed(0)
+        try:
+            eager_result = model(*ref_inputs, **ref_kwargs)
+        except RuntimeError:
+            # Eager model may fail if the dtype is not supported
+            eager_result = None
+
+        ref_inputs = [clone_preserve_strides(x) for x in example_inputs]
+        expect_dtypes = [
+            x.dtype if isinstance(x, torch.Tensor) else None
+            for x in pytree.tree_leaves(eager_result)
+        ]
+        del eager_result
+
     if reference_in_float:
         # check_lowp is ignored here, it's kept just to be able to call `common` with extra arg
         def upcast_fn(x):
@@ -349,6 +365,13 @@
             for x, y in zip(actual_flat, correct_flat)
         )
 
+    if reference_in_float and exact_dtype:
+        for expect_dtype, actual_result in zip(expect_dtypes, actual_flat):
+            if expect_dtype is not None:
+                assert (
+                    actual_result.dtype == expect_dtype
+                ), f"dtype mismatch, expected {expect_dtype} but got {actual_result.dtype}"
+
     if reference_in_float:
         correct_flat = reference_to_expect(actual_flat, correct_flat)
         correct = tree_unflatten(correct_flat, correct_spec)
@@ -1628,7 +1651,8 @@
                 a // b,
             )
 
-        self.common(fn, (1024, 100))
+        # FIXME: returns the wrong dtype
+        self.common(fn, (1024, 100), exact_dtype=False)
 
     def test_div9(self):
         def fn(x):
@@ -4601,6 +4625,14 @@
                 re.search(pattern, code), msg="Found bad index_expr in code:\n" + code
             )
 
+    def test_float_index_expression_type_promotion(self):
+        # Test that float indexing expressions participate in type promotion
+        def fn(x):
+            return x + 1.0 / x.size(0)
+
+        x = torch.arange(10)
+        self.common(fn, (x,))
+
     def test_sort(self):
         def fn(a):
             return torch.sort(a)
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 3f90de4..85c4ccd 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -162,7 +162,7 @@
 
 def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND):
     def construct_input(inp):
-        if isinstance(inp, (Number, sympy.Symbol)):
+        if isinstance(inp, (Number, sympy.Expr)):
             return inp
         else:
             assert hasattr(inp, "get_dtype")
@@ -199,7 +199,9 @@
         else:
             # FIXME that's a crude approximation for promoting args
             promoting_args = [
-                a for a in args if isinstance(a, Number) or hasattr(a, "get_dtype")
+                a
+                for a in args
+                if isinstance(a, (Number, sympy.Expr)) or hasattr(a, "get_dtype")
             ]
             dtype = get_promoted_dtype(
                 *promoting_args, type_promotion_kind=type_promotion_kind
diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py
index 71a5d5f..db86730 100644
--- a/torch/_prims_common/__init__.py
+++ b/torch/_prims_common/__init__.py
@@ -1303,7 +1303,7 @@
         return type(x)
 
 
-def symbol_type(x: sympy.Symbol) -> Type:
+def expr_type(x: sympy.Expr) -> Type:
     if x.is_integer:  # type: ignore[attr-defined]
         return int
     else:
@@ -1411,14 +1411,14 @@
     import sympy
 
     for x in args:
-        if not isinstance(x, (Number, TensorLike, sympy.Symbol)):
+        if not isinstance(x, (Number, TensorLike, sympy.Expr)):
             msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!"
             raise ValueError(msg)
 
         if isinstance(x, Number):
             highest_type = get_higher_type(highest_type, number_type(x))
-        elif isinstance(x, sympy.Symbol):
-            highest_type = get_higher_type(highest_type, symbol_type(x))
+        elif isinstance(x, sympy.Expr):
+            highest_type = get_higher_type(highest_type, expr_type(x))
         else:
             # x is a TensorLike
             highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype))