[mta] Backward of unary foreach functions (#89591)

as per title, this PR defines backward of those.

This doesn't implement forward-mode automatic differentiation as [the current codegen](https://github.com/pytorch/pytorch/blob/a747326423ed4731996769e3b8eb73eecbdee2d4/tools/autograd/gen_variable_type.py#L1513) doesn't seem to handle `ArrayRef<Tensor>`.

Rel:
- https://github.com/pytorch/pytorch/issues/53796
- https://github.com/pytorch/pytorch/issues/58833

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89591
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh
index 3b332cc..ec625e1 100644
--- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh
+++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh
@@ -7,6 +7,14 @@
 
 namespace {
 
+// TODO(crcrpar): Handle version bump in codegen.
+// rel: https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482
+inline void increment_version(TensorList tensors) {
+  for (const auto & t : tensors) {
+    t.unsafeGetTensorImpl()->bump_version();
+  }
+}
+
 // Initializes args and checks if all args are aligned
 template<int depth, typename T>
 __device__ bool init_args(
diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu
index 35e0077..693a0d8 100644
--- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu
@@ -73,6 +73,7 @@
                                          /* r_args_depth */ 1,
                                          /* res_arg_index */ 0>(),
                           Op<opmath_t>());
+    increment_version(tensors);
 }
 
 template <template<class> class Op>
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index 95dfad7..bbec47f 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -597,6 +597,73 @@
     triangular_solve
     vdot
 
+Foreach Operations
+~~~~~~~~~~~~~~~~~~
+
+.. warning::
+    This API is in beta and subject to future changes.
+    Forward-mode AD is not supported.
+
+.. autosummary::
+    :toctree: generated
+    :nosignatures:
+
+    _foreach_abs
+    _foreach_abs_
+    _foreach_acos
+    _foreach_acos_
+    _foreach_asin
+    _foreach_asin_
+    _foreach_atan
+    _foreach_atan_
+    _foreach_ceil
+    _foreach_ceil_
+    _foreach_cos
+    _foreach_cos_
+    _foreach_cosh
+    _foreach_cosh_
+    _foreach_erf
+    _foreach_erf_
+    _foreach_erfc
+    _foreach_erfc_
+    _foreach_exp
+    _foreach_exp_
+    _foreach_expm1
+    _foreach_expm1_
+    _foreach_floor
+    _foreach_floor_
+    _foreach_log
+    _foreach_log_
+    _foreach_log10
+    _foreach_log10_
+    _foreach_log1p
+    _foreach_log1p_
+    _foreach_log2
+    _foreach_log2_
+    _foreach_neg
+    _foreach_neg_
+    _foreach_tan
+    _foreach_tan_
+    _foreach_sin
+    _foreach_sin_
+    _foreach_sinh
+    _foreach_sinh_
+    _foreach_round
+    _foreach_round_
+    _foreach_sqrt
+    _foreach_sqrt_
+    _foreach_lgamma
+    _foreach_lgamma_
+    _foreach_frac
+    _foreach_frac_
+    _foreach_reciprocal
+    _foreach_reciprocal_
+    _foreach_sigmoid
+    _foreach_sigmoid_
+    _foreach_trunc
+    _foreach_trunc_
+    _foreach_zero_
+
 Utilities
 ----------------------------------
 .. autosummary::
diff --git a/test/test_foreach.py b/test/test_foreach.py
index ac2a754..130f010 100644
--- a/test/test_foreach.py
+++ b/test/test_foreach.py
@@ -222,6 +222,24 @@
             inplace_ref(copied_inputs),
             self.assertEqual(copied_inputs, inputs)
 
+    def _test_unary(self, device, dtype, opinfo, N, is_fastpath):
+        op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, 1)
+        inputs = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
+        # note(mkozuki): Complex inputs for `_foreach_abs` go through slowpath.
+        if opinfo.name == "_foreach_abs" and dtype in complex_types():
+            is_fastpath = False
+        self._regular_unary_test(dtype, op, ref, inputs, is_fastpath)
+        self._inplace_unary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath)
+
+        if opinfo.supports_autograd and dtype in floating_types():
+            tensors = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath, same_size=True)
+            tensors = [t.requires_grad_() for t in tensors]
+            ref_tensors = [t.clone().detach().requires_grad_() for t in tensors]
+
+            sum(op.func(tensors)).mean().backward()
+            sum([ref.func(t) for t in ref_tensors]).mean().backward()
+            self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
+
     @skipMeta
     @ops(foreach_unary_op_db)
     @parametrize("is_fastpath", (True, False))
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 0d56682..9ec2bb3 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -42,7 +42,7 @@
 #     to that argument could exist. You should either:
 #       - Specify the formula for that gradient
 #       - Specify not_implemented("function_name") as a formula to say that this is not
