[ROCm] unskip FFT tests
This pr enables a lot of fft tests on rocm, see the full list here https://github.com/ROCmSoftwarePlatform/pytorch/issues/924.
After enabling the tests we found that 3 tests , test_reference_1d, test_reference_nd and test_fn_grad have issue. We skip those tests on ROCM in this pr as well. We will address those skipped tests in a subsequent pr.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74713
Approved by: https://github.com/malfet
diff --git a/test/test_ops_gradients.py b/test/test_ops_gradients.py
index 71bb243..d50b0d3 100644
--- a/test/test_ops_gradients.py
+++ b/test/test_ops_gradients.py
@@ -7,7 +7,7 @@
(TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck, first_sample)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import \
- (instantiate_device_type_tests, ops, OpDTypes)
+ (instantiate_device_type_tests, ops, OpDTypes, skipCUDAIfRocm)
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
torch.set_default_dtype(torch.float32)
@@ -113,6 +113,7 @@
self.skipTest("Skipped! Complex autograd not supported.")
# Tests that gradients are computed correctly
+ @skipCUDAIfRocm
@_gradcheck_ops(op_db)
def test_fn_grad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py
index c11b87b..344c810 100644
--- a/test/test_spectral_ops.py
+++ b/test/test_spectral_ops.py
@@ -13,7 +13,7 @@
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes,
- skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf)
+ skipCPUIfNoFFT, skipCUDAIfRocm, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf)
from torch.testing._internal.common_methods_invocations import (
spectral_funcs, SpectralFuncInfo, SpectralFuncType)
@@ -204,6 +204,7 @@
else:
return (input, s, dim, norm)
+ @skipCUDAIfRocm
@onlyNativeDeviceTypes
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.OneD])
def test_reference_1d(self, device, dtype, op):
@@ -367,6 +368,7 @@
op(x)
# nd-fft tests
+ @skipCUDAIfRocm
@onlyNativeDeviceTypes
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND])
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index c0c4ec8..c873b1e 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -5330,7 +5330,6 @@
decorators = list(decorators) if decorators is not None else []
decorators += [
skipCPUIfNoFFT,
- skipCUDAIfRocm,
]
super().__init__(name=name,