blob: 7085066dd535cf7166ec0607e65b314a3cb94b72 [file] [log] [blame]
import os
from string import Template
from . import CWrapPlugin
MODULE_HEAD = """
#include <Python.h>
#include <exception>
#include "THP_API.h"
#include "torch/csrc/nn/type_checks.h"
"""
with open(os.path.join(os.path.dirname(__file__), 'templates', 'module_tail.cpp'), 'r') as f:
MODULE_TAIL = Template(f.read())
REGISTER_METHOD_TEMPLATE = Template(' {"$name", (PyCFunction)$name, METH_VARARGS, NULL},\n')
MODULE_METHODS_TEMPLATE = Template("""
static PyMethodDef module_methods[] = {
$METHODS
{NULL, NULL, 0, NULL}
};
""")
class StandaloneExtension(CWrapPlugin):
TYPE_UNPACK = {
'THFloatTensor*': Template('THNN_FloatTensor_Unpack($arg)'),
'THDoubleTensor*': Template('THNN_DoubleTensor_Unpack($arg)'),
'THLongTensor*': Template('THNN_LongTensor_Unpack($arg)'),
'THIntTensor*': Template('THNN_IntTensor_Unpack($arg)'),
'THCudaHalfTensor*': Template('THNN_CudaHalfTensor_Unpack($arg)'),
'THCudaTensor*': Template('THNN_CudaFloatTensor_Unpack($arg)'),
'THCudaDoubleTensor*': Template('THNN_CudaDoubleTensor_Unpack($arg)'),
'THCudaLongTensor*': Template('THNN_CudaLongTensor_Unpack($arg)'),
'half': Template('THPHalfUtils_unpackReal($arg)'),
'float': Template('THPFloatUtils_unpackReal($arg)'),
'double': Template('THPDoubleUtils_unpackReal($arg)'),
'bool': Template('($arg == Py_True ? true : false)'),
'int': Template('THPUtils_unpackLong($arg)'),
'long': Template('THPUtils_unpackLong($arg)'),
'int64_t': Template('THPUtils_unpackLong($arg)'),
'void*': Template('(void*)THPUtils_unpackLong($arg)'),
'THGenerator*': Template('THPGenerator_TH_CData((THPGenerator*)$arg)'),
}
TYPE_CHECK = {
'THFloatTensor*': Template('THNN_FloatTensor_Check($arg)'),
'THDoubleTensor*': Template('THNN_DoubleTensor_Check($arg)'),
'THLongTensor*': Template('THNN_LongTensor_Check($arg)'),
'THIntTensor*': Template('THNN_IntTensor_Check($arg)'),
'THCudaHalfTensor*': Template('THNN_CudaHalfTensor_Check($arg)'),
'THCudaTensor*': Template('THNN_CudaFloatTensor_Check($arg)'),
'THCudaDoubleTensor*': Template('THNN_CudaDoubleTensor_Check($arg)'),
'THCudaLongTensor*': Template('THNN_CudaLongTensor_Check($arg)'),
'half': Template('THPHalfUtils_checkReal($arg)'),
'float': Template('THPFloatUtils_checkReal($arg)'),
'double': Template('THPDoubleUtils_checkReal($arg)'),
'bool': Template('PyBool_Check($arg)'),
'int': Template('THPUtils_checkLong($arg)'),
'long': Template('THPUtils_checkLong($arg)'),
'int64_t': Template('THPUtils_checkLong($arg)'),
'void*': Template('THPUtils_checkLong($arg)'),
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
}
WRAPPER_TEMPLATE = Template("""
PyObject * $name(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
int __argcount = args ? PyTuple_Size(args) : 0;
$options
} else {
THPUtils_invalidArguments(args, NULL, "$name", 1, $expected_args);
return NULL;
}
END_HANDLE_TH_ERRORS
}
""")
TYPE_NAMES = {
'THGenerator*': 'Generator',
'THCudaHalfTensor*': 'torch.cuda.HalfTensor',
'THCudaTensor*': 'torch.cuda.FloatTensor',
'THCudaDoubleTensor*': 'torch.cuda.DoubleTensor',
'THCudaLongTensor*': 'torch.cuda.LongTensor',
'THDoubleTensor*': 'torch.DoubleTensor',
'THFloatTensor*': 'torch.FloatTensor',
'THBoolTensor*': 'torch.ByteTensor',
'THLongTensor*': 'torch.LongTensor',
'THIndexTensor*': 'torch.LongTensor',
'THIntTensor*': 'torch.IntTensor',
'THLongStorage*': 'torch.LongStorage',
'long': 'int',
'int64_t': 'int',
'int': 'int',
'real': 'float',
'half': 'float',
'double': 'float',
'float': 'float',
'accreal': 'float',
'bool': 'bool',
'void*': 'int',
}
def __init__(self, module_name):
self.module_name = module_name
self.declarations = []
def process_full_file(self, code):
short_name = self.module_name.split('.')[-1]
new_code = MODULE_HEAD
new_code += code
new_code += self.declare_module_methods()
new_code += MODULE_TAIL.substitute(full_name=self.module_name, short_name=short_name)
return new_code
def process_wrapper(self, code, declaration):
self.declarations.append(declaration)
return code
def declare_module_methods(self):
module_methods = ''
for declaration in self.declarations:
module_methods += REGISTER_METHOD_TEMPLATE.substitute(name=declaration['name'])
return MODULE_METHODS_TEMPLATE.substitute(METHODS=module_methods)
def get_type_unpack(self, arg, option):
return self.TYPE_UNPACK.get(arg['type'], None)
def get_type_check(self, arg, option):
return self.TYPE_CHECK.get(arg['type'], None)
def get_wrapper_template(self, declaration):
arg_desc = []
def describe_arg(arg):
desc = self.TYPE_NAMES[arg['type']] + ' ' + arg['name']
if arg.get('nullable'):
return '[{} or None]'.format(desc)
return desc
for option in declaration['options']:
option_desc = [describe_arg(arg)
for arg in option['arguments']
if not arg.get('ignore_check', False)]
if option_desc:
arg_desc.append('({})'.format(', '.join(option_desc)))
else:
arg_desc.append('no arguments')
arg_desc.sort(key=len)
arg_desc = ['"' + desc + '"' for desc in arg_desc]
arg_str = ', '.join(arg_desc)
return Template(self.WRAPPER_TEMPLATE.safe_substitute(expected_args=arg_str))