| # -*- coding: utf-8 -*- |
| # Owner(s): ["oncall: pt2"] |
| |
| import itertools |
| import sys |
| |
| import sympy |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| run_tests, |
| TestCase, |
| ) |
| from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges |
| from torch.utils._sympy.reference import ReferenceAnalysis |
| from torch.utils._sympy.interp import sympy_interp |
| |
| |
| UNARY_OPS = [ |
| "reciprocal", |
| "square", |
| "abs", |
| "neg", |
| "exp", |
| "log", |
| "sqrt", |
| "floor", |
| "ceil", |
| ] |
| BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"] |
| |
| UNARY_BOOL_OPS = ["not_"] |
| BINARY_BOOL_OPS = ["or_", "and_"] |
| COMPARE_OPS = ["eq", "ne", "lt", "gt", "le", "ge"] |
| |
| # a mix of constants, powers of two, primes |
| CONSTANTS = [ |
| -1, |
| 0, |
| 1, |
| 2, |
| 3, |
| 4, |
| 5, |
| 8, |
| 16, |
| 32, |
| 64, |
| 100, |
| 101, |
| 2**24, |
| 2**32, |
| 2**37 - 1, |
| sys.maxsize - 1, |
| sys.maxsize, |
| ] |
| # less constants for N^2 situations |
| LESS_CONSTANTS = [-1, 0, 1, 2, 100] |
| |
| |
| def valid_unary(fn, v): |
| if fn == "log" and v <= 0: |
| return False |
| elif fn == "reciprocal" and v == 0: |
| return False |
| elif fn == "sqrt" and v < 0: |
| return False |
| return True |
| |
| |
| def valid_binary(fn, a, b): |
| if fn == "pow" and ( |
| b > 4 |
| or ( # sympy will expand to x*x*... for integral b; don't do it if it's big |
| a <= 0 and b == -1 |
| ) |
| or (a == b == 0) # no imaginary numbers # 0**0 is undefined |
| ): |
| return False |
| elif fn == "mod" and b == 0: |
| return False |
| elif (fn == "div" or fn == "truediv") and b == 0: |
| return False |
| return True |
| |
| |
| def generate_range(vals): |
| for a1, a2 in itertools.product(vals, repeat=2): |
| if a1 in [sympy.true, sympy.false]: |
| if a1 == sympy.true and a2 == sympy.false: |
| continue |
| else: |
| if a1 > a2: |
| continue |
| # ranges that only admit infinite values are not interesting |
| if a1 == sympy.oo or a2 == -sympy.oo: |
| continue |
| yield ValueRanges(a1, a2) |
| |
| |
| class TestValueRanges(TestCase): |
| @parametrize("fn", UNARY_OPS) |
| @parametrize("dtype", ("int", "float")) |
| def test_unary_ref(self, fn, dtype): |
| dtype = {"int": sympy.Integer, "float": sympy.Float}[dtype] |
| for v in CONSTANTS: |
| if not valid_unary(fn, v): |
| continue |
| with self.subTest(v=v): |
| v = dtype(v) |
| ref_r = getattr(ReferenceAnalysis, fn)(v) |
| r = getattr(ValueRangeAnalysis, fn)(v) |
| self.assertEqual(r.lower.is_integer, r.upper.is_integer) |
| self.assertEqual(r.lower, r.upper) |
| self.assertEqual(ref_r.is_integer, r.upper.is_integer) |
| self.assertEqual(ref_r, r.lower) |
| |
| def test_pow_half(self): |
| ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) |
| |
| @parametrize("fn", BINARY_OPS) |
| @parametrize("dtype_a", ("int", "float")) |
| @parametrize("dtype_b", ("int", "float")) |
| def test_binary_ref(self, fn, dtype_a, dtype_b): |
| to_dtype = {"int": sympy.Integer, "float": sympy.Float} |
| dtype_a = to_dtype[dtype_a] |
| dtype_b = to_dtype[dtype_b] |
| for a, b in itertools.product(CONSTANTS, repeat=2): |
| if not valid_binary(fn, a, b): |
| continue |
| a = dtype_a(a) |
| b = dtype_b(b) |
| with self.subTest(a=a, b=b): |
| r = getattr(ValueRangeAnalysis, fn)(a, b) |
| if r == ValueRanges.unknown(): |
| continue |
| ref_r = getattr(ReferenceAnalysis, fn)(a, b) |
| |
| # sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf |
| if fn != "floordiv": |
| self.assertEqual(r.lower.is_integer, r.upper.is_integer) |
| self.assertEqual(ref_r.is_integer, r.upper.is_integer) |
| self.assertEqual(r.lower, r.upper) |
| self.assertEqual(ref_r, r.lower) |
| |
| def test_mul_zero_unknown(self): |
| self.assertEqual( |
| ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()), |
| ValueRanges.wrap(0), |
| ) |
| |
| @parametrize("fn", UNARY_BOOL_OPS) |
| def test_unary_bool_ref_range(self, fn): |
| vals = [sympy.false, sympy.true] |
| for a in generate_range(vals): |
| with self.subTest(a=a): |
| ref_r = getattr(ValueRangeAnalysis, fn)(a) |
| unique = set() |
| for a0 in vals: |
| if a0 not in a: |
| continue |
| with self.subTest(a0=a0): |
| r = getattr(ReferenceAnalysis, fn)(a0) |
| self.assertIn(r, ref_r) |
| unique.add(r) |
| if ref_r.lower == ref_r.upper: |
| self.assertEqual(len(unique), 1) |
| else: |
| self.assertEqual(len(unique), 2) |
| |
| @parametrize("fn", BINARY_BOOL_OPS) |
| def test_binary_bool_ref_range(self, fn): |
| vals = [sympy.false, sympy.true] |
| for a, b in itertools.product(generate_range(vals), repeat=2): |
| with self.subTest(a=a, b=b): |
| ref_r = getattr(ValueRangeAnalysis, fn)(a, b) |
| unique = set() |
| for a0, b0 in itertools.product(vals, repeat=2): |
| if a0 not in a or b0 not in b: |
| continue |
| with self.subTest(a0=a0, b0=b0): |
| r = getattr(ReferenceAnalysis, fn)(a0, b0) |
| self.assertIn(r, ref_r) |
| unique.add(r) |
| if ref_r.lower == ref_r.upper: |
| self.assertEqual(len(unique), 1) |
| else: |
| self.assertEqual(len(unique), 2) |
| |
| @parametrize("fn", UNARY_OPS) |
| def test_unary_ref_range(self, fn): |
| vals = [-sympy.oo, *CONSTANTS, sympy.oo] |
| for a in generate_range(vals): |
| with self.subTest(a=a): |
| ref_r = getattr(ValueRangeAnalysis, fn)(a) |
| for a0 in CONSTANTS: |
| if a0 not in a: |
| continue |
| if not valid_unary(fn, a0): |
| continue |
| with self.subTest(a0=a0): |
| r = getattr(ReferenceAnalysis, fn)(sympy.Integer(a0)) |
| self.assertIn(r, ref_r) |
| |
| # This takes about 4s for all the variants |
| @parametrize("fn", BINARY_OPS + COMPARE_OPS) |
| def test_binary_ref_range(self, fn): |
| vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo] |
| for a, b in itertools.product(generate_range(vals), repeat=2): |
| # don't attempt pow on exponents that are too large (but oo is OK) |
| if fn == "pow" and b.upper > 4 and b.upper != sympy.oo: |
| continue |
| with self.subTest(a=a, b=b): |
| ref_r = getattr(ValueRangeAnalysis, fn)(a, b) |
| for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2): |
| if a0 not in a or b0 not in b: |
| continue |
| if not valid_binary(fn, a0, b0): |
| continue |
| with self.subTest(a0=a0, b0=b0): |
| r = getattr(ReferenceAnalysis, fn)( |
| sympy.Integer(a0), sympy.Integer(b0) |
| ) |
| if r.is_finite: |
| self.assertIn(r, ref_r) |
| |
| def test_rational_bounds(self): |
| # Repro from https://github.com/pytorch/pytorch/issues/105097 |
| from sympy import floor, Eq |
| shape_0 = sympy.Symbol('shape_0', positive=True, integer=True) |
| new_expr = ( |
| Eq(30 * floor(4 * (((shape_0 + 1) // 96)) * |
| (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 647 + |
| 2584 * (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 647), |
| 2880 * floor((((shape_0 + 1) // 96)) * |
| (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 15528 + |
| 323 * (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 7764))) |
| new_range_env = {shape_0: ValueRanges(lower=1, upper=190)} |
| self.assertTrue(new_expr.subs({shape_0: 95})) |
| self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr)) |
| |
| |
| class TestSympyInterp(TestCase): |
| @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) |
| def test_interp(self, fn): |
| # SymPy does not implement truncation for Expressions |
| if fn in ("div", "truncdiv", "minimum", "maximum"): |
| return |
| |
| from sympy.abc import x, y |
| vals = CONSTANTS |
| if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: |
| vals = [True, False] |
| arity = 1 |
| if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: |
| arity = 2 |
| symbols = [x] |
| if arity == 2: |
| symbols = [x, y] |
| for args in itertools.product(vals, repeat=arity): |
| if arity == 1 and not valid_unary(fn, *args): |
| continue |
| elif arity == 2 and not valid_binary(fn, *args): |
| continue |
| with self.subTest(args=args): |
| sargs = [sympy.sympify(a) for a in args] |
| sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) |
| ref_r = getattr(ReferenceAnalysis, fn)(*sargs) |
| # Yes, I know this is a longwinded way of saying xreplace; the |
| # point is to test sympy_interp |
| r = sympy_interp(ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr) |
| self.assertEqual(ref_r, r) |
| |
| |
| instantiate_parametrized_tests(TestValueRanges) |
| instantiate_parametrized_tests(TestSympyInterp) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |