blob: 76d2e974f1a6467f92d01e39489904bd075b196d [file] [log] [blame]
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import sys
from typing import Any, List
import yaml
from torchgen.code_template import CodeTemplate
ops_and_dtypes_template_str = """((exec_aten::string_view(operator_name).compare("$operator_name") == 0)\n && ($dtype_checks))"""
ops_and_dtypes_template = CodeTemplate(ops_and_dtypes_template_str)
selected_kernel_dtypes_h_template_str = """#pragma once
/**
* Generated by executorch/codegen/tools/gen_selected_op_variants.py
*/
inline constexpr bool should_include_kernel_dtype(
const char *operator_name,
exec_aten::ScalarType scalar_type
) {
return $body;
}
"""
selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str)
# enum from: https://github.com/pytorch/executorch/blob/main/runtime/core/portable_type/scalar_type.h
dtype_enum_to_type = {
"0": "Byte",
"1": "Char",
"2": "Short",
"3": "Int",
"4": "Long",
"5": "Half",
"6": "Float",
"7": "Double",
"8": "ComplexHalf",
"9": "ComplexFloat",
"10": "ComplexDouble",
"11": "Bool",
"12": "QInt8",
"13": "QUInt8",
"14": "QInt32",
"15": "BFloat16",
"16": "QUInt4x2",
"17": "QUInt2x4",
"18": "Bits1x8",
"19": "Bits2x4",
"20": "Bits4x2",
"21": "Bits8",
"22": "Bits16",
}
def write_selected_op_variants(yaml_file_path: str, output_dir: str) -> None:
with open(yaml_file_path, "r") as selected_operators_file:
# Collect et_kernel_metadata from selected_operators.yaml and extract dtypes
# Example format: v1/6;0,1|6;0,1|6;0,1|6;0,1 # Float, 0, 1
selected_operators_dict = yaml.safe_load(selected_operators_file)
et_kernel_metadata = selected_operators_dict.get("et_kernel_metadata", {})
assert isinstance(et_kernel_metadata, dict)
body = "true"
body_parts = []
for operator_name, kernel_metadata_str in et_kernel_metadata.items():
tensor_meta = []
for kernel_metadata in kernel_metadata_str:
if kernel_metadata == "default":
break
else:
x = kernel_metadata.split("/")[1]
tensor_meta.extend(x.split("|"))
conditions = ["true"]
if len(tensor_meta) > 0:
dtype_set = set([x.split(";")[0] for x in tensor_meta])
dtype_list = sorted([dtype_enum_to_type[x] for x in dtype_set])
conditions = [
"scalar_type == exec_aten::ScalarType::" + x for x in dtype_list
]
body_parts.append(
ops_and_dtypes_template.substitute(
operator_name=operator_name.replace("aten::", ""),
dtype_checks=" || ".join(conditions),
),
)
body = "\n || ".join(body_parts)
header_contents = selected_kernel_dtypes_h_template.substitute(body=body)
selected_op_variants_path = os.path.join(output_dir, "selected_op_variants.h")
with open(selected_op_variants_path, "wb") as out_file:
out_file.write(header_contents.encode("utf-8"))
def main(argv: List[Any]) -> None:
parser = argparse.ArgumentParser(description="Generate operator lists")
parser.add_argument(
"--yaml-file-path",
"--yaml_file_path",
help=("The directory where selected_operators.yaml was generated)"),
required=True,
)
parser.add_argument(
"--output-dir",
"--output_dir",
help=(
"The directory to store the output yaml files (selected_op_variants.h, "
+ "selected_kernel_dtypes.h, selected_operators.yaml)"
),
required=True,
)
options = parser.parse_args(argv)
write_selected_op_variants(options.yaml_file_path, options.output_dir)
if __name__ == "__main__":
main(sys.argv[1:])