Workaround for cublas bug for 45724 (#46001)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/45724
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46001
Reviewed By: mruberry
Differential Revision: D24184058
Pulled By: ngimel
fbshipit-source-id: 7d2bab3206ddbc10a7cae3efd9b5e253f38400a9
diff --git a/aten/src/THC/THCBlas.cu b/aten/src/THC/THCBlas.cu
index 859d904..3f16eec 100644
--- a/aten/src/THC/THCBlas.cu
+++ b/aten/src/THC/THCBlas.cu
@@ -133,6 +133,56 @@
at::cuda::blas::gemm<double>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
+#ifndef __HIP_PLATFORM_HCC__
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200
+#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx
+#else
+// Workaround for https://github.com/pytorch/pytorch/issues/45724
+cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle,
+ cublasOperation_t transa,
+ cublasOperation_t transb,
+ int m,
+ int n,
+ int k,
+ const void *alpha,
+ const void *A,
+ cudaDataType Atype,
+ int lda,
+ long long int strideA,
+ const void *B,
+ cudaDataType Btype,
+ int ldb,
+ long long int strideB,
+ const void *beta,
+ void *C,
+ cudaDataType Ctype,
+ int ldc,
+ long long int strideC,
+ int64_t batchCount,
+ cudaDataType computeType,
+ cublasGemmAlgo_t algo)
+{
+ cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+ if (prop->major != 7) {
+ return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo);
+ }
+ cublasStatus_t result;
+ constexpr int64_t split = 63 * 1024;
+ for(int64_t i = 0; i < batchCount; i += split) {
+ int64_t count = std::min<int64_t>(split, batchCount - i);
+ result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha,
+ (char *)A + i * strideA * 2, Atype, lda, strideA,
+ (char *)B + i * strideB * 2, Btype, ldb, strideB,
+ beta,
+ (char *)C + i * strideC * 2, Ctype, ldc, strideC,
+ (int)count, computeType, algo);
+ THCublasCheck(result);
+ }
+ return result;
+}
+#endif
+#endif
+
void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
at::Half alpha, const at::Half *a, int64_t lda, int64_t strideA, const at::Half *b, int64_t ldb, int64_t strideB,
at::Half beta, at::Half *c, int64_t ldc, int64_t strideC, int64_t batchCount)
@@ -167,7 +217,7 @@
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000
- THCublasCheck(cublasGemmStridedBatchedEx(handle,
+ THCublasCheck(cublasGemmStridedBatchedExFix(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,
b, CUDA_R_16F, (int)ldb, strideB,
@@ -207,7 +257,7 @@
if (prop->major < 8) {
TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU");
}
- THCublasCheck(cublasGemmStridedBatchedEx(handle,
+ THCublasCheck(cublasGemmStridedBatchedExFix(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, CUDA_R_16BF, (int)lda, strideA,
b, CUDA_R_16BF, (int)ldb, strideB,
diff --git a/test/test_torch.py b/test/test_torch.py
index 312943d..39f2f92 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -16814,6 +16814,16 @@
m2 = torch.randn(k, m, device=device).to(dtype)
self._test_addmm_addmv(torch.addmm, M, m1, m2)
+ @onlyCUDA
+ def test_matmul_45724(self, device):
+ # https://github.com/pytorch/pytorch/issues/45724
+ a = torch.rand(65537, 22, 64).cuda().half()
+ b = torch.rand(65537, 64, 22).cuda().half()
+ c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device='cuda')
+ cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).cuda().half()
+ torch.matmul(a, b, out=c)
+ self.assertEqual(c, cpu_result)
+
def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
def compare_with_numpy_bin_op(torch_fn, np_fn, x, y):
y_np = y.cpu().numpy()