[ROCm] Enable some sparse tests on ROCm (#77877)
Enabling:
test_sampled_addmm_errors_cuda_complex128
test_sampled_addmm_errors_cuda_complex64
test_sampled_addmm_errors_cuda_float32
test_sampled_addmm_errors_cuda_float64
test_sparse_add_cuda_complex128
test_sparse_add_cuda_complex64
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77877
Approved by: https://github.com/pruthvistony, https://github.com/malfet
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index ff8b432..367d7df 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -1538,9 +1538,6 @@
def test_sparse_add(self, device, dtype):
def run_test(m, n, index_dtype):
- if TEST_WITH_ROCM and dtype.is_complex:
- self.skipTest("ROCm doesn't work with complex dtype correctly.")
-
alpha = random.random()
nnz1 = random.randint(0, m * n)
nnz2 = random.randint(0, m * n)
@@ -1744,10 +1741,9 @@
b = make_tensor((k, n), dtype=dtype, device=device)
run_test(c, a, b)
- @skipCUDAIfRocm
@onlyCUDA
@skipCUDAIf(
- not _check_cusparse_sddmm_available(),
+ not (TEST_WITH_ROCM or _check_cusparse_sddmm_available()),
"cuSparse Generic API SDDMM is not available"
)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)