blob: da1ae0444ddfb90379e0930b3fc7f841daa6f93b [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from types import MappingProxyType
import torch
from executorch import exir
from executorch.exir.backend.backend_details import CompileSpec, ExportedProgram
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_pattern_op_partitions,
)
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import (
AnyOperatorSupport,
)
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
ExecutorBackend,
)
from executorch.exir.backend.test.op_partitioner_demo import (
AddAttributePartitionerDemo,
AllNodesPartitionerDemo,
)
from executorch.exir.backend.utils import get_delegates, tag_constant_data
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.tests.models import MLP
from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib
_load_for_executorch_from_buffer,
)
from executorch.extension.pytree import tree_flatten
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch.export import export, export_for_training
from torch.fx.passes.operator_support import any_chain
class TestPartitioner(unittest.TestCase):
def test_partitioner_with_spec(self):
# Create a custom partitioner with spec and check the spec can be accessed by not mutable.
class PartitionerWithSpec(Partitioner):
def __init__(self, spec) -> None:
super().__init__(spec)
self.op_support = any_chain(AnyOperatorSupport())
self.delegation_spec = DelegationSpec(
ExecutorBackend.__name__,
[CompileSpec(key, value) for key, value in self.spec.items()],
)
def partition(
self, edge_exported_program: ExportedProgram
) -> PartitionResult:
partition_tags = {}
partition_list = generate_pattern_op_partitions(
edge_exported_program.graph_module, op_support=self.op_support
)
for partition in partition_list:
for node in partition.nodes:
delegation_tag = f"tag{partition.id}"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)
mlp = MLP()
example_inputs = mlp.get_random_inputs()
model = export_for_training(mlp, example_inputs).module()
aten = export(model, example_inputs)
spec_key = "path"
spec_value = "/a/b/c/d"
spec = MappingProxyType({spec_key: spec_value})
my_partitioner = PartitionerWithSpec(spec)
edge = exir.to_edge(aten).to_backend(my_partitioner)
lowered_module_nodes = get_delegates(edge.exported_program().graph)
self.assertEqual(len(lowered_module_nodes), 1)
# Check the lowered module has correct compile spec
for lower_module_node in lowered_module_nodes:
lower_module = getattr(
edge.exported_program().graph_module, lower_module_node.name
)
self.assertEqual(lower_module.compile_specs[0].key, spec_key)
self.assertEqual(lower_module.compile_specs[0].value, spec_value)
# Check the custom partitioner has the correct spec
self.assertEqual(my_partitioner.spec[spec_key], spec_value)
with self.assertRaisesRegex(
TypeError,
"'mappingproxy' object does not support item assignment",
):
my_partitioner.spec[spec_key] = "new_value"
with self.assertRaisesRegex(
AttributeError,
"can't set attribute 'spec'",
):
my_partitioner.spec = {"new_key": "new_value"}
def test_bad_partitioner_tagged_output(self):
# Create a bad partitioner to tag output, which is not allowed.
class PartitionerTagOutput(Partitioner):
def __init__(self) -> None:
super().__init__()
self.delegation_spec = DelegationSpec(
ExecutorBackend.__name__,
[CompileSpec(key, value) for key, value in self.spec.items()],
)
def partition(
self, edge_exported_program: ExportedProgram
) -> PartitionResult:
partition_tags = {}
for node in edge_exported_program.graph.nodes:
if node.op == "output":
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)
mlp = MLP()
example_inputs = mlp.get_random_inputs()
model = export_for_training(mlp, example_inputs).module()
aten = export(model, example_inputs)
edge = exir.to_edge(aten)
with self.assertRaisesRegex(
RuntimeError,
"output node output should not be tagged",
):
_ = edge.to_backend(PartitionerTagOutput())
def test_bad_partitioner_tagged_model_input(self):
# Create a bad partitioner to tag an input that is neither params nor buffer, which is not allowed.
class PartitionerTagInput(Partitioner):
def __init__(self) -> None:
super().__init__()
self.delegation_spec = DelegationSpec(
ExecutorBackend.__name__,
[CompileSpec(key, value) for key, value in self.spec.items()],
)
def partition(
self, edge_exported_program: ExportedProgram
) -> PartitionResult:
partition_tags = {}
for node in edge_exported_program.graph.nodes:
if node.op == "placeholder":
if not is_param(edge_exported_program, node) and not is_buffer(
edge_exported_program, node
):
delegation_tag = "tag_" + str(node.meta["debug_handle"])
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)
mlp = MLP()
example_inputs = mlp.get_random_inputs()
model = export_for_training(mlp, example_inputs).module()
edge = exir.to_edge(export(model, example_inputs))
with self.assertRaisesRegex(
RuntimeError,
"placeholder node for non-params, non-buffer, and non-tensor constants should not be tagged",
):
_ = edge.to_backend(PartitionerTagInput())
class AddConst(torch.nn.Module):
def __init__(self):
super().__init__()
self.const1 = torch.ones(2, 2)
self.register_buffer("const2", torch.ones(2, 2), persistent=False)
self.register_parameter("const3", torch.nn.Parameter(torch.ones(2, 2)))
def forward(self, x):
return x + self.const1 + self.const2 + self.const3
def test_partitioner_not_tag_data(self):
"""
We test here that when partitioners do not explicitly tag constant data nodes,
then the partitioned ExportedProgram will not own the data. Instead the owning program
will still own the constant data and instead feed it as inputs to the partitioned
program
"""
class PartitionerNoTagData(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
ExecutorBackend.__name__,
[CompileSpec(key, value) for key, value in self.spec.items()],
)
def partition(
self, edge_exported_program: ExportedProgram
) -> PartitionResult:
partition_tags = {}
for node in edge_exported_program.graph.nodes:
if node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor
]:
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)
model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module()
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
delegated = edge.to_backend(PartitionerNoTagData())
# Check Owning Program still owns all constant data
owning_program = delegated.exported_program()
self.assertEqual(
len(owning_program.state_dict) + len(owning_program.constants), 3
)
self.assertEqual(
len(owning_program.graph_signature.buffers)
+ len(owning_program.graph_signature.lifted_tensor_constants),
2,
)
self.assertEqual(len(owning_program.graph_signature.parameters), 1)
# Check Lowered Module Exported Program does not have any constant data
lowered_module_nodes = get_delegates(delegated.exported_program().graph)
self.assertEqual(len(lowered_module_nodes), 1)
lowered_module_node = lowered_module_nodes[0]
# get call delegate node
call_delegate_node = list(lowered_module_node.users.keys())[0]
# 5 args to lowered module are: delegated_payload, x, const1, const2, const3
self.assertEqual(len(call_delegate_node.args), 5)
lower_module = getattr(
delegated.exported_program().graph_module, lowered_module_node.name
)
delegated_ep = lower_module.original_module
self.assertEqual(len(delegated_ep.state_dict), 0)
self.assertEqual(len(delegated_ep.graph_signature.buffers), 0)
self.assertEqual(len(delegated_ep.graph_signature.parameters), 0)
# check exported program is still runnable
output = delegated.exported_program().module()(torch.ones(2, 2))
reference_output = model(torch.ones(2, 2))
self.assertTrue(torch.allclose(reference_output, output))
def test_partitioner_tag_data(self):
"""
We test here that when partitioners explicitly tag constant data nodes,
then the partitioned ExportedProgram will own the data, and the data will
be removed from the owning program.
"""
class PartitionerTagData(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
ExecutorBackend.__name__,
[CompileSpec(key, value) for key, value in self.spec.items()],
)
def partition(
self, edge_exported_program: ExportedProgram
) -> PartitionResult:
partition_tags = {}
for node in edge_exported_program.graph.nodes:
if node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor
]:
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
if node.op == "placeholder" and (
is_param(edge_exported_program, node)
or is_buffer(edge_exported_program, node)
or is_lifted_tensor_constant(edge_exported_program, node)
):
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)
model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module()
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
delegated = edge.to_backend(PartitionerTagData())
# Check Owning Program still owns all constant data
owning_program = delegated.exported_program()
self.assertEqual(len(owning_program.state_dict), 0)
self.assertEqual(len(owning_program.graph_signature.buffers), 0)
self.assertEqual(len(owning_program.graph_signature.parameters), 0)
# Check Lowered Module Exported Program does not have any constant data
lowered_module_nodes = get_delegates(delegated.exported_program().graph)
self.assertEqual(len(lowered_module_nodes), 1)
lowered_module_node = lowered_module_nodes[0]
# get call delegate node
call_delegate_node = list(lowered_module_node.users.keys())[0]
# 5 args to lowered module are: delegated_payload, x
self.assertEqual(len(call_delegate_node.args), 2)
lower_module = getattr(
delegated.exported_program().graph_module, lowered_module_node.name
)
delegated_ep = lower_module.original_module
self.assertEqual(len(delegated_ep.state_dict) + len(delegated_ep.constants), 3)
self.assertEqual(
len(delegated_ep.graph_signature.buffers)
+ len(delegated_ep.graph_signature.lifted_tensor_constants),
2,
)
self.assertEqual(len(delegated_ep.graph_signature.parameters), 1)
# check exported program is still runnable
output = delegated.exported_program().module()(torch.ones(2, 2))
reference_output = model(torch.ones(2, 2))
self.assertTrue(torch.allclose(reference_output, output))
def test_partitioner_tag_only_params(self):
"""
We test here that when partitioners explicitly tag constant data nodes,
then the partitioned ExportedProgram will own the data, and the data will
be removed from the owning program.
"""
class PartitionerTagData(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
ExecutorBackend.__name__,
[CompileSpec(key, value) for key, value in self.spec.items()],
)
def partition(
self, edge_exported_program: ExportedProgram
) -> PartitionResult:
partition_tags = {}
for node in edge_exported_program.graph.nodes:
if node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor
]:
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
if node.op == "placeholder" and (
is_param(edge_exported_program, node)
):
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)
model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module()
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
delegated = edge.to_backend(PartitionerTagData())
# Check Owning Program still owns only buffers
owning_program = delegated.exported_program()
self.assertEqual(
len(owning_program.state_dict) + len(owning_program.constants), 2
)
self.assertEqual(
len(owning_program.graph_signature.buffers)
+ len(owning_program.graph_signature.lifted_tensor_constants),
2,
)
self.assertEqual(len(owning_program.graph_signature.parameters), 0)
# Check Lowered Module Exported Program does not own any buffers
lowered_module_nodes = get_delegates(delegated.exported_program().graph)
self.assertEqual(len(lowered_module_nodes), 1)
lowered_module_node = lowered_module_nodes[0]
# get call delegate node
call_delegate_node = list(lowered_module_node.users.keys())[0]
# 5 args to lowered module are: delegated_payload, x, buffer1, buffer2
self.assertEqual(len(call_delegate_node.args), 4)
lower_module = getattr(
delegated.exported_program().graph_module, lowered_module_node.name
)
delegated_ep = lower_module.original_module
self.assertEqual(len(delegated_ep.state_dict), 1)
self.assertEqual(len(delegated_ep.graph_signature.buffers), 0)
self.assertEqual(len(delegated_ep.graph_signature.parameters), 1)
# check exported program is still runnable
output = delegated.exported_program().module()(torch.ones(2, 2))
reference_output = model(torch.ones(2, 2))
self.assertTrue(torch.allclose(reference_output, output))
def test_partitioner_splits_constant_data(self):
"""
We test that we throw an error when constant data users are split
between different delegated payloads or owning program.
"""
class ReuseConstData(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.ones(2, 2)
def forward(self, x):
y = x + self.const
z = x - self.const
return y, z
class PartitionerTagData(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
ExecutorBackend.__name__,
[CompileSpec(key, value) for key, value in self.spec.items()],
)
def partition(
self, edge_exported_program: ExportedProgram
) -> PartitionResult:
partition_tags = {}
for node in edge_exported_program.graph.nodes:
if node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor
]:
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
if node.op == "placeholder" and (
is_param(edge_exported_program, node)
or is_buffer(edge_exported_program, node)
):
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)
inputs = (torch.ones(2, 2),)
model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module()
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
exec_prog = edge.to_backend(PartitionerTagData()).to_executorch()
executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
inputs_flattened, _ = tree_flatten(inputs)
# Send the input from server executor to client executor, and receive the result from client executor
_ = executorch_module.run_method("forward", inputs)
def test_partitioner_alert_split_constant_data(self):
"""
We test that we throw an error when constant data users are split
between different delegated payloads or owning program.
"""
class ReuseConstData(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.ones(2, 2)
def forward(self, x):
y = x + self.const
z = x - self.const
return y, z
class PartitionerTagData(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
ExecutorBackend.__name__,
[CompileSpec(key, value) for key, value in self.spec.items()],
)
def partition(
self, edge_exported_program: ExportedProgram
) -> PartitionResult:
partition_tags = {}
for node in edge_exported_program.graph.nodes:
if node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor
]:
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
if node.op == "placeholder" and (
is_param(edge_exported_program, node)
or is_buffer(edge_exported_program, node)
or is_lifted_tensor_constant(edge_exported_program, node)
):
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
node.meta["no_copy"] = True
partition_tags[delegation_tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)
model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module()
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
with self.assertRaises(RuntimeError) as error:
_ = edge.to_backend(PartitionerTagData())
self.assertTrue(
"is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)"
in str(error.exception),
)
def test_not_delegate_mutable_buffers(self) -> None:
"""
A test case to check the mutated buffer is not delegated. We'll need to add a test case
to consider when the delegate can consume the mutable buffer.
"""
class MutableStateModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("my_state", torch.zeros(1))
def forward(self, x):
y = x + self.my_state
self.my_state.add_(1)
return y
edge = exir.to_edge(
torch.export.export(
MutableStateModule(),
(torch.zeros(1),),
)
)
self.assertGreater(
len(edge.exported_program().graph_signature.buffers_to_mutate),
0,
"The test case should at leaset one mutable buffer",
)
class PartitionerTagData(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
ExecutorBackend.__name__,
[CompileSpec(key, value) for key, value in self.spec.items()],
)
def partition(
self, edge_exported_program: ExportedProgram
) -> PartitionResult:
partition_tags = {}
for node in edge_exported_program.graph.nodes:
if node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor
]:
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
tag_constant_data(edge_exported_program)
return PartitionResult(
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)
# Check the edge program inital buffers_to_mutate
mutate_op = "aten_add_tensor_1"
self.assertEqual(
edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
"my_state",
)
edge = edge.to_backend(PartitionerTagData())
# After to_backend, add is delegated and is no longer in buffers_to_mutate.
self.assertNotIn(
mutate_op,
edge.exported_program().graph_signature.buffers_to_mutate,
)
mutate_op = "getitem_1"
# Ensure the mutated buffer is not delegated, and the new mutate node is getitem (from call_delegate)
self.assertEqual(
edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
"my_state",
)
# Check the copy_ node is inserted
edge = edge.to_executorch()
copy_node = [
node
for node in edge.exported_program().graph.nodes
if node.op == "call_function"
and node.target == torch.ops.aten.copy_.default
]
self.assertEqual(len(copy_node), 1)
def test_buffer_mutation1(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("b", torch.ones(3, 3))
def forward(self, x):
self.b.add_(x)
return x + self.b
model_inputs = (torch.ones(3, 3),)
orig_res = TestModule()(*model_inputs)
edge_program = exir.to_edge(torch.export.export(TestModule(), model_inputs))
lowered = edge_program.to_backend(AddAttributePartitionerDemo())
self.assertTrue(
torch.allclose(lowered.exported_program().module()(*model_inputs), orig_res)
)
self.assertEqual(
len(lowered.exported_program().graph_signature.buffers_to_mutate),
0,
)
lowered_module_nodes = get_delegates(lowered.exported_program().graph)
self.assertEqual(len(lowered_module_nodes), 1)
lowered_module_node = lowered_module_nodes[0]
# get call delegate node
call_delegate_node = list(lowered_module_node.users.keys())[0]
self.assertEqual(len(call_delegate_node.args), 2)
lower_module = getattr(
lowered.exported_program().graph_module, lowered_module_node.name
)
delegated_ep = lower_module.original_module
self.assertEqual(len(delegated_ep.state_dict), 1)
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)
def test_buffer_mutation_llama_repro(self):
SHAPE = (2, 3)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("cache", torch.zeros(SHAPE, dtype=torch.float32))
def forward(self, q, k_val, input_pos):
q_T = q.transpose(0, 1)
k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val)
attn = k.mm(q_T)
return attn
q = torch.rand(1, 3)
k = torch.rand(1, 3)
example_inputs = (q, k, torch.tensor([1, 1]))
model = Model()
model.eval()
exir_program_aten = torch.export.export(model, example_inputs)
exir_program_aten.module()(*example_inputs)
edge_program_manager = exir.to_edge(exir_program_aten)
lowered = edge_program_manager.to_backend(AllNodesPartitionerDemo())
self.assertEqual(
len(lowered.exported_program().graph_signature.buffers_to_mutate),
0,
)
lowered_module_nodes = get_delegates(lowered.exported_program().graph)
self.assertEqual(len(lowered_module_nodes), 1)
lowered_module_node = lowered_module_nodes[0]
# get call delegate node
call_delegate_node = list(lowered_module_node.users.keys())[0]
self.assertEqual(len(call_delegate_node.args), 4)
lower_module = getattr(
lowered.exported_program().graph_module, lowered_module_node.name
)
delegated_ep = lower_module.original_module
self.assertEqual(len(delegated_ep.state_dict), 1)
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)
def test_buffer_mutation_unsupported(self):
SHAPE = (2, 3)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("state_1", torch.zeros(SHAPE, dtype=torch.float32))
def forward(self, x):
add = self.state_1.add_(x)
return add
model = Model()
model.eval()
example_inputs = (torch.randn(SHAPE),)
exir_program_aten = torch.export.export(model, example_inputs)
edge_program_manager = exir.to_edge(exir_program_aten)
with self.assertRaises(AssertionError):
edge_program_manager.to_backend(AddAttributePartitionerDemo())