Add symbolic singleton int (#110370)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110370
Approved by: https://github.com/ezyang
ghstack dependencies: #110044, #110369
diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py
index 72bcd35..474d70d 100644
--- a/test/test_sympy_utils.py
+++ b/test/test_sympy_utils.py
@@ -18,6 +18,10 @@
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.reference import ReferenceAnalysis
from torch.utils._sympy.interp import sympy_interp
+from torch.utils._sympy.singleton_int import SingletonInt
+from sympy.core.relational import is_ge, is_le, is_gt, is_lt
+import functools
+
UNARY_OPS = [
@@ -520,6 +524,89 @@
r = solver.check()
self.assertEqual(r, z3.unsat)
+class TestSingletonInt(TestCase):
+ def test_basic(self):
+ j1 = SingletonInt(1, coeff=1)
+ j1_copy = SingletonInt(1, coeff=1)
+ j2 = SingletonInt(2, coeff=1)
+ j1x2 = SingletonInt(1, coeff=2)
+
+ def test_eq(a, b, expected):
+ self.assertEqual(sympy.Eq(a, b), expected)
+ self.assertEqual(sympy.Ne(b, a), not expected)
+
+ # eq, ne
+ test_eq(j1, j1, True)
+ test_eq(j1, j1_copy, True)
+ test_eq(j1, j2, False)
+ test_eq(j1, j1x2, False)
+ test_eq(j1, sympy.Integer(1), False)
+ test_eq(j1, sympy.Integer(3), False)
+
+ def test_ineq(a, b, expected, *, strict=True):
+ greater = (sympy.Gt, is_gt) if strict else (sympy.Ge, is_ge)
+ less = (sympy.Lt, is_lt) if strict else (sympy.Le, is_le)
+
+ if isinstance(expected, bool):
+ # expected is always True
+ for fn in greater:
+ self.assertEqual(fn(a, b), expected)
+ self.assertEqual(fn(b, a), not expected)
+ for fn in less:
+ self.assertEqual(fn(b, a), expected)
+ self.assertEqual(fn(a, b), not expected)
+ else:
+ for fn in greater:
+ with self.assertRaisesRegex(ValueError, expected):
+ fn(a, b)
+ for fn in less:
+ with self.assertRaisesRegex(ValueError, expected):
+ fn(b, a)
+
+ # ge, le, gt, lt
+ for strict in (True, False):
+ _test_ineq = functools.partial(test_ineq, strict=strict)
+ _test_ineq(j1, sympy.Integer(0), True)
+ _test_ineq(j1, sympy.Integer(3), "indeterminate")
+ _test_ineq(j1, j2, "indeterminate")
+ _test_ineq(j1x2, j1, True)
+
+ # Special cases for ge, le, gt, lt:
+ for ge in (sympy.Ge, is_ge):
+ self.assertTrue(ge(j1, j1))
+ self.assertTrue(ge(j1, sympy.Integer(2)))
+ with self.assertRaisesRegex(ValueError, "indeterminate"):
+ ge(sympy.Integer(2), j1)
+ for le in (sympy.Le, is_le):
+ self.assertTrue(le(j1, j1))
+ self.assertTrue(le(sympy.Integer(2), j1))
+ with self.assertRaisesRegex(ValueError, "indeterminate"):
+ le(j1, sympy.Integer(2))
+
+ for gt in (sympy.Gt, is_gt):
+ self.assertFalse(gt(j1, j1))
+ self.assertFalse(gt(sympy.Integer(2), j1))
+ # it is only known to be that j1 >= 2, j1 > 2 is indeterminate
+ with self.assertRaisesRegex(ValueError, "indeterminate"):
+ gt(j1, sympy.Integer(2))
+
+ for lt in (sympy.Lt, is_lt):
+ self.assertFalse(lt(j1, j1))
+ self.assertFalse(lt(j1, sympy.Integer(2)))
+ with self.assertRaisesRegex(ValueError, "indeterminate"):
+ lt(sympy.Integer(2), j1)
+
+ # mul
+ self.assertEqual(j1 * 2, j1x2)
+ # Unfortunately, this doesn't not automatically simplify to 2*j1
+ # since sympy.Mul doesn't trigger __mul__ unlike the above.
+ self.assertIsInstance(sympy.Mul(j1, 2), sympy.core.mul.Mul)
+
+ with self.assertRaisesRegex(ValueError, "cannot be multiplied"):
+ j1 * j2
+
+ self.assertEqual(j1.free_symbols, set())
+
instantiate_parametrized_tests(TestValueRanges)
instantiate_parametrized_tests(TestSympyInterp)
diff --git a/torch/utils/_sympy/singleton_int.py b/torch/utils/_sympy/singleton_int.py
new file mode 100644
index 0000000..d67e373
--- /dev/null
+++ b/torch/utils/_sympy/singleton_int.py
@@ -0,0 +1,94 @@
+import sympy
+from sympy.multipledispatch import dispatch
+
+__all__ = ["SingletonInt"]
+
+
+class SingletonInt(sympy.AtomicExpr):
+ # This is probably not super important unless we are in multiple dispatch
+ # situations with other more exotic Expr types.
+ _op_priority = 99999
+
+ def __new__(cls, *args, coeff=None, **kwargs):
+ instance = super().__new__(cls, *args, **kwargs)
+ return instance
+
+ # The semantics of this class should match that of SingletonSymNodeImpl in
+ # c10/core/SingletonSymNodeImpl.h
+ def __init__(self, val, *, coeff=1):
+ self._val = val
+ self._coeff = coeff
+ super().__init__()
+
+ # See NOTE [ Inequalities with SingletonInt ]
+ def _eval_Eq(self, other):
+ if (
+ isinstance(other, SingletonInt)
+ and other._val == self._val
+ and self._coeff == other._coeff
+ ):
+ return sympy.true
+ else:
+ return sympy.false
+
+ # This is necessary so that calling expr.free_symbols on exprs that contain
+ # this Singleton does not error
+ @property
+ def free_symbols(self):
+ return set()
+
+ def __mul__(self, other):
+ if isinstance(other, SingletonInt):
+ raise ValueError(
+ "SingletonInt cannot be multiplied by another SingletonInt"
+ )
+ return SingletonInt(self._val, coeff=self._coeff * other)
+
+ def __rmul__(self, other):
+ if isinstance(other, SingletonInt):
+ raise ValueError(
+ "SingletonInt cannot be multiplied by another SingletonInt"
+ )
+ return SingletonInt(self._val, coeff=self._coeff * other)
+
+ # Make sure we promptly raise an error instead of falling back to building
+ # an expression tree. There are probably more ops, how can we be exhaustive?
+ def __add__(self, other):
+ raise NotImplementedError("NYI")
+
+ def __sub__(self, other):
+ raise NotImplementedError("NYI")
+
+ def __truediv__(self, other):
+ raise NotImplementedError("NYI")
+
+ def __floordiv__(self, other):
+ raise NotImplementedError("NYI")
+
+ def __mod__(self, other):
+ raise NotImplementedError("NYI")
+
+
+# See NOTE [ Inequalities with SingletonInt ]
+@dispatch(sympy.Integer, SingletonInt)
+def _eval_is_ge(a, b):
+ if a < 2:
+ return sympy.false
+ raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
+
+
+@dispatch(SingletonInt, sympy.Integer) # type: ignore[no-redef]
+def _eval_is_ge(a, b): # noqa: F811
+ if b <= 2:
+ return sympy.true
+ raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
+
+
+@dispatch(SingletonInt, SingletonInt) # type: ignore[no-redef]
+def _eval_is_ge(a, b): # noqa: F811
+ if a._val == b._val:
+ if a._coeff >= b._coeff:
+ return sympy.true
+ else:
+ return sympy.false
+ raise ValueError("Symbolic SingletonInt: Relation is indeterminate")