blob: a67dbb9091193e6f07cbdd9e5f45d266d5cd97f4 [file] [log] [blame]
import pathlib
import argparse
import os
import yaml
from typing import List, Dict, Union, Tuple, Sequence
from tools.codegen.gen import FileManager, get_grouped_native_functions, LineLoader, parse_native_yaml
from tools.codegen.model import (ExternalBackendFunction, ExternalBackendFunctionsGroup,
NativeFunction, NativeFunctionsGroup, OperatorName,
ExternalBackendMetadata, assert_never)
from tools.codegen.selective_build.selector import SelectiveBuilder
from tools.codegen.utils import Target, concatMap
import tools.codegen.dest as dest
def parse_backend_yaml(
backend_yaml_path: str,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]]
) -> Tuple[str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
with open(backend_yaml_path, 'r') as f:
yaml_values = yaml.load(f, Loader=LineLoader)
assert isinstance(yaml_values, dict)
cpp_namespace = yaml_values.pop('cpp_namespace')
backend = yaml_values.pop('backend')
supported = yaml_values.pop('supported', [])
assert isinstance(supported, list), f'expected "supported" to be a list, but got: {supported}'
supported_autograd = yaml_values.pop('autograd', [])
assert isinstance(supported, list), f'expected "autograd" to be a list, but got: {supported_autograd}'
assert len(yaml_values.keys()) > 0, \
f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}'
metadata: Dict[OperatorName, ExternalBackendMetadata] = {}
for op in supported:
op_name = OperatorName.parse(op)
m = ExternalBackendMetadata(op_name, backend, is_autograd=False)
metadata[m.operator] = m
for op in supported_autograd:
op_name = OperatorName.parse(op)
m = ExternalBackendMetadata(op_name, backend, is_autograd=True)
metadata[m.operator] = m
native_functions_map: Dict[OperatorName, NativeFunction] = {
f.func.name: f
for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
}
def native_to_external(
g: Union[NativeFunction, NativeFunctionsGroup]
) -> Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]:
if isinstance(g, NativeFunction):
f = g
m = metadata.get(f.func.name, None)
return ExternalBackendFunction(f, m)
elif isinstance(g, NativeFunctionsGroup):
return ExternalBackendFunctionsGroup.from_function_group(g, metadata)
else:
assert_never(g)
for op_name in metadata.keys():
if op_name not in native_functions_map:
raise AssertionError(f"Found an invalid operator name: {op_name}")
return cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
def main() -> None:
parser = argparse.ArgumentParser(description='Generate backend stub files')
parser.add_argument(
'-s',
'--source_yaml',
help='path to source yaml file containing operator external definitions')
parser.add_argument(
'-o', '--output_dir', help='output directory')
parser.add_argument(
'--dry_run', type=bool, default=False, help='output directory')
options = parser.parse_args()
# Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py
pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
def make_file_manager(install_dir: str) -> FileManager:
return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run)
fm = make_file_manager(options.output_dir)
native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
grouped_native_functions = get_grouped_native_functions(native_yaml_path)
cpp_namespace, external_backend_functions = parse_backend_yaml(options.source_yaml, grouped_native_functions)
native_functions = parse_native_yaml(native_yaml_path)
selector = SelectiveBuilder.get_nop_selector()
generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!'
fm.write('aten_xla_type.h', lambda: {
'generated_comment': generated_comment,
'cpp_namespace': cpp_namespace,
'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)),
})
fm.write('aten_xla_type_default.h', lambda: {
'generated_comment': generated_comment,
'cpp_namespace': cpp_namespace,
'dispatch_aten_fallback_declarations': list(concatMap(
dest.GenExternalAtenFallback(Target.NAMESPACED_DECLARATION), external_backend_functions
)),
})
fm.write('aten_xla_type_default.cpp', lambda: {
'generated_comment': generated_comment,
'cpp_namespace': cpp_namespace,
# TODO: after cpu fallbacks are moved to a boxed kernel,
# merge registrations / definitions into RegisterDispatchKey
'dispatch_aten_fallback_definitions': list(concatMap(
dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions
)),
'dispatch_registrations': list(concatMap(
dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if not e.is_autograd_kernel]
)),
'dispatch_autograd_registrations': list(concatMap(
dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if e.is_autograd_kernel]
)),
})
if __name__ == '__main__':
main()