Workaround for cublas bug for 45724 (#46001)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/45724

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46001

Reviewed By: mruberry

Differential Revision: D24184058

Pulled By: ngimel

fbshipit-source-id: 7d2bab3206ddbc10a7cae3efd9b5e253f38400a9
diff --git a/aten/src/THC/THCBlas.cu b/aten/src/THC/THCBlas.cu
index 859d904..3f16eec 100644
--- a/aten/src/THC/THCBlas.cu
+++ b/aten/src/THC/THCBlas.cu
@@ -133,6 +133,56 @@
   at::cuda::blas::gemm<double>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
 }
 
+#ifndef __HIP_PLATFORM_HCC__
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200
+#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx
+#else
+// Workaround for https://github.com/pytorch/pytorch/issues/45724
+cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle,
+  cublasOperation_t transa,
+  cublasOperation_t transb,
+  int m,
+  int n,
+  int k,
+  const void    *alpha,
+  const void     *A,
+  cudaDataType Atype,
+  int lda,
+  long long int strideA,
+  const void     *B,
+  cudaDataType Btype,
+  int ldb,
+  long long int strideB,
+  const void    *beta,
+  void           *C,
+  cudaDataType Ctype,
+  int ldc,
+  long long int strideC,
+  int64_t batchCount,
+  cudaDataType computeType,
+  cublasGemmAlgo_t algo)
+{
+  cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+  if (prop->major != 7) {
+    return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo);
+  }
+  cublasStatus_t result;
+  constexpr int64_t split = 63 * 1024;
+  for(int64_t i = 0; i < batchCount; i += split) {
+    int64_t count = std::min<int64_t>(split, batchCount - i);
+    result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha,
+      (char *)A + i * strideA * 2, Atype, lda, strideA,
+      (char *)B + i * strideB * 2, Btype, ldb, strideB,
+      beta,
+      (char *)C + i * strideC * 2, Ctype, ldc, strideC,
+      (int)count, computeType, algo);
+    THCublasCheck(result);
+  }
+  return result;
+}
+#endif
+#endif
+
 void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
                              at::Half alpha, const at::Half *a, int64_t lda, int64_t strideA, const at::Half *b, int64_t ldb, int64_t strideB,
                              at::Half beta, at::Half *c, int64_t ldc, int64_t strideC, int64_t batchCount)
@@ -167,7 +217,7 @@
   // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
   THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
 #endif  // CUDA_VERSION < 11000 
-  THCublasCheck(cublasGemmStridedBatchedEx(handle,
+  THCublasCheck(cublasGemmStridedBatchedExFix(handle,
                                    opa, opb, (int)m, (int)n, (int)k,
                                    (void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,
                                    b, CUDA_R_16F, (int)ldb, strideB,
@@ -207,7 +257,7 @@
   if (prop->major < 8) {
     TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU");
   }
-  THCublasCheck(cublasGemmStridedBatchedEx(handle,
+  THCublasCheck(cublasGemmStridedBatchedExFix(handle,
                                    opa, opb, (int)m, (int)n, (int)k,
                                    (void*)&fAlpha, a, CUDA_R_16BF, (int)lda, strideA,
                                    b, CUDA_R_16BF, (int)ldb, strideB,
diff --git a/test/test_torch.py b/test/test_torch.py
index 312943d..39f2f92 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -16814,6 +16814,16 @@
                     m2 = torch.randn(k, m, device=device).to(dtype)
                     self._test_addmm_addmv(torch.addmm, M, m1, m2)
 
+    @onlyCUDA
+    def test_matmul_45724(self, device):
+        # https://github.com/pytorch/pytorch/issues/45724
+        a = torch.rand(65537, 22, 64).cuda().half()
+        b = torch.rand(65537, 64, 22).cuda().half()
+        c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device='cuda')
+        cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).cuda().half()
+        torch.matmul(a, b, out=c)
+        self.assertEqual(c, cpu_result)
+
     def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
         def compare_with_numpy_bin_op(torch_fn, np_fn, x, y):
             y_np = y.cpu().numpy()