blob: fa2b63ce0740e0afef7fde9d41d4a49f25cfe7c7 [file] [log] [blame]
from string import Template
import copy
from copy import deepcopy
from . import CWrapPlugin
from itertools import product
class CuDNNPlugin(CWrapPlugin):
TYPE_UNPACK = {
'THTensor*': Template('createTensor($arg)'),
'int': Template('((int) THPUtils_unpackLong($arg))'),
'std::vector<int>': Template('THPUtils_unpackIntTuple($arg)'),
'cudnnDataType_t': Template('$arg'),
'cudnnHandle_t': Template('$arg'),
'Convolution*': Template('(Convolution*)THPWrapper_get($arg)'),
'bool': Template('$arg == Py_True'),
'double': Template('THPDoubleUtils_unpackReal($arg)'),
}
INPUT_ARGUMENT_MAP = {
'THTensor*': 'const at::Tensor&',
}
TYPE_CHECK = {
'Convolution*': Template('THPWrapper_check($arg)'),
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == tensorClass'),
'int': Template('THPUtils_checkLong($arg)'),
'std::vector<int>': Template('THPUtils_checkIntTuple($arg)'),
'bool': Template('PyBool_Check($arg)'),
'double': Template('THPDoubleUtils_checkReal($arg)'),
}
RETURN_WRAPPER = {
'Convolution*': Template('return THPWrapper_New($result, [](void* arg) { delete (Convolution*)arg; });'),
}
METHODS_DECLARATION = Template("""
static PyMethodDef _THCUDNN_methods[] = {
$methods
{NULL}
};
PyMethodDef* THCUDNN_methods()
{
return _THCUDNN_methods;
}
""")
WRAPPER_TEMPLATE = Template("""\
static PyObject * $name(PyObject *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
int __tuplecount = args ? (int) PyTuple_Size(args) : 0;
int __dictcount = kwargs ? (int) PyDict_Size(kwargs) : 0;
int __argcount = __tuplecount + __dictcount;
PyObject* tensorClass = getTensorClass(args);
THCPAutoGPU __autogpu_guard = THCPAutoGPU(args);
$options
}
THPUtils_invalidArguments(args, kwargs, "$readable_name", $num_options, $expected_args);
return NULL;
END_HANDLE_TH_ERRORS
}
""")
RELEASE_ARG = Template("_${name}_guard.release();")
TYPE_NAMES = {
'THTensor*': '" THPTensorStr "',
'long': 'int',
'bool': 'bool',
'int': 'int',
}
def __init__(self):
self.declarations = []
def get_type_unpack(self, arg, option):
return self.TYPE_UNPACK.get(arg['type'], None)
def get_type_check(self, arg, option):
return self.TYPE_CHECK.get(arg['type'], None)
def get_assign_args(self, arguments):
assign_args = []
for arg in arguments:
arg = copy.copy(arg)
new_type = self.INPUT_ARGUMENT_MAP.get(arg['type'])
if new_type is not None:
arg['type'] = new_type
assign_args.append(arg)
return assign_args
def get_wrapper_template(self, declaration):
arg_desc = []
for option in declaration['options']:
option_desc = [self.TYPE_NAMES.get(arg['type'], arg['type']) + ' ' + arg['name']
for arg in option['arguments']
if not arg.get('ignore_check', False)]
# TODO: this should probably go to THPLongArgsPlugin
if option_desc:
arg_desc.append('({})'.format(', '.join(option_desc)))
else:
arg_desc.append('no arguments')
arg_desc.sort(key=len)
arg_desc = ['"' + desc + '"' for desc in arg_desc]
arg_str = ', '.join(arg_desc)
readable_name = declaration['python_name']
return Template(self.WRAPPER_TEMPLATE.safe_substitute(
readable_name=readable_name, num_options=len(arg_desc),
expected_args=arg_str))
def get_return_wrapper(self, option):
return self.RETURN_WRAPPER.get(option['return'], None)
def get_arg_accessor(self, arg, option):
name = arg['name']
if name == 'self':
return 'self'
elif name == 'dataType':
return 'getCudnnDataType(tensorClass)'
elif name == 'handle':
return 'getCudnnHandle()'
def process_declarations(self, declarations):
for declaration in declarations:
declaration.setdefault('python_name', '_{}'.format(declaration['name']))
declaration['name'] = 'THCUDNN_{}'.format(declaration['name'])
self.declarations.append(declaration)
for option in declaration['options']:
for arg in option['arguments']:
if arg['name'] in ['self', 'state', 'dataType', 'handle']:
arg['ignore_check'] = True
declaration['options'] = self.filter_unique_options(declaration['options'])
return [d for d in declarations if not d.get('only_register', False)]
def filter_unique_options(self, options):
def signature(option):
return '#'.join(arg['type'] for arg in option['arguments']
if 'ignore_check' not in arg or not arg['ignore_check'])
seen_signatures = set()
unique = []
for option in options:
sig = signature(option)
if sig not in seen_signatures:
unique.append(option)
seen_signatures.add(sig)
return unique
def preprocessor_guard(self, code, condition):
return '#if ' + condition + '\n' + code + '#endif\n'
def process_wrapper(self, code, declaration):
if 'defined_if' in declaration:
return self.preprocessor_guard(code, declaration['defined_if'])
return code
def process_all_call_arg(self, code, option):
return 'state, ' + code
def declare_methods(self):
methods = ''
for declaration in self.declarations:
extra_flags = ' | ' + declaration.get('method_flags') if 'method_flags' in declaration else ''
if not declaration.get('only_register'):
extra_flags += ' | METH_KEYWORDS'
entry = Template(' {"$python_name", (PyCFunction)$name, METH_VARARGS$extra_flags, NULL},\n').substitute(
python_name=declaration['python_name'], name=declaration['name'], extra_flags=extra_flags
)
if 'defined_if' in declaration:
entry = self.preprocessor_guard(entry, declaration['defined_if'])
methods += entry
return self.METHODS_DECLARATION.substitute(methods=methods)
def process_full_file(self, code):
return code + self.declare_methods()