Reland OpInfo support for forward AD (#58304)

Summary:
Try 3 to land this.
Trying ci-all label to ensure we test everything.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58304

Reviewed By: heitorschueroff

Differential Revision: D28474343

Pulled By: albanD

fbshipit-source-id: 8230fa3c0a8d3633f09999e7c2f47dbdc5fe57e9
diff --git a/test/test_ops.py b/test/test_ops.py
index 4bfd723..f36fc8b 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -99,7 +99,7 @@
 
         return _fn
 
-    def _check_helper(self, device, dtype, op, variant, check):
+    def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False):
         if variant is None:
             self.skipTest("Skipped! Variant not implemented.")
         if not op.supports_dtype(dtype, torch.device(device).type):
@@ -139,8 +139,10 @@
                                           check_batched_grad=op.check_batched_grad,
                                           check_grad_dtypes=True,
                                           nondet_tol=op.gradcheck_nondet_tol,
-                                          fast_mode=op.gradcheck_fast_mode))
+                                          fast_mode=op.gradcheck_fast_mode,
+                                          check_forward_ad=check_forward_ad))
             elif check == 'gradgradcheck':
+                self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
                 self.assertTrue(gradgradcheck(fn, gradcheck_args,
                                               gen_non_contig_grad_outputs=False,
                                               check_batched_grad=op.check_batched_gradgrad,
@@ -156,8 +158,8 @@
             else:
                 self.assertTrue(False, msg="Unknown check requested!")
 
-    def _grad_test_helper(self, device, dtype, op, variant):
-        return self._check_helper(device, dtype, op, variant, 'gradcheck')
+    def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False):
+        return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad)
 
     def _gradgrad_test_helper(self, device, dtype, op, variant):
         return self._check_helper(device, dtype, op, variant, 'gradgradcheck')
@@ -221,6 +223,19 @@
             self.skipTest("Skipped! Operation does not support inplace autograd.")
         self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
 
+    @_gradcheck_ops(op_db)
+    def test_forward_mode_AD(self, device, dtype, op):
+        self._skip_helper(op, device, dtype)
+
+        if op.supports_forward_ad:
+            self._grad_test_helper(device, dtype, op, op.get_op(), check_forward_ad=True)
+        else:
+            err_msg = r"Trying to use forward AD with .* that does not support it\."
+            hint_msg = ("Running forward AD for an OP that has does not support it did not "
+                        "raise any error. If your op supports forward AD, you should set supports_forward_ad=True")
+            with self.assertRaisesRegex(RuntimeError, err_msg, msg=hint_msg):
+                self._grad_test_helper(device, dtype, op, op.get_op(), check_forward_ad=True)
+
 
 # Tests operators for consistency between JIT and eager, also checks
 #   correctness of JIT specific alias schemas and intended
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index fd45a76..06b3fb1 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -145,7 +145,6 @@
         else:
             raise
 
-
 # Classes and methods for the operator database
 class OpInfo(object):
     """Operator information and helper functions for acquiring it."""
@@ -184,6 +183,9 @@
                  supports_gradgrad=True,  # support second order gradients (this value is ignored if supports_autograd=False)
                  supports_inplace_autograd=None,  # whether the operation supports inplace autograd
                                                   # defaults to supports_autograd's value
+                 supports_forward_ad=False,  # Whether the operation support forward mode AD
+                                             # If the value is True, we check that the gradients are correct
+                                             # If the value is False, we test that forward grad is not implemented
                  supports_sparse=False,  # whether the op supports sparse inputs
                  gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs),  # wrapper function for gradcheck
                  check_batched_grad=True,  # check batched grad when doing gradcheck
@@ -248,6 +250,7 @@
 
         self.gradcheck_wrapper = gradcheck_wrapper
         self.supports_gradgrad = supports_gradgrad
+        self.supports_forward_ad = supports_forward_ad
         self.check_batched_grad = check_batched_grad
         self.check_batched_gradgrad = check_batched_gradgrad
         self.gradcheck_nondet_tol = gradcheck_nondet_tol
@@ -3659,7 +3662,8 @@
                          dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
                          dtypesIfCPU=all_types_and_complex_and(torch.bfloat16, torch.half),
                          dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
-                         safe_casts_outputs=False)
+                         safe_casts_outputs=False,
+                         supports_forward_ad=True)
 ]
 
 def reference_sign(x):
@@ -3783,7 +3787,8 @@
                                 dtypes=[torch.cfloat, torch.cdouble]),
                    ),
                    supports_inplace_autograd=False,
