blob: 1997d7c4749e21cb45b52a5bb6c0327633ffeeab [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import functools
import torch
import torch._dynamo
import torch._dynamo.test_case
from torch._C._dynamo import guards
from torch.testing._internal.common_utils import set_default_dtype
RootGuardManager = guards.RootGuardManager
GetAttrGuardAccessor = guards.GetAttrGuardAccessor
def id_type(x):
return id(type(x))
def equals_match(x, expected):
return x == expected
def equals_match_verbose_code_parts(expected):
return [f"x == {expected}"]
def ge_match(x, expected):
return x >= expected
def ge_match_verbose_code_parts(expected):
return f"expected >= {expected}"
def less_match(x, expected):
return x < expected
def less_match_verbose_code_parts(expected):
return [f"expected < {expected}"]
class GuardManagerTests(torch._dynamo.test_case.TestCase):
def test_global_state_guard(self):
guard = guards.GLOBAL_STATE(["global_state_check"])
self.assertTrue(guard(None))
with set_default_dtype(torch.double):
self.assertFalse(guard(None))
self.assertTrue(guard(None))
_orig = torch.are_deterministic_algorithms_enabled()
try:
torch.use_deterministic_algorithms(not _orig)
self.assertFalse(guard(None))
finally:
torch.use_deterministic_algorithms(_orig)
self.assertTrue(guard(None))
def test_python_lambda_leaf_guard(self):
const_guard = guards.LAMBDA_GUARD(
functools.partial(equals_match, expected=5),
equals_match_verbose_code_parts(5),
)
self.assertTrue(const_guard(5))
self.assertFalse(const_guard(4))
self.assertFalse(const_guard("foo"))
def test_type_guard(self):
foo = 4
guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == int"])
self.assertTrue(guard(5))
self.assertTrue(guard(4))
self.assertFalse(guard("foo"))
foo = {"a": 1}
guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == dict"])
self.assertTrue(guard(foo))
self.assertTrue(guard({}))
self.assertFalse(guard(5))
self.assertFalse(guard("foo"))
class Foo:
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == Foo"])
self.assertTrue(guard(foo))
self.assertFalse(guard({}))
self.assertFalse(guard(5))
self.assertFalse(guard("foo"))
def test_id_guard(self):
foo = 4
guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"])
self.assertTrue(guard(foo))
self.assertFalse(guard(5))
self.assertFalse(guard("foo"))
foo = {"a": 1}
guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"])
self.assertTrue(guard(foo))
self.assertFalse(guard({"a": 1}))
self.assertFalse(guard({}))
self.assertFalse(guard(5))
def test_equals_guard(self):
foo = 4
guard = guards.EQUALS_MATCH(foo, ["x == 4"])
self.assertTrue(guard(4))
self.assertFalse(guard(5))
self.assertFalse(guard("foo"))
# tuple
foo = (1, 2, 3)
guard = guards.EQUALS_MATCH(foo, ["x == foo"])
self.assertTrue(guard(foo))
self.assertTrue(guard((1, 2, 3)))
self.assertFalse(guard((1, 2, 3, 4)))
self.assertFalse(guard({}))
# list
foo = [1, 2, 3]
guard = guards.EQUALS_MATCH(foo, ["x == foo"])
self.assertTrue(guard(foo))
self.assertTrue(guard([1, 2, 3]))
self.assertFalse(guard([1, 2, 3, 4]))
# type
foo = int
guard = guards.EQUALS_MATCH(foo, ["x == foo"])
self.assertTrue(guard(foo))
self.assertTrue(guard(int))
self.assertFalse(guard(float))
def test_default_device_guard(self):
foo = 1
guard = guards.DEFAULT_DEVICE(["cpu device"])
self.assertTrue(guard(foo))
try:
torch.set_default_device("cuda")
self.assertFalse(guard(foo))
finally:
torch.set_default_device(None)
def test_data_ptr_match_guard(self):
foo = torch.tensor([1, 2, 3])
guard = guards.DATA_PTR_MATCH(foo, ["x.data_ptr() == foo.data_ptr()"])
self.assertTrue(guard(foo))
self.assertFalse(guard(torch.tensor([1, 2, 3])))
def test_guard_manager_leaf_guard(self):
guard_manager = RootGuardManager()
guard_manager.add_type_match_guard(id_type(5), ["type(x) == int"])
guard_manager.add_lambda_guard(
functools.partial(ge_match, expected=5),
ge_match_verbose_code_parts(expected=5),
)
guard_manager.add_lambda_guard(
functools.partial(less_match, expected=10),
less_match_verbose_code_parts(expected=10),
)
self.assertEqual(len(guard_manager.get_leaf_guards()), 3)
self.assertEqual(len(guard_manager.get_accessors()), 0)
self.assertTrue(guard_manager.check(6))
self.assertFalse(guard_manager.check(4))
self.assertFalse(guard_manager.check("foo"))
def test_attr_guard_manager(self):
class Foo:
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
guard_manager = RootGuardManager()
guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
guard_manager.getattr_manager("x", 1).add_lambda_guard(
functools.partial(equals_match, expected=foo.x),
equals_match_verbose_code_parts(foo.x),
)
guard_manager.getattr_manager("y", 2).add_lambda_guard(
functools.partial(equals_match, expected=foo.y),
equals_match_verbose_code_parts(foo.y),
)
self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
# 2 child managers, one for x and one for y
self.assertEqual(len(guard_manager.get_accessors()), 2)
self.assertTrue(
isinstance(guard_manager.get_accessors()[0], GetAttrGuardAccessor)
)
self.assertTrue(
isinstance(guard_manager.get_accessors()[1], GetAttrGuardAccessor)
)
# Check leaf guards on child managers
self.assertEqual(
len(guard_manager.getattr_manager("x", None).get_leaf_guards()), 1
)
self.assertEqual(
len(guard_manager.getattr_manager("y", None).get_leaf_guards()), 1
)
self.assertTrue(guard_manager.check(foo))
self.assertFalse(guard_manager.check(Foo(3, 4)))
self.assertFalse(guard_manager.check("foo"))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()