|  | # Generates C++ autograd functions for the derivatives of ATen operations | 
|  | # | 
|  | # This writes two files: | 
|  | #  Functions.h/cpp: subclasses of autograd::Node | 
|  | #  python_functions.h/cpp: Python bindings for the above classes | 
|  | # | 
|  | from .gen_inplace_or_view_type import VIEW_FUNCTIONS | 
|  |  | 
|  | from typing import List, Sequence, Tuple | 
|  |  | 
|  | from torchgen.api.autograd import ( | 
|  | Derivative, | 
|  | DifferentiabilityInfo, | 
|  | SavedAttribute, | 
|  | uses_retain_variables, | 
|  | uses_single_grad, | 
|  | ) | 
|  | from torchgen.api.types import ( | 
|  | Binding, | 
|  | BaseCType, | 
|  | OptionalCType, | 
|  | tensorT, | 
|  | longT, | 
|  | doubleT, | 
|  | scalarT, | 
|  | stringT, | 
|  | boolT, | 
|  | intArrayRefT, | 
|  | tensorListT, | 
|  | MutRefCType, | 
|  | ListCType, | 
|  | ArrayRefCType, | 
|  | optionalIntArrayRefT, | 
|  | ) | 
|  | from torchgen.code_template import CodeTemplate | 
|  | from torchgen.utils import FileManager | 
|  | from torchgen.model import Argument | 
|  |  | 
|  | FUNCTION_DECLARATION = CodeTemplate( | 
|  | """\ | 
|  | struct TORCH_API ${op} : public ${superclass} { | 
|  | using ${superclass}::${superclass}; | 
|  | variable_list apply(variable_list&& grads) override; | 
|  | std::string name() const override { return "${op}"; } | 
|  | void release_variables() override { | 
|  | ${thread_lock} | 
|  | ${release_variables} | 
|  | } | 
|  | ${will_release_variables} | 
|  | ${saved_variables} | 
|  | ${saved_list_sizes} | 
|  | }; | 
|  | """ | 
|  | ) | 
|  |  | 
|  | WILL_RELEASE_VARIABLES = CodeTemplate( | 
|  | """\ | 
|  | bool retain_variables = true; | 
|  | void will_release_variables() override { | 
|  | retain_variables = false; | 
|  | } | 
|  | """ | 
|  | ) | 
|  |  | 
|  | FUNCTION_DEFINITION = CodeTemplate( | 
|  | """\ | 
|  | variable_list ${op}::apply(variable_list&& grads) { | 
|  | ${thread_lock} | 
|  | ${asserts} | 
|  | IndexRangeGenerator gen; | 
|  | ${compute_index_ranges} | 
|  | variable_list grad_inputs(gen.size()); | 
|  | ${body} | 
|  | return grad_inputs; | 
|  | } | 
|  | """ | 
|  | ) | 
|  |  | 
|  | GRAD_INPUT_MASK = CodeTemplate( | 
|  | """\ | 
|  | auto grad_input_mask = std::array<bool, ${n}>{ | 
|  | ${masks} | 
|  | };\ | 
|  | """ | 
|  | ) | 
|  |  | 
|  | DERIVATIVE_SINGLE = CodeTemplate( | 
|  | """\ | 
|  | if (should_compute_output({ ${name}_ix })) { | 
|  | auto grad_result = ${derivative}; | 
|  | copy_range(grad_inputs, ${name}_ix, grad_result); | 
|  | } | 
|  | """ | 
|  | ) | 
|  |  | 
|  | DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate( | 
|  | """\ | 
|  | if (should_compute_output({ ${name}_ix })) { | 
|  | copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result)); | 
|  | } | 
|  | """ | 
|  | ) | 
|  |  | 
|  | DERIVATIVE_MULTI = CodeTemplate( | 
|  | """\ | 
|  | if (should_compute_output({ ${idx_ranges} })) { | 
|  | ${grad_input_mask} | 
|  | auto grad_result = ${derivative}; | 
|  | ${copy_ranges} | 
|  | } | 
|  | """ | 
|  | ) | 
|  |  | 
|  | # Generates python bindings | 
|  | # | 
|  | # This generates the definitions for: | 
|  | #   (1) The PyTypeObject for each backward grad_fn subclassing Node | 
|  | #   (2) The entry for PyTypeObject's tp_getset slot (an array of PyGetSetDef structs) | 
|  | #       We generate one PyGetSetDef struct for each of grad_fn's saved inputs and outputs | 
|  | #       Each PyGetSetDef has a function ptr to a getter, also defined here (3). | 
|  | #   (3) Getters for each of grad_fn's saved inputs and outputs. | 
|  | # | 
|  | PY_FUNCTION_DEFINITION = CodeTemplate( | 
|  | """\ | 
|  | static PyTypeObject ${op}Class; | 
|  | addClass<${op}>(${op}Class, "${op}", ${op}_properties); | 
|  | """ | 
|  | ) | 
|  |  | 
|  | PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate( | 
|  | """\ | 
|  | ${all_getter_definitions} | 
|  |  | 
|  | static struct PyGetSetDef ${op}_properties[] = { | 
|  | THP_FUNCTION_DEFAULT_PROPERTIES, | 
|  | ${all_getsetdef_structs} | 
|  | {nullptr} /* sentinel */ | 
|  | }; | 
|  |  | 
|  | """ | 
|  | ) | 
|  |  | 
|  | PY_GETSETDEF_STRUCT = CodeTemplate( | 
|  | """\ | 
|  | {(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}""" | 
|  | ) | 
|  |  | 
|  | PY_RAW_GETSETDEF_STRUCT = CodeTemplate( | 
|  | """\ | 
|  | {(char*)"_raw_saved_${name}", (getter)THP${op}_${name}_raw_getter, nullptr, nullptr, nullptr}""" | 
|  | ) | 
|  |  | 
|  | # Getter templates | 
|  | GETTER_DEFINITION = CodeTemplate( | 
|  | """\ | 
|  | PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { | 
|  | HANDLE_TH_ERRORS | 
|  | auto prop = static_cast<${op}*>(self->cdata.get())->${name}; | 
|  | ${body} | 
|  | END_HANDLE_TH_ERRORS | 
|  | } | 
|  | """ | 
|  | ) | 
|  |  | 
|  | GETTER_DEFINITION_SAVEDVAR = CodeTemplate( | 
|  | """\ | 
|  | PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { | 
|  | HANDLE_TH_ERRORS | 
|  | const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_; | 
|  | ${body} | 
|  | END_HANDLE_TH_ERRORS | 
|  | } | 
|  | """ | 
|  | ) | 
|  |  | 
|  | GETTER_DEFINITION_RAW_SAVEDVAR = CodeTemplate( | 
|  | """\ | 
|  | PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) { | 
|  | HANDLE_TH_ERRORS | 
|  | const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_; | 
|  | ${body} | 
|  | END_HANDLE_TH_ERRORS | 
|  | } | 
|  | """ | 
|  | ) | 
|  |  | 
|  | GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate( | 
|  | """\ | 
|  | PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { | 
|  | HANDLE_TH_ERRORS | 
|  | const auto *node = static_cast<${op}*>(self->cdata.get()); | 
|  | const auto& prop = node->${name}_; | 
|  | if (node->${name}_released_) { | 
|  | PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); | 
|  | return nullptr; | 
|  | } | 
|  | ${body} | 
|  | END_HANDLE_TH_ERRORS | 
|  | } | 
|  | """ | 
|  | ) | 
|  |  | 
|  | GETTER_DEFINITION_RAW_VEC_SAVEDVAR = CodeTemplate( | 
|  | """\ | 
|  | PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) { | 
|  | HANDLE_TH_ERRORS | 
|  | const auto *node = static_cast<${op}*>(self->cdata.get()); | 
|  | const auto& prop = node->${name}_; | 
|  | if (node->${name}_released_) { | 
|  | PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); | 
|  | return nullptr; | 
|  | } | 
|  | ${body} | 
|  | END_HANDLE_TH_ERRORS | 
|  | } | 
|  | """ | 
|  | ) | 
|  |  | 
|  | GETTER_DEFINITION_OPT = CodeTemplate( | 
|  | """\ | 
|  | PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { | 
|  | HANDLE_TH_ERRORS | 
|  | auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name}; | 
|  | if (!opt_prop.has_value()) { | 
|  | Py_RETURN_NONE; | 
|  | } | 
|  | auto prop = opt_prop.value(); | 
|  | ${body} | 
|  | END_HANDLE_TH_ERRORS | 
|  | } | 
|  | """ | 
|  | ) | 
|  |  | 
|  | GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate( | 
|  | """\ | 
|  | PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { | 
|  | HANDLE_TH_ERRORS | 
|  | auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name}; | 
|  | if (!opt_prop.list.has_value()) { | 
|  | Py_RETURN_NONE; | 
|  | } | 
|  | auto prop = opt_prop.list.value(); | 
|  | ${body} | 
|  | END_HANDLE_TH_ERRORS | 
|  | } | 
|  | """ | 
|  | ) | 
|  |  | 
|  | # Getter body | 
|  | GETTER_BODY_SAVEDVAR = """\ | 
|  | return THPVariable_Wrap(prop.unpack(self->cdata)); | 
|  | """ | 
|  |  | 
|  | GETTER_BODY_RAW_SAVEDVAR = """\ | 
|  | pybind11::object obj = pybind11::cast(prop, pybind11::return_value_policy::reference); | 
|  | return obj.release().ptr(); | 
|  | """ | 
|  |  | 
|  | GETTER_BODY_VEC_SAVEDVAR = """\ | 
|  | PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); | 
|  | for (auto i: c10::irange(prop.size())) { | 
|  | PyTuple_SetItem(tup, (Py_ssize_t) i, THPVariable_Wrap(prop[i].unpack(self->cdata))); | 
|  | } | 
|  | return tup; | 
|  | """ | 
|  |  | 
|  | GETTER_BODY_RAW_VEC_SAVEDVAR = """\ | 
|  | PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); | 
|  | for (auto i : c10::irange(prop.size())) { | 
|  | pybind11::object obj = pybind11::cast(prop[i], pybind11::return_value_policy::reference); | 
|  | PyTuple_SetItem(tup, (Py_ssize_t) i, obj.release().ptr()); | 
|  | } | 
|  | return tup; | 
|  | """ | 
|  |  | 
|  | GETTER_BODY_ARRAYREF_LONG = """\ | 
|  | PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); | 
|  | for (auto i : c10::irange(prop.size())) { | 
|  | PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong((uint64_t) prop[i])); | 
|  | } | 
|  | return tup; | 
|  | """ | 
|  |  | 
|  | GETTER_BODY_ARRAYREF_DOUBLE = """\ | 
|  | PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); | 
|  | for (auto i : c10::irange(prop.size())) { | 
|  | PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble((double) prop[i])); | 
|  | } | 
|  | return tup; | 
|  | """ | 
|  |  | 
|  | GETTER_BODY_INT64_T = """\ | 
|  | return PyLong_FromUnsignedLong((int64_t) prop); | 
|  | """ | 
|  |  | 
|  | GETTER_BODY_DOUBLE = """\ | 
|  | return PyFloat_FromDouble((double) prop); | 
|  | """ | 
|  |  | 
|  | GETTER_BODY_BOOL = """\ | 
|  | if (prop) { | 
|  | Py_RETURN_TRUE; | 
|  | } else { | 
|  | Py_RETURN_FALSE; | 
|  | } | 
|  | """ | 
|  |  | 
|  | GETTER_BODY_STRING = """\ | 
|  | return PyUnicode_FromStringAndSize(prop.data(), prop.size()); | 
|  | """ | 
|  |  | 
|  | GETTER_BODY_SCALAR = """\ | 
|  | if (prop.isComplex()) { | 
|  | auto cprop = prop.to<c10::complex<double>>(); | 
|  | return PyComplex_FromDoubles(cprop.real(), cprop.imag()); | 
|  | } else if (prop.isFloatingPoint()) { | 
|  | return PyFloat_FromDouble(prop.to<double>()); | 
|  | } else if (prop.isIntegral(/*includeBool=*/false)) { | 
|  | return PyLong_FromLong(prop.to<int64_t>()); | 
|  | } else if (prop.isBoolean()) { | 
|  | if (prop.to<bool>()) { | 
|  | Py_RETURN_TRUE; | 
|  | } else { | 
|  | Py_RETURN_FALSE; | 
|  | } | 
|  | } else { | 
|  | PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type"); | 
|  | return nullptr; | 
|  | } | 
|  | """ | 
|  |  | 
|  | MISC_GETTER_DEFS = { | 
|  | OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T), | 
|  | BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE), | 
|  | OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE), | 
|  | BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL), | 
|  | BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR), | 
|  | OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR), | 
|  | } | 
|  |  | 
|  | # These functions have backwards which cannot be traced, and so must have | 
|  | # their backward functions traced opaquely. | 
|  | # VIEW_FUNCTIONS are not traceable because they use as_strided, which | 
|  | # has an untraceable backwards, see | 
|  | # https://github.com/pytorch/pytorch/issues/4250 | 
|  | # TODO: This is probably not exhaustive, but it's a start | 
|  | UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS | 
|  |  | 
|  |  | 
|  | def gen_autograd_functions_lib( | 
|  | out: str, | 
|  | differentiability_infos: Sequence[DifferentiabilityInfo], | 
|  | template_path: str, | 
|  | ) -> None: | 
|  | """Functions.h and Functions.cpp body | 
|  |  | 
|  | These contain the auto-generated subclasses of torch::autograd::Node | 
|  | for each every differentiable torch function. | 
|  | """ | 
|  |  | 
|  | # only create an autograd function if we are actually going to calculate a derivative | 
|  | infos = list( | 
|  | filter(lambda info: info.args_with_derivatives, differentiability_infos) | 
|  | ) | 
|  | declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos)) | 
|  | definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos)) | 
|  |  | 
|  | file_basename = "Functions" | 
|  | fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) | 
|  | for suffix in [".h", ".cpp"]: | 
|  | fname = file_basename + suffix | 
|  | fm.write_with_template( | 
|  | fname, | 
|  | fname, | 
|  | lambda: { | 
|  | "generated_comment": "@" + f"generated from {fm.template_dir}/" + fname, | 
|  | "autograd_function_declarations": declarations, | 
|  | "autograd_function_definitions": definitions, | 
|  | }, | 
|  | ) | 
|  |  | 
|  |  | 
|  | def gen_autograd_functions_python( | 
|  | out: str, | 
|  | differentiability_infos: Sequence[DifferentiabilityInfo], | 
|  | template_path: str, | 
|  | ) -> None: | 
|  |  | 
|  | fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) | 
|  | num_shards = 5 | 
|  | fm.write( | 
|  | "python_functions.h", | 
|  | lambda: { | 
|  | "generated_comment": f"@generated from {fm.template_dir}/python_functions.h", | 
|  | "shard_forward_declare": [ | 
|  | f"void initialize_autogenerated_functions_{i}();" | 
|  | for i in range(num_shards) | 
|  | ], | 
|  | "shard_call": [ | 
|  | f"initialize_autogenerated_functions_{i}();" for i in range(num_shards) | 
|  | ], | 
|  | }, | 
|  | ) | 
|  |  | 
|  | infos = list( | 
|  | filter(lambda info: info.args_with_derivatives, differentiability_infos) | 
|  | ) | 
|  | fm.write_sharded( | 
|  | "python_functions.cpp", | 
|  | infos, | 
|  | key_fn=lambda info: info.name, | 
|  | base_env={ | 
|  | "generated_comment": f"@generated from {fm.template_dir}/python_functions.cpp", | 
|  | }, | 
|  | env_callable=lambda info: { | 
|  | "py_function_initializers": [ | 
|  | process_function(info, PY_FUNCTION_DEFINITION) | 
|  | ], | 
|  | "py_function_props_and_getters": [ | 
|  | process_function(info, PY_FUNCTION_PROPS_AND_GETTERS) | 
|  | ], | 
|  | }, | 
|  | num_shards=num_shards, | 
|  | sharded_keys={"py_function_initializers", "py_function_props_and_getters"}, | 
|  | ) | 
|  |  | 
|  |  | 
|  | def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str: | 
|  | saved_variables: List[str] = [] | 
|  | release_variables: List[str] = [] | 
|  | saved_list_sizes: List[str] = [] | 
|  | unpack: List[str] = [] | 
|  | asserts: List[str] = [] | 
|  | compute_index_ranges: List[str] = [] | 
|  | getter_definitions: List[str] = [] | 
|  | py_getsetdef_structs: List[str] = [] | 
|  |  | 
|  | for arg in info.args_with_derivatives: | 
|  | if ( | 
|  | arg.type == "at::TensorList" | 
|  | or arg.type == "const c10::List<c10::optional<at::Tensor>> &" | 
|  | ): | 
|  | size = f"{arg.name}_size_" | 
|  | saved_list_sizes.append(f"size_t {arg.name}_size_;") | 
|  | else: | 
|  | size = "1" | 
|  | compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});") | 
|  |  | 
|  | def save_var(var: SavedAttribute, is_output: bool) -> None: | 
|  | name = var.nctype.name | 
|  | type = var.nctype.type | 
|  | should_append_getsetdef = True | 
|  | should_append_raw_getsetdef = False | 
|  |  | 
|  | if ( | 
|  | type == BaseCType(tensorT) | 
|  | or type == OptionalCType(BaseCType(tensorT)) | 
|  | or type == MutRefCType(OptionalCType(BaseCType(tensorT))) | 
|  | or (type == BaseCType(scalarT) and is_output) | 
|  | ): | 
|  | saved_variables.append(f"SavedVariable {name}_;") | 
|  | release_variables.append(f"{name}_.reset_data();") | 
|  | ptr = "shared_from_this()" if is_output else "" | 
|  | unpack.append(f"auto {name} = {name}_.unpack({ptr});") | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION_SAVEDVAR.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_SAVEDVAR | 
|  | ) | 
|  | ) | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION_RAW_SAVEDVAR.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR | 
|  | ) | 
|  | ) | 
|  | should_append_raw_getsetdef = True | 
|  | elif type == BaseCType(tensorListT): | 
|  | saved_variables.append(f"std::vector<SavedVariable> {name}_;") | 
|  | saved_variables.append(f"bool {name}_released_ = false;") | 
|  | # Just clear() is sufficient, we don't need to loop and clear each variable. | 
|  | # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. | 
|  | release_variables.append(f"{name}_.clear();") | 
|  | release_variables.append(f"{name}_released_ = true;") | 
|  | unpack.append(f"auto {name} = unpack_list({name}_);") | 
|  | asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION_VEC_SAVEDVAR.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR | 
|  | ) | 
|  | ) | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR | 
|  | ) | 
|  | ) | 
|  | should_append_raw_getsetdef = True | 
|  | elif type == ListCType(OptionalCType(BaseCType(tensorT))): | 
|  | saved_variables.append(f"std::vector<SavedVariable> {name}_;") | 
|  | saved_variables.append(f"bool {name}_released_ = false;") | 
|  | # Just clear() is sufficient, we don't need to loop and clear each variable. | 
|  | # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. | 
|  | release_variables.append(f"{name}_.clear();") | 
|  | release_variables.append(f"{name}_released_ = true;") | 
|  | unpack.append(f"auto {name} = unpack_opt_list({name}_);") | 
|  | asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION_VEC_SAVEDVAR.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR | 
|  | ) | 
|  | ) | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR | 
|  | ) | 
|  | ) | 
|  | should_append_raw_getsetdef = True | 
|  | elif type == BaseCType(intArrayRefT): | 
|  | saved_variables.append(f"std::vector<int64_t> {name};") | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG | 
|  | ) | 
|  | ) | 
|  | elif type == BaseCType(optionalIntArrayRefT): | 
|  | saved_variables.append(f"c10::OptionalArray<int64_t> {name};") | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION_OPT_ARRAYREF.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG | 
|  | ) | 
|  | ) | 
|  | elif type == OptionalCType(BaseCType(intArrayRefT)): | 
|  | saved_variables.append(f"c10::OptionalArray<int64_t> {name};") | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION_OPT_ARRAYREF.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG | 
|  | ) | 
|  | ) | 
|  | elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))): | 
|  | saved_variables.append(f"c10::OptionalArray<double> {name};") | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION_OPT_ARRAYREF.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE | 
|  | ) | 
|  | ) | 
|  | elif type == BaseCType(longT): | 
|  | saved_variables.append(f"{type.cpp_type()} {name} = 0;") | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_INT64_T | 
|  | ) | 
|  | ) | 
|  | elif type == BaseCType(stringT): | 
|  | saved_variables.append(f"std::string {name};") | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_STRING | 
|  | ) | 
|  | ) | 
|  | elif type == OptionalCType(BaseCType(stringT)): | 
|  | saved_variables.append(f"c10::optional<std::string> {name};") | 
|  | getter_definitions.append( | 
|  | GETTER_DEFINITION_OPT.substitute( | 
|  | op=info.op, name=name, body=GETTER_BODY_STRING | 
|  | ) | 
|  | ) | 
|  | else: | 
|  | saved_variables.append(f"{type.cpp_type()} {name};") | 
|  |  | 
|  | if type in MISC_GETTER_DEFS: | 
|  | getter_def, body = MISC_GETTER_DEFS[type] | 
|  | getter_definitions.append( | 
|  | getter_def.substitute(op=info.op, name=name, body=body) | 
|  | ) | 
|  | else: | 
|  | # Types we don't expose python bindings to yet: | 
|  | #   TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry, | 
|  | #   std::vector<std::vector<int64_t>>, std::vector<at::ScalarType> | 
|  | should_append_getsetdef = False | 
|  |  | 
|  | if should_append_getsetdef: | 
|  | py_getsetdef_structs.append( | 
|  | PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name) | 
|  | ) | 
|  | if should_append_raw_getsetdef: | 
|  | py_getsetdef_structs.append( | 
|  | PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name) | 
|  | ) | 
|  |  | 
|  | for var in info.all_saved_inputs: | 
|  | save_var(var, is_output=False) | 
|  | for var in info.all_saved_outputs: | 
|  | save_var(var, is_output=True) | 
|  |  | 
|  | # lock the mutex when we release variables and in Node::apply to protect thread safety | 
|  | # see Note [Thread Safety on Autograd Node] | 
|  | if len(release_variables) > 0: | 
|  | thread_lock = "std::lock_guard<std::mutex> lock(mutex_);" | 
|  | else: | 
|  | thread_lock = "" | 
|  |  | 
|  | if uses_retain_variables(info): | 
|  | will_release_variables = WILL_RELEASE_VARIABLES.substitute() | 
|  | else: | 
|  | will_release_variables = "" | 
|  |  | 
|  | body: List[str] = [] | 
|  |  | 
|  | if uses_single_grad(info): | 
|  | body.append("const auto& grad = grads[0];") | 
|  | else: | 
|  | # Generate aliases for gradients named for returned values. | 
|  | body.extend( | 
|  | f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];" | 
|  | for name in info.used_named_gradients | 
|  | ) | 
|  |  | 
|  | def emit_derivative( | 
|  | derivative: Derivative, | 
|  | args_with_derivatives: Sequence[Binding], | 
|  | ) -> Tuple[bool, str]: | 
|  | formula = derivative.formula | 
|  | var_names = derivative.var_names | 
|  | if len(var_names) == 1: | 
|  | checks_any_grad_defined = False | 
|  | if "not_implemented" not in formula: | 
|  | matching_args = [ | 
|  | arg for arg in args_with_derivatives if arg.name == var_names[0] | 
|  | ] | 
|  | if len(matching_args) == 1: | 
|  | # We can add undefined grad support if the input variable is a Tensor | 
|  | arg = matching_args[0] | 
|  | if isinstance(arg.argument, Argument) and str( | 
|  | arg.argument.type | 
|  | ) in ("Tensor", "Tensor?"): | 
|  | formula = "any_grad_defined ? (" + formula + ") : Tensor()" | 
|  | checks_any_grad_defined = True | 
|  | return ( | 
|  | checks_any_grad_defined, | 
|  | DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula), | 
|  | ) | 
|  | else: | 
|  | if "grad_input_mask" in formula: | 
|  | masks = [f"should_compute_output({{ {n}_ix }})," for n in var_names] | 
|  | grad_input_mask = GRAD_INPUT_MASK.substitute( | 
|  | masks=masks, n=len(var_names) | 
|  | ) | 
|  | else: | 
|  | grad_input_mask = "" | 
|  | idx_ranges = ", ".join(f"{n}_ix" for n in var_names) | 
|  | copy_ranges: List[str] = [] | 
|  | for i, n in enumerate(var_names): | 
|  | copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i)) | 
|  | return False, DERIVATIVE_MULTI.substitute( | 
|  | idx_ranges=idx_ranges, | 
|  | copy_ranges=copy_ranges, | 
|  | derivative=formula, | 
|  | grad_input_mask=grad_input_mask, | 
|  | ) | 
|  |  | 
|  | body.extend(unpack) | 
|  | need_any_grad_defined_var = False | 
|  | for derivative in info.derivatives: | 
|  | checks_any_grad_defined, derivative_text = emit_derivative( | 
|  | derivative, info.args_with_derivatives | 
|  | ) | 
|  | body.append(derivative_text) | 
|  | need_any_grad_defined_var |= checks_any_grad_defined | 
|  | # Since single-output derivative formulas need to check if grads are | 
|  | # defined, only perform the check once, before all the formulas | 
|  | if need_any_grad_defined_var: | 
|  | body.insert( | 
|  | -len(info.derivatives), | 
|  | "bool any_grad_defined = any_variable_defined(grads);", | 
|  | ) | 
|  |  | 
|  | if info.name in UNTRACEABLE_FUNCTIONS: | 
|  | superclass = "Node" | 
|  | else: | 
|  | superclass = "TraceableFunction" | 
|  |  | 
|  | all_getsetdef_structs = ( | 
|  | ",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else "" | 
|  | ) | 
|  | all_getter_definitions = "\n".join(getter_definitions) | 
|  |  | 
|  | return template.substitute( | 
|  | op=info.op, | 
|  | compute_index_ranges=compute_index_ranges, | 
|  | saved_variables=saved_variables, | 
|  | release_variables=release_variables, | 
|  | saved_list_sizes=saved_list_sizes, | 
|  | asserts=asserts, | 
|  | thread_lock=thread_lock, | 
|  | will_release_variables=will_release_variables, | 
|  | body=body, | 
|  | superclass=superclass, | 
|  | all_getter_definitions=all_getter_definitions, | 
|  | all_getsetdef_structs=all_getsetdef_structs, | 
|  | ) |