-                   assert_autodiffed=True),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True),
     # NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952)
     UnaryUfuncInfo('acos',
                    aliases=('arccos', ),
@@ -3808,6 +3813,8 @@
                                 dtypes=[torch.cdouble], active_if=IS_WINDOWS),
                        SkipInfo('TestGradients', 'test_inplace_grad',
                                 dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+                       SkipInfo('TestGradients', 'test_forward_mode_AD',
+                                dtypes=[torch.cdouble], active_if=IS_WINDOWS),
                    )),
     # NOTE: the derivative for inplace acosh is not implemented
     UnaryUfuncInfo('acosh',
@@ -3840,16 +3847,20 @@
                                 device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
                        SkipInfo('TestGradients', 'test_method_grad',
                                 device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+                       SkipInfo('TestGradients', 'test_forward_mode_AD',
+                                dtypes=[torch.cdouble], active_if=IS_WINDOWS),
                    )),
     OpInfo('add',
            dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
            assert_autodiffed=True,
            sample_inputs_func=partial(sample_inputs_binary_pwise, alpha=2),
-           supports_inplace_autograd=False),
+           supports_inplace_autograd=False,
+           supports_forward_ad=True),
     OpInfo('mul',
            aliases=('multiply',),
            dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
            assert_autodiffed=True,
+           supports_forward_ad=True,
            sample_inputs_func=sample_inputs_binary_pwise),
     OpInfo('sub',
            aliases=('subtract',),
@@ -4132,7 +4143,9 @@
            skips=(
                # cuda gradchecks are slow
                # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
-               SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)),
+               SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),
+               # Gradcheck for complex generates invalid inputs for this function
+               SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),)),
     OpInfo('cholesky_inverse',
            dtypes=floating_and_complex_types(),
            backward_dtypes=floating_types(),
@@ -4187,6 +4200,7 @@
                    ref=np.positive,
                    dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
                    supports_out=False,
+                   supports_forward_ad=True,
                    ),
     UnaryUfuncInfo('conj',
                    ref=np.conj,
@@ -4203,14 +4217,18 @@
                        SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
                                 dtypes=[torch.int],
                                 active_if=IS_WINDOWS),
+                       # TODO fix the formula for complex forward AD
+                       SkipInfo('TestGradients', 'test_forward_mode_AD'),
                    )),
     OpInfo('view_as_real',
            dtypes=complex_types(),
+           supports_forward_ad=True,
            sample_inputs_func=sample_inputs_view_as_real,
            ),
     OpInfo('view_as_complex',
            dtypes=floating_types_and(torch.half),
            supports_out=False,
+           supports_forward_ad=True,
            skips=(
                # "sum_cpu/sum_cuda" not implemented for 'ComplexHalf'
                SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.half,)),
@@ -4632,7 +4650,9 @@
            skips=(
                # cuda gradchecks are slow
                # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
-               SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)
+               SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),
+               # Gradcheck for complex generates invalid inputs for this function
+               SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),)
            ),
     OpInfo('linalg.cholesky_ex',
            aten_name='linalg_cholesky_ex',
@@ -4705,8 +4725,7 @@
            dtypes=floating_and_complex_types(),
            supports_out=True,
            sample_inputs_func=sample_inputs_linalg_lstsq,
-           check_batched_grad=False,
-           check_batched_gradgrad=False,
+           supports_autograd=False,
            decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
            skips=(
                # skip because `linalg_lstsq` is not differentiable
@@ -5095,6 +5114,7 @@
     OpInfo('narrow',
            dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
            supports_out=False,
+           supports_forward_ad=True,
            sample_inputs_func=sample_inputs_narrow),
     UnaryUfuncInfo('neg',
                    aliases=('negative', ),
@@ -5274,6 +5294,7 @@
            supports_out=False,
            skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit',),),
            assert_autodiffed=True,
+           supports_forward_ad=True,
            autodiff_nonfusible_nodes=['aten::add'],),
     OpInfo('__rdiv__',
            op=torch.Tensor.__rdiv__,
@@ -5290,6 +5311,7 @@
            supports_out=False,
            skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit',),),
            assert_autodiffed=True,
+           supports_forward_ad=True,
            autodiff_nonfusible_nodes=['aten::mul'],),
     OpInfo('__rpow__',
            op=torch.Tensor.__rpow__,
@@ -5339,6 +5361,7 @@
            assert_autodiffed=True,),
     OpInfo('select',
            dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
+           supports_forward_ad=True,
            sample_inputs_func=sample_inputs_select,
            supports_out=False),
     UnaryUfuncInfo('signbit',
@@ -5416,19 +5439,23 @@
            dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
            dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
            supports_out=False,
+           supports_forward_ad=True,
            skips=(SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),),
            sample_inputs_func=sample_inputs_tensor_split,),
     OpInfo('hsplit',
            dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
            supports_out=False,
+           supports_forward_ad=True,
            sample_inputs_func=sample_inputs_hsplit,),
     OpInfo('vsplit',
            dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
            supports_out=False,
+           supports_forward_ad=True,
            sample_inputs_func=sample_inputs_vsplit,),
     OpInfo('dsplit',
            dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
            supports_out=False,
+           supports_forward_ad=True,
            sample_inputs_func=sample_inputs_dsplit,),
     OpInfo('triangular_solve',
            op=torch.triangular_solve,