| # Generates Python bindings for ATen functions |
| # |
| # The bindings are generated as methods on python_variable or functions on the |
| # torch._C._nn object. |
| # |
| |
| # Code tries to stick to the following rules: |
| # |
| # - templates should be colocated with the functions that use them. |
| # no templates are currently shared between functions, but if that |
| # happens, maybe put the template with the first one |
| # |
| # - don't use environment dictionaries when calling template.substitute(). |
| # pass named arguments directly for everything, otherwise it's much too |
| # hard to track what's actually being used and by who |
| # |
| # - colocate any new hacks/adjustments with existing ones of the same kind. |
| # ideally in a data structure rather than code if possible. See e.g. |
| # SCHEMA_DEFAULT_CONVERSION_HACKS, etc. |
| # |
| # - similarly, conversions from one format to another should ideally happen |
| # all at once in a single place. |
| # |
| # - no nontrivial nested functions. couple-liners are ok but please no more. |
| # especially avoid functions that read/write outer variables defined far away. |
| # |
| # - raise RuntimeError instead of asserting, and put as much |
| # information as is available into the message. I.e. no need to |
| # plumb in new params whose only purpose is to fill out an error |
| # message, but use what's there |
| # |
| |
| from collections import defaultdict |
| import re |
| from .gen_variable_type import should_trace |
| from .utils import write, is_tensor_method |
| |
| try: |
| from src.ATen.code_template import CodeTemplate |
| except ImportError: |
| from tools.shared.module_loader import import_module |
| CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate |
| |
| # |
| # declarations blacklist |
| # We skip codegen for these functions, for various reasons. |
| # Future PRs will categorize this list and eliminate or hoist |
| # them out of eager-only codegen. |
| # See https://github.com/pytorch/pytorch/issues/30788 |
| # |
| |
| # These functions require manual Python bindings or are not exposed to Python |
| SKIP_PYTHON_BINDINGS = [ |
| 'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride', |
| '.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward', |
| '.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*', |
| '_arange.*', '_range.*', '_linspace.*', '_logspace.*', |
| '_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*', '_sparse_dense_add_out', |
| 'index', 'unique_dim_consecutive', |
| '_indexCopy_', 'max_values', 'min_values', |
| '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', |
| '_th_.*', '_thnn_.*', |
| 'arange.*', 'range.*', '_solve.*', '_inverse.*', |
| 'full(_out)?', |
| '_cholesky.*', '_triangular_solve.*', '_qr.*', '_symeig.*', '_svd.*', |
| 'slice', 'randint(_out)?', |
| 'item', '_local_scalar_dense', 'to', |
| 'copy_sparse_to_sparse_', 'copy_', |
| 'numpy_T', # this needs to be an attribute in Python, not a function |
| 'nonzero(_(out|numpy))?', |
| 'set_quantizer_', # return types not supported yet |
| 'set_data', |
| '.*_overrideable', # overrideable functions for backend extension |
| 'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad' |
| ] |
| |
| # These function signatures are not exposed to Python. Note that this signature |
| # list does not support regex. |
| SKIP_PYTHON_BINDINGS_SIGNATURES = [ |
| 'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)', |
| 'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)', |
| 'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)', |
| 'div(Tensor, Scalar)', 'div_(Tensor, Scalar)', |
| ] |
| |
| NATIVE_NAMESPACE_MAPPING = { |
| "torch": "THPVariableFunctionsModule", |
| "torch.nn": "THPNNVariableFunctionsModule" |
| } |
| |
| def should_generate_python_binding(declaration): |
| name = declaration['name'] |
| for pattern in SKIP_PYTHON_BINDINGS: |
| if re.match('^' + pattern + '$', name): |
| return False |
| |
| simple_types = [arg['simple_type'] for arg in declaration['arguments']] |
| signature = '{}({})'.format(name, ', '.join(simple_types)) |
| for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: |
| if pattern == signature: |
| return False |
| |
| return True |
| |
| # |
| # top-level codegen functions, called from gen_autograd |
| # |
| |
| def get_py_variable_methods(declarations): |
| """ |
| Get declarations (grouped by name) which should be generated |
| as methods on Tensor. |
| """ |
| def should_bind(declaration): |
| return (should_generate_python_binding(declaration) and |
| not is_nn_module_function(declaration) and |
| is_tensor_method(declaration)) |
| |
| return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) |
| |
| |
| def gen_py_variable_methods(out, declarations, template_path): |
| """ |
| Generate Tensor methods. |
| """ |
| PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp') |
| |
| py_variable_methods = get_py_variable_methods(declarations) |
| |
| env = create_python_bindings(py_variable_methods, is_python_method=True, module=None) |
| |
| write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env) |
| |
| |
| def get_py_nn_functions(declarations): |
| """ |
| Get declarations (grouped by name) which should be generated |
| as functions in the "nn" module. |
| """ |
| def should_bind(declaration): |
| return (should_generate_python_binding(declaration) and |
| is_nn_module_function(declaration)) |
| |
| return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) |
| |
| |
| def gen_py_nn_functions(out, declarations, template_path): |
| """ |
| Generate functions in the "nn" module. |
| """ |
| PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp') |
| |
| py_nn_functions = get_py_nn_functions(declarations) |
| |
| env = create_python_bindings(py_nn_functions, is_python_method=False, module="torch.nn") |
| |
| write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env) |
| |
| |
| def get_py_torch_functions(declarations): |
| """ |
| Get declarations (grouped by name) which should be generated |
| as functions in the "torch" module. |
| """ |
| def should_bind(declaration): |
| return (should_generate_python_binding(declaration) and |
| not is_nn_module_function(declaration) and |
| is_torch_function(declaration)) |
| |
| return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) |
| |
| |
| def gen_py_torch_functions(out, declarations, template_path): |
| """ |
| Generate functions in the "torch" module. |
| """ |
| PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp') |
| |
| py_torch_functions = get_py_torch_functions(declarations) |
| |
| env = create_python_bindings(py_torch_functions, is_python_method=False, module="torch") |
| |
| write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env) |
| |
| |
| def group_declarations_by_op_name(declarations): |
| groups = defaultdict(list) |
| for d in declarations: |
| groups[op_name(d)].append(d) |
| return groups |
| |
| |
| def create_python_bindings(python_functions, is_python_method, module): |
| """Generates Python bindings to ATen functions""" |
| py_methods = [] |
| py_method_defs = [] |
| py_forwards = [] |
| |
| for name in sorted(python_functions.keys()): |
| overload_decls = python_functions[name] |
| py_methods.append(method_impl(name, overload_decls, is_python_method, module)) |
| py_method_defs.append(method_def(name, overload_decls, is_python_method, module)) |
| py_forwards.extend(forward_decls(name, overload_decls, is_python_method, module)) |
| |
| return { |
| 'py_forwards': py_forwards, |
| 'py_methods': py_methods, |
| 'py_method_defs': py_method_defs, |
| } |
| |
| |
| # |
| # extracting and storing parsed args |
| # |
| |
| UNPACK_METHODS = { |
| 'const Tensor &': 'tensor', |
| 'Tensor &': 'tensor', |
| 'c10::optional<Generator>': 'generator', |
| 'Storage': 'storage', |
| 'Storage &': 'storage', |
| 'const ScalarType &': 'scalartype', |
| 'const Device &': 'device', |
| 'c10::optional<DimnameList>': 'toDimnameListOptional', |
| 'c10::optional<ScalarType>': 'scalartypeOptional', |
| 'c10::optional<Layout>': 'layoutOptional', |
| 'c10::optional<MemoryFormat>': 'memoryformatOptional', |
| 'c10::optional<Scalar>': 'scalarOptional', |
| 'c10::optional<int64_t>': 'toInt64Optional', |
| 'c10::optional<bool>': 'toBoolOptional', |
| 'c10::optional<double>': 'toDoubleOptional', |
| 'IntArrayRef': 'intlist', |
| 'Scalar': 'scalar', |
| 'ScalarType': 'scalartype', |
| 'Dimname': 'dimname', |
| 'DimnameList': 'dimnamelist', |
| 'TensorList': 'tensorlist', |
| 'int64_t': 'toInt64', |
| 'bool': 'toBool', |
| 'double': 'toDouble', |
| 'std::string': 'string', |
| } |
| |
| UNPACK_WITH_SIZE_METHODS = { |
| 'TensorList': 'tensorlist_n<{}>', |
| 'DimnameList': 'dimnamelist', |
| 'IntArrayRef': 'intlist', |
| } |
| |
| UNPACK_WITH_DEFAULT_METHODS = { |
| 'const ScalarType &': 'scalartypeWithDefault', |
| 'const Device &': 'deviceWithDefault', |
| 'c10::optional<Layout>': 'layoutWithDefault', |
| } |
| |
| def parsed_arg_expr(arg, arg_index): |
| # e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)' |
| typename = arg['type'] |
| |
| default_init = arg.get('python_default_init') |
| if default_init is not None: |
| # Note: only introduced by make_python_binding_args |
| default_init = arg['python_default_init'] |
| if typename not in UNPACK_WITH_DEFAULT_METHODS: |
| raise RuntimeError( |
| 'type \'{}\' is not supported in python_default_init'. |
| format(typename)) |
| unpack_with_default = UNPACK_WITH_DEFAULT_METHODS[typename] |
| return '_r.{}({}, {})'.format(unpack_with_default, arg_index, default_init) |
| |
| size = arg.get('size') |
| if size is not None: |
| if typename not in UNPACK_WITH_SIZE_METHODS: |
| raise RuntimeError( |
| 'type \'{}\' with definite size ({}) is not supported'. |
| format(typename, size)) |
| unpack_with_size = UNPACK_WITH_SIZE_METHODS[typename].format(size) |
| return '_r.{}({})'.format(unpack_with_size, arg_index) |
| |
| unpack = UNPACK_METHODS.get(typename) |
| if unpack is None: |
| raise RuntimeError('type \'{}\' is not supported'.format(typename)) |
| |
| return '_r.{}({})'.format(unpack, arg_index) |
| |
| |
| # TODO make this part of something more general, or get rid of it |
| def unpack_optional_dimname_list_hack(name, expr): |
| # optional<ArrayRef<T>> are special. The PythonArgParser returns an |
| # optional<vector<T>>, which cannot be implicitly converted to |
| # optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap. |
| result = """\ |
| auto __{name} = {expr}; |
| c10::optional<{typ}> {name} = __{name} ? c10::make_optional({typ}(__{name}.value())) : c10::nullopt; |
| """.format(name=name, expr=expr, typ='DimnameList') |
| return [line.strip() for line in result.split('\n')] |
| |
| |
| def parse_arg(arg, arg_index, unpack_to_local=False): |
| # get parsed rhs |
| expr = parsed_arg_expr(arg, arg_index) |
| |
| # maybe unpack to local |
| name = arg['name'] |
| typename = arg['type'] |
| if typename == 'c10::optional<DimnameList>': |
| inits = unpack_optional_dimname_list_hack(name, expr) |
| expr = name |
| elif unpack_to_local: |
| inits = ['auto {} = {};'.format(name, expr)] |
| expr = name |
| else: |
| inits = [] |
| |
| return expr, inits |
| |
| |
| # |
| # schema type to cpp type conversions |
| # some of these are to prevent dangling refs to temps, others are more obscure |
| # TODO don't know if these fold into more general conversions somehere, hope so |
| # |
| |
| TEMP_SAFE_CPP_DECL_TYPE = { |
| 'Tensor &': 'Tensor', |
| } |
| |
| def get_cpp_decl_type(typename, ensure_temp_safe=True): |
| if ensure_temp_safe: |
| typename = TEMP_SAFE_CPP_DECL_TYPE.get(typename, typename) |
| return typename |
| |
| |
| def get_cpp_formal(arg, ensure_temp_safe=True): |
| decl_type = get_cpp_decl_type(arg['type'], ensure_temp_safe) |
| return '{} {}'.format(decl_type, arg['name']) |
| |
| |
| # XXX: if you got here because of an assertion failure, it doesn't mean |
| # it's enough to just extend the list here. Before you do this, make sure |
| # to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h. |
| SUPPORTED_RETURN_TYPES = { |
| 'Tensor', |
| 'std::tuple<Tensor,Tensor>', |
| 'std::tuple<Tensor,Tensor,Tensor>', |
| 'std::tuple<Tensor,Tensor,Tensor,Tensor>', |
| 'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>', |
| 'std::tuple<Tensor,Tensor,Tensor,int64_t>', |
| 'std::tuple<Tensor,Tensor,double,int64_t>', |
| 'std::tuple<Tensor,Tensor,Tensor,Tensor,int64_t>', |
| 'std::tuple<Tensor,Tensor,double,Tensor,int64_t>', |
| 'std::tuple<double,int64_t>', |
| 'std::vector<Tensor>', |
| 'Scalar', 'bool', 'int64_t', 'void*', 'void', |
| 'QScheme', 'double', |
| 'IntArrayRef', |
| 'ScalarType' |
| } |
| |
| def get_simple_return_type(declaration): |
| # Use the simple_return_type (Tensor) rather than the fancy return type |
| # (Tensor &). This is important because the dispatch lambdas take |
| # mutable arguments *by value*, not by reference. If you then return |
| # a reference to such an argument, you will now have a pointer to a |
| # dangling stack entry. Not good. |
| # |
| # You want: |
| # |
| # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); }; |
| # ^^^^^^ |
| # |
| # *not* |
| # |
| # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); }; |
| # ^^^^^^^ |
| # |
| # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing |
| # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a |
| # mutable reference to temporary. Maybe we could assign it to a |
| # variable itself.) |
| # |
| simple_return_type = declaration['return_type'].replace(' &', '') |
| if simple_return_type not in SUPPORTED_RETURN_TYPES: |
| raise RuntimeError(declaration['name'] + " returns unsupported type " + simple_return_type) |
| return simple_return_type |
| |
| # |
| # dispatch codegen |
| # |
| |
| def get_dispatch_callee(declaration): |
| # format the name of the receiving function or method |
| if is_tensor_method(declaration): |
| return 'self.{}'.format(declaration['name']) |
| elif is_torch_function(declaration): |
| namespace = function_namespace(declaration) |
| return '{}::{}'.format(namespace, declaration['name']) |
| else: |
| raise RuntimeError('could not dispatch, neither namespace function nor Tensor method') |
| |
| |
| def get_op_args(declaration, argmap): |
| # returns a list of argmap values in op call order, with two wrinkles: |
| # 1. 'self' is eliminated for methods, it's baked into the callee expression elsewhere |
| # 2. declaration['call_args'] shims legacy overrides and may contain constant values, |
| # not just names (see load_deprecated_signatures() in gen_autograd.py) |
| call_args_override = declaration.get('call_args') |
| if call_args_override: |
| # names or constants |
| keys = call_args_override |
| else: |
| # only names |
| keys = [param['name'] for param in declaration['arguments']] |
| |
| if is_tensor_method(declaration): |
| # exclude self for method calls |
| keys = [k for k in keys if k != 'self'] |
| |
| if call_args_override: |
| # assume missing keys are constants |
| return [argmap.get(k, k) for k in keys] |
| else: |
| return [argmap[k] for k in keys] |
| |
| |
| TENSOR_OPTIONS_DECL = CodeTemplate("""\ |
| const auto ${name} = TensorOptions() |
| .dtype(${dtype}) |
| .device(${device}) |
| .layout(${layout}) |
| .requires_grad(${requires_grad}) |
| .pinned_memory(${pin_memory}); |
| """) |
| |
| # addition to output-variant handler in which tensor options params |
| # (if present) are checked against properties of a tensor output param |
| # TODO remove hardcoding, use unpack logic from emit_single_dispatch |
| PY_VARIABLE_CHECK_OUT_TYPE_HACK = CodeTemplate("""\ |
| check_out_type_matches(_r.tensor(${out_idx}), _r.scalartype(${type_idx}), |
| _r.isNone(${type_idx}), _r.layoutOptional(${layout_idx}), |
| _r.device(${device_idx}), _r.isNone(${device_idx})); |
| """) |
| |
| # Unpack parsed args to locals, call the op, and wrap the result. |
| # Lambda is so GIL is back on by wrap() time (wrap can allocate) |
| PY_VARIABLE_WRAP = CodeTemplate("""\ |
| ${inits} |
| auto dispatch_${name} = [](${lambda_formals}) -> ${simple_return_type} { |
| ${auto_no_gil} |
| return ${dispatch_callee}(${dispatch_args}); |
| }; |
| return wrap(${namedtuple_typeref}dispatch_${name}(${lambda_args})${set_requires_grad}); |
| """) |
| |
| # void return variant |
| PY_VARIABLE_RETURN_VOID = CodeTemplate("""\ |
| ${inits} |
| auto dispatch_${name} = [](${lambda_formals}) -> ${simple_return_type} { |
| ${auto_no_gil} |
| ${dispatch_callee}(${dispatch_args}); |
| }; |
| dispatch_${name}(${lambda_args})${set_requires_grad}; |
| Py_RETURN_NONE; |
| """) |
| |
| |
| def emit_single_dispatch(declaration, is_python_method, output_gap=0): |
| """ |
| Emit dispatch code for a single declared overload. |
| """ |
| deprecated = '[deprecated] ' if declaration.get('deprecated', False) else '' |
| schema_comment = '// ' + deprecated + declaration['schema_string'] |
| inits = [schema_comment] |
| |
| pa = declaration['python_arglists'] |
| args = pa['input_args'] + pa['input_kwargs'] + pa['output_args'] |
| has_options = has_tensor_options(declaration) |
| |
| argmap = {} |
| |
| if is_python_method: |
| # self is passed directly to python binding, rather than parsed |
| argmap['self'] = {'value': 'self', 'formal': 'Tensor & self'} |
| |
| for i, arg in enumerate(args): |
| unpack = is_scatter(arg) or (has_options and is_tensor_self(arg)) |
| arg_expr, unpack_stmts = parse_arg(arg, i, unpack_to_local=unpack) |
| inits.extend(unpack_stmts) |
| if is_scatter(arg): |
| for j, elem in enumerate(arg['scatter_args']): |
| argmap[elem['name']] = { |
| 'value': '{}[{}]'.format(arg_expr, j), |
| 'formal': get_cpp_formal(elem, ensure_temp_safe=False), |
| } |
| else: |
| argmap[arg['name']] = {'value': arg_expr, 'formal': get_cpp_formal(arg)} |
| |
| # synthetic python binding args deliver op args |
| binding_argmap, binding_inits, set_requires_grad = \ |
| handle_python_binding_args(declaration, output_gap) |
| argmap.update(binding_argmap) |
| inits.extend(binding_inits) |
| |
| lambda_formals = [argmap[arg['name']]['formal'] for arg in declaration['arguments']] |
| lambda_args = [argmap[arg['name']]['value'] for arg in declaration['arguments']] |
| |
| dispatch_callee = get_dispatch_callee(declaration) |
| dispatch_args = get_op_args(declaration, {name: name for name, _ in argmap.items()}) |
| |
| auto_no_gil = [] if declaration['with_gil'] else ['pybind11::gil_scoped_release no_gil;'] |
| |
| simple_return_type = get_simple_return_type(declaration) |
| if simple_return_type == 'void': |
| template = PY_VARIABLE_RETURN_VOID |
| else: |
| template = PY_VARIABLE_WRAP |
| |
| return template.substitute( |
| name=declaration['name'], |
| inits=inits, |
| lambda_formals=lambda_formals, |
| lambda_args=lambda_args, |
| dispatch_callee=dispatch_callee, |
| dispatch_args=dispatch_args, |
| auto_no_gil=auto_no_gil, |
| set_requires_grad=set_requires_grad, |
| simple_return_type=simple_return_type, |
| namedtuple_typeref=declaration['namedtuple_typeref'], |
| ) |
| |
| |
| # arg['name'] to arg['simple_type'] for scattered tensor options fields |
| TENSOR_OPTIONS_FIELDS = { |
| 'dtype': 'ScalarType', |
| 'device': 'Device', |
| 'layout': 'Layout', |
| 'pin_memory': 'bool', |
| 'requires_grad': 'bool', |
| } |
| |
| def handle_python_binding_args(declaration, output_gap): |
| # map synthetic python binding args to op args and misc other stuff |
| # note: this logic shares arcane knowledge with make_python_binding_args |
| # and isn't completely airtight w.r.t. the possible contents of |
| # python_binding_args. TODO |
| |
| argmap = {} |
| inits = [] |
| set_requires_grad = '' |
| |
| pa = declaration['python_arglists'] |
| python_binding_args = pa['python_binding_args'] |
| |
| if len(python_binding_args) == 0: |
| # nothing to see here |
| return argmap, inits, set_requires_grad |
| |
| args = pa['input_args'] + pa['input_kwargs'] + pa['output_args'] |
| binding_arg_base = len(args) + output_gap |
| binding_arg_offsets = {arg['name']: i for i, arg in enumerate(python_binding_args)} |
| |
| def binding_arg_index(name): |
| return binding_arg_base + binding_arg_offsets[name] |
| |
| def parse_binding_arg(name): |
| binding_arg = python_binding_args[binding_arg_offsets[name]] |
| expr, _ = parse_arg(binding_arg, binding_arg_index(name)) |
| return expr |
| |
| has_output = len(pa['output_args']) == 1 |
| tensor_options_arg = get_tensor_options(declaration) |
| |
| if tensor_options_arg is not None: |
| # if our op has a tensor options arg, these are its scattered fields. |
| # first some checks |
| if has_output: |
| raise RuntimeError('{}: tensor options with output arg'.format(declaration['name'])) |
| for arg in python_binding_args: |
| typename = TENSOR_OPTIONS_FIELDS.get(arg['name']) |
| if typename is None: |
| raise RuntimeError( |
| '{}: unrecognized tensor options field \'{}\' in python binding arguments'. |
| format(declaration['name'], arg['name'])) |
| if typename != arg['simple_type']: |
| raise RuntimeError( |
| '{}: unrecognized type \'{}\' for tensor options field \'{}\' in python binding arguments'. |
| format(declaration['name'], arg['type'], arg['name'])) |
| python_binding_argnames = [arg['name'] for arg in python_binding_args] |
| if not all([key in python_binding_argnames for key in TENSOR_OPTIONS_FIELDS.keys()]): |
| raise RuntimeError( |
| '{}: incomplete tensor options args: {}'. |
| format(declaration['name'], [arg['name'] for arg in python_binding_args])) |
| # generate a gathering initialization of options struct |
| argname = tensor_options_arg['name'] |
| inits.append(TENSOR_OPTIONS_DECL.substitute({ |
| 'name': argname, |
| 'dtype': parse_binding_arg('dtype'), |
| 'layout': parse_binding_arg('layout'), |
| 'device': parse_binding_arg('device'), |
| 'requires_grad': parse_binding_arg('requires_grad'), |
| 'pin_memory': parse_binding_arg('pin_memory'), |
| })) |
| inits.append('torch::utils::maybe_initialize_cuda({});'.format(argname)) |
| # and add to op arg map |
| argmap['options'] = { |
| 'value': argname, |
| 'formal': get_cpp_formal(tensor_options_arg), |
| } |
| |
| else: |
| # not the scattered fields of a tensor options - sort of a grab bag |
| if 'dtype' in binding_arg_offsets: |
| # we're an output-arg variant, check these args against output tensor |
| if not has_output: |
| raise RuntimeError( |
| '{}: dtype in python_binding_args without output arg'. |
| format(declaration['name'])) |
| if not all([name in binding_arg_offsets for name in ['layout', 'device']]): |
| raise RuntimeError( |
| '{}: incomplete tensor options for output check'. |
| format(declaration['name'])) |
| check_type = PY_VARIABLE_CHECK_OUT_TYPE_HACK.substitute( |
| out_idx=get_python_output_index(declaration), |
| type_idx=binding_arg_index('dtype'), |
| layout_idx=binding_arg_index('layout'), |
| device_idx=binding_arg_index('device'), |
| ) |
| inits.append(check_type) |
| # we'll set requires_grad on outgoing tensor |
| if 'requires_grad' not in binding_arg_offsets: |
| raise RuntimeError( |
| '{}: expected "requires_grad" in python_binding_args absent tensor options arg but found [{}]'. |
| format(declaration['name'], [arg['name'] for arg in python_binding_args])) |
| requires_grad = parse_binding_arg('requires_grad') |
| set_requires_grad = '.set_requires_grad({})'.format(requires_grad) |
| |
| return argmap, inits, set_requires_grad |
| |
| |
| # handler for output/no-output overload pair |
| # (plugged into PY_VARIABLE_CASE as ${call_dispatch}) |
| PY_VARIABLE_OUT = CodeTemplate("""\ |
| if (_r.isNone(${out_idx})) { |
| ${call_dispatch} |
| } else { |
| ${call_dispatch_out} |
| } |
| """) |
| |
| # handler for a single parsed signature - may be a single overload or |
| # a pair of overloads that whose signatures only differ in output params |
| PY_VARIABLE_CASE = CodeTemplate("""\ |
| case ${i}: { |
| ${body} |
| } |
| """) |
| |
| |
| def emit_dispatch_case(i, dictionary, is_python_method): |
| """ |
| Emit dispatch code for a single parsed signature. This corresponds to either |
| a single overload, or a pair that differ only in output params. In the latter |
| case, a single signature is used for both and dispatching switches on the |
| presence/absence of passed output args. |
| - i: this signature's position in generated binding's signature list if number of |
| signatures > 1, otherwise None |
| - dictionary: contains a no-output overload declaration under 'base', and optionally |
| a second overload with outputs under 'out' |
| - true if we're generating a python method, in which case self is not parsed but |
| passed directly |
| """ |
| base_decl = dictionary['base'] |
| |
| if 'out' in dictionary: |
| # dispatch to output or no-output variant based on arg test |
| out_decl = dictionary['out'] |
| out_idx = get_python_output_index(out_decl) |
| output_gap = get_python_argc(out_decl) - get_python_argc(base_decl) |
| |
| call_dispatch = emit_single_dispatch(base_decl, is_python_method, output_gap) |
| call_dispatch_out = emit_single_dispatch(out_decl, is_python_method) |
| |
| # dispatch output and no-output variants, branch on _r.isNone(<out_idx>) |
| body = PY_VARIABLE_OUT.substitute( |
| out_idx=out_idx, |
| call_dispatch=call_dispatch, |
| call_dispatch_out=call_dispatch_out, |
| ) |
| else: |
| # no-output version only |
| body = emit_single_dispatch(base_decl, is_python_method) |
| |
| if i is not None: |
| # generate case for ith overload |
| return PY_VARIABLE_CASE.substitute(i=i, body=body) |
| else: |
| # only one overload, omit case wrapper |
| return body |
| |
| # |
| # named tuple codegen |
| # |
| |
| def namedtuple_fieldnames(declaration): |
| returns = declaration['returns'] |
| if len(returns) <= 1 or all(['field_name' not in x for x in returns]): |
| return [] |
| else: |
| def get_field_name(x): |
| # See Note [field_name versus name] |
| if 'field_name' not in x: |
| # When building on Windows, `PyStructSequence_UnnamedField` could not be |
| # resolved by the linker for some reason, which cause error in building: |
| # |
| # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol |
| # PyStructSequence_UnnamedField |
| # |
| # Thus, at this point in time, we do not support unnamed |
| # fields in namedtuple; you must either name all fields, |
| # or none of them. |
| raise ValueError("Unnamed field is not supported by codegen") |
| else: |
| return x['field_name'] |
| return [get_field_name(x) for x in returns] |
| |
| PY_NAMEDTUPLE_FIELDSDEF = CodeTemplate("""\ |
| static PyStructSequence_Field ${fieldsname}[] = { ${fields,} {nullptr} }; |
| """) |
| |
| PY_NAMEDTUPLE_TYPEDEF = CodeTemplate("""\ |
| static PyTypeObject ${typename}; |
| static bool ${typename}_initialized = false; |
| if (!${typename}_initialized) { |
| ${typename}_initialized = true; |
| static PyStructSequence_Desc desc = { "torch.return_types.${name}", nullptr, ${fieldsname}, ${size} }; |
| PyStructSequence_InitType(&${typename}, &desc); |
| ${typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; |
| } |
| """) |
| |
| |
| def emit_namedtuple_typedefs(declarations): |
| """ |
| Generate block of named tuple type def inits, and add typeref snippets |
| to declarations that use them |
| """ |
| flddefnames = {} # map from unique field name lists to field def name |
| flddefs = [] # field def declarations |
| typenames = {} # map from unique name + field name lists to typedef name |
| typedefs = [] # typedef declarations and init code |
| |
| for decl in declarations: |
| fieldnames = namedtuple_fieldnames(decl) |
| if fieldnames == []: |
| decl['namedtuple_typeref'] = '' |
| continue |
| |
| fn_key = '_'.join(fieldnames) |
| fieldsname = flddefnames.get(fn_key) |
| if fieldsname is None: |
| fieldsname = 'NamedTuple_fields{}'.format('' if flddefs == [] else len(fielddefs)) |
| fields = ['{{"{}", ""}}'.format(fn) for fn in fieldnames] |
| fieldsdef = PY_NAMEDTUPLE_FIELDSDEF.substitute( |
| fieldsname=fieldsname, |
| fields=fields |
| ) |
| flddefnames[fn_key] = fieldsname |
| flddefs.append(fieldsdef) |
| |
| name = decl['name'] |
| key = '{}_{}'.format(name, '_'.join(fieldnames)) |
| typename = typenames.get(key) |
| if typename is None: |
| typename = 'NamedTuple{}'.format('' if typedefs == [] else len(typedefs)) |
| typedef = PY_NAMEDTUPLE_TYPEDEF.substitute( |
| name=name, |
| typename=typename, |
| size=len(fieldnames), |
| fieldsname=fieldsname |
| ) |
| typenames[key] = typename |
| typedefs.append(typedef) |
| |
| decl['namedtuple_typeref'] = '&{}, '.format(typename) |
| |
| return flddefs + typedefs |
| |
| # |
| # method impl codegen |
| # |
| |
| def get_pycname(name): |
| return 'THPVariable_{}'.format(name) |
| |
| |
| def is_noarg_binding(overloads): |
| return len(overloads) == 1 and get_python_argc(overloads[0]) == 0 |
| |
| |
| # python binding for all overloads of a particular function/method |
| PY_VARIABLE_METHOD_VARARGS = CodeTemplate(r"""\ |
| // ${name} |
| static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) |
| { |
| ${method_header} |
| static PythonArgParser parser({ |
| ${signatures} |
| }, /*traceable=*/${traceable}); |
| |
| ParsedArgs<${max_args}> parsed_args; |
| auto _r = parser.parse(args, kwargs, parsed_args); |
| ${check_has_torch_function} |
| switch (_r.idx) { |
| ${dispatch} |
| } |
| ${method_footer} |
| } |
| |
| """) |
| |
| # python binding for single-overload function/method |
| PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate("""\ |
| // ${name} |
| static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) |
| { |
| ${method_header} |
| static PythonArgParser parser({ |
| ${signatures} |
| }, /*traceable=*/${traceable}); |
| |
| ParsedArgs<${max_args}> parsed_args; |
| auto _r = parser.parse(args, kwargs, parsed_args); |
| ${check_has_torch_function} |
| ${dispatch} |
| ${method_footer} |
| } |
| |
| """) |
| |
| # python binding for a method with no args, shortcuts parsing |
| PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\ |
| // ${name} |
| static PyObject * ${pycname}(PyObject* self_, PyObject* args) |
| { |
| ${method_header} |
| ${dispatch} |
| ${method_footer} |
| } |
| |
| """) |
| |
| TORCH_FUNCTION_CHECK = CodeTemplate("""\ |
| if(_r.has_torch_function()) { |
| return handle_torch_function(_r, args, kwargs, ${namespace}, ${modulename}); |
| } |
| """) |
| |
| # NOTE: we type the unpacked self as Tensor not Variable to avoid return type |
| # discrepancies on method resolution (e.g. Variable::detach_ returns void |
| # rather than Tensor &) |
| UNPACK_SELF = "Tensor& self = reinterpret_cast<THPVariable*>(self_)->cdata;" |
| |
| |
| def method_impl(name, declarations, is_python_method, module): |
| """ |
| Generate a python binding for all overloads of an op. |
| """ |
| for declaration in declarations: |
| # formals for python binding signature |
| declaration['python_arglists'] = make_python_arglists(declaration, is_python_method) |
| |
| pycname = get_pycname(name) |
| |
| method_header = ['HANDLE_TH_ERRORS'] |
| method_header += emit_namedtuple_typedefs(declarations) |
| method_header += [UNPACK_SELF] if is_python_method else [] |
| |
| method_footer = ['END_HANDLE_TH_ERRORS'] |
| |
| # emit dispatch |
| if is_noarg_binding(declarations): |
| dispatch = emit_single_dispatch(declaration, is_python_method) |
| return PY_VARIABLE_METHOD_NOARGS.substitute( |
| name=name, |
| pycname=pycname, |
| method_header=method_header, |
| dispatch=dispatch, |
| method_footer=method_footer, |
| ) |
| |
| method_footer = ['Py_RETURN_NONE;'] + method_footer |
| |
| grouped = group_overloads(declarations, is_python_method) |
| is_singleton = len(grouped) == 1 |
| |
| signatures = [] |
| dispatch = [] |
| for i, dictionary in enumerate(grouped): |
| signature = dictionary['signature'] |
| signatures.append('"{}",'.format(signature)) |
| overload_index = i if not is_singleton else None |
| dispatch.append(emit_dispatch_case(overload_index, dictionary, is_python_method)) |
| |
| if is_singleton: |
| template = PY_VARIABLE_METHOD_VARARGS_SINGLETON |
| else: |
| template = PY_VARIABLE_METHOD_VARARGS |
| |
| if module: |
| check_has_torch_function = TORCH_FUNCTION_CHECK.substitute( |
| namespace=NATIVE_NAMESPACE_MAPPING[module], |
| modulename='"' + module + '"', |
| ) |
| else: |
| check_has_torch_function = '' |
| |
| max_args = max([get_python_argc(decl) for decl in declarations]) |
| traceable = 'true' if all(should_trace(d) for d in declarations) else 'false' |
| |
| return template.substitute( |
| name=name, |
| pycname=pycname, |
| method_header=method_header, |
| max_args=max_args, |
| signatures=signatures, |
| traceable=traceable, |
| check_has_torch_function=check_has_torch_function, |
| dispatch=dispatch, |
| method_footer=method_footer, |
| ) |
| |
| |
| # |
| # forward declarations |
| # |
| |
| PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION = CodeTemplate("""\ |
| static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs); |
| """) |
| |
| PY_VARIABLE_FUNCTION_NOARGS_FORWARD_DECLARATION = CodeTemplate("""\ |
| static PyObject * ${pycname}(PyObject* self_, PyObject* args); |
| """) |
| |
| |
| def forward_decls(name, declarations, is_python_method, module): |
| if is_python_method: |
| return [] |
| |
| if is_noarg_binding(declarations): |
| template = PY_VARIABLE_FUNCTION_NOARGS_FORWARD_DECLARATION |
| else: |
| template = PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION |
| |
| pycname = get_pycname(name) |
| return [template.substitute(pycname=pycname)] |
| |
| |
| # |
| # method def (binding table entry) codegen |
| # |
| |
| # Python binary operator dunder methods |
| BINARY_OP_NAMES = [ |
| '__lt__', '__le__', |
| '__gt__', '__ge__', |
| '__eq__', '__ne__', |
| |
| '__add__', '__radd__', '__iadd__', |
| '__sub__', '__rsub__', '__isub__', |
| '__mul__', '__rmul__', '__imul__', |
| '__matmul__', '__rmatmul__', '__imatmul__', |
| '__truediv__', '__rtruediv__', '__itruediv__', |
| '__floordiv__', '__rfloordiv__', '__ifloordiv__', |
| '__mod__', '__rmod__', '__imod__', |
| '__divmod__', '__rdivmod__', '__idivmod__', |
| '__pow__', '__rpow__', '__ipow__', |
| '__lshift__', '__rlshift__', '__ilshift__', |
| '__rshift__', '__rrshift__', '__irshift__', |
| '__and__', '__rand__', '__iand__', |
| '__xor__', '__rxor__', '__ixor__', |
| '__or__', '__ror__', '__ior__', |
| ] |
| |
| # PyMethodDef entry for binary op, throws not implemented error |
| PY_VARIABLE_METHOD_BINOP_DEF = CodeTemplate("""\ |
| {"${name}", (PyCFunction)${pycfunc_voidcast}TypeError_to_NotImplemented_<${pycname}>, ${flags}, NULL},""") |
| |
| # PyMethodDef entry |
| PY_VARIABLE_METHOD_DEF = CodeTemplate("""\ |
| {"${name}", (PyCFunction)${pycfunc_voidcast}${pycname}, ${flags}, NULL},""") |
| |
| |
| def method_def(name, declarations, is_python_method, module): |
| """ |
| Generate method def entry. |
| """ |
| pycname = get_pycname(name) |
| |
| if is_noarg_binding(declarations): |
| pycfunc_voidcast = '' |
| flags = 'METH_NOARGS' if is_python_method else 'METH_VARARGS | METH_KEYWORDS' |
| else: |
| pycfunc_voidcast = '(void(*)(void))' |
| flags = 'METH_VARARGS | METH_KEYWORDS' |
| |
| if name in BINARY_OP_NAMES: |
| def_template = PY_VARIABLE_METHOD_BINOP_DEF |
| else: |
| def_template = PY_VARIABLE_METHOD_DEF |
| |
| return def_template.substitute( |
| name=name, |
| pycname=pycname, |
| pycfunc_voidcast=pycfunc_voidcast, |
| flags=flags, |
| ) |
| |
| # |
| # overload sorting and grouping |
| # |
| |
| def group_overloads(declarations, is_python_method): |
| """Returns a list of dictionaries containing the optional keys: |
| |
| "base": the regular ATen declaration (e.g. conv2d) |
| "out": the out variant (e.g. conv2d_out) |
| "signature": the signature used for Python argument parsing |
| |
| Note that we merge pairs of declarations with signatures that |
| are equivalent mod output arguments, and use a single entry in |
| the python_arg_parser sig list for both (output arguments become |
| optional) |
| """ |
| grouped = defaultdict(dict) |
| |
| # first group by signature ignoring out arguments |
| for declaration in declarations: |
| signature = get_python_signature(declaration, is_python_method, skip_outputs=True) |
| v = grouped[signature] |
| if declaration['name'].endswith('_out'): |
| v['out'] = declaration |
| # prefer the signature with optional out=... arguments |
| v['signature'] = get_python_signature(declaration, is_python_method) |
| else: |
| v['base'] = declaration |
| if 'signature' not in v: |
| v['signature'] = signature |
| |
| result = [] |
| for x, dictionary in sorted(grouped.items()): |
| if 'base' not in dictionary: |
| raise RuntimeError( |
| "'base' not in dictionary for {}. keys are {}".format( |
| x, list(dictionary.keys()))) |
| result.append(dictionary) |
| return sort_declarations(result) |
| |
| |
| # This function declares a partial order on declarations, and sorts them according |
| # to its linear extension. This is necessary, because there's some ambiguity in the |
| # choice of overload, and we want a different order. |
| # |
| # See Note[Order of overloads matters] |
| def sort_declarations(grouped_decls): |
| |
| # TODO: This is a hack! |
| # |
| # For some reason, when you specify a Scalar argument in a native |
| # function, you get a Declarations.yaml entry that looks like this: |
| # |
| # - default: 1 |
| # dynamic_type: Scalar |
| # is_nullable: false |
| # kwarg_only: true |
| # name: alpha |
| # type: Scalar |
| # |
| # This is contrast to when there is a 'real' argument in TH |
| # Declarations.cwrap; this gets (correctly?) translated into |
| # dynamic_type: real, and type: Scalar. I would like to fix this |
| # at the source but I have never understood what dynamic_type is |
| # supposed to be. |
| def normalized_dynamic_type(arg): |
| if arg['dynamic_type'] == 'real': |
| return 'Scalar' |
| return arg['dynamic_type'] |
| |
| def is_coord_smaller(arg1, arg2): |
| return normalized_dynamic_type(arg1) == 'Scalar' and arg2['dynamic_type'] == 'Tensor' |
| |
| def is_smaller(d1, d2): |
| """Returns True if d1 < d2 in the partial order.""" |
| args1, args2 = d1['base']['arguments'], d2['base']['arguments'] |
| if len(args1) != len(args2): |
| return False |
| any_smaller = any(is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2)) |
| all_smaller_or_equal = all(normalized_dynamic_type(arg1) == normalized_dynamic_type(arg2) or |
| is_coord_smaller(arg1, arg2) |
| for arg1, arg2 in zip(args1, args2)) |
| return any_smaller and all_smaller_or_equal |
| |
| # Construct the relation graph |
| larger_than = defaultdict(set) |
| for i1, decl1 in enumerate(grouped_decls): |
| for i2, decl2 in enumerate(grouped_decls): |
| if is_smaller(decl1, decl2): |
| larger_than[i1].add(i2) |
| |
| if not larger_than: |
| return grouped_decls |
| |
| # Use a topological sort to sort decls according to the partial order. |
| sorted_deps = [(i, decl) for i, decl in enumerate(grouped_decls) |
| if i not in larger_than] |
| for i, decl in sorted_deps: |
| for i2 in sorted(larger_than.keys()): |
| larger = larger_than[i2] |
| larger.discard(i) |
| if not larger: |
| del larger_than[i2] |
| sorted_deps.append((i2, grouped_decls[i2])) |
| |
| return [decl for i, decl in sorted_deps] |
| |
| |
| # |
| # python signature codegen |
| # |
| |
| SCHEMA_DEFAULT_CONVERSION_HACKS = { |
| 'nullptr': 'None', |
| 'c10::nullopt': 'None', |
| '{}': 'None', |
| } |
| |
| def get_schema_formal(arg, is_python_method): |
| name = arg['name'] |
| typename = arg['simple_type'] |
| |
| # TODO: remove this and make optional types in simple_type to be consistent across |
| # tensor and other types after make Tensor? be optional instead of undefined |
| if arg.get('is_nullable') and '?' not in typename: |
| typename = '{}?'.format(typename) |
| |
| # s/self/input/ outside method bindings. |
| # TODO remove this? doesn't rename in codegen, it's just for the parse string |
| if name == 'self' and typename == 'Tensor' and not is_python_method: |
| name = 'input' |
| |
| size = arg.get('size') |
| if size is not None: |
| typename = '{}[{}]'.format(typename, size) |
| |
| # default |
| default = arg.get('default') |
| if default is not None: |
| default = SCHEMA_DEFAULT_CONVERSION_HACKS.get(default, default) |
| return '{} {}={}'.format(typename, name, default) |
| else: |
| return '{} {}'.format(typename, name) |
| |
| |
| PYTHON_ARG_PARSER_SCHEMA = CodeTemplate("""\ |
| ${name}(${schema_formals})${deprecated}""") |
| |
| |
| def get_python_signature(declaration, is_python_method, skip_outputs=False): |
| # Compute the Python function signature for argument parsing, |
| # as specified in torch/csrc/utils/python_arg_parser.h. WARNING: |
| # this is NOT the same type signature as specified by PEP 484 |
| # as understood by mypy; our format was independently developed |
| # and has some quirks to make it more suitable specifically |
| # for error parsing. |
| # |
| # For a translation to mypy-valid type signatures, see |
| # tools/gen_pyi.py. If you change any logic here, please |
| # check that file too. |
| |
| python_args = get_python_args(declaration) |
| if skip_outputs: |
| python_args = [arg for arg in python_args if not is_output(arg)] |
| |
| schema_formals = [get_schema_formal(arg, is_python_method) for arg in python_args] |
| positional_argc = len(declaration['python_arglists']['input_args']) |
| if len(python_args) > positional_argc: |
| schema_formals.insert(positional_argc, '*') |
| |
| # Python function signature. |
| # This is the string that we give to FunctionParameter, which is |
| # then parsed into the actual structure which we do parsing with. |
| name = op_name(declaration) |
| deprecated = '|deprecated' if declaration.get('deprecated', False) else '' |
| return PYTHON_ARG_PARSER_SCHEMA.substitute( |
| name=name, |
| schema_formals=schema_formals, |
| deprecated=deprecated, |
| ) |
| |
| # |
| # op args to python parsed args transform |
| # |
| |
| def get_python_args(decl): |
| arglists = decl['python_arglists'] |
| return \ |
| arglists['input_args'] + \ |
| arglists['input_kwargs'] + \ |
| arglists['output_args'] + \ |
| arglists['python_binding_args'] |
| |
| |
| def get_python_argc(decl): |
| return sum([len(arglist) for arglist in decl['python_arglists'].values()]) |
| |
| |
| def get_python_output_index(decl): |
| arglists = decl['python_arglists'] |
| return len(arglists['input_args'] + arglists['input_kwargs']) |
| |
| |
| def make_python_arglists(declaration, is_python_method): |
| # produces python-ready args converted from declaration['args'], |
| # partitioned into sublists by category. subslists are order, so |
| # the final python arglist can be recovered by simple flattening |
| # (see get_python_args()) |
| |
| # partition args into sublists |
| |
| args = declaration['arguments'] |
| |
| input_args = [] |
| input_kwargs = [] |
| output_args = [] |
| |
| current_input_args = input_args |
| for arg in args: |
| if is_output(arg): |
| output_args.append(arg) |
| else: |
| if arg.get('kwarg_only', False): |
| current_input_args = input_kwargs |
| current_input_args.append(arg) |
| |
| # adjustments |
| |
| # positional inputs: |
| # - filter self when we're generating a method binding.else - there, it comes in as |
| # a separate Python param, not in args array |
| def include(arg): |
| return not (is_tensor_self(arg) and is_python_method) |
| input_args = [arg for arg in input_args if include(arg)] |
| |
| # keyword inputs: |
| # - filter options. after loading the yaml, an upstream step has gathered dtype, |
| # layout et al into a single tensor options arg. here we reintroduce the originals |
| input_kwargs = [arg for arg in input_kwargs if not is_tensor_options(arg)] |
| |
| # outputs: |
| # - coalesce multiple output args into a single 'out' arg w/type TensorList. |
| # - force a default. This is so we can use this sig for both out and non-out variants |
| num_outputs = len(output_args) |
| if num_outputs > 1: |
| for arg in output_args: |
| if not arg['simple_type'] == 'Tensor': |
| raise RuntimeError( |
| '{}: unsupported output argument type {}'. |
| format(declaration['name'], arg['type'])) |
| typename = 'TensorList' |
| output_args = [{ |
| 'default': 'None', |
| 'kwarg_only': True, |
| 'name': 'out', |
| 'output': True, |
| 'scatter_args': output_args, |
| 'simple_type': typename, |
| 'size': num_outputs, |
| 'type': typename, |
| }] |
| elif num_outputs == 1: |
| output_arg = output_args[0].copy() |
| output_arg['default'] = 'None' |
| output_args = [output_arg] |
| |
| # make python binding args |
| # these are the (re)scattered versions of the options arg omitted above. |
| # TODO because these aren't guaranteed to be 100% faithful to the original |
| # versions in the yaml, this recreation is a potential source of drift between |
| # eager and JIT. Pull this logic out to a shared place. |
| python_binding_args = make_python_binding_args(declaration) |
| |
| return { |
| 'input_args': input_args, |
| 'input_kwargs': input_kwargs, |
| 'output_args': output_args, |
| 'python_binding_args': python_binding_args, |
| } |
| |
| # |
| # python binding args |
| # |
| |
| # TODO blowtorch |
| def dtype_default_type_hack(name): |
| if name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices': |
| return 'torch.int64' |
| else: |
| return 'None' |
| |
| |
| def make_python_binding_args(declaration): |
| """ |
| Given various properties of a declaration, build a set of scattered python binding args. |
| """ |
| name = declaration['name'] |
| python_binding_arguments = [] |
| has_tensor_input_arg = False |
| has_options_arg = False |
| for arg in declaration['arguments']: |
| if is_output(arg): |
| continue |
| typename = arg['simple_type'] |
| if typename in ['Tensor', 'TensorList']: |
| has_tensor_input_arg = True |
| elif typename == 'TensorOptions': |
| has_options_arg = True |
| if arg['name'] == 'requires_grad': |
| raise ValueError("argument named requires_grad not supported") |
| |
| has_tensor_return = False |
| for ret in declaration['returns']: |
| if ret['dynamic_type'] in ['Tensor', 'TensorList']: |
| # this probably won't work if one of the returns is not a tensor, but it will |
| # produce a compile-time error that is obvious |
| has_tensor_return = True |
| |
| category_override = declaration['category_override'] |
| is_like_function = name.endswith('_like') or category_override == 'like' |
| is_like_function_with_options = is_like_function and has_options_arg |
| is_new_function = name.startswith('new_') or category_override == 'new' |
| is_new_function_with_options = is_new_function and has_options_arg |
| is_factory_function = has_tensor_return and not has_tensor_input_arg or category_override == 'factory' |
| is_factory_or_like_or_new_function = has_tensor_return and (is_factory_function or is_like_function or is_new_function) |
| is_like_or_new_function_with_options = is_like_function_with_options or is_new_function_with_options |
| |
| if is_factory_function or has_options_arg: |
| default_type = dtype_default_type_hack(name) |
| py_default_dtype = 'self.scalar_type()' if is_like_or_new_function_with_options else None |
| dtype_arg = { |
| 'default': default_type, |
| 'dynamic_type': 'ScalarType', |
| 'kwarg_only': True, |
| 'name': 'dtype', |
| 'type': 'const ScalarType &', |
| 'simple_type': 'ScalarType', |
| 'python_default_init': py_default_dtype, |
| } |
| python_binding_arguments.append(dtype_arg) |
| |
| if is_factory_function or is_like_or_new_function_with_options: |
| py_default_layout = 'layout_from_backend(self.options().backend())' if is_like_or_new_function_with_options else None |
| layout_arg = { |
| 'default': 'torch.strided', |
| 'dynamic_type': 'Layout', |
| 'kwarg_only': True, |
| 'name': 'layout', |
| 'type': 'c10::optional<Layout>', |
| 'simple_type': 'Layout', |
| 'python_default_init': py_default_layout, |
| } |
| python_binding_arguments.append(layout_arg) |
| py_default_device = 'self.device()' if is_like_or_new_function_with_options else None |
| device_arg = { |
| 'default': 'None', |
| 'dynamic_type': 'Device', |
| 'kwarg_only': True, |
| 'name': 'device', |
| 'type': 'const Device &', |
| 'simple_type': 'Device', |
| 'python_default_init': py_default_device |
| } |
| python_binding_arguments.append(device_arg) |
| pin_memory_arg = { |
| 'default': False, |
| 'dynamic_type': 'bool', |
| 'kwarg_only': True, |
| 'name': 'pin_memory', |
| 'type': 'bool', |
| 'simple_type': 'bool', |
| } |
| python_binding_arguments.append(pin_memory_arg) |
| |
| if is_factory_or_like_or_new_function: |
| requires_grad_arg = { |
| 'default': False, |
| 'dynamic_type': 'bool', |
| 'kwarg_only': True, |
| 'name': 'requires_grad', |
| 'type': 'bool', |
| 'simple_type': 'bool', |
| } |
| python_binding_arguments.append(requires_grad_arg) |
| |
| return python_binding_arguments |
| |
| # |
| # declaration derived props, utils, etc. |
| # declarations are dicts loaded from Declarations.yaml, |
| # passed to our codegen methods by callers in gen_autograd |
| # |
| |
| def is_tensor_self(arg): |
| return arg['name'] == 'self' and arg['simple_type'] == 'Tensor' |
| |
| |
| def is_tensor_options(arg): |
| return arg['simple_type'] == 'TensorOptions' |
| |
| |
| def is_scatter(arg): |
| return arg.get('scatter_args') is not None |
| |
| def is_output(arg): |
| return arg.get('output', False) |
| |
| |
| def has_outputs(declaration): |
| return any([is_output(arg) for arg in declaration['arguments']]) |
| |
| |
| def get_tensor_options(declaration): |
| args = [arg for arg in declaration['arguments'] if is_tensor_options(arg)] |
| if len(args) == 0: |
| return None |
| if len(args) != 1: |
| raise RuntimeError( |
| '{}: multiple tensor options arguments'. |
| format(declaration['name'])) |
| return args[0] |
| |
| |
| def has_tensor_options(declaration): |
| return get_tensor_options(declaration) is not None |
| |
| |
| def is_torch_function(declaration): |
| return 'namespace' in declaration['method_of'] |
| |
| |
| def is_nn_module_function(declaration): |
| return declaration.get('python_module') == 'nn' |
| |
| |
| def function_namespace(declaration): |
| # TODO look into why these can't all be 'torch' calls |
| if has_tensor_options(declaration) or op_name(declaration).endswith('_like'): |
| return 'torch' |
| else: |
| return 'at' |
| |
| |
| def op_name(declaration): |
| name = declaration['name'] |
| if has_outputs(declaration): |
| if not name.endswith("_out"): |
| raise RuntimeError( |
| '{} has output params, expecting name ending with \'_out\''. |
| format(declaration['name'])) |
| return name[:-4] |
| else: |
| if name.endswith("_out"): |
| raise RuntimeError( |
| '{}: name ends with \'_out\', expecting output params'. |
| format(declaration['name'])) |
| return name |