sym_int simplification for integer args, attempt 3 (#94799)

Per title, now propagates to inductor codegen.
Where should I put the test and how should test look like?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94799
Approved by: https://github.com/ezyang
diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py
index a30f17c..480b83d 100644
--- a/test/test_dynamic_shapes.py
+++ b/test/test_dynamic_shapes.py
@@ -388,6 +388,24 @@
         self.assertEqual(r, 2)
         self.assertIsInstance(r, torch.SymInt, msg=type(r))
         self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""")
+        r = math.floor(3.0 * a0)
+        self.assertEqual(r, 15)
+        self.assertIsInstance(r, torch.SymInt, msg=type(r))
+        self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
+
+    @skipIfNoSympy
+    def test_sym_ceil(self):
+        shape_env = ShapeEnv()
+        a0 = create_symint(shape_env, 5)
+        r = math.ceil(a0 / 2)
+        self.assertEqual(r, 3)
+        self.assertIsInstance(r, torch.SymInt, msg=type(r))
+        self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""")
+        r = math.floor(3.0 * a0)
+        self.assertEqual(r, 15)
+        self.assertIsInstance(r, torch.SymInt, msg=type(r))
+        self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
+
 
     @skipIfNoSympy
     def test_int_conversion(self):
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index c7bbba4..6b88fa0 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -546,6 +546,23 @@
 def error():
     raise AssertionError("shouldn't be hit")
 
+def floor_ceil_helper(a, fn):
+    if isinstance(a, sympy.Mul):
+        aa = a.args
+        if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
+            coef = sympy.Integer(aa[0])
+            if aa[0] == coef:  # structural equality test
+                return coef * aa[1]
+    if isinstance(a, sympy.Float) and a == sympy.Integer(a) or isinstance(a, sympy.Integer):
+        return sympy.Integer(a)
+    return fn(a)
+
+def floor_impl(a):
+    return floor_ceil_helper(a, sympy.floor)
+
+def ceil_impl(a):
+    return floor_ceil_helper(a, sympy.ceiling)
+
 
 magic_methods = {
     **reflectable_magic_methods,
@@ -556,9 +573,9 @@
     'lt': lambda a, b: sympy.Lt(a, b),
     'le': lambda a, b: sympy.Le(a, b),
     'ge': lambda a, b: sympy.Ge(a, b),
-    'floor': lambda a: sympy.floor(a),
+    'floor': floor_impl,
     'sym_float': lambda a: a,  # Cannot use sympy.Float(a) here, coz it expects python literals
-    'ceil': lambda a: sympy.ceiling(a),
+    'ceil': ceil_impl,
     'neg': lambda a: -a,
     'sym_min': lambda a, b: sympy.Min(a, b),
     'sym_max': lambda a, b: sympy.Max(a, b),
@@ -737,25 +754,11 @@
         # TODO: consider constant prop here
         expr = self.shape_env.replace(self.expr)
 
-        # Attempt some extra simplification on floor/ceil
-        out = None
-        if method == "floor" or method == "ceil":
-            if isinstance(expr, sympy.Mul):
-                aa = expr.args
-                if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
-                    coef = sympy.Integer(aa[0])
-                    if aa[0] == coef:  # structural equality test
-                        out = coef * aa[1]
-            elif isinstance(expr, sympy.Float) and expr == sympy.Integer(expr) or isinstance(expr, sympy.Integer):
-                out = sympy.Integer(expr)
-
-        # Do the regular evaluation otherwise
-        if out is None:
-            try:
-                out = func(expr)
-            except Exception:
-                log.warning(f"failed to eval {method}({expr})")
-                raise
+        try:
+            out = func(expr)
+        except Exception:
+            log.warning(f"failed to eval {method}({expr})")
+            raise
 
         out_hint = None
         if self.hint is not None: