blob: 92937486789183e01ca90c0f5e4e5608964c5237 [file] [log] [blame]
import math
from copy import deepcopy
from itertools import product
from .utils import argfilter
from .options import make_option
from .config import *
def make_function(lines, stateless):
if not stateless:
return Function(lines)
else:
return StatelessFunction(lines)
class Function(object):
DEFINITION_START = Template("""
static PyObject * THPTensor_(${name})(THPTensor *self, PyObject *args)
{
HANDLE_TH_ERRORS
Py_ssize_t _argcount = args ? PyTuple_Size(args) : 0;
""")
DEFINITION_END = Template("""
THPUtils_invalidArguments(args, ${expected_args});
return NULL;
END_HANDLE_TH_ERRORS
}
""")
def __init__(self, lines):
self._parse_lines(lines)
def generate(self):
if not self.options:
return '' # Ignore function
definition = self.DEFINITION_START.substitute({'name': self.name})
# Declare variables
variables = set((arg.type, arg.name) for option in self.options
for arg in option.arguments)
is_already_provided = argfilter()
for variable in variables:
if not is_already_provided(Argument(*variable)):
definition += ' {} {};\n'.format(*variable)
# Generate function body
definition += self._generate_body()
# Prepare quick docs and end the declaration
accepted_args_str = self._describe_options()
definition += self.DEFINITION_END.substitute(expected_args=accepted_args_str)
return definition
def _generate_body(self):
"""Generates code implementing all argument options
Options are sorted according to their argument count. Ones with equal
counts are wrapped in the same if block, that checks how many
arguments have been provided. This allows to ignore checking some
argument configurations, and save a couple of cycles (PyArg_ParseTuple
calls add some overhead).
"""
impl = ''
prev_arg_count = -1
for option in sorted(self.options, key=lambda o: o.num_required_args()):
num_args = option.num_required_args()
if num_args > prev_arg_count:
# Nothing to close if it's the first option
if prev_arg_count != -1 and prev_arg_count < math.inf:
impl += ' }\n'
if num_args < math.inf:
impl += Template(' if (_argcount == $numargs) {') \
.substitute({'numargs': num_args})
prev_arg_count = num_args
else:
impl += ' PyErr_Clear();'
impl += '\n {'
impl += option.generate()
impl += ' }\n'
# Close last argcount block
if prev_arg_count < math.inf:
impl += ' }\n'
return impl
def _describe_options(self):
"""Generates a string describing accepted argument configurations.
"""
def describe_arg(arg):
return TYPE_DESCRIPTIONS.get(arg.type, arg.type) + ' ' + arg.name
result = '"'
for option in self.options:
is_provided = argfilter()
args = list(filter(lambda arg: not is_provided(arg), option.arguments))
if args:
result += '('
result += ', '.join(map(describe_arg, args))
result += ')'
else:
result += 'no arguments'
result += ' or '
return result[:-4] + '"'
def _resolve_optional_args(self):
resolved_options = []
for option, optional_args in zip(self.options, self.optional_args):
if not optional_args:
resolved_options.append(option)
# Generate options with all possible configurations of optional args
for enabled_bits in product((True, False), repeat=len(optional_args)):
new_option = option.copy()
# Replace disabled args with their defaults
for enabled, default in zip(enabled_bits, optional_args):
if enabled:
continue
new_option.arguments[default[0]] = Argument('CONSTANT', default[1])
resolved_options.append(new_option)
self.options = resolved_options
def _should_ignore(self):
return 'STATELESS_ONLY' in self.flags
def _parse_lines(self, lines):
"""Parses cwrap declaration.
Accepts an iterable of lines and a boolean indicating if the function
should be stateless.
Returns a tuple of function name and possible options.
If option list is empty, the function should be ignored.
"""
assert len(lines) > 1
self.options = []
self.optional_args = []
self.name, self.flags = FUNCTION_NAME_REGEX.match(lines[0]).group(1, 2)
if self._should_ignore():
return
for line in lines[1:]:
match = OPTION_REGEX.match(line)
if match:
thname, rettype, flags = match.group(1, 2, 3)
self.options.append(make_option(thname, rettype, flags))
self.optional_args.append([])
else:
assert line.startswith(ARGUMENT_PREFIX)
arg = line.replace(ARGUMENT_PREFIX, '').strip()
option = self.options[-1]
# Check for default values
default_value = OPTIONAL_ARGUMENT_REGEX.match(arg)
if default_value:
arg_nr = len(option.arguments)
self.optional_args[-1].append((arg_nr, default_value.group(1)))
arg = arg[:arg.find(' OPTIONAL')]
# Parse argument
t, name = arg.split() if arg != 'self' else ('THTensor', 'self')
t = TYPE_TRANSFORMS.get(t, t)
option.add_argument(Argument(t, name))
self._resolve_optional_args()
self._parse_options()
def _parse_options(self):
pass
class StatelessFunction(Function):
DEFINITION_START = Template("""
static PyObject * THPTensor_stateless_(${name})(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
Py_ssize_t _argcount = args ? PyTuple_Size(args) : 0;
""")
def _should_ignore(self):
return 'STATEFUL_ONLY' in self.flags
def _parse_options(self):
self.options = self._make_stateless()
self._filter_options()
def _filter_options(self):
"""Filters out options that will never be reached.
"""
signatures = set()
def uniq_signatures(option):
h = option.signature_hash()
if h in signatures:
return False
signatures.add(h)
return True
return list(filter(uniq_signatures, self.options))
def _make_stateless(self):
"""Converts stateful options to stateless options.
There are two ways of performing this conversion:
1. If self is only an output argument (it's optional) it can be allocated
2. The user can also provide self, irrespective of its purpose
"""
stateless_options = []
def self_to_(new_name):
def self_to_new(arg):
if arg.name != 'self':
return arg
return Argument(arg.type, new_name)
return self_to_new
# First pass - try to allocate self, wherever possible
# This has to go first, because it will be favored during unique
for option in self.options:
# If self is optional, it can be allocated
if option.is_self_optional():
assert option.return_type == 'self'
new = option.copy()
new.map_arguments(self_to_('_res_new'))
new.return_type = 'STATELESS NEW self'
stateless_options.append(new)
# Second pass - if self is actually needed, it can be provided
for option in self.options:
provided = option.copy()
provided.map_arguments(self_to_('_res'))
if provided.return_type == 'self':
provided.return_type = 'STATELESS PROV self'
# This is where it gets tricky. There are two cases:
# 1. User only provides an output tensor
# 2. User provides both an output tensor, as well as an index tensor
if provided.return_type == 'new SelfIndexPair':
# Case 1.
provided.insert_argument(0, Argument('THPTensor*', '_res'))
provided.return_type = 'STATELESS PROV new SelfIndexPair'
stateless_options.append(provided.copy())
# Reuse option from case 1. to make 2.
provided.insert_argument(1, Argument('THPLongTensor*', '_res_ind'))
provided.return_type = 'STATELESS PROV2 new SelfIndexPair'
stateless_options.append(provided)
return stateless_options