Revert D28412496: Revert "Revert D28387767: Add forward AD test for op info"
Test Plan: revert-hammer
Differential Revision:
D28412496 (https://github.com/pytorch/pytorch/commit/4f28c0b5909d5122b0beffa49a579fc0d5fe0f80)
Original commit changeset: 5b8e30b5e807
fbshipit-source-id: 5a47aad4d5428e97e2d2b4acb4192909360870cd
diff --git a/test/test_ops.py b/test/test_ops.py
index f36fc8b..4bfd723 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -99,7 +99,7 @@
return _fn
- def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False):
+ def _check_helper(self, device, dtype, op, variant, check):
if variant is None:
self.skipTest("Skipped! Variant not implemented.")
if not op.supports_dtype(dtype, torch.device(device).type):
@@ -139,10 +139,8 @@
check_batched_grad=op.check_batched_grad,
check_grad_dtypes=True,
nondet_tol=op.gradcheck_nondet_tol,
- fast_mode=op.gradcheck_fast_mode,
- check_forward_ad=check_forward_ad))
+ fast_mode=op.gradcheck_fast_mode))
elif check == 'gradgradcheck':
- self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
self.assertTrue(gradgradcheck(fn, gradcheck_args,
gen_non_contig_grad_outputs=False,
check_batched_grad=op.check_batched_gradgrad,
@@ -158,8 +156,8 @@
else:
self.assertTrue(False, msg="Unknown check requested!")
- def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False):
- return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad)
+ def _grad_test_helper(self, device, dtype, op, variant):
+ return self._check_helper(device, dtype, op, variant, 'gradcheck')
def _gradgrad_test_helper(self, device, dtype, op, variant):
return self._check_helper(device, dtype, op, variant, 'gradgradcheck')
@@ -223,19 +221,6 @@
self.skipTest("Skipped! Operation does not support inplace autograd.")
self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
- @_gradcheck_ops(op_db)
- def test_forward_mode_AD(self, device, dtype, op):
- self._skip_helper(op, device, dtype)
-
- if op.supports_forward_ad:
- self._grad_test_helper(device, dtype, op, op.get_op(), check_forward_ad=True)
- else:
- err_msg = r"Trying to use forward AD with .* that does not support it\."
- hint_msg = ("Running forward AD for an OP that has does not support it did not "
- "raise any error. If your op supports forward AD, you should set supports_forward_ad=True")
- with self.assertRaisesRegex(RuntimeError, err_msg, msg=hint_msg):
- self._grad_test_helper(device, dtype, op, op.get_op(), check_forward_ad=True)
-
# Tests operators for consistency between JIT and eager, also checks
# correctness of JIT specific alias schemas and intended
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index a8bddaa..fd45a76 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -145,6 +145,7 @@
else:
raise
+
# Classes and methods for the operator database
class OpInfo(object):
"""Operator information and helper functions for acquiring it."""
@@ -183,9 +184,6 @@
supports_gradgrad=True, # support second order gradients (this value is ignored if supports_autograd=False)
supports_inplace_autograd=None, # whether the operation supports inplace autograd
# defaults to supports_autograd's value
- supports_forward_ad=False, # Whether the operation support forward mode AD
- # If the value is True, we check that the gradients are correct
- # If the value is False, we test that forward grad is not implemented
supports_sparse=False, # whether the op supports sparse inputs
gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs), # wrapper function for gradcheck
check_batched_grad=True, # check batched grad when doing gradcheck
@@ -250,7 +248,6 @@
self.gradcheck_wrapper = gradcheck_wrapper
self.supports_gradgrad = supports_gradgrad
- self.supports_forward_ad = supports_forward_ad
self.check_batched_grad = check_batched_grad
self.check_batched_gradgrad = check_batched_gradgrad
self.gradcheck_nondet_tol = gradcheck_nondet_tol
@@ -3662,8 +3659,7 @@
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bfloat16, torch.half),
dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
- safe_casts_outputs=False,
- supports_forward_ad=True)
+ safe_casts_outputs=False)
]
def reference_sign(x):
@@ -3787,8 +3783,7 @@
dtypes=[torch.cfloat, torch.cdouble]),
),
supports_inplace_autograd=False,
- assert_autodiffed=True,
- supports_forward_ad=True),
+ assert_autodiffed=True),
# NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952)
UnaryUfuncInfo('acos',
aliases=('arccos', ),
@@ -3850,13 +3845,11 @@
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
assert_autodiffed=True,
sample_inputs_func=partial(sample_inputs_binary_pwise, alpha=2),
- supports_inplace_autograd=False,
- supports_forward_ad=True),
+ supports_inplace_autograd=False),
OpInfo('mul',
aliases=('multiply',),
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
assert_autodiffed=True,
- supports_forward_ad=True,
sample_inputs_func=sample_inputs_binary_pwise),
OpInfo('sub',
aliases=('subtract',),
@@ -4139,9 +4132,7 @@
skips=(
# cuda gradchecks are slow
# see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
- SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),
- # Gradcheck for complex generates invalid inputs for this function
- SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),)),
+ SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)),
OpInfo('cholesky_inverse',
dtypes=floating_and_complex_types(),
backward_dtypes=floating_types(),
@@ -4196,7 +4187,6 @@
ref=np.positive,
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
supports_out=False,
- supports_forward_ad=True,
),
UnaryUfuncInfo('conj',
ref=np.conj,
@@ -4213,18 +4203,14 @@
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
dtypes=[torch.int],
active_if=IS_WINDOWS),
- # TODO fix the formula for complex forward AD
- SkipInfo('TestGradients', 'test_forward_mode_AD'),
)),
OpInfo('view_as_real',
dtypes=complex_types(),
- supports_forward_ad=True,
sample_inputs_func=sample_inputs_view_as_real,
),
OpInfo('view_as_complex',
dtypes=floating_types_and(torch.half),
supports_out=False,
- supports_forward_ad=True,
skips=(
# "sum_cpu/sum_cuda" not implemented for 'ComplexHalf'
SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.half,)),
@@ -4646,9 +4632,7 @@
skips=(
# cuda gradchecks are slow
# see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
- SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),
- # Gradcheck for complex generates invalid inputs for this function
- SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),)
+ SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)
),
OpInfo('linalg.cholesky_ex',
aten_name='linalg_cholesky_ex',
@@ -4721,7 +4705,8 @@
dtypes=floating_and_complex_types(),
supports_out=True,
sample_inputs_func=sample_inputs_linalg_lstsq,
- supports_autograd=False,
+ check_batched_grad=False,
+ check_batched_gradgrad=False,
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
skips=(
# skip because `linalg_lstsq` is not differentiable
@@ -5110,7 +5095,6 @@
OpInfo('narrow',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
- supports_forward_ad=True,
sample_inputs_func=sample_inputs_narrow),
UnaryUfuncInfo('neg',
aliases=('negative', ),
@@ -5290,7 +5274,6 @@
supports_out=False,
skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit',),),
assert_autodiffed=True,
- supports_forward_ad=True,
autodiff_nonfusible_nodes=['aten::add'],),
OpInfo('__rdiv__',
op=torch.Tensor.__rdiv__,
@@ -5307,7 +5290,6 @@
supports_out=False,
skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit',),),
assert_autodiffed=True,
- supports_forward_ad=True,
autodiff_nonfusible_nodes=['aten::mul'],),
OpInfo('__rpow__',
op=torch.Tensor.__rpow__,
@@ -5357,7 +5339,6 @@
assert_autodiffed=True,),
OpInfo('select',
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
- supports_forward_ad=True,
sample_inputs_func=sample_inputs_select,
supports_out=False),
UnaryUfuncInfo('signbit',
@@ -5435,23 +5416,19 @@
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
- supports_forward_ad=True,
skips=(SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),),
sample_inputs_func=sample_inputs_tensor_split,),
OpInfo('hsplit',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
- supports_forward_ad=True,
sample_inputs_func=sample_inputs_hsplit,),
OpInfo('vsplit',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
- supports_forward_ad=True,
sample_inputs_func=sample_inputs_vsplit,),
OpInfo('dsplit',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
- supports_forward_ad=True,
sample_inputs_func=sample_inputs_dsplit,),
OpInfo('triangular_solve',
op=torch.triangular_solve,