Revert "Introduce ProcessGroupCudaP2P (#122163)"

This reverts commit 2dd269986027ea25c092f769ef8e9524920aaef6.

Reverted https://github.com/pytorch/pytorch/pull/122163 on behalf of https://github.com/jithunnair-amd due to This is breaking ROCm distributed CI on trunk ([comment](https://github.com/pytorch/pytorch/pull/122163#issuecomment-2127518473))
diff --git a/.ci/pytorch/multigpu-test.sh b/.ci/pytorch/multigpu-test.sh
index 71c7400..7e04e92 100755
--- a/.ci/pytorch/multigpu-test.sh
+++ b/.ci/pytorch/multigpu-test.sh
@@ -18,7 +18,6 @@
 time python test/run_test.py --verbose -i distributed/test_c10d_nccl
 time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
 time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
-time python test/run_test.py --verbose -i distributed/test_cuda_p2p
 time python test/run_test.py --verbose -i distributed/test_store
 time python test/run_test.py --verbose -i distributed/test_pg_wrapper
 time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent
diff --git a/benchmarks/distributed/intra_node_comm/allgather_matmul.py b/benchmarks/distributed/intra_node_comm/allgather_matmul.py
index 19c6521..f3fddd5 100644
--- a/benchmarks/distributed/intra_node_comm/allgather_matmul.py
+++ b/benchmarks/distributed/intra_node_comm/allgather_matmul.py
@@ -1,16 +1,17 @@
 #!/usr/bin/env python3
-# This file contains an example for using cuda_p2p backend to implement efficient fused
+# This file contains an example for using IntraNodeComm to implement efficient fused
 # allgather_matmul (inspired by https://dl.acm.org/doi/pdf/10.1145/3567955.3567959 and
 # @lw's efficient GPU implementation in xformers). Its purpose to help guide the
 # development of relevant primitives and serve as an example for interested users.
 #
 # The benchmark can be executed as follows:
 #   torchrun --nproc-per-node 8 allgather_matmul.py
+#
+# NOTE: _IntraNodeComm is a prototype API which WILL change over time.
 import os
 
 import torch
-import torch.distributed as dist
-from torch.distributed._cuda_p2p import ProcessGroupCudaP2P
+import torch._C._distributed_c10d as c10d
 
 M = 16384
 N = 8192
@@ -20,60 +21,55 @@
 BENCH_ITERS = 50
 
 
