blob: 006c524cbc08b4381965708f0fa779eb3dddf86d [file] [log] [blame]
import re
from code_template import CodeTemplate
import sys
if sys.version_info[0] == 3:
string_type = str
else:
string_type = basestring
# temporary things we cannot handle
EXCLUDE_PATTERN = "bernoulli.*"
# what has to be done to add a Operation ...
# 1. if broadcasting or without the full list of arguments, add a non-virtual
# declaration under Type.h
TYPE_METHOD_DECLARATION_NON_VIRTUAL = CodeTemplate("""\
${return_type} ${method_prefix}${api_name}(${formals_with_defaults}) const;
""")
# 2. broadcasting functions are implemented in Type.cpp
TYPE_METHOD_DEFINITION_BROADCAST = CodeTemplate("""\
${return_type} Type::${method_prefix}${api_name}(${formals}) const {
Tensor ${broadcast_returns};
std::tie(${broadcast_returns}) = ${broadcast_function}(${broadcast_actuals});
return ${method_prefix_derived}${api_name}(${broadcast_modified_actuals});
}
""")
# 3. add virtual dispatch declaration to Type.h and default impl to Type.cpp
TYPE_METHOD_DECLARATION = CodeTemplate("""\
virtual ${return_type} ${method_prefix}${api_name}(${formals_with_defaults}) const;
""")
TYPE_METHOD_DEFINITION = CodeTemplate("""\
${return_type} Type::${method_prefix}${api_name}(${formals}) const {
throw std::runtime_error(std::string("${api_name} is not implemented for type ") + toString());
}
""")
# 4. add virtual override to TypeDerived.h
TYPE_DERIVED_DECLARATION = CodeTemplate("""\
virtual ${return_type} ${method_prefix_derived}${api_name}(${formals}) const override;
""")
# 5. add override definition to TypeDerived.cpp
TYPE_DERIVED_DEFINITION = CodeTemplate("""\
${return_type} ${Type}::${method_prefix_derived}${api_name}(${formals}) const {
${type_definition_body}
}
""")
# 6. add non-virtual declaration to Tensor.h
TENSOR_METHOD_DECLARATION = CodeTemplate("""\
${return_type} ${api_name}(${method_formals_with_defaults})${const_mark};
""")
# 7. add non-virtual declaration to Tensor.cpp
TENSOR_METHOD_DEFINITION = CodeTemplate("""\
inline ${return_type} Tensor::${api_name}(${method_formals})${const_mark} {
return type().${method_prefix}${api_name}(${method_actuals});
}
""")
# 8. add a method declaration in Functions.h
FUNCTION_DECLARATION = CodeTemplate("""\
static inline ${return_type} ${api_name}(${formals_with_defaults});
""")
# 9. add method definition in Functions.h
FUNCTION_DEFINITION = CodeTemplate("""\
static inline ${return_type} ${api_name}(${formals}) {
return ${inferred_type}.${api_name}(${actuals});
}
""")
# We need to cast to the base type because C++ may hide the base class
# implementation of ${api_name} if we have overloaded a function with
# the same name (but different signature) already
ZERO_DIM_CHECK = CodeTemplate("""\
if(${check_name}.dim() == 0) {
return static_cast<const Type*>(this)->${method_prefix}${api_name}(${zero_dim_actuals});
}""")
SPARSE_CHECK = CodeTemplate("""\
if(${check_name}.type().isSparse()) {
return static_cast<const Type*>(this)->${method_prefix}${api_name}(${sparse_actuals});
}""")
BUFFER_DEFINITION = CodeTemplate("""\
auto ${name}_ = new ${Tensor}(context);
auto ${name} = Tensor(${name}_, false);""")
CONDITIONAL_INITIALIZER = CodeTemplate("""\
if (${name}.defined()) {
${initializer}
}""")
CALL_TEMPLATE = CodeTemplate("${cname}(${actuals})")
class NYIError(Exception):
"""Indicates we don't support this declaration yet"""
def __init__(self, reason):
self.reason = reason
TYPE_FORMAL_GENERIC = {
'THTensor*': 'Tensor &',
'THSTensor*': 'SparseTensor',
'THBoolTensor*': 'Tensor &',
'THIndexTensor*': 'Tensor &',
'THIntegerTensor*': 'Tensor &',
'THStorage*': 'Storage &',
'THGenerator*': 'Generator *',
'THSize*': 'IntList',
'THStride*': 'IntList',
'accreal': 'Scalar',
'real': 'Scalar',
'long': 'int64_t',
}
DYNAMIC_TYPE = {
'THTensor*': 'Tensor',
'THBoolTensor*': 'BoolTensor',
'THIndexTensor*': 'IndexTensor',
'THIntegerTensor*': 'IntegerTensor',
'THStorage*': 'Storage',
'THGenerator*': 'Generator*',
'THSize*': 'IntList',
'THStride*': 'IntList',
'accreal': 'accreal',
'real': 'real',
'long': 'int64_t',
}
TYPE_RETURN = {
'THTensor*': 'Tensor',
'THIndexTensor*': 'Tensor',
'THBoolTensor*': 'Tensor',
'THIntegerTensor*': 'Tensor',
'real': 'Scalar',
'accreal': 'Scalar',
'long': 'int64_t',
}
CHECKED_CAST = {
'THTensor*': CodeTemplate('checked_cast<${Tensor}>(${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay})'),
'THSTensor*':
CodeTemplate(
'checked_cast<Sparse${Tensor}>(${arg_name}.tref.pImpl,"${arg_name}",${arg_pos},false)'),
'THBoolTensor*':
CodeTemplate(
'checked_cast<${Backend}ByteTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay})'),
'THIndexTensor*':
CodeTemplate(
'checked_cast<${Backend}LongTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay})'),
'THIntegerTensor*':
CodeTemplate(
'checked_cast<${Backend}IntTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay})'),
'THStorage*': CodeTemplate('checked_cast<${Storage}>(&${arg_name},"${arg_name}",${arg_pos}, false)'),
'THGenerator*':
CodeTemplate(
'check_generator<${Backend}Generator>(${arg_name}, &context->defaultGenerator(backend()))'),
'THSize*': CodeTemplate('THLongStorageView::make(${arg_name}, true)'),
'THStride*': CodeTemplate('THLongStorageView::make(${arg_name}, false, true)'),
'real': CodeTemplate('${arg_name}.to${ScalarName}()'),
'accreal': CodeTemplate('${arg_name}.to${AccScalarName}()'),
'TensorList': CodeTemplate('tensor_list_checked_cast<${Tensor}, Tensor, '
'${THTensor}>(${arg_name},"${arg_name}",${arg_pos})'),
'IntList': CodeTemplate('check_intlist<${size}>(${arg_name}, "${arg_name}", ${arg_pos}${,default_init})')
}
CHECKED_USE = {
'THTensor*': '{}_->tensor',
'THSTensor*': '{}_->tensor',
'THIndexTensor*': '{}_->tensor',
'THBoolTensor*': '{}_->tensor',
'THIntegerTensor*': '{}_->tensor',
'THStorage*': '{}_->storage',
'THGenerator*': '{}_->generator',
'TensorList': "{0}_.data(), {0}_.size()",
}
CHECKED_USE_NULLABLE = CodeTemplate('${arg_name}_ ? ${usage} : NULL')
ALLOC_WRAP = {
'THTensor*': 'new ${Tensor}(context)',
'THBoolTensor*': 'new ${Backend}ByteTensor(context)',
'THIndexTensor*': 'new ${Backend}LongTensor(context)',
'THIntegerTensor*': 'new ${Backend}IntTensor(context)',
}
# Replacements for constants when calling into TH
CONSTANT_REPLACEMENTS = [
('AS_REAL', '${AS_REAL}'),
('THPDefaultGenerator->cdata',
'dynamic_cast<${Generator}&>().generator'),
('__storage_size.get\\(\\)',
'THLongStorageView::make(static_cast<int64_t>(storage.size()))'),
('__last_dim', 'self.ndimension()-1'),
]
# Replacements for constants in header file function definitions
HEADER_CONSTANT_REPLACEMENTS = [
(r'AS_REAL\((.*)\)', r'\1'),
('THPDefaultGenerator->cdata', 'nullptr'),
('__last_dim', '-1'),
]
class nested_dict(object):
def __init__(self, base, parent):
self.base, self.parent = base, parent
def __getitem__(self, x):
r = self.base.get(x)
if r is not None:
return r
return self.parent[x]
def is_real_argument_to_wrapper(argument):
return not argument.get('output', False) and\
argument['type'] != 'CONSTANT' and\
argument['type'] != 'argument'
def is_mutable_formal_argument(argument, option):
return argument.get('output') or option['inplace'] and argument['name'] == 'self'
def to_return_type(arg, option):
t = arg['type']
rt = TYPE_RETURN.get(t, t)
if rt == 'Tensor' and not arg.get('allocate'):
rt = rt + ' &'
if not is_mutable_formal_argument(arg, option):
rt = 'const ' + rt
return {
'name': arg['name'],
'type': rt,
'dynamic_type': DYNAMIC_TYPE.get(arg['type'], arg['type']),
}
def create_generic(top_env, declarations):
# translates defaults from cwrap types to C++ values
def translate_default(argument, type_str, default):
if default is None:
# cause the default constructor for the object to run
return '{}'
if 'if_true' in argument:
return argument['default'] == argument['if_true']
for pattern, replacement in HEADER_CONSTANT_REPLACEMENTS:
default = re.sub(pattern, replacement, str(default))
if type_str in {'Scalar', 'int64_t', 'double'}:
return float(default) if '.' in default else int(default)
elif type_str == 'bool':
assert default.lower() in ['true', 'false']
return default.lower() == 'true'
else:
return default
# change from THTensor* to Tensor & so we get how it will appear
# in the aten argument list...
def translate_formal(argument, option):
type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type'])
if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option):
type_str = 'const ' + type_str
translated = {
'name': argument['name'],
'type': type_str,
'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']),
}
if 'kwarg_only' in argument:
translated['kwarg_only'] = argument['kwarg_only']
if 'default' in argument:
default = translate_default(argument, type_str, argument['default'])
translated['default'] = default
translated['default_init'] = argument.get('default_init', default)
if argument.get('output'):
translated['output'] = True
if argument.get('size'):
translated['size'] = argument['size']
if argument.get('is_nullable') is not None:
translated['is_nullable'] = argument['is_nullable']
return translated
def get_formals(option, include_constants=False):
seen = set()
pos_args = []
kwd_args = []
def insert(argument):
if argument['name'] not in seen:
seen.add(argument['name'])
if argument.get('kwarg_only', False):
kwd_args.append(argument)
else:
pos_args.append(argument)
def has_output_mask(argument):
return argument.get('allocate', False) and argument.get('mask', False)
for argument in option['arguments']:
if argument.get('output') and not argument.get('allocate', False):
insert(argument)
for argument in option['arguments']:
if argument['type'] == 'THSTensor*':
# only enable for a subset of Dense/Sparse ops
if not (option.get('aten_dense_sparse', False)):
raise NYIError("Sparse Tensor")
if include_constants and argument['type'] == 'CONSTANT':
insert(argument)
elif is_real_argument_to_wrapper(argument):
insert(argument)
if any(has_output_mask(arg) for arg in option['arguments']):
mask_size = sum(has_output_mask(arg) for arg in option['arguments'])
insert({
'name': 'output_mask',
'type': 'std::array<bool, {}>'.format(mask_size),
'default': '{' + ', '.join(['true'] * mask_size) + '}',
})
result = pos_args + kwd_args
return [translate_formal(argument, option) for argument in result]
def get_return_types(option):
ret = option['return']
if ret['kind'] == 'arguments':
argument_indices = ret['arguments']
if len(argument_indices) == 1:
the_arg = option['arguments'][argument_indices[0]]
return [to_return_type(the_arg, option)]
else:
return [to_return_type(option['arguments'][idx], option)
for idx in argument_indices]
elif ret['kind'] == 'type':
return [{
'type': TYPE_RETURN.get(ret['type'], ret['type']),
'dynamic_type': DYNAMIC_TYPE.get(ret['type'], ret['type']),
}]
else:
raise Exception("format_return_type")
def format_return_type(return_types):
if len(return_types) == 1:
return return_types[0]['type']
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
def find_dispatch_tensor(formals):
# dispatch to self if it's a parameter
for formal in formals:
if formal['name'] == 'self' and formal['dynamic_type'] == 'Tensor':
return formal['name']
# otherwise dispatch to the first Tensor or TensorList
for formal in formals:
if 'TensorList' == formal['dynamic_type'] or formal['dynamic_type'] == 'Tensor':
return formal['name']
return None
def format_formal(f):
return '{} {}'.format(f['type'], f['name'])
def formal_with_default(f):
s = format_formal(f)
v = f.get('default')
if v is None:
return s
if isinstance(v, bool):
v = str(v).lower()
return '{}={}'.format(s, v)
def get_broadcast_argument(option):
for argument in option['arguments']:
if argument.get('broadcast'):
return argument
def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims):
# return the actuals that will be passed to the broadcast function.
# 1) in the common case, this is the broadcasted argument (e.g. "self") followed by the tensors
# that it is broadcasted against (comma-separated) (e.g. "self, tensor1, tensor2").
# 2) in the broadcast_dims case, this is the broadcasted argument (e.g. "self") followed by the sizes
# it is broadcasted to (as an initializer list), so e.g. the specification
# "mat1.dim0,mat2.dim1" gets transformed to "self, {mat1.size(0),mat2.size(1)}"
if not broadcast_dims:
broadcast_actuals = [broadcast_arg['name']] + broadcast_arg['broadcast'].split()[0].split(",")
else:
broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',')
# generate size call for each dimension
broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')'
for x in broadcast_dims_spec])
broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}'
broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list]
return broadcast_actuals
excluded_names = set()
def process_option(option, output_options):
option['inplace'] = re.search(
'(^__i|[^_]_$)', option['api_name']) is not None
if re.match(EXCLUDE_PATTERN, option['name']):
excluded_names.add(option['name'])
raise NYIError("NYI")
# print(yaml.dump(option))
formals = get_formals(option)
option['formals_list'] = formals
option['formals'] = [format_formal(f) for f in formals]
option['formals_with_defaults'] = [formal_with_default(f) for f in formals]
option['returns'] = get_return_types(option)
option['return_type'] = format_return_type(option['returns'])
option['return_call'] = 'return ' if option['return_type'] != 'void' else ''
option['actuals'] = [f['name'] for f in formals]
option['method_formals'] = [format_formal(f) for f in formals
if f['name'] != 'self']
option['method_formals_with_defaults'] = (
[formal_with_default(f) for f in formals if f['name'] != 'self'])
option['method_actuals'] = [
f['name'] if f['name'] != 'self' else '*this' for f in formals]
option['const_mark'] = '' if option['inplace'] else ' const'
is_method = 'method' in option['variants']
is_function = 'function' in option['variants']
dispatch_tensor = find_dispatch_tensor(formals)
is_namespace_function = is_function and dispatch_tensor is not None
# method-only things are prefixed with m_ in Type so that
# another function-only variant can exist without the name colliding
option['method_prefix'] = 'm_' if is_method and not is_function else ''
option['method_prefix_derived'] = option['method_prefix']
env = nested_dict(option, top_env)
broadcast_arg = get_broadcast_argument(option)
if broadcast_arg is None:
top_env['type_method_declarations'].append(
TYPE_METHOD_DECLARATION.substitute(env))
top_env['type_method_definitions'].append(
TYPE_METHOD_DEFINITION.substitute(env))
else:
top_env['type_method_declarations'].append(
TYPE_METHOD_DECLARATION_NON_VIRTUAL.substitute(env))
# "s_" for "same size".
option['method_prefix_derived'] = 's_' + option['method_prefix']
same_size_option = option.copy()
same_size_option['method_prefix'] = option['method_prefix_derived']
same_size_env = nested_dict(same_size_option, top_env)
top_env['type_method_declarations'].append(
TYPE_METHOD_DECLARATION.substitute(same_size_env))
top_env['type_method_definitions'].append(
TYPE_METHOD_DEFINITION.substitute(same_size_env))
broadcast_inplace = 'inplace' in broadcast_arg['broadcast']
broadcast_dims = 'dims:' in broadcast_arg['broadcast']
option['broadcast_actuals'] = get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims)
if not broadcast_dims:
option['broadcast_returns'] = (["b_" + x for x in option['broadcast_actuals']
if x != broadcast_arg['name'] or not broadcast_inplace])
else:
option['broadcast_returns'] = ["b_" + broadcast_arg['name']]
option['broadcast_function'] = 'expand_' + ('inplace' if broadcast_inplace
else 'size' if broadcast_dims else 'outplace')
option['broadcast_modified_actuals'] = ['b_' + y if 'b_' + y in option['broadcast_returns'] else y
for y in option['actuals']]
top_env['type_method_definitions'].append(
TYPE_METHOD_DEFINITION_BROADCAST.substitute(env))
method_of = ['Type']
if is_method:
top_env['tensor_method_declarations'].append(
TENSOR_METHOD_DECLARATION.substitute(env))
top_env['tensor_method_definitions'].append(
TENSOR_METHOD_DEFINITION.substitute(env))
method_of.append('Tensor')
if is_namespace_function:
option['inferred_type'] = 'infer_type({})'.format(dispatch_tensor)
top_env['function_declarations'].append(
FUNCTION_DECLARATION.substitute(env))
top_env['function_definitions'].append(
FUNCTION_DEFINITION.substitute(env))
method_of.append('namespace')
output_options.append({
'name': option['api_name'],
'method_prefix': option['method_prefix_derived'],
'arguments': formals,
'method_of': method_of,
'mode': option['mode'],
'returns': option['returns'],
'inplace': option['inplace'],
})
output_declarations = []
for declaration in declarations:
output_options = []
for option in declaration['options']:
try:
process_option(option, output_options)
except NYIError:
option['skip'] = True
output_declarations.extend(output_options)
print("ATen Excluded: {}".format(excluded_names))
return output_declarations
def create_derived(backend_type_env, declarations):
type_object_declarations = []
type_object_definitions = []
is_cuda = 'CUDA' in backend_type_env['Backend']
def replace_with_null(argument):
return (argument['type'] == 'THGenerator*' and
backend_type_env['Backend'] == 'CUDA')
def requires_checked_cast(argument):
if argument['type'] == 'IntList':
return 'size' in argument
return argument['type'] in CHECKED_CAST
def nullable_argument(argument):
return argument.get('is_nullable', False)
def bool_option_is_string(argument):
return 'if_true' in argument and isinstance(argument['if_true'], string_type)
def get_argument(argument, option):
if replace_with_null(argument):
return 'NULL'
elif requires_checked_cast(argument):
checked_use = CHECKED_USE.get(
argument['type'], '{}_').format(argument['name'])
if nullable_argument(argument):
checked_use = CHECKED_USE_NULLABLE.substitute(
env={}, arg_name=argument['name'], usage=checked_use)
return checked_use
elif argument['type'] == 'bool' and 'if_true' in argument:
if bool_option_is_string(argument):
tpl = '({}) ? "{}" : "{}"'
else:
tpl = '({}) ? {} : {}'
return tpl.format(argument['name'],
argument['if_true'], argument['if_false'])
elif argument['type'] == 'CONSTANT':
# this is a bool that is actually a string...
if bool_option_is_string(argument):
return '"{}"'.format(argument['name'])
v = str(argument.get('default', argument['name']))
for pattern, replacement in CONSTANT_REPLACEMENTS:
v = re.sub(pattern, replacement, v)
return CodeTemplate(v).substitute(backend_type_env)
# e.g. argument 0, i.e. repeat the 0th argument in this position...
elif argument['type'] == 'argument':
index = int(argument['name'])
return get_argument(option['arguments'][index], option)
else:
return argument['name']
def drop_argument(argument, option):
return 'CUDA' in backend_type_env['Backend'] and (
(option['mode'] == 'TH' and argument['type'] == 'THGenerator*') or
argument.get('default') == 'THPDefaultGenerator->cdata')
def get_arguments(arguments, option):
return [get_argument(argument, option)
for argument in arguments if not drop_argument(argument, option)]
def is_actual_return_long(ret):
if ret['type'] == 'long':
return True
if ret['type'] == 'real':
return backend_type_env['ScalarName'] == 'Long'
if ret['type'] == 'accreal':
return backend_type_env['AccScalarName'] == 'Long'
return False
def handle_zero_dim(env, option):
if 'zero_dim_dispatch_when_scalar' not in option:
return []
check_name = option['zero_dim_dispatch_when_scalar']
zero_dim_actuals = [arg['name']
if arg['name'] != check_name else "Scalar({})".format(arg['name'])
for arg in option['formals_list']]
return [ZERO_DIM_CHECK.substitute(env, check_name=check_name, zero_dim_actuals=zero_dim_actuals)]
def handle_sparse(env, option):
if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']:
return []
check_name = option['when_sparse_dispatch']
sparse_actuals = [arg['name']
if arg['name'] != check_name else "SparseTensor({})".format(arg['name'])
for arg in option['formals_list']]
return [SPARSE_CHECK.substitute(env, check_name=check_name, sparse_actuals=sparse_actuals)]
def handle_buffers(env, option):
if 'buffers' not in option:
return []
return [BUFFER_DEFINITION.substitute(env, name=b['name'])
for b in option['buffers']]
def allocate_arg(env, arg, output_count):
name = arg['name']
allocation = CodeTemplate(ALLOC_WRAP[arg['type']]).substitute(env)
if arg.get('mask', False):
allocation = 'output_mask[{}] ? {} : nullptr'.format(output_count, allocation)
return [
'auto {}_ = {};'.format(name, allocation),
'auto {} = Tensor({}_,false);'.format(name, name),
]
def resize_arg(arg):
resize = arg['resize']
if isinstance(resize, str):
return "{}.resize_({}.sizes());".format(arg['name'], resize)
else:
dims = ['{}.size({})'.format(name, dim) for name, dim in resize]
return "{}.resize_({{ {} }});".format(arg['name'], ','.join(dims))
def handle_call(env, option, cimpl):
is_nn = option['mode'] == 'NN'
actuals = get_arguments(cimpl['arguments'], option)
if is_cuda or is_nn:
actuals = ['context->thc_state'] + actuals
cname = cimpl['cname']
if option.get('sparse', False):
if is_cuda:
cname = 'THCS' + env['ScalarName'] + "Tensor_" + cname
else:
cname = env['THTensor'].replace('TH', 'THS') + '_' + cname
elif is_nn:
cname = 'THNN_{}'.format(env['THType']) + cname
else:
cname = env['THTensor'] + '_' + cname
call = CALL_TEMPLATE.substitute(actuals=actuals, cname=cname)
if cimpl.get('condition') is not None:
call = 'if ({}) {}'.format(cimpl['condition'], call)
return call
def emit_body(env, option):
body = []
body += handle_sparse(env, option)
body += handle_zero_dim(env, option)
body += handle_buffers(env, option)
# arguments are potentially duplicated because of one argument
# referencing another
seen_names = set()
seen_tensorlists = set()
count = 0
output_count = 0
# scalar_check is the heuristic conditions when a result may be a scalar_check
# if there is a THSize* argument, then its dimensions are used to determine scalar.
# otherwise, it is true if all the input tensors are scalars,
scalar_check_is_from_size = False
scalar_check = None
for arg in option['arguments']:
if is_real_argument_to_wrapper(arg):
count += 1
if arg['type'] == 'THSize*':
scalar_check_is_from_size = True
scalar_check = '{}.size() == 0'.format(arg['name'])
if arg['type'] == 'TensorList':
seen_tensorlists.add(arg['name'])
wrap_dim_arg = arg.get('wrap_dim', None)
if wrap_dim_arg is not None:
# wrap_dim specification can have (add) expressions, e.g. self+1
wrap_dim_params = wrap_dim_arg.split("+")
# for Tensors, "name_" is the TensorImpl, but for TensorLists, it is an
# std::vector of TH*s. Since TH*s have different dimension rules, we used
# "name" instead, but keep "name_" for tensor to avoid an extra function call.
if wrap_dim_params[0] not in seen_tensorlists:
wrap_dim_params[0] = wrap_dim_params[0] + "_"
wrap_dim_target = wrap_dim_params[0]
wrap_dim_toadd = 0 if len(wrap_dim_params) == 1 else wrap_dim_params[1]
body.append("{} = maybe_wrap_dim({}, {}, {});"
.format(arg['name'], arg['name'], wrap_dim_target, wrap_dim_toadd))
# only generated checked casts the first time we see it
if arg['name'] not in seen_names and requires_checked_cast(arg):
seen_names.add(arg['name'])
# make a new allocation of TensorImpl, then wrap a Tensor around it.
if arg.get('allocate', False):
body += allocate_arg(env, arg, output_count)
output_count += 1
# extract the TensorImpl from an existing tensor (or Storage, etc.)
else:
# special case where we allow undefined Tensors, and thus
# the checked cast succeeds even if the Tensor is not
# defined
null_okay = 'true' if nullable_argument(arg) else 'false'
default_init = []
if 'default_init' in arg:
default_init.append(arg['default_init'])
check_cast = CHECKED_CAST[arg['type']].substitute(
env, arg_name=arg['name'], arg_pos=count,
null_okay=null_okay, default_init=default_init,
size=arg.get('size'))
body.append("auto {}_ = {};".format(
arg['name'], check_cast))
if drop_argument(arg, option) or replace_with_null(arg):
body.append(
"(void) {}_; //silence unused warning".format(arg['name']))
initializers = []
# resize tensors for special ops that require it
if 'resize' in arg:
initializers.append(resize_arg(arg))
# also special handling where we zero some outputs.
if arg.get('zero', False) or (arg.get('cpu_zero', False) and not is_cuda):
initializers.append("{}.zero_();".format(arg['name']))
# only initialize non-null arguments
if nullable_argument(arg) and len(initializers) > 0:
body.append(CONDITIONAL_INITIALIZER.substitute({
'name': arg['name'],
'initializer': initializers
}))
else:
body += initializers
# isScalar() for all input tensors is and'd to form
# the test for whether the output is also a scalar
if (not arg.get('output') and 'Tensor' in arg['type'] and
'TensorList' not in arg['type'] and
'THS' not in arg['type'] and
not scalar_check_is_from_size):
check = '{}->isScalar()'.format(arg['name'] + '_')
if nullable_argument(arg):
check = '(!{} || {})'.format(arg['name'] + '_', check)
scalar_check = (check if scalar_check is None
else scalar_check + ' && ' + check)
# cimpls, if it exists, contains the underlying C function names and
# arguments. Otherwise use option
cimpls = option.get('cimpls', [option])
calls = [handle_call(env, option, cimpl) for cimpl in cimpls]
ret = option['return']
if ret['kind'] == 'arguments':
if 'aten_custom_call' in option:
# all aten_custom_call bodies handle settings on their own.
scalar_check = None
body.append(CodeTemplate(
option['aten_custom_call']).substitute(env))
else:
body.extend([call + ';' for call in calls])
arguments_indices = ret['arguments']
arguments = [option['arguments'][argi]
for argi in arguments_indices]
if scalar_check is not None:
if len(arguments) > 1:
body.append("bool maybe_scalar = {};".format(scalar_check))
scalar_check = 'maybe_scalar'
for arg in arguments:
stmt = "{}_->maybeScalar({});".format(arg['name'], scalar_check)
if nullable_argument(arg):
stmt = "if ({}_) {}".format(arg['name'], stmt)
body.append(stmt)
if len(arguments_indices) == 1:
arg = arguments[0]
body.append("return {};".format(arg['name']))
else:
types = [to_return_type(arg, option)['type']
for arg in arguments]
# TODO: check for move semantics...
names = [arg['name'] for arg in arguments]
body.append(CodeTemplate("return std::tuple<${types}>(${names});").substitute(
types=types, names=names))
elif ret['kind'] == 'type':
assert len(calls) == 1
call = calls[0]
if ret['type'] == 'THTensor*':
maybe_scalar = "->maybeScalar({})".format(scalar_check) \
if scalar_check is not None \
else ""
return_tensor = "return Tensor((new ${Tensor}(context,${arg_name}))${maybe_scalar},false);"
body.append(CodeTemplate(return_tensor).substitute(
env, arg_name=call, maybe_scalar=maybe_scalar))
else:
# we using int64_t for long in the API, so correct it here...
if is_actual_return_long(ret):
call = "static_cast<int64_t>({})".format(call)
body.append("return {};".format(call))
else:
raise Exception("NYI - return handling")
return body
def process_option(option):
pair = (backend_type_env['Backend'],
backend_type_env['ScalarName'])
if pair in option['backend_type_pairs']:
env = nested_dict(option, backend_type_env)
body = emit_body(env, option)
option['type_definition_body'] = body
type_object_declarations.append(
TYPE_DERIVED_DECLARATION.substitute(env))
type_object_definitions.append(
TYPE_DERIVED_DEFINITION.substitute(env))
for declaration in declarations:
for option in declaration['options']:
if not option.get('skip', False):
try:
process_option(option)
except NYIError:
pass
return type_object_declarations, type_object_definitions