blob: d51d87d65c2e19ab7526d8a9b87d4dc0d64f30fc [file] [log] [blame]
from __future__ import print_function
import re
import yaml
import pprint
import sys
import copy
try:
# use faster C loader if available
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
# [temp translations]
# We're currently incrementally moving from the custom func schema to the
# JIT signature schema incrementally. This will reduce overall complexity
# and increase compliance between these components. So for now we do simple
# type translations to continue to emit the legacy func schema for further
# processing by downstream tools. This will helps us avoid having to prematurely
# change all downstream tools to detect these new types.
def type_argument_translations(arg):
type_and_name = [a.strip() for a in arg.rsplit(' ', 1)]
name = ''
if len(type_and_name) > 1:
name = type_and_name[1]
t = type_and_name[0]
name = name.split('=')
default = None
nullable = False
size = None # Only applies to int[\d+] and Tensor[\d+] arguments
if len(name) > 1:
default = name[1]
name = name[0]
match = re.match(r'(Tensor.*)\((.+)\)(.*)', t)
annotation = None
if match:
t = match.group(1) + match.group(3)
annotation = match.group(2)
# XXX: is_nullable flag can only annotate entire type as optional type,
# need to special case Generator? logic to make ? only available in jit
# TODO: deprecate is_nullable global flag, and parse the type
# to support annotating complicated types with optional annotation
nullable = (t != 'Generator?' and '?' in t)
# This enables "Generator? x = None and translates to legacy
# "Generator* x = nullptr". See [temp translations].
if t == 'Generator?' and default == 'None':
t = 'Generator*'
default = 'nullptr'
# Enables Generator? by translating to legacy Generator*.
elif t == "Generator?":
t = 'Generator*'
# Enables Tensor[] by translating to legacy TensorList.
elif t == 'Tensor[]' or t == 'Tensor?[]':
t = 'TensorList'
# Enables int[] by translating to legacy IntArrayRef.
elif t == 'int[]':
t = 'IntArrayRef'
# Enables int by translating to legacy int64_t.
elif t == 'int':
t = 'int64_t'
elif t == 'int?':
t = 'int64_t?'
elif t == 'int64_t':
raise RuntimeError("Please use int and not int64_t. "
"See [temp translations] for details.")
elif t == 'int64_t?':
raise RuntimeError("Please use int? and not int64_t?. "
"See [temp translations] for details.")
elif t == 'Dimname[]?':
t = 'DimnameList?'
# Enables float by translating to legacy double.
elif t == 'float':
t = 'double'
# Enables str by translating to legacy std::string.
elif t == 'str':
t = 'std::string'
elif t == 'double':
raise RuntimeError("Please use float and not double. "
"See [temp translations] for details.")
# Enables int[x] by translating to legacy IntArrayRef[x]. See [temp translations]
elif re.match(r'int\[(\d+)\]', t):
match = re.match(r'int\[(\d+)\]', t)
t = 'IntArrayRef'
size = int(match.group(1))
# Enables bool[x] by translating to legacy std::array<bool,x>. See [temp translations]
elif re.match(r'bool\[(\d+)\]', t):
match = re.match(r'bool\[(\d+)\]', t)
t = 'std::array<bool,{}>'.format(match.group(1))
elif re.match(r'std::array', t):
raise RuntimeError("Please use array notation, e.g. bool[3] and not std::array."
"See [temp translations] for details.")
# Legacy type sanitization. TODO: Do we really need this?
if t == 'Generator*':
t = 'Generator *'
if not default:
pass
# This enables Tensor? x=None and translates to legacy
# "Tensor? x={}". See [temp translations].
elif t.startswith('Tensor?') and default == 'None':
default = "{}"
elif default == 'True':
default = True
elif default == 'False':
default = False
elif default == 'true':
raise RuntimeError("Please use True and not true. "
"See [temp translations] for details.")
elif default == 'false':
raise RuntimeError("Please use False and not false. "
"See [temp translations] for details.")
# Enables default argument [] by translating to legacy {}.
# See [temp translations]
elif default == '[]':
default = '{}'
# Enables lists by translating to legacy {.*}.
# See [temp translations]
elif re.match(r'\[.*\]', default):
default = "{" + default[1:-1] + "}"
elif default == 'None':
default = 'c10::nullopt'
# The JIT signature schema uses Mean, but in particular C++ needs
# the legacy Reduction::Mean. So we'll continue emiting that until
# we change this at either a JIT schema or C++ level.
elif default == 'Mean':
default = 'Reduction::Mean'
elif default == 'contiguous_format':
default = 'MemoryFormat::Contiguous'
elif default == 'per_tensor_affine':
default = 'QScheme::PER_TENSOR_AFFINE'
else:
try:
default = int(default)
except ValueError:
try:
default = float(default)
except ValueError:
pass
return t, name, default, nullable, size, annotation
def parse_arguments(args, func_variants, declaration, func_return):
arguments = []
kwarg_only = False
if len(args.strip()) == 0:
return arguments
# TODO: Use a real parser here; this will get bamboozled
# by signatures that contain things like std::array<bool, 2> (note the space)
for arg_idx, arg in enumerate(args.split(', ')):
type_and_name = [a.strip() for a in arg.rsplit(' ', 1)]
if type_and_name == ['*']:
assert not kwarg_only
kwarg_only = True
continue
t, name, default, nullable, size, annotation = type_argument_translations(arg)
argument_dict = {'type': t.rstrip('?'), 'name': name, 'is_nullable': nullable, 'annotation': annotation}
if size:
argument_dict['size'] = size
if default is not None:
argument_dict['default'] = default
if kwarg_only:
argument_dict['kwarg_only'] = True
arguments.append(argument_dict)
is_out_fn = False
arguments_out = []
arguments_other = []
for argument in arguments:
if argument['type'] == "Tensor" and \
argument['annotation'] and \
re.match(r'^(.*!)$', argument['annotation']) and \
argument.get('kwarg_only'):
argument['output'] = True
argument['kwarg_only'] = False
arguments_out.append(argument)
is_out_fn = True
else:
arguments_other.append(argument)
arguments = arguments_out + arguments_other
name = declaration['name']
if is_out_fn:
declaration['name'] += "_out"
# Reverse splat of TensorOptions
# As we move towards the JIT function schema for native_functions.yaml we need to support
# the expanded version of TensorOptions. For now we discover whether there are three
# types and names of keyword arguments: "ScalarType dtype", "Layout layout" and "Device device"
# Each, if set, must have default arguments set to long or float, strided and "cpu" respectively.
# They must appear in this order and in this order only in order for us to be able to process them.
# In the future we will get rid of this specific processing as downstream consumers start relying
# less on the content of Declarations.yaml. If you want to support more than this you'll
# potentially have to extend the JIT.
supported_topt_arguments = [
[
{'name': 'dtype', 'type': 'ScalarType', 'is_nullable': False, 'annotation': None},
{'name': 'layout', 'type': 'Layout', 'is_nullable': False, 'annotation': None},
{'name': 'device', 'type': 'Device', 'is_nullable': False, 'annotation': None},
{'name': 'pin_memory', 'type': 'bool', 'is_nullable': False, 'annotation': None, 'default': False},
]
]
supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[0]))
for arg in supported_topt_arguments[1]:
arg.update({'kwarg_only': True})
supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[1]))
for arg in supported_topt_arguments[2]:
arg.update({'default': 'c10::nullopt', 'is_nullable': True})
# add explicit support for what is needed for tril_indices / triu_indices
supported_topt_arguments.append(
[
{'name': 'dtype', 'type': 'ScalarType', 'annotation': None, 'kwarg_only': True,
'default': 'long', 'is_nullable': True},
{'name': 'layout', 'type': 'Layout', 'annotation': None, 'kwarg_only': True,
'default': 'c10::nullopt', 'is_nullable': True},
{'name': 'device', 'type': 'Device', 'annotation': None, 'kwarg_only': True,
'default': 'c10::nullopt', 'is_nullable': True},
{'name': 'pin_memory', 'type': 'bool', 'annotation': None, 'kwarg_only': True,
'default': 'c10::nullopt', 'is_nullable': True},
]
)
corresponding_topts = [
{'type': 'TensorOptions', 'name': 'options', 'is_nullable': False, 'annotation': None},
]
corresponding_topts.append(corresponding_topts[0].copy())
corresponding_topts[1]['kwarg_only'] = True
corresponding_topts.append(corresponding_topts[1].copy())
corresponding_topts[2]['default'] = '{}'
corresponding_topts.append(
{'type': 'TensorOptions', 'name': 'options', 'is_nullable': False, 'annotation': None,
'kwarg_only': True, 'default': 'at::kLong'})
def check_topt_representation(topt_representation):
for idx, supported_topt in enumerate(supported_topt_arguments):
matches = all(topt_representation[i] == topt for i, topt in enumerate(supported_topt))
if matches:
return corresponding_topts[idx]
return None
def is_tensor_option(argument):
return argument['name'] in ['dtype', 'layout', 'device', 'pin_memory']
new_arguments = []
idx = 0
while idx < len(arguments):
argument = arguments[idx]
number_of_arguments = len(supported_topt_arguments[0])
if is_tensor_option(argument) and len(arguments) - idx >= number_of_arguments:
topt_representation = []
for i in range(number_of_arguments):
argument = arguments[idx]
if not is_tensor_option(argument):
break
topt_representation.append(argument)
idx += 1
if len(topt_representation) == number_of_arguments:
merged_argument = check_topt_representation(topt_representation)
assert merged_argument, \
"Unsupported combination of TensorOptions {}, the only currently supported combinations are {}"\
.format(str(topt_representation), str(supported_topt_arguments))
new_arguments.append(merged_argument)
else:
new_arguments += topt_representation
else:
new_arguments.append(argument)
idx += 1
arguments = new_arguments
# Sanity checks
# TODO: convention is that the ith-argument correspond to the i-th return, but it would
# be better if we just named everything and matched by name.
for arg_idx, argument in enumerate(arguments_out):
assert argument['annotation'] == func_return[arg_idx]['annotation'], \
"For func {} writeable keyword Tensor arguments need to have a matching return Tensor. Further, " \
"the ith-argument needs to correspond to the i-th return.".format(name)
assert len(arguments_out) <= len(func_return), "func {} must return at least as many Tensors " \
"as can be passed as output.".format(name)
if name.endswith('_out'):
raise RuntimeError("Native function {} may not be suffixed with _out as we transition to a unified schema. "
"Otherwise you will cause confusion amongst consumers of native functions.".format(name))
if is_out_fn and func_variants not in [[], 'function', ['function']]:
raise RuntimeError("Native functions with output MUST be declared with only the function variant; "
"e.g., variants: function; otherwise you will tickle a Python argument binding bug "
"(which usually manifests itself as the result variable being undefined.) "
"The culprit was: {}".format(name))
if not is_out_fn:
assert len(arguments_out) == 0, "func {} is not marked as output yet contains output " \
"keyword arguments".format(name)
# TODO: Explicit checking for void is a hack and should disappear after a more
# functionally complete implementation of Tensor aliases.
if declaration['inplace'] and len(func_return) > 0 and func_return[0]['type'] != "void":
found_self = False
for arg_idx, argument in enumerate(arguments):
if argument['name'] == "self":
assert argument['annotation'] and argument['annotation'].endswith("!"), \
"Inplace function \"{}\" needs to annotate Tensor argument named self " \
"as mutable.".format(name)
found_self = True
assert argument['annotation'] == func_return[arg_idx]['annotation'], \
"Inplace function annotations of function {} need to match between " \
"input and correponding output.".format(name)
assert argument['name'] == func_return[arg_idx]['name'] or \
argument['name'] == func_return[arg_idx]['name'] + "_return"
assert argument['type'] == func_return[arg_idx]['type']
assert found_self, "Inplace function \"{}\" needs Tensor argument named self.".format(name)
return arguments
def parse_return_arguments(return_decl, inplace, func_decl):
arguments = []
# TODO: Use a real parser here; this will get bamboozled
# by signatures that contain things like std::array<bool, 2> (note the space)
if return_decl[0] == '(' and return_decl[-1] == ')':
return_decl = return_decl[1:-1]
multiple_args = len(return_decl.split(', ')) > 1
for arg_idx, arg in enumerate(return_decl.split(', ')):
t, name, default, nullable, size, annotation = type_argument_translations(arg)
# name of arguments and name of return sometimes have collision
# in this case, we rename the return name to <name>_return.
return_name = name
if name in func_decl['func'].split('->')[0]:
return_name = name + "_return"
argument_dict = {'type': t, 'name': return_name, 'annotation': annotation}
if name:
# See Note [field_name versus name]
argument_dict['field_name'] = name
else:
if t == "Tensor" and inplace:
assert annotation and annotation.endswith("!"), \
"Return Tensor of function \"{}\" flagged as inplace needs to be " \
"annotated as mutable".format(func_decl['func'])
argument_dict['name'] = 'self'
else:
argument_dict['name'] = 'result' if not multiple_args else 'result' + str(arg_idx)
argument_dict['output'] = True
arguments.append(argument_dict)
return arguments
def parse_native_yaml(path):
with open(path, 'r') as f:
return yaml.load(f, Loader=Loader)
def propagate_field_names(output_arguments, return_arguments):
if output_arguments:
for i, r in enumerate(return_arguments):
if 'field_name' in r:
output_arguments[i]['field_name'] = r['field_name']
def is_named_tensor_only(declaration):
return any(['Dimname' in arg['type'] for arg in declaration['arguments']])
def run(paths):
declarations = []
for path in paths:
for func in parse_native_yaml(path):
declaration = {'mode': 'native'}
try:
declaration['schema_string'] = "aten::" + func['func']
if '->' in func['func']:
func_decl, return_decl = [x.strip() for x in func['func'].split('->')]
else:
raise Exception('Expected return declaration')
fn_name, arguments = func_decl.split('(', 1)
assert arguments[-1] == ")", "Expecting closing ) for {}".format(func['func'])
arguments = arguments[:-1] # Expect closing )
declaration['name'] = func.get('name', fn_name)
declaration['inplace'] = re.search('(^__i|[^_]_$)', fn_name) is not None
return_arguments = parse_return_arguments(return_decl, declaration['inplace'], func)
arguments = parse_arguments(arguments, func.get('variants', []), declaration, return_arguments)
output_arguments = [x for x in arguments if x.get('output')]
propagate_field_names(output_arguments, return_arguments)
declaration['return'] = return_arguments if len(output_arguments) == 0 else output_arguments
declaration['variants'] = func.get('variants', ['function'])
declaration['requires_tensor'] = func.get('requires_tensor', False)
declaration['matches_jit_signature'] = func.get('matches_jit_signature', True)
declaration['cpu_half'] = func.get('cpu_half', False)
declaration['cpu_bfloat16'] = func.get('cpu_bfloat16', False)
declaration['cpu_bool'] = func.get('cpu_bool', False)
declaration['cuda_bool'] = func.get('cuda_bool', False)
declaration['deprecated'] = func.get('deprecated', False)
declaration['device_guard'] = func.get('device_guard', True)
declaration['named_guard'] = func.get('named_guard', True)
declaration['arguments'] = func.get('arguments', arguments)
declaration['type_method_definition_dispatch'] = func.get('dispatch', declaration['name'])
declaration['python_module'] = func.get('python_module', '')
declarations.append(declaration)
except Exception as e:
msg = '''Exception raised in processing function:
{func}
Generated partial declaration:
{decl}'''.format(func=pprint.pformat(func), decl=pprint.pformat(declaration))
print(msg, file=sys.stderr)
raise e
return declarations