[CUBLAS][CUDA GRAPHS] Explicitly set the workspace for cuBLAS handles (#83461)

We're seeing an issue where repeatedly capturing graphs incurs increasing memory usage as cuBLAS internally allocates a new workspace for each graph even when the same handle is being used:
https://gist.github.com/tomconerlyanth/a20c04a4a46a0f6e9ce18f5280729b36

This PR works around the issue by intercepting the `CUBLAS_WORKSPACE_CONFIG` environment variable and allocating the workspace for the cuBLAS handle explicitly.

CC @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83461
Approved by: https://github.com/ngimel
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
index 08fa4e4..6c5db20 100644
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
@@ -1,6 +1,10 @@
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/cuda/detail/DeviceThreadHandles.h>
 
+#include <c10/cuda/CUDACachingAllocator.h>
+
+#include <regex>
+
 namespace at { namespace cuda {
 namespace {
 
@@ -25,6 +29,42 @@
 
 } // namespace
 
+static std::map<std::tuple<void *, void *>, at::DataPtr> handle_stream_to_workspace;
+
+size_t parseChosenWorkspaceSize() {
+  const char * val = getenv("CUBLAS_WORKSPACE_CONFIG");
+  if (val) {
+    size_t total_size = 0;
+    const std::string config(val);
+    std::regex exp(":([0-9]+):([0-9]+)");
+    std::sregex_iterator next(config.begin(), config.end(), exp);
+    std::sregex_iterator end;
+    if (next == end) {
+      TORCH_WARN("Could not parse CUBLAS_WORKSPACE_CONFIG, using default workspace size of 4096.");
+      return 4096 * 1024;
+    }
+    while (next != end) {
+      std::smatch match = *next;
+      TORCH_CHECK(match.size() == 3, "Expected CUBLAS_WORKSPACE_SPACE_CONFIG match of size 3 (Format :SIZE:COUNT)");
+      size_t curr_size = (size_t) std::stoi(match.str(1));
+      total_size += curr_size * 1024;
+      next++;
+    }
+    return total_size;
+  } else /* :4096:8 */ {
+    return 4096 * 1024;
+  }
+}
+
+size_t getChosenWorkspaceSize() {
+  static size_t pool_size = parseChosenWorkspaceSize();
+  return pool_size;
+}
+
+at::DataPtr getNewWorkspace() {
+  return c10::cuda::CUDACachingAllocator::get()->allocate(getChosenWorkspaceSize());
+}
+
 cublasHandle_t getCurrentCUDABlasHandle() {
   int device;
   AT_CUDA_CHECK(cudaGetDevice(&device));
@@ -47,6 +87,15 @@
   auto handle = myPoolWindow->reserve(device);
   auto stream = c10::cuda::getCurrentCUDAStream();
   TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
+#ifndef USE_ROCM
+  cudaStream_t _stream = stream;
+  auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
+  if (handle_stream_to_workspace.find(key) == handle_stream_to_workspace.end()) {
+      auto workspace_ptr = getNewWorkspace();
+      handle_stream_to_workspace[key] = std::move(workspace_ptr);
+  }
+  TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, handle_stream_to_workspace[key].get(), getChosenWorkspaceSize()));
+#endif
 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
   // On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
   // FP32 data type calculations based on the value of the allow_tf32 flag.
diff --git a/test/distributed/fsdp/test_fsdp_memory.py b/test/distributed/fsdp/test_fsdp_memory.py
index b463216..e05dbd9 100644
--- a/test/distributed/fsdp/test_fsdp_memory.py
+++ b/test/distributed/fsdp/test_fsdp_memory.py
@@ -180,43 +180,45 @@
         expected = {}
 
         for iteration in range(iterations):
+            # another 4 MiB per thread/stream/cuBLAS handle expected
+            # after initial cuBLAS workspace allocation change #83461
             if iteration == 0:
                 # sharded model size + 1MB temp memory
                 expected[f"iter {iteration}: start"] = sharded_model_size_mb + 1
                 # it is hard to calculate this memory size, get it from printed memory usage
                 if ckpt == "ckpt":
-                    expected[f"iter {iteration}: after fwd"] = 51
-                    expected[f"iter {iteration}: after loss"] = 51
+                    expected[f"iter {iteration}: after fwd"] = 51 + 4
+                    expected[f"iter {iteration}: after loss"] = 51 + 4
                 else:
-                    expected[f"iter {iteration}: after fwd"] = 340
-                    expected[f"iter {iteration}: after loss"] = 340
+                    expected[f"iter {iteration}: after fwd"] = 340 + 4
+                    expected[f"iter {iteration}: after loss"] = 340 + 4
                 # sharded model size + sharded grad size + 1M temp memory
-                expected[f"iter {iteration}: after bwd"] = 2 * sharded_model_size_mb + 1
+                expected[f"iter {iteration}: after bwd"] = 2 * sharded_model_size_mb + 1 + 4 + 4
             else:
                 # after optimizer step in the first iteraiton, memory usage increased by
                 # sharded_model_size_mb becasue of increased optimizer states memory usage
-                expected[f"iter {iteration}: start"] = 2 * sharded_model_size_mb + 1
+                expected[f"iter {iteration}: start"] = 2 * sharded_model_size_mb + 1 + 4 + 4
                 if ckpt == "ckpt":
                     expected[f"iter {iteration}: after fwd"] = (
-                        51 + sharded_model_size_mb
+                        51 + sharded_model_size_mb + 4 + 4
                     )
                     expected[f"iter {iteration}: after loss"] = (
-                        51 + sharded_model_size_mb
+                        51 + sharded_model_size_mb + 4 + 4
                     )
                 else:
                     expected[f"iter {iteration}: after fwd"] = (
-                        340 + sharded_model_size_mb
+                        340 + sharded_model_size_mb + 4 + 4
                     )
                     expected[f"iter {iteration}: after loss"] = (
-                        340 + sharded_model_size_mb
+                        340 + sharded_model_size_mb + 4 + 4
                     )
-                expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1
+                expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1 + 4 + 4
 
             # sharded model size + sharded grad size + optimizer states + 1M temp memory
-            expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1
+            expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1 + 4 + 4
             # grad memory is claimed after setting grad = None
             # sharded model size + optimizer states + 1M temp memory
-            expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1
+            expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1 + 4 + 4
 
         # Get the fsdp and checkpoint flags.
         with_ckpt = ckpt == "ckpt"
diff --git a/test/distributed/pipeline/sync/test_balance.py b/test/distributed/pipeline/sync/test_balance.py
index 0072573..8d31f1a 100644
--- a/test/distributed/pipeline/sync/test_balance.py
+++ b/test/distributed/pipeline/sync/test_balance.py
@@ -99,13 +99,13 @@
 
 @skip_if_no_cuda
 def test_balance_by_size_param():
-    model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)])
-    sample = torch.rand(7, 1)
+    model = nn.Sequential(*[nn.Linear((i + 1) * 1000, (i + 2) * 1000) for i in range(6)])
+    sample = torch.rand(7, 1000)
     balance = balance_by_size(2, model, sample, param_scale=100)
     assert balance == [4, 2]
 
