| #!/usr/bin/env python3 |
| |
| """ Generates PyTorch ONNX Export Diagnostic rules for C++, Python and documentations. |
| The rules are defined in torch/onnx/_internal/diagnostics/rules.yaml. |
| |
| Usage: |
| |
| python -m tools.onnx.gen_diagnostics \ |
| torch/onnx/_internal/diagnostics/rules.yaml \ |
| torch/onnx/_internal/diagnostics \ |
| torch/csrc/onnx/diagnostics/generated \ |
| torch/docs/source |
| """ |
| |
| import argparse |
| import os |
| import string |
| import subprocess |
| import textwrap |
| from typing import Any, Mapping, Sequence |
| |
| import yaml |
| |
| from torchgen import utils as torchgen_utils |
| from torchgen.yaml_utils import YamlLoader |
| |
| |
| _RULES_GENERATED_COMMENT = """\ |
| GENERATED CODE - DO NOT EDIT DIRECTLY |
| This file is generated by gen_diagnostics.py. |
| See tools/onnx/gen_diagnostics.py for more information. |
| |
| Diagnostic rules for PyTorch ONNX export. |
| """ |
| |
| _PY_RULE_CLASS_COMMENT = """\ |
| GENERATED CODE - DO NOT EDIT DIRECTLY |
| The purpose of generating a class for each rule is to override the `format_message` |
| method to provide more details in the signature about the format arguments. |
| """ |
| |
| _PY_RULE_CLASS_TEMPLATE = """\ |
| class _{pascal_case_name}(infra.Rule): |
| \"\"\"{short_description}\"\"\" |
| def format_message( # type: ignore[override] |
| self, |
| {message_arguments} |
| ) -> str: |
| \"\"\"Returns the formatted default message of this Rule. |
| |
| Message template: {message_template} |
| \"\"\" |
| return self.message_default_template.format({message_arguments_assigned}) |
| |
| def format( # type: ignore[override] |
| self, |
| level: infra.Level, |
| {message_arguments} |
| ) -> Tuple[infra.Rule, infra.Level, str]: |
| \"\"\"Returns a tuple of (Rule, Level, message) for this Rule. |
| |
| Message template: {message_template} |
| \"\"\" |
| return self, level, self.format_message({message_arguments_assigned}) |
| |
| """ |
| |
| _PY_RULE_COLLECTION_FIELD_TEMPLATE = """\ |
| {snake_case_name}: _{pascal_case_name} = dataclasses.field( |
| default=_{pascal_case_name}.from_sarif(**{sarif_dict}), |
| init=False, |
| ) |
| \"\"\"{short_description}\"\"\" |
| """ |
| |
| _CPP_RULE_TEMPLATE = """\ |
| /** |
| * @brief {short_description} |
| */ |
| {name}, |
| """ |
| |
| _RuleType = Mapping[str, Any] |
| |
| |
| def _kebab_case_to_snake_case(name: str) -> str: |
| return name.replace("-", "_") |
| |
| |
| def _kebab_case_to_pascal_case(name: str) -> str: |
| return "".join(word.capitalize() for word in name.split("-")) |
| |
| |
| def _format_rule_for_python_class(rule: _RuleType) -> str: |
| pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) |
| short_description = rule["short_description"]["text"] |
| message_template = rule["message_strings"]["default"]["text"] |
| field_names = [ |
| field_name |
| for _, field_name, _, _ in string.Formatter().parse(message_template) |
| if field_name is not None |
| ] |
| for field_name in field_names: |
| assert isinstance( |
| field_name, str |
| ), f"Unexpected field type {type(field_name)} from {field_name}. " |
| "Field name must be string.\nFull message template: {message_template}" |
| assert ( |
| not field_name.isnumeric() |
| ), f"Unexpected numeric field name {field_name}. " |
| "Only keyword name formatting is supported.\nFull message template: {message_template}" |
| message_arguments = ", ".join(field_names) |
| message_arguments_assigned = ", ".join( |
| [f"{field_name}={field_name}" for field_name in field_names] |
| ) |
| return _PY_RULE_CLASS_TEMPLATE.format( |
| pascal_case_name=pascal_case_name, |
| short_description=short_description, |
| message_template=repr(message_template), |
| message_arguments=message_arguments, |
| message_arguments_assigned=message_arguments_assigned, |
| ) |
| |
| |
| def _format_rule_for_python_field(rule: _RuleType) -> str: |
| snake_case_name = _kebab_case_to_snake_case(rule["name"]) |
| pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) |
| short_description = rule["short_description"]["text"] |
| |
| return _PY_RULE_COLLECTION_FIELD_TEMPLATE.format( |
| snake_case_name=snake_case_name, |
| pascal_case_name=pascal_case_name, |
| sarif_dict=rule, |
| short_description=short_description, |
| ) |
| |
| |
| def _format_rule_for_cpp(rule: _RuleType) -> str: |
| name = f"k{_kebab_case_to_pascal_case(rule['name'])}" |
| short_description = rule["short_description"]["text"] |
| return _CPP_RULE_TEMPLATE.format(name=name, short_description=short_description) |
| |
| |
| def gen_diagnostics_python( |
| rules: Sequence[_RuleType], out_py_dir: str, template_dir: str |
| ) -> None: |
| rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules] |
| rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules] |
| |
| fm = torchgen_utils.FileManager( |
| install_dir=out_py_dir, template_dir=template_dir, dry_run=False |
| ) |
| fm.write_with_template( |
| "_rules.py", |
| "rules.py.in", |
| lambda: { |
| "generated_comment": _RULES_GENERATED_COMMENT, |
| "generated_rule_class_comment": _PY_RULE_CLASS_COMMENT, |
| "rule_classes": "\n".join(rule_class_lines), |
| "rules": textwrap.indent("\n".join(rule_field_lines), " " * 4), |
| }, |
| ) |
| _lint_file(os.path.join(out_py_dir, "_rules.py")) |
| |
| |
| def gen_diagnostics_cpp( |
| rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str |
| ) -> None: |
| rule_lines = [_format_rule_for_cpp(rule) for rule in rules] |
| rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules] |
| |
| fm = torchgen_utils.FileManager( |
| install_dir=out_cpp_dir, template_dir=template_dir, dry_run=False |
| ) |
| fm.write_with_template( |
| "rules.h", |
| "rules.h.in", |
| lambda: { |
| "generated_comment": textwrap.indent( |
| _RULES_GENERATED_COMMENT, |
| " * ", |
| predicate=lambda x: True, # Don't ignore empty line |
| ), |
| "rules": textwrap.indent("\n".join(rule_lines), " " * 2), |
| "py_rule_names": textwrap.indent("\n".join(rule_names), " " * 4), |
| }, |
| ) |
| _lint_file(os.path.join(out_cpp_dir, "rules.h")) |
| |
| |
| def gen_diagnostics_docs( |
| rules: Sequence[_RuleType], out_docs_dir: str, template_dir: str |
| ) -> None: |
| # TODO: Add doc generation in a follow-up PR. |
| pass |
| |
| |
| def _lint_file(file_path: str) -> None: |
| p = subprocess.Popen(["lintrunner", "-a", file_path]) |
| p.wait() |
| |
| |
| def gen_diagnostics( |
| rules_path: str, |
| out_py_dir: str, |
| out_cpp_dir: str, |
| out_docs_dir: str, |
| ) -> None: |
| with open(rules_path) as f: |
| rules = yaml.load(f, Loader=YamlLoader) |
| |
| template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates") |
| |
| gen_diagnostics_python( |
| rules, |
| out_py_dir, |
| template_dir, |
| ) |
| |
| gen_diagnostics_cpp( |
| rules, |
| out_cpp_dir, |
| template_dir, |
| ) |
| |
| gen_diagnostics_docs(rules, out_docs_dir, template_dir) |
| |
| |
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Generate ONNX diagnostics files") |
| parser.add_argument("rules_path", metavar="RULES", help="path to rules.yaml") |
| parser.add_argument( |
| "out_py_dir", |
| metavar="OUT_PY", |
| help="path to output directory for Python", |
| ) |
| parser.add_argument( |
| "out_cpp_dir", |
| metavar="OUT_CPP", |
| help="path to output directory for C++", |
| ) |
| parser.add_argument( |
| "out_docs_dir", |
| metavar="OUT_DOCS", |
| help="path to output directory for docs", |
| ) |
| args = parser.parse_args() |
| gen_diagnostics( |
| args.rules_path, |
| args.out_py_dir, |
| args.out_cpp_dir, |
| args.out_docs_dir, |
| ) |
| |
| |
| if __name__ == "__main__": |
| main() |