| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-strict |
| |
| import io |
| import unittest |
| from typing import Tuple |
| |
| import executorch.exir as exir |
| |
| import torch |
| from executorch.exir import to_edge |
| from executorch.exir.backend.backend_api import CompileSpec, to_backend |
| from executorch.exir.backend.test.backend_with_compiler_demo import ( |
| BackendWithCompilerDemo, |
| ) |
| |
| from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo |
| from executorch.exir.serde.serialize import deserialize, serialize |
| from torch import nn |
| from torch.export import export |
| from torch.export.exported_program import ExportedProgram as TorchExportedProgram |
| from torch.utils import _pytree as pytree |
| |
| |
| # Tests for serializing to json and back |
| class TestSerde(unittest.TestCase): |
| def check_ep( |
| self, |
| ep1: TorchExportedProgram, |
| ep2: TorchExportedProgram, |
| inputs: Tuple[exir.Value, ...], |
| ) -> None: |
| """ |
| Checks if two graphs are equivalent |
| """ |
| orig_outputs = ep1.module()(*inputs) |
| loaded_outputs = ep2.module()(*inputs) |
| |
| flat_orig_outputs, _ = pytree.tree_flatten(orig_outputs) |
| flat_loaded_outputs, _ = pytree.tree_flatten(loaded_outputs) |
| |
| for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True): |
| self.assertTrue(torch.allclose(orig, loaded)) |
| |
| # pyre-ignore |
| def check_serde(self, m, inputs, check_executorch=True) -> None: |
| aten = export(m, inputs) |
| aten_new = deserialize(serialize(aten)) |
| self.check_ep(aten, aten_new, inputs) |
| |
| edge = to_edge(aten) |
| edge_new = deserialize(serialize(edge.exported_program())) |
| self.check_ep(edge.exported_program(), edge_new, inputs) |
| |
| buffer = io.BytesIO() |
| exir.save(edge.exported_program(), buffer) |
| buffer.seek(0) |
| loaded_ep = exir.load(buffer) |
| self.check_ep(edge.exported_program(), loaded_ep, inputs) |
| |
| executorch = edge.to_executorch().exported_program() |
| executorch_new = deserialize(serialize(executorch)) |
| if check_executorch: |
| with torch.no_grad(): |
| self.check_ep(executorch, executorch_new, inputs) |
| |
| buffer = io.BytesIO() |
| exir.save(executorch, buffer) |
| buffer.seek(0) |
| loaded_ep = exir.load(buffer) |
| self.check_ep(executorch, loaded_ep, inputs) |
| |
| def test_basic(self) -> None: |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| x = x + x |
| x = x * x |
| x = x / x |
| return x, x.clone() |
| |
| inputs = (torch.ones([512], requires_grad=True),) |
| self.check_serde(MyModule(), inputs) |
| |
| def test_to_out_variant_singleon_tensor_list(self) -> None: |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| return torch.split(x, 10) |
| |
| def get_random_inputs(self): |
| return (torch.randn(10),) |
| |
| model = MyModel() |
| inputs = model.get_random_inputs() |
| # We set check_executorch to false for this test because this triggers |
| # an edge case where calling .module() on the executorch exported program |
| # will cause an unlift pass to be run on the graph and dead code elimination |
| # will be subsequently run, which essentially causes the split_copy op to be |
| # removed. |
| self.check_serde(model, inputs, check_executorch=False) |
| |
| def test_to_out_variant_multiple_out(self) -> None: |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| values, indices = torch.topk(x, 5) |
| return (values, indices) |
| |
| def get_random_inputs(self): |
| return (torch.randn(10),) |
| |
| model = MyModel() |
| inputs = model.get_random_inputs() |
| self.check_serde(model, inputs) |
| |
| def test_delegate(self) -> None: |
| class SinModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| return torch.sin(x) |
| |
| sin_module = SinModule() |
| model_inputs = (torch.ones(1),) |
| edgeir_m = to_edge(export(sin_module, model_inputs)) |
| max_value = model_inputs[0].shape[0] |
| compile_specs = [CompileSpec("max_value", bytes([max_value]))] |
| lowered_sin_module = to_backend( |
| BackendWithCompilerDemo.__name__, edgeir_m.exported_program(), compile_specs |
| ) |
| |
| class CompositeModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lowered_linear_sin = lowered_sin_module |
| |
| def forward(self, x): |
| return self.lowered_linear_sin(x) |
| |
| composite_model = CompositeModule() |
| model_inputs = (torch.ones(1),) |
| |
| composite_model(*model_inputs) |
| |
| edge = to_edge(export(composite_model, model_inputs)) |
| edge_new = deserialize(serialize(edge.exported_program())) |
| self.check_ep(edge.exported_program(), edge_new, model_inputs) |
| |
| def test_model_with_weights(self) -> None: |
| class LinearAdd(nn.Module): |
| def __init__(self, M: int, N: int): |
| super().__init__() |
| self.M = M |
| self.N = N |
| self.linear = torch.nn.Linear(M, N) |
| |
| def forward(self, x, y): |
| x = self.linear(x) |
| y = self.linear(y) |
| return torch.add(x, y) |
| |
| @classmethod |
| def _get_random_inputs(cls): |
| return (torch.rand(128, 20), torch.rand(128, 20)) |
| |
| linear_add = LinearAdd(20, 30) |
| model_inputs = LinearAdd._get_random_inputs() |
| |
| self.check_serde(linear_add, model_inputs) |
| |
| def test_delegate_partitioner(self) -> None: |
| class Model(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, a, x, b): |
| y = torch.mm(a, x) |
| z = y + b |
| a = z - a |
| y = torch.mm(a, x) |
| z = y + b |
| return z |
| |
| m = Model() |
| inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) |
| |
| ep = to_edge(export(m, inputs)) |
| edge = ep.to_backend(AddMulPartitionerDemo()) |
| edge_new = deserialize(serialize(edge.exported_program())) |
| self.check_ep(edge.exported_program(), edge_new, inputs) |
| |
| def test_meta_stack_trace_module_hierarchy(self) -> None: |
| class Model(nn.Module): |
| def __init__(self): |
| super(Model, self).__init__() |
| self.conv_layer = nn.Conv2d( |
| in_channels=1, out_channels=64, kernel_size=3, padding=1 |
| ) |
| |
| def forward(self, x): |
| return self.conv_layer(x) |
| |
| m = Model() |
| inputs = (torch.randn(1, 1, 32, 32),) |
| |
| metadata = () |
| edge = to_edge(export(m, inputs)) |
| for node in edge.exported_program().graph_module.graph.nodes: |
| if "convolution" in str(node.target): |
| metadata = ( |
| node.meta.get("stack_trace"), |
| node.meta.get("nn_module_stack"), |
| ) |
| |
| metadata_serde = () |
| edge_new = deserialize(serialize(edge.exported_program())) |
| for node in edge_new.graph_module.graph.nodes: |
| if "convolution" in str(node.target): |
| metadata_serde = ( |
| node.meta.get("stack_trace"), |
| node.meta.get("nn_module_stack"), |
| ) |
| self.assertTrue(len(metadata) != 0 and len(metadata_serde) != 0) |
| self.assertTrue( |
| all(val is not None for val in metadata) |
| and all(val is not None for val in metadata_serde) |
| ) |
| self.assertEqual(metadata[0], metadata_serde[0]) |
| self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys())) |