blob: d9d95df100018234ddf13d04ad25789c53ed0af4 [file] [log] [blame]
# HEY! Trying to understand what this file does? Read
# "what has to be done to add a Operation ..." first!
import re
from code_template import CodeTemplate
try:
import typing # noqa: F401
except ImportError:
raise RuntimeError(
'Missing build dependency: Unable to import the `typing` module. '
'Please install it via `conda install typing` or `pip install typing`')
# flake8 doesn't take into account usages in type annotations.
from typing import Union, Set # noqa: F401
from typing import Any, Dict, List, Optional, Tuple, NamedTuple
try:
from mypy_extensions import TypedDict
except ImportError:
# Avoid the dependency on the mypy_extensions package.
# It is required, however, for type checking.
def TypedDict(name, attrs, total=True): # type: ignore
return Dict[Any, Any]
import sys
if sys.version_info[0] == 3:
string_type = str
else:
string_type = basestring
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# what has to be done to add a Operation ...
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# TH functions are generated into at::legacy::cpu and at::legacy::cuda,
# where they can be called directly by a native function, they can be wrapped
# by a native function that handles dispatch
# Handle broadcasting for TH functions that need it
LEGACY_TH_DECLARATION_BROADCAST = CodeTemplate("""\
${return_type} ${api_name}(${type_method_formals});
""")
LEGACY_TH_DEFINITION_BROADCAST = CodeTemplate("""\
${return_type} ${api_name}(${type_method_formals}) {
${named_guard_declaration}
${device_guard_declaration}
Tensor ${broadcast_returns};
std::tie(${broadcast_returns}) = ${broadcast_function}(${broadcast_actuals}, "${api_name}");
return ${method_prefix_derived}${api_name}(${broadcast_modified_actuals});
}
""")
LEGACY_TH_DECLARATION = CodeTemplate("""\
${return_type} ${method_prefix_derived}${api_name}(${type_method_formals});
""")
LEGACY_TH_DEFINITION = CodeTemplate("""\
${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) {
${named_guard_declaration}
${device_guard_declaration}
${type_definition_body}
}
""")
LEGACY_TH_DEFINITION_SWITCH_STATEMENT = CodeTemplate("""\
${dispatch_scalar_type_declaration}
switch (dispatch_scalar_type) {
${cases}
default:
AT_ERROR("${api_name} not supported on ${Type} for ", dispatch_scalar_type);
}
""")
LEGACY_TH_DEFINITION_CASE = CodeTemplate("""\
case ScalarType::${ScalarName}: {
${case_body}
break;
}
""")
# Native functions are generated and registered on the dispatcher. We register the
# function on Backend::Undefined if it does not have backend dependent dispatch.
# In this case, it will be called for all backends, but can be overwritten on a
# per backend basis.
NATIVE_DISPATCH_DECLARATION = CodeTemplate("""\
${return_type} ${api_name}(${type_method_formals});
""")
NATIVE_DISPATCH_DEFINITION_DEFAULT = CodeTemplate("""\
${return_type} ${api_name}(${type_method_formals}) {
${named_guard_declaration}
${device_guard_declaration}
${return_call} at::native::${native_type_method_dispatch}(${native_actuals});
}
""")
NATIVE_DISPATCH_DEFINITION_BACKEND = CodeTemplate("""\
${return_type} ${api_name}(${type_method_formals}) {
${named_guard_declaration}
${device_guard_declaration}
${return_call} at::native::${native_type_method_dispatch}(${native_actuals});
}
""")
DEFAULT_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\
.op(torch::RegisterOperators::options()
.schema("${schema_string}")
.impl_unboxedOnlyCatchAllKernel<${return_type} (${formals_types}), &TypeDefault::${api_name}>()
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))
""")
BACKEND_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\
.op(torch::RegisterOperators::options()
.schema("${schema_string}")
.impl_unboxedOnlyKernel<${return_type} (${formals_types}), &${Type}::${api_name}>(DispatchKey::${Backend}TensorId)
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))
""")
DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\
.op(torch::RegisterOperators::options()
.schema("${schema_string}")
.catchAllKernel<${return_type} (${formals_types})>(&TypeDefault::${api_name})
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))
""")
DEFAULT_SCHEMA_REGISTRATION = CodeTemplate("""\
.op(torch::RegisterOperators::options()
.schema("${schema_string}")
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))
""")
BACKEND_FUNCTION_REGISTRATION = CodeTemplate("""\
.op(torch::RegisterOperators::options()
.schema("${schema_string}")
.kernel<${return_type} (${formals_types})>(DispatchKey::${Backend}TensorId, &${Type}::${api_name})
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))
""")
# add non-virtual declaration to TensorBody.h
TENSOR_METHOD_DECLARATION = CodeTemplate("""\
${return_type} ${api_name}(${method_formals_with_defaults}) const;
""")
# add non-virtual declaration to Tensor.cpp
C10_TENSOR_METHOD_DEFINITION = CodeTemplate("""\
inline ${return_type} Tensor::${api_name}(${method_formals}) const {
#ifdef USE_STATIC_DISPATCH
${static_dispatch_method_body}
#else
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::${operator_name}", "${overload_name}");
return op.callUnboxed<${formals_types_with_return}>(${method_actuals});
#endif
}
""")
# add a method declaration in Functions.h
FUNCTION_DECLARATION = CodeTemplate("""\
static inline ${return_type} ${api_name}(${formals_with_defaults});
""")
# add a method declaration in Functions.h
DEPRECATED_FUNCTION_DECLARATION = CodeTemplate("""\
C10_DEPRECATED static inline ${return_type} ${api_name}(${formals_with_defaults});
""")
# add method definition in Functions.h
C10_FUNCTION_DEFINITION = CodeTemplate("""\
static inline ${return_type} ${api_name}(${formals}) {
#ifdef USE_STATIC_DISPATCH
${static_dispatch_function_body}
#else
static c10::OperatorHandle op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::${operator_name}", "${overload_name}");
return op.callUnboxed<${formals_types_with_return}>(${native_actuals});
#endif
}
""")
# In order to rely on the linker to strip unused ops, it requires us to dispatch statically
# in Functions.h and TensorMethods.h.
#
# NB: The default body also needs to apply a variable guard, as in some
# situations what we think is a default body actually does have an
# explicit derivative, and thereby would have gotten unwrapped by
# the time you get to the implementation.
STATIC_DISPATCH_FUNCTION_DEFAULT_BODY = CodeTemplate("""\
at::AutoNonVariableTypeMode _var_guard(true);
${return_call} TypeDefault::${native_type_method_dispatch}(${native_arguments});
""")
STATIC_DISPATCH_FUNCTION_SWITCH_BODY = CodeTemplate("""\
at::AutoNonVariableTypeMode _var_guard(true);
switch(dispatchKeyToBackend(c10::impl::dispatchTypeId(${key_set}, c10::DispatchKeySet(c10::DispatchKeySet::FULL)))) {
${static_dispatch_function_switches}
default:
AT_ERROR("${api_name} not implemented for ", at::toString(${key_set}));
}
""")
STATIC_DISPATCH_FUNCTION_SWITCH_STATEMENT = CodeTemplate("""\
case Backend::${backend}:
${return_call} ${backend}Type::${api_name}(${native_arguments});
break;
""")
# add a native declaration for a native function
NATIVE_DECLARATION = CodeTemplate("""\
CAFFE2_API ${return_type} ${native_type_method_dispatch}(${formals_with_defaults});
""")
# special method definition for factory functions in Functions.h that initializes backends
C10_FACTORY_DEFINITION = CodeTemplate("""\
static inline ${return_type} ${api_name}(${formals}) {
#ifdef USE_STATIC_DISPATCH
${static_dispatch_function_body}
#else
globalLegacyTypeDispatch().initForDispatchKeySet(${inferred_key_set});
static c10::OperatorHandle op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::${operator_name}", "${overload_name}");
return op.callUnboxed<${formals_types_with_return}>(${native_actuals});
#endif
}
""")
ZERO_DIM_CHECK = CodeTemplate("""\
if (${check_name}.dim() == 0) {
return ${api_name}(${zero_dim_actuals});
}""")
SPARSE_CHECK = CodeTemplate("""\
if(${check_name}.is_sparse()) {
return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${sparse_actuals});
}""")
CONDITIONAL_INITIALIZER = CodeTemplate("""\
if (${name}.defined()) {
${initializer}
}""")
CALL_TEMPLATE = CodeTemplate("${cname}(${actuals})")
OPERATOR_NAME = CodeTemplate("aten::${operator_name}")
OPERATOR_NAME_FULL = CodeTemplate("""\
{"aten::${operator_name}", "${overload_name}"},
""")
# scalar_name, c_type, accreal, is_floating_type
scalar_types = [
('Bool', 'bool', 'BoolAccrealNotDefined', False),
('Byte', 'uint8_t', 'Long', False),
('Char', 'int8_t', 'Long', False),
('Double', 'double', 'Double', True),
('Float', 'float', 'Double', True),
('Int', 'int', 'Long', False),
('Long', 'int64_t', 'Long', False),
('Short', 'int16_t', 'Long', False),
('Half', 'Half', 'Double', True),
('BFloat16', 'BFloat16', 'BFloat16AccrealNotDefined', True),
]
static_dispatch_backends = ['CPU', 'QuantizedCPU', 'SparseCPU']
class NYIError(Exception):
"""Indicates we don't support this declaration yet"""
__slots__ = ['reason']
def __init__(self, reason):
self.reason = reason
TYPE_FORMAL_GENERIC = {
'THTensor*': 'Tensor &',
'THByteTensor*': 'Tensor &',
'THIndexTensor*': 'Tensor &',
'THBoolTensor*': 'Tensor &',
'THStorage*': 'Storage',
'THGenerator*': 'Generator *',
'IntArrayRefSize': 'IntArrayRef',
'accreal': 'Scalar',
'real': 'Scalar',
'long': 'int64_t',
}
DYNAMIC_TYPE = {
'THTensor*': 'Tensor',
'THByteTensor*': 'ByteTensor',
'THBoolTensor*': 'BoolTensor',
'THIndexTensor*': 'IndexTensor',
'THStorage*': 'Storage',
'THGenerator*': 'Generator*',
'IntArrayRefSize': 'IntArrayRef',
'accreal': 'accreal',
'real': 'real',
'long': 'int64_t',
}
NATIVE_DYNAMIC_TYPE = {
'Tensor &': 'Tensor',
'const Tensor &': 'Tensor',
}
TYPE_RETURN = {
'THTensor*': 'Tensor',
'THIndexTensor*': 'Tensor',
'THByteTensor*': 'Tensor',
'THBoolTensor*': 'Tensor',
'real': 'Tensor',
'accreal': 'Tensor',
'long': 'int64_t',
}
CHECKED_CAST = {
'THTensor*':
CodeTemplate(
'checked_dense_tensor_unwrap('
'${arg_name}, "${arg_name}", ${arg_pos}, "${api_name}", ${null_okay}, '
'DeviceType::${DeviceType}, ScalarType::${ScalarName})'),
'THByteTensor*':
CodeTemplate(
'checked_dense_tensor_unwrap('
'${arg_name}, "${arg_name}", ${arg_pos}, "${api_name}", ${null_okay}, '
'DeviceType::${DeviceType}, ScalarType::Byte)'),
'THBoolTensor*':
CodeTemplate(
'checked_dense_tensor_unwrap('
'${arg_name}, "${arg_name}", ${arg_pos}, "${api_name}", ${null_okay}, '
'DeviceType::${DeviceType}, ScalarType::Bool)'),
'THIndexTensor*':
CodeTemplate(
'checked_dense_tensor_unwrap('
'${arg_name}, "${arg_name}", ${arg_pos}, "${api_name}", ${null_okay}, '
'DeviceType::${DeviceType}, ScalarType::Long)'),
'THStorage*':
CodeTemplate(
'checked_storage('
'${arg_name}, "${arg_name}", ${arg_pos}, '
# We're punning here (Backend and DeviceType constructors coincide)
# but DeviceType is the correct way to classify storages
'DeviceType::${Backend}, at::scalarTypeToTypeMeta(ScalarType::${ScalarName}))'),
# This is a cast done via direct-construction
'IntArrayRefStride': CodeTemplate('at::IntArrayRef ${result_name} = get_intlist_stride_th(${arg_name});'),
'real': CodeTemplate('${arg_name}.to${ScalarName}()'),
'accreal': CodeTemplate('${arg_name}.to${AccScalarName}()'),
'TensorList': CodeTemplate(
'checked_tensor_list_unwrap(${arg_name},"${arg_name}",${arg_pos}, '
'Backend::${Backend}, ScalarType::${ScalarName})'),
'IntArrayRef': CodeTemplate('check_intlist<${size}>(${arg_name}, "${arg_name}", ${arg_pos})')
}
CHECKED_USE = {
'THTensor*': '{}_',
'THIndexTensor*': '{}_',
'THByteTensor*': '{}_',
'THBoolTensor*': '{}_',
'THStorage*': '{}_.unsafeGetStorageImpl()',
'TensorList': "{0}_.data(), {0}_.size()",
}
CHECKED_USE_NULLABLE = CodeTemplate('${arg_name}_ ? ${usage} : NULL')
ALLOC_NOARGS_WRAP = {
'THTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
'(c10::Storage(caffe2::TypeMeta::Make<${ScalarType}>(), 0, allocator(), true),'
'DispatchKey::${Backend}TensorId).release()',
'THByteTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
'(c10::Storage(scalarTypeToTypeMeta(ScalarType::Byte), 0, allocator(), true),'
'DispatchKey::${Backend}TensorId).release()',
'THBoolTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
'(c10::Storage(scalarTypeToTypeMeta(ScalarType::Bool), 0, allocator(), true),'
'DispatchKey::${Backend}TensorId).release()',
'THIndexTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
'(c10::Storage(scalarTypeToTypeMeta(ScalarType::Long), 0, allocator(), true),'
'DispatchKey::${Backend}TensorId).release()',
}
ALLOC_WRAP = {
'THTensor*': '${arguments}',
'THByteTensor*': '${arguments}',
'THBoolTensor*': '${arguments}',
'THIndexTensor*': '${arguments}',
}
# Replacements for constants when calling into TH
CONSTANT_REPLACEMENTS = [
('AS_REAL', '${ScalarType}'),
]
# Replacements for constants in header file function definitions
HEADER_CONSTANT_REPLACEMENTS = [
(r'AS_REAL\((.*)\)', r'\1'),
]
class nested_dict(object):
def __init__(self, base, parent):
self.base, self.parent = base, parent
def __getitem__(self, x):
r = self.base.get(x)
if r is not None:
return r
return self.parent[x]
Environment = TypedDict('Environment', {
'state': str,
'ScalarType': str,
'ScalarName': str,
'THTensor': str,
'THType': str,
'Backend': str,
'DeviceType': str,
'AccScalarName': str,
})
TopEnvironment = TypedDict('TopEnvironment', {
'type_registrations': List[str],
'type_headers': List[str],
'function_registrations': List[str],
'list_of_aten_ops': List[str],
'type_method_declarations': List[str],
'type_method_definitions': List[str],
'tensor_method_declarations': List[str],
'tensor_method_definitions': List[str],
'function_declarations': List[str],
'function_definitions': List[str],
'type_ids': List[str],
'native_function_declarations': List[str],
})
# A Declarations.cwrap formal argument
# type can contain THTensor* types
THFormal = TypedDict('THFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'output': bool,
'size': int,
'allocate': bool,
'mask': bool,
'wrap_dim': str,
# Broadcast is originally a str but gets unwrapped to a List or Dict in-place
'broadcast': Any,
'resize': str,
'cpu_zero': bool,
'zero': bool,
}, total=False)
# Generic ATen formal or native_functions.yaml formal argument.
# type can contain Tensor& reference types.
AtFormal = TypedDict('AtFormal', {
'name': str,
'type': str,
'dynamic_type': str,
'kwarg_only': bool,
'is_nullable': bool,
'default': str,
'output': bool,
'size': int,
}, total=False)
# Note [field_name versus name]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# What is the difference between "field_name" and "name"?
#
# Return values of ATen operators always have a name: if it is not
# explicitly assigned a name inside native_functions.yaml like func:
# myop() -> (Tensor indices, Tensor value), then the codegen will
# automatically assign it a name like result0, or name might be
# specified inside Declarations.cwrap. We don't want these assigned
# names to become part of the public API when we return a namedtuple for
# any such multiple-return function.
#
# Thus field_name is like name, but it is defined only when there is a
# name specified in native_functions.yaml. If field_name is defined,
# then the codegen would generate code to return namedtuple. Otherwise,
# it would just return tuple.
ReturnType = TypedDict('ReturnType', {
'name': str,
# See Note [field_name versus name]
'field_name': str,
'type': str,
'dynamic_type': str,
}, total=False)
ReturnDecl = TypedDict('ReturnDecl', {
'kind': str,
'type': str,
'arguments': List[int],
}, total=False)
# Represents a buffer in nn.yaml
NNBuffer = TypedDict('NNBuffer', {
'name': str,
})
FunctionOption = TypedDict('FunctionOption', {
'actuals': List[str],
'api_name': str,
'arguments': List[THFormal],
'backend_types': Dict[str, List[str]],
'backends': List[str],
'broadcast_actuals': List[str],
'broadcast_function': str,
'broadcast_modified_actuals': List[str],
'broadcast_returns': List[str],
'buffers': List[NNBuffer],
# cimpls is really a List[FunctionOption]
'cimpls': List[Any],
'cname': str,
# explicitly specify whether the function is a factory function or other special category
'category_override': str,
'condition': str,
'device_guard': bool,
'device_guard_declaration': str,
'dispatch_scalar_type_declaration': str,
'use_c10_dispatcher': str,
'manual_kernel_registration': bool,
'with_gil': bool,
'cpu_half': bool,
'cpu_bfloat16': bool,
'cuda_bfloat16': bool,
'deprecated': bool,
'cpu_bool': bool,
'cuda_bool': bool,
# See Note [field_name versus name]
'field_name': str,
'formals_list': List[AtFormal],
'formals_with_defaults': List[str],
'formals': List[str],
'formals_types': List[str],
'formals_types_with_return': List[str],
'inferred_key_set': str,
'inplace': bool,
'matches_jit_signature': bool,
# This controls whether or not we generate the interface in Type or
# TypeExtendedInterface
'extended_method': bool,
'method_actuals': List[str],
'method_formals_with_defaults': List[str],
'method_formals': List[str],
'method_prefix_derived': str,
'named_guard_declaration': str,
'mode': str,
'python_module': str,
'name': str,
'operator_name': str,
'overload_name': str,
'native_actuals': List[str],
'native_type_method_dispatch': str,
# options should be List[FunctionOption]
'options': Any,
'schema_string': str,
'requires_tensor': bool,
'return_call': str,
'return_type': str,
'return': ReturnDecl,
'returns': List[ReturnType],
'sparse': bool,
'type_definition_body': List[str],
'type_method_actuals': List[str],
'type_method_definition_dispatch': str,
'type_method_formals': List[str],
'variants': str,
'when_spares_dispatch': str,
'when_sparse_dispatch': str,
'with_gil': bool,
'zero_dim_dispatch_when_scalar': str,
})
OutputDeclaration = NamedTuple('OutputDeclaration', [
('name', str),
('operator_name', str),
('overload_name', str),
('use_c10_dispatcher', str),
('manual_kernel_registration', bool),
('category_override', str),
('matches_jit_signature', bool),
('schema_string', str),
('method_prefix_derived', str),
('arguments', List[AtFormal]),
('method_of', List[str]),
('mode', str),
('python_module', str),
('buffers', Optional[List[str]]),
('returns', List[ReturnType]),
('inplace', bool),
('is_factory_method', bool),
('abstract', bool),
('requires_tensor', bool),
('device_guard', bool),
('with_gil', bool),
('deprecated', bool),
])
FunctionCode = NamedTuple('FunctionCode', [
('definition', str),
('declaration', str),
])
OpRegistration = NamedTuple('OpRegistration', [
('operator_name', str),
('registration_code', str),
])
def device_guard(option, dispatch_options, dispatch_tensor):
# For factory methods the `DeviceGuard` is already in the template.
if option.get('device_guard', True):
if dispatch_options:
return 'const DeviceGuard device_guard({}.device());'.format(dispatch_options['name'])
if dispatch_tensor:
return 'const OptionalDeviceGuard device_guard(device_of({}));'.format(dispatch_tensor)
return '// DeviceGuard omitted'
def named_guard(option, tensors, tensorlists):
if option.get('supports_named_tensor', False) or (len(tensors) + len(tensorlists) == 0):
return ''
# Override: supports_named_tensor = False for _th_ functions. This is because:
# There is always some at:: function that calls the _th_ function.
if option['name'].startswith('_th_'):
return ''
named_conditions = []
for tensor in tensors:
named_conditions.append('{}.has_names()'.format(tensor))
for tensorlist in tensorlists:
named_conditions.append('at::has_names({})'.format(tensorlist))
return ("""\
if ({named_conditions}) {{
AT_ERROR(
"{op} is not yet supported with named tensors. Please drop names via "
"`tensor = tensor.rename(None)`, call the op with an unnamed tensor, "
"and set names on the result of the operation.");
}}""".format(named_conditions=' || '.join(named_conditions), op=option['name']))
def dispatch_scalar_type(option, dispatch_options, dispatch_tensor):
if dispatch_options:
return 'auto dispatch_scalar_type = typeMetaToScalarType({}.dtype());'.format(dispatch_options['name'])
if dispatch_tensor:
return 'auto dispatch_scalar_type = infer_scalar_type({});'.format(dispatch_tensor)
return '// dispatch_scalar_type omitted'
def is_real_argument_to_wrapper(argument):
# type: (THFormal) -> bool
return not argument.get('output', False) and\
argument['type'] != 'CONSTANT' and\
argument['type'] != 'argument'
def is_mutable_formal_argument(argument, option):
# type: (THFormal, FunctionOption) -> bool
return argument.get('output') or option['inplace'] and argument['name'] == 'self'
def check_methods_do_not_start_with_underscore(name, is_method):
if name in {'_values', '_indices', '_nnz', '_dimI', '_dimV', '_coalesced_',
'_version'}:
return
if is_method and name.startswith('_') and not name.startswith('__') and not name.startswith('_th_'):
message = "Function '{}' starts with a single underscore and is ".format(name)
message += "configured to have a method on Tensor. Functions that start with "
message += " a single underscore should only be functions in the at:: "
message += "namespace and not methods on Tensor!"
raise RuntimeError(message)
def to_return_type(arg, option):
# type: (THFormal, FunctionOption) -> ReturnType
t = arg['type']
rt = TYPE_RETURN.get(t, t)
if rt == 'Tensor' and not arg.get('allocate'):
rt = rt + ' &'
if not is_mutable_formal_argument(arg, option):
rt = 'const ' + rt
return {
'name': arg['name'],
'type': rt,
'dynamic_type': DYNAMIC_TYPE.get(arg['type'], arg['type']),
}
def create_generic(top_env, declarations):
# type: (TopEnvironment, List[FunctionOption]) -> Tuple[List[OutputDeclaration], List[OpRegistration]]
# translates defaults from cwrap types to C++ values
def translate_default(argument, type_str, default):
# type: (THFormal, str, Any) -> Any
if default is None:
# cause the default constructor for the object to run
return '{}'
for pattern, replacement in HEADER_CONSTANT_REPLACEMENTS:
default = re.sub(pattern, replacement, str(default))
if type_str in {'Scalar', 'int64_t', 'double'}:
try:
return int(default)
except Exception:
try:
return float(default)
except Exception:
return default
elif type_str == 'bool':
assert default.lower() in ['true', 'false']
return default.lower() == 'true'
else:
return default
# change from THTensor* to Tensor & so we get how it will appear
# in the aten argument list...
def translate_formal(argument, option):
# type: (THFormal, FunctionOption) -> AtFormal
type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type'])
if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option):
type_str = 'const ' + type_str
translated = {
'name': argument['name'],
'type': type_str,
'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']),
} # type: AtFormal
if 'kwarg_only' in argument:
translated['kwarg_only'] = argument['kwarg_only']
if 'default' in argument:
default = translate_default(argument, type_str, argument['default'])
translated['default'] = default
if argument.get('output'):
translated['output'] = True
if argument.get('size'):
translated['size'] = argument['size']
if argument.get('is_nullable') is not None:
translated['is_nullable'] = argument['is_nullable']
return translated
def get_formals(option, include_constants=False):
# type: (FunctionOption, bool) -> List[AtFormal]
seen = set() # type: Set[str]
pos_args = [] # type: List[THFormal]
kwd_args = [] # type: List[THFormal]
def insert(argument):
# type: (THFormal) -> None
if argument['name'] not in seen:
seen.add(argument['name'])
if argument.get('kwarg_only', False):
kwd_args.append(argument)
else:
pos_args.append(argument)
def has_output_mask(argument):
# type: (THFormal) -> bool
return argument.get('allocate', False) and argument.get('mask', False)
for argument in option['arguments']:
if argument.get('output') and not argument.get('allocate', False):
insert(argument)
for argument in option['arguments']:
if include_constants and argument['type'] == 'CONSTANT':
insert(argument)
elif is_real_argument_to_wrapper(argument):
insert(argument)
if any(has_output_mask(arg) for arg in option['arguments']):
mask_size = sum(has_output_mask(arg) for arg in option['arguments'])
insert({
'name': 'output_mask',
# NB: Lack of space in comma works around parsing
# problem in gen_variable_type.py
'type': 'std::array<bool,{}>'.format(mask_size),
'default': '{{' + ', '.join(['true'] * mask_size) + '}}',
})
result = pos_args + kwd_args
return [translate_formal(argument, option) for argument in result]
def get_return_types(option):
# type: (FunctionOption) -> List[ReturnType]
ret = option['return']
if ret['kind'] == 'arguments':
argument_indices = ret['arguments']
if len(argument_indices) == 1:
the_arg = option['arguments'][argument_indices[0]]
return [to_return_type(the_arg, option)]
else:
return [to_return_type(option['arguments'][idx], option)
for idx in argument_indices]
elif ret['kind'] == 'type':
return [{
'type': TYPE_RETURN.get(ret['type'], ret['type']),
'dynamic_type': DYNAMIC_TYPE.get(ret['type'], ret['type']),
}]
else:
raise Exception("format_return_type")
def format_return_type(return_types):
# type: (List[ReturnType]) -> str
if len(return_types) == 0:
return "void"
elif len(return_types) == 1:
return return_types[0]['type']
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
def is_any_tensor_type(formal):
return (formal['dynamic_type'] == 'Tensor' or formal['dynamic_type'] == 'ByteTensor'
or formal['dynamic_type'] == 'IndexTensor' or formal['dynamic_type'] == 'BoolTensor')
def find_tensors(formals):
# type: (List[AtFormal]) -> List[str]
return [formal['name'] for formal in formals if is_any_tensor_type(formal)]
def find_tensorlists(formals):
# type: (List[AtFormal]) -> List[str]
return [formal['name'] for formal in formals if formal['dynamic_type'] == 'TensorList']
def find_dispatch_tensor(formals):
# type: (List[AtFormal]) -> Optional[str]
# Determine legacy TH-style single dispatch tensor.
#
# Also used to determine what tensor should be used to provide a default
# DeviceGuard. Unlike dispatch, we don't guard on ALL tensor arguments
# (because this is not actually a thing you can do.) Guarding on the
# first argument is best effort to help people avoid doing this
# themselves.
for formal in formals:
if formal['name'] == 'self' and is_any_tensor_type(formal) and not formal.get('is_nullable', False):
return formal['name']
# otherwise dispatch to the first Tensor or TensorList
for formal in formals:
if 'TensorList' == formal['dynamic_type'] or is_any_tensor_type(formal) and \
not formal.get('is_nullable', False):
return formal['name']
return None
def find_multidispatch_tensors(formals):
# type: (List[AtFormal]) -> List[str]
# Compute the list of all tensor arguments which should be considered
# for multiple dispatch. Note that this doesn't completely replace
# find_dispatch_tensor because we use the "dispatch tensor" to determine
# device guards. TensorOptions is included as part of this calculation.
#
# The interaction of multiple dispatch with TensorOptions
# is quite interesting. In particular, suppose I have:
#
# cuda_tensor.new_like(1, device='cpu')
#
# Multiple dispatch will attempt a dispatch to CUDA, even though
# the end tensor that should be produced here is a CPU one. The
# upshot is that if you have an operator with mixed TensorOptions
# and Tensor arguments, you MUST only ever register it generically.
r = []
for formal in formals:
if formal['dynamic_type'] in ['TensorOptions', 'TensorList'] or is_any_tensor_type(formal):
r.append(formal['name'])
return r
def format_formal(f):
# type: (AtFormal) -> str
return '{} {}'.format(f['type'], f['name'])
def formal_with_default(f):
# type: (AtFormal) -> str
s = format_formal(f)
v = f.get('default')
if v is None:
return s
if isinstance(v, bool):
v = str(v).lower()
return '{}={}'.format(s, v)
def get_broadcast_argument(option):
# type: (FunctionOption) -> Optional[THFormal]
for argument in option['arguments']:
if argument.get('broadcast'):
return argument
return None
def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims):
# type: (THFormal, bool, bool) -> List[str]
# Note: broadcast_dims can change type...
# return the actuals that will be passed to the broadcast function.
# 1) in the common case, this is the broadcasted argument (e.g. "self") followed by the tensors
# that it is broadcasted against (comma-separated) (e.g. "self, tensor1, tensor2").
# 2) in the broadcast_dims case, this is the broadcasted argument (e.g. "self") followed by the sizes
# it is broadcasted to (as an initializer list), so e.g. the specification
# "mat1.dim0,mat2.dim1" gets transformed to "self, {mat1.size(0),mat2.size(1)}"
if not broadcast_dims:
broadcast_actuals = [broadcast_arg['name']] + broadcast_arg['broadcast'].split()[0].split(",")
else:
broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',')
# generate size call for each dimension
broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')' # type: ignore
for x in broadcast_dims_spec])
broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}' # type: ignore
broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list]
return broadcast_actuals
def process_legacy_th_option(option):
# type: (FunctionOption) -> None
# Mutably populate option with derived values computed from values
# passed in to option.
option['inplace'] = re.search(
'(^__i|[^_]_$)', option['api_name']) is not None
# print(yaml.dump(option))
formals = get_formals(option)
option['formals_list'] = formals
option['formals'] = [format_formal(f) for f in formals]
option['formals_with_defaults'] = [formal_with_default(f) for f in formals]
option['returns'] = get_return_types(option)
option['return_type'] = format_return_type(option['returns'])
option['return_call'] = 'return ' if option['return_type'] != 'void' else ''
option['actuals'] = [f['name'] for f in formals]
option['method_formals'] = [format_formal(f) for f in formals
if f['name'] != 'self']
option['method_formals_with_defaults'] = (
[formal_with_default(f) for f in formals if f['name'] != 'self'])
# *this is 'const Tensor&' since all Tensor methods are const and must
# be const_casted to be accepted as native function's non-const argument
option['method_actuals'] = [
f['name'] if f['name'] != 'self' else 'const_cast<Tensor&>(*this)' for f in formals]
# There are no cases where these differ, but they do in native_functions
option['type_method_formals'] = option['formals']
option['type_method_actuals'] = option['actuals']
assert 'method' not in option['variants'], 'TH functions cannot be methods'
is_function = 'function' in option['variants']
# NB: TH functions don't support multiple dispatch
dispatch_tensor = find_dispatch_tensor(formals)
is_namespace_function = is_function and dispatch_tensor is not None
broadcast_arg = get_broadcast_argument(option)
# "s_" for "same size".
option['method_prefix_derived'] = '' if broadcast_arg is None else 's_'
if option['mode'] == 'TH':
option['device_guard'] = False
option['device_guard_declaration'] = device_guard(option, False, dispatch_tensor)
option['named_guard_declaration'] = named_guard(option, find_tensors(formals),
find_tensorlists(formals))
option['dispatch_scalar_type_declaration'] = dispatch_scalar_type(option, False, dispatch_tensor)
assert option['extended_method'], 'Expected legacy operator to be an extended method'
if broadcast_arg is not None:
broadcast_inplace = 'inplace' in broadcast_arg['broadcast']
broadcast_dims = 'dims:' in broadcast_arg['broadcast']
option['broadcast_actuals'] = get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims)
if not broadcast_dims:
option['broadcast_returns'] = (["b_" + x for x in option['broadcast_actuals']
if x != broadcast_arg['name'] or not broadcast_inplace])
else:
option['broadcast_returns'] = ["b_" + broadcast_arg['name']]
option['broadcast_function'] = 'expand_' + ('inplace' if broadcast_inplace
else 'size' if broadcast_dims else 'outplace')
option['broadcast_modified_actuals'] = ['b_' + y if 'b_' + y in option['broadcast_returns'] else y
for y in option['actuals']]
def native_get_formals(option, include_constants=False):
# type: (FunctionOption, bool) -> List[AtFormal]
seen = set() # type: Set[str]
pos_args = []
kwd_args = []
def insert(argument):
# type: (AtFormal) -> None
if argument['name'] not in seen:
seen.add(argument['name'])
if argument.get('kwarg_only', False):
kwd_args.append(argument)
else:
pos_args.append(argument)
for argument in option['arguments']:
insert(argument)
# not clear we need dynamic_type translation as we can specify the correct type
# directly in native functions
def add_dynamic_type(argument, option):
# type: (AtFormal, FunctionOption) -> AtFormal
argument['dynamic_type'] = NATIVE_DYNAMIC_TYPE.get(argument['type'], argument['type'])
return argument
result = pos_args + kwd_args
result = [add_dynamic_type(argument, option) for argument in result]
# ensure we get reference-type formals when appropriate
def native_translate_formals(argument, option):
# type: (AtFormal, FunctionOption) -> AtFormal
def translate_map(const):
# type: (bool) -> Dict[str, str]
return {
'Tensor': 'const Tensor &' if const else 'Tensor &',
'Type': 'const Type &' if const else 'Type &',
'TensorOptions': 'const TensorOptions &' if const else 'TensorOptions &',
'TensorList': 'TensorList',
}
if argument.get('is_nullable') and argument['type'] not in translate_map(False).keys():
argument['type'] = "c10::optional<{}>".format(argument['type'])
if (option['inplace'] and argument['name'] == 'self') or argument.get('output', False):
argument['type'] = translate_map(False).get(argument['type'], argument['type'])
else:
argument['type'] = translate_map(True).get(argument['type'], argument['type'])
return argument
result = [native_translate_formals(argument, option) for argument in result]
return result
# this can return multiple return types in a list, e.g. ['Tensor', 'Tensor']
def native_get_return_types(option):
# type: (FunctionOption) -> List[ReturnType]
ret = option['return']
return_types = [] # List[ReturnType]
for t_raw in ret:
# See Note [field_name versus name]
field_name = None
if isinstance(t_raw, string_type):
t = t_raw
name = None
else:
t = t_raw['type']
name = t_raw['name']
if 'field_name' in t_raw:
field_name = t_raw['field_name']
# can't actually return a TensorList (since it's a reference object)
actual_return_type = {'TensorList': 'std::vector<Tensor>'}.get(t, t)
if actual_return_type == 'Tensor' and (option['inplace'] or option['api_name'].endswith('_out')):
# follow normal ATen convention of returning Tensor & for inplace functions.
actual_return_type = 'Tensor &'
rtype = {
'type': actual_return_type,
'dynamic_type': NATIVE_DYNAMIC_TYPE.get(t, t),
} # type: ReturnType
if name is not None:
rtype['name'] = name
if field_name is not None:
rtype['field_name'] = field_name
return_types.append(rtype)
return return_types
def process_native(option):
# type: (FunctionOption) -> Optional[OutputDeclaration]
assert option['python_module'] == '' or option['python_module'] == 'nn', \
"Found python_module of {} for decl {}, but only \'\' string or \'nn\' are supported".format(
option['python_module'], option['name'])
formals = native_get_formals(option)
option['formals_list'] = formals
option['formals'] = [format_formal(f) for f in formals]
option['formals_with_defaults'] = [formal_with_default(f) for f in formals]
option['returns'] = native_get_return_types(option)
option['return_type'] = format_return_type(option['returns'])
option['return_call'] = 'return ' if option['return_type'] != 'void' else ''
option['actuals'] = [f['name'] for f in formals]
option['formals_types'] = [f['type'] for f in option['formals_list']]
option['native_actuals'] = [f['name'] for f in option['formals_list']]
option['formals_types_with_return'] = [option['return_type']]
if len(option['formals_types']) > 0:
option['formals_types_with_return'].extend(option['formals_types'])
option['method_formals'] = [format_formal(f) for f in formals
if f['name'] != 'self']
option['method_formals_with_defaults'] = (
[formal_with_default(f) for f in formals if f['name'] != 'self'])
# *this is 'const Tensor&' since all Tensor methods are const and must
# be const_casted to be accepted as native function's non-const argument
option['method_actuals'] = [
f['name'] if f['name'] != 'self' else 'const_cast<Tensor&>(*this)' for f in formals]
def find_formal(formal_name, formals):
for formal in formals:
if formal_name == formal['dynamic_type']:
return formal
return None
def gen_tensor_method(option, multidispatch_tensors):
# type: (Any, List[str]) -> FunctionCode
def swizzle_self(t): # blegh
if t == 'self':
return '*this'
else:
return t
option['inferred_key_set'] = 'c10::detail::multi_dispatch_key_set({})'.format(
', '.join(swizzle_self(t) for t in multidispatch_tensors)
)
if isinstance(type_method_dispatch, dict):
static_dispatch_function_switches = []
# NB: As this code is currently written, there will NEVER be
# a backend generated for variable dispatch. There is nothing
# stopping us from actually implementing this, however, if you
# really wanted variable on mobile, there's nothing stopping
# you from implementing this (however, you would have an
# annoying phase problem, since code generation for variable
# happens in tools/ which happens later than here.)
#
# If you pass in a variable to the dispatch, and variable is
# enabled, this switch will fail. This is intentional: you
# probably need to disable variable globally in the mobile
# calling code.
for backend in static_dispatch_backends:
if backend in type_method_dispatch:
static_dispatch_function_switches.append(STATIC_DISPATCH_FUNCTION_SWITCH_STATEMENT.substitute(
option,
backend=backend,
backend_function=type_method_dispatch[backend],
native_arguments=option['method_actuals']))
static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute(
option,
key_set='key_set()',
static_dispatch_function_switches=static_dispatch_function_switches)
else:
static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_DEFAULT_BODY.substitute(
option, native_arguments=option['method_actuals'])
method_definition = C10_TENSOR_METHOD_DEFINITION
return FunctionCode(
declaration=TENSOR_METHOD_DECLARATION.substitute(
option, static_dispatch_method_body=static_dispatch_method_body),
definition=method_definition.substitute(
option, static_dispatch_method_body=static_dispatch_method_body))
def gen_namespace_function(option, multidispatch_tensors):
# type: (Any, List[str]) -> FunctionCode
option['inferred_key_set'] = (
'c10::detail::multi_dispatch_key_set({})'.format(', '.join(multidispatch_tensors)))
declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION
fn_declaration = declaration.substitute(option)
if isinstance(type_method_dispatch, dict):
static_dispatch_function_switches = []
for backend in static_dispatch_backends:
if backend in type_method_dispatch:
static_dispatch_function_switches.append(STATIC_DISPATCH_FUNCTION_SWITCH_STATEMENT.substitute(
option,
backend=backend,
backend_function=type_method_dispatch[backend],
native_arguments=option['native_actuals']))
static_dispatch_function_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute(
option,
key_set=option['inferred_key_set'],
static_dispatch_function_switches=static_dispatch_function_switches)
else:
static_dispatch_function_body = STATIC_DISPATCH_FUNCTION_DEFAULT_BODY.substitute(
option, native_arguments=option['native_actuals'])
if is_factory_method:
fn_definition = C10_FACTORY_DEFINITION.substitute(
option, static_dispatch_function_body=static_dispatch_function_body)
else:
fn_definition = C10_FUNCTION_DEFINITION.substitute(
option, static_dispatch_function_body=static_dispatch_function_body)
return FunctionCode(definition=fn_definition, declaration=fn_declaration)
assert find_formal('Type', formals) is None, \
"Found Type argument in {}({}). Use TensorOptions instead.".format(
option['name'], ", ".join(option['method_formals_with_defaults']))
type_method_dispatch = option['type_method_definition_dispatch']
multidispatch_tensors = find_multidispatch_tensors(formals)
option['type_method_formals'] = [format_formal(f) for f in formals]
option['type_method_actuals'] = [f['name'] for f in formals]
option['native_actuals'] = [f['name'] for f in formals]
is_method = 'method' in option['variants']
is_namespace_function = 'function' in option['variants']
# For method-only entries, the first argument should be self
if is_method and not is_namespace_function:
assert formals[0]['name'] == 'self'
is_factory_method = find_formal('TensorOptions', formals) and 'method' not in option['variants']
check_methods_do_not_start_with_underscore(option['name'], is_method)
option['method_prefix_derived'] = ''
# NB: Device guard and scalar type generated code is still based on the
# first argument. Scalar type test will be removed once TH is removed.
# If you need more complex device guard behavior, you should disable
# device guard and then manually add the guards you need.
dispatch_options = find_formal('TensorOptions', formals)
guard_tensor = None if dispatch_options else find_dispatch_tensor(formals)
option['device_guard_declaration'] = device_guard(option, dispatch_options, guard_tensor)
option['named_guard_declaration'] = named_guard(option, find_tensors(formals),
find_tensorlists(formals))
option['dispatch_scalar_type_declaration'] = dispatch_scalar_type(option, dispatch_options, guard_tensor)
broadcast_arg = get_broadcast_argument(option)
if broadcast_arg is not None:
raise Exception("broadcasting is not yet supported for native functions, "
"but specified for function {}", option['name'])
top_env['list_of_aten_ops'].append(OPERATOR_NAME_FULL.substitute(option))
option['native_type_method_dispatch'] = type_method_dispatch
# Note [Abstract ATen methods]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# An abstract ATen method is one whose dispatch differs between
# types. These are implemented in derived types (with a
# standard (throwing) definition in Type). A concrete ATen
# method is one which has the same dispatch for all types;
# we just implement it in the base Type. This is exposed
# in Declarations.yaml via a field named 'abstract'.
abstract = False
if isinstance(type_method_dispatch, dict):
abstract = True
# Having manual_kernel_registration for an abstract method doesn't make sense.
assert not option['manual_kernel_registration']
else:
top_env['type_method_declarations'].append(NATIVE_DISPATCH_DECLARATION.substitute(option))
top_env['type_method_definitions'].append(NATIVE_DISPATCH_DEFINITION_DEFAULT.substitute(option))
if option['manual_kernel_registration']:
op_registrations.append(OpRegistration(
operator_name=OPERATOR_NAME.substitute(option),
registration_code=DEFAULT_SCHEMA_REGISTRATION.substitute(option)))
else:
if option['use_c10_dispatcher'] == 'full':
op_registrations.append(OpRegistration(
operator_name=OPERATOR_NAME.substitute(option),
registration_code=DEFAULT_FUNCTION_REGISTRATION.substitute(option)))
else:
assert option['use_c10_dispatcher'] == 'unboxed_only'
op_registrations.append(OpRegistration(
operator_name=OPERATOR_NAME.substitute(option),
registration_code=DEFAULT_UNBOXEDONLY_FUNCTION_REGISTRATION.substitute(option)))
# generate the at::native function declarations (i.e. what the user will implement)
if isinstance(type_method_dispatch, dict):
generated_native_functions = [] # type: List[str]
for key in sorted(type_method_dispatch.keys()):
value = type_method_dispatch[key]
# skip functions in different namespace, e.g. legacy::cpu
if "::" in value:
continue
if value not in generated_native_functions:
option['native_type_method_dispatch'] = value
top_env['native_function_declarations'].append(NATIVE_DECLARATION.substitute(option))
generated_native_functions.append(value)
else:
top_env['native_function_declarations'].append(NATIVE_DECLARATION.substitute(option))
method_of = ['Type']
if is_method:
code = gen_tensor_method(option, multidispatch_tensors)
top_env['tensor_method_declarations'].append(code.declaration)
top_env['tensor_method_definitions'].append(code.definition)
method_of.append('Tensor')
if is_namespace_function:
code = gen_namespace_function(option, multidispatch_tensors)
top_env['function_definitions'].append(code.definition)
top_env['function_declarations'].append(code.declaration)
method_of.append('namespace')
return OutputDeclaration(
name=option['api_name'],
operator_name=option['operator_name'],
overload_name=option['overload_name'],
use_c10_dispatcher=option['use_c10_dispatcher'],
manual_kernel_registration=option['manual_kernel_registration'],
category_override=option['category_override'],
matches_jit_signature=option["matches_jit_signature"],
schema_string=option["schema_string"],
method_prefix_derived=option['method_prefix_derived'],
arguments=formals,
method_of=method_of,
mode=option['mode'],
python_module=option['python_module'],
buffers=None,
returns=option['returns'],
inplace=option['inplace'],
is_factory_method=is_factory_method,
# See Note [Abstract ATen methods]
abstract=abstract,
requires_tensor=option.get('requires_tensor', False),
device_guard=option.get('device_guard', True),
with_gil=option.get('with_gil', False),
deprecated=option['deprecated'],
)
output_declarations = [] # type: List[OutputDeclaration]
op_registrations = [] # type: List[OpRegistration]
for declaration in declarations:
output_options = [] # type: List[OutputDeclaration]
for option in declaration['options']:
option["matches_jit_signature"] = declaration["matches_jit_signature"]
option["schema_string"] = declaration["schema_string"]
try:
if option['mode'] != 'native':
# Mutably populate option with values
process_legacy_th_option(option)
else:
output_option = process_native(option)
if output_option:
output_options.append(output_option)
except NYIError:
option['skip'] = True
output_declarations.extend(output_options)
return output_declarations, op_registrations
def create_derived(backend_type_env, declarations):
# type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str], List[OpRegistration], List[str], List[str]]
type_object_declarations = [] # type: List[str]
type_object_definitions = [] # type: List[str]
op_registrations = [] # type: List[OpRegistration]
legacy_th_declarations = [] # type: List[str]
legacy_th_definitions = [] # type: List[str]
is_cuda = 'CUDA' in backend_type_env['Backend']
def requires_checked_cast(argument):
# type: (THFormal) -> bool
if argument['type'] == 'IntArrayRef':
return 'size' in argument
return argument['type'] in CHECKED_CAST
def nullable_argument(argument):
# type: (THFormal) -> bool
return argument.get('is_nullable', False)
def get_argument(env, argument, option):
# type: (Environment, THFormal, FunctionOption) -> str
if requires_checked_cast(argument):
checked_use = CHECKED_USE.get(
argument['type'], '{}_').format(argument['name'])
if nullable_argument(argument):
checked_use = CHECKED_USE_NULLABLE.substitute(
env={}, arg_name=argument['name'], usage=checked_use)
return checked_use
elif argument['type'] == 'CONSTANT':
v = str(argument.get('default', argument['name']))
for pattern, replacement in CONSTANT_REPLACEMENTS:
v = re.sub(pattern, replacement, v)
return CodeTemplate(v).substitute(env)
# e.g. argument 0, i.e. repeat the 0th argument in this position...
elif argument['type'] == 'argument':
index = int(argument['name'])
return get_argument(env, option['arguments'][index], option)
else:
return argument['name']
def drop_argument(argument, option):
# type: (THFormal, FunctionOption) -> bool
# Devices are handled in the body of the function.
if argument['name'] == 'device':
return True
return False
def get_arguments(env, arguments, option):
# type: (Environment, List[THFormal], FunctionOption) -> List[str]
return [get_argument(env, argument, option)
for argument in arguments if not drop_argument(argument, option)]
def is_actual_return_long(env, ret):
# type: (Environment, ReturnDecl) -> bool
if ret['type'] == 'long':
return True
if ret['type'] == 'real':
return env['ScalarName'] == 'Long'
if ret['type'] == 'accreal':
return env['AccScalarName'] == 'Long'
return False
def handle_zero_dim(env, option):
# type: (Environment, FunctionOption) -> List[str]
zero_dim_dispatch = option.get('zero_dim_dispatch_when_scalar', '')
if not zero_dim_dispatch:
return []
broadcasts_arg = zero_dim_dispatch in option.get('broadcast_actuals', '')
# if the argument broadcasts, then this would only affect cases where all broadcasted
# tensors were zero-dim, which is inconsistent with the scalar handling.
if broadcasts_arg:
return []
zero_dim_actuals = [arg['name']
if arg['name'] != zero_dim_dispatch else "{}.item()".format(arg['name'])
for arg in option['formals_list']]
return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)]
def allocate_arg(env, arg, output_count):
# type: (Environment, THFormal, int) -> List[str]
name = arg['name']
allocation = CodeTemplate(ALLOC_NOARGS_WRAP[arg['type']]).substitute(env)
tensor_arg = '{}_'.format(name)
if arg.get('mask', False):
allocation = 'output_mask[{}] ? {} : nullptr'.format(output_count, allocation)
tensor_arg = ('{}_ == nullptr ? (TensorImpl*)UndefinedTensorImpl::singleton() : (TensorImpl*){}_'
.format(name, name))
intrusive_ptr_type = 'c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>'
return [
'auto {}_ = {};'.format(name, allocation),
'auto {} = Tensor({}::reclaim({}));'.format(name, intrusive_ptr_type, tensor_arg),
]
def resize_arg(arg):
# type: (THFormal) -> str
resize = arg['resize']
if isinstance(resize, str):
return "{}.resize_({}.sizes());".format(arg['name'], resize)
else:
resize_scalar = arg.get('resize_scalar', False)
if resize_scalar:
dims = ['{}.dim() == 0 ? 1 : {}.size({})'.format(name, name, dim) for name, dim in resize]
else:
dims = ['{}.size({})'.format(name, dim) for name, dim in resize]
return "{}.resize_({{ {} }});".format(arg['name'], ','.join(dims))
def handle_call(env, option, cimpl):
# type: (Environment, FunctionOption, FunctionOption) -> str
is_nn = option['mode'] == 'NN'
actuals = get_arguments(env, cimpl['arguments'], option)
if is_cuda or is_nn:
actuals = ['globalContext().getTHCState()'] + actuals
cname = cimpl['cname']
if option.get('sparse', False):
if is_cuda:
cname = 'THCS' + env['ScalarName'] + "Tensor_" + cname
else:
cname = env['THTensor'].replace('TH', 'THS') + '_' + cname
elif is_nn:
cname = 'THNN_{}'.format(env['THType']) + cname
else:
cname = env['THTensor'] + '_' + cname
call = CALL_TEMPLATE.substitute(actuals=actuals, cname=cname)
if cimpl.get('condition') is not None:
call = 'if ({}) {}'.format(cimpl['condition'], call)
return call
def emit_body(env, option, scalar_type_cases):
# type: (Environment, FunctionOption, List[str]) -> List[str]
body = [] # type: List[str]
body += handle_zero_dim(env, option)
cases = []
for scalar_name, c_type, accreal, _ in scalar_types:
if scalar_name in scalar_type_cases:
case_body = []
# arguments are potentially duplicated because of one argument
# referencing another
seen_names = set() # type: Set[str]
seen_tensorlists = set() # type: Set[str]
count = 0
output_count = 0
case_env = {
'Backend': env['Backend'],
'DeviceType': env['DeviceType'],
'state': env['state'],
'ScalarType': c_type,
'ScalarName': scalar_name,
'AccScalarName': accreal,
'THType': scalar_name,
'THTensor': 'TH{}Tensor'.format(scalar_name)
} # type: Environment
if case_env['Backend'] == 'CUDA':
sname = '' if scalar_name == "Float" else scalar_name
case_env['THType'] = 'Cuda{}'.format(sname)
case_env['THTensor'] = 'THCuda{}Tensor'.format(sname)
for arg in option['arguments']:
if is_real_argument_to_wrapper(arg):
count += 1
if arg['type'] == 'TensorList':
seen_tensorlists.add(arg['name'])
wrap_dim_target = arg.get('wrap_dim', None)
if wrap_dim_target is not None:
# for Tensors, "name_" is the TensorImpl, but for TensorLists, it is an
# std::vector of TH*s. Since TH*s have different dimension rules, we used
# "name" instead, but keep "name_" for tensor to avoid an extra function call.
if wrap_dim_target not in seen_tensorlists:
wrap_dim_target = wrap_dim_target + "_"
case_body.append("{} = maybe_wrap_dim({}, {});".format(
arg['name'], arg['name'], wrap_dim_target))
# only generated checked casts the first time we see it
if arg['name'] not in seen_names and requires_checked_cast(arg):
seen_names.add(arg['name'])
# make a new allocation of TensorImpl, then wrap a Tensor around it.
if arg.get('allocate', False):
case_body += allocate_arg(case_env, arg, output_count)
output_count += 1
# extract the TensorImpl from an existing tensor (or Storage, etc.)
else:
# special case where we allow undefined Tensors, and thus
# the checked cast succeeds even if the Tensor is not
# defined
null_okay = 'true' if nullable_argument(arg) else 'false'
check_cast = CHECKED_CAST[arg['type']].substitute(
case_env, arg_name=arg['name'], arg_pos=count,
api_name=option['api_name'], null_okay=null_okay,
size=arg.get('size'))
case_body.append("auto {}_ = {};".format(
arg['name'], check_cast))
if drop_argument(arg, option):
case_body.append(
"(void) {}_; //silence unused warning".format(arg['name']))
initializers = []
# resize tensors for special ops that require it
if 'resize' in arg:
initializers.append(resize_arg(arg))
# also special handling where we zero some outputs.
if arg.get('zero', False) or (arg.get('cpu_zero', False) and not is_cuda):
initializers.append("{}.zero_();".format(arg['name']))
# only initialize non-null arguments
if nullable_argument(arg) and len(initializers) > 0:
case_body.append(CONDITIONAL_INITIALIZER.substitute({
'name': arg['name'],
'initializer': initializers
}))
else:
case_body += initializers
# cimpls, if it exists, contains the underlying C function names and
# arguments. Otherwise use option
cimpls = option.get('cimpls', [option])
calls = [handle_call(case_env, option, cimpl) for cimpl in cimpls]
ret = option['return']
if ret['kind'] == 'arguments':
case_body.extend([call + ';' for call in calls])
arguments_indices = ret['arguments']
arguments = [option['arguments'][argi]
for argi in arguments_indices]
if len(arguments_indices) == 1:
arg = arguments[0]
case_body.append("return {};".format(arg['name']))
else:
types = [to_return_type(arg, option)['type']
for arg in arguments]
# TODO: check for move semantics...
names = [arg['name'] for arg in arguments]
case_body.append(CodeTemplate("return std::tuple<${types}>(${names});").substitute(
types=types, names=names))
elif ret['kind'] == 'type':
assert len(calls) == 1
call = calls[0]
if ret['type'] in ALLOC_WRAP.keys():
wrapped_tensor = CodeTemplate(ALLOC_WRAP[ret['type']]).substitute(
case_env, arguments=[call])
return_tensor = (
"return Tensor(" +
"c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(" +
"(${wrapped_tensor})));")
case_body.append(CodeTemplate(return_tensor).substitute(
case_env, wrapped_tensor=wrapped_tensor))
# return the same underlying Tensor type for both real and accreal; this ensures
# e.g. x.sum(0) and x.sum() return the same type. We explicitly cast to the
# ScalarType before constructing the scalar_tensor to avoid overflow checking.
elif ret['type'] == 'accreal' or ret['type'] == 'real':
return_scalar = ('return at::scalar_tensor(convert<${ScalarType}>(${call}), '
'options(ScalarType::${ScalarName}));')
case_body.append(CodeTemplate(return_scalar).substitute(case_env, call=call))
else:
# we using int64_t for long in the API, so correct it here...
if is_actual_return_long(case_env, ret):
call = "static_cast<int64_t>({})".format(call)
case_body.append("return {};".format(call))
else:
raise Exception("NYI - return handling")
cases.append(LEGACY_TH_DEFINITION_CASE.substitute(case_env, case_body=case_body))
body.append(LEGACY_TH_DEFINITION_SWITCH_STATEMENT.substitute(env, cases=cases))
return body
def process_legacy_th_option(option):
# type: (FunctionOption) -> None
backend = backend_type_env['Backend']
if backend in option['backend_types']:
env = nested_dict(option, backend_type_env)
body = emit_body(env, option, option['backend_types'][backend]) # type: ignore
option['type_definition_body'] = body
if option.get('broadcast_actuals', None):
legacy_th_declarations.append(
LEGACY_TH_DECLARATION_BROADCAST.substitute(env))
legacy_th_definitions.append(
LEGACY_TH_DEFINITION_BROADCAST.substitute(env))
legacy_th_declarations.append(
LEGACY_TH_DECLARATION.substitute(env))
legacy_th_definitions.append(
LEGACY_TH_DEFINITION.substitute(env))
def process_native(option):
# type: (FunctionOption) -> None
dispatch = option['type_method_definition_dispatch']
env = nested_dict(option, backend_type_env)
if isinstance(dispatch, dict):
# If we're here, then our native_functions.yaml entry has dispatch configuration.
# Having manual kernel registration doesn't make sense.
assert not option['manual_kernel_registration']
backend = backend_type_env['Backend']
if backend in option['backend_types']:
native_dispatch = dispatch.get(backend)
type_object_declarations.append(
NATIVE_DISPATCH_DECLARATION.substitute(env))
option['native_type_method_dispatch'] = native_dispatch
type_object_definitions.append(
NATIVE_DISPATCH_DEFINITION_BACKEND.substitute(env))
if native_dispatch:
if option['use_c10_dispatcher'] == 'full':
op_registrations.append(OpRegistration(
operator_name=OPERATOR_NAME.substitute(option),
registration_code=BACKEND_FUNCTION_REGISTRATION.substitute(env)))
else:
assert option['use_c10_dispatcher'] == 'unboxed_only'
op_registrations.append(OpRegistration(
operator_name=OPERATOR_NAME.substitute(option),
registration_code=BACKEND_UNBOXEDONLY_FUNCTION_REGISTRATION.substitute(env)))
for declaration in declarations:
for option in declaration['options']:
if not option.get('skip', False):
try:
if option['mode'] == 'NN' and option.get('cimpls') is None:
continue
if option['mode'] != 'native':
process_legacy_th_option(option)
else:
process_native(option)
except NYIError:
pass
return (type_object_declarations, type_object_definitions, op_registrations,
legacy_th_declarations, legacy_th_definitions)