Forward AD formulas for activation backwards (#70460)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70460

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33405363

Pulled By: soulitzer

fbshipit-source-id: f68b59857a609ff593e9e399b9287d58dacef9e2
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 26323f7..72cc8c9 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -692,6 +692,7 @@
 
 - name: hardsigmoid(Tensor self) -> Tensor
   self: hardsigmoid_backward(grad, self)
+  result: auto_element_wise
 
 - name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
   output_differentiability: [False]
@@ -1792,6 +1793,7 @@
 - name: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor
   grad_out: hardshrink_backward(grad, self, lambd)
   self: zeros_like(grad)
+  result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_out_t, at::zeros({}, result.options()).expand_as(result))
 
 - name: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
   self: hardtanh_backward(grad, self, min_val, max_val)
@@ -2142,6 +2144,7 @@
 - name: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor
   grad_output: hardtanh_backward(grad, self, min_val, max_val)
   self: zeros_like(grad)
+  result: at::where((self_p > min_val).logical_and(self_p < max_val), grad_output_t, at::zeros({}, result.options()).expand_as(result))
 
 - name: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor
   grad_output: kl_div_double_backward_grad_output(grad, self, target, reduction, log_target)
@@ -2258,6 +2261,7 @@
 - name: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor
   grad_output: softshrink_backward(grad, self, lambd)
   self: zeros_like(grad)
+  result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_output_t, at::zeros({}, result.options()).expand_as(result))
 
 - name: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor
   grad_output: threshold_backward(grad, self, threshold)
@@ -2352,10 +2356,12 @@
 - name: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor
   grad_output: sigmoid_backward(grad, output.conj())
   output: grad.conj() * grad_output * (-2 * output.conj() + 1)
+  result: sigmoid_backward(grad_output_t, output_p) + output_t.conj() * grad_output_p * (-2 * output_p.conj() + 1)
 
 - name: tanh_backward(Tensor grad_output, Tensor output) -> Tensor
   grad_output: tanh_backward(grad, output.conj())
   output: grad.conj() * (-2 * output.conj() * grad_output)
+  result: tanh_backward(grad_output_t, output_p) + output_t.conj() * (-2 * output_p.conj() * grad_output_p)
 
 # cudnn
 - name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 1da49d9..fd14ca0 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -11419,6 +11419,7 @@
         supports_autograd=True,
         assert_autodiffed=False,
         supports_gradgrad=False,
+        supports_forward_ad=True,
         supports_out=False,
         inplace_variant=partial(torch.nn.functional.hardsigmoid, inplace=True),
         decorators=[
@@ -11487,7 +11488,7 @@
         dtypes=all_types_and_complex_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
         supports_forward_ad=True,
-        supports_fwgrad_bwgrad=False,  # Need: tanh_backward
+        supports_fwgrad_bwgrad=True,
         supports_autograd=True,
         assert_autodiffed=False,
         supports_gradgrad=True,
@@ -11612,6 +11613,7 @@
            dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
            supports_autograd=True,
            supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
            assert_autodiffed=False,
            sample_inputs_func=sample_inputs_softshrink_hardshrink_hardtanh,
            supports_gradgrad=True,
@@ -11627,7 +11629,7 @@
            supports_gradgrad=True,
            supports_out=False,
            supports_forward_ad=True,
-           supports_fwgrad_bwgrad=False,  # Need: hardshrink_backward
+           supports_fwgrad_bwgrad=True,
            autodiff_nonfusible_nodes=["aten::hardshrink"]),
     OpInfo('nn.functional.hardtanh',
            aten_name="hardtanh",
@@ -11641,7 +11643,7 @@
            supports_gradgrad=True,
            supports_out=False,
            supports_forward_ad=True,
-           supports_fwgrad_bwgrad=False,  # Need: hardtanh_backward
+           supports_fwgrad_bwgrad=True,
            autodiff_nonfusible_nodes=["aten::hardtanh"],
            ),
     OpInfo('nn.functional.gelu',
@@ -11668,7 +11670,7 @@
            supports_gradgrad=True,
            supports_out=False,
            supports_forward_ad=True,
-           supports_fwgrad_bwgrad=False,  # Need: hardtanh_backward
+           supports_fwgrad_bwgrad=True,
            autodiff_nonfusible_nodes=["aten::relu6"]),
     OpInfo('mm',
            dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
@@ -12219,7 +12221,7 @@
                    safe_casts_outputs=True,
                    assert_jit_shape_analysis=True,
                    supports_forward_ad=True,
-                   supports_fwgrad_bwgrad=False,  # Need: tanh_backward
+                   supports_fwgrad_bwgrad=True,
                    supports_sparse=True,
                    supports_sparse_csr=True,
                    skips=(
@@ -14032,7 +14034,7 @@
                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    safe_casts_outputs=True,
                    supports_forward_ad=True,
-                   supports_fwgrad_bwgrad=False,  # Need: sigmoid_backward
+                   supports_fwgrad_bwgrad=True,
                    assert_autodiffed=True,
                    # sigmoid(z) = 1 / (1 + exp(-z)), at z = j * pi * odd_number, the denominator is zero
                    reference_numerics_filter=NumericsFilter(