blob: a455afb50ee03902dd45853def37beff0be57553 [file] [log] [blame]
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Generates a template `functions.yaml` from a model binary. Ignoring all custom ops
import argparse
import os
import sys
from typing import Any, List
import torch
import yaml
from executorch.codegen.tools.yaml_util import BlankLineDumper
from executorch.exir._serialize import _deserialize_pte_binary
from executorch.exir.schema import Operator
def get_operators(model_file: str) -> List[Operator]:
print("Processing model file: ", model_file)
with open(model_file, "rb") as f:
flatbuffer = f.read()
program = _deserialize_pte_binary(flatbuffer)
print(f"Program loaded from model file: {model_file}")
operators = program.execution_plan[0].operators
return operators
def dump_yaml(model_file: str, output_file: str) -> None:
ops = get_operators(model_file)
m = []
for op in ops:
if op.name.startswith("aten::"):
schemas = torch._C._jit_get_schemas_for_operator(op.name)
m.extend(filter(lambda s: s.overload_name == op.overload, schemas))
else:
print(f"Warning: not generating template for {op.name}")
output = []
for s in m:
print(str(s))
name = s.name.replace("aten::", "torch::executor::")
output.append(
{
"func": str(s),
"variants": "function",
"dispatch": {
"CPU": f"{name}_{s.overload_name}",
},
}
)
with open(output_file, "w") as f:
yaml.dump(
output,
f,
Dumper=BlankLineDumper,
default_flow_style=False,
sort_keys=False,
width=1000,
)
def main(args: List[Any]) -> None:
"""This binary generates a template functions.yaml which will be consumed by ExecuTorch codegen.
It reads the model file, deserialize it and dumps all the operators into a new functions.yaml.
The generated file contains placeholder kernels, it needs to be updated with proper kernel names.
"""
parser = argparse.ArgumentParser(
description="Generate operator list from a model file"
)
parser.add_argument(
"--output_path",
help=("The path to the output yaml file (functions.yaml)"),
required=True,
)
parser.add_argument(
"--model_file_path",
help=("Path to an executorch program"),
required=False,
)
options = parser.parse_args(args)
assert options.model_file_path, "Need to provide a model_file_path."
assert os.path.isfile(
options.model_file_path
), "The value for --model_file_path needs to be a valid file."
dump_yaml(
model_file=options.model_file_path,
output_file=options.output_path if options.output_path else "./functions.yaml",
)
if __name__ == "__main__":
main(sys.argv[1:])