-#         implement yet (but might be in the future and the user can request that on an issue)
+#         implemented yet (but might be in the future and the user can request that on an issue)
 #   - If that argument is not differentiable, because it is not a floating point dtype or the
 #     function is not differentiable with respect to that argument  for
 #     example. You should either:
diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py
index 19e4809..f7b30cf 100644
--- a/tools/autograd/gen_autograd_functions.py
+++ b/tools/autograd/gen_autograd_functions.py
@@ -98,6 +98,23 @@
 """
 )
 
+# note(crcrpar): `self` argument and other optional positional argument
+# of foreach functions are basically a list of n `Tensor`s thus iterating over
+# `grads` in order to utilize and apply the existing derivative definitions
+# to each `Tensor`(s) of `self`, and the others.
+DERIVATIVE_SINGLE_FOREACH = CodeTemplate(
+    """\
+if (task_should_compute_output({ ${name}_ix })) {
+  std::vector<Tensor> grad_result;
+  grad_result.reserve(grads.size());
+  for (const auto & i : c10::irange(grads.size())) {
+    grad_result.emplace_back(${derivative});
+  }
+  copy_range(grad_inputs, ${name}_ix, grad_result);
+}
+"""
+)
+
 DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
     """\
   if (task_should_compute_output({ ${name}_ix })) {
@@ -709,9 +726,13 @@
                     ) in ("Tensor", "Tensor?"):
                         formula = "any_grad_defined ? (" + formula + ") : Tensor()"
                         checks_any_grad_defined = True
+            if info.name.startswith("_foreach_"):
+                derivative_template = DERIVATIVE_SINGLE_FOREACH
+            else:
+                derivative_template = DERIVATIVE_SINGLE
             return (
                 checks_any_grad_defined,
-                DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula),
+                derivative_template.substitute(name=var_names[0], derivative=formula),
             )
         else:
             if "grad_input_mask" in formula:
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index ee0254e..7d0a3c3 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -14003,3 +14003,59 @@
 are freshly created instead of aliasing the input.
 """,
 )
+
+for unary_base_func_name in (
+    "exp",
+    "sqrt",
+    "abs",
+    "acos",
+    "asin",
+    "atan",
+    "ceil",
+    "cos",
+    "cosh",
+    "erf",
+    "erfc",
+    "expm1",
+    "floor",
+    "log",
+    "log10",
+    "log1p",
+    "log2",
+    "neg",
+    "tan",
+    "tanh",
+    "sin",
+    "sinh",
+    "round",
+    "lgamma",
+    "frac",
+    "reciprocal",
+    "sigmoid",
+    "trunc",
+    "zero",
+):
+    unary_foreach_func_name = f"_foreach_{unary_base_func_name}"
+    if hasattr(torch, unary_foreach_func_name):
+        add_docstr(
+            getattr(torch, unary_foreach_func_name),
+            r"""
+{}(self: List[Tensor]) -> List[Tensor]
+
+Apply :func:`torch.{}` to each Tensor of the input list.
+            """.format(
+                unary_foreach_func_name, unary_base_func_name
+            ),
+        )
+    unary_inplace_foreach_func_name = f"{unary_foreach_func_name}_"
+    if hasattr(torch, unary_inplace_foreach_func_name):
+        add_docstr(
+            getattr(torch, unary_inplace_foreach_func_name),
+            r"""
+{}(self: List[Tensor]) -> None
+
+Apply :func:`torch.{}` to each Tensor of the input list.
+        """.format(
+                unary_inplace_foreach_func_name, unary_base_func_name
+            ),
+        )
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 2eb7e8c..45661db 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8074,25 +8074,26 @@
 
 
 foreach_unary_op_db: List[OpInfo] = [
-    ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
-    ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
-    ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
-    ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
-    ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
-    ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
-    ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
-    ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
-    ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
-    ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
-    ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
-    ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
-    ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
+    ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+    ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+    ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+    ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+    ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+    ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+    ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+    ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+    ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+    ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+    ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+    ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+    ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
 
     ForeachFuncInfo(
         'neg',
         dtypes=all_types_and_complex(),
         dtypesIfCUDA=all_types_and_complex(),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8100,6 +8101,7 @@
         dtypes=floating_and_complex_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_and_complex_types_and(torch.half),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8107,6 +8109,7 @@
         dtypes=all_types_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8114,6 +8117,7 @@
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8121,6 +8125,7 @@
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8128,6 +8133,7 @@
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8135,6 +8141,7 @@
         dtypes=all_types_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8142,6 +8149,7 @@
         dtypes=floating_and_complex_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_and_complex_types_and(torch.half),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8149,6 +8157,7 @@
         dtypes=all_types_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8156,6 +8165,7 @@
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8163,6 +8173,7 @@
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8170,6 +8181,7 @@
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8177,6 +8189,7 @@
         dtypes=all_types_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 
     ForeachFuncInfo(
@@ -8186,6 +8199,7 @@
         supports_forward_ad=True,
         supports_fwgrad_bwgrad=True,
         sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
     ),
 ]
 
diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py
index 36522f1..313c54a 100644
--- a/torch/testing/_internal/opinfo/core.py
+++ b/torch/testing/_internal/opinfo/core.py
@@ -2571,6 +2571,7 @@
         dtypesIfROCM=None,
         supports_alpha_param=False,
         sample_inputs_func=sample_inputs_foreach,
+        supports_autograd=False,
         **kwargs,
     ):
         super().__init__(
@@ -2579,6 +2580,7 @@
             dtypesIfCUDA=dtypesIfCUDA,
             dtypesIfROCM=dtypesIfROCM,
             sample_inputs_func=sample_inputs_func,
+            supports_autograd=supports_autograd,
             **kwargs,
         )
 
diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py
index affce13..bb3998f 100644
--- a/torchgen/api/autograd.py
+++ b/torchgen/api/autograd.py
@@ -1,9 +1,10 @@
+import copy
 import re
 from dataclasses import dataclass
 from typing import Dict, List, Match, Optional, Sequence, Set, Tuple
 
 from torchgen.api import cpp
-from torchgen.api.types import Binding, NamedCType
+from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
 from torchgen.model import (
     FunctionSchema,
     NativeFunction,
@@ -357,6 +358,94 @@
  this is not currently supported (we'd need to fix up the formula in the codegen)."""
             return info_dict, False
 
