Update sparse_funcs to include primtorch types (#107421)

Fixes #107335.

A few issues have been identified while enabling this test and filed:
https://github.com/pytorch/pytorch/issues/105986
https://github.com/pytorch/pytorch/issues/108204
https://github.com/pytorch/pytorch/issues/108205

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107421
Approved by: https://github.com/ezyang
diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py
index 6f4bbe3..f5a0481 100644
--- a/test/test_spectral_ops.py
+++ b/test/test_spectral_ops.py
@@ -325,12 +325,18 @@
         # TODO: Remove torch.half error when complex32 is fully implemented
         sample = first_sample(self, op.sample_inputs(device, dtype))
         device_type = torch.device(device).type
+        # FIXME: https://github.com/pytorch/pytorch/issues/108204
+        default_msg = (
+            r"(Unsupported dtype|"
+            r"FFT doesn't support (tensors*|transforms) of type|"
+            r"expected scalar type \w+ but found|)"
+        )
         if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
-            err_msg = "Unsupported dtype "
+            err_msg = default_msg
         elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
             err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53"
         else:
-            err_msg = "Unsupported dtype "
+            err_msg = default_msg
         with self.assertRaisesRegex(RuntimeError, err_msg):
             op(sample.input, *sample.args, **sample.kwargs)
 
@@ -444,11 +450,12 @@
          allowed_dtypes=[torch.float, torch.cfloat])
     def test_fftn_invalid(self, device, dtype, op):
         a = torch.rand(10, 10, 10, device=device, dtype=dtype)
-
-        with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
+        # FIXME: https://github.com/pytorch/pytorch/issues/108205
+        errMsg = r"(dims must be unique|duplicate value in the list of dims)"
+        with self.assertRaisesRegex(RuntimeError, errMsg):
             op(a, dim=(0, 1, 0))
 
-        with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
+        with self.assertRaisesRegex(RuntimeError, errMsg):
             op(a, dim=(2, -1))
 
         with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index d1a248d..7ce26ae 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -21585,7 +21585,7 @@
 unary_ufuncs = [op for op in ops_and_refs if isinstance(op, UnaryUfuncInfo)]
 binary_ufuncs = [op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)]
 binary_ufuncs_and_refs = tuple(op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo))
-spectral_funcs = [op for op in op_db if isinstance(op, SpectralFuncInfo)]
+spectral_funcs = [op for op in ops_and_refs if isinstance(op, SpectralFuncInfo)]
 sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse]
 sparse_csr_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse_csr]
 sparse_reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo) and op.supports_sparse]
diff --git a/torch/testing/_internal/opinfo/definitions/fft.py b/torch/testing/_internal/opinfo/definitions/fft.py
index dc8f7ae..48b7c7c 100644
--- a/torch/testing/_internal/opinfo/definitions/fft.py
+++ b/torch/testing/_internal/opinfo/definitions/fft.py
@@ -650,34 +650,110 @@
     SpectralFuncPythonRefInfo(
         "_refs.fft.fft",
         torch_opinfo_name="fft.fft",
+        skips=(
+            # _refs.fft.* functions have inconsistent behavior for empty tensors
+            # https://github.com/pytorch/pytorch/issues/105986
+            DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
+        ),
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.ifft",
         torch_opinfo_name="fft.ifft",
+        skips=(
+            # _refs.fft.* functions have inconsistent behavior for empty tensors
+            # https://github.com/pytorch/pytorch/issues/105986
+            DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
+        ),
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.rfft",
         torch_opinfo_name="fft.rfft",
+        skips=(
+            # _refs.fft.* functions have inconsistent behavior for empty tensors
+            # https://github.com/pytorch/pytorch/issues/105986
+            DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
+        ),
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.irfft",
         torch_opinfo_name="fft.irfft",
+        skips=(
+            # _refs.fft.* functions have inconsistent behavior for empty tensors
+            # https://github.com/pytorch/pytorch/issues/105986
+            DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
+            # TODO: internally promoted to complex64 so not rejected
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestFFT",
+                "test_fft_half_and_bfloat16_errors",
+                dtypes=[torch.bfloat16],
+            ),
+        ),
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.hfft",
         torch_opinfo_name="fft.hfft",
+        skips=(
+            # _refs.fft.* functions have inconsistent behavior for empty tensors
+            # https://github.com/pytorch/pytorch/issues/105986
+            DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
+            # FIXME: https://github.com/pytorch/pytorch/issues/108204
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestFFT",
+                "test_fft_half_and_bfloat16_errors",
+                dtypes=[torch.bfloat16],
+            ),
+        ),
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.ihfft",
         torch_opinfo_name="fft.ihfft",
+        skips=(
+            # _refs.fft.* functions have inconsistent behavior for empty tensors
+            # https://github.com/pytorch/pytorch/issues/105986
+            DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
+        ),
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.fftn",
         torch_opinfo_name="fft.fftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+        skips=(
+            # FIXME: https://github.com/pytorch/pytorch/issues/108204
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestFFT",
+                "test_fft_half_and_bfloat16_errors",
+                dtypes=[torch.bfloat16],
+            ),
+        ),
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.ifftn",
         torch_opinfo_name="fft.ifftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+        skips=(
+            # FIXME: https://github.com/pytorch/pytorch/issues/108204
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestFFT",
+                "test_fft_half_and_bfloat16_errors",
+                dtypes=[torch.bfloat16],
+            ),
+        ),
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.rfftn",
@@ -686,14 +762,67 @@
     SpectralFuncPythonRefInfo(
         "_refs.fft.irfftn",
         torch_opinfo_name="fft.irfftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+        skips=(
+            # FIXME: https://github.com/pytorch/pytorch/issues/108204
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestFFT",
+                "test_fft_half_and_bfloat16_errors",
+                dtypes=[torch.bfloat16],
+            ),
+        ),
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.hfftn",
         torch_opinfo_name="fft.hfftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+        skips=(
+            # FIXME: https://github.com/pytorch/pytorch/issues/108204
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestFFT",
+                "test_fft_half_and_bfloat16_errors",
+                dtypes=[torch.bfloat16],
+            ),
+            # FIXME: https://github.com/pytorch/pytorch/issues/108205
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestFFT",
+                "test_fftn_invalid",
+            ),
+        ),
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.ihfftn",
         torch_opinfo_name="fft.ihfftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+        skips=(
+            # FIXME: https://github.com/pytorch/pytorch/issues/108205
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestFFT",
+                "test_fftn_invalid",
+            ),
+        ),
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.fft2",
@@ -702,6 +831,13 @@
     SpectralFuncPythonRefInfo(
         "_refs.fft.ifft2",
         torch_opinfo_name="fft.ifft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.rfft2",
@@ -710,14 +846,35 @@
     SpectralFuncPythonRefInfo(
         "_refs.fft.irfft2",
         torch_opinfo_name="fft.irfft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.hfft2",
         torch_opinfo_name="fft.hfft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
     ),
     SpectralFuncPythonRefInfo(
         "_refs.fft.ihfft2",
         torch_opinfo_name="fft.ihfft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
     ),
     PythonRefInfo(
         "_refs.fft.fftshift",