| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-unsafe |
| |
| import unittest |
| |
| import torch |
| from executorch.exir import to_edge |
| from executorch.exir.passes.const_prop_pass import ConstPropPass |
| from executorch.exir.schema import Tensor, TensorList |
| |
| from executorch.exir.verification.interpreter import Interpreter |
| from executorch.exir.verification.verifier import EXIREdgeDialectVerifier |
| from torch._export.verifier import SpecViolationError |
| from torch.export import export |
| |
| |
| class WrapperModule(torch.nn.Module): |
| def __init__(self, fn): |
| super().__init__() |
| self.fn = fn |
| |
| def forward(self, *args, **kwargs): |
| return self.fn(*args, **kwargs) |
| |
| |
| class TestVerification(unittest.TestCase): |
| def test_constant_buffer(self) -> None: |
| def f(x: torch.Tensor) -> torch.Tensor: |
| return torch.ones(2) + x + torch.ones(2) |
| |
| # Generate program |
| program = ( |
| to_edge(export(WrapperModule(f), (torch.randn(2),))) |
| .transform( |
| [ |
| ConstPropPass(), |
| ] |
| ) |
| .to_executorch() |
| ._emitter_output.program |
| ) |
| |
| test = Interpreter(program) |
| for val_idx in range(len(test.execution_plan.values)): |
| val = test.execution_plan.values[val_idx].val |
| if not ( |
| isinstance(val, Tensor) and val.data_buffer_idx == 0 |
| ) and not isinstance(val, TensorList): |
| test.load_value(val_idx) |
| vlist = test.get_value_list() |
| for e in vlist: |
| if isinstance(e, torch.Tensor): |
| self.assertTrue(torch.allclose(e, torch.ones(2))) |
| |
| # asserting only 2 constant Tensors exist in value list |
| self.assertEqual(len([e for e in vlist if isinstance(e, torch.Tensor)]), 2) |
| |
| def test_operator_list(self) -> None: |
| class Op1(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.a = torch.ones(2, 2) |
| self.b = 2 * torch.ones(2, 2) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for _ in range(10): |
| z = self.a * x # mul |
| y = z - self.b # sub |
| return y |
| |
| class Op2(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.a = torch.ones(2, 2) |
| self.b = 2 * torch.ones(2, 2) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for _ in range(10): |
| z = self.a % x # remainder |
| y = z / self.b # div |
| z = z + z # add |
| return y + z |
| |
| # Generate a program with Op1's operations (mul, sub) |
| model1 = Op1() |
| inputs = (torch.ones(2, 2),) |
| program = ( |
| to_edge(export(model1, inputs)).to_executorch()._emitter_output.program |
| ) |
| |
| # Initialize and test Interpreter -- assert that the operators are same as above |
| test = Interpreter(program) |
| self.assertEqual( |
| set(test.get_operators_list()), |
| {torch.ops.aten.mul.out, torch.ops.aten.sub.out}, |
| ) |
| |
| # Generate a program with Op2's operations (remainder, div, add_, add) |
| model2 = Op2() |
| inputs = (torch.ones(2, 2),) |
| program = ( |
| to_edge(export(model2, inputs)).to_executorch()._emitter_output.program |
| ) |
| |
| # Initialize and test Interpreter -- assert that the operators are same as above |
| test = Interpreter(program) |
| self.assertEqual( |
| set(test.get_operators_list()), |
| { |
| torch.ops.aten.remainder.Tensor_out, |
| torch.ops.aten.div.out, |
| torch.ops.aten.add.out, |
| }, |
| ) |
| |
| def test_verification(self) -> None: |
| class Op2(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.a = torch.ones(2, 2) |
| self.b = 2 * torch.ones(2, 2) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for _ in range(10): |
| z = self.a % x # remainder |
| y = z / self.b # div |
| z = z + z # add |
| return y + z |
| |
| # Generate a program with Op2's operations (remainder, div, add) |
| model2 = Op2() |
| inputs = torch.ones(2, 2) |
| exec_prog = to_edge(export(model2, (inputs,))).to_executorch() |
| |
| exported_prog = exec_prog.exported_program() |
| res = exported_prog.module()(inputs)[0] # noqa |
| # Verifiers are run internally in to_edge, export, and to_executorch. |
| # If we make it this far then no errors were thrown in verification |
| |
| |
| class TestEdgeVerification(unittest.TestCase): |
| def test_edge_happy(self) -> None: |
| class TestModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("a", torch.randn(1, 3, 100, 100)) |
| |
| def forward(self, x): |
| b = self.a + x |
| return torch._to_cpu([b, x]) |
| |
| m = TestModel() |
| egm = ( |
| to_edge( |
| export( |
| m, |
| (torch.randn(1, 3, 100, 100).to(dtype=torch.int),), |
| ) |
| ) |
| .exported_program() |
| .graph_module |
| ) |
| verifier = EXIREdgeDialectVerifier() |
| verifier(egm) |
| self.assertTrue(verifier.is_valid(egm)) |
| |
| def test_edge_happy_with_optional_tensor_input(self) -> None: |
| class TestModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x, weight, bias): |
| # weight and bias here are optional tensor inputs. |
| return torch.group_norm(x, 4, weight, bias) |
| |
| m = TestModel() |
| egm = ( |
| to_edge( |
| export( |
| m, |
| (torch.rand(16, 8, 32, 32), torch.rand(8), torch.rand(8)), |
| ) |
| ) |
| .exported_program() |
| .graph_module |
| ) |
| verifier = EXIREdgeDialectVerifier() |
| verifier(egm) |
| self.assertTrue(verifier.is_valid(egm)) |
| |
| def test_edge_happy_with_empty_tensorlist_input(self) -> None: |
| class TestModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| return torch._to_cpu(x) |
| |
| m = TestModel() |
| egm = ( |
| to_edge( |
| export( |
| m, |
| ([],), |
| ) |
| ) |
| .exported_program() |
| .graph_module |
| ) |
| verifier = EXIREdgeDialectVerifier() |
| verifier(egm) |
| self.assertTrue(verifier.is_valid(egm)) |
| |
| def test_edge_sad(self) -> None: |
| class TestModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("a", torch.randn(1, 3, 100, 100)) |
| |
| def forward(self, x): |
| b = self.a + x |
| return torch._to_cpu([b, x]) |
| |
| m = TestModel() |
| egm = export( |
| m, |
| (torch.randn(1, 3, 100, 100).to(dtype=torch.int),), |
| ).graph_module |
| verifier = EXIREdgeDialectVerifier() |
| with self.assertRaises(SpecViolationError): |
| verifier(egm) |
| |
| def test_edge_happy_with_edge_ops(self) -> None: |
| class TestModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| return x + x |
| |
| m = TestModel() |
| egm = ( |
| to_edge( |
| export( |
| m, |
| (torch.randn(1, 3, 100, 100).to(dtype=torch.int),), |
| ) |
| ) |
| .exported_program() |
| .graph_module |
| ) |
| verifier = EXIREdgeDialectVerifier() |
| verifier(egm) |
| self.assertTrue(verifier.is_valid(egm)) |
| |
| def test_edge_sad_with_edge_ops(self) -> None: |
| # log_softmax only takes float or double Tensor |
| m = torch.nn.LogSoftmax(dim=1) |
| with self.assertRaises(SpecViolationError): |
| _ = ( |
| to_edge( |
| export( |
| m, |
| (torch.randn(1, 3, 100, 100).to(dtype=torch.bfloat16),), |
| ) |
| ) |
| .exported_program() |
| .graph_module |
| ) |