[mta] Backward of unary foreach functions (#89591)
as per title, this PR defines backward of those.
This doesn't implement forward-mode automatic differentiation as [the current codegen](https://github.com/pytorch/pytorch/blob/a747326423ed4731996769e3b8eb73eecbdee2d4/tools/autograd/gen_variable_type.py#L1513) doesn't seem to handle `ArrayRef<Tensor>`.
Rel:
- https://github.com/pytorch/pytorch/issues/53796
- https://github.com/pytorch/pytorch/issues/58833
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89591
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh
index 3b332cc..ec625e1 100644
--- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh
+++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh
@@ -7,6 +7,14 @@
namespace {
+// TODO(crcrpar): Handle version bump in codegen.
+// rel: https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482
+inline void increment_version(TensorList tensors) {
+ for (const auto & t : tensors) {
+ t.unsafeGetTensorImpl()->bump_version();
+ }
+}
+
// Initializes args and checks if all args are aligned
template<int depth, typename T>
__device__ bool init_args(
diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu
index 35e0077..693a0d8 100644
--- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu
@@ -73,6 +73,7 @@
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
Op<opmath_t>());
+ increment_version(tensors);
}
template <template<class> class Op>
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index 95dfad7..bbec47f 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -597,6 +597,73 @@
triangular_solve
vdot
+Foreach Operations
+~~~~~~~~~~~~~~~~~~
+
+.. warning::
+ This API is in beta and subject to future changes.
+ Forward-mode AD is not supported.
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
+
+ _foreach_abs
+ _foreach_abs_
+ _foreach_acos
+ _foreach_acos_
+ _foreach_asin
+ _foreach_asin_
+ _foreach_atan
+ _foreach_atan_
+ _foreach_ceil
+ _foreach_ceil_
+ _foreach_cos
+ _foreach_cos_
+ _foreach_cosh
+ _foreach_cosh_
+ _foreach_erf
+ _foreach_erf_
+ _foreach_erfc
+ _foreach_erfc_
+ _foreach_exp
+ _foreach_exp_
+ _foreach_expm1
+ _foreach_expm1_
+ _foreach_floor
+ _foreach_floor_
+ _foreach_log
+ _foreach_log_
+ _foreach_log10
+ _foreach_log10_
+ _foreach_log1p
+ _foreach_log1p_
+ _foreach_log2
+ _foreach_log2_
+ _foreach_neg
+ _foreach_neg_
+ _foreach_tan
+ _foreach_tan_
+ _foreach_sin
+ _foreach_sin_
+ _foreach_sinh
+ _foreach_sinh_
+ _foreach_round
+ _foreach_round_
+ _foreach_sqrt
+ _foreach_sqrt_
+ _foreach_lgamma
+ _foreach_lgamma_
+ _foreach_frac
+ _foreach_frac_
+ _foreach_reciprocal
+ _foreach_reciprocal_
+ _foreach_sigmoid
+ _foreach_sigmoid_
+ _foreach_trunc
+ _foreach_trunc_
+ _foreach_zero_
+
Utilities
----------------------------------
.. autosummary::
diff --git a/test/test_foreach.py b/test/test_foreach.py
index ac2a754..130f010 100644
--- a/test/test_foreach.py
+++ b/test/test_foreach.py
@@ -222,6 +222,24 @@
inplace_ref(copied_inputs),
self.assertEqual(copied_inputs, inputs)
+ def _test_unary(self, device, dtype, opinfo, N, is_fastpath):
+ op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, 1)
+ inputs = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
+ # note(mkozuki): Complex inputs for `_foreach_abs` go through slowpath.
+ if opinfo.name == "_foreach_abs" and dtype in complex_types():
+ is_fastpath = False
+ self._regular_unary_test(dtype, op, ref, inputs, is_fastpath)
+ self._inplace_unary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath)
+
+ if opinfo.supports_autograd and dtype in floating_types():
+ tensors = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath, same_size=True)
+ tensors = [t.requires_grad_() for t in tensors]
+ ref_tensors = [t.clone().detach().requires_grad_() for t in tensors]
+
+ sum(op.func(tensors)).mean().backward()
+ sum([ref.func(t) for t in ref_tensors]).mean().backward()
+ self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
+
@skipMeta
@ops(foreach_unary_op_db)
@parametrize("is_fastpath", (True, False))
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 0d56682..9ec2bb3 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -42,7 +42,7 @@
# to that argument could exist. You should either:
# - Specify the formula for that gradient
# - Specify not_implemented("function_name") as a formula to say that this is not
-# implement yet (but might be in the future and the user can request that on an issue)
+# implemented yet (but might be in the future and the user can request that on an issue)
# - If that argument is not differentiable, because it is not a floating point dtype or the
# function is not differentiable with respect to that argument for
# example. You should either:
diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py
index 19e4809..f7b30cf 100644
--- a/tools/autograd/gen_autograd_functions.py
+++ b/tools/autograd/gen_autograd_functions.py
@@ -98,6 +98,23 @@
"""
)
+# note(crcrpar): `self` argument and other optional positional argument
+# of foreach functions are basically a list of n `Tensor`s thus iterating over
+# `grads` in order to utilize and apply the existing derivative definitions
+# to each `Tensor`(s) of `self`, and the others.
+DERIVATIVE_SINGLE_FOREACH = CodeTemplate(
+ """\
+if (task_should_compute_output({ ${name}_ix })) {
+ std::vector<Tensor> grad_result;
+ grad_result.reserve(grads.size());
+ for (const auto & i : c10::irange(grads.size())) {
+ grad_result.emplace_back(${derivative});
+ }
+ copy_range(grad_inputs, ${name}_ix, grad_result);
+}
+"""
+)
+
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
"""\
if (task_should_compute_output({ ${name}_ix })) {
@@ -709,9 +726,13 @@
) in ("Tensor", "Tensor?"):
formula = "any_grad_defined ? (" + formula + ") : Tensor()"
checks_any_grad_defined = True
+ if info.name.startswith("_foreach_"):
+ derivative_template = DERIVATIVE_SINGLE_FOREACH
+ else:
+ derivative_template = DERIVATIVE_SINGLE
return (
checks_any_grad_defined,
- DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula),
+ derivative_template.substitute(name=var_names[0], derivative=formula),
)
else:
if "grad_input_mask" in formula:
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index ee0254e..7d0a3c3 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -14003,3 +14003,59 @@
are freshly created instead of aliasing the input.
""",
)
+
+for unary_base_func_name in (
+ "exp",
+ "sqrt",
+ "abs",
+ "acos",
+ "asin",
+ "atan",
+ "ceil",
+ "cos",
+ "cosh",
+ "erf",
+ "erfc",
+ "expm1",
+ "floor",
+ "log",
+ "log10",
+ "log1p",
+ "log2",
+ "neg",
+ "tan",
+ "tanh",
+ "sin",
+ "sinh",
+ "round",
+ "lgamma",
+ "frac",
+ "reciprocal",
+ "sigmoid",
+ "trunc",
+ "zero",
+):
+ unary_foreach_func_name = f"_foreach_{unary_base_func_name}"
+ if hasattr(torch, unary_foreach_func_name):
+ add_docstr(
+ getattr(torch, unary_foreach_func_name),
+ r"""
+{}(self: List[Tensor]) -> List[Tensor]
+
+Apply :func:`torch.{}` to each Tensor of the input list.
+ """.format(
+ unary_foreach_func_name, unary_base_func_name
+ ),
+ )
+ unary_inplace_foreach_func_name = f"{unary_foreach_func_name}_"
+ if hasattr(torch, unary_inplace_foreach_func_name):
+ add_docstr(
+ getattr(torch, unary_inplace_foreach_func_name),
+ r"""
+{}(self: List[Tensor]) -> None
+
+Apply :func:`torch.{}` to each Tensor of the input list.
+ """.format(
+ unary_inplace_foreach_func_name, unary_base_func_name
+ ),
+ )
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 2eb7e8c..45661db 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8074,25 +8074,26 @@
foreach_unary_op_db: List[OpInfo] = [
- ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
- ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
- ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
- ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
- ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
- ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
- ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
- ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
- ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
- ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
- ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
- ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
- ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
+ ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+ ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+ ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+ ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+ ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+ ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+ ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+ ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+ ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+ ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+ ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+ ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
+ ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo(
'neg',
dtypes=all_types_and_complex(),
dtypesIfCUDA=all_types_and_complex(),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8100,6 +8101,7 @@
dtypes=floating_and_complex_types_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8107,6 +8109,7 @@
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8114,6 +8117,7 @@
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8121,6 +8125,7 @@
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8128,6 +8133,7 @@
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8135,6 +8141,7 @@
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8142,6 +8149,7 @@
dtypes=floating_and_complex_types_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8149,6 +8157,7 @@
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8156,6 +8165,7 @@
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8163,6 +8173,7 @@
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8170,6 +8181,7 @@
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8177,6 +8189,7 @@
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
ForeachFuncInfo(
@@ -8186,6 +8199,7 @@
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+ supports_autograd=True,
),
]
diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py
index 36522f1..313c54a 100644
--- a/torch/testing/_internal/opinfo/core.py
+++ b/torch/testing/_internal/opinfo/core.py
@@ -2571,6 +2571,7 @@
dtypesIfROCM=None,
supports_alpha_param=False,
sample_inputs_func=sample_inputs_foreach,
+ supports_autograd=False,
**kwargs,
):
super().__init__(
@@ -2579,6 +2580,7 @@
dtypesIfCUDA=dtypesIfCUDA,
dtypesIfROCM=dtypesIfROCM,
sample_inputs_func=sample_inputs_func,
+ supports_autograd=supports_autograd,
**kwargs,
)
diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py
index affce13..bb3998f 100644
--- a/torchgen/api/autograd.py
+++ b/torchgen/api/autograd.py
@@ -1,9 +1,10 @@
+import copy
import re
from dataclasses import dataclass
from typing import Dict, List, Match, Optional, Sequence, Set, Tuple
from torchgen.api import cpp
-from torchgen.api.types import Binding, NamedCType
+from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
from torchgen.model import (
FunctionSchema,
NativeFunction,
@@ -357,6 +358,94 @@
this is not currently supported (we'd need to fix up the formula in the codegen)."""
return info_dict, False
+ # (4) Generate derivative information of unary foreach functions if none is defined in `derivatives.yaml`
+ base_op_name = f.func.name.name
+ if (
+ base_op_name.base.startswith("_foreach")
+ and not base_op_name.inplace
+ and len(f.func.arguments.post_self_positional) == 0
+ ):
+ ref_native_op_name = base_op_name.base.split("_foreach_")[-1]
+ for function_schema in functional_info_by_signature:
+ if (
+ function_schema.name.name.base == ref_native_op_name
+ and not function_schema.name.name.inplace
+ ):
+ all_saved_inputs = []
+ all_saved_outputs = []
+ diff_info_dict = copy.deepcopy(
+ differentiability_infos[function_schema]
+ )
+ diff_info = diff_info_dict["Default"]
+ modified_derivative_formulas = []
+ for derivative in diff_info.derivatives:
+ saved_inputs = []
+ saved_outputs = []
+ modified_formula = (
+ derivative.formula.replace("grad", "grads[i]")
+ .replace("self", "self[i]")
+ .replace("result", "result[i]")
+ )
+ if "self" in modified_formula:
+ saved_inputs.append(
+ SavedAttribute(
+ nctype=NamedCType(
+ name="self", type=BaseCType(tensorListT)
+ ),
+ expr="self",
+ )
+ )
+ all_saved_inputs.append(saved_inputs[-1])
+ if "result" in modified_formula:
+ saved_outputs.append(
+ SavedAttribute(
+ nctype=NamedCType(
+ name="result", type=BaseCType(tensorListT)
+ ),
+ expr="result",
+ )
+ )
+ all_saved_outputs.append(saved_outputs[-1])
+ modified_derivative = Derivative(
+ formula=modified_formula,
+ original_formula=derivative.original_formula,
+ var_names=("self",),
+ saved_inputs=tuple(saved_inputs),
+ saved_outputs=tuple(saved_outputs),
+ named_gradients=set(),
+ )
+ modified_derivative_formulas.append(modified_derivative)
+ assert f.func.arguments.self_arg is not None
+ diff_info = DifferentiabilityInfo(
+ name=base_op_name.base,
+ func=f,
+ op=f"Foreach{diff_info.op}",
+ derivatives=modified_derivative_formulas,
+ forward_derivatives=[],
+ all_saved_inputs=tuple(set(all_saved_inputs)),
+ all_saved_outputs=tuple(set(all_saved_outputs)),
+ available_named_gradients=(),
+ used_named_gradients=set(),
+ args_with_derivatives=[
+ Binding(
+ name="self",
+ nctype=NamedCType(
+ name="self", type=BaseCType(tensorListT)
+ ),
+ argument=f.func.arguments.self_arg.argument,
+ default=None,
+ )
+ ],
+ non_differentiable_arg_names=[],
+ output_differentiability=None,
+ output_differentiability_conditions=None,
+ )
+ diff_info_dict["Default"] = diff_info
+ if f.func not in differentiability_infos:
+ differentiability_infos[f.func] = diff_info_dict
+ functional_info_by_signature[f.func] = diff_info_dict
+ return diff_info_dict, True
+
return None, False
result: List[NativeFunctionWithDifferentiabilityInfo] = []