+        # (4) Generate derivative information of unary foreach functions if none is defined in `derivatives.yaml`
+        base_op_name = f.func.name.name
+        if (
+            base_op_name.base.startswith("_foreach")
+            and not base_op_name.inplace
+            and len(f.func.arguments.post_self_positional) == 0
+        ):
+            ref_native_op_name = base_op_name.base.split("_foreach_")[-1]
+            for function_schema in functional_info_by_signature:
+                if (
+                    function_schema.name.name.base == ref_native_op_name
+                    and not function_schema.name.name.inplace
+                ):
+                    all_saved_inputs = []
+                    all_saved_outputs = []
+                    diff_info_dict = copy.deepcopy(
+                        differentiability_infos[function_schema]
+                    )
+                    diff_info = diff_info_dict["Default"]
+                    modified_derivative_formulas = []
+                    for derivative in diff_info.derivatives:
+                        saved_inputs = []
+                        saved_outputs = []
+                        modified_formula = (
+                            derivative.formula.replace("grad", "grads[i]")
+                            .replace("self", "self[i]")
+                            .replace("result", "result[i]")
+                        )
+                        if "self" in modified_formula:
+                            saved_inputs.append(
+                                SavedAttribute(
+                                    nctype=NamedCType(
+                                        name="self", type=BaseCType(tensorListT)
+                                    ),
+                                    expr="self",
+                                )
+                            )
+                            all_saved_inputs.append(saved_inputs[-1])
+                        if "result" in modified_formula:
+                            saved_outputs.append(
+                                SavedAttribute(
+                                    nctype=NamedCType(
+                                        name="result", type=BaseCType(tensorListT)
+                                    ),
+                                    expr="result",
+                                )
+                            )
+                            all_saved_outputs.append(saved_outputs[-1])
+                        modified_derivative = Derivative(
+                            formula=modified_formula,
+                            original_formula=derivative.original_formula,
+                            var_names=("self",),
+                            saved_inputs=tuple(saved_inputs),
+                            saved_outputs=tuple(saved_outputs),
+                            named_gradients=set(),
+                        )
+                        modified_derivative_formulas.append(modified_derivative)
+                    assert f.func.arguments.self_arg is not None
+                    diff_info = DifferentiabilityInfo(
+                        name=base_op_name.base,
+                        func=f,
+                        op=f"Foreach{diff_info.op}",
+                        derivatives=modified_derivative_formulas,
+                        forward_derivatives=[],
+                        all_saved_inputs=tuple(set(all_saved_inputs)),
+                        all_saved_outputs=tuple(set(all_saved_outputs)),
+                        available_named_gradients=(),
+                        used_named_gradients=set(),
+                        args_with_derivatives=[
+                            Binding(
+                                name="self",
+                                nctype=NamedCType(
+                                    name="self", type=BaseCType(tensorListT)
+                                ),
+                                argument=f.func.arguments.self_arg.argument,
+                                default=None,
+                            )
+                        ],
+                        non_differentiable_arg_names=[],
+                        output_differentiability=None,
+                        output_differentiability_conditions=None,
+                    )
+                    diff_info_dict["Default"] = diff_info
+                    if f.func not in differentiability_infos:
+                        differentiability_infos[f.func] = diff_info_dict
+                        functional_info_by_signature[f.func] = diff_info_dict
+                    return diff_info_dict, True
+
         return None, False
 
     result: List[NativeFunctionWithDifferentiabilityInfo] = []