| # Owner(s): ["module: inductor"] |
| |
| from sympy import Symbol |
| |
| from torch._inductor.test_case import run_tests, TestCase |
| from torch._inductor.utils import sympy_subs |
| |
| |
| class TestUtils(TestCase): |
| def testSympySubs(self): |
| # integer and nonnegetaive attributes are preserved. |
| expr = Symbol("x") |
| result = sympy_subs(expr, {expr: "y"}) |
| self.assertEqual(result.name, "y") |
| self.assertEqual(result.is_integer, None) |
| self.assertEqual(result.is_nonnegative, None) |
| |
| expr = Symbol("x", integer=True, nonnegative=False) |
| result = sympy_subs(expr, {expr: "y"}) |
| self.assertEqual(result.name, "y") |
| self.assertEqual(result.is_integer, True) |
| self.assertEqual(result.is_nonnegative, False) |
| |
| # invalid replacement. |
| expr = Symbol("x", integer=True) |
| result = sympy_subs(expr, {Symbol("x"): Symbol("y")}) |
| self.assertEqual(result.name, "x") |
| |
| # valid replacement since properties match. |
| expr = Symbol("x", integer=True) |
| result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")}) |
| self.assertEqual(result.name, "y") |
| |
| # invalid replacement. |
| expr = Symbol("x", integer=None) |
| result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")}) |
| self.assertEqual(result.name, "x") |
| |
| # replaced cant be string |
| self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"}) |
| |
| # replaced can be an expression |
| expr = Symbol("x") |
| expr = abs(expr) |
| self.assertEqual(expr.is_integer, None) |
| self.assertEqual(expr.is_nonnegative, None) |
| # replace abs(x) with y |
| # propagte abs(x) sympy properties. |
| result = sympy_subs(expr, {expr: Symbol("y")}) |
| self.assertEqual(result.name, "y") |
| self.assertEqual(result.is_integer, None) |
| self.assertEqual(result.is_nonnegative, None) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |