blob: 22e29e4e7f7c32d05cc951cc96c71531300a4570 [file] [log] [blame]
import re
from copy import deepcopy
from function_wrapper import TYPE_FORMAL_GENERIC
import common_with_cwrap
type_map = {
'floating_point': [
'Float',
'Double',
'Half',
],
'integral': [
'Byte',
'Char',
'Short',
'Int',
'Long'
],
}
all_types = type_map['floating_point'] + type_map['integral']
type_map['all'] = all_types
all_backends = ['CPU', 'CUDA', 'SparseCPU', 'SparseCUDA']
default_backends = ['CPU', 'CUDA']
def process_types_and_backends(option):
# if specific pairs were not listed, then enumerate them
# based on the backend and type attributes
# if backend or type is not defined, it is assumed to be all of them
if 'backend_type_pairs' not in option:
backends = option.get('backends', default_backends)
if option.get('aten_sparse', False):
backends = all_backends
types = option.get('types', all_types)
pairs = [[p, t] for p in backends for t in types]
else:
pairs = option['backend_type_pairs']
# expand type alias (integral, floating_point, all)
def expand(pair):
p, t = pair
assert(p in all_backends)
if t in type_map:
return [(p, tt) for tt in type_map[t]]
assert(t in all_types)
return [(p, t)]
pairs = set(p for pair in pairs for p in expand(pair))
# disable CUDA Half if there is a Sparse argument
for arg in option.get('arguments', []):
if arg['type'] == 'THSTensor*':
pairs.discard(('CUDA', 'Half'))
# special case remove Half for cpu unless it is explicitly enabled,
if not option.get('cpu_half', False):
pairs.discard(('CPU', 'Half'))
# sort the result for easy reading
option['backend_type_pairs'] = sorted([p for p in pairs])
def exclude(declaration):
return 'only_register' in declaration or declaration.get('python_name') == 'ndimension'
def add_variants(option):
option.setdefault('variants', ['method'])
# if we have 'output' arguments, generate a variant where
# we mark oututs as allocate = True, and where the method variant
# is disabled...
def handle_outputs_taken_as_arguments(options):
new_options = []
def is_nullable(arg):
return (arg['type'] in {'THIntegerTensor*', 'THTensor*'} and
arg.get('default', '') in {'NULL', 'nullptr'})
for option in options:
for arg in option['arguments']:
# mark arguments which can be null
if is_nullable(arg):
arg['is_nullable'] = True
if any('output' in arg for arg in option['arguments']):
allocate_option = deepcopy(option)
# the allocating option needs to be marked
for arg in allocate_option['arguments']:
if 'output' in arg:
arg['allocate'] = True
# the original option, which takes arguments for the results,
# is no longer a method, and has _out added to indicte it takes
# output arguments
if 'function' in option['variants']:
if 'method' in option['variants']:
option['variants'].remove('method')
option['api_name'] += '_out'
new_options.append(option)
new_options.append(allocate_option)
else:
new_options.append(option)
return new_options
def sanitize_return(option):
ret = option['return']
m = re.match('argument (\d+(,\d+)*)', ret)
if m is not None:
arguments = [int(x) for x in m.group(1).split(',')]
option['return'] = {'kind': 'arguments', 'arguments': arguments}
elif ret == 'self':
option['return'] = {'kind': 'arguments', 'arguments': []}
for i, x in enumerate(option['arguments']):
if x['name'] == 'self':
option['return']['arguments'].append(i)
break
else:
option['return'] = {'kind': 'type', 'type': option['return']}
def set_mode(option):
option['mode'] = option.get('mode', 'TH')
# To enable 0-dim support in TH operations
# we find all places where a single Scalar replaced with a Tensor
# as an argument is still a valid function
# we then mark the tensor variant with a key zero_dim_dispatch_when_scalar: name
# where 'name' is the name of the argument that should be a scalar
# during dispatch, if that argument is marked internally as holding a scalar
# then the method will dispatch to that function.
def discover_zero_dim_tensor_operations(declaration):
def exclude(arg):
return arg.get('ignore_check')
def signature(option, i=None, value=None):
elements = [TYPE_FORMAL_GENERIC.get(arg['type'], arg['type'])
if i is None or j != i else value
for j, arg in enumerate(option['arguments'])
if not exclude(arg)]
return '#'.join(elements)
signature_to_option = {signature(option): option
for option in declaration['options']}
for option in declaration['options']:
for i, arg in enumerate(option['arguments']):
if arg['type'] == 'real':
signature_of_tensor_version = signature(option, i, 'Tensor &')
if signature_of_tensor_version in signature_to_option:
tensor_version = \
signature_to_option[signature_of_tensor_version]
names = [arg['name'] for arg in tensor_version['arguments']
if not exclude(arg)]
tensor_version['zero_dim_dispatch_when_scalar'] = names[i]
# print("FOUND "+str(i) )
# print("Scalar Version ===== ")
# print(yaml.dump(option))
# print("Tensor Version ===== ")
# print(yaml.dump(tensor_version))
# print("SHARED "+names[i])
def discover_sparse_tensor_operations(declaration):
def exclude(arg):
return arg.get('ignore_check')
def signature(option, i=None, value=None):
elements = [TYPE_FORMAL_GENERIC.get(arg['type'], arg['type'])
if i is None or j != i else value
for j, arg in enumerate(option['arguments'])
if not exclude(arg)]
return '#'.join(elements)
name = declaration['name']
if name == 'add' or name == 'add_':
signature_to_option = {signature(option): option
for option in declaration['options']}
for option in declaration['options']:
for i, arg in enumerate(option['arguments']):
if (arg['type'] == 'THSTensor*' and
option.get('aten_dense_sparse', False)):
signature_of_tensor_version = signature(
option, i, 'Tensor &')
if signature_of_tensor_version in signature_to_option:
tensor_version = \
signature_to_option[signature_of_tensor_version]
raw_args = len(tensor_version['arguments'])
names = [arg['name'] for arg in tensor_version['arguments']
if not exclude(arg)]
filtered_args = len(names)
tensor_version['when_sparse_dispatch'] = names[i -
(raw_args - filtered_args)]
def run(declarations):
declarations = [d for d in declarations if not exclude(d)]
for declaration in declarations:
common_with_cwrap.set_declaration_defaults(declaration)
declaration['options'] = [deepcopy(o) for o in declaration['options']]
declaration['options'] = common_with_cwrap.filter_unique_options(
declaration['options'],
allow_kwarg=False,
type_to_signature=TYPE_FORMAL_GENERIC,
remove_self=True)
common_with_cwrap.sort_by_number_of_options(declaration)
discover_zero_dim_tensor_operations(declaration)
discover_sparse_tensor_operations(declaration)
for option in declaration['options']:
set_mode(option)
sanitize_return(option)
process_types_and_backends(option)
add_variants(option)
declaration['options'] = handle_outputs_taken_as_arguments(
declaration['options'])
return declarations