Reland OpInfo support for forward AD (#58304)
Summary:
Try 3 to land this.
Trying ci-all label to ensure we test everything.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58304
Reviewed By: heitorschueroff
Differential Revision: D28474343
Pulled By: albanD
fbshipit-source-id: 8230fa3c0a8d3633f09999e7c2f47dbdc5fe57e9
diff --git a/test/test_ops.py b/test/test_ops.py
index 4bfd723..f36fc8b 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -99,7 +99,7 @@
return _fn
- def _check_helper(self, device, dtype, op, variant, check):
+ def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False):
if variant is None:
self.skipTest("Skipped! Variant not implemented.")
if not op.supports_dtype(dtype, torch.device(device).type):
@@ -139,8 +139,10 @@
check_batched_grad=op.check_batched_grad,
check_grad_dtypes=True,
nondet_tol=op.gradcheck_nondet_tol,
- fast_mode=op.gradcheck_fast_mode))
+ fast_mode=op.gradcheck_fast_mode,
+ check_forward_ad=check_forward_ad))
elif check == 'gradgradcheck':
+ self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
self.assertTrue(gradgradcheck(fn, gradcheck_args,
gen_non_contig_grad_outputs=False,
check_batched_grad=op.check_batched_gradgrad,
@@ -156,8 +158,8 @@
else:
self.assertTrue(False, msg="Unknown check requested!")
- def _grad_test_helper(self, device, dtype, op, variant):
- return self._check_helper(device, dtype, op, variant, 'gradcheck')
+ def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False):
+ return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad)
def _gradgrad_test_helper(self, device, dtype, op, variant):
return self._check_helper(device, dtype, op, variant, 'gradgradcheck')
@@ -221,6 +223,19 @@
self.skipTest("Skipped! Operation does not support inplace autograd.")
self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
+ @_gradcheck_ops(op_db)
+ def test_forward_mode_AD(self, device, dtype, op):
+ self._skip_helper(op, device, dtype)
+
+ if op.supports_forward_ad:
+ self._grad_test_helper(device, dtype, op, op.get_op(), check_forward_ad=True)
+ else:
+ err_msg = r"Trying to use forward AD with .* that does not support it\."
+ hint_msg = ("Running forward AD for an OP that has does not support it did not "
+ "raise any error. If your op supports forward AD, you should set supports_forward_ad=True")
+ with self.assertRaisesRegex(RuntimeError, err_msg, msg=hint_msg):
+ self._grad_test_helper(device, dtype, op, op.get_op(), check_forward_ad=True)
+
# Tests operators for consistency between JIT and eager, also checks
# correctness of JIT specific alias schemas and intended
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index fd45a76..06b3fb1 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -145,7 +145,6 @@
else:
raise
-
# Classes and methods for the operator database
class OpInfo(object):
"""Operator information and helper functions for acquiring it."""
@@ -184,6 +183,9 @@
supports_gradgrad=True, # support second order gradients (this value is ignored if supports_autograd=False)
supports_inplace_autograd=None, # whether the operation supports inplace autograd
# defaults to supports_autograd's value
+ supports_forward_ad=False, # Whether the operation support forward mode AD
+ # If the value is True, we check that the gradients are correct
+ # If the value is False, we test that forward grad is not implemented
supports_sparse=False, # whether the op supports sparse inputs
gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs), # wrapper function for gradcheck
check_batched_grad=True, # check batched grad when doing gradcheck
@@ -248,6 +250,7 @@
self.gradcheck_wrapper = gradcheck_wrapper
self.supports_gradgrad = supports_gradgrad
+ self.supports_forward_ad = supports_forward_ad
self.check_batched_grad = check_batched_grad
self.check_batched_gradgrad = check_batched_gradgrad
self.gradcheck_nondet_tol = gradcheck_nondet_tol
@@ -3659,7 +3662,8 @@
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bfloat16, torch.half),
dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
- safe_casts_outputs=False)
+ safe_casts_outputs=False,
+ supports_forward_ad=True)
]
def reference_sign(x):
@@ -3783,7 +3787,8 @@
dtypes=[torch.cfloat, torch.cdouble]),
),
supports_inplace_autograd=False,
- assert_autodiffed=True),
+ assert_autodiffed=True,
+ supports_forward_ad=True),
# NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952)
UnaryUfuncInfo('acos',
aliases=('arccos', ),
@@ -3808,6 +3813,8 @@
dtypes=[torch.cdouble], active_if=IS_WINDOWS),
SkipInfo('TestGradients', 'test_inplace_grad',
dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+ SkipInfo('TestGradients', 'test_forward_mode_AD',
+ dtypes=[torch.cdouble], active_if=IS_WINDOWS),
)),
# NOTE: the derivative for inplace acosh is not implemented
UnaryUfuncInfo('acosh',
@@ -3840,16 +3847,20 @@
device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
SkipInfo('TestGradients', 'test_method_grad',
device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+ SkipInfo('TestGradients', 'test_forward_mode_AD',
+ dtypes=[torch.cdouble], active_if=IS_WINDOWS),
)),
OpInfo('add',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
assert_autodiffed=True,
sample_inputs_func=partial(sample_inputs_binary_pwise, alpha=2),
- supports_inplace_autograd=False),
+ supports_inplace_autograd=False,
+ supports_forward_ad=True),
OpInfo('mul',
aliases=('multiply',),
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
assert_autodiffed=True,
+ supports_forward_ad=True,
sample_inputs_func=sample_inputs_binary_pwise),
OpInfo('sub',
aliases=('subtract',),
@@ -4132,7 +4143,9 @@
skips=(
# cuda gradchecks are slow
# see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
- SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)),
+ SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),
+ # Gradcheck for complex generates invalid inputs for this function
+ SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),)),
OpInfo('cholesky_inverse',
dtypes=floating_and_complex_types(),
backward_dtypes=floating_types(),
@@ -4187,6 +4200,7 @@
ref=np.positive,
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
supports_out=False,
+ supports_forward_ad=True,
),
UnaryUfuncInfo('conj',
ref=np.conj,
@@ -4203,14 +4217,18 @@
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
dtypes=[torch.int],
active_if=IS_WINDOWS),
+ # TODO fix the formula for complex forward AD
+ SkipInfo('TestGradients', 'test_forward_mode_AD'),
)),
OpInfo('view_as_real',
dtypes=complex_types(),
+ supports_forward_ad=True,
sample_inputs_func=sample_inputs_view_as_real,
),
OpInfo('view_as_complex',
dtypes=floating_types_and(torch.half),
supports_out=False,
+ supports_forward_ad=True,
skips=(
# "sum_cpu/sum_cuda" not implemented for 'ComplexHalf'
SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.half,)),
@@ -4632,7 +4650,9 @@
skips=(
# cuda gradchecks are slow
# see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
- SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)
+ SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),
+ # Gradcheck for complex generates invalid inputs for this function
+ SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),)
),
OpInfo('linalg.cholesky_ex',
aten_name='linalg_cholesky_ex',
@@ -4705,8 +4725,7 @@
dtypes=floating_and_complex_types(),
supports_out=True,
sample_inputs_func=sample_inputs_linalg_lstsq,
- check_batched_grad=False,
- check_batched_gradgrad=False,
+ supports_autograd=False,
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
skips=(
# skip because `linalg_lstsq` is not differentiable
@@ -5095,6 +5114,7 @@
OpInfo('narrow',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
+ supports_forward_ad=True,
sample_inputs_func=sample_inputs_narrow),
UnaryUfuncInfo('neg',
aliases=('negative', ),
@@ -5274,6 +5294,7 @@
supports_out=False,
skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit',),),
assert_autodiffed=True,
+ supports_forward_ad=True,
autodiff_nonfusible_nodes=['aten::add'],),
OpInfo('__rdiv__',
op=torch.Tensor.__rdiv__,
@@ -5290,6 +5311,7 @@
supports_out=False,
skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit',),),
assert_autodiffed=True,
+ supports_forward_ad=True,
autodiff_nonfusible_nodes=['aten::mul'],),
OpInfo('__rpow__',
op=torch.Tensor.__rpow__,
@@ -5339,6 +5361,7 @@
assert_autodiffed=True,),
OpInfo('select',
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
+ supports_forward_ad=True,
sample_inputs_func=sample_inputs_select,
supports_out=False),
UnaryUfuncInfo('signbit',
@@ -5416,19 +5439,23 @@
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
+ supports_forward_ad=True,
skips=(SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),),
sample_inputs_func=sample_inputs_tensor_split,),
OpInfo('hsplit',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
+ supports_forward_ad=True,
sample_inputs_func=sample_inputs_hsplit,),
OpInfo('vsplit',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
+ supports_forward_ad=True,
sample_inputs_func=sample_inputs_vsplit,),
OpInfo('dsplit',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
+ supports_forward_ad=True,
sample_inputs_func=sample_inputs_dsplit,),
OpInfo('triangular_solve',
op=torch.triangular_solve,