blob: 67e0efde5e291dbbac5ecfe405b2a284580e0fd1 [file] [log] [blame]
import os
import argparse
from itertools import count
from ..autograd.utils import CodeTemplate, write, uninplace_api_name
from ..autograd.gen_autograd import load_aten_declarations
template_path = os.path.join(os.path.dirname(__file__), 'templates')
ATEN_DISPATCH_H = CodeTemplate.from_file(template_path + '/aten_dispatch.h')
ATEN_DISPATCH_CPP = CodeTemplate.from_file(template_path + '/aten_dispatch.cpp')
ATEN_INTERNED_STRINGS_H = CodeTemplate.from_file(template_path + '/aten_interned_strings.h')
ATTR_METHOD_MAP = {
'int64_t': 'i',
'IntList': 'is',
'Scalar': 't',
'bool': 'i',
'double': 'f',
'std::array<bool,2>': 'is',
'std::array<bool,3>': 'is',
'std::array<bool,4>': 'is',
}
TYPE_CASTS = {
'std::array<bool,2>': 'as_bool_array<2>',
'std::array<bool,3>': 'as_bool_array<3>',
'std::array<bool,4>': 'as_bool_array<4>',
'Scalar': 'Scalar',
'IntList': 'std::vector<int64_t>',
}
KW_ASSIGNMENT = CodeTemplate("""\
auto ${name} = ${type_cast}(node->${method}(Symbol::attr("${name}")));\
""")
POS_ASSIGNMENT = CodeTemplate("""\
auto ${name} = tensor_as<${type}>(std::move(peek(stack, ${i}, ${N})));\
""")
CALL_NAMESPACE = CodeTemplate("at::${name}(${args})")
CALL_METHOD = CodeTemplate("(${first}).${name}(${args})")
CONSTRUCTOR = CodeTemplate("""\
{"${descriptor}", [](Node *node) {
${kw_assignments}
return TensorOp([=](Stack & stack) {
autograd::profiler::RecordFunction record("${name}");
AutoGPU device_guard(deviceForInputs(stack, ${num_dynamic_inputs}));
${pos_assignments}
auto result = ${call};
drop(stack, ${num_dynamic_inputs});
pack(stack, std::move(result));
return 0;
}, "${name}", ${num_dynamic_inputs});
}},
""")
def is_magic_method(api_name):
return api_name.startswith('__') and api_name.endswith('__')
def is_jit_op(decl):
uses_tensors = any(arg['simple_type'] in {'Tensor', 'TensorList'} for arg in decl['arguments']) or \
'Tensor' in decl['method_of']
return ((not decl['api_name'].endswith('_') or is_magic_method(decl['api_name'])) and
not decl['name'].endswith('_out') and
not any(arg['simple_type'] == 'Generator' for arg in decl['arguments']) and
not any(arg['simple_type'] == 'SparseTensor' for arg in decl['arguments']) and
not any(arg['simple_type'] == 'Storage' for arg in decl['arguments']) and
not any(arg['simple_type'] == 'Type' for arg in decl['arguments']) and
uses_tensors)
# Scalar overloads like add(Tensor self, Scalar other) are not supported atm.
# TODO: Why are they not supported?
skip_scalar_overload = {
'lt-2': [1], 'gt-2': [1], 'le-2': [1], 'ge-2': [1], 'eq-2': [1], 'ne-2': [1],
'pow-2': [0, 1], 'add-3': [1], 'sub-3': [1], 'mul-2': [1], 'div-2': [1],
'fmod-2': [1], 'remainder-2': [1], '__and__-2': [1], '__or__-2': [1],
'__iand__-2': [1], '__ior__-2': [1], '__xor__-2': [1], '__ixor__-2': [1],
'__lshift__-2': [1], '__ilshift__-2': [1], '__rshift__-2': [1], '__irshift__-2': [1],
}
def gen_jit_dispatch(declarations, out):
ops = {}
def is_tensor_arg(arg):
return arg['simple_type'] in {'Tensor', 'TensorList'}
def get_invocation(decl, args):
if 'namespace' in decl['method_of']:
return CALL_NAMESPACE.substitute(name=decl['name'], args=args)
else:
return CALL_METHOD.substitute(name=decl['name'], first=args[0], args=args[1:])
def emit_decl_variant(decl, is_positional_arg, has_tensorlist):
# is_positional_arg is a boolean list the same length as decl['arguments']
# that indicates if the argument should come from the postional list
# of inputs. If false, the argument comes from the constant attributes
kw_assignments = []
attr_names = []
pos_assignments = []
arguments = []
if has_tensorlist:
kw_assignments.append('size_t varargs_length = node->inputs().size();')
# arguments look like: [tensor list], arg1, arg2, arg3
# we use peek(<i>, static_inputs) to read the non-vararg inputs
# from the end of the stack
static_inputs = sum(is_positional_arg) - 1
num_dynamic_inputs = 'varargs_length'
else:
static_inputs = sum(is_positional_arg)
num_dynamic_inputs = static_inputs
real_inputs = count()
for i, arg in enumerate(decl['arguments']):
# XXX: we currently support only TensorList ops that have a TensorList as
# the first argument, that is then followed by a number of positional args.
if arg['simple_type'] == 'TensorList':
arguments.append('peekSlice(stack, 0, varargs_length - {}, varargs_length)'.format(static_inputs))
elif is_tensor_arg(arg):
arguments.append('std::move(peek(stack, {}, {}))'.format(next(real_inputs), static_inputs))
elif is_positional_arg[i]:
assign = POS_ASSIGNMENT.substitute(type=arg['simple_type'],
name=arg['name'],
i=next(real_inputs),
N=static_inputs)
pos_assignments.append(assign)
arguments.append(arg['name'])
else:
assign = KW_ASSIGNMENT.substitute(type_cast=TYPE_CASTS.get(arg['simple_type'], arg['simple_type']),
name=arg['name'],
method=ATTR_METHOD_MAP[arg['simple_type']])
kw_assignments.append(assign)
attr_names.append(arg['name'])
arguments.append(arg['name'])
call = get_invocation(decl, arguments)
# Descriptor is a unique identifier for a particular overload of an op.
attr_names = sorted(attr_names)
num_inputs = '*' if has_tensorlist else static_inputs
descriptor = '-'.join([decl['name'], str(num_inputs)] + attr_names)
# If there are two overloads with the same descriptor, that differ only by a type of a
# single argument, where one of them takes a tensor, while another one takes an
# at::Scalar as a positional scalar arg, then prefer the tensor overload.
# It should get broadcasted correctly.
if descriptor in skip_scalar_overload:
if any(decl['arguments'][idx]['simple_type'] == 'Scalar'
for idx in skip_scalar_overload[descriptor]):
return
constructor = CONSTRUCTOR.substitute(descriptor=descriptor, name=decl['name'],
call=call,
kw_assignments=kw_assignments,
pos_assignments=pos_assignments,
num_dynamic_inputs=num_dynamic_inputs)
assert descriptor not in ops, descriptor
ops[descriptor] = constructor
def emit_decl(decl):
arguments = decl['arguments']
has_tensorlist = any(arg['simple_type'] == 'TensorList' for arg in arguments)
num_tensor_args = sum(map(is_tensor_arg, arguments))
# we currently only support vararg tensor lists when they are the _first_ argument
# and the only tensor argument
if has_tensorlist and (num_tensor_args != 1 or arguments[0]['simple_type'] != 'TensorList'):
return
# Right now, we generate dispatch methods that either take all non-tensor arguments
# as attributes, or don't use any attributes at all. In the future we might want to
# have something in the middle too (might be useful for e.g. constant propagation
# into attributes, as that would allow us to avoid reparsing tensors into scalar
# args at every invocation).
all_arguments_are_inputs = tuple(True for _ in arguments)
only_tensors_are_inputs = tuple(is_tensor_arg(arg) for arg in arguments)
# NB: if there are no scalar args then both options on LHS are equivalent, so deduplicate them.
for variant in set([all_arguments_are_inputs, only_tensors_are_inputs]):
emit_decl_variant(decl, variant, has_tensorlist)
# We need to add methods implemented manually in TensorImpl
tensor_impl_methods = [{
'name': name,
'api_name': name,
'method_of': ['Tensor'],
'arguments': [{'name': 'self', 'simple_type': 'Tensor'}],
} for name in ['sizes', 'strides', 'dim']]
aten_decls = load_aten_declarations(declarations) + tensor_impl_methods
jit_decls = [d for d in aten_decls if is_jit_op(d)]
for decl in jit_decls:
emit_decl(decl)
# Sort the generated snippets to ensure that the generation is deterministic
env = {'constructors': sorted(ops.values())}
write(out, 'aten_dispatch.h', ATEN_DISPATCH_H, env)
write(out, 'aten_dispatch.cpp', ATEN_DISPATCH_CPP, env)
# NB: Operate on aten_decls, not jit_decls, because VariableType is
# a client for these symbols as well
# NB: This means we DON'T generate interned strings for inplace ops.
# Change this when you do!
# NB: Keep this code synchronized with the code in
# tool/autograd/gen_variable_type.py
# NB: Some operations have inplace versions, but NOT non-inplace
# versions! Thus uninplace_api_name() is mandatory (if you remove
# it, you will get missing symbols.)
names = set(uninplace_api_name(decl['api_name']) for decl in aten_decls)
# NB: This grabs non keyword arguments too, but it's harmless
attrs = set(arg['name'] for decl in aten_decls for arg in decl['arguments'])
strings_env = {
'aten_symbols': ["_(aten, {}) \\".format(n) for n in sorted(names)],
'attr_symbols': ["_(attr, {}) \\".format(n) for n in sorted(attrs)]
}
write(out, 'aten_interned_strings.h', ATEN_INTERNED_STRINGS_H, strings_env)
def main():
parser = argparse.ArgumentParser(
description='Generate JIT op dispatch')
parser.add_argument('declarations', metavar='DECL',
help='path to Declarations.yaml')
parser.add_argument('out', metavar='OUT',
help='path to output directory')
args = parser.parse_args()
gen_jit_dispatch(args.declarations, args.out)
if __name__ == '__main__':
main()