| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-strict |
| |
| import copy |
| import unittest |
| from typing import Dict, List, Tuple |
| |
| import executorch.exir as exir |
| import executorch.exir.tests.models as models |
| |
| import torch |
| |
| from executorch.exir import CaptureConfig |
| from executorch.exir.dialects._ops import ops as exir_ops |
| from executorch.exir.tests.common import register_additional_test_aten_ops |
| from executorch.exir.tracer import dynamo_trace, ExirDynamoConfig, using_dynamo |
| from functorch.experimental.control_flow import cond, map |
| |
| from parameterized import parameterized |
| from torch._export.verifier import SpecViolationError |
| from torch.fx.experimental.symbolic_shapes import is_concrete_int |
| from torch.testing import FileCheck |
| |
| |
| class TestTorchDispatchFXTracer(unittest.TestCase): |
| @classmethod |
| def setUpClass(cls) -> None: |
| register_additional_test_aten_ops() |
| |
| def test_simple(self) -> None: |
| f = models.BasicSinMax() |
| f = ( |
| exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) |
| .to_edge() |
| .exported_program.graph_module |
| ) |
| |
| FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").run(f.code) |
| |
| def test_static_control_flow(self) -> None: |
| def f(pred: bool, x: torch.Tensor) -> torch.Tensor: |
| if pred: |
| return torch.sin(x).max() |
| else: |
| return torch.sin(x) |
| |
| pred = True |
| x = torch.randn(100) |
| f_true = ( |
| exir.capture(f, (pred, x), exir.CaptureConfig()) |
| .to_edge() |
| .exported_program.graph_module |
| ) |
| |
| FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check( |
| "executorch_exir_dialects_edge__ops_aten_max" |
| ).run(f_true.code) |
| |
| pred = False |
| f_false = ( |
| exir.capture(f, (pred, x), exir.CaptureConfig()) |
| .to_edge() |
| .exported_program.graph_module |
| ) |
| FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check_not( |
| "executorch_exir_dialects_edge__ops_aten_max" |
| ).run(f_false.code) |
| |
| def test_copy(self) -> None: |
| f = models.BasicSinMax() |
| f = ( |
| exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) |
| .to_edge() |
| .exported_program.graph_module |
| ) |
| |
| self.assertTrue(isinstance(f, torch.fx.GraphModule)) |
| g = copy.deepcopy(f) |
| self.assertTrue(isinstance(g, torch.fx.GraphModule)) |
| |
| def test_stacktrace(self) -> None: |
| def f(x: torch.Tensor) -> torch.Tensor: |
| return x + x |
| |
| traced_f = ( |
| exir.capture(f, (torch.rand(2, 2),), exir.CaptureConfig()) |
| .to_edge() |
| .exported_program.graph_module |
| ) |
| # Check that stacktrace is populated and retained (by checking twice) |
| self.assertTrue( |
| any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes) |
| ) |
| self.assertTrue( |
| any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes) |
| ) |
| |
| def test_possible_input_mutation(self) -> None: |
| def f(x: torch.Tensor) -> torch.Tensor: |
| return torch.add(torch.ones(5), torch.ones(5), out=x) |
| |
| with self.assertRaisesRegex( |
| SpecViolationError, |
| r"operator .* is not functional", |
| ): |
| exir.capture(f, (torch.zeros(5),), exir.CaptureConfig()).to_edge() |
| |
| def test_tensor_spec_for_const_tensors(self) -> None: |
| class Module(torch.nn.Module): |
| def __init__(self) -> None: |
| super(Module, self).__init__() |
| self.linear = torch.nn.Linear(2, 3) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.linear(x) |
| |
| def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: |
| return (torch.randn(2),) |
| |
| model = Module() |
| graph_module = ( |
| exir.capture(model, model.get_random_inputs(), exir.CaptureConfig()) |
| # torch._ops.aten.t.default |
| .to_edge( |
| exir.EdgeCompileConfig(_check_ir_validity=False) |
| ).exported_program.graph_module |
| ) |
| num_get_attr_node = 0 |
| num_get_attr_node_with_tensorspec = 0 |
| for nd in graph_module.graph.nodes: |
| if nd.op == "get_attr": |
| num_get_attr_node += 1 |
| if nd.meta.get("val") is not None: |
| num_get_attr_node_with_tensorspec += 1 |
| |
| self.assertEqual(2, num_get_attr_node) |
| self.assertEqual(2, num_get_attr_node_with_tensorspec) |
| |
| def test_multiple_returns_spec(self) -> None: |
| def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| return torch.ops.aten.max.dim(x, 0, False) |
| |
| cnt = 0 |
| module = ( |
| exir.capture(f, (torch.zeros(1, 2, 3),), exir.CaptureConfig()) |
| .to_edge() |
| .exported_program.graph_module |
| ) |
| for node in module.graph.nodes: |
| if node.target == exir_ops.edge.aten.max.dim: |
| cnt += 1 |
| self.assertIsInstance(node.meta["val"], tuple) |
| self.assertEqual(cnt, 1) |
| |
| def test_multiple_returns_pt2_mode(self) -> None: |
| def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| a = x * x |
| b = x + a |
| return a, b |
| |
| inputs = (torch.ones(1, 2, 3),) |
| orig_res = f(*inputs) |
| module = ( |
| exir.capture( |
| f, |
| inputs, |
| exir.CaptureConfig(), |
| ) |
| .to_edge() |
| .exported_program.graph_module |
| ) |
| new_res = module(*inputs) |
| for node in module.graph.nodes: |
| if node.op == "output": |
| self.assertIsInstance(node.meta["val"], list) |
| self.assertEqual(len(node.meta["val"]), 2) |
| |
| self.assertTrue(torch.allclose(orig_res[0], new_res[0])) |
| self.assertTrue(torch.allclose(orig_res[1], new_res[1])) |
| |
| def test_dynamo_capture_scalar_outputs(self) -> None: |
| def f(x: torch.Tensor) -> float: |
| return x.item() |
| |
| gm, guards = dynamo_trace( |
| f, |
| (torch.ones(1),), |
| False, |
| "real", |
| ExirDynamoConfig(), |
| ) |
| |
| # pyre-ignore |
| @parameterized.expand([("stock_tensor",)]) |
| def test_embedding_dynamic_shape(self, input_type: str) -> None: |
| class Module(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x): |
| return x + x |
| |
| example_input = torch.ones(10, dtype=torch.int64) |
| m = Module() |
| gm = ( |
| exir.capture( |
| m, |
| (example_input,), |
| exir.CaptureConfig( |
| enable_functionalization=False, |
| enable_dynamic_shape=True, |
| ), |
| ) |
| .to_edge() |
| .exported_program.graph_module |
| ) |
| |
| print(gm.graph) |
| |
| def test_dynamic_shape(self) -> None: |
| def forward(x: torch.Tensor) -> torch.Tensor: |
| x = x.view(x.shape[0] - 1, -1) |
| return torch.cat([x, x]) |
| |
| gm = ( |
| exir.capture( |
| forward, |
| (torch.ones(3, 2, dtype=torch.int64),), |
| exir.CaptureConfig( |
| enable_functionalization=False, |
| enable_dynamic_shape=True, |
| _dynamo_config=ExirDynamoConfig(assume_static_by_default=True), |
| ), |
| # sym_size is not reg op |
| ) |
| .to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) |
| .exported_program.graph_module |
| ) |
| |
| for node in gm.graph.nodes: |
| if node.op in ("placeholder", "call_function"): |
| self.assertIn("val", node.meta) |
| |
| def test_dynamo_frontend_container_input(self) -> None: |
| class Module(torch.nn.Module): |
| def __init__(self) -> None: |
| super(Module, self).__init__() |
| |
| def forward( |
| self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] |
| ) -> torch.Tensor: |
| a = x[0] |
| b = x[1] |
| cum = 0 |
| for i in b: |
| cum += i.sum() |
| return a.cos() + cum.sin() |
| |
| with using_dynamo(True): |
| inp = ((torch.ones(6), (torch.ones(6), torch.ones(6))),) |
| gm = exir.capture(Module(), inp, exir.CaptureConfig()) |
| self.assertTrue(torch.allclose(Module()(*inp), gm(*inp))) |
| |
| # TODO (tmanlaibaatar) remove this test |
| def test_pt2_mode_with_dynamo_config(self) -> None: |
| def f(x: torch.Tensor) -> torch.Tensor: |
| return x[: x.shape[0] - 1] |
| |
| inp = (torch.randn(4, 5),) |
| prog = exir.capture( |
| f, |
| inp, |
| # missing dispatch key |
| ).to_edge() |
| self.assertTrue(prog(torch.randn(4, 5)).shape[0], 3) |
| |
| def test_input_container_type(self) -> None: |
| def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]: |
| # pyre-ignore |
| return {"a": x.sum() + sum(y).sum()} |
| |
| inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)]) |
| |
| # pyre-fixme[23]: Unable to unpack `(...) -> Tuple[GraphModule, |
| # Set[torch._guards.Guard]]` into 2 values. |
| gm, _ = torch._dynamo.export(f, *inp, aten_graph=True, tracing_mode="symbolic") |
| prog = exir.capture(f, inp, config=exir.CaptureConfig()).to_edge() |
| |
| self.assertEqual(prog(*inp), f(*inp)) |
| |
| def test_aot_buffer_mutation(self) -> None: |
| class Module(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer( |
| "_bin_num_examples", |
| torch.empty([42]).fill_( |
| 0.0, |
| ), |
| ) |
| |
| def forward(self, x, y, z): |
| self._bin_num_examples.index_copy_( |
| dim=0, |
| index=y, |
| source=z, |
| ) |
| self._bin_num_examples.index_add_( |
| dim=0, index=torch.arange(4), source=x |
| ) |
| return self._bin_num_examples - 1, x * z |
| |
| model = Module() |
| example_inputs = ( |
| torch.randn(4, requires_grad=True), |
| torch.tensor(0), |
| torch.tensor(3.14), |
| ) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "Found a graph input that requires gradients, and received a mutation.", |
| ): |
| _ = exir.capture( |
| model, |
| example_inputs, |
| exir.CaptureConfig( |
| enable_aot=True, |
| ), |
| ) |
| |
| # Note that model._bin_num_examples is mutated during exir.capture |
| # We need to create a new_model |
| new_model = Module() |
| example_inputs = ( |
| torch.randn(4), |
| torch.tensor(0), |
| torch.tensor(3.14), |
| ) |
| |
| ep = exir.capture( |
| new_model, |
| example_inputs, |
| exir.CaptureConfig( |
| enable_aot=True, |
| ), |
| ) |
| |
| test_inputs = ( |
| torch.randn(4), |
| torch.tensor(0), |
| torch.tensor(2.1), |
| ) |
| graph_outputs = ep(*test_inputs) |
| eager_outputs = Module()(*test_inputs) |
| self.assertEqual(len(graph_outputs), 2) |
| self.assertEqual(len(eager_outputs), 2) |
| self.assertTrue(torch.allclose(graph_outputs[0], eager_outputs[0])) |
| self.assertTrue(torch.allclose(graph_outputs[1], eager_outputs[1])) |
| |
| def test_assume_constant_by_default_prop(self) -> None: |
| def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| if x.shape[0] > 3: |
| return x.cos() |
| return x.sin() |
| |
| dynamo_config = ExirDynamoConfig(assume_static_by_default=True) |
| capture_config = exir.CaptureConfig( |
| enable_dynamic_shape=True, _dynamo_config=dynamo_config |
| ) |
| captured = exir.capture( |
| foo, (torch.ones(6, 2), torch.ones(6, 3)), capture_config |
| ).exported_program.graph_module |
| found = False |
| for node in captured.graph.nodes: |
| # at least one input needs to have concrete dims |
| if "val" in node.meta: |
| fake_val = node.meta["val"] |
| for dim in fake_val.shape: |
| if is_concrete_int(dim): |
| found = True |
| |
| self.assertTrue(found) |
| |
| def test_aot_config(self) -> None: |
| class FooWithBuffer(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("buffer", torch.zeros(42)) |
| |
| def forward(self, x): |
| return x.cos() + self.buffer.sum() |
| |
| capture_config = exir.CaptureConfig(enable_aot=True) |
| captured_ep = exir.capture(FooWithBuffer(), (torch.ones(6, 2),), capture_config) |
| captured_gm = captured_ep.exported_program.graph_module |
| |
| placeholder_nodes = set() |
| print(captured_gm.graph) |
| for node in captured_gm.graph.nodes: |
| self.assertFalse(node.op == "get_attr") |
| if node.op == "placeholder": |
| placeholder_nodes.add(node) |
| if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: |
| # make sure the placeholders are used |
| arg_0, arg_1 = node.args |
| self.assertEqual( |
| placeholder_nodes, |
| { |
| list(arg_0._input_nodes.keys())[0], |
| list(arg_1._input_nodes.keys())[0], |
| }, |
| ) |
| |
| self.assertEqual(len(placeholder_nodes), 2) |
| captured_ep.to_edge() |
| |
| def test_export_unlift(self) -> None: |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("buffer", torch.ones(6, 4)) |
| |
| def forward(self, x): |
| return x.cos() + self.buffer.sin() |
| |
| ep = exir.capture( |
| Foo(), |
| (torch.ones(6, 4),), |
| exir.CaptureConfig(enable_aot=True, _unlift=True), |
| ) |
| |
| self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))) |
| |
| def test_export_container_unlift(self) -> None: |
| class FooContainerInputOutput(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("buffer", torch.ones(6, 4)) |
| |
| def forward(self, x): |
| return x[0][0].cos() + x[0][1].sin() + self.buffer.sin() |
| |
| inp = ((torch.ones(6, 4), torch.ones(6, 4)),) |
| ep = exir.capture( |
| FooContainerInputOutput(), |
| (inp,), |
| CaptureConfig(enable_aot=True, _unlift=True), |
| ) |
| self.assertTrue(torch.allclose(ep(inp), FooContainerInputOutput()(inp))) |
| |
| def test_export_container_input_unlift(self) -> None: |
| class FooContainerInputOutputV2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("buffer", torch.ones(6, 4)) |
| |
| def forward(self, x, y): |
| return x[0].cos() + y[0].sin() + self.buffer.sin() |
| |
| inp = ((torch.ones(6, 4),), (torch.ones(6, 4),)) |
| ep = exir.capture( |
| FooContainerInputOutputV2(), |
| inp, |
| CaptureConfig(enable_aot=True, _unlift=True), |
| ) |
| self.assertTrue(torch.allclose(ep(*inp), FooContainerInputOutputV2()(*inp))) |
| |
| def test_export_cond(self) -> None: |
| class A(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("buffer", torch.ones(6, 4)) |
| |
| def forward(self): |
| return self.buffer.cos() |
| |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = A() |
| |
| def forward(self, x): |
| def true_fn(x): |
| return x.cos() + self.a().sum() |
| |
| def false_fn(x): |
| return x.sin() |
| |
| return cond(x.shape[0] > 4, true_fn, false_fn, [x]) |
| |
| inp = torch.ones(6, 4) |
| ep = exir.capture( |
| Foo(), |
| (inp,), |
| CaptureConfig(enable_aot=True, _unlift=True), |
| ) |
| self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))) |
| |
| def test_export_cond_map(self) -> None: |
| class A(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("buffer", torch.ones(6, 4)) |
| |
| def forward(self): |
| return self.buffer.sum() |
| |
| class Module(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = A() |
| |
| def inner(self, x, pred): |
| def true_fn(x): |
| return x + x + self.a() |
| |
| def false_fn(x): |
| return x * x - self.a() |
| |
| return cond(pred, true_fn, false_fn, [x]) |
| |
| def forward(self, pred, xs): |
| def body(x, pred): |
| return self.inner(x, pred) + self.a() |
| |
| return map(body, xs, pred) |
| |
| inp = torch.randn(3, 2, 1) |
| ep = exir.capture( |
| Module(), |
| (torch.tensor(True), inp), |
| CaptureConfig(enable_aot=True, _unlift=True), |
| ) |
| |
| inp_test = torch.randn(3, 2, 1) |
| self.assertTrue( |
| torch.allclose( |
| ep(torch.tensor(True), inp_test), |
| Module()(torch.tensor(True), inp_test), |
| ) |
| ) |
| self.assertTrue( |
| torch.allclose( |
| ep(torch.tensor(False), inp_test), |
| Module()(torch.tensor(False), inp_test), |
| ) |
| ) |