-    model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))])
-    sample = torch.rand(1, 7)
+    model = nn.Sequential(*[nn.Linear((i + 2) * 1000, (i + 1) * 1000) for i in reversed(range(6))])
+    sample = torch.rand(1, 7000)
     balance = balance_by_size(2, model, sample, param_scale=100)
     assert balance == [2, 4]
 
@@ -124,15 +124,15 @@
             return x
 
     model = nn.Sequential(
-        Tradeoff(param_size=1, latent_size=6),
-        Tradeoff(param_size=2, latent_size=5),
-        Tradeoff(param_size=3, latent_size=4),
-        Tradeoff(param_size=4, latent_size=3),
-        Tradeoff(param_size=5, latent_size=2),
-        Tradeoff(param_size=6, latent_size=1),
+        Tradeoff(param_size=1000, latent_size=6),
+        Tradeoff(param_size=2000, latent_size=5),
+        Tradeoff(param_size=3000, latent_size=4),
+        Tradeoff(param_size=4000, latent_size=3),
+        Tradeoff(param_size=5000, latent_size=2),
+        Tradeoff(param_size=6000, latent_size=1),
     )
 
-    sample = torch.rand(1, requires_grad=True)
+    sample = torch.rand(1000, requires_grad=True)
 
     balance = balance_by_size(2, model, sample, param_scale=0)
     assert balance == [2, 4]
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 601089c..e04f0dc 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -3154,6 +3154,32 @@
     @unittest.skipIf((not TEST_CUDA) or
                      TEST_WITH_ROCM or
                      int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
+    @skipCUDAMemoryLeakCheckIf(True)  # This test may incur an expected allocation for a cuBLAS workspace
+    def test_repeat_graph_capture_cublas_workspace_memory(self):
+        (x, y, z) = 1024, 512, 64
+        a = torch.rand((x, y), device='cuda')
+        b = torch.rand((y, z), device='cuda')
+
+        # warmup
+        torch.mm(a, b)
+
+        free_bytes_before, total_bytes = torch.cuda.mem_get_info()
+        used_gb_before = (total_bytes - free_bytes_before) / 1e9
+
+        for i in range(100):
+            torch_graph = torch.cuda.CUDAGraph()
+            with torch.cuda.graph(torch_graph):
+                torch.mm(a, b)
+            torch_graph.replay()
+
+        free_bytes_after, _ = torch.cuda.mem_get_info()
+        used_gb_after = (total_bytes - free_bytes_after) / 1e9
+
+        self.assertFalse(used_gb_before + 0.1 < used_gb_after)
+
+    @unittest.skipIf((not TEST_CUDA) or
+                     TEST_WITH_ROCM or
+                     int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
     def test_graph_rng_functional(self):
         ops_with_kwargs = ((torch.nn.functional.dropout, {"p": 0.1}),
                            (torch.nn.functional.rrelu, {"training": True}),)