blob: 35124429fb81743b3cfb855dbf2c32642906d620 [file] [log] [blame]
import os
import argparse
from collections import defaultdict
from tools.shared.module_loader import import_module
from itertools import count
from ..autograd.gen_variable_type import load_aten_declarations, CodeTemplate, write, \
FALLTHROUGH_RETURN_TYPES, FALLTHROUGH_FUNCTIONS, GENERATED_COMMENT
template_path = os.path.join(os.path.dirname(__file__), 'templates')
ATEN_DISPATCH_H = CodeTemplate.from_file(template_path + '/aten_dispatch.h')
ATEN_DISPATCH_CPP = CodeTemplate.from_file(template_path + '/aten_dispatch.cpp')
ATTR_METHOD_MAP = {
'int64_t': 'i',
'IntList': 'is',
'Scalar': 't',
'bool': 'i',
'double': 'f',
'std::array<bool,2>': 'is',
'std::array<bool,3>': 'is',
}
TYPE_CASTS = {
'std::array<bool,2>': 'as_bool_array<2>',
'std::array<bool,3>': 'as_bool_array<3>',
'Scalar': 'Scalar',
'IntList': 'std::vector<int64_t>',
}
ATTR_ASSIGNMENT = CodeTemplate("""\
auto ${name} = ${type_cast}(node->${method}(stringToSymbol("${name}")));\
""")
CALL_NAMESPACE = CodeTemplate("at::${name}(${args})")
CALL_METHOD = CodeTemplate("inputs[0].${name}(${args})")
CONSTRUCTOR = CodeTemplate("""\
{"${descriptor}", [](Node *node) {
${assignments}
return TensorOp([=](const std::vector<Tensor> & inputs, std::vector<Tensor> & outputs) {
autograd::profiler::RecordFunction record("${name}");
pack_list(outputs, ${call});
}, "${name}", ${num_inputs});
}},
""")
def is_jit_op(decl):
return (not decl['api_name'].endswith('_') and
not decl['name'].endswith('_out') and
not decl['name'].endswith('_forward') and
not any(arg['simple_type'] == 'Generator' for arg in decl['arguments']) and
not any(arg['simple_type'] == 'SparseTensor' for arg in decl['arguments']) and
not decl['return_type'] in FALLTHROUGH_RETURN_TYPES and
not decl['name'] in FALLTHROUGH_FUNCTIONS)
def gen_jit_dispatch(declarations, out):
aten_decls = load_aten_declarations(declarations)
jit_decls = [d for d in aten_decls if is_jit_op(d)]
def is_tensor_arg(arg):
return arg['simple_type'] in {'Tensor', 'TensorList'}
ops = {}
for decl in jit_decls:
arguments = decl['arguments']
name = decl['name']
scalar_args = [arg for arg in arguments if not is_tensor_arg(arg)]
has_tensorlist = any(arg['simple_type'] == 'TensorList' for arg in arguments)
# Descriptor is a unique identified for a particular overload of an op
attr_names = sorted([arg['name'] for arg in scalar_args])
num_inputs = len(arguments) - len(scalar_args) if not has_tensorlist else "*"
descriptor = '-'.join([decl['name'], str(num_inputs)] + attr_names)
# All scalar args need to be assigned, so they can be captured by a lambda
assignments = [ATTR_ASSIGNMENT.substitute(type=arg['simple_type'],
type_cast=TYPE_CASTS.get(arg['simple_type'], arg['simple_type']),
name=arg['name'],
method=ATTR_METHOD_MAP[arg['simple_type']])
for arg in scalar_args]
# Generate the actuall ATen call. This gets a bit tricky because of
# TensorList arguments, and functions that are only available as methods.
if 'namespace' in decl['method_of']:
if has_tensorlist:
if sum(map(is_tensor_arg, arguments)) != 1:
# TODO: support this
continue
args = ['inputs' if is_tensor_arg(arg) else arg['name']
for arg in arguments]
else:
tensor_id = iter(count(start=0))
args = ['inputs[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name']
for arg in arguments]
call = CALL_NAMESPACE.substitute(name=name, args=args)
else:
tensor_id = iter(count(start=1))
args = ['inputs[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name']
for arg in arguments[1:]]
call = CALL_METHOD.substitute(name=name, args=args)
constructor = CONSTRUCTOR.substitute(descriptor=descriptor, name=name, call=call,
assignments=assignments,
# num_inputs is only used in AutogradClosure, which
# is going to be removed soon anyway. There's no good value
# we can provide for cat.
num_inputs=num_inputs if num_inputs != "*" else 0)
assert descriptor not in ops, descriptor
ops[descriptor] = constructor
# Sort the generated snippets to ensure that the generation is deterministic
env = {'constructors': sorted(list(ops.values()))}
write(out, 'aten_dispatch.h', ATEN_DISPATCH_H, env)
write(out, 'aten_dispatch.cpp', ATEN_DISPATCH_CPP, env)
def main():
parser = argparse.ArgumentParser(
description='Generate JIT op dispatch')
parser.add_argument('declarations', metavar='DECL',
help='path to Declarations.yaml')
parser.add_argument('out', metavar='OUT',
help='path to output directory')
args = parser.parse_args()
gen_jit_dispatch(args.declarations, args.out)
if __name__ == '__main__':
main()