blob: c223e0ad844e12387e62b9992750c5a40e35fb29 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import unittest
import torch
from executorch.exir import to_edge
from executorch.exir.passes.const_prop_pass import ConstPropPass
from executorch.exir.schema import Tensor, TensorList
from executorch.exir.verification.interpreter import Interpreter
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
from torch._export.verifier import SpecViolationError
from torch.export import export
class WrapperModule(torch.nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, *args, **kwargs):
return self.fn(*args, **kwargs)
class TestVerification(unittest.TestCase):
def test_constant_buffer(self) -> None:
def f(x: torch.Tensor) -> torch.Tensor:
return torch.ones(2) + x + torch.ones(2)
# Generate program
program = (
to_edge(export(WrapperModule(f), (torch.randn(2),)))
.transform(
[
ConstPropPass(),
]
)
.to_executorch()
._emitter_output.program
)
test = Interpreter(program)
for val_idx in range(len(test.execution_plan.values)):
val = test.execution_plan.values[val_idx].val
if not (
isinstance(val, Tensor) and val.data_buffer_idx == 0
) and not isinstance(val, TensorList):
test.load_value(val_idx)
vlist = test.get_value_list()
for e in vlist:
if isinstance(e, torch.Tensor):
self.assertTrue(torch.allclose(e, torch.ones(2)))
# asserting only 2 constant Tensors exist in value list
self.assertEqual(len([e for e in vlist if isinstance(e, torch.Tensor)]), 2)
def test_operator_list(self) -> None:
class Op1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = torch.ones(2, 2)
self.b = 2 * torch.ones(2, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for _ in range(10):
z = self.a * x # mul
y = z - self.b # sub
return y
class Op2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = torch.ones(2, 2)
self.b = 2 * torch.ones(2, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for _ in range(10):
z = self.a % x # remainder
y = z / self.b # div
z = z + z # add
return y + z
# Generate a program with Op1's operations (mul, sub)
model1 = Op1()
inputs = (torch.ones(2, 2),)
program = (
to_edge(export(model1, inputs)).to_executorch()._emitter_output.program
)
# Initialize and test Interpreter -- assert that the operators are same as above
test = Interpreter(program)
self.assertEqual(
set(test.get_operators_list()),
{torch.ops.aten.mul.out, torch.ops.aten.sub.out},
)
# Generate a program with Op2's operations (remainder, div, add_, add)
model2 = Op2()
inputs = (torch.ones(2, 2),)
program = (
to_edge(export(model2, inputs)).to_executorch()._emitter_output.program
)
# Initialize and test Interpreter -- assert that the operators are same as above
test = Interpreter(program)
self.assertEqual(
set(test.get_operators_list()),
{
torch.ops.aten.remainder.Tensor_out,
torch.ops.aten.div.out,
torch.ops.aten.add.out,
},
)
def test_verification(self) -> None:
class Op2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = torch.ones(2, 2)
self.b = 2 * torch.ones(2, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for _ in range(10):
z = self.a % x # remainder
y = z / self.b # div
z = z + z # add
return y + z
# Generate a program with Op2's operations (remainder, div, add)
model2 = Op2()
inputs = torch.ones(2, 2)
exec_prog = to_edge(export(model2, (inputs,))).to_executorch()
exported_prog = exec_prog.exported_program()
res = exported_prog.module()(inputs)[0] # noqa
# Verifiers are run internally in to_edge, export, and to_executorch.
# If we make it this far then no errors were thrown in verification
class TestEdgeVerification(unittest.TestCase):
def test_edge_happy(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("a", torch.randn(1, 3, 100, 100))
def forward(self, x):
b = self.a + x
return torch._to_cpu([b, x])
m = TestModel()
egm = (
to_edge(
export(
m,
(torch.randn(1, 3, 100, 100).to(dtype=torch.int),),
)
)
.exported_program()
.graph_module
)
verifier = EXIREdgeDialectVerifier()
verifier(egm)
self.assertTrue(verifier.is_valid(egm))
def test_edge_happy_with_optional_tensor_input(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight, bias):
# weight and bias here are optional tensor inputs.
return torch.group_norm(x, 4, weight, bias)
m = TestModel()
egm = (
to_edge(
export(
m,
(torch.rand(16, 8, 32, 32), torch.rand(8), torch.rand(8)),
)
)
.exported_program()
.graph_module
)
verifier = EXIREdgeDialectVerifier()
verifier(egm)
self.assertTrue(verifier.is_valid(egm))
def test_edge_happy_with_empty_tensorlist_input(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch._to_cpu(x)
m = TestModel()
egm = (
to_edge(
export(
m,
([],),
)
)
.exported_program()
.graph_module
)
verifier = EXIREdgeDialectVerifier()
verifier(egm)
self.assertTrue(verifier.is_valid(egm))
def test_edge_sad(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("a", torch.randn(1, 3, 100, 100))
def forward(self, x):
b = self.a + x
return torch._to_cpu([b, x])
m = TestModel()
egm = export(
m,
(torch.randn(1, 3, 100, 100).to(dtype=torch.int),),
).graph_module
verifier = EXIREdgeDialectVerifier()
with self.assertRaises(SpecViolationError):
verifier(egm)
def test_edge_happy_with_edge_ops(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x + x
m = TestModel()
egm = (
to_edge(
export(
m,
(torch.randn(1, 3, 100, 100).to(dtype=torch.int),),
)
)
.exported_program()
.graph_module
)
verifier = EXIREdgeDialectVerifier()
verifier(egm)
self.assertTrue(verifier.is_valid(egm))
def test_edge_sad_with_edge_ops(self) -> None:
# log_softmax only takes float or double Tensor
m = torch.nn.LogSoftmax(dim=1)
with self.assertRaises(SpecViolationError):
_ = (
to_edge(
export(
m,
(torch.randn(1, 3, 100, 100).to(dtype=torch.bfloat16),),
)
)
.exported_program()
.graph_module
)