Forward AD formulas for activation backwards (#70460)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70460
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D33405363
Pulled By: soulitzer
fbshipit-source-id: f68b59857a609ff593e9e399b9287d58dacef9e2
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 26323f7..72cc8c9 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -692,6 +692,7 @@
- name: hardsigmoid(Tensor self) -> Tensor
self: hardsigmoid_backward(grad, self)
+ result: auto_element_wise
- name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
output_differentiability: [False]
@@ -1792,6 +1793,7 @@
- name: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor
grad_out: hardshrink_backward(grad, self, lambd)
self: zeros_like(grad)
+ result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_out_t, at::zeros({}, result.options()).expand_as(result))
- name: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
self: hardtanh_backward(grad, self, min_val, max_val)
@@ -2142,6 +2144,7 @@
- name: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor
grad_output: hardtanh_backward(grad, self, min_val, max_val)
self: zeros_like(grad)
+ result: at::where((self_p > min_val).logical_and(self_p < max_val), grad_output_t, at::zeros({}, result.options()).expand_as(result))
- name: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor
grad_output: kl_div_double_backward_grad_output(grad, self, target, reduction, log_target)
@@ -2258,6 +2261,7 @@
- name: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor
grad_output: softshrink_backward(grad, self, lambd)
self: zeros_like(grad)
+ result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_output_t, at::zeros({}, result.options()).expand_as(result))
- name: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor
grad_output: threshold_backward(grad, self, threshold)
@@ -2352,10 +2356,12 @@
- name: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor
grad_output: sigmoid_backward(grad, output.conj())
output: grad.conj() * grad_output * (-2 * output.conj() + 1)
+ result: sigmoid_backward(grad_output_t, output_p) + output_t.conj() * grad_output_p * (-2 * output_p.conj() + 1)
- name: tanh_backward(Tensor grad_output, Tensor output) -> Tensor
grad_output: tanh_backward(grad, output.conj())
output: grad.conj() * (-2 * output.conj() * grad_output)
+ result: tanh_backward(grad_output_t, output_p) + output_t.conj() * (-2 * output_p.conj() * grad_output_p)
# cudnn
- name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 1da49d9..fd14ca0 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -11419,6 +11419,7 @@
supports_autograd=True,
assert_autodiffed=False,
supports_gradgrad=False,
+ supports_forward_ad=True,
supports_out=False,
inplace_variant=partial(torch.nn.functional.hardsigmoid, inplace=True),
decorators=[
@@ -11487,7 +11488,7 @@
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
- supports_fwgrad_bwgrad=False, # Need: tanh_backward
+ supports_fwgrad_bwgrad=True,
supports_autograd=True,
assert_autodiffed=False,
supports_gradgrad=True,
@@ -11612,6 +11613,7 @@
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_autograd=True,
supports_forward_ad=True,
+ supports_fwgrad_bwgrad=True,
assert_autodiffed=False,
sample_inputs_func=sample_inputs_softshrink_hardshrink_hardtanh,
supports_gradgrad=True,
@@ -11627,7 +11629,7 @@
supports_gradgrad=True,
supports_out=False,
supports_forward_ad=True,
- supports_fwgrad_bwgrad=False, # Need: hardshrink_backward
+ supports_fwgrad_bwgrad=True,
autodiff_nonfusible_nodes=["aten::hardshrink"]),
OpInfo('nn.functional.hardtanh',
aten_name="hardtanh",
@@ -11641,7 +11643,7 @@
supports_gradgrad=True,
supports_out=False,
supports_forward_ad=True,
- supports_fwgrad_bwgrad=False, # Need: hardtanh_backward
+ supports_fwgrad_bwgrad=True,
autodiff_nonfusible_nodes=["aten::hardtanh"],
),
OpInfo('nn.functional.gelu',
@@ -11668,7 +11670,7 @@
supports_gradgrad=True,
supports_out=False,
supports_forward_ad=True,
- supports_fwgrad_bwgrad=False, # Need: hardtanh_backward
+ supports_fwgrad_bwgrad=True,
autodiff_nonfusible_nodes=["aten::relu6"]),
OpInfo('mm',
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
@@ -12219,7 +12221,7 @@
safe_casts_outputs=True,
assert_jit_shape_analysis=True,
supports_forward_ad=True,
- supports_fwgrad_bwgrad=False, # Need: tanh_backward
+ supports_fwgrad_bwgrad=True,
supports_sparse=True,
supports_sparse_csr=True,
skips=(
@@ -14032,7 +14034,7 @@
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
safe_casts_outputs=True,
supports_forward_ad=True,
- supports_fwgrad_bwgrad=False, # Need: sigmoid_backward
+ supports_fwgrad_bwgrad=True,
assert_autodiffed=True,
# sigmoid(z) = 1 / (1 + exp(-z)), at z = j * pi * odd_number, the denominator is zero
reference_numerics_filter=NumericsFilter(