blob: c514592fa9b7cc1d74e460548d833b76d4e9b327 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pye-strict
import unittest
from typing import Any, Dict
import torch
from executorch.exir import ExecutorchBackendConfig
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.error import ExportError
from executorch.exir.lowered_backend_module import get_lowered_submodules
from executorch.exir.pass_base import ExportPass
from executorch.exir.program._program import (
EdgeProgramManager,
ExecutorchProgramManager,
to_edge,
)
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)
from torch.export import export, ExportedProgram
def get_exported_programs() -> Dict[str, ExportedProgram]:
def forward(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
z = torch.mul(x, y)
return torch.add(z, x)
def foo(x: torch.Tensor) -> torch.Tensor:
return torch.add(x, torch.ones(1))
programs = {}
programs["forward"] = export(
forward,
args=(
torch.ones(1),
torch.zeros(1),
),
).run_decompositions()
programs["foo"] = export(
foo,
(torch.ones(1),),
).run_decompositions()
return programs
def get_config_methods() -> Dict[str, Any]:
def bam():
return 3
def bar():
return "bar"
return {"bam": bam(), "bar": bar()}
class AddToMulPassEdge(ExportPass):
def call_operator(self, op, args, kwargs, meta):
if op == exir_ops.edge.aten.add.Tensor:
return super().call_operator(
exir_ops.edge.aten.mul.Tensor, args, kwargs, meta
)
else:
return super().call_operator(op, args, kwargs, meta)
class TestProgramManagers(unittest.TestCase):
def test_edge_manager_basic_api(self):
edge_manager: EdgeProgramManager = to_edge(
get_exported_programs(), get_config_methods()
)
# test basic apis
self.assertEqual(edge_manager.methods, {"forward", "foo"})
self.assertEqual(edge_manager.config_methods, {"bam", "bar"})
# test dialect is correct
try:
EXIREdgeDialectVerifier()(
edge_manager.exported_program("forward").graph_module
)
EXIREdgeDialectVerifier()(edge_manager.exported_program("foo").graph_module)
except ExportError as e:
self.assertTrue(False, msg="Graph not in edge dialect : " + e.msg)
def test_executorch_manager_basic_api(self):
executorch_manager: ExecutorchProgramManager = to_edge(
get_exported_programs(), get_config_methods()
).to_executorch()
# test basic apis
self.assertEqual(executorch_manager.methods, {"forward", "foo"})
self.assertEqual(executorch_manager.config_methods, {"bam", "bar"})
# test that the emitted output is correct
self.assertEqual(
len(executorch_manager._emitter_output.program.execution_plan), 4
)
# test that the buffer is correct
executorch_module = _load_for_executorch_from_buffer(executorch_manager.buffer)
self.assertEqual(
executorch_module.run_method("forward", (torch.ones(1), torch.zeros(1)))[0],
torch.ones(1),
)
self.assertEqual(
executorch_module.run_method("foo", (torch.ones(1),))[0],
torch.ones(1) + torch.ones(1),
)
self.assertEqual(
executorch_module.run_method("bar", ())[0],
"bar",
)
self.assertEqual(
executorch_module.run_method("bam", ())[0],
3,
)
def test_edge_manager_transform(self):
edge_manager: EdgeProgramManager = to_edge(
get_exported_programs(), get_config_methods()
)
original_res = edge_manager.exported_program("forward")(
torch.ones(1), torch.ones(1)
)
# perform transformation
transformed_edge = edge_manager.transform(
[
AddToMulPassEdge(),
]
)
# still have all our methods
self.assertEqual(len(transformed_edge.methods), 2)
self.assertEqual(len(transformed_edge.config_methods), 2)
print(transformed_edge.exported_program("forward").graph_module.graph)
# transformation was applied
self.assertEqual(
transformed_edge.exported_program("forward")(torch.ones(1), torch.ones(1)),
torch.ones(1), # x * y * x
)
# original unchanged
self.assertEqual(
edge_manager.exported_program("forward")(torch.ones(1), torch.ones(1)),
original_res, # x * y + x
)
def test_transform_dict_api(self):
edge_manager = to_edge(get_exported_programs(), get_config_methods())
transformed_edge = edge_manager.transform(
{
"forward": [
AddToMulPassEdge(),
]
}
)
self.assertEqual(
transformed_edge.exported_program("forward")(torch.ones(1), torch.ones(1)),
torch.ones(1), # x * y * x
)
self.assertEqual(
transformed_edge.exported_program("foo")(
torch.ones(1),
),
torch.ones(1) + 1, # x + 1
)
def test_edge_to_backend_replaces_subgraph(self):
edge_manager: EdgeProgramManager = to_edge(
get_exported_programs(), get_config_methods()
)
delegate_manager: EdgeProgramManager = edge_manager.to_backend(
AddMulPartitionerDemo
)
forward_program = delegate_manager.exported_program("forward")
self.assertEqual(
forward_program(torch.ones(1), torch.ones(1)),
torch.ones(1) + 1, # x * y + x
)
add_nodes = [
node
for node in forward_program.graph_module.graph.nodes
if node.op == "call_function"
and node.target == exir_ops.edge.aten.add.Tensor
]
self.assertEqual(len(add_nodes), 0)
foo_program = delegate_manager.exported_program("foo")
add_nodes = [
node
for node in foo_program.graph_module.graph.nodes
if node.op == "call_function"
and node.target == exir_ops.edge.aten.add.Tensor
]
self.assertEqual(len(add_nodes), 0)
lowered_submods = get_lowered_submodules(foo_program.graph_module)
self.assertEqual(len(lowered_submods), 1)
# original unchanged
lowered_submods = get_lowered_submodules(
edge_manager.exported_program("forward").graph_module
)
self.assertEqual(len(lowered_submods), 0)
# two delegate blobs for forward and foo
self.assertEqual(
len(
delegate_manager.to_executorch(
ExecutorchBackendConfig(extract_segments=True)
)
._emitter_output.program.execution_plan[0]
.delegates
),
1,
)
self.assertEqual(
len(
delegate_manager.to_executorch(
ExecutorchBackendConfig(extract_segments=True)
)
._emitter_output.program.execution_plan[1]
.delegates
),
1,
)
def test_edge_to_backend_selective(self):
edge_manager: EdgeProgramManager = to_edge(
get_exported_programs(), get_config_methods()
)
delegate_manager: EdgeProgramManager = edge_manager.to_backend(
{"forward": AddMulPartitionerDemo}
)
forward_program = delegate_manager.exported_program("forward")
self.assertEqual(
forward_program(torch.ones(1), torch.ones(1)),
torch.ones(1) + 1, # x * y + x
)
add_nodes = [
node
for node in forward_program.graph_module.graph.nodes
if node.op == "call_function"
and node.target == exir_ops.edge.aten.add.Tensor
]
self.assertEqual(len(add_nodes), 0)
# foo unchanged
lowered_submods = get_lowered_submodules(
delegate_manager.exported_program("foo").graph_module
)
self.assertEqual(len(lowered_submods), 0)
# original unchanged
lowered_submods = get_lowered_submodules(
edge_manager.exported_program("forward").graph_module
)
self.assertEqual(len(lowered_submods), 0)
# one delegate blob for forward
self.assertEqual(
len(
delegate_manager.to_executorch(ExecutorchBackendConfig())
._emitter_output.program.execution_plan[0] # foo
.delegates
),
0,
)
self.assertEqual(
len(
delegate_manager.to_executorch(ExecutorchBackendConfig())
._emitter_output.program.execution_plan[1] # forward
.delegates
),
1,
)