blob: a9ee0a8ea82a1e401f8685644ef198bae16579ad [file] [log] [blame]
"""
To run this file by hand from the root of the PyTorch
repository, run:
python -m tools.jit.gen_jit_dispatch \
build/aten/src/ATen/Declarations.yaml \
$OUTPUT_DIR \
tools/jit/templates
Where $OUTPUT_DIR is where you would like the files to be
generated. In the full build system, OUTPUT_DIR is
torch/csrc/jit/generated/
"""
import argparse
import copy
import re
import yaml
from itertools import groupby
from ..autograd.utils import CodeTemplate, YamlLoader, write
from ..autograd.gen_autograd import load_aten_declarations
from ..autograd.gen_autograd import RETURNS_VIEWS_OF_INPUT
# JIT has a type system of
# Scalar = int | float | bool # int is the largest int (int64_t),
# float is the largest float (double) we don't have the others because they are never held in tensors
# Type = Scalar # primitive numbers
# | Tensor # any tensor, as defined by at::Tensor
# | Type[] # a dynamically sized list[ of a type
# | Scalar[N] # a homogenous fixed size scalar list, single scalars can expand to this list
# | (Type1, Type2, ...) # a heterogeneous tuple
# | Layout | ScalarType | Device | Generator # special singleton types for built-in concepts in tensor lib
# clean up the variety of C++ types in the ATen declarations
# to be in the restricted set of types that the IR represents
# note: no default values for this map, to make it clear what types
# can be passedthrough
TYPE_MAP = {
'std::array<bool,2>': 'bool[2]',
'std::array<bool,3>': 'bool[3]',
'std::array<bool,4>': 'bool[4]',
'std::string': 'str',
'Scalar': 'Scalar',
'MemoryFormat': 'MemoryFormat',
'MemoryFormat?': 'MemoryFormat?',
'QScheme': 'QScheme',
'Scalar?': 'Scalar?',
'Tensor': 'Tensor',
'Tensor?': 'Tensor?',
'TensorList': 'Tensor[]',
# this appears in return values instead of TensorList
# since TensorList is a ArrayRef in arguments but a vector
# in returns
'std::vector<Tensor>': 'Tensor[]',
'IntArrayRef': 'int[]',
'Layout': 'Layout',
'Layout?': 'Layout?',
'Device': 'Device',
'Device?': 'Device?',
'ScalarType': 'ScalarType',
'ScalarType?': 'ScalarType?',
'int64_t': 'int',
'int64_t?': 'int?',
'double': 'float',
'double?': 'float?',
'bool': 'bool',
'bool?': 'bool?',
'Generator': 'Generator?',
}
def optional_type_of(arg, typ):
# optional type special handling for Tensor?[] and Tensor
# types that is missing a optional annotation
if arg.get('is_nullable') and '?' not in typ:
if typ == 'TensorList' or typ == 'Tensor[]':
typ = 'Tensor?[]'
else:
typ = '{}?'.format(typ)
return typ
def annotated_type_of(arg, typ):
anno = arg.get('annotation')
if anno:
typ = '{}({})'.format(typ, anno)
return typ
def jit_type_of(arg):
jit_type = arg.get('jit_type')
if not jit_type:
jit_type = TYPE_MAP[arg['simple_type']]
if is_sized_intlist_arg(arg):
jit_type = 'int[{}]'.format(arg['size'])
jit_type = optional_type_of(arg, jit_type)
jit_type = annotated_type_of(arg, jit_type)
arg['jit_type'] = jit_type
return jit_type
# map from aten 'simple_type' to the function that will turn a tensor into
# that type
FROM_IVALUE = {
'Device': '{}.toDevice()',
'Device?': '{}.toOptional<c10::Device>()',
'IntArrayRef': '{}.toIntVector()',
'Layout': '{}.toLayout()',
'Layout?': '{}.toOptional<c10::Layout>()',
'MemoryFormat': '{}.toMemoryFormat()',
'MemoryFormat?': '{}.toOptional<c10::MemoryFormat>()',
'QScheme': '{}.toQScheme()',
'Scalar': '{}.toScalar()',
'Scalar?': '{}.toOptional<Scalar>()',
'ScalarType': '{}.toScalarType()',
'ScalarType?': '{}.toOptional<ScalarType>()',
'Tensor': '{}.toTensor()',
'Tensor?': 'toOptionalTensor({})',
'Tensor?[]': 'toListOfOptionalTensor({})',
'TensorList': '{}.toTensorVector()',
'bool': '{}.toBool()',
'bool?': '{}.toOptional<bool>()',
'double': '{}.toDouble()',
'double?': '{}.toOptional<double>()',
'int64_t': '{}.toInt()',
'int64_t?': '{}.toOptional<int64_t>()',
'std::string': '{}.toStringRef()',
'Generator': 'nullptr',
'std::array<bool,2>': 'as_bool_array<2>({}.toBoolList())',
'std::array<bool,3>': 'as_bool_array<3>({}.toBoolList())',
'std::array<bool,4>': 'as_bool_array<4>({}.toBoolList())',
}
def from_ivalue(arg, value):
typ = optional_type_of(arg, arg['simple_type'])
return FROM_IVALUE[typ].format(value)
CALL_UNBOXED_KERNEL = CodeTemplate("""\
auto result_ = callUnboxedKernel<${return_type}${formals_types_with_leading_comma}>(unboxedKernel${args_with_leading_comma});
""")
CALL_NAMESPACE = CodeTemplate("""\
auto result_ = at::${name}(
${args}
);
""")
CALL_METHOD = CodeTemplate("""\
auto result_ = (${first}).${name}(
${args}
);
""")
CALL_NAMESPACE_WITH_TENSOR_OPTIONS = CodeTemplate("""\
const auto options = TensorOptions()
.dtype(${dtype})
.layout(${layout})
.device(${device})
.pinned_memory(${pin_memory});
#ifdef USE_STATIC_DISPATCH
auto result_ = at::${name}(${args_with_tensor_options});
#else
auto result_ = torch::${name}(${args_with_tensor_options});
#endif
""")
CALL_METHOD_WITH_TENSOR_OPTIONS = CodeTemplate("""\
const auto options = TensorOptions()
.dtype(${dtype})
.layout(${layout})
.device(${device})
.pinned_memory(${pin_memory});
auto result_ = (${first}).${name}(${args_with_tensor_options});
""")
CONSTRUCTOR = CodeTemplate("""\
[](OperatorKernel* unboxedKernel, const OperatorHandle&, Stack* stack) {
using namespace at;
${lvalues}
${call}
drop(*stack, ${num_inputs});
pack(*stack, std::move(result_));
}
""")
CONSTRUCTOR_JITONLY = CodeTemplate("""\
[](Stack* stack) {
using namespace at;
${lvalues}
${call}
drop(*stack, ${num_inputs});
pack(*stack, std::move(result_));
return 0;
}
""")
OPERATOR = CodeTemplate("""\
.op("${signature}",
${op})
""")
OPERATOR_JITONLY = CodeTemplate("""\
.jitOnlyOp("${signature}",
${op})
""")
blacklisted_types = {
'Storage',
'DimnameList?',
'ConstQuantizerPtr',
'Dimname',
'DimnameList',
}
default_only_types = {'Generator'}
def is_jit_arg(i, arg):
simple_type = arg['simple_type']
if simple_type in blacklisted_types:
return False
if simple_type in default_only_types and 'default' not in arg:
return False
if simple_type == 'Type':
return False
return True
def is_jit_op(decl):
# We currently don't support functions that return nothing
assert all(r['type'] != 'void' for r in decl['returns'])
if len(decl['returns']) == 0:
return False
arguments = decl['arguments']
# there must be a single out variant
if is_out_variant(decl) and sum([not not arg.get('output') for arg in arguments]) > 1:
return False
return (('namespace' in decl['method_of'] or 'Tensor' in decl['method_of']) and
all(is_jit_arg(i, arg) for i, arg in enumerate(decl['arguments'])) and
all(is_jit_arg(i, arg) for i, arg in enumerate(decl['returns'])))
def is_tensor_arg(arg):
return arg['simple_type'] in {'Tensor', 'TensorList'}
def is_sized_intlist_arg(arg):
"""Returns True for arguments declared as IntArrayRef[k], but False for IntArrayRef."""
return (arg['simple_type'] == 'IntArrayRef') and ('size' in arg)
def base_name(decl):
name = decl['name']
return name[:-1] if decl.get('inplace', False) else name[:-4] if name.endswith('_out') else name
def is_view(decl):
return base_name(decl) in RETURNS_VIEWS_OF_INPUT
def is_out_variant(decl):
return decl['name'].endswith('_out')
# Copied from ..autograd.gen_python_functions.SKIP_PYTHON_BINDINGS
BACKWARD_OP_PATTERNS = [
'.*_backward',
'.*_backward_(out|input|weight|bias)',
]
def is_backward_op(decl):
for pattern in BACKWARD_OP_PATTERNS:
if re.match('^' + pattern + '$', decl['name']):
return True
return False
# for each argument in decl, the location it should appear in the
# jit schema declaration. e.g.
# arguments = [x, y, z] # the order in aten
# jit_argument_order = [2, 0, 1]
# aten::my_arg(Tensor y, Tensor z, Tensor x) # the order in schema
# used to move 'out' arguments to the end of the list
def argument_order(decl):
return decl.get('jit_argument_order') or list(range(len(decl['arguments'])))
def load_op_list(path):
with open(path, 'r') as f:
op_list = yaml.load(f, Loader=YamlLoader)
return op_list
def gen_jit_dispatch(
declarations,
out,
template_path,
disable_autograd=False,
selected_op_list_path=None,
selected_op_list=None,
force_schema_registration=False,
):
REGISTER_ATEN_OPS_CPP = CodeTemplate.from_file(template_path + '/register_aten_ops.cpp')
ops = []
def get_invocation(decl, args, num_inputs):
# because the arg list can get lengthy we put them on a separate line
def pack_arguments(args):
return ',\n'.join(args)
is_namespace_function = 'namespace' in decl['method_of']
tensor_options_arg_index = decl.get('tensor_options_arg_index', None)
if tensor_options_arg_index is not None:
dtype = args[tensor_options_arg_index]
layout = args[tensor_options_arg_index + 1]
device = args[tensor_options_arg_index + 2]
pin_memory = args[tensor_options_arg_index + 3]
args_with_tensor_options = args[:tensor_options_arg_index] + \
['options'] + args[(tensor_options_arg_index + 4):]
if is_namespace_function:
return CALL_NAMESPACE_WITH_TENSOR_OPTIONS.substitute(
name=decl['name'], dtype=dtype, layout=layout,
device=device, pin_memory=pin_memory,
args_with_tensor_options=pack_arguments(args_with_tensor_options))
else:
return CALL_METHOD_WITH_TENSOR_OPTIONS.substitute(
name=decl['name'], dtype=dtype, layout=layout,
device=device, pin_memory=pin_memory,
args_with_tensor_options=pack_arguments(args_with_tensor_options[1:]),
first=args_with_tensor_options[0], num_inputs=num_inputs)
# The use_c10_dispatcher setting in native_functions.yaml now has a new option
# 'with_codegenerated_unboxing_wrapper' which means we take the codegened unboxing wrapper from
# register_aten_ops.cpp and stuff it into c10. This new argument is the default, 'unboxed_only' is not the
# default anymore. For the (very few) ops that don't support boxed dispatch yet (i.e. ops taking TensorOptions
# arguments), we set them to 'unboxed_only' and they follow the old behavior of having register_aten_ops.cpp
# register the jit op.
elif decl['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper' and not needs_hacked_twin(decl):
if len(decl['returns']) == 0:
return_type = "void"
elif len(decl['returns']) == 1:
return_type = decl['returns'][0]['type']
else:
return_type = "std::tuple<{}>".format(", ".join([r['type'] for r in decl['returns']]))
for a in decl['arguments']:
if 'type' not in a:
raise Exception(decl)
argument_types_with_leading_comma = ", ".join([a['type'] for a in decl['arguments']])
if argument_types_with_leading_comma != "":
argument_types_with_leading_comma = ", " + argument_types_with_leading_comma
args_with_leading_comma = pack_arguments(args)
if args_with_leading_comma != "":
args_with_leading_comma = ", " + args_with_leading_comma
return CALL_UNBOXED_KERNEL.substitute(name=decl['name'],
args_with_leading_comma=args_with_leading_comma,
num_inputs=num_inputs,
return_type=return_type,
formals_types_with_leading_comma=argument_types_with_leading_comma)
else:
assert decl['use_c10_dispatcher'] in ['unboxed_only', 'full'] or needs_hacked_twin(decl)
if is_namespace_function:
return CALL_NAMESPACE.substitute(name=decl['name'],
args=pack_arguments(args),
num_inputs=num_inputs)
else:
return CALL_METHOD.substitute(
name=decl['name'], first=args[0],
args=pack_arguments(args[1:]), num_inputs=num_inputs)
def requires_lvalue(arg):
jit_type = jit_type_of(arg)
return jit_type.startswith('Tensor') and '!' in jit_type
def emit_decl_variant(decl):
if ('emit_dummy_placeholder' in decl):
if decl['use_c10_dispatcher'] == 'unboxed_only' or needs_hacked_twin(decl):
return "DUMMY_OPERATION_JITONLY"
else:
return "DUMMY_OPERATION"
kw_assignments = []
# mutable arguments in aten are passed as non const references
# these must be lvalues, so we have to put them in variables
# before calling the function
lvalues = []
arguments = []
num_inputs = len(decl['arguments'])
op_capture = ''
order = argument_order(decl)
for i, arg in enumerate(decl['arguments']):
value = from_ivalue(arg, '(std::move(peek(*stack, {}, {})))'.format(order[i], num_inputs))
if requires_lvalue(arg):
lvalues.append('auto {} = {};\n'.format(arg['name'], value))
value = arg['name']
arguments.append(value)
call = get_invocation(decl, arguments, num_inputs)
returns = decl['returns']
if decl['use_c10_dispatcher'] == 'unboxed_only' or needs_hacked_twin(decl):
# Ops taking TensorOptions aren't supported in this mechanism yet because boxed dispatch doesn't
# work for them. They use the old mechanism of registering a jitonly op for now.
# TODO We should get rid of this once TensorOptions are supported.
constructor = CONSTRUCTOR_JITONLY.substitute(name=decl['name'],
call=call,
kw_assignments=kw_assignments,
num_inputs=num_inputs,
op_capture=op_capture,
lvalues=lvalues)
elif decl['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper':
constructor = CONSTRUCTOR.substitute(name=decl['name'],
call=call,
kw_assignments=kw_assignments,
num_inputs=num_inputs,
op_capture=op_capture,
lvalues=lvalues)
else:
assert decl['use_c10_dispatcher'] == 'full'
return constructor
def filter_decls(jit_decls, disable_autograd, selected_op_list, force_schema_registration):
result = []
for decl in jit_decls:
if disable_autograd and is_backward_op(decl):
continue
op_name = signature_without_args(decl)
if selected_op_list and op_name not in selected_op_list:
if force_schema_registration:
decl['emit_dummy_placeholder'] = True
else:
continue
result.append(decl)
return result
# This function declares an order on declarations. This is necessary because
# there is some ambiguity in the choice of overload: if an argument is overloaded
# to accept both Scalar and Tensor, the schema with the Tensor should come first
# TODO: this can (probably) be removed when we remove the implicit conversion
# from Tensor -> Number.
def sort_decls(jit_decls):
def declkey(decl):
# key = sum_{i < len(args)} {1 if arg is tensor else 2} * (3 ** i)
# This is a ternary encoding where
# 0: No argument at this position
# 1: Tensor argument at this position
# 2: Some other argument at this position.
args = decl['arguments']
result = 0
for i in range(len(args)):
result += (3 ** i) * (1 if args[i]['simple_type'] == 'Tensor' else 2)
return result
# NB: itertools.groupby requires the list be sorted.
sorted_decls = sorted(jit_decls, key=lambda decl: decl['name'])
grouped_decls = [list(g) for _, g in
groupby(sorted_decls, key=lambda decl: decl['name'])]
return [sorted(g, key=declkey) for g in grouped_decls]
# We need to add methods implemented manually in TensorImpl
# TODO: This seems to claim sizes() returns an int64_t. Really?
tensor_impl_methods = [{
'name': name,
'api_name': name,
'schema_string': schema_string,
'overload_name': '',
'method_of': ['Tensor'],
'arguments': [{'name': 'self', 'simple_type': 'Tensor'}],
'returns': [{'name': 'result', 'type': 'int64_t', 'dynamic_type': 'int64_t', 'simple_type': 'int64_t'}],
'use_c10_dispatcher': 'unboxed_only',
} for name, schema_string in [
('sizes', 'aten::sizes(Tensor self) -> int'),
('strides', 'aten::strides(Tensor self) -> int'),
('dim', 'aten::dim(Tensor self) -> int'),
('numel', 'aten::numel(Tensor self) -> int'),
('element_size', 'aten::element_size(Tensor self) -> int'),
]]
aten_decls = load_aten_declarations(declarations) + tensor_impl_methods
jit_decls = [d for d in aten_decls if is_jit_op(d)]
# add arguments dtype and device for functions like zeros
def expand_options(decl, i, arg):
if arg['simple_type'] != 'TensorOptions':
return [arg]
assert decl.get('tensor_options_arg_index') != i
decl['tensor_options_arg_index'] = i
tensor_options_expansion = [
# XXX - until we actually have first-class interpreter types for these
# concepts, the default values to be encoded in Tensors
# If you change this, you also need to update [TensorOptions in script]
# in the tracer code.
# dtype is specified as an int64_t of at::ScalarType
{'name': 'dtype', 'simple_type': 'ScalarType'},
# layout is specified as an int64_t of at::Layout
{'name': 'layout', 'simple_type': 'Layout'},
# device is specified as an IntArrayRef of { at::Device::Type, device_id }
{'name': 'device', 'simple_type': 'Device'},
# pin_memory is specified as a boolean
{'name': 'pin_memory', 'simple_type': 'bool', 'default': False},
]
# TODO: Don't repack this into TensorOptions. Needs various changes in downstream code.
if 'default' in arg:
for el in tensor_options_expansion:
el['simple_type'] += '?'
el['default'] = 'None'
if 'default' in arg and arg['default'] == 'at::kLong':
tensor_options_expansion[0]['default'] = 'long'
if 'kwarg_only' in arg and arg['kwarg_only']:
for el in tensor_options_expansion:
el['kwarg_only'] = True
return tensor_options_expansion
additional_jit_decls = []
for decl in jit_decls:
decl['arguments'] = [a for i, arg in enumerate(decl['arguments']) for a in expand_options(decl, i, arg)]
if is_out_variant(decl):
reorder_out_args(decl)
if needs_hacked_twin(decl):
additional_jit_decls.append(hacked_twin(decl))
jit_decls.extend(additional_jit_decls)
if not selected_op_list:
selected_op_list = []
selected_op_list += load_op_list(selected_op_list_path) if selected_op_list_path else []
jit_decls = filter_decls(jit_decls, disable_autograd, selected_op_list, force_schema_registration)
# generation is deterministic
jit_decl_groups = sort_decls(jit_decls)
# NOTE: see Note [Sharded File] at the top of the register_aten_ops.cpp
# template regarding sharding of the generated files.
#
# If you edit the number of shards here, you will also have to
# modify generate_code.py, torch/CMakeLists.txt, and the TARGETS
# files.
num_shards = 3
shards = [[] for _ in range(num_shards)]
# ops are assigned arbitrarily but stably to a file based on hash
for group in jit_decl_groups:
x = sum(ord(c) for c in group[0]['name']) % num_shards
for decl in group:
if decl['use_c10_dispatcher'] == 'unboxed_only' or needs_hacked_twin(decl):
shards[x].append(OPERATOR_JITONLY.substitute(signature=decl['schema_string'],
op=emit_decl_variant(decl)))
elif decl['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper':
shards[x].append(OPERATOR.substitute(signature=decl['schema_string'],
op=emit_decl_variant(decl)))
else:
assert decl['use_c10_dispatcher'] == 'full'
for i, shard in enumerate(shards):
env = {
'constructors': shard,
}
write(out, 'register_aten_ops_%d.cpp' % i, REGISTER_ATEN_OPS_CPP, env)
default_map = {'{}': 'None', 'nullptr': 'None', 'c10::nullopt': 'None'}
def reorder_out_args(decl):
first_arg = decl['arguments'][0]
assert(first_arg['output'])
# the output variant must go at the end
# note: this is an annoying side effect of using a single '*'
# to denote kwarg_only
nargs = len(decl['arguments'])
decl['jit_argument_order'] = [nargs - 1] + list(range(nargs - 1))
def is_kwarg_only(a):
return a.get('kwarg_only') or a.get('output')
#
# create a clone of these declarations
# with nullability scrubbed from TensorList arg types
# TOOD find out why this exists and how to do it without the hack
#
NEEDS_HACKED_TWIN_NAMES = [
"aten::_index_put_impl_",
"aten::index.Tensor",
"aten::index_put",
"aten::index_put_",
]
def needs_hacked_twin(decl):
schema_string = decl['schema_string']
return any([schema_string.startswith(name) for name in NEEDS_HACKED_TWIN_NAMES])
def hacked_twin(decl):
decl_copy = copy.deepcopy(decl)
old_overload_name = decl['overload_name']
schema_string = decl['schema_string']
name = decl['name']
schema_string = schema_string.replace('Tensor?[]', 'Tensor[]')
if old_overload_name:
new_overload_name = old_overload_name + "_hacked_twin"
decl_copy['overload_name'] = new_overload_name
decl_copy['schema_string'] = schema_string.replace(name + "." + old_overload_name,
name + "." + new_overload_name)
else:
new_overload_name = "hacked_twin"
decl_copy['overload_name'] = new_overload_name
decl_copy['schema_string'] = schema_string.replace(name, name + "." + new_overload_name)
for arg in decl_copy['arguments']:
if arg['simple_type'] == 'TensorList' and arg.get('is_nullable'):
arg['is_nullable'] = False
return decl_copy
def signature_without_args(decl):
name = decl['name'] if not is_out_variant(decl) else decl['name'][:-4]
overload_name = '.' + decl['overload_name'] if not decl['overload_name'] == '' else ''
return 'aten::{}{}'.format(name, overload_name)
def main():
parser = argparse.ArgumentParser(
description='Generate JIT op dispatch')
parser.add_argument('declarations', metavar='DECL',
help='path to Declarations.yaml')
parser.add_argument('out', metavar='OUT',
help='path to output directory')
parser.add_argument('template_path', metavar='TEMPLATE_PATH',
help='path to templates directory')
args = parser.parse_args()
gen_jit_dispatch(args.declarations, args.out, args.template_path)
if __name__ == '__main__':
main()