| # this code should be common among cwrap and ATen preprocessing |
| # for now, I have put it in one place but right now is copied out of cwrap |
| |
| import copy |
| from typing import Any, Dict, Iterable, List, Union |
| |
| Arg = Dict[str, Any] |
| |
| def parse_arguments(args: List[Union[str, Arg]]) -> List[Arg]: |
| new_args = [] |
| for arg in args: |
| # Simple arg declaration of form "<type> <name>" |
| if isinstance(arg, str): |
| t, _, name = arg.partition(' ') |
| new_args.append({'type': t, 'name': name}) |
| elif isinstance(arg, dict): |
| if 'arg' in arg: |
| arg['type'], _, arg['name'] = arg['arg'].partition(' ') |
| del arg['arg'] |
| new_args.append(arg) |
| else: |
| raise AssertionError() |
| return new_args |
| |
| |
| Declaration = Dict[str, Any] |
| |
| |
| def set_declaration_defaults(declaration: Declaration) -> None: |
| if 'schema_string' not in declaration: |
| # This happens for legacy TH bindings like |
| # _thnn_conv_depthwise2d_backward |
| declaration['schema_string'] = '' |
| declaration.setdefault('arguments', []) |
| declaration.setdefault('return', 'void') |
| if 'cname' not in declaration: |
| declaration['cname'] = declaration['name'] |
| if 'backends' not in declaration: |
| declaration['backends'] = ['CPU', 'CUDA'] |
| assert 'api_name' not in declaration |
| declaration['api_name'] = declaration['name'] |
| # NB: keep this in sync with gen_autograd.py |
| if declaration.get('overload_name'): |
| declaration['type_wrapper_name'] = "{}_{}".format( |
| declaration['name'], declaration['overload_name']) |
| else: |
| declaration['type_wrapper_name'] = declaration['name'] |
| # TODO: Uggggh, parsing the schema string here, really??? |
| declaration['operator_name_with_overload'] = declaration['schema_string'].split('(')[0] |
| if declaration['schema_string']: |
| declaration['unqual_schema_string'] = declaration['schema_string'].split('::')[1] |
| declaration['unqual_operator_name_with_overload'] = declaration['operator_name_with_overload'].split('::')[1] |
| else: |
| declaration['unqual_schema_string'] = '' |
| declaration['unqual_operator_name_with_overload'] = '' |
| # Simulate multiple dispatch, even if it's not necessary |
| if 'options' not in declaration: |
| declaration['options'] = [{ |
| 'arguments': copy.deepcopy(declaration['arguments']), |
| 'schema_order_arguments': copy.deepcopy(declaration['schema_order_arguments']), |
| }] |
| del declaration['arguments'] |
| del declaration['schema_order_arguments'] |
| # Parse arguments (some of them can be strings) |
| for option in declaration['options']: |
| option['arguments'] = parse_arguments(option['arguments']) |
| option['schema_order_arguments'] = parse_arguments(option['schema_order_arguments']) |
| # Propagate defaults from declaration to options |
| for option in declaration['options']: |
| for k, v in declaration.items(): |
| # TODO(zach): why does cwrap not propagate 'name'? I need it |
| # propagaged for ATen |
| if k != 'options': |
| option.setdefault(k, v) |
| |
| # TODO(zach): added option to remove keyword handling for C++ which cannot |
| # support it. |
| |
| Option = Dict[str, Any] |
| |
| |
| def filter_unique_options( |
| options: Iterable[Option], |
| allow_kwarg: bool, |
| type_to_signature: Dict[str, str], |
| remove_self: bool, |
| ) -> List[Option]: |
| def exclude_arg(arg: Arg) -> bool: |
| return arg['type'] == 'CONSTANT' # type: ignore[no-any-return] |
| |
| def exclude_arg_with_self_check(arg: Arg) -> bool: |
| return exclude_arg(arg) or (remove_self and arg['name'] == 'self') |
| |
| def signature(option: Option, num_kwarg_only: int) -> str: |
| if num_kwarg_only == 0: |
| kwarg_only_count = None |
| else: |
| kwarg_only_count = -num_kwarg_only |
| arg_signature = '#'.join( |
| type_to_signature.get(arg['type'], arg['type']) |
| for arg in option['arguments'][:kwarg_only_count] |
| if not exclude_arg_with_self_check(arg)) |
| if kwarg_only_count is None: |
| return arg_signature |
| kwarg_only_signature = '#'.join( |
| arg['name'] + '#' + arg['type'] |
| for arg in option['arguments'][kwarg_only_count:] |
| if not exclude_arg(arg)) |
| return arg_signature + "#-#" + kwarg_only_signature |
| seen_signatures = set() |
| unique = [] |
| for option in options: |
| # if only check num_kwarg_only == 0 if allow_kwarg == False |
| limit = len(option['arguments']) if allow_kwarg else 0 |
| for num_kwarg_only in range(0, limit + 1): |
| sig = signature(option, num_kwarg_only) |
| if sig not in seen_signatures: |
| if num_kwarg_only > 0: |
| for arg in option['arguments'][-num_kwarg_only:]: |
| arg['kwarg_only'] = True |
| unique.append(option) |
| seen_signatures.add(sig) |
| break |
| return unique |
| |
| |
| def sort_by_number_of_args(declaration: Declaration, reverse: bool = True) -> None: |
| def num_args(option: Option) -> int: |
| return len(option['arguments']) |
| declaration['options'].sort(key=num_args, reverse=reverse) |
| |
| |
| class Function(object): |
| |
| def __init__(self, name: str) -> None: |
| self.name = name |
| self.arguments: List['Argument'] = [] |
| |
| def add_argument(self, arg: 'Argument') -> None: |
| assert isinstance(arg, Argument) |
| self.arguments.append(arg) |
| |
| def __repr__(self) -> str: |
| return self.name + '(' + ', '.join(a.__repr__() for a in self.arguments) + ')' |
| |
| |
| class Argument(object): |
| |
| def __init__(self, _type: str, name: str, is_optional: bool): |
| self.type = _type |
| self.name = name |
| self.is_optional = is_optional |
| |
| def __repr__(self) -> str: |
| return self.type + ' ' + self.name |
| |
| |
| def parse_header(path: str) -> List[Function]: |
| with open(path, 'r') as f: |
| lines: Iterable[Any] = f.read().split('\n') |
| |
| # Remove empty lines and prebackend directives |
| lines = filter(lambda l: l and not l.startswith('#'), lines) |
| # Remove line comments |
| lines = (l.partition('//') for l in lines) |
| # Select line and comment part |
| lines = ((l[0].strip(), l[2].strip()) for l in lines) |
| # Remove trailing special signs |
| lines = ((l[0].rstrip(');').rstrip(','), l[1]) for l in lines) |
| # Split arguments |
| lines = ((l[0].split(','), l[1]) for l in lines) |
| # Flatten lines |
| new_lines = [] |
| for l, c in lines: |
| for split in l: |
| new_lines.append((split, c)) |
| lines = new_lines |
| del new_lines |
| # Remove unnecessary whitespace |
| lines = ((l[0].strip(), l[1]) for l in lines) |
| # Remove empty lines |
| lines = filter(lambda l: l[0], lines) |
| generic_functions = [] |
| for l, c in lines: |
| if l.startswith('TH_API void THNN_'): |
| fn_name = l[len('TH_API void THNN_'):] |
| if fn_name[0] == '(' and fn_name[-2] == ')': |
| fn_name = fn_name[1:-2] |
| else: |
| fn_name = fn_name[:-1] |
| generic_functions.append(Function(fn_name)) |
| elif l.startswith('TORCH_CUDA_CPP_API void THNN_'): |
| fn_name = l[len('TORCH_CUDA_CPP_API void THNN_'):] |
| if fn_name[0] == '(' and fn_name[-2] == ')': |
| fn_name = fn_name[1:-2] |
| else: |
| fn_name = fn_name[:-1] |
| generic_functions.append(Function(fn_name)) |
| elif l.startswith('TORCH_CUDA_CU_API void THNN_'): |
| fn_name = l[len('TORCH_CUDA_CU_API void THNN_'):] |
| if fn_name[0] == '(' and fn_name[-2] == ')': |
| fn_name = fn_name[1:-2] |
| else: |
| fn_name = fn_name[:-1] |
| generic_functions.append(Function(fn_name)) |
| elif l: |
| t, name = l.split() |
| if '*' in name: |
| t = t + '*' |
| name = name[1:] |
| generic_functions[-1].add_argument( |
| Argument(t, name, '[OPTIONAL]' in c)) |
| return generic_functions |