[CUDA][cuSPARSE] Work around IMA in cuSPARSE ALG1 on SM 8.9 devices (#119610)
Originally surfaced from the discuss forum:
https://discuss.pytorch.org/t/issue-with-torch-sparse-mm-while-running-on-gpu/188669
This has been forwarded to cuSPARSE but we have not yet received a commitment on their end to fix this issue directly.
CC @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119610
Approved by: https://github.com/jeffdaily, https://github.com/jcaip
diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp
index 22efdbd..0251857 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp
+++ b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp
@@ -154,6 +154,13 @@
auto handle = at::cuda::getCurrentCUDASparseHandle();
+ cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+ // ALG1 is broken on SM89 as of CUDA 11.8+
+#if !defined(USE_ROCM)
+ auto default_alg = prop->major == 8 && prop->minor == 9 ? CUSPARSE_SPMM_CSR_ALG2 : CUSPARSE_SPMM_CSR_ALG1;
+#else
+ auto default_alg = CUSPARSE_SPMM_CSR_ALG1;
+#endif
// cusparseSpMM_bufferSize returns the bufferSize that can be used by cusparseSpMM
size_t bufferSize;
@@ -164,7 +171,7 @@
beta,
descC,
cusparse_value_type, /* data type in which the computation is executed */
- CUSPARSE_SPMM_CSR_ALG1, /* default computing algorithm for CSR sparse matrix format */
+ default_alg, /* default computing algorithm for CSR sparse matrix format */
&bufferSize /* output */
));
@@ -178,7 +185,7 @@
beta,
descC,
cusparse_value_type, /* data type in which the computation is executed */
- CUSPARSE_SPMM_CSR_ALG1, /* default computing algorithm for CSR sparse matrix format */
+ default_alg, /* default computing algorithm for CSR sparse matrix format */
dataPtr.get() /* external buffer */
));