| # HEY! Trying to understand what this file does? Read |
| # "what has to be done to add a Operation ..." first! |
| |
| import re |
| from code_template import CodeTemplate |
| |
| try: |
| import typing # noqa: F401 |
| except ImportError: |
| raise RuntimeError( |
| 'Missing build dependency: Unable to import the `typing` module. ' |
| 'Please install it via `conda install typing` or `pip install typing`') |
| |
| # flake8 doesn't take into account usages in type annotations. |
| from typing import Union, Set # noqa: F401 |
| from typing import Any, Dict, List, Optional, Tuple, NamedTuple |
| |
| try: |
| from mypy_extensions import TypedDict |
| except ImportError: |
| # Avoid the dependency on the mypy_extensions package. |
| # It is required, however, for type checking. |
| def TypedDict(name, attrs, total=True): # type: ignore |
| return Dict[Any, Any] |
| |
| import sys |
| if sys.version_info[0] == 3: |
| string_type = str |
| else: |
| string_type = basestring |
| |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| # |
| # what has to be done to add a Operation ... |
| # |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| # |
| # 1. if broadcasting or without the full list of arguments, add a non-virtual |
| # declaration under Type.h (right now, we call this template |
| # BROADCAST but it also handles default arguments) |
| TYPE_METHOD_DECLARATION_BROADCAST = CodeTemplate("""\ |
| ${return_type} ${api_name}(${type_method_formals_with_defaults}) const; |
| """) |
| # 2. broadcasting functions are implemented in Type.cpp |
| TYPE_METHOD_DEFINITION_BROADCAST = CodeTemplate("""\ |
| ${return_type} Type::${api_name}(${type_method_formals}) const { |
| ${device_guard_declaration} |
| Tensor ${broadcast_returns}; |
| std::tie(${broadcast_returns}) = ${broadcast_function}(${broadcast_actuals}, "${api_name}"); |
| return ${method_prefix_derived}${api_name}(${broadcast_modified_actuals}); |
| } |
| """) |
| # 3. add virtual dispatch declaration to Type.h and impl to Type.cpp; method_prefix_derived |
| # is present for providing a base-class definition for a derived-type method with a prefix. |
| # |
| # If the declaration is abstract, then the actual implementation will |
| # be in a derived type; we put in a simple default "not implemented" |
| # stub. However, if the declaration is concrete, we dispatch to the |
| # actual implementation. At the moment, this situation *only* occurs |
| # for 'native' declarations (so the native dispatch is hardcoded into |
| # the template here.) |
| TYPE_METHOD_DECLARATION_ABSTRACT = CodeTemplate("""\ |
| virtual ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals_with_defaults}) const; |
| """) |
| TYPE_METHOD_DEFINITION_ABSTRACT = CodeTemplate("""\ |
| ${return_type} Type::${method_prefix_derived}${api_name}(${type_method_formals}) const { |
| AT_ERROR("${method_prefix_derived}${api_name} is not implemented for type ", toString()); |
| } |
| """) |
| TYPE_METHOD_DECLARATION_CONCRETE = CodeTemplate("""\ |
| virtual ${return_type} ${api_name}(${type_method_formals_with_defaults}) const; |
| """) |
| DEPRECATED_TYPE_METHOD_DECLARATION_CONCRETE = CodeTemplate("""\ |
| AT_DEPRECATED(virtual ${return_type} ${api_name}(${type_method_formals_with_defaults}) const); |
| """) |
| TYPE_METHOD_DEFINITION_CONCRETE = CodeTemplate("""\ |
| ${return_type} Type::${api_name}(${type_method_formals}) const { |
| ${device_guard_declaration} |
| ${type_definition_body} |
| } |
| """) |
| DEPRECATED_TYPE_METHOD_DEFINITION_CONCRETE = CodeTemplate("""\ |
| ${return_type} Type::${api_name}(${type_method_formals}) const { |
| TensorOptions options(*this); |
| ${device_guard_declaration} |
| return at::native::${api_name}(${type_method_actuals}, options); |
| } |
| """) |
| # 4. add virtual override to TypeDerived.h |
| TYPE_DERIVED_DECLARATION = CodeTemplate("""\ |
| virtual ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override; |
| """) |
| # 5. add override definition to TypeDerived.cpp |
| TYPE_DERIVED_DEFINITION = CodeTemplate("""\ |
| ${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const { |
| ${device_guard_declaration} |
| ${type_definition_body} |
| } |
| """) |
| # NB: As far as ezyang can tell, we don't *have* to codegen this, |
| # because we will inherit it from the TYPE_METHOD_DEFINITION_CONCRETE in |
| # the superclass. But it doesn't seem to be harmful. |
| TYPE_DERIVED_DEFINITION_NATIVE = CodeTemplate("""\ |
| ${return_type} ${Type}::${api_name}(${type_method_formals}) const { |
| ${device_guard_declaration} |
| const auto& self_ty = *this; |
| (void)self_ty; |
| ${return_call} at::native::${native_type_method_dispatch}(/* actuals */ ${actuals}); |
| } |
| """) |
| TYPE_DERIVED_DEFINITION_NATIVE_MISSING = CodeTemplate("""\ |
| ${return_type} ${Type}::${api_name}(${type_method_formals}) const { |
| AT_ERROR("${api_name} not supported on ${Type}"); |
| } |
| """) |
| TYPE_DEFINITION_BODY_NATIVE = CodeTemplate("""\ |
| ${return_call} at::native::${native_type_method_dispatch}(/* native_actuals */ ${native_actuals}); |
| """) |
| |
| # add non-virtual declaration to Tensor.h |
| TENSOR_METHOD_DECLARATION = CodeTemplate("""\ |
| ${return_type} ${api_name}(${method_formals_with_defaults})${const_mark}; |
| """) |
| # add non-virtual declaration to Tensor.cpp |
| TENSOR_METHOD_DEFINITION = CodeTemplate("""\ |
| inline ${return_type} Tensor::${api_name}(${method_formals})${const_mark} { |
| return type().${api_name}(${method_actuals}); |
| } |
| """) |
| # add a method declaration in Functions.h |
| FUNCTION_DECLARATION = CodeTemplate("""\ |
| static inline ${return_type} ${api_name}(${formals_with_defaults}); |
| """) |
| # add a method declaration in Functions.h |
| DEPRECATED_FUNCTION_DECLARATION = CodeTemplate("""\ |
| AT_DEPRECATED(static inline ${return_type} ${api_name}(${formals_with_defaults})); |
| """) |
| # add method definition in Functions.h |
| FUNCTION_DEFINITION = CodeTemplate("""\ |
| static inline ${return_type} ${api_name}(${formals}) { |
| return ${inferred_type}.${api_name}(${type_method_actuals}); |
| } |
| """) |
| # add a native declaration for a native function |
| NATIVE_DECLARATION = CodeTemplate("""\ |
| AT_API ${return_type} ${native_type_method_dispatch}(${formals_with_defaults}); |
| """) |
| |
| # special method definition for factory functions in Functions.h |
| FACTORY_DEFINITION = CodeTemplate("""\ |
| static inline ${return_type} ${api_name}(${formals}) { |
| const DeviceGuard guard(options.device()); |
| return at::native::${api_name}(${type_method_actuals}); |
| } |
| """) |
| |
| # special method definition for *deprecated* factory functions in Functions.h |
| DEPRECATED_FACTORY_DEFINITION = CodeTemplate("""\ |
| static inline ${return_type} ${api_name}(${formals}) { |
| return at::${api_name}(${type_method_actuals}, TensorOptions(${inferred_type})); |
| } |
| """) |
| |
| # We need to cast to the base type because C++ may hide the base class |
| # implementation of ${api_name} if we have overloaded a function with |
| # the same name (but different signature) already |
| ZERO_DIM_CHECK = CodeTemplate("""\ |
| if (${check_name}.dim() == 0) { |
| return static_cast<const Type*>(this)->${api_name}(${zero_dim_actuals}); |
| }""") |
| |
| ZERO_DIM_ONLY = CodeTemplate("""\ |
| AT_ERROR("${api_name} only supports a 0-dimensional ${check_name} tensor, but got tensor " |
| "with ", ${check_name}.dim(), " dimension(s)."); |
| """) |
| |
| SPARSE_CHECK = CodeTemplate("""\ |
| if(${check_name}.type().is_sparse()) { |
| return static_cast<const Type*>(this)->${api_name}(${sparse_actuals}); |
| }""") |
| |
| BUFFER_DEFINITION = CodeTemplate("""\ |
| auto ${name}_ = new ${Tensor}(${THTensor}_new()); |
| auto ${name} = Tensor(${name}_, false);""") |
| |
| CONDITIONAL_INITIALIZER = CodeTemplate("""\ |
| if (${name}.defined()) { |
| ${initializer} |
| }""") |
| |
| CALL_TEMPLATE = CodeTemplate("${cname}(${actuals})") |
| |
| HALF_CONVERSION = CodeTemplate("convert<half>(${value})") |
| |
| |
| class NYIError(Exception): |
| """Indicates we don't support this declaration yet""" |
| |
| def __init__(self, reason): |
| self.reason = reason |
| |
| |
| TYPE_FORMAL_GENERIC = { |
| 'THTensor*': 'Tensor &', |
| 'THSTensor*': 'SparseTensorRef', |
| 'THBoolTensor*': 'Tensor &', |
| 'THIndexTensor*': 'Tensor &', |
| 'THIntegerTensor*': 'Tensor &', |
| 'THDenseTensor*': 'Tensor &', |
| 'THDenseIndexTensor*': 'Tensor &', |
| 'THStorage*': 'Storage &', |
| 'THGenerator*': 'Generator *', |
| 'IntListSize': 'IntList', |
| 'IntListStride': 'IntList', |
| 'accreal': 'Scalar', |
| 'real': 'Scalar', |
| 'long': 'int64_t', |
| } |
| |
| DYNAMIC_TYPE = { |
| 'THTensor*': 'Tensor', |
| 'THSTensor*': 'SparseTensorRef', |
| 'THBoolTensor*': 'BoolTensor', |
| 'THIndexTensor*': 'IndexTensor', |
| 'THIntegerTensor*': 'IntegerTensor', |
| 'THDenseTensor*': 'Tensor', |
| 'THDenseIndexTensor*': 'IndexTensor', |
| 'THStorage*': 'Storage', |
| 'THGenerator*': 'Generator*', |
| 'IntListSize': 'IntList', |
| 'IntListStride': 'IntList', |
| 'accreal': 'accreal', |
| 'real': 'real', |
| 'long': 'int64_t', |
| } |
| |
| NATIVE_DYNAMIC_TYPE = { |
| 'Tensor &': 'Tensor', |
| 'const Tensor &': 'Tensor', |
| } |
| |
| TYPE_RETURN = { |
| 'THTensor*': 'Tensor', |
| 'THIndexTensor*': 'Tensor', |
| 'THBoolTensor*': 'Tensor', |
| 'THIntegerTensor*': 'Tensor', |
| 'THSTensor*': 'Tensor', |
| 'THDenseTensor*': 'Tensor', |
| 'THDenseIndexTensor*': 'Tensor', |
| 'real': 'Tensor', |
| 'accreal': 'Tensor', |
| 'long': 'int64_t', |
| } |
| |
| CHECKED_CAST = { |
| 'THTensor*': |
| CodeTemplate( |
| 'checked_cast_tensor<${Tensor}>(' |
| '${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay}, ' |
| 'Backend::${Backend}, ScalarType::${ScalarName})'), |
| 'THSTensor*': |
| CodeTemplate( |
| 'checked_cast_tensor<Sparse${Tensor}>(' |
| '${arg_name}.tref.pImpl,"${arg_name}",${arg_pos},false, ' |
| 'Backend::${Backend}, ScalarType::${ScalarName})'), |
| 'THBoolTensor*': |
| CodeTemplate( |
| 'checked_cast_tensor<${Backend}ByteTensor>(' |
| '${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay}, ' |
| 'Backend::${Backend}, ScalarType::Byte)'), |
| 'THIndexTensor*': |
| CodeTemplate( |
| 'checked_cast_tensor<${Backend}LongTensor>(' |
| '${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay}, ' |
| 'Backend::${Backend}, ScalarType::Long)'), |
| 'THIntegerTensor*': |
| CodeTemplate( |
| 'checked_cast_tensor<${Backend}IntTensor>(' |
| '${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay}, ' |
| 'Backend::${Backend}, ScalarType::Int)'), |
| 'THDenseTensor*': |
| CodeTemplate( |
| 'checked_cast_tensor<${DenseTensor}>(' |
| '${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay}, ' |
| 'Backend::${DenseBackend}, ScalarType::${ScalarName})'), |
| 'THDenseIndexTensor*': |
| CodeTemplate( |
| 'checked_cast_tensor<${DenseBackend}LongTensor>(' |
| '${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay}, ' |
| 'Backend::${DenseBackend}, ScalarType::Long)'), |
| 'THStorage*': |
| CodeTemplate( |
| 'checked_cast_storage<Storage>(' |
| '&${arg_name},"${arg_name}",${arg_pos}, ' |
| 'Backend::${Backend}, ScalarType::${ScalarName})'), |
| 'THGenerator*': |
| CodeTemplate( |
| 'check_generator<${Backend}Generator>(${arg_name}, &globalContext().defaultGenerator(backend()))'), |
| # This is a cast done via direct-construction |
| 'IntListSize': CodeTemplate('at::IntList ${result_name} = get_intlist_size_th(${arg_name});'), |
| 'IntListStride': CodeTemplate('at::IntList ${result_name} = get_intlist_stride_th(${arg_name});'), |
| 'real': CodeTemplate('${arg_name}.to${ScalarName}()'), |
| 'accreal': CodeTemplate('${arg_name}.to${AccScalarName}()'), |
| 'TensorList': CodeTemplate( |
| 'tensor_list_checked_cast<${Tensor}, Tensor, ' |
| '${THTensor}>(${arg_name},"${arg_name}",${arg_pos}, ' |
| 'Backend::${Backend}, ScalarType::${ScalarName})'), |
| 'IntList': CodeTemplate('check_intlist<${size}>(${arg_name}, "${arg_name}", ${arg_pos}${,default_init})') |
| } |
| |
| DIRECT_CONSTRUCTION_CHECKED_CAST = {'IntListSize', 'IntListStride'} |
| |
| CHECKED_USE = { |
| 'THTensor*': '{}_->tensor', |
| 'THSTensor*': '{}_->tensor', |
| 'THIndexTensor*': '{}_->tensor', |
| 'THBoolTensor*': '{}_->tensor', |
| 'THIntegerTensor*': '{}_->tensor', |
| 'THDenseTensor*': '{}_->tensor', |
| 'THDenseIndexTensor*': '{}_->tensor', |
| 'THStorage*': '{}_->pImpl()', |
| 'THGenerator*': '{}_->generator', |
| 'TensorList': "{0}_.data(), {0}_.size()", |
| } |
| |
| CHECKED_USE_NULLABLE = CodeTemplate('${arg_name}_ ? ${usage} : NULL') |
| |
| ALLOC_NOARGS_WRAP = { |
| 'THTensor*': 'detail::new_${Tensor}()', |
| 'THBoolTensor*': 'detail::new_${Backend}ByteTensor()', |
| 'THIndexTensor*': 'detail::new_${Backend}LongTensor()', |
| 'THIntegerTensor*': 'detail::new_${Backend}IntTensor()', |
| 'THSTensor*': 'detail::new_Sparse${Tensor}()', |
| 'THDenseTensor*': 'detail::new_${DenseTensor}()', |
| 'THDenseIndexTensor*': 'detail::new_${DenseBackend}LongTensor()', |
| } |
| |
| ALLOC_WRAP = { |
| 'THTensor*': 'new ${Tensor}(${arguments})', |
| 'THBoolTensor*': 'new ${Backend}ByteTensor(${arguments})', |
| 'THIndexTensor*': 'new ${Backend}LongTensor(${arguments})', |
| 'THIntegerTensor*': 'new ${Backend}IntTensor(${arguments})', |
| 'THSTensor*': 'new Sparse${Tensor}(${arguments})', |
| 'THDenseTensor*': 'new ${DenseTensor}(${arguments})', |
| 'THDenseIndexTensor*': 'new ${DenseBackend}LongTensor(${arguments})', |
| } |
| |
| # Replacements for constants when calling into TH |
| CONSTANT_REPLACEMENTS = [ |
| ('AS_REAL', '${AS_REAL}'), |
| ('__last_dim', 'self.ndimension()-1'), |
| ] |
| |
| # Replacements for constants in header file function definitions |
| HEADER_CONSTANT_REPLACEMENTS = [ |
| (r'AS_REAL\((.*)\)', r'\1'), |
| ('__last_dim', '-1'), |
| ] |
| |
| |
| class nested_dict(object): |
| def __init__(self, base, parent): |
| self.base, self.parent = base, parent |
| |
| def __getitem__(self, x): |
| r = self.base.get(x) |
| if r is not None: |
| return r |
| return self.parent[x] |
| |
| |
| Environment = TypedDict('Environment', { |
| 'ScalarName': str, |
| 'THTensor': str, |
| 'THType': str, |
| 'THTensor': str, |
| 'Backend': str, |
| 'AccScalarName': str, |
| }) |
| |
| TopEnvironment = TypedDict('TopEnvironment', { |
| 'type_registrations': List[str], |
| 'type_headers': List[str], |
| 'type_method_declarations': List[str], |
| 'type_method_definitions': List[str], |
| 'type_method_inline_definitions': List[str], |
| 'tensor_method_declarations': List[str], |
| 'tensor_method_definitions': List[str], |
| 'function_declarations': List[str], |
| 'function_definitions': List[str], |
| 'type_ids': List[str], |
| 'native_function_declarations': List[str], |
| }) |
| |
| # A Declarations.cwrap formal argument |
| # type can contain THTensor* types |
| THFormal = TypedDict('THFormal', { |
| 'name': str, |
| 'type': str, |
| 'dynamic_type': str, |
| 'kwarg_only': bool, |
| 'is_nullable': bool, |
| 'default': str, |
| 'default_init': str, |
| 'python_default_init': str, |
| 'output': bool, |
| 'size': int, |
| 'declared_type': str, |
| 'ignore_check': bool, |
| 'allocate': bool, |
| 'mask': bool, |
| 'if_true': bool, |
| 'if_false': bool, |
| 'wrap_dim': str, |
| # Broadcast is originally a str but gets unwrapped to a List or Dict in-place |
| 'broadcast': Any, |
| 'resize': str, |
| 'cpu_zero': bool, |
| 'zero': bool, |
| 'is_type_dispatched': bool, |
| }, total=False) |
| |
| # Generic ATen formal or native_functions.yaml formal argument. |
| # type can contain Tensor& reference types. |
| AtFormal = TypedDict('AtFormal', { |
| 'name': str, |
| 'type': str, |
| 'dynamic_type': str, |
| 'kwarg_only': bool, |
| 'is_nullable': bool, |
| 'default': str, |
| 'default_init': str, |
| 'python_default_init': str, |
| 'output': bool, |
| 'size': int, |
| 'is_type_dispatched': bool, |
| }, total=False) |
| |
| ReturnType = TypedDict('ReturnType', { |
| 'name': str, |
| 'type': str, |
| 'dynamic_type': str, |
| }, total=False) |
| |
| ReturnDecl = TypedDict('ReturnDecl', { |
| 'kind': str, |
| 'type': str, |
| 'arguments': List[int], |
| }, total=False) |
| |
| # Represents a buffer in nn.yaml |
| NNBuffer = TypedDict('NNBuffer', { |
| 'name': str, |
| }) |
| |
| FunctionOption = TypedDict('FunctionOption', { |
| 'actuals': List[str], |
| 'api_name': str, |
| 'arguments': List[THFormal], |
| 'aten_custom_call': str, |
| 'aten_dense_sparse': bool, |
| 'backend_type_pairs': List[Tuple[str, str]], |
| 'backends': List[str], |
| 'broadcast_actuals': List[str], |
| 'broadcast_function': str, |
| 'broadcast_modified_actuals': List[str], |
| 'broadcast_returns': List[str], |
| 'buffers': List[NNBuffer], |
| # cimpls is really a List[FunctionOption] |
| 'cimpls': List[Any], |
| 'cname': str, |
| 'condition': str, |
| 'const_mark': str, |
| 'device_guard': bool, |
| 'device_guard_declaration': str, |
| 'with_gil': bool, |
| 'cpu_half': bool, |
| 'deprecated': bool, |
| 'formals_list': List[AtFormal], |
| 'formals_with_defaults': List[str], |
| 'formals': List[str], |
| 'inferred_type': str, |
| 'inplace': bool, |
| 'method_actuals': List[str], |
| 'method_formals_with_defaults': List[str], |
| 'method_formals': List[str], |
| 'method_prefix_derived': str, |
| 'mode': str, |
| 'name': str, |
| 'native_actuals': List[str], |
| 'native_type_method_dispatch': str, |
| # options should be List[FunctionOption] |
| 'options': Any, |
| 'return_call': str, |
| 'return_type': str, |
| 'return': ReturnDecl, |
| 'returns': List[ReturnType], |
| 'scalar_check': str, |
| 'sparse': bool, |
| 'type_definition_body': List[str], |
| 'type_method_actuals': List[str], |
| 'type_method_definition_dispatch': str, |
| 'type_method_formals_with_defaults': List[str], |
| 'type_method_formals': List[str], |
| 'variants': str, |
| 'when_spares_dispatch': str, |
| 'when_sparse_dispatch': str, |
| 'with_gil': bool, |
| 'zero_dim_dispatch_when_scalar': str, |
| 'zero_dim_tensor_only': bool, |
| }) |
| |
| OutputDeclaration = NamedTuple('OutputDeclaration', [ |
| ('name', str), |
| ('method_prefix_derived', str), |
| ('arguments', List[AtFormal]), |
| ('method_of', List[str]), |
| ('mode', str), |
| ('buffers', Optional[List[str]]), |
| ('returns', List[ReturnType]), |
| ('inplace', bool), |
| ('abstract', bool), |
| ('device_guard', bool), |
| ('with_gil', bool), |
| ('deprecated', bool), |
| ]) |
| |
| |
| def device_guard(option, formals, is_factory_method=False): |
| # For factory methods the `DeviceGuard` is already in the template. |
| if option.get('device_guard', True) and not is_factory_method: |
| tensor_arguments = [f for f in formals if f['dynamic_type'] in {'Tensor', 'TensorList'}] |
| if tensor_arguments: |
| tensor_argument = tensor_arguments[0]['name'] |
| return 'const DeviceGuard device_guard({});'.format(tensor_argument) |
| return '// DeviceGuard omitted' |
| |
| |
| def is_real_argument_to_wrapper(argument): |
| # type: (THFormal) -> bool |
| return not argument.get('output', False) and\ |
| argument['type'] != 'CONSTANT' and\ |
| argument['type'] != 'argument' |
| |
| |
| def is_mutable_formal_argument(argument, option): |
| # type: (THFormal, FunctionOption) -> bool |
| return argument.get('output') or option['inplace'] and argument['name'] == 'self' |
| |
| |
| def to_return_type(arg, option): |
| # type: (THFormal, FunctionOption) -> ReturnType |
| t = arg['type'] |
| rt = TYPE_RETURN.get(t, t) |
| if rt == 'Tensor' and not arg.get('allocate'): |
| rt = rt + ' &' |
| if not is_mutable_formal_argument(arg, option): |
| rt = 'const ' + rt |
| return { |
| 'name': arg['name'], |
| 'type': rt, |
| 'dynamic_type': DYNAMIC_TYPE.get(arg['type'], arg['type']), |
| } |
| |
| |
| def create_generic(top_env, declarations): |
| # type: (TopEnvironment, List[FunctionOption]) -> List[OutputDeclaration] |
| # translates defaults from cwrap types to C++ values |
| def translate_default(argument, type_str, default): |
| # type: (THFormal, str, Any) -> Any |
| if default is None: |
| # cause the default constructor for the object to run |
| return '{}' |
| if 'if_true' in argument: |
| return argument['default'] == argument['if_true'] |
| for pattern, replacement in HEADER_CONSTANT_REPLACEMENTS: |
| default = re.sub(pattern, replacement, str(default)) |
| if type_str in {'Scalar', 'int64_t', 'double'}: |
| try: |
| return int(default) |
| except Exception: |
| try: |
| return float(default) |
| except Exception: |
| return default |
| elif type_str == 'bool': |
| assert default.lower() in ['true', 'false'] |
| return default.lower() == 'true' |
| else: |
| return default |
| |
| # change from THTensor* to Tensor & so we get how it will appear |
| # in the aten argument list... |
| def translate_formal(argument, option): |
| # type: (THFormal, FunctionOption) -> AtFormal |
| type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type']) |
| if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option): |
| type_str = 'const ' + type_str |
| translated = { |
| 'name': argument['name'], |
| 'type': type_str, |
| 'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']), |
| } # type: AtFormal |
| if 'kwarg_only' in argument: |
| translated['kwarg_only'] = argument['kwarg_only'] |
| if 'default' in argument: |
| default = translate_default(argument, type_str, argument['default']) |
| translated['default'] = default |
| translated['default_init'] = argument.get('default_init', default) |
| if 'python_default_init' in argument: |
| assert 'default' not in argument |
| default = translate_default(argument, type_str, argument['python_default_init']) |
| translated['python_default_init'] = default |
| if argument.get('output'): |
| translated['output'] = True |
| if argument.get('size'): |
| translated['size'] = argument['size'] |
| if argument.get('is_nullable') is not None: |
| translated['is_nullable'] = argument['is_nullable'] |
| return translated |
| |
| def get_formals(option, include_constants=False): |
| # type: (FunctionOption, bool) -> List[AtFormal] |
| seen = set() # type: Set[str] |
| pos_args = [] # type: List[THFormal] |
| kwd_args = [] # type: List[THFormal] |
| |
| def insert(argument): |
| # type: (THFormal) -> None |
| if argument['name'] not in seen: |
| seen.add(argument['name']) |
| if argument.get('kwarg_only', False): |
| kwd_args.append(argument) |
| else: |
| pos_args.append(argument) |
| |
| def has_output_mask(argument): |
| # type: (THFormal) -> bool |
| return argument.get('allocate', False) and argument.get('mask', False) |
| |
| for argument in option['arguments']: |
| if argument.get('output') and not argument.get('allocate', False): |
| insert(argument) |
| for argument in option['arguments']: |
| if argument['type'] == 'THSTensor*': |
| # only enable for a subset of Dense/Sparse ops |
| if not (option.get('aten_dense_sparse', False)): |
| raise NYIError("Sparse Tensor") |
| |
| if include_constants and argument['type'] == 'CONSTANT': |
| insert(argument) |
| elif is_real_argument_to_wrapper(argument): |
| insert(argument) |
| if any(has_output_mask(arg) for arg in option['arguments']): |
| mask_size = sum(has_output_mask(arg) for arg in option['arguments']) |
| insert({ |
| 'name': 'output_mask', |
| # NB: Lack of space in comma works around parsing |
| # problem in gen_variable_type.py |
| 'type': 'std::array<bool,{}>'.format(mask_size), |
| 'default': '{{' + ', '.join(['true'] * mask_size) + '}}', |
| }) |
| |
| result = pos_args + kwd_args |
| return [translate_formal(argument, option) for argument in result] |
| |
| def get_return_types(option): |
| # type: (FunctionOption) -> List[ReturnType] |
| ret = option['return'] |
| if ret['kind'] == 'arguments': |
| argument_indices = ret['arguments'] |
| if len(argument_indices) == 1: |
| the_arg = option['arguments'][argument_indices[0]] |
| return [to_return_type(the_arg, option)] |
| else: |
| return [to_return_type(option['arguments'][idx], option) |
| for idx in argument_indices] |
| elif ret['kind'] == 'type': |
| return [{ |
| 'type': TYPE_RETURN.get(ret['type'], ret['type']), |
| 'dynamic_type': DYNAMIC_TYPE.get(ret['type'], ret['type']), |
| }] |
| else: |
| raise Exception("format_return_type") |
| |
| def format_return_type(return_types): |
| # type: (List[ReturnType]) -> str |
| if len(return_types) == 1: |
| return return_types[0]['type'] |
| return "std::tuple<{}>".format(','.join(r['type'] for r in return_types)) |
| |
| def find_dispatch_tensor(formals): |
| # type: (List[AtFormal]) -> Optional[str] |
| # dispatch to self if it's a parameter |
| for formal in formals: |
| if formal['name'] == 'self' and formal['dynamic_type'] == 'Tensor': |
| return formal['name'] |
| # otherwise dispatch to the first Tensor or TensorList |
| for formal in formals: |
| if 'TensorList' == formal['dynamic_type'] or formal['dynamic_type'] == 'Tensor': |
| return formal['name'] |
| return None |
| |
| def format_formal(f): |
| # type: (AtFormal) -> str |
| return '{} {}'.format(f['type'], f['name']) |
| |
| def formal_with_default(f): |
| # type: (AtFormal) -> str |
| s = format_formal(f) |
| v = f.get('default') |
| if v is None: |
| return s |
| if isinstance(v, bool): |
| v = str(v).lower() |
| return '{}={}'.format(s, v) |
| |
| def get_broadcast_argument(option): |
| # type: (FunctionOption) -> Optional[THFormal] |
| for argument in option['arguments']: |
| if argument.get('broadcast'): |
| return argument |
| return None |
| |
| def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims): |
| # type: (THFormal, bool, bool) -> List[str] |
| # Note: broadcast_dims can change type... |
| # return the actuals that will be passed to the broadcast function. |
| # 1) in the common case, this is the broadcasted argument (e.g. "self") followed by the tensors |
| # that it is broadcasted against (comma-separated) (e.g. "self, tensor1, tensor2"). |
| # 2) in the broadcast_dims case, this is the broadcasted argument (e.g. "self") followed by the sizes |
| # it is broadcasted to (as an initializer list), so e.g. the specification |
| # "mat1.dim0,mat2.dim1" gets transformed to "self, {mat1.size(0),mat2.size(1)}" |
| if not broadcast_dims: |
| broadcast_actuals = [broadcast_arg['name']] + broadcast_arg['broadcast'].split()[0].split(",") |
| else: |
| broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',') |
| # generate size call for each dimension |
| broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')' # type: ignore |
| for x in broadcast_dims_spec]) |
| broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}' # type: ignore |
| broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list] |
| |
| return broadcast_actuals |
| |
| def emit_nn_body(option): |
| # type: (FunctionOption) -> Union[str, List[str]] |
| # Concrete definition on Type.cpp for NN functions. Delegates to the |
| # xxx_forward variant variant after creating any necessary buffers. |
| actuals = option['actuals'] |
| base_name = option['name'][:-1] if option['inplace'] else option['name'] |
| fwd_name = option['api_name'].replace(base_name, base_name + '_forward') |
| |
| if len(option['buffers']) == 0: |
| return 'return {}({});'.format(fwd_name, ', '.join(actuals)) |
| |
| body = [] # type: List[str] |
| if option['api_name'].endswith('_out'): |
| # _out variants must create buffers and insert them in the |
| # arguments list between output and input arguments |
| for buffer in option['buffers']: |
| body.append('Tensor {} = tensor();'.format(buffer['name'])) |
| actuals = [arg['name'] for arg in option['arguments'] if arg.get('output')] |
| actuals += [buffer['name'] for buffer in option['buffers']] |
| actuals += [arg['name'] for arg in option['arguments'] if not arg.get('output')] |
| |
| body.append('return std::get<0>({}({}));'.format(fwd_name, ', '.join(actuals))) |
| return body |
| |
| def process_option(option, output_options): |
| # type: (FunctionOption, List[OutputDeclaration]) -> None |
| option['inplace'] = re.search( |
| '(^__i|[^_]_$)', option['api_name']) is not None |
| |
| # print(yaml.dump(option)) |
| formals = get_formals(option) |
| option['formals_list'] = formals |
| option['formals'] = [format_formal(f) for f in formals] |
| option['formals_with_defaults'] = [formal_with_default(f) for f in formals] |
| option['returns'] = get_return_types(option) |
| option['return_type'] = format_return_type(option['returns']) |
| option['return_call'] = 'return ' if option['return_type'] != 'void' else '' |
| option['actuals'] = [f['name'] for f in formals] |
| |
| option['method_formals'] = [format_formal(f) for f in formals |
| if f['name'] != 'self'] |
| option['method_formals_with_defaults'] = ( |
| [formal_with_default(f) for f in formals if f['name'] != 'self']) |
| option['method_actuals'] = [ |
| f['name'] if f['name'] != 'self' else '*this' for f in formals] |
| |
| # There are no cases where these differ, but they do in native_functions |
| option['type_method_formals'] = option['formals'] |
| option['type_method_formals_with_defaults'] = option['formals_with_defaults'] |
| option['type_method_actuals'] = option['actuals'] |
| |
| option['const_mark'] = '' if option['inplace'] else ' const' |
| |
| is_method = 'method' in option['variants'] |
| is_function = 'function' in option['variants'] |
| dispatch_tensor = find_dispatch_tensor(formals) |
| is_namespace_function = is_function and dispatch_tensor is not None |
| |
| broadcast_arg = get_broadcast_argument(option) |
| # "s_" for "same size". |
| option['method_prefix_derived'] = '' if broadcast_arg is None else 's_' |
| option['device_guard_declaration'] = device_guard(option, formals) |
| |
| env = nested_dict(option, top_env) |
| |
| mode = option['mode'] |
| abstract = True |
| if mode == 'NN' and option.get('cimpls') is None: |
| # NN function with no _forward/_backward suffix don't have cimpls. |
| # They call the _forward function and discard any buffer returns |
| abstract = False |
| top_env['type_method_declarations'].append( |
| TYPE_METHOD_DECLARATION_CONCRETE.substitute(env)) |
| body = emit_nn_body(option) |
| top_env['type_method_definitions'].append( |
| TYPE_METHOD_DEFINITION_CONCRETE.substitute( |
| env, type_definition_body=body)) |
| elif broadcast_arg is None: |
| top_env['type_method_declarations'].append( |
| TYPE_METHOD_DECLARATION_ABSTRACT.substitute(env)) |
| top_env['type_method_definitions'].append( |
| TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env)) |
| else: |
| top_env['type_method_declarations'].append( |
| TYPE_METHOD_DECLARATION_BROADCAST.substitute(env)) |
| top_env['type_method_declarations'].append( |
| TYPE_METHOD_DECLARATION_ABSTRACT.substitute(env)) |
| top_env['type_method_definitions'].append( |
| TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env)) |
| |
| broadcast_inplace = 'inplace' in broadcast_arg['broadcast'] |
| broadcast_dims = 'dims:' in broadcast_arg['broadcast'] |
| option['broadcast_actuals'] = get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims) |
| if not broadcast_dims: |
| option['broadcast_returns'] = (["b_" + x for x in option['broadcast_actuals'] |
| if x != broadcast_arg['name'] or not broadcast_inplace]) |
| else: |
| option['broadcast_returns'] = ["b_" + broadcast_arg['name']] |
| |
| option['broadcast_function'] = 'expand_' + ('inplace' if broadcast_inplace |
| else 'size' if broadcast_dims else 'outplace') |
| option['broadcast_modified_actuals'] = ['b_' + y if 'b_' + y in option['broadcast_returns'] else y |
| for y in option['actuals']] |
| top_env['type_method_definitions'].append( |
| TYPE_METHOD_DEFINITION_BROADCAST.substitute(env)) |
| |
| method_of = ['Type'] |
| if is_method: |
| top_env['tensor_method_declarations'].append( |
| TENSOR_METHOD_DECLARATION.substitute(env)) |
| top_env['tensor_method_definitions'].append( |
| TENSOR_METHOD_DEFINITION.substitute(env)) |
| method_of.append('Tensor') |
| |
| if is_namespace_function: |
| option['inferred_type'] = 'infer_type({})'.format(dispatch_tensor) |
| top_env['function_declarations'].append( |
| FUNCTION_DECLARATION.substitute(env)) |
| top_env['function_definitions'].append( |
| FUNCTION_DEFINITION.substitute(env)) |
| method_of.append('namespace') |
| |
| buffer_names = [buffer['name'] for buffer in option.get('buffers', [])] |
| |
| output_options.append(OutputDeclaration( |
| name=option['api_name'], |
| method_prefix_derived=option['method_prefix_derived'], |
| arguments=formals, |
| method_of=method_of, |
| mode=mode, |
| buffers=buffer_names, |
| returns=option['returns'], |
| inplace=option['inplace'], |
| # See Note [Abstract ATen methods] |
| abstract=abstract, |
| device_guard=option.get('device_guard', True), |
| with_gil=option.get('with_gil', False), |
| deprecated=option.get('deprecated', False) |
| )) |
| |
| def native_get_formals(option, include_constants=False): |
| # type: (FunctionOption, bool) -> List[AtFormal] |
| seen = set() # type: Set[str] |
| pos_args = [] |
| kwd_args = [] |
| |
| def insert(argument): |
| # type: (AtFormal) -> None |
| if argument['name'] not in seen: |
| seen.add(argument['name']) |
| if argument.get('kwarg_only', False): |
| kwd_args.append(argument) |
| else: |
| pos_args.append(argument) |
| |
| for argument in option['arguments']: |
| insert(argument) |
| |
| # not clear we need dynamic_type translation as we can specify the correct type |
| # directly in native functions |
| def add_dynamic_type(argument, option): |
| # type: (AtFormal, FunctionOption) -> AtFormal |
| argument['dynamic_type'] = NATIVE_DYNAMIC_TYPE.get(argument['type'], argument['type']) |
| return argument |
| |
| result = pos_args + kwd_args |
| result = [add_dynamic_type(argument, option) for argument in result] |
| |
| # ensure we get reference-type formals when appropriate |
| def native_translate_formals(argument, option): |
| # type: (AtFormal, FunctionOption) -> AtFormal |
| def translate_map(const): |
| # type: (bool) -> Dict[str, str] |
| return { |
| 'Tensor': 'const Tensor &' if const else 'Tensor &', |
| 'BoolTensor': 'const Tensor &' if const else 'Tensor &', |
| 'IndexTensor': 'const Tensor &' if const else 'Tensor &', |
| 'Type': 'const Type &' if const else 'Type &', |
| 'TensorOptions': 'const TensorOptions &' if const else 'TensorOptions &', |
| } |
| |
| if (option['inplace'] and argument['name'] == 'self') or argument.get('output', False): |
| argument['type'] = translate_map(False).get(argument['type'], argument['type']) |
| else: |
| argument['type'] = translate_map(True).get(argument['type'], argument['type']) |
| |
| return argument |
| |
| result = [native_translate_formals(argument, option) for argument in result] |
| return result |
| |
| # this can return multiple return types in a list, e.g. ['Tensor', 'Tensor'] |
| def native_get_return_types(option): |
| # type: (FunctionOption) -> List[ReturnType] |
| ret = option['return'] |
| |
| return_types = [] # List[ReturnType] |
| for t_raw in ret: |
| if isinstance(t_raw, string_type): |
| t = t_raw |
| name = None |
| elif t_raw is None: |
| t = 'void' |
| name = None |
| else: |
| t = t_raw['type'] |
| name = t_raw['name'] |
| |
| # can't actually return a TensorList (since it's a reference object) |
| actual_return_type = {'TensorList': 'std::vector<Tensor>'}.get(t, t) |
| |
| if actual_return_type == 'Tensor' and (option['inplace'] or option['api_name'].endswith('_out')): |
| # follow normal ATen convention of returning Tensor & for inplace functions. |
| actual_return_type = 'Tensor &' |
| |
| rtype = { |
| 'type': actual_return_type, |
| 'dynamic_type': NATIVE_DYNAMIC_TYPE.get(t, t), |
| } # type: ReturnType |
| if name is not None: |
| rtype['name'] = name |
| return_types.append(rtype) |
| |
| return return_types |
| |
| def process_native(option, output_options): |
| # type: (FunctionOption, List[OutputDeclaration]) -> None |
| option['inplace'] = re.search( |
| '(^__i|[^_]_$)', option['api_name']) is not None |
| |
| formals = native_get_formals(option) |
| option['formals_list'] = formals |
| option['formals'] = [format_formal(f) for f in formals] |
| option['formals_with_defaults'] = [formal_with_default(f) for f in formals] |
| option['returns'] = native_get_return_types(option) |
| option['return_type'] = format_return_type(option['returns']) |
| option['return_call'] = 'return ' if option['return_type'] != 'void' else '' |
| option['actuals'] = [f['name'] for f in formals] |
| |
| option['method_formals'] = [format_formal(f) for f in formals |
| if f['name'] != 'self'] |
| option['method_formals_with_defaults'] = ( |
| [formal_with_default(f) for f in formals if f['name'] != 'self']) |
| option['method_actuals'] = [ |
| f['name'] if f['name'] != 'self' else '*this' for f in formals] |
| |
| def find_formal(formal_name, formals): |
| for formal in formals: |
| if formal_name == formal['dynamic_type']: |
| return formal |
| return None |
| |
| dispatch_tensor = find_dispatch_tensor(formals) |
| dispatch_type = None if dispatch_tensor else find_formal('Type', formals) |
| if dispatch_type: |
| dispatch_type['is_type_dispatched'] = True |
| |
| option['type_method_formals'] = [format_formal(f) for f in formals if f != dispatch_type] |
| option['type_method_formals_with_defaults'] = [formal_with_default(f) for f in formals if f != dispatch_type] |
| option['type_method_actuals'] = [f['name'] for f in formals if f != dispatch_type] |
| option['native_actuals'] = [f['name'] if f != dispatch_type else '*this' for f in formals] |
| |
| option['const_mark'] = '' if option['inplace'] else ' const' |
| |
| is_method = 'method' in option['variants'] |
| is_namespace_function = 'function' in option['variants'] |
| is_factory_method = find_formal('TensorOptions', formals) |
| is_deprecated_factory_method = len(formals) > 0 and \ |
| formals[0]['dynamic_type'] == 'Type' and \ |
| option['return_type'] == 'Tensor' and option['deprecated'] |
| needs_native_definition = not is_deprecated_factory_method |
| |
| has_dispatch = dispatch_tensor or dispatch_type |
| |
| option['method_prefix_derived'] = '' |
| option['device_guard_declaration'] = device_guard(option, formals, is_factory_method) |
| |
| env = nested_dict(option, top_env) |
| |
| broadcast_arg = get_broadcast_argument(option) |
| if broadcast_arg is not None: |
| raise Exception("broadcasting is not yet supported for native functions, " |
| "but specified for function {}", option['name']) |
| |
| # Factory methods are not dispatched over `Type`. |
| if not is_factory_method: |
| if option['deprecated']: |
| top_env['type_method_declarations'].append(DEPRECATED_TYPE_METHOD_DECLARATION_CONCRETE.substitute(env)) |
| else: |
| top_env['type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env)) |
| dispatch = option['type_method_definition_dispatch'] |
| option['native_type_method_dispatch'] = dispatch |
| |
| # Note [Abstract ATen methods] |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| # An abstract ATen method is one whose dispatch differs between |
| # types. These are implemented in derived types (with a |
| # standard (throwing) definition in Type). A concrete ATen |
| # method is one which has the same dispatch for all types; |
| # we just implement it in the base Type. This is exposed |
| # in Declarations.yaml via a field named 'abstract'. |
| abstract = False |
| if isinstance(dispatch, dict): |
| abstract = True |
| top_env['type_method_definitions'].append( |
| TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env)) |
| elif is_deprecated_factory_method: |
| top_env['type_method_definitions'].append( |
| DEPRECATED_TYPE_METHOD_DEFINITION_CONCRETE.substitute(env)) |
| elif not is_factory_method: |
| body = TYPE_DEFINITION_BODY_NATIVE.substitute(env) |
| top_env['type_method_definitions'].append( |
| TYPE_METHOD_DEFINITION_CONCRETE.substitute( |
| env, type_definition_body=body)) |
| |
| # generate the at::native function declarations (i.e. what the user will implement) |
| if needs_native_definition: |
| if isinstance(dispatch, dict): |
| generated_native_functions = [] # type: List[str] |
| for key in sorted(dispatch.keys()): |
| value = dispatch[key] |
| if value not in generated_native_functions: |
| option['native_type_method_dispatch'] = value |
| top_env['native_function_declarations'].append( |
| NATIVE_DECLARATION.substitute(env)) |
| generated_native_functions.append(value) |
| else: |
| top_env['native_function_declarations'].append( |
| NATIVE_DECLARATION.substitute(env)) |
| |
| method_of = ['Type'] |
| if is_method: |
| top_env['tensor_method_declarations'].append( |
| TENSOR_METHOD_DECLARATION.substitute(env)) |
| top_env['tensor_method_definitions'].append( |
| TENSOR_METHOD_DEFINITION.substitute(env)) |
| method_of.append('Tensor') |
| |
| if is_namespace_function: |
| if dispatch_type: |
| option['inferred_type'] = dispatch_type['name'] |
| elif dispatch_tensor: |
| option['inferred_type'] = 'infer_type({})'.format(dispatch_tensor) |
| else: |
| # doesn't depend on a specific type, use undefined float |
| option['inferred_type'] = 'at::getType(at::Backend::Undefined, at::ScalarType::Float)' |
| declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION |
| top_env['function_declarations'].append(declaration.substitute(env)) |
| if is_factory_method: |
| top_env['function_definitions'].append(FACTORY_DEFINITION.substitute(env)) |
| elif is_deprecated_factory_method: |
| top_env['function_definitions'].append(DEPRECATED_FACTORY_DEFINITION.substitute(env)) |
| else: |
| top_env['function_definitions'].append(FUNCTION_DEFINITION.substitute(env)) |
| method_of.append('namespace') |
| |
| output_options.append(OutputDeclaration( |
| name=option['api_name'], |
| method_prefix_derived=option['method_prefix_derived'], |
| arguments=formals, |
| method_of=method_of, |
| mode=option['mode'], |
| buffers=None, |
| returns=option['returns'], |
| inplace=option['inplace'], |
| # See Note [Abstract ATen methods] |
| abstract=abstract, |
| device_guard=option.get('device_guard', True), |
| with_gil=option.get('with_gil', False), |
| deprecated=option['deprecated'], |
| )) |
| |
| output_declarations = [] # type: List[OutputDeclaration] |
| for declaration in declarations: |
| output_options = [] # type: List[OutputDeclaration] |
| for option in declaration['options']: |
| try: |
| if option['mode'] != 'native': |
| process_option(option, output_options) |
| else: |
| process_native(option, output_options) |
| except NYIError: |
| option['skip'] = True |
| output_declarations.extend(output_options) |
| return output_declarations |
| |
| |
| def create_derived(backend_type_env, declarations): |
| # type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]] |
| type_object_declarations = [] |
| type_object_definitions = [] |
| |
| is_cuda = 'CUDA' in backend_type_env['Backend'] |
| |
| real_is_half = backend_type_env['ScalarName'] == 'Half' |
| |
| def replace_with_null(argument): |
| # type: (THFormal) -> bool |
| return (argument['type'] == 'THGenerator*' and |
| backend_type_env['Backend'] == 'CUDA') |
| |
| def requires_checked_cast(argument): |
| # type: (THFormal) -> bool |
| if argument['type'] == 'IntList': |
| return 'size' in argument |
| return argument['type'] in CHECKED_CAST |
| |
| def nullable_argument(argument): |
| # type: (THFormal) -> bool |
| return argument.get('is_nullable', False) |
| |
| def bool_option_is_string(argument): |
| # type: (THFormal) -> bool |
| return 'if_true' in argument and isinstance(argument['if_true'], string_type) |
| |
| def get_argument(argument, option): |
| # type: (THFormal, FunctionOption) -> str |
| if replace_with_null(argument): |
| return 'NULL' |
| elif requires_checked_cast(argument): |
| checked_use = CHECKED_USE.get( |
| argument['type'], '{}_').format(argument['name']) |
| if real_is_half and argument['type'] == 'real': |
| checked_use = HALF_CONVERSION.substitute(value=checked_use) |
| if nullable_argument(argument): |
| checked_use = CHECKED_USE_NULLABLE.substitute( |
| env={}, arg_name=argument['name'], usage=checked_use) |
| return checked_use |
| elif argument['type'] == 'bool' and 'if_true' in argument: |
| if bool_option_is_string(argument): |
| tpl = '({}) ? "{}" : "{}"' |
| else: |
| tpl = '({}) ? {} : {}' |
| return tpl.format(argument['name'], |
| argument['if_true'], argument['if_false']) |
| elif argument['type'] == 'CONSTANT': |
| # this is a bool that is actually a string... |
| if bool_option_is_string(argument): |
| return '"{}"'.format(argument['name']) |
| v = str(argument.get('default', argument['name'])) |
| for pattern, replacement in CONSTANT_REPLACEMENTS: |
| v = re.sub(pattern, replacement, v) |
| return CodeTemplate(v).substitute(backend_type_env) |
| # e.g. argument 0, i.e. repeat the 0th argument in this position... |
| elif argument['type'] == 'argument': |
| index = int(argument['name']) |
| return get_argument(option['arguments'][index], option) |
| else: |
| return argument['name'] |
| |
| def drop_argument(argument, option): |
| # type: (THFormal, FunctionOption) -> bool |
| # Devices are handled in the body of the function. |
| if argument['name'] == 'device': |
| return True |
| return 'CUDA' in backend_type_env['Backend'] and ( |
| option['mode'] == 'TH' and argument['type'] == 'THGenerator*') |
| |
| def get_arguments(arguments, option): |
| # type: (List[THFormal], FunctionOption) -> List[str] |
| return [get_argument(argument, option) |
| for argument in arguments if not drop_argument(argument, option)] |
| |
| def is_actual_return_long(ret): |
| # type: (ReturnDecl) -> bool |
| if ret['type'] == 'long': |
| return True |
| if ret['type'] == 'real': |
| return backend_type_env['ScalarName'] == 'Long' |
| if ret['type'] == 'accreal': |
| return backend_type_env['AccScalarName'] == 'Long' |
| return False |
| |
| def handle_zero_dim(env, option): |
| # type: (Environment, FunctionOption) -> List[str] |
| zero_dim_dispatch = option.get('zero_dim_dispatch_when_scalar', '') |
| if not zero_dim_dispatch: |
| return [] |
| broadcasts_arg = zero_dim_dispatch in option.get('broadcast_actuals', '') |
| zero_dim_only = option.get('zero_dim_tensor_only', False) |
| # this combination doesn't seem to make sense |
| assert not (broadcasts_arg and zero_dim_only) |
| # if the argument broadcasts, then this would only affect cases where all broadcasted |
| # tensors were zero-dim, which is inconsistent with the scalar handling. |
| if broadcasts_arg: |
| return [] |
| zero_dim_actuals = [arg['name'] |
| if arg['name'] != zero_dim_dispatch else "Scalar({})".format(arg['name']) |
| for arg in option['formals_list']] |
| return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)] |
| |
| def handle_only_zero_dim(env, option): |
| # type: (Environment, FunctionOption) -> Optional[List[str]] |
| if option.get('zero_dim_tensor_only', False): |
| check_name = option['zero_dim_dispatch_when_scalar'] |
| return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)] |
| else: |
| return None |
| |
| def handle_sparse(env, option): |
| # type: (Environment, FunctionOption) -> List[str] |
| if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']: |
| return [] |
| check_name = option['when_sparse_dispatch'] |
| sparse_actuals = [arg['name'] |
| if arg['name'] != check_name else "SparseTensorRef({})".format(arg['name']) |
| for arg in option['formals_list']] |
| return [SPARSE_CHECK.substitute(env, check_name=check_name, sparse_actuals=sparse_actuals)] |
| |
| def allocate_arg(env, arg, output_count): |
| # type: (Environment, THFormal, int) -> List[str] |
| name = arg['name'] |
| state = '' |
| if is_cuda: |
| state = 'globalContext().getTHCState()' |
| allocation = CodeTemplate(ALLOC_NOARGS_WRAP[arg['type']]).substitute(env) |
| tensor_arg = '{}_'.format(name) |
| if arg.get('mask', False): |
| allocation = 'output_mask[{}] ? {} : nullptr'.format(output_count, allocation) |
| tensor_arg = ('{}_ == nullptr ? (TensorImpl*)UndefinedTensor::singleton() : (TensorImpl*){}_' |
| .format(name, name)) |
| return [ |
| 'auto {}_ = {};'.format(name, allocation), |
| 'auto {} = Tensor({}, false);'.format(name, tensor_arg), |
| ] |
| |
| def resize_arg(arg): |
| # type: (THFormal) -> str |
| resize = arg['resize'] |
| if isinstance(resize, str): |
| return "{}.resize_({}.sizes());".format(arg['name'], resize) |
| else: |
| resize_scalar = arg.get('resize_scalar', False) |
| if resize_scalar: |
| dims = ['{}.dim() == 0 ? 1 : {}.size({})'.format(name, name, dim) for name, dim in resize] |
| else: |
| dims = ['{}.size({})'.format(name, dim) for name, dim in resize] |
| return "{}.resize_({{ {} }});".format(arg['name'], ','.join(dims)) |
| |
| def handle_call(env, option, cimpl): |
| # type: (Environment, FunctionOption, FunctionOption) -> str |
| is_nn = option['mode'] == 'NN' |
| actuals = get_arguments(cimpl['arguments'], option) |
| if is_cuda or is_nn: |
| actuals = ['globalContext().getTHCState()'] + actuals |
| |
| cname = cimpl['cname'] |
| if option.get('sparse', False): |
| if is_cuda: |
| cname = 'THCS' + env['ScalarName'] + "Tensor_" + cname |
| else: |
| cname = env['THTensor'].replace('TH', 'THS') + '_' + cname |
| elif is_nn: |
| cname = 'THNN_{}'.format(env['THType']) + cname |
| else: |
| cname = env['THTensor'] + '_' + cname |
| |
| call = CALL_TEMPLATE.substitute(actuals=actuals, cname=cname) |
| if cimpl.get('condition') is not None: |
| call = 'if ({}) {}'.format(cimpl['condition'], call) |
| return call |
| |
| def emit_body(env, option): |
| # type: (Environment, FunctionOption) -> List[str] |
| body = [] # type: List[str] |
| body += handle_sparse(env, option) |
| body += handle_zero_dim(env, option) |
| only_zero_dim_check = handle_only_zero_dim(env, option) |
| if only_zero_dim_check is not None: |
| # code below only_zero_dim_check is unreachable so we do not need to generate the rest. |
| body += only_zero_dim_check |
| return body |
| |
| # arguments are potentially duplicated because of one argument |
| # referencing another |
| seen_names = set() # type: Set[str] |
| seen_tensorlists = set() # type: Set[str] |
| count = 0 |
| output_count = 0 |
| |
| # scalar_check is the heuristic conditions when a result may be a scalar_check |
| # if there is a IntListSize argument, then its dimensions are used to determine scalar. |
| # otherwise, it is true if all the input tensors are scalars, |
| scalar_check_is_from_size = False |
| scalar_check_is_from_option = False |
| scalar_check = None |
| scalar_check_opt = option.get('scalar_check') |
| if scalar_check_opt is not None: |
| if isinstance(scalar_check_opt, bool): |
| scalar_check = str(scalar_check_opt).lower() |
| else: |
| scalar_check = scalar_check_opt |
| scalar_check_is_from_option = True |
| |
| for arg in option['arguments']: |
| if is_real_argument_to_wrapper(arg): |
| count += 1 |
| if arg['type'] == 'IntListSize' and not scalar_check_is_from_option: |
| scalar_check_is_from_size = True |
| scalar_check = '{}.size() == 0'.format(arg['name']) |
| if arg['type'] == 'TensorList': |
| seen_tensorlists.add(arg['name']) |
| |
| wrap_dim_target = arg.get('wrap_dim', None) |
| if wrap_dim_target is not None: |
| # for Tensors, "name_" is the TensorImpl, but for TensorLists, it is an |
| # std::vector of TH*s. Since TH*s have different dimension rules, we used |
| # "name" instead, but keep "name_" for tensor to avoid an extra function call. |
| if wrap_dim_target not in seen_tensorlists: |
| wrap_dim_target = wrap_dim_target + "_" |
| body.append("{} = maybe_wrap_dim({}, {});" |
| .format(arg['name'], arg['name'], wrap_dim_target)) |
| |
| # only generated checked casts the first time we see it |
| if arg['name'] not in seen_names and requires_checked_cast(arg): |
| seen_names.add(arg['name']) |
| |
| # make a new allocation of TensorImpl, then wrap a Tensor around it. |
| if arg.get('allocate', False): |
| body += allocate_arg(env, arg, output_count) |
| output_count += 1 |
| # extract the TensorImpl from an existing tensor (or Storage, etc.) |
| else: |
| # special case where we allow undefined Tensors, and thus |
| # the checked cast succeeds even if the Tensor is not |
| # defined |
| null_okay = 'true' if nullable_argument(arg) else 'false' |
| default_init = [] |
| if 'default_init' in arg: |
| default_init.append(arg['default_init']) |
| |
| if arg['type'] in DIRECT_CONSTRUCTION_CHECKED_CAST: |
| body.append(CHECKED_CAST[arg['type']].substitute( |
| env, arg_name=arg['name'], arg_pos=count, |
| null_okay=null_okay, default_init=default_init, |
| size=arg.get('size'), |
| result_name=arg['name'] + '_')) |
| else: |
| check_cast = CHECKED_CAST[arg['type']].substitute( |
| env, arg_name=arg['name'], arg_pos=count, |
| null_okay=null_okay, default_init=default_init, |
| size=arg.get('size')) |
| body.append("auto {}_ = {};".format( |
| arg['name'], check_cast)) |
| if drop_argument(arg, option) or replace_with_null(arg): |
| body.append( |
| "(void) {}_; //silence unused warning".format(arg['name'])) |
| |
| initializers = [] |
| |
| # resize tensors for special ops that require it |
| if 'resize' in arg: |
| initializers.append(resize_arg(arg)) |
| |
| # also special handling where we zero some outputs. |
| if arg.get('zero', False) or (arg.get('cpu_zero', False) and not is_cuda): |
| initializers.append("{}.zero_();".format(arg['name'])) |
| |
| # only initialize non-null arguments |
| if nullable_argument(arg) and len(initializers) > 0: |
| body.append(CONDITIONAL_INITIALIZER.substitute({ |
| 'name': arg['name'], |
| 'initializer': initializers |
| })) |
| else: |
| body += initializers |
| |
| # for out-of-place: dim() == 0 for all input tensors is and'd to form |
| # the test for whether the output is also a scalar |
| # for in-place: dim() == 0 shouldn't change as a result of the operation |
| if (not arg.get('output') and 'Tensor' in arg['type'] and |
| 'TensorList' not in arg['type'] and |
| 'THS' not in arg['type'] and |
| not scalar_check_is_from_size and |
| not scalar_check_is_from_option and |
| not option['inplace']): |
| check = '{}->dim() == 0'.format(arg['name'] + '_') |
| if nullable_argument(arg): |
| check = '(!{} || {})'.format(arg['name'] + '_', check) |
| scalar_check = (check if scalar_check is None |
| else scalar_check + ' && ' + check) |
| |
| # cimpls, if it exists, contains the underlying C function names and |
| # arguments. Otherwise use option |
| cimpls = option.get('cimpls', [option]) |
| calls = [handle_call(env, option, cimpl) for cimpl in cimpls] |
| |
| ret = option['return'] |
| |
| if ret['kind'] == 'arguments': |
| if 'aten_custom_call' in option: |
| # all aten_custom_call bodies handle settings on their own. |
| scalar_check = None |
| body.append(CodeTemplate( |
| option['aten_custom_call']).substitute(env)) |
| else: |
| body.extend([call + ';' for call in calls]) |
| arguments_indices = ret['arguments'] |
| arguments = [option['arguments'][argi] |
| for argi in arguments_indices] |
| if scalar_check is not None: |
| if not isinstance(scalar_check, dict): |
| if len(arguments) > 1: |
| body.append("bool maybe_scalar = {};".format(scalar_check)) |
| scalar_check = 'maybe_scalar' |
| for arg in arguments: |
| scalar_check_arg = (scalar_check if not isinstance(scalar_check, dict) |
| else scalar_check.get(arg['name'])) # type: ignore |
| if scalar_check_arg is not None: |
| stmt = "{}_->maybe_zero_dim({});".format(arg['name'], scalar_check_arg) |
| if nullable_argument(arg): |
| stmt = "if ({}_) {}".format(arg['name'], stmt) |
| body.append(stmt) |
| if len(arguments_indices) == 1: |
| arg = arguments[0] |
| body.append("return {};".format(arg['name'])) |
| else: |
| types = [to_return_type(arg, option)['type'] |
| for arg in arguments] |
| # TODO: check for move semantics... |
| names = [arg['name'] for arg in arguments] |
| body.append(CodeTemplate("return std::tuple<${types}>(${names});").substitute( |
| types=types, names=names)) |
| elif ret['kind'] == 'type': |
| assert len(calls) == 1 |
| call = calls[0] |
| if 'aten_custom_call' in option: |
| # all aten_custom_call bodies handle settings on their own. |
| scalar_check = None |
| body.append(CodeTemplate( |
| option['aten_custom_call']).substitute(env)) |
| |
| if ret['type'] in ALLOC_WRAP.keys(): |
| maybe_scalar = "->maybe_zero_dim({})".format(scalar_check) \ |
| if scalar_check is not None \ |
| else "" |
| wrapped_tensor = CodeTemplate(ALLOC_WRAP[ret['type']]).substitute( |
| env, arguments=[call]) |
| return_tensor = "return Tensor((${wrapped_tensor})${maybe_scalar},false);" |
| body.append(CodeTemplate(return_tensor).substitute( |
| env, wrapped_tensor=wrapped_tensor, maybe_scalar=maybe_scalar)) |
| # return the same underlying Tensor type for both real and accreal; this ensures |
| # e.g. x.sum(0) and x.sum() return the same type. We explicitly cast to the |
| # ScalarType before constructing the scalarTensor to avoid overflow checking. |
| elif ret['type'] == 'accreal' or ret['type'] == 'real': |
| return_scalar = 'return scalarTensor(convert<${ScalarType}>(${call}));' |
| body.append(CodeTemplate(return_scalar).substitute(env, call=call)) |
| else: |
| # we using int64_t for long in the API, so correct it here... |
| if is_actual_return_long(ret): |
| call = "static_cast<int64_t>({})".format(call) |
| body.append("return {};".format(call)) |
| else: |
| raise Exception("NYI - return handling") |
| return body |
| |
| def process_option(option): |
| # type: (FunctionOption) -> None |
| pair = (backend_type_env['Backend'], |
| backend_type_env['ScalarName']) |
| if pair in option['backend_type_pairs']: |
| env = nested_dict(option, backend_type_env) |
| body = emit_body(env, option) # type: ignore |
| option['type_definition_body'] = body |
| type_object_declarations.append( |
| TYPE_DERIVED_DECLARATION.substitute(env)) |
| type_object_definitions.append( |
| TYPE_DERIVED_DEFINITION.substitute(env)) |
| |
| def process_native(option): |
| # type: (FunctionOption) -> None |
| dispatch = option['type_method_definition_dispatch'] |
| env = nested_dict(option, backend_type_env) |
| |
| if isinstance(dispatch, dict): |
| pair = (backend_type_env['Backend'], |
| backend_type_env['ScalarName']) |
| if pair in option['backend_type_pairs']: |
| native_dispatch = dispatch.get(pair[0]) |
| type_object_declarations.append( |
| TYPE_DERIVED_DECLARATION.substitute(env)) |
| if native_dispatch is None: |
| type_object_definitions.append( |
| TYPE_DERIVED_DEFINITION_NATIVE_MISSING.substitute(env)) |
| else: |
| option['native_type_method_dispatch'] = native_dispatch |
| type_object_definitions.append( |
| TYPE_DERIVED_DEFINITION_NATIVE.substitute(env)) |
| |
| for declaration in declarations: |
| for option in declaration['options']: |
| if not option.get('skip', False): |
| try: |
| if option['mode'] == 'NN' and option.get('cimpls') is None: |
| continue |
| if option['mode'] != 'native': |
| process_option(option) |
| else: |
| process_native(option) |
| except NYIError: |
| pass |
| return type_object_declarations, type_object_definitions |