blob: 142f59beed4ce731c850ea6ccabe6f6be8069e73 [file] [log] [blame]
from string import Template
from copy import deepcopy
from . import CWrapPlugin
from itertools import product, chain
from collections import OrderedDict
class THPPlugin(CWrapPlugin):
TYPE_UNPACK = {
'THFloatTensor*': Template('((THPFloatTensor*)$arg)->cdata'),
'THDoubleTensor*': Template('((THPDoubleTensor*)$arg)->cdata'),
'THLongTensor*': Template('((THPLongTensor*)$arg)->cdata'),
'THIntTensor*': Template('((THPIntTensor*)$arg)->cdata'),
'THTensor*': Template('((THPTensor*)$arg)->cdata'),
'THBoolTensor*': Template('((THPBoolTensor*)$arg)->cdata'),
'THIndexTensor*': Template('((THPIndexTensor*)$arg)->cdata'),
'THIntegerTensor*': Template('((THPIntegerTensor*)$arg)->cdata'),
'THCudaTensor*': Template('((THCPFloatTensor*)$arg)->cdata'),
'THCudaDoubleTensor*': Template('((THCPDoubleTensor*)$arg)->cdata'),
'THCudaIntTensor*': Template('((THCPIntTensor*)$arg)->cdata'),
'THCudaLongTensor*': Template('((THCPLongTensor*)$arg)->cdata'),
'THSFloatTensor*': Template('((THSPFloatTensor*)$arg)->cdata'),
'THSDoubleTensor*': Template('((THSPDoubleTensor*)$arg)->cdata'),
'THSLongTensor*': Template('((THSPLongTensor*)$arg)->cdata'),
'THSIntTensor*': Template('((THSPIntTensor*)$arg)->cdata'),
'THSTensor*': Template('((THSPTensor*)$arg)->cdata'),
'THSBoolTensor*': Template('((THSPBoolTensor*)$arg)->cdata'),
'THSIndexTensor*': Template('((THSPIndexTensor*)$arg)->cdata'),
'THLongStorage*': Template('((THPLongStorage*)$arg)->cdata'),
'THStorage*': Template('((THPStorage*)$arg)->cdata'),
'THGenerator*': Template('THPGenerator_TH_CData($arg)'),
'THSize*': Template('__size.get()'),
'THStride*': Template('__stride.get()'),
'void*': Template('THPUtils_unpackLong($arg)'),
'long': Template('THPUtils_unpackLong($arg)'),
'int': Template('((int) THPUtils_unpackLong($arg))'),
'int64_t': Template('THPUtils_unpackLong($arg)'),
'bool': Template('($arg == Py_True ? true : false)'),
'float': Template('THPFloatUtils_unpackReal($arg)'),
'double': Template('THPDoubleUtils_unpackReal($arg)'),
'real': Template('THPUtils_(unpackReal)($arg)'),
'accreal': Template('THPUtils_(unpackAccreal)($arg)'),
}
TYPE_CHECK = {
'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'),
'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'),
'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'),
'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'),
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == THPTensorClass'),
'THBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THPBoolTensorClass'),
'THIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIndexTensorClass'),
'THIntegerTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntegerTensorClass'),
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'),
'THCudaDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPDoubleTensorClass'),
'THCudaIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPIntTensorClass'),
'THCudaLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPLongTensorClass'),
'THSDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPDoubleTensorClass'),
'THSFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPFloatTensorClass'),
'THSLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPLongTensorClass'),
'THSIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIntTensorClass'),
'THSTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPTensorClass'),
'THSBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPBoolTensorClass'),
'THSIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIndexTensorClass'),
'THLongStorage*': Template('(PyObject*)Py_TYPE($arg) == THPLongStorageClass'),
'THStorage*': Template('(PyObject*)Py_TYPE($arg) == THPStorageClass'),
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
'THSize*': Template('THPUtils_tryUnpackLongs($arg, __size)'),
'THStride*': Template('THPUtils_tryUnpackLongs($arg, __stride)'),
'void*': Template('THPUtils_checkLong($arg)'),
'long': Template('THPUtils_checkLong($arg)'),
'int64_t': Template('THPUtils_checkLong($arg)'),
'int': Template('THPUtils_checkLong($arg)'),
'bool': Template('PyBool_Check($arg)'),
'float': Template('THPFloatUtils_checkReal($arg)'),
'double': Template('THPDoubleUtils_checkReal($arg)'),
'real': Template('THPUtils_(checkReal)($arg)'),
'accreal': Template('THPUtils_(checkAccreal)($arg)'),
}
SIZE_VARARG_CHECK = Template('THPUtils_tryUnpackLongVarArgs(args, $idx, __size)')
RETURN_WRAPPER = {
'THTensor*': Template('return THPTensor_(New)($result);'),
'THSTensor*': Template('return THSPTensor_(New)($result);'),
'THIndexTensor*': Template('return THPIndexTensor_(New)($result);'),
'THLongTensor*': Template('return THPLongTensor_New($result);'),
'THLongStorage*': Template('return THPLongStorage_New($result);'),
'THCudaIntTensor*': Template('return THCPIntTensor_New($result);'),
'THCudaLongTensor*': Template('return THCPLongTensor_New($result);'),
# TODO: make it smarter - it should return python long if result doesn't fit into an int
'long': Template('return PyInt_FromLong($result);'),
'int64_t': Template('return PyInt_FromLong($result);'),
'int': Template('return PyLong_FromLong($result);'),
'accreal': Template('return THPUtils_(newAccreal)($result);'),
'self': Template('Py_INCREF(self);\nreturn (PyObject*)self;'),
'real': Template('return THPUtils_(newReal)($result);'),
}
TENSOR_METHODS_DECLARATION = Template("""
static PyMethodDef TH${sparse}PTensor_$stateless(methods)[] = {
$methods
{NULL}
};
""")
WRAPPER_TEMPLATE = Template("""\
PyObject * $name(PyObject *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
int __tuplecount = args ? (int) PyTuple_Size(args) : 0;
int __dictcount = kwargs ? (int) PyDict_Size(kwargs) : 0;
int __argcount = __tuplecount + __dictcount;
$variables
$init
$options
}
THPUtils_invalidArguments(args, kwargs, "$readable_name", $num_options, $expected_args);
return NULL;
END_HANDLE_TH_ERRORS
}
""")
ALLOCATE_TMPL = Template("""\
THP${type}TensorPtr _${name}_guard((THP${type}Tensor*) THP${type}Tensor_NewEmpty());
if (!_${name}_guard.get()) return NULL;
THP${type}Tensor* $name = _${name}_guard.get();
""")
ALLOCATE_CUDA = Template("""\
#if IS_CUDA
${cuda}
#else
${cpu}
#endif
""")
def _allocate(typename, tmpl, cuda_tmpl=None, sparse=False):
code = tmpl.safe_substitute(type=typename)
if typename == '':
code = code.replace('NewEmpty', '(NewEmpty)')
if cuda_tmpl:
cuda_code = code.replace('THP', 'THCP')
code = cuda_tmpl.substitute(cuda=cuda_code, cpu=code)
if sparse:
code = code.replace('THP', 'THSP')
code = code.replace('THCP', 'THCSP')
return Template(code)
ALLOCATE_TYPE = {
'THTensor*': _allocate('', ALLOCATE_TMPL),
'THLongTensor*': _allocate('Long', ALLOCATE_TMPL),
'THIntTensor*': _allocate('Int', ALLOCATE_TMPL),
'THBoolTensor*': _allocate('Byte', ALLOCATE_TMPL, ALLOCATE_CUDA),
'THIndexTensor*': _allocate('Long', ALLOCATE_TMPL, ALLOCATE_CUDA),
'THIntegerTensor*': _allocate('Int', ALLOCATE_TMPL, ALLOCATE_CUDA),
'THSTensor*': _allocate('', ALLOCATE_TMPL, sparse=True),
}
TYPE_NAMES = {
'THTensor*': '" THPTensorStr "',
'THSTensor*': '" THSPTensorStr "',
'THStorage*': '" THPStorageStr "',
'THGenerator*': 'torch.Generator',
'THLongStorage*': '" THPModuleStr "LongStorage',
'THLongTensor*': '" THPModuleStr "LongTensor',
'THIntTensor*': '" THPModuleStr "IntTensor',
'THBoolTensor*': '" THPModuleStr "ByteTensor',
'THIndexTensor*': '" THPModuleStr "LongTensor',
'THIntegerTensor*': '" THPModuleStr "IntTensor',
'THFloatTensor*': '" THPModuleStr "FloatTensor',
'THDoubleTensor*': '" THPModuleStr "DoubleTensor',
'THCudaTensor*': 'torch.cuda.FloatTensor',
'THCudaDoubleTensor*': 'torch.cuda.DoubleTensor',
'THCudaIntTensor*': 'torch.cuda.IntTensor',
'THCudaLongTensor*': 'torch.cuda.LongTensor',
'THSize*': 'torch.Size',
'THStride*': 'tuple',
'long': 'int',
'int64_t': 'int',
'int': 'int',
'real': '" RealStr "',
'double': 'float',
'accreal': '" RealStr "',
'bool': 'bool',
'const char*': 'bool', # Can come only from bool option.
}
OUT_INIT = """
___out = kwargs ? PyDict_GetItemString(kwargs, "out") : NULL;
if (___out == Py_None) { ___out = NULL; __dictcount--; __argcount--; }
"""
def __init__(self):
self.declarations = []
self.stateless_declarations = []
self.docstrings = []
BACKEND_SUBSTITUTIONS = {
'CPU': 'TH',
'CUDA': 'THCuda',
}
def substitute_tensor_backend(self, arg, option):
if 'Backend' in arg['type']:
arg['type'] = arg['type'].replace('Backend',
self.BACKEND_SUBSTITUTIONS.get(option['backends'][0]))
# handle the fact that THCudaTensor isn't THCudaFloatTensor
if option['backends'][0] == 'CUDA' and 'Float' in arg['type']:
arg['type'] = arg['type'].replace('Float', '')
def get_type_unpack(self, arg, option):
self.substitute_tensor_backend(arg, option)
return self.TYPE_UNPACK.get(arg['type'], None)
def get_type_check(self, arg, option):
if arg['type'] == 'THSize*' and arg.get('long_args', False):
return self.SIZE_VARARG_CHECK
self.substitute_tensor_backend(arg, option)
return self.TYPE_CHECK.get(arg['type'], None)
# TODO: argument descriptions shouldn't be part of THP, but rather a general cwrap thing
def get_wrapper_template(self, declaration):
arg_desc = OrderedDict()
def format_arg(arg, var_args=False):
if var_args and arg.get('long_args', False):
return 'int ... ' + arg['name']
else:
return self.TYPE_NAMES[arg['type']] + ' ' + arg['name']
def format_args(args, var_args=False):
option_desc = [format_arg(arg, var_args)
for arg in args
if not arg.get('ignore_check', False) and
not arg.get('output')]
output_args = list(filter(lambda a: a.get('output'), args))
if output_args:
if len(output_args) > 1:
out_type = 'tuple['
out_type += ', '.join(
self.TYPE_NAMES[arg['type']] for arg in output_args)
out_type += ']'
option_desc += ['#' + out_type + ' out']
else:
arg = output_args[0]
option_desc += ['#' + self.TYPE_NAMES[arg['type']] + ' out']
if option_desc:
return '({})'.format(', '.join(option_desc))
else:
return 'no arguments'
for option in declaration['options']:
arg_desc[format_args(option['arguments'], False)] = True
arg_desc[format_args(option['arguments'], True)] = True
arg_desc = sorted(list(arg_desc.keys()), key=len)
arg_desc = ['"' + desc + '"' for desc in arg_desc]
arg_str = ', '.join(arg_desc)
variables_str = '\n'.join(declaration.get('variables', []))
init_str = '\n'.join(declaration.get('init', []))
if 'stateless' in declaration['name']:
readable_name = 'torch.' + declaration['python_name']
else:
readable_name = declaration['python_name']
return Template(self.WRAPPER_TEMPLATE.safe_substitute(
readable_name=readable_name, num_options=len(arg_desc),
expected_args=arg_str, variables=variables_str, init=init_str))
def get_return_wrapper(self, option):
return self.RETURN_WRAPPER.get(option['return'], None)
def get_arg_accessor(self, arg, option):
if arg['name'] == 'self':
return 'self'
if arg.get('output'):
if not option['output_provided']:
return arg['name']
if option['output_count'] == 1:
return '___out'
else:
return 'PyTuple_GET_ITEM(___out, {})'.format(arg['output_idx'])
def process_docstrings(self):
for declaration in self.declarations:
docstr = declaration.get('docstring_method')
if docstr is None:
continue
declaration['docstring_content'] = docstr.replace('\n', '\\n')
declaration['docstring_var'] = 'docstr_' + declaration['python_name']
for declaration in self.stateless_declarations:
docstr = declaration.get('docstring_stateless')
if docstr is None:
continue
declaration['docstring_content'] = docstr.replace('\n', '\\n')
declaration['docstring_var'] = 'stateless_docstr_' + declaration['python_name']
def generate_out_options(self, declaration):
new_options = []
declaration.setdefault('init', [])
declaration['init'] += [self.OUT_INIT]
for option in declaration['options']:
out_idx = []
for i, arg in enumerate(option['arguments']):
if arg.get('output'):
out_idx.append(i)
if not out_idx:
option['has_output'] = True
option['output_provided'] = False
new_options.append(option)
continue
for output_provided in (True, False):
option_copy = deepcopy(option)
option_copy['has_output'] = True
option_copy['output_provided'] = output_provided
option_copy['output_count'] = len(out_idx)
for i, idx in enumerate(out_idx):
arg = option_copy['arguments'][idx]
arg['output_idx'] = i
if not output_provided:
arg['ignore_check'] = True
else:
option_copy['argcount_offset'] = -len(out_idx) + 1
arg['no_kwargs'] = True
arg['no_idx'] = True
new_options.append(option_copy)
declaration['options'] = new_options
def process_declarations(self, declarations):
new_declarations = []
def has_arg_type(declaration, type_name):
return any(arg['type'] == type_name
for option in declaration['options']
for arg in option['arguments'])
def has_long_args(declaration):
return any(arg.get('long_args', False)
for option in declaration['options']
for arg in option['arguments'])
def has_output_args(declaration):
return any(arg.get('output')
for option in declaration['options']
for arg in option['arguments'])
def backends_types_to_defined_if_string(declaration):
# A declaration has two fields: 'backend', which stores a list of
# backends (currently 'cpu' and 'cuda') the declaration applies
# to, and 'types', which stores a list of real types the
# declaration applies to. In PyTorch, when a function is only
# supported by a subset of types, we wrap it in macro definition
# checks.
#
# Previously, we manually required the cwrap declaration to
# specify for which backend/type combinations a function was
# defined for. Now, we explicitly list the types and backends for
# a declaration, if it should only be supported for a specific
# subset of types, backends, or type-backend pairs.
types = declaration.get('types', [])
backends = declaration['backends']
all_backends = ['CPU', 'CUDA']
def get_defined_string(backend, real):
if backend == 'CUDA':
if real == 'all':
return "IS_CUDA"
else:
return 'CUDA_{0}'.format(real.upper())
else:
if real == 'all':
return "!IS_CUDA"
else:
return 'defined(TH_REAL_IS_{0})'.format(real.upper())
def expand_composite_type(p, t):
if t == 'floating_point':
result = ['double', 'float']
if p == 'CUDA':
result.append('half')
elif t == 'integral':
result = ['byte', 'char', 'short', 'int', 'long']
else:
result = [t]
return result
defineds = []
# The logic below does not handle corner cases well. We allow the
# declaration to have a field 'backend_type_pairs' that stores a
# dictionary from type --> backend representing allowed
# combinations. Let's use these first.
for pair in declaration.get('backend_type_pairs', []):
p, t = pair
defineds.extend([get_defined_string(p, et) for et in
expand_composite_type(p, t)])
# In the base case, types is empty and backends contains both
# 'CPU' and 'CUDA' --> this means we support all types, and our
# string should be empty, or simply the list of explict type
# backend pairs
if (len(types) == 0 and all([proc in backends for proc in
all_backends])):
return " || ".join(defineds)
# Case 2: types is empty, but only one backend type is specified
if len(types) == 0 and len(backends) == 1:
defineds.append('IS_CUDA' if backends[0] == 'CUDA' else
"!IS_CUDA")
return " || ".join(defineds)
# Else, we loop overall all of the backend, type pairs and add
# them
for p in backends:
for t in types:
defineds.extend([get_defined_string(p, et) for et in
expand_composite_type(p, t)])
return " || ".join(defineds)
for declaration in declarations:
# Disable all methods for THHalfTensor, unless cpu_half is True
dfstr = backends_types_to_defined_if_string(declaration)
if len(dfstr) > 0:
# for now, need to check for distributed defined if as well
if 'defined_if' in declaration:
declaration['defined_if'] += ' && (' + dfstr + ')'
else:
declaration['defined_if'] = dfstr
if not declaration.get('cpu_half', False):
defined_if = '!defined(TH_REAL_IS_HALF)'
if 'defined_if' in declaration:
defined_if += ' && (' + declaration['defined_if'] + ')'
declaration['defined_if'] = defined_if
if declaration.get('only_register', False):
continue
declaration.setdefault('python_name', declaration['name'])
declaration.setdefault('variables', [])
if has_arg_type(declaration, 'THSize*'):
declaration['variables'] += ['THLongStoragePtr __size;']
if has_arg_type(declaration, 'THStride*'):
declaration['variables'] += ['THLongStoragePtr __stride;']
if has_output_args(declaration):
declaration['variables'] += ['PyObject *___out;']
self.generate_out_options(declaration)
if has_long_args(declaration):
for option in declaration['options']:
for arg in option['arguments']:
if arg.get('long_args', False):
arg['no_kwargs'] = True
for option in declaration['options']:
option['cname'] = 'TH{}Tensor_({})'.format(
'S' if option.get('sparse', False) else '', option['cname'])
if option.get('sparse', False):
defined_if = option.get('defined_if', '')
option['defined_if'] = '!IS_DISTRIBUTED' + (' && ' if defined_if else '') + defined_if
variants = declaration.get('variants', ['method'])
if 'function' in variants:
stateless_declaration = self.make_stateless(declaration)
new_declarations.append(stateless_declaration)
self.stateless_declarations.append(stateless_declaration)
if 'method' not in variants:
continue
self.declarations.append(declaration)
declaration['name'] = 'TH{}PTensor_({})'.format(
'S' if declaration.get('sparse', False) else '', declaration['name'])
for option in declaration['options']:
for arg in option['arguments']:
if arg['name'] == 'self':
arg['ignore_check'] = True
register_only = [d for d in declarations if d.get('only_register', False)]
declarations = [d for d in declarations
if (('method' in d.get('variants', ['method'])) and
(not d.get('only_register', False)))]
self.declarations.extend(filter(lambda x: 'method' in x.get('variants',
['method']), register_only))
self.stateless_declarations.extend(filter(lambda x: 'method' not in
x.get('variants', ['method']),
register_only))
self.process_docstrings()
all_declarations = declarations + new_declarations
return all_declarations
def make_stateless(self, declaration):
declaration = deepcopy(declaration)
declaration['name'] = 'TH{}PTensor_stateless_({})'.format(
'S' if declaration.get('sparse', False) else '', declaration['name'])
for option in declaration['options']:
for arg in option['arguments']:
if arg['name'] == 'self':
arg['assign_name'] = 'self'
arg['name'] = 'source'
return declaration
def declare_methods(self, stateless, sparse):
tensor_methods = ''
for declaration in (self.declarations if not stateless else self.stateless_declarations):
if declaration.get('sparse', False) != sparse:
continue
flags = 'METH_VARARGS'
flags += ' | ' + declaration.get('method_flags') if 'method_flags' in declaration else ''
if not declaration.get('only_register'):
flags += ' | METH_KEYWORDS'
if declaration.get('override_method_flags'):
flags = declaration['override_method_flags']
entry = Template(' {"$python_name", (PyCFunction)$name, $flags, $docstring},\n').substitute(
python_name=declaration['python_name'], name=declaration['name'], flags=flags,
docstring=declaration.get('docstring_var', 'NULL')
)
if 'defined_if' in declaration:
entry = self.preprocessor_guard(entry, declaration['defined_if'])
tensor_methods += entry
generated = self.TENSOR_METHODS_DECLARATION.substitute(
methods=tensor_methods,
stateless=('' if not stateless else 'stateless_'),
sparse=('' if not sparse else 'S'),
)
if sparse:
generated = '#if !defined(TH_REAL_IS_HALF) && !IS_DISTRIBUTED\n' + generated + '\n#endif\n\n'
return generated
def process_full_file(self, code):
# We have to find a place before all undefs
idx = code.find('// PUT DEFINITIONS IN HERE PLEASE')
return (code[:idx] +
self.declare_methods(False, False) +
self.declare_methods(True, False) +
self.declare_methods(False, True) +
self.declare_methods(True, True) +
code[idx:]
)
def preprocessor_guard(self, code, condition):
return '#if ' + condition + '\n' + code + '#endif\n'
def process_wrapper(self, code, declaration):
if 'defined_if' in declaration:
return self.preprocessor_guard(code, declaration['defined_if'])
return code
def process_all_call_arg(self, code, option):
return 'LIBRARY_STATE ' + code
def process_all_checks(self, code, option):
if option.get('has_output'):
indent = " " * 10
if option['output_provided']:
checks = "___out != NULL &&\n" + indent
if option['output_count'] > 1:
checks += "PyTuple_Check(___out) &&\n" + indent
length_check = "PyTuple_GET_SIZE(___out) == {} &&\n".format(
option['output_count'])
checks += length_check + indent
code = checks + code
else:
code = "___out == NULL &&\n" + indent + code
if any(arg.get('long_args', False) for arg in option['arguments']):
code = code.replace('__argcount ==', '__argcount >=')
expected = str(int(option.get('output_provided', False)) +
sum(not arg.get('no_kwargs', False) and not arg.get('ignore_check', False)
for arg in option['arguments']))
code = '__dictcount == ' + expected + ' &&\n ' + code
return code
def process_option_code(self, code, option):
if option.get('defined_if', ''):
defined_if = option['defined_if']
placeholder = ''
# This means that it's a first option, so we need a dummy if,
# so the next option can be an else if.
if 'else if' not in code:
placeholder = '\n #else\n if (false) {'
return '#if ' + defined_if + '\n ' + code + placeholder + '\n #endif\n'
return code
def process_pre_arg_assign(self, template, option):
new_args = []
for arg in option['arguments']:
if not option.get('output_provided', True) and arg.get('output'):
new_args.append(self.ALLOCATE_TYPE[arg['type']].substitute(name=arg['name']))
template = new_args + template
return template
def generate_docstrings_cpp(self):
template = Template('char* $name = "$content";')
return '\n\n'.join(
template.substitute(name=decl['docstring_var'], content=decl['docstring_content'])
for decl in chain(self.declarations, self.stateless_declarations)
if 'docstring_var' in decl)
def generate_docstrings_h(self):
template = Template('extern char* $name;')
return '\n\n'.join(
template.substitute(name=decl['docstring_var'])
for decl in chain(self.declarations, self.stateless_declarations)
if 'docstring_var' in decl)