| from .nested_dict import nested_dict |
| from tools.shared.module_loader import import_module |
| |
| CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate |
| |
| |
| PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\ |
| static PyObject * ${pycname}(PyObject* self, PyObject* args, PyObject* kwargs) |
| { |
| HANDLE_TH_ERRORS |
| static PythonArgParser parser({ |
| ${prototypes} |
| }); |
| ${unpack_self} |
| PyObject* parsed_args[${max_args}]; |
| auto r = parser.parse(args, kwargs, parsed_args); |
| ${dispatch} |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| """) |
| |
| PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\ |
| static PyObject * ${pycname}(PyObject* self, PyObject* args) |
| { |
| HANDLE_TH_ERRORS |
| ${unpack_self} |
| return wrap(${dispatch_name}(${actuals})); |
| END_HANDLE_TH_ERRORS |
| } |
| """) |
| |
| PY_VARIABLE_CASE = CodeTemplate("""\ |
| ${cond} (r.idx == ${i}) { |
| return wrap(${dispatch_name}(${actuals})); |
| """) |
| |
| PY_VARIABLE_DISPATCH = CodeTemplate("""\ |
| inline ${return_type} ${dispatch_name}(${formal_args}) { |
| ${AutoNoGIL} |
| ${AutoGPU} |
| return ${dispatch_call}(${dispatch_args}); |
| } |
| """) |
| |
| PY_VARIABLE_METHOD_DEF = CodeTemplate("""\ |
| {"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""") |
| |
| UNPACK_SELF = "auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;" |
| |
| |
| # XXX: if you got here because of an assertion failure, it doesn't mean |
| # it's enough to just extend the list here. Before you do this, make sure |
| # to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h. |
| SUPPORTED_RETURN_TYPES = { |
| 'Tensor', 'std::tuple<Tensor,Tensor>', |
| 'std::tuple<Tensor,Tensor,Tensor>', 'std::vector<Tensor>', |
| 'Scalar', 'bool', 'int64_t', 'void*' |
| } |
| |
| |
| def create_python_bindings( |
| python_functions, py_methods, py_method_defs, py_method_dispatch, |
| is_class): |
| """python_variable_methods.cpp |
| |
| Generates Python bindings to Variable methods |
| """ |
| |
| unpack_methods = { |
| 'const Tensor &': 'tensor', |
| 'Tensor &': 'tensor', |
| 'Generator *': 'generator', |
| 'Storage &': 'storage', |
| 'int64_t': 'toInt64', |
| 'bool': 'toBool', |
| 'double': 'toDouble', |
| } |
| |
| def first_tensor_arg(arguments): |
| for arg in arguments: |
| if arg['simple_type'] in {'Tensor', 'TensorList'}: |
| return arg['name'] |
| return None |
| |
| def auto_gpu(option): |
| tensor_arg = first_tensor_arg(option['arguments']) |
| if tensor_arg is None: |
| return '' |
| return 'AutoGPU auto_gpu({});'.format(tensor_arg) |
| |
| def emit_dispatch(i, function): |
| env = {} |
| simple_return_type = function['return_type'].replace(' &', '') |
| assert simple_return_type in SUPPORTED_RETURN_TYPES, \ |
| function['name'] + ' returns unsupported type: ' + simple_return_type |
| |
| actuals = [] |
| formal_args = [] |
| arg_idx = 0 |
| for arg in function['arguments']: |
| if 'Tensor' in function['method_of'] and arg['name'] == 'self': |
| formal_args.append('Tensor & {}'.format(arg['name'])) |
| actuals.append('self_') |
| continue |
| |
| typename = arg['type'] |
| if typename.startswith('IntList['): |
| typename = 'IntList' |
| if typename.startswith('LongTensor'): |
| typename = 'Tensor' |
| |
| unpack = unpack_methods.get(typename, typename.lower()) |
| actuals.append('r.{}({})'.format(unpack, arg_idx)) |
| dispatch_type = typename |
| if dispatch_type == 'Tensor': |
| dispatch_type = 'const Tensor &' |
| elif dispatch_type == 'Tensor &': |
| dispatch_type = 'Tensor' |
| formal_args.append('{} {}'.format(dispatch_type, arg['name'])) |
| arg_idx += 1 |
| |
| env['i'] = i |
| env['actuals'] = actuals |
| env['formal_args'] = formal_args |
| if 'call_args' in function: |
| env['dispatch_args'] = function['call_args'] |
| else: |
| env['dispatch_args'] = [arg['name'] for arg in function['arguments']] |
| if 'Tensor' in function['method_of']: |
| env['dispatch_args'] = [arg for arg in env['dispatch_args'] if arg != 'self'] |
| env['dispatch_call'] = 'self.{}'.format(function['name']) |
| else: |
| env['dispatch_call'] = 'at::{}'.format(function['name']) |
| env['AutoNoGIL'] = 'AutoNoGIL no_gil;' |
| env['AutoGPU'] = auto_gpu(function) |
| env['cond'] = 'if' if i == 0 else '} else if' |
| env = nested_dict(env, function) |
| py_method_dispatch.append(PY_VARIABLE_DISPATCH.substitute(env)) |
| return PY_VARIABLE_CASE.substitute(env) |
| |
| def process_function(name, functions): |
| env = { |
| 'name': name, |
| 'dispatch_name': 'dispatch_{}'.format(name), |
| 'pycname': 'THPVariable_{}'.format(name), |
| 'prototypes': [], |
| 'max_args': max(len(o['arguments']) for o in functions), |
| 'unpack_self': [], |
| 'dispatch': [], |
| } |
| |
| is_method = 'Tensor' in functions[0]['method_of'] |
| if is_method: |
| env['unpack_self'] = [UNPACK_SELF] |
| |
| for o in functions: |
| prototype = o['prototype'] |
| if is_method: |
| prototype = prototype.replace('Tensor self, ', '') |
| prototype = prototype.replace('Tensor self', '') |
| if not is_class: |
| # Use 'input' instead of 'self' for NN functions |
| prototype = prototype.replace('Tensor self', 'Tensor input') |
| if 'deprecated' in o: |
| prototype += '|deprecated' |
| env['prototypes'].append('"{}",'.format(prototype)) |
| |
| for i, option in enumerate(functions): |
| env['dispatch'].append(emit_dispatch(i, nested_dict(env, option))) |
| env['dispatch'].append('}') |
| |
| if len(functions) == 1 and len(functions[0]['args']) == 1 and is_method: |
| tmpl = PY_VARIABLE_METHOD_NOARGS |
| env['actuals'] = ['self_'] |
| env['flags'] = 'METH_NOARGS' |
| else: |
| tmpl = PY_VARIABLE_METHOD_VARARGS |
| env['flags'] = 'METH_VARARGS | METH_KEYWORDS' |
| |
| if is_class and not is_method: |
| env['flags'] += ' | METH_STATIC' |
| |
| py_methods.append(tmpl.substitute(env)) |
| py_method_defs.append(PY_VARIABLE_METHOD_DEF.substitute(env)) |
| |
| for name in sorted(python_functions.keys()): |
| process_function(name, python_functions[name]) |