Update sparse_funcs to include primtorch types (#107421)
Fixes #107335.
A few issues have been identified while enabling this test and filed:
https://github.com/pytorch/pytorch/issues/105986
https://github.com/pytorch/pytorch/issues/108204
https://github.com/pytorch/pytorch/issues/108205
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107421
Approved by: https://github.com/ezyang
diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py
index 6f4bbe3..f5a0481 100644
--- a/test/test_spectral_ops.py
+++ b/test/test_spectral_ops.py
@@ -325,12 +325,18 @@
# TODO: Remove torch.half error when complex32 is fully implemented
sample = first_sample(self, op.sample_inputs(device, dtype))
device_type = torch.device(device).type
+ # FIXME: https://github.com/pytorch/pytorch/issues/108204
+ default_msg = (
+ r"(Unsupported dtype|"
+ r"FFT doesn't support (tensors*|transforms) of type|"
+ r"expected scalar type \w+ but found|)"
+ )
if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
- err_msg = "Unsupported dtype "
+ err_msg = default_msg
elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53"
else:
- err_msg = "Unsupported dtype "
+ err_msg = default_msg
with self.assertRaisesRegex(RuntimeError, err_msg):
op(sample.input, *sample.args, **sample.kwargs)
@@ -444,11 +450,12 @@
allowed_dtypes=[torch.float, torch.cfloat])
def test_fftn_invalid(self, device, dtype, op):
a = torch.rand(10, 10, 10, device=device, dtype=dtype)
-
- with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
+ # FIXME: https://github.com/pytorch/pytorch/issues/108205
+ errMsg = r"(dims must be unique|duplicate value in the list of dims)"
+ with self.assertRaisesRegex(RuntimeError, errMsg):
op(a, dim=(0, 1, 0))
- with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
+ with self.assertRaisesRegex(RuntimeError, errMsg):
op(a, dim=(2, -1))
with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index d1a248d..7ce26ae 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -21585,7 +21585,7 @@
unary_ufuncs = [op for op in ops_and_refs if isinstance(op, UnaryUfuncInfo)]
binary_ufuncs = [op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)]
binary_ufuncs_and_refs = tuple(op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo))
-spectral_funcs = [op for op in op_db if isinstance(op, SpectralFuncInfo)]
+spectral_funcs = [op for op in ops_and_refs if isinstance(op, SpectralFuncInfo)]
sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse]
sparse_csr_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse_csr]
sparse_reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo) and op.supports_sparse]
diff --git a/torch/testing/_internal/opinfo/definitions/fft.py b/torch/testing/_internal/opinfo/definitions/fft.py
index dc8f7ae..48b7c7c 100644
--- a/torch/testing/_internal/opinfo/definitions/fft.py
+++ b/torch/testing/_internal/opinfo/definitions/fft.py
@@ -650,34 +650,110 @@
SpectralFuncPythonRefInfo(
"_refs.fft.fft",
torch_opinfo_name="fft.fft",
+ skips=(
+ # _refs.fft.* functions have inconsistent behavior for empty tensors
+ # https://github.com/pytorch/pytorch/issues/105986
+ DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
+ ),
),
SpectralFuncPythonRefInfo(
"_refs.fft.ifft",
torch_opinfo_name="fft.ifft",
+ skips=(
+ # _refs.fft.* functions have inconsistent behavior for empty tensors
+ # https://github.com/pytorch/pytorch/issues/105986
+ DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
+ ),
),
SpectralFuncPythonRefInfo(
"_refs.fft.rfft",
torch_opinfo_name="fft.rfft",
+ skips=(
+ # _refs.fft.* functions have inconsistent behavior for empty tensors
+ # https://github.com/pytorch/pytorch/issues/105986
+ DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
+ ),
),
SpectralFuncPythonRefInfo(
"_refs.fft.irfft",
torch_opinfo_name="fft.irfft",
+ skips=(
+ # _refs.fft.* functions have inconsistent behavior for empty tensors
+ # https://github.com/pytorch/pytorch/issues/105986
+ DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
+ # TODO: internally promoted to complex64 so not rejected
+ DecorateInfo(
+ unittest.expectedFailure,
+ "TestFFT",
+ "test_fft_half_and_bfloat16_errors",
+ dtypes=[torch.bfloat16],
+ ),
+ ),
),
SpectralFuncPythonRefInfo(
"_refs.fft.hfft",
torch_opinfo_name="fft.hfft",
+ skips=(
+ # _refs.fft.* functions have inconsistent behavior for empty tensors
+ # https://github.com/pytorch/pytorch/issues/105986
+ DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
+ # FIXME: https://github.com/pytorch/pytorch/issues/108204
+ DecorateInfo(
+ unittest.expectedFailure,
+ "TestFFT",
+ "test_fft_half_and_bfloat16_errors",
+ dtypes=[torch.bfloat16],
+ ),
+ ),
),
SpectralFuncPythonRefInfo(
"_refs.fft.ihfft",
torch_opinfo_name="fft.ihfft",
+ skips=(
+ # _refs.fft.* functions have inconsistent behavior for empty tensors
+ # https://github.com/pytorch/pytorch/issues/105986
+ DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
+ ),
),
SpectralFuncPythonRefInfo(
"_refs.fft.fftn",
torch_opinfo_name="fft.fftn",
+ decorators=[
+ DecorateInfo(
+ precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+ "TestFFT",
+ "test_reference_nd",
+ )
+ ],
+ skips=(
+ # FIXME: https://github.com/pytorch/pytorch/issues/108204
+ DecorateInfo(
+ unittest.expectedFailure,
+ "TestFFT",
+ "test_fft_half_and_bfloat16_errors",
+ dtypes=[torch.bfloat16],
+ ),
+ ),
),
SpectralFuncPythonRefInfo(
"_refs.fft.ifftn",
torch_opinfo_name="fft.ifftn",
+ decorators=[
+ DecorateInfo(
+ precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+ "TestFFT",
+ "test_reference_nd",
+ )
+ ],
+ skips=(
+ # FIXME: https://github.com/pytorch/pytorch/issues/108204
+ DecorateInfo(
+ unittest.expectedFailure,
+ "TestFFT",
+ "test_fft_half_and_bfloat16_errors",
+ dtypes=[torch.bfloat16],
+ ),
+ ),
),
SpectralFuncPythonRefInfo(
"_refs.fft.rfftn",
@@ -686,14 +762,67 @@
SpectralFuncPythonRefInfo(
"_refs.fft.irfftn",
torch_opinfo_name="fft.irfftn",
+ decorators=[
+ DecorateInfo(
+ precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+ "TestFFT",
+ "test_reference_nd",
+ )
+ ],
+ skips=(
+ # FIXME: https://github.com/pytorch/pytorch/issues/108204
+ DecorateInfo(
+ unittest.expectedFailure,
+ "TestFFT",
+ "test_fft_half_and_bfloat16_errors",
+ dtypes=[torch.bfloat16],
+ ),
+ ),
),
SpectralFuncPythonRefInfo(
"_refs.fft.hfftn",
torch_opinfo_name="fft.hfftn",
+ decorators=[
+ DecorateInfo(
+ precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+ "TestFFT",
+ "test_reference_nd",
+ )
+ ],
+ skips=(
+ # FIXME: https://github.com/pytorch/pytorch/issues/108204
+ DecorateInfo(
+ unittest.expectedFailure,
+ "TestFFT",
+ "test_fft_half_and_bfloat16_errors",
+ dtypes=[torch.bfloat16],
+ ),
+ # FIXME: https://github.com/pytorch/pytorch/issues/108205
+ DecorateInfo(
+ unittest.expectedFailure,
+ "TestFFT",
+ "test_fftn_invalid",
+ ),
+ ),
),
SpectralFuncPythonRefInfo(
"_refs.fft.ihfftn",
torch_opinfo_name="fft.ihfftn",
+ decorators=[
+ DecorateInfo(
+ precisionOverride({torch.float: 2e-4}),
+ "TestFFT",
+ "test_reference_nd",
+ )
+ ],
+ skips=(
+ # FIXME: https://github.com/pytorch/pytorch/issues/108205
+ DecorateInfo(
+ unittest.expectedFailure,
+ "TestFFT",
+ "test_fftn_invalid",
+ ),
+ ),
),
SpectralFuncPythonRefInfo(
"_refs.fft.fft2",
@@ -702,6 +831,13 @@
SpectralFuncPythonRefInfo(
"_refs.fft.ifft2",
torch_opinfo_name="fft.ifft2",
+ decorators=[
+ DecorateInfo(
+ precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+ "TestFFT",
+ "test_reference_nd",
+ )
+ ],
),
SpectralFuncPythonRefInfo(
"_refs.fft.rfft2",
@@ -710,14 +846,35 @@
SpectralFuncPythonRefInfo(
"_refs.fft.irfft2",
torch_opinfo_name="fft.irfft2",
+ decorators=[
+ DecorateInfo(
+ precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+ "TestFFT",
+ "test_reference_nd",
+ )
+ ],
),
SpectralFuncPythonRefInfo(
"_refs.fft.hfft2",
torch_opinfo_name="fft.hfft2",
+ decorators=[
+ DecorateInfo(
+ precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+ "TestFFT",
+ "test_reference_nd",
+ )
+ ],
),
SpectralFuncPythonRefInfo(
"_refs.fft.ihfft2",
torch_opinfo_name="fft.ihfft2",
+ decorators=[
+ DecorateInfo(
+ precisionOverride({torch.float: 2e-4}),
+ "TestFFT",
+ "test_reference_nd",
+ )
+ ],
),
PythonRefInfo(
"_refs.fft.fftshift",