blob: 9e8c5b40979f6456d5bcf359c8c6deedfc8d18d4 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Example script for exporting simple models to flatbuffer
import argparse
import executorch.exir as exir
from ..models import MODEL_NAME_TO_MODEL
from .utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG
def export_to_pte(model_name, model, example_inputs):
m = model.eval()
edge = exir.capture(m, example_inputs, _CAPTURE_CONFIG).to_edge(
_EDGE_COMPILE_CONFIG
)
print("Exported graph:\n", edge.exported_program.graph)
exec_prog = edge.to_executorch()
buffer = exec_prog.buffer
filename = f"{model_name}.pte"
print(f"Saving exported program to {filename}")
with open(filename, "wb") as file:
file.write(buffer)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name",
required=True,
help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
)
args = parser.parse_args()
if args.model_name not in MODEL_NAME_TO_MODEL:
raise RuntimeError(
f"Model {args.model_name} is not a valid name. "
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
)
model, example_inputs = MODEL_NAME_TO_MODEL[args.model_name]()
export_to_pte(args.model_name, model, example_inputs)