| from __future__ import annotations | 
 |  | 
 | import re | 
 | from dataclasses import dataclass | 
 | from typing import cast, Sequence | 
 |  | 
 | from torchgen import local | 
 | from torchgen.api import cpp | 
 | from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT | 
 | from torchgen.model import ( | 
 |     BaseTy, | 
 |     BaseType, | 
 |     FunctionSchema, | 
 |     ListType, | 
 |     NativeFunction, | 
 |     NativeFunctionsViewGroup, | 
 |     SchemaKind, | 
 |     Type, | 
 | ) | 
 | from torchgen.utils import IDENT_REGEX | 
 |  | 
 |  | 
 | # Represents a saved attribute involved in backward calculation. | 
 | # Note that it can be a derived property of an input argument, e.g.: | 
 | # we could save `other.scalar_type()` instead of the entire `other` tensor. | 
 | @dataclass(frozen=True) | 
 | class SavedAttribute: | 
 |     # The NamedCType holds the updated name and cpp type of the attribute | 
 |     # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type` | 
 |     nctype: NamedCType | 
 |  | 
 |     # The expression to read the derived property at save time, e.g.: | 
 |     # `other.scalar_type()`. | 
 |     expr: str | 
 |  | 
 |  | 
 | # Represents a backward formula that calculates derivatives for one | 
 | # or more tensors. | 
 | @dataclass(frozen=True) | 
 | class Derivative: | 
 |     # The formula string (legit C++ expression). | 
 |     # Note that expressions against input arguments have been replaced with the | 
 |     # corresponding saved attributes. | 
 |     # E.g.: | 
 |     #  raw formula: `mul_tensor_backward(grad, self, other.scalar_type())` | 
 |     #         here: `mul_tensor_backward(grad, self, other_scalar_type)` | 
 |     formula: str | 
 |  | 
 |     # The formula string before input argument replacement | 
 |     original_formula: str | 
 |  | 
 |     # Names of the arguments for which this formula calculates derivatives. | 
 |     var_names: tuple[str, ...] | 
 |  | 
 |     # Saved inputs that are referenced by the formula. | 
 |     saved_inputs: tuple[SavedAttribute, ...] | 
 |  | 
 |     # Saved outputs that are referenced by the formula. | 
 |     saved_outputs: tuple[SavedAttribute, ...] | 
 |  | 
 |     # Gradients that are referenced by name in the formula. | 
 |     named_gradients: set[str] | 
 |  | 
 |  | 
 | # Represents a forward formula that calculates forward derivatives | 
 | # for one tensor. | 
 | @dataclass(frozen=True) | 
 | class ForwardDerivative: | 
 |     # The formula string (legit C++ expression). | 
 |     # Note that special keywords such as "linear" or "element_wise" have been | 
 |     # replaced by the automatically generated formula. | 
 |     formula: str | 
 |  | 
 |     # Name of the output arguments for which this formula calculates forward | 
 |     # derivatives | 
 |     var_names: tuple[str, ...] | 
 |  | 
 |     # Type of the output arguments for which this formula calculates forward | 
 |     # derivatives | 
 |     var_types: tuple[Type, ...] | 
 |  | 
 |     # Inputs for which the forward derivatives are required for this formula | 
 |     required_inputs_fw_grad: tuple[str, ...] | None | 
 |  | 
 |     # Inputs for which the primal is required for this formula | 
 |     required_inputs_primal: tuple[str, ...] | None | 
 |  | 
 |     # Flag to specify if this formula requires the original value of self | 
 |     # This is only used by inplace operations | 
 |     required_original_self_value: bool | 
 |  | 
 |     # If this formula is specified in derivatives.yaml or if we are re-using the | 
 |     # out of place formula for inplace | 
 |     is_reusing_outplace_formula: bool | 
 |  | 
 |  | 
 | # Represents differentiability info for a NativeFunction. | 
 | @dataclass(frozen=True) | 
 | class DifferentiabilityInfo: | 
 |     # The base name read from derivatives.yaml. | 
 |     name: str | 
 |  | 
 |     # The matching native function. | 
 |     # | 
 |     # There can be multiple NativeFunction having the same base name: | 
 |     #  - different overloads with different types of input arguments; | 
 |     #  - in-place/out/functional variants of the same function; | 
 |     # | 
 |     # We first use the schema string (under the 'name' key) in derivatives.yaml | 
 |     # to find the NativeFunction having the same schema string. | 
 |     # Then we find the in-place/out/functional variants of the matching function. | 
 |     # Among these variants, we choose the one having the same name as the | 
 |     # derivatives.yaml entry. If there is no exact match, then we choose the | 
 |     # in-place variant. | 
 |     # TODO: maybe the logic to search for all variants is no longer necessary? | 
 |     func: NativeFunction | 
 |  | 
 |     # The name of the generated autograd function. | 
 |     # It's set only if we will calculate a derivative, i.e. | 
 |     # 'args_with_derivatives' is not empty. | 
 |     op: str | None | 
 |  | 
 |     # The derivatives formulae for this function. | 
 |     # Note that the length of this sequence is the number of differentiable inputs | 
 |     derivatives: Sequence[Derivative] | 
 |  | 
 |     # The forward derivatives formulae for this function. | 
 |     # Note that the length of this sequence is the number of differentiable outputs | 
 |     forward_derivatives: Sequence[ForwardDerivative] | 
 |  | 
 |     # The union of 'saved_inputs' of all 'derivatives'. | 
 |     all_saved_inputs: Sequence[SavedAttribute] | 
 |  | 
 |     # The union of 'saved_outputs' of all 'derivatives'. | 
 |     all_saved_outputs: Sequence[SavedAttribute] | 
 |  | 
 |     # All named gradients that are available for use, in the same | 
 |     # order as in the grads vector. | 
 |     available_named_gradients: Sequence[str] | 
 |  | 
 |     # The named gradients that are used in any of the derivatives. | 
 |     # Invariant: all(name in available_named_gradients for name in used_named_gradients) | 
 |     used_named_gradients: set[str] | 
 |  | 
 |     # The function's input arguments for which it calculates derivatives. | 
 |     # It's the union of 'var_names' of all 'derivatives', sorted by the | 
 |     # argument order in the function schema. | 
 |     args_with_derivatives: Sequence[Binding] | 
 |  | 
 |     # Names of arguments whose derivative formula is 'non_differentiable'. | 
 |     non_differentiable_arg_names: Sequence[str] | 
 |  | 
 |     # Raw data read from derivatives.yaml. | 
 |     output_differentiability: list[bool] | None | 
 |  | 
 |     # output_differentiability in derivatives.yaml can be a list of | 
 |     # conditions that express if the output is differentiable. In this case, | 
 |     # the number of conditions must match the number of outputs | 
 |     # (NB: we only support one condition right now). | 
 |     # output_differentiability gets populated with True for each condition, | 
 |     # while output_differentiability_conditions gets populated with the conditions | 
 |     output_differentiability_conditions: list[str] | None | 
 |  | 
 |     @property | 
 |     def has_derivatives(self) -> bool: | 
 |         return len(self.args_with_derivatives) > 0 | 
 |  | 
 |     # Generates a new DifferentiabilityInfo using the exact same set of derivative information, | 
 |     # but with a new operator name. | 
 |     # This is used when generating "copy" variants of view ops, | 
 |     # which are able to use the exact same derivative formula as the original view op | 
 |     # See Note [Codegen'd {view}_copy Operators] | 
 |     def create_view_copy_from_view_derivative( | 
 |         self, g: NativeFunctionsViewGroup | 
 |     ) -> DifferentiabilityInfo | None: | 
 |         if g.view_copy is None: | 
 |             return None | 
 |         f = g.view_copy | 
 |  | 
 |         name_split_by_period = self.name.split(".", maxsplit=2) | 
 |         # Append a "_copy" to the base name of the operator (but keep the overload name the same) | 
 |         view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join( | 
 |             name_split_by_period[1:] | 
 |         ) | 
 |         view_copy_op_name = None if self.op is None else f"{self.op}_copy" | 
 |  | 
 |         return DifferentiabilityInfo( | 
 |             # Use the "_copy" version of name/func/op | 
 |             name=view_copy_name, | 
 |             func=f, | 
 |             op=view_copy_op_name, | 
 |             # But keep all derivative info the same | 
 |             derivatives=self.derivatives, | 
 |             forward_derivatives=self.forward_derivatives, | 
 |             all_saved_inputs=self.all_saved_inputs, | 
 |             all_saved_outputs=self.all_saved_outputs, | 
 |             available_named_gradients=self.available_named_gradients, | 
 |             used_named_gradients=self.used_named_gradients, | 
 |             args_with_derivatives=self.args_with_derivatives, | 
 |             non_differentiable_arg_names=self.non_differentiable_arg_names, | 
 |             output_differentiability=self.output_differentiability, | 
 |             output_differentiability_conditions=self.output_differentiability_conditions, | 
 |         ) | 
 |  | 
 |  | 
 | def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool: | 
 |     if info is None: | 
 |         return False | 
 |     for derivative in info.derivatives: | 
 |         formula = derivative.formula | 
 |         if re.search(IDENT_REGEX.format(ident), formula): | 
 |             return True | 
 |     return False | 
 |  | 
 |  | 
 | def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool: | 
 |     return uses_ident(info, "retain_variables") | 
 |  | 
 |  | 
 | def uses_single_grad(info: DifferentiabilityInfo | None) -> bool: | 
 |     return uses_ident(info, "grad") | 
 |  | 
 |  | 
 | # Represents a differentiable `Argument`. | 
 | # How is it different from the `Argument` type? | 
 | # - It's processed Arguments which are differentiable and only used in the | 
 | #   context of the autograd codegen; | 
 | # - It can represent SelfArgument or regular Argument but not TensorOptionsArgument; | 
 | @dataclass(frozen=True) | 
 | class DifferentiableInput: | 
 |     name: str | 
 |     type: Type | 
 |  | 
 |     # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. | 
 |     cpp_type: str | 
 |  | 
 |  | 
 | # Represents a differentiable `Return`. | 
 | # How it it different from the `Return` type? | 
 | # - The name in `Return` is optional. Here it is always populated using the same | 
 | #   `cpp.return_names()` method. | 
 | #   TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant? | 
 | # - It's processed Returns which are differentiable, in compliance with the | 
 | #   `output_differentiability` field defined in derivatives.yaml (if specified), | 
 | #   and are only used in the context of the autograd codegen; | 
 | @dataclass(frozen=True) | 
 | class DifferentiableOutput: | 
 |     name: str | 
 |     type: Type | 
 |  | 
 |     # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. | 
 |     cpp_type: str | 
 |  | 
 |  | 
 | @dataclass(frozen=True) | 
 | class NativeFunctionWithDifferentiabilityInfo: | 
 |     func: NativeFunction | 
 |     info: dict[str, DifferentiabilityInfo] | None | 
 |     fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None | 
 |  | 
 |  | 
 | # TODO: Update comment below since it is out of date. | 
 | def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str: | 
 |     """How are we going to call the underlying implementation of a | 
 |     declaration?  There are two strategies: | 
 |         - use_derived: we want to call the implementation on CPUDoubleType | 
 |           (or a similar, derived Type instance).  Because these derived | 
 |           instances deal in Tensors, not Variables (it's a completely different | 
 |           object, so it doesn't dispatch back to VariableType), code on | 
 |           this dispatch path needs to wrap/unwrap tensors.  If the | 
 |           derived implementation takes and returns tensors, the | 
 |           implementation is usually differentiable (although we also use | 
 |           the derived dispatch path for non-differentiable functions | 
 |           that we still want to dispatch on the derived Type instance; | 
 |           e.g., size()) | 
 |         - use_type: we want to call the implementation on Type, because | 
 |           it is implemented concretely, and the functions it invokes will | 
 |           get dispatched back to VariableType (which will ensure that they | 
 |           are differentiable.) | 
 |     """ | 
 |     # fn is derived as long as any of its per-key differentiability infos | 
 |     # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType | 
 |     # and ADInplaceOrViewType. We want to generate these functions as long as a | 
 |     # derivative is defined for ANY dispatch key. | 
 |     if fn.func.is_abstract or ( | 
 |         fn.info is not None and any(info.has_derivatives for info in fn.info.values()) | 
 |     ): | 
 |         # If the function is abstract (not implemented on at::Type), we must | 
 |         # call the implementation on the derived type with unpacked tensors. | 
 |  | 
 |         # If the function has a derivative specified and is concrete, we could | 
 |         # call either implementation. We prefer the calling the derived | 
 |         # type's implementation with unpacked tensors because it is more | 
 |         # performant in some cases: any internal calls to other ATen functions | 
 |         # won't have the history tracked. | 
 |  | 
 |         # If the function has a type dispatched argument (i.e. is a factory), | 
 |         # we prefer calling the derived type's implementation both because it is | 
 |         # more performant and to ensure factory functions return tensors with _version | 
 |         # of 0 (probably not strictly necessary, but nice to have to keeps versions simple | 
 |         # to understand. | 
 |  | 
 |         return "use_derived" | 
 |     else: | 
 |         # If the function is concrete (we don't have to override it) and we | 
 |         # didn't declare it in derivatives.yaml, we'll assume that it is | 
 |         # actually implemented out of differentiable functions. (This | 
 |         # assumption might not hold, but then you'll see gradcheck fail.) | 
 |         return "use_type" | 
 |  | 
 |  | 
 | def is_foreach_func(f: NativeFunction) -> bool: | 
 |     return f.func.name.name.base.startswith("_foreach_") | 
 |  | 
 |  | 
 | # note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind | 
 | # is functional for their backward derivatives (and forward derivatives in the future), i.e., | 
 | # they would find such one in `functional_info_by_signature`. There however are some exceptions: | 
 | _foreach_with_inplace_ref = {"_foreach_zero_"} | 
 | _foreach_with_tensor_overload = { | 
 |     "_foreach_add.Tensor", | 
 |     "_foreach_mul.Tensor", | 
 |     "_foreach_div.Tensor", | 
 | } | 
 | # The following do not support the alpha kwarg, which the nonforeach versions support. | 
 | _skip_argument_len_check = { | 
 |     "_foreach_add.Scalar", | 
 |     "_foreach_add_.Scalar", | 
 |     "_foreach_add.ScalarList", | 
 |     "_foreach_add_.ScalarList", | 
 |     "_foreach_sub.Scalar", | 
 |     "_foreach_sub_.Scalar", | 
 |     "_foreach_sub.ScalarList", | 
 |     "_foreach_sub_.ScalarList", | 
 | } | 
 |  | 
 |  | 
 | # Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function | 
 | # reference to generate derivatives. | 
 | def is_reference_for_foreach( | 
 |     f: NativeFunction, | 
 |     function_schema: FunctionSchema, | 
 | ) -> bool: | 
 |     return ( | 
 |         f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base | 
 |         and ( | 
 |             not function_schema.name.name.inplace | 
 |             or str(f.func.name) in _foreach_with_inplace_ref | 
 |         ) | 
 |         and ( | 
 |             str(f.func.name) in _skip_argument_len_check | 
 |             or len(f.func.arguments.flat_non_out) | 
 |             == len(function_schema.arguments.flat_non_out) | 
 |         ) | 
 |         and all( | 
 |             ref_arg.type in (arg.type, getattr(arg.type, "elem", None)) | 
 |             for arg, ref_arg in zip( | 
 |                 f.func.arguments.flat_non_out, | 
 |                 function_schema.arguments.flat_non_out, | 
 |             ) | 
 |         ) | 
 |     ) | 
 |  | 
 |  | 
 | # TODO(crcrpar): Avoid hard coding "Default" ideally. | 
 | def gen_foreach_derivativeinfo( | 
 |     foreach_function: NativeFunction, | 
 |     functional_info_by_signature: dict[ | 
 |         FunctionSchema, dict[str, DifferentiabilityInfo] | 
 |     ], | 
 |     non_functional_info_by_signature: dict[ | 
 |         FunctionSchema, dict[str, DifferentiabilityInfo] | 
 |     ], | 
 |     dispatch_key: str = "Default", | 
 | ) -> tuple[DifferentiabilityInfo | None, bool]: | 
 |     """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place. | 
 |  | 
 |     The second return value indicates whether the info is generated in this function. | 
 |     """ | 
 |     ref_diff_info: DifferentiabilityInfo | None = None | 
 |  | 
 |     for function_schema, diff_info in functional_info_by_signature.items(): | 
 |         if not is_reference_for_foreach(foreach_function, function_schema): | 
 |             continue | 
 |         ref_diff_info = diff_info[dispatch_key] | 
 |         if ref_diff_info is not None: | 
 |             break | 
 |     # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature | 
 |     # while the info of `zero_` is in non_functional_info_by_signature | 
 |     if ( | 
 |         ref_diff_info is None | 
 |         and foreach_function.func.kind() == SchemaKind.inplace | 
 |         and str(foreach_function.func.name) in _foreach_with_inplace_ref | 
 |     ): | 
 |         for function_schema, diff_info in non_functional_info_by_signature.items(): | 
 |             if not is_reference_for_foreach(foreach_function, function_schema): | 
 |                 continue | 
 |             ref_diff_info = diff_info[dispatch_key] | 
 |             if ref_diff_info is not None: | 
 |                 break | 
 |     if ref_diff_info is None: | 
 |         return None, False | 
 |  | 
 |     # non out-place uses the existing Derivative. | 
 |     if foreach_function.func.kind() == SchemaKind.inplace: | 
 |         return ref_diff_info, False | 
 |  | 
 |     map_refarg2foreacharg, map_name2arg = {}, {} | 
 |     for i, (arg, ref_arg) in enumerate( | 
 |         zip( | 
 |             foreach_function.func.arguments.flat_non_out, | 
 |             function_schema.arguments.flat_non_out, | 
 |         ) | 
 |     ): | 
 |         map_refarg2foreacharg[ref_arg.name] = arg.name | 
 |         map_name2arg[arg.name] = arg | 
 |  | 
 |     all_saved_inputs, all_saved_outputs, all_var_names = [], [], [] | 
 |     modified_derivative_formulas = [] | 
 |     for i, derivative in enumerate(ref_diff_info.derivatives): | 
 |         modified_formula = derivative.formula.replace("grad", "grads[i]").replace( | 
 |             "result", "result[i]" | 
 |         ) | 
 |         saved_inputs, saved_outputs = [], [] | 
 |         # note(crcrpar): This context seems necessary to call `cpp.argument_type` | 
 |         with local.parametrize( | 
 |             use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, | 
 |             use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, | 
 |         ): | 
 |             for ref_input in derivative.saved_inputs: | 
 |                 ref_input_jit_name = ref_input.expr.split(".")[0] | 
 |                 mapped_name = map_refarg2foreacharg[ref_input_jit_name] | 
 |                 if isinstance(map_name2arg[mapped_name].type, ListType): | 
 |                     mapped_expr = mapped_name + "[i]" | 
 |                 else: | 
 |                     mapped_expr = mapped_name | 
 |                 new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr) | 
 |                 modified_formula = modified_formula.replace( | 
 |                     cast(str, ref_input.nctype.name), new_expr | 
 |                 ) | 
 |  | 
 |                 nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name) | 
 |                 canonical_nctype = NamedCType( | 
 |                     nctype.name, nctype.type.remove_const_ref() | 
 |                 ) | 
 |                 saved_inputs.append( | 
 |                     SavedAttribute(nctype=canonical_nctype, expr=mapped_name) | 
 |                 ) | 
 |             for ref_output in derivative.saved_outputs: | 
 |                 if ref_output.nctype.name == "result": | 
 |                     saved_outputs.append( | 
 |                         SavedAttribute( | 
 |                             nctype=NamedCType( | 
 |                                 name="result", type=BaseCType(tensorListT) | 
 |                             ), | 
 |                             expr="result", | 
 |                         ) | 
 |                     ) | 
 |                 else: | 
 |                     raise RuntimeError("") | 
 |         var_names = [map_refarg2foreacharg[var] for var in derivative.var_names] | 
 |         all_var_names.extend(var_names) | 
 |         all_saved_inputs.extend(saved_inputs) | 
 |         all_saved_outputs.extend(saved_outputs) | 
 |         modified_derivative = Derivative( | 
 |             formula=modified_formula, | 
 |             original_formula=derivative.formula, | 
 |             var_names=tuple(var_names), | 
 |             saved_inputs=tuple(saved_inputs), | 
 |             saved_outputs=tuple(saved_outputs), | 
 |             named_gradients=set(), | 
 |         ) | 
 |         modified_derivative_formulas.append(modified_derivative) | 
 |  | 
 |     with local.parametrize( | 
 |         use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, | 
 |         use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, | 
 |     ): | 
 |         args_with_derivatives = [ | 
 |             Binding( | 
 |                 name=arg.name, | 
 |                 nctype=cpp.argument_type(arg, binds=arg.name), | 
 |                 argument=arg, | 
 |                 default=None, | 
 |             ) | 
 |             for arg in foreach_function.func.arguments.flat_non_out | 
 |             if arg.name in all_var_names | 
 |         ] | 
 |  | 
 |     forward_derivatives: list[ForwardDerivative] = [] | 
 |     fw_derivative: ForwardDerivative | 
 |     for fw_derivative in ref_diff_info.forward_derivatives: | 
 |         var_names: list[str] = list(fw_derivative.var_names)  # type: ignore[no-redef] | 
 |         var_types: list[Type] = list(fw_derivative.var_types) | 
 |         required_inputs_fw_grad: list[str] = [] | 
 |         required_inputs_primal: list[str] = [] | 
 |         if fw_derivative.required_inputs_fw_grad is not None: | 
 |             required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad) | 
 |         if fw_derivative.required_inputs_primal: | 
 |             required_inputs_primal = list(fw_derivative.required_inputs_primal) | 
 |         modified_formula = fw_derivative.formula | 
 |  | 
 |         # Foreach's result is TensorList | 
 |         if "result" in modified_formula: | 
 |             modified_formula = fw_derivative.formula.replace("result", "result[i]") | 
 |  | 
 |         for foreach_arg, ref_arg in zip( | 
 |             foreach_function.func.arguments.flat_non_out, | 
 |             ref_diff_info.func.func.arguments.flat_non_out, | 
 |         ): | 
 |             # Modify reference forward formula | 
 |             if ( | 
 |                 isinstance(foreach_arg.type, ListType) | 
 |                 and not foreach_arg.type.is_tensor_like() | 
 |             ): | 
 |                 # Assuming ScalarList | 
 |                 modified_formula = modified_formula.replace( | 
 |                     ref_arg.name, foreach_arg.name + "[i]" | 
 |                 ) | 
 |             elif foreach_arg.type.is_tensor_like(): | 
 |                 # Assuming TensorList / Tensor | 
 |                 # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}" | 
 |                 assert isinstance(foreach_arg.type, ListType) or ( | 
 |                     foreach_arg.type == BaseType(BaseTy.Tensor) | 
 |                     and str(foreach_function.func.name) in _foreach_with_tensor_overload | 
 |                 ), f"{foreach_function.func.name}, {foreach_arg.type}" | 
 |                 for suffix in ("_p", "_t"): | 
 |                     curr_expr = ref_arg.name + suffix | 
 |                     if curr_expr in modified_formula: | 
 |                         new_expr = foreach_arg.name + suffix | 
 |                         modified_formula = modified_formula.replace(curr_expr, new_expr) | 
 |             else: | 
 |                 # Assuming Scalar | 
 |                 if foreach_arg.name != ref_arg.name: | 
 |                     modified_formula = modified_formula.replace( | 
 |                         ref_arg.name, foreach_arg.name | 
 |                     ) | 
 |  | 
 |             # note(crcrpar): there should exist a cooler way... | 
 |             for i, name in enumerate(var_names): | 
 |                 if name == ref_arg.name: | 
 |                     var_names[i] = foreach_arg.name | 
 |                     var_types[i] = foreach_arg.type | 
 |             for i, name in enumerate(required_inputs_fw_grad): | 
 |                 if name == ref_arg.name: | 
 |                     required_inputs_fw_grad[i] = foreach_arg.name | 
 |             for i, name in enumerate(required_inputs_primal): | 
 |                 if name == ref_arg.name: | 
 |                     required_inputs_primal[i] = foreach_arg.name | 
 |         forward_derivatives.append( | 
 |             ForwardDerivative( | 
 |                 formula=modified_formula, | 
 |                 var_names=tuple(var_names), | 
 |                 var_types=tuple(var_types), | 
 |                 required_inputs_fw_grad=tuple(required_inputs_fw_grad), | 
 |                 required_inputs_primal=tuple(required_inputs_primal), | 
 |                 required_original_self_value=fw_derivative.required_original_self_value, | 
 |                 is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula, | 
 |             ) | 
 |         ) | 
 |  | 
 |     return ( | 
 |         DifferentiabilityInfo( | 
 |             name=foreach_function.func.name.name.base, | 
 |             func=foreach_function, | 
 |             op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}", | 
 |             derivatives=modified_derivative_formulas, | 
 |             forward_derivatives=forward_derivatives, | 
 |             all_saved_inputs=tuple(set(all_saved_inputs)), | 
 |             all_saved_outputs=tuple(set(all_saved_outputs)), | 
 |             available_named_gradients=(), | 
 |             used_named_gradients=set(), | 
 |             args_with_derivatives=args_with_derivatives, | 
 |             non_differentiable_arg_names=[], | 
 |             output_differentiability=None, | 
 |             output_differentiability_conditions=None, | 
 |         ), | 
 |         True, | 
 |     ) | 
 |  | 
 |  | 
 | def match_differentiability_info( | 
 |     native_functions: list[NativeFunction], | 
 |     differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], | 
 | ) -> list[NativeFunctionWithDifferentiabilityInfo]: | 
 |     """Sets the "derivative" key on declarations to matching autograd function | 
 |     In-place functions will use the out-of-place derivative definition if there | 
 |     is no in-place specific derivative. | 
 |     """ | 
 |  | 
 |     functional_info_by_signature = { | 
 |         schema.signature(strip_default=True): info_dict | 
 |         for schema, info_dict in differentiability_infos.items() | 
 |         if schema.kind() == SchemaKind.functional | 
 |     } | 
 |     non_functional_info_by_signature = { | 
 |         schema.signature(strip_default=True): info_dict | 
 |         for schema, info_dict in differentiability_infos.items() | 
 |         if schema.kind() != SchemaKind.functional | 
 |     } | 
 |  | 
 |     def find_info( | 
 |         f: NativeFunction, | 
 |     ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]: | 
 |         # Don't bother matching info to generated out= variants | 
 |         if "generated" in f.tags and f.func.kind() == SchemaKind.out: | 
 |             return None, False | 
 |  | 
 |         # (1) Check for an exact match | 
 |         if f.func in differentiability_infos: | 
 |             return differentiability_infos[f.func], True | 
 |  | 
 |         # (2) If no exact match, check if the out-of-place variant | 
 |         # of this operator has a match. | 
 |         # i.e mul() for mul_() or mul_out() | 
 |         # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing | 
 |         # native functions instead of the out-place counterparts. | 
 |         f_sig = f.func.signature(strip_default=True) | 
 |         if f_sig in functional_info_by_signature and not is_foreach_func(f): | 
 |             return functional_info_by_signature[f_sig], False | 
 |  | 
 |         # (3) Some operators have a derivative explicitly defined for the mutable | 
 |         # variant, but get a code-generated out-of-place variant which does *not* | 
 |         # come with a derivative formula. | 
 |         # For the generated out-of-place variant, use the mutable variant's formula | 
 |         # if it exists. | 
 |         if "generated" in f.tags and f_sig in non_functional_info_by_signature: | 
 |             info_dict = non_functional_info_by_signature[f_sig] | 
 |             # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389 | 
 |             assert not any( | 
 |                 any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs) | 
 |                 for info in info_dict.values() | 
 |             ), f"""\ | 
 | Attempted to convert a derivative formula for a mutable operator | 
 |  to be used by automatically by its functional variant ("{str(f.func)}"). | 
 |  this is not currently supported (we'd need to fix up the formula in the codegen).""" | 
 |             return info_dict, False | 
 |  | 
 |         # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml` | 
 |         if is_foreach_func(f): | 
 |             assert f.func not in differentiability_infos | 
 |             diff_info, is_generated = gen_foreach_derivativeinfo( | 
 |                 f, | 
 |                 functional_info_by_signature, | 
 |                 non_functional_info_by_signature, | 
 |             ) | 
 |             if diff_info is None: | 
 |                 return None, False | 
 |             # TODO(crcrpar): Avoid hard coding "Default" ideally. | 
 |             diff_info_dict = {"Default": diff_info} | 
 |             if is_generated: | 
 |                 differentiability_infos[f.func] = diff_info_dict | 
 |                 functional_info_by_signature[f.func] = diff_info_dict | 
 |             return diff_info_dict, is_generated | 
 |  | 
 |         return None, False | 
 |  | 
 |     result: list[NativeFunctionWithDifferentiabilityInfo] = [] | 
 |     for f in native_functions: | 
 |         info_dict, is_exact_match = find_info(f) | 
 |  | 
 |         # Currently, the '.strides()' to 'strides_or_error' replacement does not support | 
 |         # 'self' derivatives of an inplace function, so we must check for this case. | 
 |         if f.func.kind() == SchemaKind.inplace and (info_dict is not None): | 
 |             for info in info_dict.values(): | 
 |                 for derivative in info.derivatives: | 
 |                     if "self" in derivative.var_names: | 
 |                         for saved_input in derivative.saved_inputs: | 
 |                             assert "strides_or_error" not in saved_input.expr, ( | 
 |                                 "Calling '.strides()' in the 'self' derivative formula of an " | 
 |                                 f"in-place function is not supported: {f.func}" | 
 |                             ) | 
 |  | 
 |         if not info_dict: | 
 |             result.append( | 
 |                 NativeFunctionWithDifferentiabilityInfo( | 
 |                     func=f, info=None, fw_derivatives=None | 
 |                 ) | 
 |             ) | 
 |             continue | 
 |  | 
 |         fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {} | 
 |         for key, info in info_dict.items(): | 
 |             if not info.forward_derivatives: | 
 |                 fw_derivative_dict[key] = [] | 
 |                 continue | 
 |  | 
 |             forward_derivatives = info.forward_derivatives | 
 |  | 
 |             # For functions that have a single def for out-of-place and inplace (like abs()) | 
 |             if f.func.kind() == SchemaKind.inplace: | 
 |                 # For inplace functions there is a little bit of work to do: | 
 |                 #  1) Validate the formula and make sure the input that is modified in not used: | 
 |                 #    - If there is a formula for the inplace variant of the function (is_exact_match == True) then | 
 |                 #      we make sure that the original value of the input that is being modified inplace (self_p) is | 
 |                 #      not used in the formula. Note that the formula can use "original_self_p" here and that would | 
 |                 #      trigger a clone of the original input. | 
 |                 #    - If we are re-using the out of place formula (is_exact_match == False) then we replace every | 
 |                 #      occurrence of self_p and self_t by original_self_p and original_self_t. These will be | 
 |                 #      populated by cloned version of the original input (either the clone done by the backward AD | 
 |                 #      logic if self is also used in a backward formula or a special clone that we add). | 
 |                 #  2) At this point, there cannot be a self_p in the formula. | 
 |                 #  3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is | 
 |                 #     simply called self (as it is modified inplace). | 
 |                 #  4) Update the required primals data in case it used to contain "result" but should now contain | 
 |                 #     "self" | 
 |                 #  5) If it is not an exact match, the user formula is not modifying the existing forward grad | 
 |                 #     inplace as it should. So add some code that makes sure that we do so if the forward grad | 
 |                 #     already exists. | 
 |  | 
 |                 assert ( | 
 |                     len(info.forward_derivatives) == 1 | 
 |                 )  # Only single output inplace should exist | 
 |                 fw_info = info.forward_derivatives[0] | 
 |                 formula = fw_info.formula | 
 |  | 
 |                 def replace_self_with_original_self(formula: str, postfix: str) -> str: | 
 |                     def repl(m: re.Match[str]) -> str: | 
 |                         return f"{m.group(1)}original_self{postfix}{m.group(2)}" | 
 |  | 
 |                     return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula) | 
 |  | 
 |                 if re.search(IDENT_REGEX.format("self_p"), formula): | 
 |                     if is_exact_match: | 
 |                         # For manually defined formulas, don't allow the original value to be used | 
 |                         raise RuntimeError( | 
 |                             f'The formula for "{f.func.name}" is using the original value of self ' | 
 |                             "that is being modified inplace. This would lead to wrong forward gradients. " | 
 |                             'Please use "result" in the formula only.' | 
 |                         ) | 
 |                     else: | 
 |                         # When the original formula is out of place, we save a clone of the primal | 
 |                         # value to be able to access this value if needed | 
 |                         # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t" | 
 |                         formula = replace_self_with_original_self(formula, "_p") | 
 |                         formula = replace_self_with_original_self(formula, "_t") | 
 |  | 
 |                 # replace "result" from the formula by "self_p" | 
 |                 def repl(m: re.Match[str]) -> str: | 
 |                     return f"{m.group(1)}self_p{m.group(2)}" | 
 |  | 
 |                 formula = re.sub(IDENT_REGEX.format("result"), repl, formula) | 
 |  | 
 |                 required_primals = fw_info.required_inputs_primal | 
 |                 if re.search(IDENT_REGEX.format("self_p"), formula): | 
 |                     required_primals = ( | 
 |                         required_primals + ("self",) if required_primals else ("self",) | 
 |                     ) | 
 |  | 
 |                 if not is_exact_match: | 
 |                     # NOTE [In-place forward AD formula Optimization] | 
 |                     # | 
 |                     # This optimization transforms the formula to directly do inplace, i.e. | 
 |                     # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met: | 
 |                     # | 
 |                     # 1) the formula satisfies the pattern: "self_t.op(*args)" | 
 |                     # 2) "op" in (1) needs to be the same as the op the derivative is for | 
 |                     # | 
 |                     # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2) | 
 |                     # If there is a need, we can relax (2) to allow any op that has an in-place variant | 
 |                     is_single_method_on_self_t = False | 
 |                     directly_do_inplace = False | 
 |                     op_name: str | None = None | 
 |                     between_parens: str | None = None | 
 |                     match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula) | 
 |                     if match: | 
 |                         op_name, between_parens = match.group(1), match.group(2) | 
 |  | 
 |                         # We want to... | 
 |                         #   Match: self_t.op1(other_p.op2(arg)) | 
 |                         #   Avoid: self_t.op1(args) + self_t.op2(args) | 
 |                         #   Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args) | 
 |                         def check_parens_nest_level_gt_zero(s: str) -> bool: | 
 |                             level = 1 | 
 |                             for ch in s: | 
 |                                 if ch == ")": | 
 |                                     level -= 1 | 
 |                                     if level == 0: | 
 |                                         return False | 
 |                                 if ch == "(": | 
 |                                     level += 1 | 
 |                             return True | 
 |  | 
 |                         is_single_method_on_self_t = check_parens_nest_level_gt_zero( | 
 |                             between_parens | 
 |                         ) | 
 |                         directly_do_inplace = ( | 
 |                             is_single_method_on_self_t and op_name == info.name | 
 |                         ) | 
 |  | 
 |                     if directly_do_inplace: | 
 |                         assert op_name is not None | 
 |                         assert between_parens is not None | 
 |                         formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}" | 
 |                     else: | 
 |                         # Make sure that the forward grad is modified inplace when the original formula | 
 |                         # is out of place | 
 |                         formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}" | 
 |  | 
 |                 required_original_self_value = bool( | 
 |                     re.search(IDENT_REGEX.format("original_self_p"), formula) | 
 |                 ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula)) | 
 |  | 
 |                 forward_derivatives = [ | 
 |                     ForwardDerivative( | 
 |                         formula=formula, | 
 |                         var_names=("self",), | 
 |                         var_types=fw_info.var_types, | 
 |                         required_inputs_fw_grad=fw_info.required_inputs_fw_grad, | 
 |                         required_inputs_primal=required_primals, | 
 |                         required_original_self_value=required_original_self_value, | 
 |                         is_reusing_outplace_formula=not is_exact_match, | 
 |                     ), | 
 |                 ] | 
 |  | 
 |             fw_derivative_dict[key] = forward_derivatives | 
 |  | 
 |         result.append( | 
 |             NativeFunctionWithDifferentiabilityInfo( | 
 |                 func=f, info=info_dict, fw_derivatives=fw_derivative_dict | 
 |             ) | 
 |         ) | 
 |  | 
 |     return result | 
 |  | 
 |  | 
 | def is_differentiable( | 
 |     name: str, type: Type, info: DifferentiabilityInfo | None | 
 | ) -> bool: | 
 |     return type.is_tensor_like() and ( | 
 |         info is None or name not in info.non_differentiable_arg_names | 
 |     ) | 
 |  | 
 |  | 
 | def gen_differentiable_outputs( | 
 |     fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" | 
 | ) -> list[DifferentiableOutput]: | 
 |     f = fn.func | 
 |     info = fn.info[key] if fn.info else None | 
 |     outputs: list[DifferentiableOutput] = [ | 
 |         DifferentiableOutput( | 
 |             name=name, | 
 |             type=ret.type, | 
 |             cpp_type=cpp.return_type(ret, symint=True).cpp_type(), | 
 |         ) | 
 |         for name, ret in zip(cpp.return_names(f), f.func.returns) | 
 |     ] | 
 |     output_differentiability = info.output_differentiability if info else None | 
 |     if output_differentiability is not None: | 
 |         if len(output_differentiability) != len(outputs): | 
 |             raise RuntimeError( | 
 |                 f"The length of output_differentiability ({len(output_differentiability)}), " | 
 |                 f"does not match the number of outputs ({len(outputs)})." | 
 |             ) | 
 |         differentiable_outputs: list[DifferentiableOutput] = [] | 
 |         if False in output_differentiability and f.func.kind() == SchemaKind.inplace: | 
 |             raise RuntimeError( | 
 |                 "output_differentiability=False for inplace operation (version_counter won't get updated)" | 
 |             ) | 
 |         for differentiable, output in zip(output_differentiability, outputs): | 
 |             if differentiable: | 
 |                 differentiable_outputs.append(output) | 
 |         return differentiable_outputs | 
 |     candidate_differentiable_outputs = list( | 
 |         filter(lambda r: is_differentiable(r.name, r.type, info), outputs) | 
 |     ) | 
 |     if uses_single_grad(info): | 
 |         return candidate_differentiable_outputs[:1] | 
 |     else: | 
 |         return candidate_differentiable_outputs |