blob: 0d02f1ec4e8f0afb98abf98af12e452d087b0960 [file] [log] [blame]
#!/usr/bin/env python3
""" Generates PyTorch ONNX Export Diagnostic rules for C++, Python and documentations.
The rules are defined in torch/onnx/_internal/diagnostics/rules.yaml.
Usage:
python -m tools.onnx.gen_diagnostics \
torch/onnx/_internal/diagnostics/rules.yaml \
torch/onnx/_internal/diagnostics \
torch/csrc/onnx/diagnostics/generated \
torch/docs/source
"""
import argparse
import os
import string
import subprocess
import textwrap
from typing import Any, Mapping, Sequence
import yaml
from torchgen import utils as torchgen_utils
from torchgen.yaml_utils import YamlLoader
_RULES_GENERATED_COMMENT = """\
GENERATED CODE - DO NOT EDIT DIRECTLY
This file is generated by gen_diagnostics.py.
See tools/onnx/gen_diagnostics.py for more information.
Diagnostic rules for PyTorch ONNX export.
"""
_PY_RULE_CLASS_COMMENT = """\
GENERATED CODE - DO NOT EDIT DIRECTLY
The purpose of generating a class for each rule is to override the `format_message`
method to provide more details in the signature about the format arguments.
"""
_PY_RULE_CLASS_TEMPLATE = """\
class _{pascal_case_name}(infra.Rule):
\"\"\"{short_description}\"\"\"
def format_message( # type: ignore[override]
self,
{message_arguments}
) -> str:
\"\"\"Returns the formatted default message of this Rule.
Message template: {message_template}
\"\"\"
return self.message_default_template.format({message_arguments_assigned})
def format( # type: ignore[override]
self,
level: infra.Level,
{message_arguments}
) -> Tuple[infra.Rule, infra.Level, str]:
\"\"\"Returns a tuple of (Rule, Level, message) for this Rule.
Message template: {message_template}
\"\"\"
return self, level, self.format_message({message_arguments_assigned})
"""
_PY_RULE_COLLECTION_FIELD_TEMPLATE = """\
{snake_case_name}: _{pascal_case_name} = dataclasses.field(
default=_{pascal_case_name}.from_sarif(**{sarif_dict}),
init=False,
)
\"\"\"{short_description}\"\"\"
"""
_CPP_RULE_TEMPLATE = """\
/**
* @brief {short_description}
*/
{name},
"""
_RuleType = Mapping[str, Any]
def _kebab_case_to_snake_case(name: str) -> str:
return name.replace("-", "_")
def _kebab_case_to_pascal_case(name: str) -> str:
return "".join(word.capitalize() for word in name.split("-"))
def _format_rule_for_python_class(rule: _RuleType) -> str:
pascal_case_name = _kebab_case_to_pascal_case(rule["name"])
short_description = rule["short_description"]["text"]
message_template = rule["message_strings"]["default"]["text"]
field_names = [
field_name
for _, field_name, _, _ in string.Formatter().parse(message_template)
if field_name is not None
]
for field_name in field_names:
assert isinstance(
field_name, str
), f"Unexpected field type {type(field_name)} from {field_name}. "
"Field name must be string.\nFull message template: {message_template}"
assert (
not field_name.isnumeric()
), f"Unexpected numeric field name {field_name}. "
"Only keyword name formatting is supported.\nFull message template: {message_template}"
message_arguments = ", ".join(field_names)
message_arguments_assigned = ", ".join(
[f"{field_name}={field_name}" for field_name in field_names]
)
return _PY_RULE_CLASS_TEMPLATE.format(
pascal_case_name=pascal_case_name,
short_description=short_description,
message_template=repr(message_template),
message_arguments=message_arguments,
message_arguments_assigned=message_arguments_assigned,
)
def _format_rule_for_python_field(rule: _RuleType) -> str:
snake_case_name = _kebab_case_to_snake_case(rule["name"])
pascal_case_name = _kebab_case_to_pascal_case(rule["name"])
short_description = rule["short_description"]["text"]
return _PY_RULE_COLLECTION_FIELD_TEMPLATE.format(
snake_case_name=snake_case_name,
pascal_case_name=pascal_case_name,
sarif_dict=rule,
short_description=short_description,
)
def _format_rule_for_cpp(rule: _RuleType) -> str:
name = f"k{_kebab_case_to_pascal_case(rule['name'])}"
short_description = rule["short_description"]["text"]
return _CPP_RULE_TEMPLATE.format(name=name, short_description=short_description)
def gen_diagnostics_python(
rules: Sequence[_RuleType], out_py_dir: str, template_dir: str
) -> None:
rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules]
rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules]
fm = torchgen_utils.FileManager(
install_dir=out_py_dir, template_dir=template_dir, dry_run=False
)
fm.write_with_template(
"_rules.py",
"rules.py.in",
lambda: {
"generated_comment": _RULES_GENERATED_COMMENT,
"generated_rule_class_comment": _PY_RULE_CLASS_COMMENT,
"rule_classes": "\n".join(rule_class_lines),
"rules": textwrap.indent("\n".join(rule_field_lines), " " * 4),
},
)
_lint_file(os.path.join(out_py_dir, "_rules.py"))
def gen_diagnostics_cpp(
rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str
) -> None:
rule_lines = [_format_rule_for_cpp(rule) for rule in rules]
rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules]
fm = torchgen_utils.FileManager(
install_dir=out_cpp_dir, template_dir=template_dir, dry_run=False
)
fm.write_with_template(
"rules.h",
"rules.h.in",
lambda: {
"generated_comment": textwrap.indent(
_RULES_GENERATED_COMMENT,
" * ",
predicate=lambda x: True, # Don't ignore empty line
),
"rules": textwrap.indent("\n".join(rule_lines), " " * 2),
"py_rule_names": textwrap.indent("\n".join(rule_names), " " * 4),
},
)
_lint_file(os.path.join(out_cpp_dir, "rules.h"))
def gen_diagnostics_docs(
rules: Sequence[_RuleType], out_docs_dir: str, template_dir: str
) -> None:
# TODO: Add doc generation in a follow-up PR.
pass
def _lint_file(file_path: str) -> None:
p = subprocess.Popen(["lintrunner", "-a", file_path])
p.wait()
def gen_diagnostics(
rules_path: str,
out_py_dir: str,
out_cpp_dir: str,
out_docs_dir: str,
) -> None:
with open(rules_path) as f:
rules = yaml.load(f, Loader=YamlLoader)
template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
gen_diagnostics_python(
rules,
out_py_dir,
template_dir,
)
gen_diagnostics_cpp(
rules,
out_cpp_dir,
template_dir,
)
gen_diagnostics_docs(rules, out_docs_dir, template_dir)
def main() -> None:
parser = argparse.ArgumentParser(description="Generate ONNX diagnostics files")
parser.add_argument("rules_path", metavar="RULES", help="path to rules.yaml")
parser.add_argument(
"out_py_dir",
metavar="OUT_PY",
help="path to output directory for Python",
)
parser.add_argument(
"out_cpp_dir",
metavar="OUT_CPP",
help="path to output directory for C++",
)
parser.add_argument(
"out_docs_dir",
metavar="OUT_DOCS",
help="path to output directory for docs",
)
args = parser.parse_args()
gen_diagnostics(
args.rules_path,
args.out_py_dir,
args.out_cpp_dir,
args.out_docs_dir,
)
if __name__ == "__main__":
main()