|  | # Owner(s): ["module: inductor"] | 
|  |  | 
|  | from sympy import Symbol | 
|  | from torch._inductor.test_case import run_tests, TestCase | 
|  | from torch._inductor.utils import sympy_subs | 
|  |  | 
|  |  | 
|  | class TestUtils(TestCase): | 
|  | def testSympySubs(self): | 
|  | # integer and nonnegetaive attributes are preserved. | 
|  | expr = Symbol("x") | 
|  | result = sympy_subs(expr, {expr: "y"}) | 
|  | self.assertEqual(result.name, "y") | 
|  | self.assertEqual(result.is_integer, None) | 
|  | self.assertEqual(result.is_nonnegative, None) | 
|  |  | 
|  | expr = Symbol("x", integer=True, nonnegative=False) | 
|  | result = sympy_subs(expr, {expr: "y"}) | 
|  | self.assertEqual(result.name, "y") | 
|  | self.assertEqual(result.is_integer, True) | 
|  | self.assertEqual(result.is_nonnegative, False) | 
|  |  | 
|  | # invalid replacement. | 
|  | expr = Symbol("x", integer=True) | 
|  | result = sympy_subs(expr, {Symbol("x"): Symbol("y")}) | 
|  | self.assertEqual(result.name, "x") | 
|  |  | 
|  | # valid replacement since properties match. | 
|  | expr = Symbol("x", integer=True) | 
|  | result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")}) | 
|  | self.assertEqual(result.name, "y") | 
|  |  | 
|  | # invalid replacement. | 
|  | expr = Symbol("x", integer=None) | 
|  | result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")}) | 
|  | self.assertEqual(result.name, "x") | 
|  |  | 
|  | # replaced cant be string | 
|  | self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"}) | 
|  |  | 
|  | # replaced can be an expression | 
|  | expr = Symbol("x") | 
|  | expr = abs(expr) | 
|  | self.assertEqual(expr.is_integer, None) | 
|  | self.assertEqual(expr.is_nonnegative, None) | 
|  | # replace abs(x) with y | 
|  | # propagte abs(x) sympy properties. | 
|  | result = sympy_subs(expr, {expr: Symbol("y")}) | 
|  | self.assertEqual(result.name, "y") | 
|  | self.assertEqual(result.is_integer, None) | 
|  | self.assertEqual(result.is_nonnegative, None) | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | run_tests() |