sym_int simplification for integer args, attempt 3 (#94799)
Per title, now propagates to inductor codegen.
Where should I put the test and how should test look like?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94799
Approved by: https://github.com/ezyang
diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py
index a30f17c..480b83d 100644
--- a/test/test_dynamic_shapes.py
+++ b/test/test_dynamic_shapes.py
@@ -388,6 +388,24 @@
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""")
+ r = math.floor(3.0 * a0)
+ self.assertEqual(r, 15)
+ self.assertIsInstance(r, torch.SymInt, msg=type(r))
+ self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
+
+ @skipIfNoSympy
+ def test_sym_ceil(self):
+ shape_env = ShapeEnv()
+ a0 = create_symint(shape_env, 5)
+ r = math.ceil(a0 / 2)
+ self.assertEqual(r, 3)
+ self.assertIsInstance(r, torch.SymInt, msg=type(r))
+ self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""")
+ r = math.floor(3.0 * a0)
+ self.assertEqual(r, 15)
+ self.assertIsInstance(r, torch.SymInt, msg=type(r))
+ self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
+
@skipIfNoSympy
def test_int_conversion(self):
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index c7bbba4..6b88fa0 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -546,6 +546,23 @@
def error():
raise AssertionError("shouldn't be hit")
+def floor_ceil_helper(a, fn):
+ if isinstance(a, sympy.Mul):
+ aa = a.args
+ if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
+ coef = sympy.Integer(aa[0])
+ if aa[0] == coef: # structural equality test
+ return coef * aa[1]
+ if isinstance(a, sympy.Float) and a == sympy.Integer(a) or isinstance(a, sympy.Integer):
+ return sympy.Integer(a)
+ return fn(a)
+
+def floor_impl(a):
+ return floor_ceil_helper(a, sympy.floor)
+
+def ceil_impl(a):
+ return floor_ceil_helper(a, sympy.ceiling)
+
magic_methods = {
**reflectable_magic_methods,
@@ -556,9 +573,9 @@
'lt': lambda a, b: sympy.Lt(a, b),
'le': lambda a, b: sympy.Le(a, b),
'ge': lambda a, b: sympy.Ge(a, b),
- 'floor': lambda a: sympy.floor(a),
+ 'floor': floor_impl,
'sym_float': lambda a: a, # Cannot use sympy.Float(a) here, coz it expects python literals
- 'ceil': lambda a: sympy.ceiling(a),
+ 'ceil': ceil_impl,
'neg': lambda a: -a,
'sym_min': lambda a, b: sympy.Min(a, b),
'sym_max': lambda a, b: sympy.Max(a, b),
@@ -737,25 +754,11 @@
# TODO: consider constant prop here
expr = self.shape_env.replace(self.expr)
- # Attempt some extra simplification on floor/ceil
- out = None
- if method == "floor" or method == "ceil":
- if isinstance(expr, sympy.Mul):
- aa = expr.args
- if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
- coef = sympy.Integer(aa[0])
- if aa[0] == coef: # structural equality test
- out = coef * aa[1]
- elif isinstance(expr, sympy.Float) and expr == sympy.Integer(expr) or isinstance(expr, sympy.Integer):
- out = sympy.Integer(expr)
-
- # Do the regular evaluation otherwise
- if out is None:
- try:
- out = func(expr)
- except Exception:
- log.warning(f"failed to eval {method}({expr})")
- raise
+ try:
+ out = func(expr)
+ except Exception:
+ log.warning(f"failed to eval {method}({expr})")
+ raise
out_hint = None
if self.hint is not None: