Expose Stream Recording Apis in python (#96384)

Differential Revision: [D43999891](https://our.internmc.facebook.com/intern/diff/D43999891)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96384
Approved by: https://github.com/zdevito
diff --git a/test/test_cuda.py b/test/test_cuda.py
index d097da5..00190c0 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -5537,6 +5537,81 @@
         self.checkFunction(m, [inp])
 
 
+    def test_allocate_in_thread_to_pool(self):
+
+        def foo():
+            return torch.rand([4], device="cuda")
+
+        pool = torch.cuda.graph_pool_handle()
+        graph, outputs = cudagraphify(foo, [], pool=pool)
+        device = outputs[0].device.index
+        del outputs
+
+        @contextlib.contextmanager
+        def _use_cuda_memory_pool_manager(device, mem_pool):
+            """
+            Context manager to use cuda graph pool for new allocations. If you use this manager
+            all cudagraph tensors in use should be reflected in the allocator or they will be overwritten.
+            existing_graph should already have been used in a capture, and the mem_pool must already exist.
+            """
+            torch.cuda.synchronize()
+            stream = torch.cuda.Stream()
+            stream.wait_stream(torch.cuda.current_stream())
+            stream_context = torch.cuda.stream(stream)
+            stream_context.__enter__()
+            torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool)
+            try:
+                yield
+            finally:
+                torch._C._cuda_endAllocateCurrentStreamToPool(device)
+                torch._C._cuda_releasePool(device, mem_pool)
+                stream_context.__exit__(None, None, None)
+
+
+        segments = get_cudagraph_segments(pool)
+        self.assertEqual(len(get_cudagraph_segments(pool)), 1)
+
+        def use_pool():
+            def alloc_three():
+                a = int8_cuda(LARGE_BUFFER)
+                b = int8_cuda(LARGE_BUFFER)
+                c = a + b
+
+            with _use_cuda_memory_pool_manager(device, pool):
+                # three allocations
+                for _ in range(10):
+                    alloc_three()
+
+            # three more allocations not in pool
+            alloc_three()
+
+
+        def no_pool():
+            # two allocations
+            for _ in range(10):
+                a = int8_cuda(LARGE_BUFFER)
+                b = int8_cuda(LARGE_BUFFER)
+                del a, b
+
+        graph_thread = threading.Thread(target=use_pool)
+        no_graph_thread = threading.Thread(target=no_pool)
+        graph_thread.start()
+        no_graph_thread.start()
+
+        graph_thread.join()
+        no_graph_thread.join()
+
+        self.assertEqual(len(get_cudagraph_segments(pool)), 4)
+
+        del graph
+
+        torch.cuda.synchronize()
+        gc.collect()
+        torch.cuda.empty_cache()
+
+        self.assertEqual(len(get_cudagraph_segments(pool)), 0)
+
+
 instantiate_parametrized_tests(TestCuda)
 
 if __name__ == '__main__':
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index 7b203b6..0e3f48a 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -1045,6 +1045,25 @@
   });
 
   m.def(
+      "_cuda_beginAllocateCurrentStreamToPool",
+      [](int device, at::cuda::MempoolId_t mempool_id) {
+        auto stream = at::cuda::getCurrentCUDAStream(device);
+        TORCH_CHECK(stream, "Expected stream capture to be under way");
+        c10::cuda::CUDACachingAllocator::beginAllocateStreamToPool(
+            device, stream, mempool_id);
+      });
+
+  m.def("_cuda_endAllocateCurrentStreamToPool", [](int device) {
+    auto stream = at::cuda::getCurrentCUDAStream(device);
+    TORCH_CHECK(stream, "Expected stream capture to be under way");
+    c10::cuda::CUDACachingAllocator::endAllocateStreamToPool(device, stream);
+  });
+
+  m.def("_cuda_releasePool", [](int device, at::cuda::MempoolId_t mempool_id) {
+    c10::cuda::CUDACachingAllocator::releasePool(device, mempool_id);
+  });
+
+  m.def(
       "_cuda_setCheckpointPoolState",
       [](int device,
          std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> pps,