-def allgather_matmul(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
-    group = dist.group.WORLD
-    group_size = group.size()
-    A = torch.ops._c10d_functional.all_gather_into_tensor(A_shard, group_size, "0")
-    A = torch.ops._c10d_functional.wait_tensor(A)
-    return A @ B
+comm = None
+internal_stream = None
+internal_event = None
 
 
-def allgather_matmul_p2p(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
+def allgather_matmul(A_shard, B, out, rank, world_size):
     """
     Equivalent to `torch.matmul(dist.all_gather(A_shard), B)`.
     """
-    group = dist.group.WORLD
-    group_size = group.size()
-    rank = group.rank()
-    backend = group._get_backend(torch.device("cuda"))
-
-    out = torch.empty(
-        (A_shard.shape[0] * group.size(), B.shape[1]),
-        dtype=A_shard.dtype,
-        device="cuda",
-    )
-    out_shards = out.chunk(group_size)
-    local_p2p_buf = backend.get_p2p_buffer(rank, A_shard.shape, A_shard.dtype)
+    buf_0 = torch.empty_like(A_shard)
+    buf_1 = torch.empty_like(A_shard)
+    out_shards = [
+        out[i : i + A_shard.shape[0]]
+        for i in range(0, world_size * A_shard.shape[0], A_shard.shape[0])
+    ]
 
     # Perform matmul with the local input shard
     torch.matmul(A_shard, B, out=out_shards[rank])
 
-    with torch.cuda.stream(backend.stream()):
-        local_p2p_buf.copy_(A_shard)
-        work = backend.intra_node_barrier()
-    work.wait()
+    # In another stream, copy the local input shard into the intra-node
+    # buffer. After the barrier, all peers' input shards are accessible
+    # via their intra-node buffer without requiring synchronization.
+    with torch.cuda.stream(internal_stream):
+        comm.put(A_shard)
+        comm.barrier()
+        internal_event.record()
+    internal_event.wait()
 
-    buf_0 = torch.empty_like(A_shard)
-    buf_1 = torch.empty_like(A_shard)
-    for i in range(1, group_size):
+    # Copy input shard from remote buffer and perform matmul.
+    # Alternate between two streams to offset the wave quantization
+    # effect of smaller matmuls.
+    for i in range(1, world_size):
         if i % 2 == 0:
             buf = buf_0
             stream = torch.cuda.current_stream()
         else:
             buf = buf_1
-            stream = backend.stream()
-        remote_rank = (i + rank) % group_size
-        remote_p2p_buf = backend.get_p2p_buffer(
-            remote_rank, A_shard.shape, A_shard.dtype
-        )
+            stream = internal_stream
+        remote = (i + rank) % world_size
         with torch.cuda.stream(stream):
-            buf.copy_(remote_p2p_buf)
-            torch.matmul(buf, B, out=out_shards[remote_rank])
+            comm.get(remote, buf)
+            torch.matmul(buf, B, out=out_shards[remote])
 
-    with torch.cuda.stream(backend.stream()):
-        work = backend.intra_node_barrier()
-    work.wait()
-    return out
+    # Perform another barrier to ensure all peers have completed consuming the
+    # intra-node buffer so it can be reused.
+    with torch.cuda.stream(internal_stream):
+        comm.barrier()
+        internal_event.record()
+    internal_event.wait()
 
 
 def do_bench(fn):
@@ -93,6 +89,8 @@
 
 
 def main():
+    os.environ["ENABLE_INTRA_NODE_COMM"] = "1"
+
     rank = int(os.environ["RANK"])
     local_rank = int(os.environ["LOCAL_RANK"])
     world_size = int(os.environ["WORLD_SIZE"])
@@ -100,32 +98,33 @@
     assert M % world_size == 0
 
     torch.cuda.set_device(local_rank)
+    store, _, _ = next(torch.distributed.rendezvous("env://", rank, world_size))
 
-    options = ProcessGroupCudaP2P.Options()
-    options.buffer_size = M * N * 2 // world_size
-    dist.init_process_group("cuda_p2p", pg_options=options)
+    global comm, internal_stream, internal_event
+    comm = c10d._IntraNodeComm(
+        store=store,
+        rank=rank,
+        world_size=world_size,
+        buffer_size=M * K * torch.finfo(torch.bfloat16).bits // 8 // world_size,
+    )
+    internal_stream = torch.cuda.Stream()
+    internal_event = torch.cuda.Event()
 
     torch.manual_seed(42)
     A = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
     B = torch.randn((K, N), dtype=torch.bfloat16, device="cuda")
+    out = torch.empty((M, N), dtype=torch.bfloat16, device="cuda")
 
     stride = M // world_size
     A_shard = A[rank * stride : (rank + 1) * stride]
 
-    assert torch.allclose(
-        allgather_matmul(A_shard, B),
-        allgather_matmul_p2p(A_shard, B),
+    comm.barrier()
+    torch.cuda.synchronize()
+    allgather_matmul_ms = do_bench(
+        lambda: allgather_matmul(A_shard, B, out, rank, world_size)
     )
 
-    dist.barrier()
-    torch.cuda.synchronize()
-    allgather_matmul_ms = do_bench(lambda: allgather_matmul(A_shard, B))
-
-    dist.barrier()
-    torch.cuda.synchronize()
-    allgather_matmul_p2p_ms = do_bench(lambda: allgather_matmul_p2p(A_shard, B))
-
-    dist.barrier()
+    comm.barrier()
     torch.cuda.synchronize()
     matmul_ms = do_bench(lambda: torch.matmul(A, B))
 
@@ -135,15 +134,8 @@
             f"(M={M // world_size}, N={N}, K={K}, world_size={world_size}): "
             f"{allgather_matmul_ms:.4} ms/iter"
         )
-        print(
-            "allgather_matmul_p2p "
-            f"(M={M // world_size}, N={N}, K={K}, world_size={world_size}): "
-            f"{allgather_matmul_p2p_ms:.4} ms/iter"
-        )
         print(f"matmul (M={M}, N={N}, K={K}): {matmul_ms:.4} ms/iter")
 
-    dist.destroy_process_group()
-
 
 if __name__ == "__main__":
     main()
diff --git a/build_variables.bzl b/build_variables.bzl
index 1d63d06..152324a 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -675,7 +675,6 @@
 # These files are only supported on Linux (and others) but not on Windows.
 libtorch_cuda_distributed_extra_sources = [
     "torch/csrc/distributed/c10d/NCCLUtils.cpp",
-    "torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp",
     "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
     "torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
     "torch/csrc/distributed/c10d/UCCTracing.cpp",
diff --git a/test/distributed/test_cuda_p2p.py b/test/distributed/test_cuda_p2p.py
deleted file mode 100644
index 4e80a6f..0000000
--- a/test/distributed/test_cuda_p2p.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# Owner(s): ["module: c10d"]
-import os
-from typing import List
-
-import torch
-
-import torch.distributed as dist
-from torch.distributed._cuda_p2p import (
-    get_cuda_p2p_backend,
-    get_p2p_buffer_size,
-    is_cuda_p2p_group,
-)
-from torch.testing._internal.common_distributed import (
-    MultiProcessTestCase,
-    requires_nccl,
-    skip_if_lt_x_gpu,
-)
-from torch.testing._internal.common_utils import (
-    run_tests,
-    skip_but_pass_in_sandcastle_if,
-)
-
-
-def requires_cuda_p2p_access():
-    cuda_p2p_access_available = (
-        torch.cuda.is_available()
-        and torch.cuda.device_count() >= 2
-        and dist.is_nccl_available()
-    )
-    num_devices = torch.cuda.device_count()
-    for i in range(num_devices - 1):
-        for j in range(i + 1, num_devices):
-            if not torch.cuda.can_device_access_peer(i, j):
-                cuda_p2p_access_available = False
-                break
-        if not cuda_p2p_access_available:
-            break
-
-    return skip_but_pass_in_sandcastle_if(
-        not cuda_p2p_access_available,
-        "cuda p2p access is not available",
-    )
-
-
-@requires_nccl()
-@requires_cuda_p2p_access()
-class ProcessGroupCudaP2PTest(MultiProcessTestCase):
-    def setUp(self) -> None:
-        super().setUp()
-        self._spawn_processes()
-
-    @property
-    def world_size(self) -> int:
-        return 2
-
-    @property
-    def ranks(self) -> List[int]:
-        return list(range(self.world_size))
-
-    @property
-    def device(self) -> torch.device:
-        return torch.device(f"cuda:{self.rank}")
-
-    def _init_process_group(self, buffer_size: int) -> None:
-        os.environ["TEST_INTRA_NODE_COMM"] = "1"
-        torch.cuda.set_device(self.device)
-
-        # Verify cuda p2p specific APIs on ProcessGroupCudaP2P
-        store = dist.FileStore(self.file_name, self.world_size)
-        options = dist.ProcessGroupCudaP2P.Options()
-        options.buffer_size = buffer_size
-        dist.init_process_group(
-            backend="cuda_p2p",
-            world_size=self.world_size,
-            rank=self.rank,
-            store=store,
-            pg_options=options,
-        )
-
-    @skip_if_lt_x_gpu(2)
-    def test_p2p_apis(self) -> None:
-        BUFFER_SIZE = 4 * 1024
-
-        self._init_process_group(BUFFER_SIZE)
-
-        # Verify cuda p2p specific APIs on ProcessGroupCudaP2P
-        assert is_cuda_p2p_group(dist.group.WORLD)
-        assert get_p2p_buffer_size(dist.group.WORLD) == BUFFER_SIZE
-
-        backend = get_cuda_p2p_backend(dist.group.WORLD)
-        assert isinstance(backend, dist.ProcessGroupCudaP2P)
-        assert backend.get_buffer_size() == BUFFER_SIZE
-
-        backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4,), torch.float)
-        with self.assertRaises(RuntimeError):
-            backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4 + 1,), torch.float)
-        with self.assertRaises(RuntimeError):
-            backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4,), torch.float, 1)
-
-        # Verify cuda p2p specific APIs on non-cuda p2p process group
-        non_cuda_p2p_pg = dist.new_group(backend="nccl")
-
-        assert not is_cuda_p2p_group(non_cuda_p2p_pg)
-        assert get_p2p_buffer_size(non_cuda_p2p_pg) == 0
-        with self.assertRaises(TypeError):
-            get_cuda_p2p_backend(non_cuda_p2p_pg)
-
-        dist.barrier()
-        torch.cuda.synchronize()
-        dist.destroy_process_group()
-
-    @skip_if_lt_x_gpu(2)
-    def test_p2p_buffer(self) -> None:
-        BUFFER_SIZE = 4 * 1024
-
-        self._init_process_group(BUFFER_SIZE)
-        rank = self.rank
-        world_size = self.world_size
-
-        assert is_cuda_p2p_group(dist.group.WORLD)
-        backend = get_cuda_p2p_backend(dist.group.WORLD)
-        local_buffer = backend.get_p2p_buffer(
-            (rank) % world_size, (BUFFER_SIZE // 4,), torch.float
-        )
-        remote_buffer = backend.get_p2p_buffer(
-            (rank + 1) % world_size, (BUFFER_SIZE // 4,), torch.float
-        )
-
-        local_buffer.fill_(rank)
-        backend.intra_node_barrier()
-        assert remote_buffer.eq((rank + 1) % world_size).all()
-
-        dist.barrier()
-        torch.cuda.synchronize()
-        dist.destroy_process_group()
-
-
-if __name__ == "__main__":
-    run_tests()
diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi
index 1a3e4ea..74a73a3 100644
--- a/torch/_C/_distributed_c10d.pyi
+++ b/torch/_C/_distributed_c10d.pyi
@@ -605,30 +605,3 @@
 def _resolve_process_group(group_name: str) -> ProcessGroup: ...
 def _unregister_all_process_groups() -> None: ...
 def _unregister_process_group(group_name: str) -> None: ...
-
-class ProcessGroupCudaP2P(Backend):
-    class Options:
-        nccl_options: Optional[ProcessGroupNCCL.Options]
-        buffer_size: Optional[int]
-
-        def __init__(self) -> None: ...
-
-    def __init__(
-        self,
-        store: Store,
-        rank: int,
-        size: int,
-        options: ProcessGroupCudaP2P.Options,
-    ) -> None: ...
-    def is_p2p_available(self) -> bool: ...
-    def get_buffer_size(self) -> int: ...
-    def stream(self) -> torch.cuda.Stream: ...
-    def intra_node_barrier(self) -> Work: ...
-    def get_p2p_buffer(
-        self,
-        rank: int,
-        sizes: torch.Size,
-        dtype: torch.dtype,
-        storage_offset: Optional[int] = 0,
-    ) -> torch.Tensor: ...
-    def _shutdown(self) -> None: ...
diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp
deleted file mode 100644
index 6280eb1..0000000
--- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp
+++ /dev/null
@@ -1,206 +0,0 @@
-#ifdef USE_C10D_NCCL
-#include <torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp>
-
-#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
-
-namespace c10d {
-
-using namespace c10d::intra_node_comm;
-
-ProcessGroupCudaP2P::ProcessGroupCudaP2P(
-    const c10::intrusive_ptr<Store>& store,
-    int rank,
-    int size,
-    c10::intrusive_ptr<Options> options)
-    : Backend(rank, size), stream_(c10::cuda::getStreamFromPool()) {
-  nccl_backend_ = c10::make_intrusive<ProcessGroupNCCL>(
-      c10::make_intrusive<PrefixStore>("nccl", store),
-      rank,
-      size,
-      options->nccl_options);
-  nccl_backend_->setSequenceNumberForGroup();
-
-  p2p_backend_ = c10::make_intrusive<IntraNodeComm>(
-      c10::make_intrusive<PrefixStore>("p2p", store),
-      rank,
-      size,
-      options->buffer_size);
-  if (!p2p_backend_->rendezvous()) {
-    p2p_backend_ = nullptr;
-  }
-}
-
-bool ProcessGroupCudaP2P::is_p2p_available() {
-  return p2p_backend_ != nullptr &&
-      p2p_backend_->getTopology() == Topology::FULLY_CONNECTED;
-}
-
-size_t ProcessGroupCudaP2P::get_buffer_size() {
-  if (p2p_backend_ == nullptr) {
-    return 0;
-  }
-  return p2p_backend_->getBufferSize();
-}
-
-c10::Stream ProcessGroupCudaP2P::stream() {
-  return stream_;
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::broadcast(
-    std::vector<at::Tensor>& tensors,
-    const BroadcastOptions& opts) {
-  return nccl_backend_->broadcast(tensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allreduce(
-    std::vector<at::Tensor>& tensors,
-    const AllreduceOptions& opts) {
-  return nccl_backend_->allreduce(tensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allreduce_sparse(
-    std::vector<at::Tensor>& tensors,
-    const AllreduceOptions& opts) {
-  return nccl_backend_->allreduce_sparse(tensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allreduce_coalesced(
-    std::vector<at::Tensor>& tensors,
-    const AllreduceCoalescedOptions& opts) {
-  return nccl_backend_->allreduce_coalesced(tensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::reduce(
-    std::vector<at::Tensor>& tensors,
-    const ReduceOptions& opts) {
-  return nccl_backend_->reduce(tensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allgather(
-    std::vector<std::vector<at::Tensor>>& outputTensors,
-    std::vector<at::Tensor>& inputTensors,
-    const AllgatherOptions& opts) {
-  return nccl_backend_->allgather(outputTensors, inputTensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::_allgather_base(
-    at::Tensor& outputBuffer,
-    at::Tensor& inputBuffer,
-    const AllgatherOptions& opts) {
-  return nccl_backend_->_allgather_base(outputBuffer, inputBuffer, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allgather_coalesced(
-    std::vector<std::vector<at::Tensor>>& outputTensorLists,
-    std::vector<at::Tensor>& inputTensors,
-    const AllgatherOptions& opts) {
-  return nccl_backend_->allgather_coalesced(
-      outputTensorLists, inputTensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allgather_into_tensor_coalesced(
-    std::vector<at::Tensor>& outputs,
-    std::vector<at::Tensor>& inputs,
-    const AllgatherOptions& opts) {
-  return nccl_backend_->allgather_into_tensor_coalesced(outputs, inputs, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::gather(
-    std::vector<std::vector<at::Tensor>>& outputTensors,
-    std::vector<at::Tensor>& inputTensors,
-    const GatherOptions& opts) {
-  return nccl_backend_->gather(outputTensors, inputTensors);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::scatter(
-    std::vector<at::Tensor>& outputTensors,
-    std::vector<std::vector<at::Tensor>>& inputTensors,
-    const ScatterOptions& opts) {
-  return nccl_backend_->scatter(outputTensors, inputTensors);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::reduce_scatter(
-    std::vector<at::Tensor>& outputTensors,
-    std::vector<std::vector<at::Tensor>>& inputTensors,
-    const ReduceScatterOptions& opts) {
-  return nccl_backend_->reduce_scatter(outputTensors, inputTensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::_reduce_scatter_base(
-    at::Tensor& outputBuffer,
-    at::Tensor& inputBuffer,
-    const ReduceScatterOptions& opts) {
-  return nccl_backend_->_reduce_scatter_base(outputBuffer, inputBuffer, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::reduce_scatter_tensor_coalesced(
-    std::vector<at::Tensor>& outputs,
-    std::vector<at::Tensor>& inputs,
-    const ReduceScatterOptions& opts) {
-  return nccl_backend_->reduce_scatter_tensor_coalesced(outputs, inputs, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::alltoall_base(
-    at::Tensor& outputBuffer,
-    at::Tensor& inputBuffer,
-    std::vector<int64_t>& outputSplitSizes,
-    std::vector<int64_t>& inputSplitSizes,
-    const AllToAllOptions& opts) {
-  return nccl_backend_->alltoall_base(
-      outputBuffer, inputBuffer, outputSplitSizes, inputSplitSizes);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::alltoall(
-    std::vector<at::Tensor>& outputTensors,
-    std::vector<at::Tensor>& inputTensors,
-    const AllToAllOptions& opts) {
-  return nccl_backend_->alltoall(outputTensors, inputTensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::send(
-    std::vector<at::Tensor>& tensors,
-    int dstRank,
-    int tag) {
-  return nccl_backend_->send(tensors, dstRank, tag);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::recv(
-    std::vector<at::Tensor>& tensors,
-    int srcRank,
-    int tag) {
-  return nccl_backend_->recv(tensors, srcRank, tag);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::recvAnysource(
-    std::vector<at::Tensor>& tensors,
-    int tag) {
-  return nccl_backend_->recvAnysource(tensors, tag);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::barrier(
-    const BarrierOptions& opts) {
-  return nccl_backend_->barrier(opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::intra_node_barrier(
-    c10::optional<std::vector<int64_t>> ranks) {
-  TORCH_CHECK(p2p_backend_ != nullptr);
-  p2p_backend_->barrier(ranks);
-  return c10::make_intrusive<IntraNodeCommWork>();
-}
-
-at::Tensor ProcessGroupCudaP2P::get_p2p_buffer(
-    size_t rank,
-    const std::vector<int64_t>& sizes,
-    c10::ScalarType dtype,
-    int64_t storage_offset) {
-  TORCH_CHECK(p2p_backend_ != nullptr);
-  return p2p_backend_->getBuffer(rank, sizes, dtype, storage_offset);
-}
-
-void ProcessGroupCudaP2P::shutdown(c10::optional<std::string> reason) {
-  nccl_backend_->shutdown(reason);
-}
-
-} // namespace c10d
-#endif // USE_C10D_NCCL
diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp
deleted file mode 100644
index cff4ad0..0000000
--- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp
+++ /dev/null
@@ -1,148 +0,0 @@
-#pragma once
-
-#ifdef USE_C10D_NCCL
-#include <torch/csrc/distributed/c10d/Backend.hpp>
-#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
-#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
-
-constexpr auto kProcessGroupCudaP2PDefaultTimeout =
-    std::chrono::milliseconds(10 * 60 * 1000);
-
-namespace c10d {
-
-class TORCH_API ProcessGroupCudaP2P : public Backend {
- public:
-  struct Options : Backend::Options {
-    c10::intrusive_ptr<ProcessGroupNCCL::Options> nccl_options;
-    c10::optional<size_t> buffer_size;
-
-    explicit Options()
-        : Backend::Options("cuda_p2p", kProcessGroupCudaP2PDefaultTimeout) {}
-  };
-
-  bool is_p2p_available();
-  size_t get_buffer_size();
-
-  c10::Stream stream();
-
-  ProcessGroupCudaP2P(
-      const c10::intrusive_ptr<Store>& store,
-      int rank,
-      int size,
-      c10::intrusive_ptr<Options> options);
-
-  c10::intrusive_ptr<Work> broadcast(
-      std::vector<at::Tensor>& tensors,
-      const BroadcastOptions& opts = BroadcastOptions()) override;
-
-  c10::intrusive_ptr<Work> allreduce(
-      std::vector<at::Tensor>& tensors,
-      const AllreduceOptions& opts = AllreduceOptions()) override;
-
-  c10::intrusive_ptr<Work> allreduce_sparse(
-      std::vector<at::Tensor>& tensors,
-      const AllreduceOptions& opts = AllreduceOptions()) override;
-
-  c10::intrusive_ptr<Work> allreduce_coalesced(
-      std::vector<at::Tensor>& tensors,
-      const AllreduceCoalescedOptions& opts =
-          AllreduceCoalescedOptions()) override;
-
-  c10::intrusive_ptr<Work> reduce(
-      std::vector<at::Tensor>& tensors,
-      const ReduceOptions& opts = ReduceOptions()) override;
-
-  c10::intrusive_ptr<Work> allgather(
-      std::vector<std::vector<at::Tensor>>& outputTensors,
-      std::vector<at::Tensor>& inputTensors,
-      const AllgatherOptions& opts = AllgatherOptions()) override;
-
-  c10::intrusive_ptr<Work> _allgather_base(
-      at::Tensor& outputBuffer,
-      at::Tensor& inputBuffer,
-      const AllgatherOptions& opts = AllgatherOptions()) override;
-
-  c10::intrusive_ptr<Work> allgather_coalesced(
-      std::vector<std::vector<at::Tensor>>& outputTensorLists,
-      std::vector<at::Tensor>& inputTensors,
-      const AllgatherOptions& opts = AllgatherOptions()) override;
-
-  c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
-      std::vector<at::Tensor>& outputs,
-      std::vector<at::Tensor>& inputs,
-      const AllgatherOptions& opts = AllgatherOptions()) override;
-
-  c10::intrusive_ptr<Work> gather(
-      std::vector<std::vector<at::Tensor>>& outputTensors,
-      std::vector<at::Tensor>& inputTensors,
-      const GatherOptions& opts = GatherOptions()) override;
-
-  c10::intrusive_ptr<Work> scatter(
-      std::vector<at::Tensor>& outputTensors,
-      std::vector<std::vector<at::Tensor>>& inputTensors,
-      const ScatterOptions& opts = ScatterOptions()) override;
-
-  c10::intrusive_ptr<Work> reduce_scatter(
-      std::vector<at::Tensor>& outputTensors,
-      std::vector<std::vector<at::Tensor>>& inputTensors,
-      const ReduceScatterOptions& opts) override;
-
-  c10::intrusive_ptr<Work> _reduce_scatter_base(
-      at::Tensor& outputBuffer,
-      at::Tensor& inputBuffer,
-      const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
-
-  c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
-      std::vector<at::Tensor>& outputs,
-      std::vector<at::Tensor>& inputs,
-      const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
-
-  c10::intrusive_ptr<Work> alltoall_base(
-      at::Tensor& outputBuffer,
-      at::Tensor& inputBuffer,
-      std::vector<int64_t>& outputSplitSizes,
-      std::vector<int64_t>& inputSplitSizes,
-      const AllToAllOptions& opts = AllToAllOptions()) override;
-
-  c10::intrusive_ptr<Work> alltoall(
-      std::vector<at::Tensor>& outputTensors,
-      std::vector<at::Tensor>& inputTensors,
-      const AllToAllOptions& opts = AllToAllOptions()) override;
-
-  c10::intrusive_ptr<Work> send(
-      std::vector<at::Tensor>& tensors,
-      int dstRank,
-      int tag) override;
-
-  c10::intrusive_ptr<Work> recv(
-      std::vector<at::Tensor>& tensors,
-      int srcRank,
-      int tag) override;
-
-  c10::intrusive_ptr<Work> recvAnysource(
-      std::vector<at::Tensor>& tensors,
-      int tag) override;
-
-  /* P2P-only */
-  c10::intrusive_ptr<Work> barrier(
-      const BarrierOptions& opts = BarrierOptions()) override;
-
-  c10::intrusive_ptr<Work> intra_node_barrier(
-      c10::optional<std::vector<int64_t>> ranks = c10::nullopt);
-
-  at::Tensor get_p2p_buffer(
-      size_t rank,
-      const std::vector<int64_t>& sizes,
-      c10::ScalarType dtype,
-      int64_t storage_offest = 0);
-
-  void shutdown(c10::optional<std::string> reason = c10::nullopt);
-
- private:
-  c10::intrusive_ptr<ProcessGroupNCCL> nccl_backend_;
-  c10::intrusive_ptr<c10d::intra_node_comm::IntraNodeComm> p2p_backend_;
-  c10::Stream stream_;
-};
-
-} // namespace c10d
-#endif // USE_C10D_NCCL
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 2aaf900..505b64e 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -24,7 +24,6 @@
 
 #ifdef USE_C10D_NCCL
 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
-#include <torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp>
 #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
 #include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
 #endif
@@ -2645,7 +2644,14 @@
           py::arg("rank"),
           py::arg("world_size"),
           py::arg("buffer_size") = c10::nullopt)
-      .def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none());
+      .def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none())
+      .def("put", &IntraNodeComm::put, py::arg("input"), py::arg("offset") = 0)
+      .def(
+          "get",
+          &IntraNodeComm::get,
+          py::arg("rank"),
+          py::arg("tensor"),
+          py::arg("offset") = 0);
 
 #ifdef NCCL_HAS_COMM_CTA_CGA
   py::class_<ncclConfig_t>(
@@ -2721,54 +2727,6 @@
       .def_readwrite(
           "group_name", &::c10d::ProcessGroupNCCL::Options::group_name);
 
-  auto processGroupCudaP2P =
-      intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupCudaP2P>(
-          module, "ProcessGroupCudaP2P", backend)
-          .def(py::init<
-               const c10::intrusive_ptr<::c10d::Store>&,
-               int,
-               int,
-               c10::intrusive_ptr<::c10d::ProcessGroupCudaP2P::Options>>())
-          .def(
-              "is_p2p_available",
-              &::c10d::ProcessGroupCudaP2P::is_p2p_available)
-          .def("get_buffer_size", &::c10d::ProcessGroupCudaP2P::get_buffer_size)
-          .def("stream", &::c10d::ProcessGroupCudaP2P::stream)
-          .def(
-              "intra_node_barrier",
-              &::c10d::ProcessGroupCudaP2P::intra_node_barrier,
-              py::arg("ranks") = py::none())
-          .def(
-              "get_p2p_buffer",
-              [](c10::intrusive_ptr<::c10d::ProcessGroupCudaP2P> self,
-                 size_t rank,
-                 const std::vector<int64_t>& sizes,
-                 py::object data_type_obj,
-                 int64_t storage_offset) {
-                auto scalar_type =
-                    reinterpret_cast<THPDtype*>(data_type_obj.ptr())
-                        ->scalar_type;
-                return self->get_p2p_buffer(
-                    rank, sizes, scalar_type, storage_offset);
-              },
-              py::arg("rank"),
-              py::arg("sizes"),
-              py::arg("dtype"),
-              py::arg("storage_offset") = 0)
-          .def(
-              "_shutdown",
-              [](const c10::intrusive_ptr<::c10d::ProcessGroupCudaP2P>& self) {
-                return self->shutdown();
-              });
-
-  intrusive_ptr_class_<::c10d::ProcessGroupCudaP2P::Options>(
-      processGroupCudaP2P, "Options", processGroupOptions)
-      .def(py::init<>())
-      .def_readwrite(
-          "nccl_options", &::c10d::ProcessGroupCudaP2P::Options::nccl_options)
-      .def_readwrite(
-          "buffer_size", &::c10d::ProcessGroupCudaP2P::Options::buffer_size);
-
 #endif
 
 #ifdef USE_C10D_MPI
diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp
index 85136a9..ceec7bb 100644
--- a/torch/csrc/distributed/c10d/intra_node_comm.cpp
+++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp
@@ -211,8 +211,9 @@
     : store_(std::move(store)),
       rank_(rank),
       worldSize_(worldSize),
-      bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize),
-      barrierReady_(at::cuda::CUDAEvent()) {}
+      bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize) {
+  rendezvous();
+}
 
 IntraNodeComm::~IntraNodeComm() {
   if (!isInitialized_) {
@@ -288,7 +289,7 @@
     return true;
   }
 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
-  if (!isIntraNodeCommSupported() || worldSize_ < 2 ||
+  if (!isIntraNodeCommSupported() || !isEnabled() || worldSize_ < 2 ||
       worldSize_ > kMaxDevices) {
     return false;
   }
diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu
index 51fc625..ce479cd 100644
--- a/torch/csrc/distributed/c10d/intra_node_comm.cu
+++ b/torch/csrc/distributed/c10d/intra_node_comm.cu
@@ -504,8 +504,7 @@
     at::cuda::CUDAStream& stream) {
   checkInput(input, rank_);
 
-  const size_t numelPerWarp =
-      kBytesPerThread / input.element_size() * kWarpSize;
+  const size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
   const size_t N_aligned = alignUp(input.numel(), numelPerWarp);
   const bool isAligned = (N_aligned == static_cast<size_t>(input.numel()));
   TORCH_CHECK(N_aligned <= bufferSize_ / input.element_size());
@@ -734,7 +733,6 @@
 }
 
 void IntraNodeComm::barrier(std::optional<std::vector<int64_t>> ranks) {
-  barrierReady_.block(at::cuda::getCurrentCUDAStream());
   if (!ranks.has_value()) {
     ranks = std::vector<int64_t>(worldSize_);
     std::iota(ranks->begin(), ranks->end(), 0);
@@ -747,23 +745,44 @@
   barrierKernel<<<1, kWarpSize, 0, at::cuda::getCurrentCUDAStream()>>>(
       reinterpret_cast<P2pState**>(p2pStatesDev_), mask, rank_, worldSize_);
   C10_CUDA_KERNEL_LAUNCH_CHECK();
-  barrierReady_.record();
 }
 
-at::Tensor IntraNodeComm::getBuffer(
-    size_t rank,
-    const std::vector<int64_t>& sizes,
-    c10::ScalarType dtype,
-    int64_t storageOffset) {
-  const auto numel = std::accumulate(sizes.begin(), sizes.end(), 0);
-  const auto elementSize = c10::elementSize(dtype);
-  TORCH_CHECK((numel + storageOffset) * elementSize <= bufferSize_);
-  auto options = at::TensorOptions().dtype(dtype).device(
-      at::kCUDA, at::cuda::current_device());
-  return at::for_blob(buffers_[rank], sizes)
-      .storage_offset(storageOffset)
-      .options(options)
-      .make_tensor();
+void IntraNodeComm::put(const at::Tensor& tensor, int64_t offset) {
+  TORCH_CHECK(
+      tensor.is_non_overlapping_and_dense(),
+      "IntraNodeComm::put(): tensor must be non-overlapping and dense");
+  size_t sz = tensor.numel() * tensor.element_size();
+  TORCH_CHECK(
+      offset + sz <= bufferSize_,
+      "IntraNodeComm::put(): offset + tensor size exceeded "
+      "p2p buffer size");
+  // This results in "Memcpy PtoP" which does not use SMs for copying
+  AT_CUDA_CHECK(cudaMemcpyAsync(
+      static_cast<char*>(buffers_[rank_]) + offset,
+      static_cast<char*>(tensor.data_ptr()),
+      sz,
+      cudaMemcpyDeviceToDevice,
+      at::cuda::getCurrentCUDAStream()));
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+void IntraNodeComm::get(size_t rank, at::Tensor tensor, int64_t offset) {
+  TORCH_CHECK(
+      tensor.is_non_overlapping_and_dense(),
+      "IntraNodeComm::get(): tensor must be non-overlapping and dense");
+  size_t sz = tensor.numel() * tensor.element_size();
+  TORCH_CHECK(
+      offset + sz <= bufferSize_,
+      "IntraNodeComm::get(): offset + tensor size exceeded "
+      "p2p buffer size");
+  // This results in "Memcpy PtoP" which does not use SMs for copying
+  AT_CUDA_CHECK(cudaMemcpyAsync(
+      static_cast<char*>(tensor.data_ptr()),
+      static_cast<char*>(buffers_[rank]) + offset,
+      sz,
+      cudaMemcpyDeviceToDevice,
+      at::cuda::getCurrentCUDAStream()));
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
 }
 
 } // namespace intra_node_comm
diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp
index 5d7e2d4..fe59197 100644
--- a/torch/csrc/distributed/c10d/intra_node_comm.hpp
+++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp
@@ -46,10 +46,6 @@
    */
   bool rendezvous();
 
-  Topology getTopology() {
-    return topology_;
-  }
-
   size_t getBufferSize() {
     return bufferSize_;
   }
@@ -67,11 +63,17 @@
    */
   void barrier(std::optional<std::vector<int64_t>> ranks = c10::nullopt);
 
-  at::Tensor getBuffer(
-      size_t rank,
-      const std::vector<int64_t>& sizes,
-      c10::ScalarType dtype,
-      int64_t storageOffset);
+  /**
+   * Puts the given tensor into the p2p buffer of the current rank at the
+   * specified offset.
+   */
+  void put(const at::Tensor& tensor, int64_t offset = 0);
+
+  /**
+   * Fills the given tensor with the data from the specified rank's p2p buffer
+   * at the specified offset.
+   */
+  void get(size_t rank, at::Tensor tensor, int64_t offset = 0);
 
  private:
   at::Tensor oneShotAllReduce(
@@ -90,7 +92,6 @@
   size_t rank_;
   size_t worldSize_;
   size_t bufferSize_;
-  at::cuda::CUDAEvent barrierReady_;
 
   /**
    * Members initialized after rendezvous
diff --git a/torch/distributed/_cuda_p2p/__init__.py b/torch/distributed/_cuda_p2p/__init__.py
deleted file mode 100644
index f91d1f2..0000000
--- a/torch/distributed/_cuda_p2p/__init__.py
+++ /dev/null
@@ -1,123 +0,0 @@
-from contextlib import contextmanager
-
-from functools import partial
-from typing import Callable, cast, List, Tuple, Union
-
-import torch
-import torch.distributed._functional_collectives as funcol
-
-import torch.distributed.distributed_c10d as c10d
-from torch._C._distributed_c10d import _DistributedBackendOptions, Backend
-
-
-"""
-This file contains the registration logic and Python APIs for
-``ProcessGroupCudaP2P`` (experimental).
-
-``ProcessGroupCudaP2P`` is a thin wrapper around ``ProcessGroupNCCL``. By
-default, it routes all collectives to the underlying ``ProcessGroupNCCL``. In
-addition, ``ProcessGroupCudaP2P`` initializes a P2P workspace that allows
-direct GPU memory access among the members. The workspace can be used in Python
-to optimize intra-node communication patterns or to create custom intra-node
-collectives in CUDA.
-
-``ProcessGroupCudaP2P`` aims to bridge the gap where certain important patterns
-can be better optimized via fine-grained P2P memory access than with
-collectives in the latest version of NCCL. It is meant to complement NCCL
-rather than replacing it.
-
-Usage:
-
-    # Using ProcessGroupCudaP2P
-    dist.init_process_group(backend="cuda_p2p", ...)
-
-    # Using ProcessGroupCudaP2P while specifying ProcessGroupCudaP2P.Options
-    pg_options = ProcessGroupCudaP2P.Options()
-    dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
-
-    # Using ProcessGroupCudaP2P while specifying ProcessGroupNCCL.Options
-    pg_options = ProcessGroupNCCL.Options()
-    dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
-
-    # Using ProcessGroupCudaP2P while specifying both
-    # ProcessGroupCudaP2P.Options and ProcessGroupNCCL.Options
-    pg_options = ProcessGroupCudaP2P.Options()
-    pg_options.nccl_options = ProcessGroupNCCL.Options()
-    dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
-
-    # Down-casting the backend to access p2p buffers for cuda_p2p specific
-    # optimizations
-    if is_cuda_p2p_group(group):
-        backend = get_cuda_p2p_backend(group)
-        if required_p2p_buffer_size > backend.get_buffer_size():
-            # fallback
-        p2p_buffer = backend.get_p2p_buffer(...)
-    else:
-        # fallback
-"""
-
-
-def _create_cuda_p2p_group(
-    dist_backend_opts: "_DistributedBackendOptions",
-    options: Union[
-        "c10d.ProcessGroupCudaP2P.Options", "c10d.ProcessGroupNCCL.Options", None
-    ],
-) -> "Backend":
-    if not c10d.is_nccl_available():
-        raise RuntimeError("The cuda_p2p backend is not available")
-    if options is None:
-        options = c10d.ProcessGroupCudaP2P.Options()
-        options.nccl_options = c10d.ProcessGroupNCCL.Options()
-    elif isinstance(options, c10d.ProcessGroupNCCL.Options):
-        nccl_options = options
-        options = c10d.ProcessGroupCudaP2P.Options()
-        options.nccl_options = nccl_options
-    elif isinstance(options, c10d.ProcessGroupCudaP2P.Options):
-        if options.nccl_options is None:
-            options.nccl_options = c10d.ProcessGroupNCCL.Options()
-    else:
-        raise TypeError(
-            "options for cuda_p2p must be ProcessGroupCudaP2P.Options "
-            f"or ProcessGroupNCCL.Options (got: {type(options)})"
-        )
-
-    return c10d.ProcessGroupCudaP2P(
-        dist_backend_opts.store,
-        dist_backend_opts.group_rank,
-        dist_backend_opts.group_size,
-        options,
-    )
-
-
-def is_cuda_p2p_group(group: c10d.ProcessGroup) -> bool:
-    if not c10d.is_nccl_available():
-        return False
-    try:
-        backend = group._get_backend(torch.device("cuda"))
-    except Exception:
-        return False
-    return isinstance(backend, c10d.ProcessGroupCudaP2P) and backend.is_p2p_available()
-
-
-def get_cuda_p2p_backend(group: c10d.ProcessGroup) -> "c10d.ProcessGroupCudaP2P":
-    if not is_cuda_p2p_group(group):
-        raise TypeError("group is not a cuda_p2p process group.")
-    return cast(
-        c10d.ProcessGroupCudaP2P,
-        group._get_backend(torch.device("cuda")),
-    )
-
-
-def get_p2p_buffer_size(group: c10d.ProcessGroup) -> int:
-    if not is_cuda_p2p_group(group):
-        return 0
-    backend = get_cuda_p2p_backend(group)
-    return backend.get_buffer_size()
-
-
-c10d.Backend.register_backend(
-    "cuda_p2p",
-    _create_cuda_p2p_group,
-    extended_api=True,
-    devices=["cuda"],
-)
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 3ea412f..c6fc227 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -110,10 +110,8 @@
 
 try:
     from torch._C._distributed_c10d import ProcessGroupNCCL
-    from torch._C._distributed_c10d import ProcessGroupCudaP2P
     ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
-    ProcessGroupCudaP2P.__module__ = "torch.distributed.distributed_c10d"
-    __all__ += ["ProcessGroupNCCL", "ProcessGroupCudaP2P"]
+    __all__ += ["ProcessGroupNCCL"]
 except ImportError:
     _NCCL_AVAILABLE = False
 
@@ -1446,7 +1444,7 @@
         backend = pg._get_backend(torch.device("cuda"))
     except RuntimeError:
         pass
-    if is_nccl_available() and isinstance(backend, (ProcessGroupNCCL, ProcessGroupCudaP2P)):
+    if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
         # explictly call shutdown to ensure that NCCL resources are released
         backend._shutdown()