blob: 2c68920ff34758d3829bb5a736597b17489f057a [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import io
import unittest
from typing import Tuple
import executorch.exir as exir
import torch
from executorch.exir import to_edge
from executorch.exir.backend.backend_api import CompileSpec, to_backend
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
)
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
from executorch.exir.serde.serialize import deserialize, serialize
from torch import nn
from torch.export import export
from torch.export.exported_program import ExportedProgram as TorchExportedProgram
from torch.utils import _pytree as pytree
# Tests for serializing to json and back
class TestSerde(unittest.TestCase):
def check_ep(
self,
ep1: TorchExportedProgram,
ep2: TorchExportedProgram,
inputs: Tuple[exir.Value, ...],
) -> None:
"""
Checks if two graphs are equivalent
"""
orig_outputs = ep1.module()(*inputs)
loaded_outputs = ep2.module()(*inputs)
flat_orig_outputs, _ = pytree.tree_flatten(orig_outputs)
flat_loaded_outputs, _ = pytree.tree_flatten(loaded_outputs)
for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True):
self.assertTrue(torch.allclose(orig, loaded))
# pyre-ignore
def check_serde(self, m, inputs, check_executorch=True) -> None:
aten = export(m, inputs)
aten_new = deserialize(serialize(aten))
self.check_ep(aten, aten_new, inputs)
edge = to_edge(aten)
edge_new = deserialize(serialize(edge.exported_program()))
self.check_ep(edge.exported_program(), edge_new, inputs)
buffer = io.BytesIO()
exir.save(edge.exported_program(), buffer)
buffer.seek(0)
loaded_ep = exir.load(buffer)
self.check_ep(edge.exported_program(), loaded_ep, inputs)
executorch = edge.to_executorch().exported_program()
executorch_new = deserialize(serialize(executorch))
if check_executorch:
with torch.no_grad():
self.check_ep(executorch, executorch_new, inputs)
buffer = io.BytesIO()
exir.save(executorch, buffer)
buffer.seek(0)
loaded_ep = exir.load(buffer)
self.check_ep(executorch, loaded_ep, inputs)
def test_basic(self) -> None:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = x + x
x = x * x
x = x / x
return x, x.clone()
inputs = (torch.ones([512], requires_grad=True),)
self.check_serde(MyModule(), inputs)
def test_to_out_variant_singleon_tensor_list(self) -> None:
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.split(x, 10)
def get_random_inputs(self):
return (torch.randn(10),)
model = MyModel()
inputs = model.get_random_inputs()
# We set check_executorch to false for this test because this triggers
# an edge case where calling .module() on the executorch exported program
# will cause an unlift pass to be run on the graph and dead code elimination
# will be subsequently run, which essentially causes the split_copy op to be
# removed.
self.check_serde(model, inputs, check_executorch=False)
def test_to_out_variant_multiple_out(self) -> None:
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
values, indices = torch.topk(x, 5)
return (values, indices)
def get_random_inputs(self):
return (torch.randn(10),)
model = MyModel()
inputs = model.get_random_inputs()
self.check_serde(model, inputs)
def test_delegate(self) -> None:
class SinModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.sin(x)
sin_module = SinModule()
model_inputs = (torch.ones(1),)
edgeir_m = to_edge(export(sin_module, model_inputs))
max_value = model_inputs[0].shape[0]
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
lowered_sin_module = to_backend(
BackendWithCompilerDemo.__name__, edgeir_m.exported_program(), compile_specs
)
class CompositeModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lowered_linear_sin = lowered_sin_module
def forward(self, x):
return self.lowered_linear_sin(x)
composite_model = CompositeModule()
model_inputs = (torch.ones(1),)
composite_model(*model_inputs)
edge = to_edge(export(composite_model, model_inputs))
edge_new = deserialize(serialize(edge.exported_program()))
self.check_ep(edge.exported_program(), edge_new, model_inputs)
def test_model_with_weights(self) -> None:
class LinearAdd(nn.Module):
def __init__(self, M: int, N: int):
super().__init__()
self.M = M
self.N = N
self.linear = torch.nn.Linear(M, N)
def forward(self, x, y):
x = self.linear(x)
y = self.linear(y)
return torch.add(x, y)
@classmethod
def _get_random_inputs(cls):
return (torch.rand(128, 20), torch.rand(128, 20))
linear_add = LinearAdd(20, 30)
model_inputs = LinearAdd._get_random_inputs()
self.check_serde(linear_add, model_inputs)
def test_delegate_partitioner(self) -> None:
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, x, b):
y = torch.mm(a, x)
z = y + b
a = z - a
y = torch.mm(a, x)
z = y + b
return z
m = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
ep = to_edge(export(m, inputs))
edge = ep.to_backend(AddMulPartitionerDemo())
edge_new = deserialize(serialize(edge.exported_program()))
self.check_ep(edge.exported_program(), edge_new, inputs)
def test_meta_stack_trace_module_hierarchy(self) -> None:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv_layer = nn.Conv2d(
in_channels=1, out_channels=64, kernel_size=3, padding=1
)
def forward(self, x):
return self.conv_layer(x)
m = Model()
inputs = (torch.randn(1, 1, 32, 32),)
metadata = ()
edge = to_edge(export(m, inputs))
for node in edge.exported_program().graph_module.graph.nodes:
if "convolution" in str(node.target):
metadata = (
node.meta.get("stack_trace"),
node.meta.get("nn_module_stack"),
)
metadata_serde = ()
edge_new = deserialize(serialize(edge.exported_program()))
for node in edge_new.graph_module.graph.nodes:
if "convolution" in str(node.target):
metadata_serde = (
node.meta.get("stack_trace"),
node.meta.get("nn_module_stack"),
)
self.assertTrue(len(metadata) != 0 and len(metadata_serde) != 0)
self.assertTrue(
all(val is not None for val in metadata)
and all(val is not None for val in metadata_serde)
)
self.assertEqual(metadata[0], metadata_serde[0])
self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys()))