[CUBLAS][CUDA GRAPHS] Explicitly set the workspace for cuBLAS handles (#83461)
We're seeing an issue where repeatedly capturing graphs incurs increasing memory usage as cuBLAS internally allocates a new workspace for each graph even when the same handle is being used:
https://gist.github.com/tomconerlyanth/a20c04a4a46a0f6e9ce18f5280729b36
This PR works around the issue by intercepting the `CUBLAS_WORKSPACE_CONFIG` environment variable and allocating the workspace for the cuBLAS handle explicitly.
CC @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83461
Approved by: https://github.com/ngimel
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
index 08fa4e4..6c5db20 100644
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
@@ -1,6 +1,10 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/DeviceThreadHandles.h>
+#include <c10/cuda/CUDACachingAllocator.h>
+
+#include <regex>
+
namespace at { namespace cuda {
namespace {
@@ -25,6 +29,42 @@
} // namespace
+static std::map<std::tuple<void *, void *>, at::DataPtr> handle_stream_to_workspace;
+
+size_t parseChosenWorkspaceSize() {
+ const char * val = getenv("CUBLAS_WORKSPACE_CONFIG");
+ if (val) {
+ size_t total_size = 0;
+ const std::string config(val);
+ std::regex exp(":([0-9]+):([0-9]+)");
+ std::sregex_iterator next(config.begin(), config.end(), exp);
+ std::sregex_iterator end;
+ if (next == end) {
+ TORCH_WARN("Could not parse CUBLAS_WORKSPACE_CONFIG, using default workspace size of 4096.");
+ return 4096 * 1024;
+ }
+ while (next != end) {
+ std::smatch match = *next;
+ TORCH_CHECK(match.size() == 3, "Expected CUBLAS_WORKSPACE_SPACE_CONFIG match of size 3 (Format :SIZE:COUNT)");
+ size_t curr_size = (size_t) std::stoi(match.str(1));
+ total_size += curr_size * 1024;
+ next++;
+ }
+ return total_size;
+ } else /* :4096:8 */ {
+ return 4096 * 1024;
+ }
+}
+
+size_t getChosenWorkspaceSize() {
+ static size_t pool_size = parseChosenWorkspaceSize();
+ return pool_size;
+}
+
+at::DataPtr getNewWorkspace() {
+ return c10::cuda::CUDACachingAllocator::get()->allocate(getChosenWorkspaceSize());
+}
+
cublasHandle_t getCurrentCUDABlasHandle() {
int device;
AT_CUDA_CHECK(cudaGetDevice(&device));
@@ -47,6 +87,15 @@
auto handle = myPoolWindow->reserve(device);
auto stream = c10::cuda::getCurrentCUDAStream();
TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
+#ifndef USE_ROCM
+ cudaStream_t _stream = stream;
+ auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
+ if (handle_stream_to_workspace.find(key) == handle_stream_to_workspace.end()) {
+ auto workspace_ptr = getNewWorkspace();
+ handle_stream_to_workspace[key] = std::move(workspace_ptr);
+ }
+ TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, handle_stream_to_workspace[key].get(), getChosenWorkspaceSize()));
+#endif
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.
diff --git a/test/distributed/fsdp/test_fsdp_memory.py b/test/distributed/fsdp/test_fsdp_memory.py
index b463216..e05dbd9 100644
--- a/test/distributed/fsdp/test_fsdp_memory.py
+++ b/test/distributed/fsdp/test_fsdp_memory.py
@@ -180,43 +180,45 @@
expected = {}
for iteration in range(iterations):
+ # another 4 MiB per thread/stream/cuBLAS handle expected
+ # after initial cuBLAS workspace allocation change #83461
if iteration == 0:
# sharded model size + 1MB temp memory
expected[f"iter {iteration}: start"] = sharded_model_size_mb + 1
# it is hard to calculate this memory size, get it from printed memory usage
if ckpt == "ckpt":
- expected[f"iter {iteration}: after fwd"] = 51
- expected[f"iter {iteration}: after loss"] = 51
+ expected[f"iter {iteration}: after fwd"] = 51 + 4
+ expected[f"iter {iteration}: after loss"] = 51 + 4
else:
- expected[f"iter {iteration}: after fwd"] = 340
- expected[f"iter {iteration}: after loss"] = 340
+ expected[f"iter {iteration}: after fwd"] = 340 + 4
+ expected[f"iter {iteration}: after loss"] = 340 + 4
# sharded model size + sharded grad size + 1M temp memory
- expected[f"iter {iteration}: after bwd"] = 2 * sharded_model_size_mb + 1
+ expected[f"iter {iteration}: after bwd"] = 2 * sharded_model_size_mb + 1 + 4 + 4
else:
# after optimizer step in the first iteraiton, memory usage increased by
# sharded_model_size_mb becasue of increased optimizer states memory usage
- expected[f"iter {iteration}: start"] = 2 * sharded_model_size_mb + 1
+ expected[f"iter {iteration}: start"] = 2 * sharded_model_size_mb + 1 + 4 + 4
if ckpt == "ckpt":
expected[f"iter {iteration}: after fwd"] = (
- 51 + sharded_model_size_mb
+ 51 + sharded_model_size_mb + 4 + 4
)
expected[f"iter {iteration}: after loss"] = (
- 51 + sharded_model_size_mb
+ 51 + sharded_model_size_mb + 4 + 4
)
else:
expected[f"iter {iteration}: after fwd"] = (
- 340 + sharded_model_size_mb
+ 340 + sharded_model_size_mb + 4 + 4
)
expected[f"iter {iteration}: after loss"] = (
- 340 + sharded_model_size_mb
+ 340 + sharded_model_size_mb + 4 + 4
)
- expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1
+ expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1 + 4 + 4
# sharded model size + sharded grad size + optimizer states + 1M temp memory
- expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1
+ expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1 + 4 + 4
# grad memory is claimed after setting grad = None
# sharded model size + optimizer states + 1M temp memory
- expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1
+ expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1 + 4 + 4
# Get the fsdp and checkpoint flags.
with_ckpt = ckpt == "ckpt"
diff --git a/test/distributed/pipeline/sync/test_balance.py b/test/distributed/pipeline/sync/test_balance.py
index 0072573..8d31f1a 100644
--- a/test/distributed/pipeline/sync/test_balance.py
+++ b/test/distributed/pipeline/sync/test_balance.py
@@ -99,13 +99,13 @@
@skip_if_no_cuda
def test_balance_by_size_param():
- model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)])
- sample = torch.rand(7, 1)
+ model = nn.Sequential(*[nn.Linear((i + 1) * 1000, (i + 2) * 1000) for i in range(6)])
+ sample = torch.rand(7, 1000)
balance = balance_by_size(2, model, sample, param_scale=100)
assert balance == [4, 2]
- model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))])
- sample = torch.rand(1, 7)
+ model = nn.Sequential(*[nn.Linear((i + 2) * 1000, (i + 1) * 1000) for i in reversed(range(6))])
+ sample = torch.rand(1, 7000)
balance = balance_by_size(2, model, sample, param_scale=100)
assert balance == [2, 4]
@@ -124,15 +124,15 @@
return x
model = nn.Sequential(
- Tradeoff(param_size=1, latent_size=6),
- Tradeoff(param_size=2, latent_size=5),
- Tradeoff(param_size=3, latent_size=4),
- Tradeoff(param_size=4, latent_size=3),
- Tradeoff(param_size=5, latent_size=2),
- Tradeoff(param_size=6, latent_size=1),
+ Tradeoff(param_size=1000, latent_size=6),
+ Tradeoff(param_size=2000, latent_size=5),
+ Tradeoff(param_size=3000, latent_size=4),
+ Tradeoff(param_size=4000, latent_size=3),
+ Tradeoff(param_size=5000, latent_size=2),
+ Tradeoff(param_size=6000, latent_size=1),
)
- sample = torch.rand(1, requires_grad=True)
+ sample = torch.rand(1000, requires_grad=True)
balance = balance_by_size(2, model, sample, param_scale=0)
assert balance == [2, 4]
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 601089c..e04f0dc 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -3154,6 +3154,32 @@
@unittest.skipIf((not TEST_CUDA) or
TEST_WITH_ROCM or
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
+ @skipCUDAMemoryLeakCheckIf(True) # This test may incur an expected allocation for a cuBLAS workspace
+ def test_repeat_graph_capture_cublas_workspace_memory(self):
+ (x, y, z) = 1024, 512, 64
+ a = torch.rand((x, y), device='cuda')
+ b = torch.rand((y, z), device='cuda')
+
+ # warmup
+ torch.mm(a, b)
+
+ free_bytes_before, total_bytes = torch.cuda.mem_get_info()
+ used_gb_before = (total_bytes - free_bytes_before) / 1e9
+
+ for i in range(100):
+ torch_graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(torch_graph):
+ torch.mm(a, b)
+ torch_graph.replay()
+
+ free_bytes_after, _ = torch.cuda.mem_get_info()
+ used_gb_after = (total_bytes - free_bytes_after) / 1e9
+
+ self.assertFalse(used_gb_before + 0.1 < used_gb_after)
+
+ @unittest.skipIf((not TEST_CUDA) or
+ TEST_WITH_ROCM or
+ int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
def test_graph_rng_functional(self):
ops_with_kwargs = ((torch.nn.functional.dropout, {"p": 0.1}),
(torch.nn.functional.rrelu, {"training": True}),)