| # Owner(s): ["module: dynamo"] |
| import functools |
| |
| import torch |
| import torch._dynamo |
| import torch._dynamo.test_case |
| from torch._C._dynamo import guards |
| from torch.testing._internal.common_utils import set_default_dtype |
| |
| RootGuardManager = guards.RootGuardManager |
| GetAttrGuardAccessor = guards.GetAttrGuardAccessor |
| |
| |
| def id_type(x): |
| return id(type(x)) |
| |
| |
| def equals_match(x, expected): |
| return x == expected |
| |
| |
| def equals_match_verbose_code_parts(expected): |
| return [f"x == {expected}"] |
| |
| |
| def ge_match(x, expected): |
| return x >= expected |
| |
| |
| def ge_match_verbose_code_parts(expected): |
| return f"expected >= {expected}" |
| |
| |
| def less_match(x, expected): |
| return x < expected |
| |
| |
| def less_match_verbose_code_parts(expected): |
| return [f"expected < {expected}"] |
| |
| |
| class GuardManagerTests(torch._dynamo.test_case.TestCase): |
| def test_global_state_guard(self): |
| guard = guards.GLOBAL_STATE(["global_state_check"]) |
| self.assertTrue(guard(None)) |
| with set_default_dtype(torch.double): |
| self.assertFalse(guard(None)) |
| self.assertTrue(guard(None)) |
| _orig = torch.are_deterministic_algorithms_enabled() |
| try: |
| torch.use_deterministic_algorithms(not _orig) |
| self.assertFalse(guard(None)) |
| finally: |
| torch.use_deterministic_algorithms(_orig) |
| self.assertTrue(guard(None)) |
| |
| def test_python_lambda_leaf_guard(self): |
| const_guard = guards.LAMBDA_GUARD( |
| functools.partial(equals_match, expected=5), |
| equals_match_verbose_code_parts(5), |
| ) |
| self.assertTrue(const_guard(5)) |
| self.assertFalse(const_guard(4)) |
| self.assertFalse(const_guard("foo")) |
| |
| def test_type_guard(self): |
| foo = 4 |
| guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == int"]) |
| |
| self.assertTrue(guard(5)) |
| self.assertTrue(guard(4)) |
| self.assertFalse(guard("foo")) |
| |
| foo = {"a": 1} |
| guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == dict"]) |
| self.assertTrue(guard(foo)) |
| self.assertTrue(guard({})) |
| self.assertFalse(guard(5)) |
| self.assertFalse(guard("foo")) |
| |
| class Foo: |
| def __init__(self, x, y): |
| self.x = x |
| self.y = y |
| |
| foo = Foo(1, 2) |
| |
| guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == Foo"]) |
| self.assertTrue(guard(foo)) |
| self.assertFalse(guard({})) |
| self.assertFalse(guard(5)) |
| self.assertFalse(guard("foo")) |
| |
| def test_id_guard(self): |
| foo = 4 |
| guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"]) |
| |
| self.assertTrue(guard(foo)) |
| self.assertFalse(guard(5)) |
| self.assertFalse(guard("foo")) |
| |
| foo = {"a": 1} |
| guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"]) |
| self.assertTrue(guard(foo)) |
| self.assertFalse(guard({"a": 1})) |
| self.assertFalse(guard({})) |
| self.assertFalse(guard(5)) |
| |
| def test_equals_guard(self): |
| foo = 4 |
| guard = guards.EQUALS_MATCH(foo, ["x == 4"]) |
| |
| self.assertTrue(guard(4)) |
| self.assertFalse(guard(5)) |
| self.assertFalse(guard("foo")) |
| |
| # tuple |
| foo = (1, 2, 3) |
| guard = guards.EQUALS_MATCH(foo, ["x == foo"]) |
| self.assertTrue(guard(foo)) |
| self.assertTrue(guard((1, 2, 3))) |
| self.assertFalse(guard((1, 2, 3, 4))) |
| self.assertFalse(guard({})) |
| |
| # list |
| foo = [1, 2, 3] |
| guard = guards.EQUALS_MATCH(foo, ["x == foo"]) |
| self.assertTrue(guard(foo)) |
| self.assertTrue(guard([1, 2, 3])) |
| self.assertFalse(guard([1, 2, 3, 4])) |
| |
| # type |
| foo = int |
| guard = guards.EQUALS_MATCH(foo, ["x == foo"]) |
| self.assertTrue(guard(foo)) |
| self.assertTrue(guard(int)) |
| self.assertFalse(guard(float)) |
| |
| def test_default_device_guard(self): |
| foo = 1 |
| guard = guards.DEFAULT_DEVICE(["cpu device"]) |
| self.assertTrue(guard(foo)) |
| |
| try: |
| torch.set_default_device("cuda") |
| self.assertFalse(guard(foo)) |
| finally: |
| torch.set_default_device(None) |
| |
| def test_data_ptr_match_guard(self): |
| foo = torch.tensor([1, 2, 3]) |
| guard = guards.DATA_PTR_MATCH(foo, ["x.data_ptr() == foo.data_ptr()"]) |
| self.assertTrue(guard(foo)) |
| self.assertFalse(guard(torch.tensor([1, 2, 3]))) |
| |
| def test_guard_manager_leaf_guard(self): |
| guard_manager = RootGuardManager() |
| guard_manager.add_type_match_guard(id_type(5), ["type(x) == int"]) |
| guard_manager.add_lambda_guard( |
| functools.partial(ge_match, expected=5), |
| ge_match_verbose_code_parts(expected=5), |
| ) |
| guard_manager.add_lambda_guard( |
| functools.partial(less_match, expected=10), |
| less_match_verbose_code_parts(expected=10), |
| ) |
| self.assertEqual(len(guard_manager.get_leaf_guards()), 3) |
| self.assertEqual(len(guard_manager.get_accessors()), 0) |
| self.assertTrue(guard_manager.check(6)) |
| self.assertFalse(guard_manager.check(4)) |
| self.assertFalse(guard_manager.check("foo")) |
| |
| def test_attr_guard_manager(self): |
| class Foo: |
| def __init__(self, x, y): |
| self.x = x |
| self.y = y |
| |
| foo = Foo(1, 2) |
| guard_manager = RootGuardManager() |
| guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"]) |
| guard_manager.getattr_manager("x", 1).add_lambda_guard( |
| functools.partial(equals_match, expected=foo.x), |
| equals_match_verbose_code_parts(foo.x), |
| ) |
| guard_manager.getattr_manager("y", 2).add_lambda_guard( |
| functools.partial(equals_match, expected=foo.y), |
| equals_match_verbose_code_parts(foo.y), |
| ) |
| self.assertEqual(len(guard_manager.get_leaf_guards()), 1) |
| # 2 child managers, one for x and one for y |
| self.assertEqual(len(guard_manager.get_accessors()), 2) |
| self.assertTrue( |
| isinstance(guard_manager.get_accessors()[0], GetAttrGuardAccessor) |
| ) |
| self.assertTrue( |
| isinstance(guard_manager.get_accessors()[1], GetAttrGuardAccessor) |
| ) |
| # Check leaf guards on child managers |
| self.assertEqual( |
| len(guard_manager.getattr_manager("x", None).get_leaf_guards()), 1 |
| ) |
| self.assertEqual( |
| len(guard_manager.getattr_manager("y", None).get_leaf_guards()), 1 |
| ) |
| |
| self.assertTrue(guard_manager.check(foo)) |
| self.assertFalse(guard_manager.check(Foo(3, 4))) |
| self.assertFalse(guard_manager.check("foo")) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |