blob: c0744146c94cd9e0c67311eaf1084ec7cc2ea232 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import inspect
import os
import sys
from typing import Any, Dict, List, Type
import torch
from executorch.exir import CaptureConfig
from executorch.exir.passes import MemoryPlanningPass
from executorch.test.end2end.exported_module import ExportedModule
from torch import nn
from torch._export import dynamic_dim
"""Traces and exports nn.Modules to ExecuTorch .pte program files.
This tool mainly exists to export programs for C++ tests, but can also
be used to export models manually.
"""
#
# Module definitions.
#
# If we ever have more than a handful, consider splitting into multiple files.
#
class ModuleBasic(nn.Module):
def __init__(self):
super(ModuleBasic, self).__init__()
def forward(self, x):
return torch.sin(x).max()
def get_random_inputs(self):
return (torch.randn(100),)
@staticmethod
def get_export_kwargs() -> Dict[str, Any]:
"""Returns custom trace params for ExportedModule."""
return {
# aten::max.default does not have an out variant.
"ignore_to_out_var_failure": True,
}
class ModuleIndex(nn.Module):
def __init__(self):
super(ModuleIndex, self).__init__()
def forward(self, x):
# Weird index that happens to generate a None in torch.index.Tensor_out
# which is desirable for deserialization testing. A modified form of
# an example index from https://pytorch.org/cppdocs/notes/tensor_indexing.html.
return x[1::2, torch.tensor([1, 2])]
def get_random_inputs(self):
return (torch.randn(10, 10, 10),)
class ModuleNoOp(nn.Module):
def __init__(self):
super(ModuleNoOp, self).__init__()
def forward(self, x, y):
return (x, y)
def get_random_inputs(self):
return (torch.randn(2, 2), torch.randn(2, 2))
class ModuleAdd(nn.Module):
def __init__(self):
super(ModuleAdd, self).__init__()
def forward(self, x, y, alpha):
return torch.add(x, y, alpha=alpha)
def get_random_inputs(self):
return (torch.randn(2, 2), torch.randn(2, 2), 1.0)
class ModuleDynamicCatUnallocatedIO(nn.Module):
def __init__(self):
super(ModuleDynamicCatUnallocatedIO, self).__init__()
# TODO(T163238401)
self._inputs = (torch.randn(3, 4),)
def forward(self, k):
k = torch.cat((k, torch.ones(1, 4)))
return k
def get_random_inputs(self):
return self._inputs
def get_constraints(self):
return [
dynamic_dim(self._inputs[0], 0) <= 3,
]
def get_memory_planning_pass(self):
return MemoryPlanningPass(
memory_planning_algo="greedy",
alloc_graph_input=False,
alloc_graph_output=False,
)
@staticmethod
def get_export_kwargs():
return {"capture_config": CaptureConfig(pt2_mode=True, enable_aot=True)}
class ModuleLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = 3 * torch.ones(2, 2, dtype=torch.float)
self.b = 2 * torch.ones(2, 2, dtype=torch.float)
def forward(self, x: torch.Tensor):
out_1 = torch.mul(self.a, x)
out_2 = torch.add(out_1, self.b)
return out_2
def get_random_inputs(self):
return (torch.ones(2, 2, dtype=torch.float),)
class ModuleMultipleEntry(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = 3 * torch.ones(2, 2, dtype=torch.float)
self.b = 2 * torch.ones(2, 2, dtype=torch.float)
def forward(self, x: torch.Tensor):
return x + self.a
def forward2(self, x: torch.Tensor):
return x + self.a + self.b
def get_random_inputs(self):
return (torch.ones(2, 2, dtype=torch.float),)
@staticmethod
def get_method_names_to_export() -> List[str]:
return ["forward", "forward2"]
#
# Main logic.
#
def export_module_to_program(module_class: Type[nn.Module]):
"""Exports the module and returns the serialized program data."""
# Look for an optional @staticmethod that defines custom trace params.
export_kwargs: Dict[str, Any] = {}
if hasattr(module_class, "get_export_kwargs"):
# pyre-ignore[16]: pyre doesn't know about get_export_kwargs.
export_kwargs = module_class.get_export_kwargs()
if hasattr(module_class, "get_method_names_to_export"):
# pyre-ignore[16]: pyre doesn't know about get_export_kwargs.
methods = module_class.get_method_names_to_export()
else:
methods = ["forward"]
module = ExportedModule.export(module_class, methods, **export_kwargs)
return module.executorch_program.buffer
def main() -> None:
# These args are optimized for genrule usage. There's a lot of startup
# overhead for this tool, so it's faster to export multiple models at once
# when possible.
parser = argparse.ArgumentParser(
prog="export_program",
description="Exports nn.Module models to ExecuTorch .pte files",
)
parser.add_argument(
"--modules",
help="Comma-separated list of model class names to export; "
+ "e.g., '--modules=ModuleBasic,ModuleAdd'",
type=lambda s: [item.strip() for item in s.split(",")],
)
parser.add_argument(
"--outdir",
type=str,
required=True,
help="Path to the directory to write <classname>.pte files to",
)
args = parser.parse_args()
# Find the classes to export. Only looks in this module for now, but could
# be extended to look in other modules if helpful.
module_names_to_classes: Dict[str, Type[nn.Module]] = {}
for module in args.modules:
module_class = getattr(sys.modules[__name__], module, None)
if not (inspect.isclass(module_class) and issubclass(module_class, nn.Module)):
raise NameError(f"Could not find nn.Module class named '{module}'")
module_names_to_classes[module] = module_class
# Export and write to the output files.
os.makedirs(args.outdir, exist_ok=True)
for module_name, module_class in module_names_to_classes.items():
outfile = os.path.join(args.outdir, f"{module_name}.pte")
with open(outfile, "wb") as fp:
fp.write(export_module_to_program(module_class))
print(f"Exported {module_name} and wrote program data to {outfile}")
if __name__ == "__main__":
main()