Expose Stream Recording Apis in python (#96384)
Differential Revision: [D43999891](https://our.internmc.facebook.com/intern/diff/D43999891)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96384
Approved by: https://github.com/zdevito
diff --git a/test/test_cuda.py b/test/test_cuda.py
index d097da5..00190c0 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -5537,6 +5537,81 @@
self.checkFunction(m, [inp])
+ def test_allocate_in_thread_to_pool(self):
+
+ def foo():
+ return torch.rand([4], device="cuda")
+
+ pool = torch.cuda.graph_pool_handle()
+ graph, outputs = cudagraphify(foo, [], pool=pool)
+ device = outputs[0].device.index
+ del outputs
+
+ @contextlib.contextmanager
+ def _use_cuda_memory_pool_manager(device, mem_pool):
+ """
+ Context manager to use cuda graph pool for new allocations. If you use this manager
+ all cudagraph tensors in use should be reflected in the allocator or they will be overwritten.
+ existing_graph should already have been used in a capture, and the mem_pool must already exist.
+ """
+ torch.cuda.synchronize()
+ stream = torch.cuda.Stream()
+ stream.wait_stream(torch.cuda.current_stream())
+ stream_context = torch.cuda.stream(stream)
+ stream_context.__enter__()
+ torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool)
+ try:
+ yield
+ finally:
+ torch._C._cuda_endAllocateCurrentStreamToPool(device)
+ torch._C._cuda_releasePool(device, mem_pool)
+ stream_context.__exit__(None, None, None)
+
+
+ segments = get_cudagraph_segments(pool)
+ self.assertEqual(len(get_cudagraph_segments(pool)), 1)
+
+ def use_pool():
+ def alloc_three():
+ a = int8_cuda(LARGE_BUFFER)
+ b = int8_cuda(LARGE_BUFFER)
+ c = a + b
+
+ with _use_cuda_memory_pool_manager(device, pool):
+ # three allocations
+ for _ in range(10):
+ alloc_three()
+
+ # three more allocations not in pool
+ alloc_three()
+
+
+ def no_pool():
+ # two allocations
+ for _ in range(10):
+ a = int8_cuda(LARGE_BUFFER)
+ b = int8_cuda(LARGE_BUFFER)
+ del a, b
+
+ graph_thread = threading.Thread(target=use_pool)
+ no_graph_thread = threading.Thread(target=no_pool)
+ graph_thread.start()
+ no_graph_thread.start()
+
+ graph_thread.join()
+ no_graph_thread.join()
+
+ self.assertEqual(len(get_cudagraph_segments(pool)), 4)
+
+ del graph
+
+ torch.cuda.synchronize()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ self.assertEqual(len(get_cudagraph_segments(pool)), 0)
+
+
instantiate_parametrized_tests(TestCuda)
if __name__ == '__main__':
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index 7b203b6..0e3f48a 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -1045,6 +1045,25 @@
});
m.def(
+ "_cuda_beginAllocateCurrentStreamToPool",
+ [](int device, at::cuda::MempoolId_t mempool_id) {
+ auto stream = at::cuda::getCurrentCUDAStream(device);
+ TORCH_CHECK(stream, "Expected stream capture to be under way");
+ c10::cuda::CUDACachingAllocator::beginAllocateStreamToPool(
+ device, stream, mempool_id);
+ });
+
+ m.def("_cuda_endAllocateCurrentStreamToPool", [](int device) {
+ auto stream = at::cuda::getCurrentCUDAStream(device);
+ TORCH_CHECK(stream, "Expected stream capture to be under way");
+ c10::cuda::CUDACachingAllocator::endAllocateStreamToPool(device, stream);
+ });
+
+ m.def("_cuda_releasePool", [](int device, at::cuda::MempoolId_t mempool_id) {
+ c10::cuda::CUDACachingAllocator::releasePool(device, mempool_id);
+ });
+
+ m.def(
"_cuda_setCheckpointPoolState",
[](int device,
std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> pps,