blob: 0ebac0d6e5ec51b98b23efb9b74f4b885a639d74 [file] [log] [blame]
import re
from code_template import CodeTemplate
# temporary things we cannot handle
EXCLUDE_PATTERN = "bernoulli.*|normal.*|exponential.*|random.*|arange.*"
# what has to be done to add a Operation ...
# 1. add virtual dispatch declaration to Type.h and default impl to Type.cpp
TYPE_METHOD_DECLARATION = CodeTemplate("""\
virtual ${return_type} ${method_prefix}${api_name}(${formals}) ;
""")
TYPE_METHOD_DEFINITION = CodeTemplate("""\
${return_type} Type::${method_prefix}${api_name}(${formals}) {
throw std::runtime_error(std::string("${api_name} is not implemented for type ") + toString());
}
""")
# 2. add virtual override to TypeDerived.h
TYPE_DERIVED_DECLARATION = CodeTemplate("""\
virtual ${return_type} ${method_prefix}${api_name}(${formals}) override;
""")
# 3. add override definition to TypeDerived.cpp
TYPE_DERIVED_DEFINITION = CodeTemplate("""\
${return_type} ${Type}::${method_prefix}${api_name}(${formals}) {
${type_definition_body}
}
""")
# 4. add non-virtual declaration to Tensor.h
TENSOR_METHOD_DECLARATION = CodeTemplate("""\
${return_type} ${api_name}(${method_formals})${const_mark};
""")
# 5. add non-virtual declaration to Tensor.cpp
TENSOR_METHOD_DEFINITION = CodeTemplate("""\
inline ${return_type} Tensor::${api_name}(${method_formals})${const_mark} {
return type().${method_prefix}${api_name}(${method_actuals});
}
""")
# 6. add a method declaration in Functions.h
FUNCTION_DECLARATION = CodeTemplate("""\
static inline ${return_type} ${api_name}(${formals});
""")
# 7. add a method definition in Functions.cpp
FUNCTION_DEFINITION = CodeTemplate("""\
static inline ${return_type} ${api_name}(${formals}) {
return ${inferred_type}.${api_name}(${actuals});
}
""")
ZERO_DIM_CHECK = CodeTemplate("""\
if(${check_name}.dim() == 0) {
return ${method_prefix}${api_name}(${zero_dim_actuals});
}""")
SCALAR_EXPAND = CodeTemplate("""\
Tensor ${name}__;
if(${name}_->isScalar()) {
${name}__ = ${name}.expand(${other}.sizes());
${name}_ = static_cast<${Tensor}*>(${name}__.pImpl);
}
""")
class NYIError(Exception):
"""Indicates we don't support this declaration yet"""
def __init__(self, reason):
self.reason = reason
TYPE_FORMAL_GENERIC = {
'THTensor*': 'Tensor &',
'THBoolTensor*': 'Tensor &',
'THIndexTensor*': 'Tensor &',
'THIntegerTensor*': 'Tensor &',
'THStorage*': 'Storage &',
'THGenerator*': 'Generator &',
'THSize*': 'IntList',
'THStride*': 'IntList',
'accreal': 'Scalar',
'real': 'Scalar',
'long': 'int64_t',
}
TYPE_RETURN = {
'THTensor*': 'Tensor',
'THIndexTensor*': 'Tensor',
'THBoolTensor*': 'Tensor',
'THIntegerTensor*': 'Tensor',
'real': 'Scalar',
'accreal': 'Scalar',
'long': 'int64_t',
}
CHECKED_CAST = {
'THTensor*': CodeTemplate('checked_cast<${Tensor}>(${arg_name}.pImpl,"${arg_name}",${arg_pos})'),
'THBoolTensor*': CodeTemplate('checked_cast<${Backend}ByteTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos})'),
'THIndexTensor*': CodeTemplate('checked_cast<${Backend}LongTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos})'),
'THIntegerTensor*': CodeTemplate('checked_cast<${Backend}IntTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos})'),
'THStorage*': CodeTemplate('checked_cast<${Storage}>(&${arg_name},"${arg_name}",${arg_pos})'),
'THGenerator*': CodeTemplate('check_generator(&${arg_name})'),
'THSize*': CodeTemplate('THLongStorageView::make(${arg_name},true)'),
'THStride*': CodeTemplate('THLongStorageView::make(${arg_name},true)'),
'real': CodeTemplate('${arg_name}.to${ScalarName}()'),
'accreal': CodeTemplate('${arg_name}.to${AccScalarName}()'),
'TensorList': CodeTemplate('tensor_list_checked_cast<${Tensor}, Tensor, ${THTensor}>(${arg_name},"${arg_name}",${arg_pos})'),
}
CHECKED_USE = {
'THTensor*': '{}_->tensor',
'THIndexTensor*': '{}_->tensor',
'THBoolTensor*': '{}_->tensor',
'THIntegerTensor*': '{}_->tensor',
'THStorage*': '{}_->storage',
'THGenerator*': '{}_->generator',
'TensorList': "{0}_.data(), {0}_.size()",
}
ALLOC_WRAP = {
'THTensor*': 'new ${Tensor}(context)',
'THBoolTensor*': 'new ${Backend}ByteTensor(context)',
'THIndexTensor*': 'new ${Backend}LongTensor(context)',
'THIntegerTensor*': 'new ${Backend}IntTensor(context)',
}
CONSTANT_REPLACEMENTS = [
('AS_REAL', '${AS_REAL}'),
('THPDefaultGenerator->cdata',
'dynamic_cast<${Generator}&>(context->defaultGenerator(backend())).generator'),
('__storage_size.get\\(\\)',
'THLongStorageView::make(static_cast<int64_t>(storage.size()))'),
('__last_dim', 'self.ndimension()-1'),
]
class nested_dict(object):
def __init__(self, base, parent):
self.base, self.parent = base, parent
def __getitem__(self, x):
r = self.base.get(x)
if r is not None:
return r
return self.parent[x]
def is_real_argument_to_wrapper(argument):
return not argument.get('output', False) and\
argument['type'] != 'CONSTANT' and\
argument['type'] != 'argument'
def is_mutable_formal_argument(argument, option):
return argument.get('output') or option['inplace'] and argument['name'] == 'self'
def to_return_type(arg, option):
t = arg['type']
rt = TYPE_RETURN.get(t, t)
if rt == 'Tensor' and not arg.get('allocate'):
rt = rt + ' &'
if not is_mutable_formal_argument(arg, option):
rt = 'const ' + rt
return rt
def create_generic(top_env, declarations):
def get_formals(option):
seen = set()
result = []
def insert(argument):
if argument['name'] not in seen:
seen.add(argument['name'])
result.append(argument)
for argument in option['arguments']:
if argument['type'] == 'THSTensor*':
raise NYIError("Sparse Tensor")
if is_real_argument_to_wrapper(argument):
insert(argument)
for argument in option['arguments']:
if argument.get('output') and not argument.get('allocate', False):
insert(argument)
return result
def format_formal(argument, option):
type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type'])
if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option):
type_str = 'const ' + type_str
return '{} {}'.format(type_str, argument['name'])
def format_return_type(option):
ret = option['return']
if ret['kind'] == 'arguments':
argument_indices = ret['arguments']
if len(argument_indices) == 1:
the_arg = option['arguments'][argument_indices[0]]
return to_return_type(the_arg, option)
else:
types = [to_return_type(option['arguments'][idx], option)
for idx in argument_indices]
return "std::tuple<{}>".format(','.join(types))
elif ret['kind'] == 'type':
return TYPE_RETURN.get(ret['type'], ret['type'])
else:
raise Exception("format_return_type")
def find_first_tensor(formals):
for argument in formals:
if argument['type'] == "THTensor*":
return argument['name']
return None
def process_option(option):
option['inplace'] = re.search(
'(^__i|[^_]_$)', option['api_name']) is not None
if re.match(EXCLUDE_PATTERN, option['name']):
print("Excluding {}".format(option['name']))
raise NYIError("NYI")
# print(yaml.dump(option))
formals = get_formals(option)
option['formals_list'] = formals
option['formals'] = [format_formal(f, option) for f in formals]
option['actuals'] = [f['name'] for f in formals]
option['method_formals'] = [format_formal(f, option) for f in formals
if f['name'] != 'self']
option['method_actuals'] = [
f['name'] if f['name'] != 'self' else '*this' for f in formals]
option['return_type'] = format_return_type(option)
option['const_mark'] = '' if option['inplace'] else ' const'
is_method = 'method' in option['variants']
is_function = 'function' in option['variants']
# method-only things are prefixed with m_ in Type so that
# another function-only variant can exist without the name colliding
option['method_prefix'] = 'm_' if is_method and not is_function else ''
env = nested_dict(option, top_env)
top_env['type_method_declarations'].append(
TYPE_METHOD_DECLARATION.substitute(env))
top_env['type_method_definitions'].append(
TYPE_METHOD_DEFINITION.substitute(env))
if is_method:
top_env['tensor_method_declarations'].append(
TENSOR_METHOD_DECLARATION.substitute(env))
top_env['tensor_method_definitions'].append(
TENSOR_METHOD_DEFINITION.substitute(env))
if is_function:
first_tensor = find_first_tensor(formals)
if first_tensor is not None:
option['inferred_type'] = '{}.type()'.format(first_tensor)
top_env['function_declarations'].append(
FUNCTION_DECLARATION.substitute(env))
top_env['function_definitions'].append(
FUNCTION_DEFINITION.substitute(env))
for declaration in declarations:
for option in declaration['options']:
try:
process_option(option)
except NYIError as e:
option['skip'] = True
def create_derived(backend_type_env, declarations):
type_object_declarations = []
type_object_definitions = []
def requires_checked_cast(argument):
return argument['type'] in CHECKED_CAST
def get_argument(argument, option):
if requires_checked_cast(argument):
return CHECKED_USE.get(argument['type'], '{}_').format(argument['name'])
elif argument['type'] == 'bool' and 'if_true' in argument:
return '({}) ? "{}" : "{}"'.format(argument['name'],
argument['if_true'], argument['if_false'])
elif argument['type'] == "CONSTANT":
if 'if_true' in argument: # this was a bool that is actually a string...
return '"{}"'.format(argument['name'])
v = str(argument['name'])
for pattern, replacement in CONSTANT_REPLACEMENTS:
v = re.sub(pattern, replacement, v)
return CodeTemplate(v).substitute(backend_type_env)
# e.g. argument 0, i.e. repeat the 0th argument in this position...
elif argument['type'] == 'argument':
index = int(argument['name'])
return get_argument(option['arguments'][index], option)
else:
return argument['name']
def drop_argument(argument, option):
return backend_type_env['Backend'] == 'CUDA' and (
(option['mode'] == 'TH' and argument['type'] == 'THGenerator*') or
argument['name'] == 'THPDefaultGenerator->cdata')
def get_arguments(option):
return [get_argument(argument, option)
for argument in option['arguments'] if not drop_argument(argument, option)]
def is_actual_return_long(ret):
return ret['type'] == 'long' or (backend_type_env['ScalarName'] == 'Long' and
ret['type'] == 'real' or ret['type'] == 'accreal')
def handle_zero_dim(env, option):
if 'zero_dim_dispatch_when_scalar' not in option:
return []
check_name = option['zero_dim_dispatch_when_scalar']
zero_dim_actuals = [arg['name']
if arg['name'] != check_name else "Scalar({})".format(arg['name'])
for arg in option['formals_list']]
return [ZERO_DIM_CHECK.substitute(env, check_name=check_name, zero_dim_actuals=zero_dim_actuals)]
def emit_body(env, option):
body = []
body += handle_zero_dim(env, option)
# arguments are potentially duplicated because of one argument
# referencing another
seen_names = set()
count = 0
is_cuda = backend_type_env['Backend'] == 'CUDA'
# scalar_check is the heuristic conditions when a result may be a scalar_check
# if there is a THSize* argument, then its dimensions are used to determine scalar.
# otherwise, it is true if all the input tensors are scalars,
scalar_check_is_from_size = False
scalar_check = None
for arg in option['arguments']:
if is_real_argument_to_wrapper(arg):
count += 1
if arg['type'] == 'THSize*':
scalar_check_is_from_size = True
scalar_check = '{}.size() == 0'.format(arg['name'])
# only generated checked casts the first time we see it
if not arg['name'] in seen_names and requires_checked_cast(arg):
seen_names.add(arg['name'])
# make a new allocation of TensorImpl, then wrap a Tensor around it.
if arg.get('allocate', False):
allocation = CodeTemplate(
ALLOC_WRAP[arg['type']]).substitute(env)
body.append('auto {}_ = {};'.format(
arg['name'], allocation))
body.append('auto {} = Tensor({}_,false);'.format(
arg['name'], arg['name']))
# extract the TensorImpl from an existing tensor (or Storage, etc.)
else:
check_cast = CHECKED_CAST[arg['type']].substitute(
env, arg_name=arg['name'], arg_pos=count)
body.append("auto {}_ = {};".format(
arg['name'], check_cast))
if drop_argument(arg, option):
body.append("(void) {}_; //silence unused warning".format(arg['name']))
# resize tensors for special ops that require it
if 'resize' in arg:
resize = arg['resize']
if isinstance(resize, str):
body.append("{}.resize_({}.sizes());".format(
arg['name'], resize))
else:
dims = ['{}.size({})'.format(name, dim)
for name, dim in resize]
body.append("{}.resize_({{ {} }});".format(
arg['name'], ','.join(dims)))
# also special handling where we zero some outputs.
if arg.get('cpu_zero', False) and not is_cuda:
body.append("{}.zero_();".format(arg['name']))
# handle scalars that occur on LHS of things like a - b
if 'broadcast' in arg and 'inplace' not in arg['broadcast']:
other = arg['broadcast'].split(' ')[0].split(',')[0]
body.append(SCALAR_EXPAND.substitute(env,
name=arg['name'],
other=other))
# dim() == 0 of all input tensors is and'd to form
# the test for whether the output is also a scalar
if (not arg.get('output') and 'Tensor' in arg['type'] and
'TensorList' not in arg['type'] and not scalar_check_is_from_size):
check = '{}.dim() == 0'.format(arg['name'])
scalar_check = (check if scalar_check is None
else scalar_check + ' && ' + check)
option['derived_actuals'] = get_arguments(option)
is_nn = option['mode'] == 'NN'
if is_cuda or is_nn:
option['derived_actuals'] = ['context->thc_state'] + option['derived_actuals']
if is_nn:
prefix = 'THNN_{}'.format(env['THType'])
else:
prefix = env['THTensor'] + '_'
call = prefix + CodeTemplate("${cname}(${derived_actuals})").substitute(env)
ret = option['return']
if ret['kind'] == 'arguments':
if 'aten_custom_call' in option:
scalar_check = None # all aten_custom_call bodies handle settings on their own.
body.append(CodeTemplate(option['aten_custom_call']).substitute(env))
else:
body.append(call + ";")
arguments_indices = ret['arguments']
arguments = [option['arguments'][argi]
for argi in arguments_indices]
if scalar_check is not None:
if len(arguments) > 1:
body.append("bool maybe_scalar = {};".format(scalar_check))
scalar_check = 'maybe_scalar'
for arg in arguments:
body.append("{}_->maybeScalar({});".format(arg['name'], scalar_check))
if len(arguments_indices) == 1:
arg = arguments[0]
body.append("return {};".format(arg['name']))
else:
types = [to_return_type(arg, option) for arg in arguments]
# TODO: check for move semantics...
names = [arg['name'] for arg in arguments]
body.append(CodeTemplate("return std::tuple<${types}>(${names});").substitute(
types=types, names=names))
elif ret['kind'] == 'type':
if ret['type'] == 'THTensor*':
maybe_scalar = "->maybeScalar({})".format(scalar_check) \
if scalar_check is not None \
else ""
return_tensor = "return Tensor((new ${Tensor}(context,${arg_name}))${maybe_scalar},false);"
body.append(CodeTemplate(return_tensor).substitute(env, arg_name=call, maybe_scalar=maybe_scalar))
else:
# we using int64_t for long in the API, so correct it here...
if is_actual_return_long(ret):
call = "static_cast<int64_t>({})".format(call)
body.append("return {};".format(call))
else:
raise Exception("NYI - return handling")
return body
def process_option(option):
pair = (backend_type_env['Backend'],
backend_type_env['ScalarName'])
if pair in option['backend_type_pairs']:
env = nested_dict(option, backend_type_env)
body = emit_body(env, option)
option['type_definition_body'] = body
type_object_declarations.append(
TYPE_DERIVED_DECLARATION.substitute(env))
type_object_definitions.append(
TYPE_DERIVED_DEFINITION.substitute(env))
for declaration in declarations:
for option in declaration['options']:
if not option.get('skip', False):
try:
process_option(option)
except NYIError:
pass
return type_object_declarations, type_object_definitions