[CUDA][cuBLAS] Remove explicit cuBLAS workspace allocation for CUDA 12.2+ (#113994)

cuBLAS should be using `cudaMallocAsync` in CUDA 12.2+, which removes the need for explicit workspace allocation to avoid increasing memory usage with multiple graph captures.

CC @ptrblck @malfet

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113994
Approved by: https://github.com/ezyang, https://github.com/malfet
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
index e495dff..dae61a4 100644
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
@@ -40,7 +40,9 @@
 } // namespace
 
 void clearCublasWorkspaces() {
-  cublas_handle_stream_to_workspace().clear();
+  #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION < 12200
+      cublas_handle_stream_to_workspace().clear();
+  #endif
 }
 
 size_t parseChosenWorkspaceSize() {
@@ -105,8 +107,10 @@
   auto handle = myPoolWindow->reserve(device);
   auto stream = c10::cuda::getCurrentCUDAStream();
   TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
-#if !defined(USE_ROCM)
-  // cublasSetWorkspace not available on CUDA 10.2
+#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION < 12200
+  // cuBLAS should not need an explicitly allocated workspace after CUDA 12.2
+  // to avoid increasing memory usage during graph captures
+  // original issue: https://github.com/pytorch/pytorch/pull/83461
   cudaStream_t _stream = stream;
   auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
   auto workspace_it = cublas_handle_stream_to_workspace().find(key);
diff --git a/test/test_cuda.py b/test/test_cuda.py
index e4bb314..6bd2d41 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -29,7 +29,8 @@
     NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_WINDOWS, \
     slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_CUDA, TEST_CUDA_GRAPH, TEST_WITH_ROCM, TEST_NUMPY, \
     get_cycles_per_ms, parametrize, instantiate_parametrized_tests, subtest, IS_JETSON, gcIfJetson, NoTest, IS_LINUX
-from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_MULTIGPU, _create_scaling_case, _create_scaling_models_optimizers
+from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_MULTIGPU, \
+    _create_scaling_case, _create_scaling_models_optimizers, _get_torch_cuda_version
 from torch.testing._internal.autocast_test_lists import AutocastTestLists
 from torch.utils.viz._cycles import observe_tensor_cycles
 
@@ -296,6 +297,7 @@
         self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
 
     @unittest.skipIf(TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async")
+    @unittest.skipIf(_get_torch_cuda_version() >= (12, 2), "skipped as explicit workspace allocation is removed")
     def test_cublas_workspace_explicit_allocation(self):
         a = torch.randn(7, 7, device='cuda', requires_grad=False)
         default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024  # :4096:2:16:8