Add symbolic singleton int (#110370)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110370
Approved by: https://github.com/ezyang
ghstack dependencies: #110044, #110369
diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py
index 72bcd35..474d70d 100644
--- a/test/test_sympy_utils.py
+++ b/test/test_sympy_utils.py
@@ -18,6 +18,10 @@
 from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
 from torch.utils._sympy.reference import ReferenceAnalysis
 from torch.utils._sympy.interp import sympy_interp
+from torch.utils._sympy.singleton_int import SingletonInt
+from sympy.core.relational import is_ge, is_le, is_gt, is_lt
+import functools
+
 
 
 UNARY_OPS = [
@@ -520,6 +524,89 @@
         r = solver.check()
         self.assertEqual(r, z3.unsat)
 
+class TestSingletonInt(TestCase):
+    def test_basic(self):
+        j1 = SingletonInt(1, coeff=1)
+        j1_copy = SingletonInt(1, coeff=1)
+        j2 = SingletonInt(2, coeff=1)
+        j1x2 = SingletonInt(1, coeff=2)
+
+        def test_eq(a, b, expected):
+            self.assertEqual(sympy.Eq(a, b), expected)
+            self.assertEqual(sympy.Ne(b, a), not expected)
+
+        # eq, ne
+        test_eq(j1, j1, True)
+        test_eq(j1, j1_copy, True)
+        test_eq(j1, j2, False)
+        test_eq(j1, j1x2, False)
+        test_eq(j1, sympy.Integer(1), False)
+        test_eq(j1, sympy.Integer(3), False)
+
+        def test_ineq(a, b, expected, *, strict=True):
+            greater = (sympy.Gt, is_gt) if strict else (sympy.Ge, is_ge)
+            less = (sympy.Lt, is_lt) if strict else (sympy.Le, is_le)
+
+            if isinstance(expected, bool):
+                # expected is always True
+                for fn in greater:
+                    self.assertEqual(fn(a, b), expected)
+                    self.assertEqual(fn(b, a), not expected)
+                for fn in less:
+                    self.assertEqual(fn(b, a), expected)
+                    self.assertEqual(fn(a, b), not expected)
+            else:
+                for fn in greater:
+                    with self.assertRaisesRegex(ValueError, expected):
+                        fn(a, b)
+                for fn in less:
+                    with self.assertRaisesRegex(ValueError, expected):
+                        fn(b, a)
+
+        # ge, le, gt, lt
+        for strict in (True, False):
+            _test_ineq = functools.partial(test_ineq, strict=strict)
+            _test_ineq(j1, sympy.Integer(0), True)
+            _test_ineq(j1, sympy.Integer(3), "indeterminate")
+            _test_ineq(j1, j2, "indeterminate")
+            _test_ineq(j1x2, j1, True)
+
+        # Special cases for ge, le, gt, lt:
+        for ge in (sympy.Ge, is_ge):
+            self.assertTrue(ge(j1, j1))
+            self.assertTrue(ge(j1, sympy.Integer(2)))
+            with self.assertRaisesRegex(ValueError, "indeterminate"):
+                ge(sympy.Integer(2), j1)
+        for le in (sympy.Le, is_le):
+            self.assertTrue(le(j1, j1))
+            self.assertTrue(le(sympy.Integer(2), j1))
+            with self.assertRaisesRegex(ValueError, "indeterminate"):
+                le(j1, sympy.Integer(2))
+
+        for gt in (sympy.Gt, is_gt):
+            self.assertFalse(gt(j1, j1))
+            self.assertFalse(gt(sympy.Integer(2), j1))
+            # it is only known to be that j1 >= 2, j1 > 2 is indeterminate
+            with self.assertRaisesRegex(ValueError, "indeterminate"):
+                gt(j1, sympy.Integer(2))
+
+        for lt in (sympy.Lt, is_lt):
+            self.assertFalse(lt(j1, j1))
+            self.assertFalse(lt(j1, sympy.Integer(2)))
+            with self.assertRaisesRegex(ValueError, "indeterminate"):
+                lt(sympy.Integer(2), j1)
+
+        # mul
+        self.assertEqual(j1 * 2, j1x2)
+        # Unfortunately, this doesn't not automatically simplify to 2*j1
+        # since sympy.Mul doesn't trigger __mul__ unlike the above.
+        self.assertIsInstance(sympy.Mul(j1, 2), sympy.core.mul.Mul)
+
+        with self.assertRaisesRegex(ValueError, "cannot be multiplied"):
+            j1 * j2
+
+        self.assertEqual(j1.free_symbols, set())
+
 
 instantiate_parametrized_tests(TestValueRanges)
 instantiate_parametrized_tests(TestSympyInterp)
diff --git a/torch/utils/_sympy/singleton_int.py b/torch/utils/_sympy/singleton_int.py
new file mode 100644
index 0000000..d67e373
--- /dev/null
+++ b/torch/utils/_sympy/singleton_int.py
@@ -0,0 +1,94 @@
+import sympy
+from sympy.multipledispatch import dispatch
+
+__all__ = ["SingletonInt"]
+
+
+class SingletonInt(sympy.AtomicExpr):
+    # This is probably not super important unless we are in multiple dispatch
+    # situations with other more exotic Expr types.
+    _op_priority = 99999
+
+    def __new__(cls, *args, coeff=None, **kwargs):
+        instance = super().__new__(cls, *args, **kwargs)
+        return instance
+
+    # The semantics of this class should match that of SingletonSymNodeImpl in
+    # c10/core/SingletonSymNodeImpl.h
+    def __init__(self, val, *, coeff=1):
+        self._val = val
+        self._coeff = coeff
+        super().__init__()
+
+    # See NOTE [ Inequalities with SingletonInt ]
+    def _eval_Eq(self, other):
+        if (
+            isinstance(other, SingletonInt)
+            and other._val == self._val
+            and self._coeff == other._coeff
+        ):
+            return sympy.true
+        else:
+            return sympy.false
+
+    # This is necessary so that calling expr.free_symbols on exprs that contain
+    # this Singleton does not error
+    @property
+    def free_symbols(self):
+        return set()
+
+    def __mul__(self, other):
+        if isinstance(other, SingletonInt):
+            raise ValueError(
+                "SingletonInt cannot be multiplied by another SingletonInt"
+            )
+        return SingletonInt(self._val, coeff=self._coeff * other)
+
+    def __rmul__(self, other):
+        if isinstance(other, SingletonInt):
+            raise ValueError(
+                "SingletonInt cannot be multiplied by another SingletonInt"
+            )
+        return SingletonInt(self._val, coeff=self._coeff * other)
+
+    # Make sure we promptly raise an error instead of falling back to building
+    # an expression tree. There are probably more ops, how can we be exhaustive?
+    def __add__(self, other):
+        raise NotImplementedError("NYI")
+
+    def __sub__(self, other):
+        raise NotImplementedError("NYI")
+
+    def __truediv__(self, other):
+        raise NotImplementedError("NYI")
+
+    def __floordiv__(self, other):
+        raise NotImplementedError("NYI")
+
+    def __mod__(self, other):
+        raise NotImplementedError("NYI")
+
+
+# See NOTE [ Inequalities with SingletonInt ]
+@dispatch(sympy.Integer, SingletonInt)
+def _eval_is_ge(a, b):
+    if a < 2:
+        return sympy.false
+    raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
+
+
+@dispatch(SingletonInt, sympy.Integer)  # type: ignore[no-redef]
+def _eval_is_ge(a, b):  # noqa: F811
+    if b <= 2:
+        return sympy.true
+    raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
+
+
+@dispatch(SingletonInt, SingletonInt)  # type: ignore[no-redef]
+def _eval_is_ge(a, b):  # noqa: F811
+    if a._val == b._val:
+        if a._coeff >= b._coeff:
+            return sympy.true
+        else:
+            return sympy.false
+    raise ValueError("Symbolic SingletonInt: Relation is indeterminate")