| # Owner(s): ["module: dynamo"] |
| import contextlib |
| import functools |
| |
| import torch |
| |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| import torch._functorch.config |
| import torch.utils._pytree as pytree |
| import torch.utils.checkpoint |
| from torch._dynamo.testing import normalize_gm |
| from torch._functorch.aot_autograd import to_fun |
| from torch._higher_order_ops.wrap import wrap |
| |
| from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv |
| |
| |
| class MockSubclass(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return func(*args, **kwargs) |
| |
| |
| class EagerRecordGraphAndInputs: |
| def __init__(self): |
| self.graphs = [] |
| self.example_inputs = [] |
| |
| def __call__(self, gm: torch.fx.GraphModule, example_inputs): |
| self.graphs.append(gm) |
| self.example_inputs.append(example_inputs) |
| return gm |
| |
| |
| @contextlib.contextmanager |
| def preserve_subclass_config(): |
| old_subclass_set = set(torch._dynamo.config.traceable_tensor_subclasses) |
| try: |
| torch._dynamo.config.traceable_tensor_subclasses.add(MockSubclass) |
| yield |
| finally: |
| torch._dynamo.config.traceable_tensor_subclasses.clear() |
| torch._dynamo.config.traceable_tensor_subclasses.update(old_subclass_set) |
| |
| |
| class SubclassTests(torch._dynamo.test_case.TestCase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| cls._exit_stack.enter_context(preserve_subclass_config()) |
| |
| @classmethod |
| def tearDownClass(cls): |
| cls._exit_stack.close() |
| |
| def test_torch_function_state_graph_break(self): |
| @torch.compile(backend="eager") |
| def fn(x): |
| with torch._C.DisableTorchFunctionSubclass(): |
| torch._dynamo.graph_break() |
| return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) |
| |
| input = torch.ones(2, 2) |
| res, _ = fn(input) |
| self.assertFalse(res) |
| |
| def test_torch_function_state_tracing(self): |
| @torch.compile(backend="eager", fullgraph=True) |
| def fn(x): |
| with torch._C.DisableTorchFunctionSubclass(): |
| torch.add(x, 1.0) |
| |
| input = torch.ones(2, 2) |
| |
| res = fn(input) |
| |
| def test_torch_function_state_guards(self): |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch.compile(backend=cnt, fullgraph=True) |
| def fn(x): |
| torch.add(x, 1.0) |
| |
| input = torch.ones(2, 2) |
| |
| with torch._C.DisableTorchFunctionSubclass(): |
| res = fn(input) |
| |
| res = fn(input) |
| |
| self.assertEqual(cnt.frame_count, 2) |
| |
| def test_return_subclass(self): |
| @torch.compile(backend="eager", fullgraph=True) |
| def fn(x): |
| return MockSubclass(torch.add(x, 1.0)) |
| |
| input = torch.ones(2, 2) |
| |
| res = fn(input) |
| self.assertIsInstance(res, MockSubclass) |
| |
| def test_return_local_subclass(self): |
| class LocalSubclass(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return func(*args, **kwargs) |
| |
| torch._dynamo.config.traceable_tensor_subclasses.add(LocalSubclass) |
| |
| @torch.compile(backend="eager", fullgraph=True) |
| def fn(x): |
| return LocalSubclass(torch.add(x, 1.0)) |
| |
| input = torch.ones(2, 2) |
| |
| res = fn(input) |
| self.assertIsInstance(res, LocalSubclass) |
| |
| def test_compile_with_fake_tensor_dynamic_dim(self): |
| x = torch.randn([3, 4]) |
| |
| def f(x): |
| return torch.sin(x) |
| |
| def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count): |
| torch._dynamo.reset() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| opt_f = torch.compile(f, backend=cnt, fullgraph=True) |
| |
| x1 = torch.rand_like(x) |
| f(x) |
| f(torch.randn([4, 3])) |
| shape_env = ShapeEnv() |
| with torch._subclasses.fake_tensor.FakeTensorMode( |
| shape_env=shape_env |
| ) as fake_mode: |
| x_fake = fake_mode.from_tensor( |
| x, dynamic_dims=[dim_dynamic for i in range(x.dim())] |
| ) |
| x1_fake = fake_mode.from_tensor( |
| x1, dynamic_dims=[dim_dynamic for i in range(x.dim())] |
| ) |
| opt_f(x_fake) |
| opt_f(x1_fake) |
| |
| self.assertEqual(cnt.frame_count, exp_frame_count) |
| self.assertEqual(cnt.op_count, exp_op_count) |
| |
| test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1) |
| test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1) |
| test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1) |
| |
| def test_compile_with_fake_tensor_automatic_dynamic(self): |
| def f(x): |
| return torch.sin(x) |
| |
| def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count): |
| torch._dynamo.reset() |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_f = torch.compile(f, backend=cnt, fullgraph=True) |
| |
| shape_env = ShapeEnv() |
| with torch._subclasses.fake_tensor.FakeTensorMode( |
| shape_env=shape_env |
| ) as fake_mode: |
| for inp in inps: |
| fake_inp = fake_mode.from_tensor( |
| inp, dynamic_dims=[dim_dynamic for i in range(x.dim())] |
| ) |
| opt_f(fake_inp) |
| self.assertEqual(cnt.frame_count, exp_frame_count) |
| self.assertEqual(cnt.op_count, exp_op_count) |
| |
| x = torch.randn([3, 4]) |
| y = torch.randn([4, 5]) |
| z = torch.randn([5, 6]) |
| a = torch.randn([3, 5]) |
| b = torch.randn([4, 4]) |
| for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK, DimDynamic.STATIC]: |
| # Recompile once, first with dim 0 and 1 become Dynamic |
| test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2) |
| # Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic. |
| test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3) |
| # Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic. |
| test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3) |
| |
| def test_compile_with_functionalization(self): |
| x = torch.randn([3, 4]) |
| x_clone = x.clone() |
| x_clone2 = x.clone() |
| backend = EagerRecordGraphAndInputs() |
| cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) |
| |
| @torch.compile(backend=cnt, fullgraph=True) |
| def f(x): |
| return x.add_(1.0) + torch.nn.functional.relu_(x) |
| |
| f_out = f(x) |
| self.assertEqual(cnt.frame_count, 1) |
| self.assertEqual(cnt.op_count, 3) |
| self.assertEqual(len(backend.graphs), 1) |
| self.assertEqual(len(backend.example_inputs), 1) |
| |
| expected = """\ |
| class GraphModule(torch.nn.Module): |
| def forward(self, L_x_ : torch.Tensor): |
| l_x_ = L_x_ |
| |
| add_ = l_x_.add_(1.0) |
| relu_ = torch.relu_(l_x_); l_x_ = None |
| add = add_ + relu_; add_ = relu_ = None |
| return (add,) |
| """ |
| actual = normalize_gm(backend.graphs[0].print_readable(print_output=False)) |
| self.assertEqual(actual, expected) |
| |
| ff = torch.func.functionalize(f) |
| ff_out = ff(x_clone) |
| |
| self.assertEqual(cnt.frame_count, 2) |
| self.assertEqual(cnt.op_count, 6) |
| self.assertEqual(len(backend.graphs), 2) |
| self.assertEqual(len(backend.example_inputs), 2) |
| actual = normalize_gm(backend.graphs[1].print_readable(print_output=False)) |
| self.assertEqual(actual, expected) |
| self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) |
| |
| def aot_f_wrapper(func): |
| @functools.wraps(func) |
| def wrapper(*args, **kwargs): |
| torch._enable_functionalization(reapply_views=False) |
| try: |
| func_args = pytree.tree_map(to_fun, args) |
| func_kwargs = pytree.tree_map(to_fun, kwargs) |
| return func(*func_args, **func_kwargs) |
| finally: |
| torch._disable_functionalization() |
| |
| return wrapper |
| |
| aot_ff = aot_f_wrapper(f) |
| aot_ff_out = aot_ff(x_clone2) |
| |
| self.assertEqual(cnt.frame_count, 3) |
| self.assertEqual(cnt.op_count, 9) |
| self.assertEqual(len(backend.graphs), 3) |
| self.assertEqual(len(backend.example_inputs), 3) |
| actual = normalize_gm(backend.graphs[2].print_readable(print_output=False)) |
| self.assertEqual(actual, expected) |
| self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) |
| |
| self.assertEqual(f_out, ff_out) |
| self.assertEqual(f_out, aot_ff_out) |
| |
| try: |
| torch._enable_functionalization(reapply_views=False) |
| xf = pytree.tree_map(to_fun, x) |
| x_view = xf.t() |
| with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"): |
| f(x_view) |
| finally: |
| torch._disable_functionalization() |
| |
| def test_compile_higher_order_with_functionalization(self): |
| backend = EagerRecordGraphAndInputs() |
| cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) |
| |
| @torch.compile(backend=cnt, fullgraph=True) |
| def f(x): |
| return wrap(lambda x: x.add_(1.0), x) |
| |
| def check_count_and_graph( |
| exp_frame_count, exp_op_count, exp_n_graph, exp_graph |
| ): |
| self.assertEqual(cnt.frame_count, exp_frame_count) |
| self.assertEqual(cnt.op_count, exp_op_count) |
| self.assertEqual(len(backend.graphs), exp_n_graph) |
| actual = normalize_gm( |
| backend.graphs[exp_n_graph - 1].print_readable(print_output=False) |
| ) |
| self.assertExpectedInline(actual, exp_graph) |
| |
| t = torch.randn([3, 4]) |
| t_clone = t.clone() |
| t_clone2 = t.clone() |
| f(t) |
| |
| expected_graph = """\ |
| class GraphModule(torch.nn.Module): |
| def forward(self, L_x_ : torch.Tensor): |
| l_x_ = L_x_ |
| |
| wrap_body_0 = self.wrap_body_0 |
| wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None |
| return (wrap,) |
| |
| class GraphModule(torch.nn.Module): |
| def forward(self, l_x_): |
| add_ = l_x_.add_(1.0); l_x_ = None |
| return add_ |
| """ |
| check_count_and_graph(1, 1, 1, expected_graph) |
| |
| ff = torch.func.functionalize(f) |
| ff_out = ff(t_clone) |
| # frame count and op count are incremented due to re-compilation |
| check_count_and_graph(2, 2, 2, expected_graph) |
| |
| try: |
| x = torch._to_functional_tensor(t_clone2, mirror_autograd_meta=True) |
| torch._enable_functionalization(reapply_views=False) |
| aot_f_out = f(x) |
| finally: |
| torch._disable_functionalization() |
| |
| # frame count and op count are incremented due to re-compilation |
| check_count_and_graph(3, 3, 3, expected_graph) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |