blob: 115bb22b66368cda2c441b61a3850a4cf91c71d2 [file] [log] [blame]
# this code should be common among cwrap and ATen preprocessing
# for now, I have put it in one place but right now is copied out of cwrap
from copy import deepcopy
from itertools import product
def parse_arguments(args):
new_args = []
for arg in args:
# Simple arg declaration of form "<type> <name>"
if isinstance(arg, str):
t, _, name = arg.partition(' ')
new_args.append({'type': t, 'name': name})
elif isinstance(arg, dict):
if 'arg' in arg:
arg['type'], _, arg['name'] = arg['arg'].partition(' ')
del arg['arg']
new_args.append(arg)
else:
assert False
return new_args
def set_declaration_defaults(declaration):
declaration.setdefault('arguments', [])
declaration.setdefault('return', 'void')
if 'cname' not in declaration:
declaration['cname'] = declaration['name']
if 'backends' not in declaration:
declaration['backends'] = ['CPU', 'CUDA']
if 'api_name' not in declaration:
declaration['api_name'] = (declaration['python_name']
if 'python_name' in declaration else declaration['name'])
# Simulate multiple dispatch, even if it's not necessary
if 'options' not in declaration:
declaration['options'] = [{'arguments': declaration['arguments']}]
del declaration['arguments']
# Parse arguments (some of them can be strings)
for option in declaration['options']:
option['arguments'] = parse_arguments(option['arguments'])
# Propagate defaults from declaration to options
for option in declaration['options']:
for k, v in declaration.items():
# TODO(zach): why does cwrap not propagate 'name'? I need it
# propagaged for ATen
if k != 'options':
option.setdefault(k, v)
# TODO(zach): added option to remove keyword handling for C++ which cannot
# support it.
def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self):
def exclude_arg(arg):
return arg.get('ignore_check') or arg['type'] == 'CONSTANT'
def exclude_arg_with_self_check(arg):
return exclude_arg(arg) or (remove_self and arg['name'] == 'self')
def signature(option, kwarg_only_count):
if kwarg_only_count == 0:
kwarg_only_count = None
else:
kwarg_only_count = -kwarg_only_count
arg_signature = '#'.join(
type_to_signature.get(arg['type'], arg['type'])
for arg in option['arguments'][:kwarg_only_count]
if not exclude_arg_with_self_check(arg))
if kwarg_only_count is None:
return arg_signature
kwarg_only_signature = '#'.join(
arg['name'] + '#' + arg['type']
for arg in option['arguments'][kwarg_only_count:]
if not exclude_arg(arg))
return arg_signature + "#-#" + kwarg_only_signature
seen_signatures = set()
unique = []
for option in options:
# if only check num_kwarg_only == 0 if allow_kwarg == False
limit = len(option['arguments']) if allow_kwarg else 0
for num_kwarg_only in range(0, limit + 1):
sig = signature(option, num_kwarg_only)
if sig not in seen_signatures:
if num_kwarg_only > 0:
for arg in option['arguments'][-num_kwarg_only:]:
arg['kwarg_only'] = True
unique.append(option)
seen_signatures.add(sig)
break
return unique
def enumerate_options_due_to_default(declaration,
allow_kwarg=True, type_to_signature=[], remove_self=True):
# Checks to see if an argument with a default keyword is a Tensor that
# by default can be NULL. In this case, instead of generating another
# option that excludes this argument, we will instead generate a single
# function call that allows for the Tensor to be NULL
def is_nullable_tensor_arg(arg):
return arg['type'] == 'THTensor*' and arg['default'] == 'nullptr'
# TODO(zach): in cwrap this is shared among all declarations
# but seems to assume that all declarations will have the same
new_options = []
for option in declaration['options']:
optional_args = []
for i, arg in enumerate(option['arguments']):
if 'default' in arg:
optional_args.append(i)
for permutation in product((True, False), repeat=len(optional_args)):
option_copy = deepcopy(option)
option_copy['has_full_argument_list'] = sum(permutation) == len(optional_args)
for i, bit in zip(optional_args, permutation):
arg = option_copy['arguments'][i]
# PyYAML interprets NULL as None...
arg['default'] = 'NULL' if arg['default'] is None else arg['default']
if not bit:
arg['declared_type'] = arg['type']
arg['type'] = 'CONSTANT'
arg['ignore_check'] = True
new_options.append(option_copy)
declaration['options'] = filter_unique_options(new_options,
allow_kwarg, type_to_signature, remove_self)
def sort_by_number_of_options(declaration, reverse=True):
def num_checked_args(option):
return sum(map(lambda a: not a.get('ignore_check', False), option['arguments']))
declaration['options'].sort(key=num_checked_args, reverse=reverse)
class Function(object):
def __init__(self, name):
self.name = name
self.arguments = []
def add_argument(self, arg):
assert isinstance(arg, Argument)
self.arguments.append(arg)
def __repr__(self):
return self.name + '(' + ', '.join(map(lambda a: a.__repr__(), self.arguments)) + ')'
class Argument(object):
def __init__(self, _type, name, is_optional):
self.type = _type
self.name = name
self.is_optional = is_optional
def __repr__(self):
return self.type + ' ' + self.name
def parse_header(path):
with open(path, 'r') as f:
lines = f.read().split('\n')
# Remove empty lines and prebackend directives
lines = filter(lambda l: l and not l.startswith('#'), lines)
# Remove line comments
lines = map(lambda l: l.partition('//'), lines)
# Select line and comment part
lines = map(lambda l: (l[0].strip(), l[2].strip()), lines)
# Remove trailing special signs
lines = map(lambda l: (l[0].rstrip(');').rstrip(','), l[1]), lines)
# Split arguments
lines = map(lambda l: (l[0].split(','), l[1]), lines)
# Flatten lines
new_lines = []
for l, c in lines:
for split in l:
new_lines.append((split, c))
lines = new_lines
del new_lines
# Remove unnecessary whitespace
lines = map(lambda l: (l[0].strip(), l[1]), lines)
# Remove empty lines
lines = filter(lambda l: l[0], lines)
generic_functions = []
for l, c in lines:
if l.startswith('TH_API void THNN_'):
fn_name = l.lstrip('TH_API void THNN_')
if fn_name[0] == '(' and fn_name[-2] == ')':
fn_name = fn_name[1:-2]
else:
fn_name = fn_name[:-1]
generic_functions.append(Function(fn_name))
elif l:
t, name = l.split()
if '*' in name:
t = t + '*'
name = name[1:]
generic_functions[-1].add_argument(
Argument(t, name, '[OPTIONAL]' in c))
return generic_functions