| #!/usr/bin/env fbpython |
| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import argparse |
| import os |
| import sys |
| from typing import Any, List |
| |
| import yaml |
| |
| from torchgen.code_template import CodeTemplate |
| |
| |
| ops_and_dtypes_template_str = """((exec_aten::string_view(operator_name).compare("$operator_name") == 0)\n && ($dtype_checks))""" |
| ops_and_dtypes_template = CodeTemplate(ops_and_dtypes_template_str) |
| |
| selected_kernel_dtypes_h_template_str = """#pragma once |
| /** |
| * Generated by executorch/codegen/tools/gen_selected_op_variants.py |
| */ |
| |
| inline constexpr bool should_include_kernel_dtype( |
| const char *operator_name, |
| exec_aten::ScalarType scalar_type |
| ) { |
| return $body; |
| } |
| """ |
| selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str) |
| |
| # enum from: https://github.com/pytorch/executorch/blob/main/runtime/core/portable_type/scalar_type.h |
| dtype_enum_to_type = { |
| "0": "Byte", |
| "1": "Char", |
| "2": "Short", |
| "3": "Int", |
| "4": "Long", |
| "5": "Half", |
| "6": "Float", |
| "7": "Double", |
| "8": "ComplexHalf", |
| "9": "ComplexFloat", |
| "10": "ComplexDouble", |
| "11": "Bool", |
| "12": "QInt8", |
| "13": "QUInt8", |
| "14": "QInt32", |
| "15": "BFloat16", |
| "16": "QUInt4x2", |
| "17": "QUInt2x4", |
| "18": "Bits1x8", |
| "19": "Bits2x4", |
| "20": "Bits4x2", |
| "21": "Bits8", |
| "22": "Bits16", |
| } |
| |
| |
| def write_selected_op_variants(yaml_file_path: str, output_dir: str) -> None: |
| with open(yaml_file_path, "r") as selected_operators_file: |
| # Collect et_kernel_metadata from selected_operators.yaml and extract dtypes |
| # Example format: v1/6;0,1|6;0,1|6;0,1|6;0,1 # Float, 0, 1 |
| selected_operators_dict = yaml.safe_load(selected_operators_file) |
| et_kernel_metadata = selected_operators_dict.get("et_kernel_metadata", {}) |
| assert isinstance(et_kernel_metadata, dict) |
| body = "true" |
| body_parts = [] |
| for operator_name, kernel_metadata_str in et_kernel_metadata.items(): |
| tensor_meta = [] |
| for kernel_metadata in kernel_metadata_str: |
| if kernel_metadata == "default": |
| break |
| else: |
| x = kernel_metadata.split("/")[1] |
| tensor_meta.extend(x.split("|")) |
| conditions = ["true"] |
| if len(tensor_meta) > 0: |
| dtype_set = set([x.split(";")[0] for x in tensor_meta]) |
| dtype_list = sorted([dtype_enum_to_type[x] for x in dtype_set]) |
| conditions = [ |
| "scalar_type == exec_aten::ScalarType::" + x for x in dtype_list |
| ] |
| body_parts.append( |
| ops_and_dtypes_template.substitute( |
| operator_name=operator_name.replace("aten::", ""), |
| dtype_checks=" || ".join(conditions), |
| ), |
| ) |
| body = "\n || ".join(body_parts) |
| header_contents = selected_kernel_dtypes_h_template.substitute(body=body) |
| selected_op_variants_path = os.path.join(output_dir, "selected_op_variants.h") |
| with open(selected_op_variants_path, "wb") as out_file: |
| out_file.write(header_contents.encode("utf-8")) |
| |
| |
| def main(argv: List[Any]) -> None: |
| parser = argparse.ArgumentParser(description="Generate operator lists") |
| parser.add_argument( |
| "--yaml-file-path", |
| "--yaml_file_path", |
| help=("The directory where selected_operators.yaml was generated)"), |
| required=True, |
| ) |
| parser.add_argument( |
| "--output-dir", |
| "--output_dir", |
| help=( |
| "The directory to store the output yaml files (selected_op_variants.h, " |
| + "selected_kernel_dtypes.h, selected_operators.yaml)" |
| ), |
| required=True, |
| ) |
| |
| options = parser.parse_args(argv) |
| write_selected_op_variants(options.yaml_file_path, options.output_dir) |
| |
| |
| if __name__ == "__main__": |
| main(sys.argv[1:]) |