Revert "Introduce ProcessGroupCudaP2P (#122163)"
This reverts commit 2dd269986027ea25c092f769ef8e9524920aaef6.
Reverted https://github.com/pytorch/pytorch/pull/122163 on behalf of https://github.com/jithunnair-amd due to This is breaking ROCm distributed CI on trunk ([comment](https://github.com/pytorch/pytorch/pull/122163#issuecomment-2127518473))
diff --git a/.ci/pytorch/multigpu-test.sh b/.ci/pytorch/multigpu-test.sh
index 71c7400..7e04e92 100755
--- a/.ci/pytorch/multigpu-test.sh
+++ b/.ci/pytorch/multigpu-test.sh
@@ -18,7 +18,6 @@
time python test/run_test.py --verbose -i distributed/test_c10d_nccl
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
-time python test/run_test.py --verbose -i distributed/test_cuda_p2p
time python test/run_test.py --verbose -i distributed/test_store
time python test/run_test.py --verbose -i distributed/test_pg_wrapper
time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent
diff --git a/benchmarks/distributed/intra_node_comm/allgather_matmul.py b/benchmarks/distributed/intra_node_comm/allgather_matmul.py
index 19c6521..f3fddd5 100644
--- a/benchmarks/distributed/intra_node_comm/allgather_matmul.py
+++ b/benchmarks/distributed/intra_node_comm/allgather_matmul.py
@@ -1,16 +1,17 @@
#!/usr/bin/env python3
-# This file contains an example for using cuda_p2p backend to implement efficient fused
+# This file contains an example for using IntraNodeComm to implement efficient fused
# allgather_matmul (inspired by https://dl.acm.org/doi/pdf/10.1145/3567955.3567959 and
# @lw's efficient GPU implementation in xformers). Its purpose to help guide the
# development of relevant primitives and serve as an example for interested users.
#
# The benchmark can be executed as follows:
# torchrun --nproc-per-node 8 allgather_matmul.py
+#
+# NOTE: _IntraNodeComm is a prototype API which WILL change over time.
import os
import torch
-import torch.distributed as dist
-from torch.distributed._cuda_p2p import ProcessGroupCudaP2P
+import torch._C._distributed_c10d as c10d
M = 16384
N = 8192
@@ -20,60 +21,55 @@
BENCH_ITERS = 50
-def allgather_matmul(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
- group = dist.group.WORLD
- group_size = group.size()
- A = torch.ops._c10d_functional.all_gather_into_tensor(A_shard, group_size, "0")
- A = torch.ops._c10d_functional.wait_tensor(A)
- return A @ B
+comm = None
+internal_stream = None
+internal_event = None
-def allgather_matmul_p2p(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
+def allgather_matmul(A_shard, B, out, rank, world_size):
"""
Equivalent to `torch.matmul(dist.all_gather(A_shard), B)`.
"""
- group = dist.group.WORLD
- group_size = group.size()
- rank = group.rank()
- backend = group._get_backend(torch.device("cuda"))
-
- out = torch.empty(
- (A_shard.shape[0] * group.size(), B.shape[1]),
- dtype=A_shard.dtype,
- device="cuda",
- )
- out_shards = out.chunk(group_size)
- local_p2p_buf = backend.get_p2p_buffer(rank, A_shard.shape, A_shard.dtype)
+ buf_0 = torch.empty_like(A_shard)
+ buf_1 = torch.empty_like(A_shard)
+ out_shards = [
+ out[i : i + A_shard.shape[0]]
+ for i in range(0, world_size * A_shard.shape[0], A_shard.shape[0])
+ ]
# Perform matmul with the local input shard
torch.matmul(A_shard, B, out=out_shards[rank])
- with torch.cuda.stream(backend.stream()):
- local_p2p_buf.copy_(A_shard)
- work = backend.intra_node_barrier()
- work.wait()
+ # In another stream, copy the local input shard into the intra-node
+ # buffer. After the barrier, all peers' input shards are accessible
+ # via their intra-node buffer without requiring synchronization.
+ with torch.cuda.stream(internal_stream):
+ comm.put(A_shard)
+ comm.barrier()
+ internal_event.record()
+ internal_event.wait()
- buf_0 = torch.empty_like(A_shard)
- buf_1 = torch.empty_like(A_shard)
- for i in range(1, group_size):
+ # Copy input shard from remote buffer and perform matmul.
+ # Alternate between two streams to offset the wave quantization
+ # effect of smaller matmuls.
+ for i in range(1, world_size):
if i % 2 == 0:
buf = buf_0
stream = torch.cuda.current_stream()
else:
buf = buf_1
- stream = backend.stream()
- remote_rank = (i + rank) % group_size
- remote_p2p_buf = backend.get_p2p_buffer(
- remote_rank, A_shard.shape, A_shard.dtype
- )
+ stream = internal_stream
+ remote = (i + rank) % world_size
with torch.cuda.stream(stream):
- buf.copy_(remote_p2p_buf)
- torch.matmul(buf, B, out=out_shards[remote_rank])
+ comm.get(remote, buf)
+ torch.matmul(buf, B, out=out_shards[remote])
- with torch.cuda.stream(backend.stream()):
- work = backend.intra_node_barrier()
- work.wait()
- return out
+ # Perform another barrier to ensure all peers have completed consuming the
+ # intra-node buffer so it can be reused.
+ with torch.cuda.stream(internal_stream):
+ comm.barrier()
+ internal_event.record()
+ internal_event.wait()
def do_bench(fn):
@@ -93,6 +89,8 @@
def main():
+ os.environ["ENABLE_INTRA_NODE_COMM"] = "1"
+
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
@@ -100,32 +98,33 @@
assert M % world_size == 0
torch.cuda.set_device(local_rank)
+ store, _, _ = next(torch.distributed.rendezvous("env://", rank, world_size))
- options = ProcessGroupCudaP2P.Options()
- options.buffer_size = M * N * 2 // world_size
- dist.init_process_group("cuda_p2p", pg_options=options)
+ global comm, internal_stream, internal_event
+ comm = c10d._IntraNodeComm(
+ store=store,
+ rank=rank,
+ world_size=world_size,
+ buffer_size=M * K * torch.finfo(torch.bfloat16).bits // 8 // world_size,
+ )
+ internal_stream = torch.cuda.Stream()
+ internal_event = torch.cuda.Event()
torch.manual_seed(42)
A = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
B = torch.randn((K, N), dtype=torch.bfloat16, device="cuda")
+ out = torch.empty((M, N), dtype=torch.bfloat16, device="cuda")
stride = M // world_size
A_shard = A[rank * stride : (rank + 1) * stride]
- assert torch.allclose(
- allgather_matmul(A_shard, B),
- allgather_matmul_p2p(A_shard, B),
+ comm.barrier()
+ torch.cuda.synchronize()
+ allgather_matmul_ms = do_bench(
+ lambda: allgather_matmul(A_shard, B, out, rank, world_size)
)
- dist.barrier()
- torch.cuda.synchronize()
- allgather_matmul_ms = do_bench(lambda: allgather_matmul(A_shard, B))
-
- dist.barrier()
- torch.cuda.synchronize()
- allgather_matmul_p2p_ms = do_bench(lambda: allgather_matmul_p2p(A_shard, B))
-
- dist.barrier()
+ comm.barrier()
torch.cuda.synchronize()
matmul_ms = do_bench(lambda: torch.matmul(A, B))
@@ -135,15 +134,8 @@
f"(M={M // world_size}, N={N}, K={K}, world_size={world_size}): "
f"{allgather_matmul_ms:.4} ms/iter"
)
- print(
- "allgather_matmul_p2p "
- f"(M={M // world_size}, N={N}, K={K}, world_size={world_size}): "
- f"{allgather_matmul_p2p_ms:.4} ms/iter"
- )
print(f"matmul (M={M}, N={N}, K={K}): {matmul_ms:.4} ms/iter")
- dist.destroy_process_group()
-
if __name__ == "__main__":
main()
diff --git a/build_variables.bzl b/build_variables.bzl
index 1d63d06..152324a 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -675,7 +675,6 @@
# These files are only supported on Linux (and others) but not on Windows.
libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/NCCLUtils.cpp",
- "torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp",
"torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
"torch/csrc/distributed/c10d/UCCTracing.cpp",
diff --git a/test/distributed/test_cuda_p2p.py b/test/distributed/test_cuda_p2p.py
deleted file mode 100644
index 4e80a6f..0000000
--- a/test/distributed/test_cuda_p2p.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# Owner(s): ["module: c10d"]
-import os
-from typing import List
-
-import torch
-
-import torch.distributed as dist
-from torch.distributed._cuda_p2p import (
- get_cuda_p2p_backend,
- get_p2p_buffer_size,
- is_cuda_p2p_group,
-)
-from torch.testing._internal.common_distributed import (
- MultiProcessTestCase,
- requires_nccl,
- skip_if_lt_x_gpu,
-)
-from torch.testing._internal.common_utils import (
- run_tests,
- skip_but_pass_in_sandcastle_if,
-)
-
-
-def requires_cuda_p2p_access():
- cuda_p2p_access_available = (
- torch.cuda.is_available()
- and torch.cuda.device_count() >= 2
- and dist.is_nccl_available()
- )
- num_devices = torch.cuda.device_count()
- for i in range(num_devices - 1):
- for j in range(i + 1, num_devices):
- if not torch.cuda.can_device_access_peer(i, j):
- cuda_p2p_access_available = False
- break
- if not cuda_p2p_access_available:
- break
-
- return skip_but_pass_in_sandcastle_if(
- not cuda_p2p_access_available,
- "cuda p2p access is not available",
- )
-
-
-@requires_nccl()
-@requires_cuda_p2p_access()
-class ProcessGroupCudaP2PTest(MultiProcessTestCase):
- def setUp(self) -> None:
- super().setUp()
- self._spawn_processes()
-
- @property
- def world_size(self) -> int:
- return 2
-
- @property
- def ranks(self) -> List[int]:
- return list(range(self.world_size))
-
- @property
- def device(self) -> torch.device:
- return torch.device(f"cuda:{self.rank}")
-
- def _init_process_group(self, buffer_size: int) -> None:
- os.environ["TEST_INTRA_NODE_COMM"] = "1"
- torch.cuda.set_device(self.device)
-
- # Verify cuda p2p specific APIs on ProcessGroupCudaP2P
- store = dist.FileStore(self.file_name, self.world_size)
- options = dist.ProcessGroupCudaP2P.Options()
- options.buffer_size = buffer_size
- dist.init_process_group(
- backend="cuda_p2p",
- world_size=self.world_size,
- rank=self.rank,
- store=store,
- pg_options=options,
- )
-
- @skip_if_lt_x_gpu(2)
- def test_p2p_apis(self) -> None:
- BUFFER_SIZE = 4 * 1024
-
- self._init_process_group(BUFFER_SIZE)
-
- # Verify cuda p2p specific APIs on ProcessGroupCudaP2P
- assert is_cuda_p2p_group(dist.group.WORLD)
- assert get_p2p_buffer_size(dist.group.WORLD) == BUFFER_SIZE
-
- backend = get_cuda_p2p_backend(dist.group.WORLD)
- assert isinstance(backend, dist.ProcessGroupCudaP2P)
- assert backend.get_buffer_size() == BUFFER_SIZE
-
- backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4,), torch.float)
- with self.assertRaises(RuntimeError):
- backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4 + 1,), torch.float)
- with self.assertRaises(RuntimeError):
- backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4,), torch.float, 1)
-
- # Verify cuda p2p specific APIs on non-cuda p2p process group
- non_cuda_p2p_pg = dist.new_group(backend="nccl")
-
- assert not is_cuda_p2p_group(non_cuda_p2p_pg)
- assert get_p2p_buffer_size(non_cuda_p2p_pg) == 0
- with self.assertRaises(TypeError):
- get_cuda_p2p_backend(non_cuda_p2p_pg)
-
- dist.barrier()
- torch.cuda.synchronize()
- dist.destroy_process_group()
-
- @skip_if_lt_x_gpu(2)
- def test_p2p_buffer(self) -> None:
- BUFFER_SIZE = 4 * 1024
-
- self._init_process_group(BUFFER_SIZE)
- rank = self.rank
- world_size = self.world_size
-
- assert is_cuda_p2p_group(dist.group.WORLD)
- backend = get_cuda_p2p_backend(dist.group.WORLD)
- local_buffer = backend.get_p2p_buffer(
- (rank) % world_size, (BUFFER_SIZE // 4,), torch.float
- )
- remote_buffer = backend.get_p2p_buffer(
- (rank + 1) % world_size, (BUFFER_SIZE // 4,), torch.float
- )
-
- local_buffer.fill_(rank)
- backend.intra_node_barrier()
- assert remote_buffer.eq((rank + 1) % world_size).all()
-
- dist.barrier()
- torch.cuda.synchronize()
- dist.destroy_process_group()
-
-
-if __name__ == "__main__":
- run_tests()
diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi
index 1a3e4ea..74a73a3 100644
--- a/torch/_C/_distributed_c10d.pyi
+++ b/torch/_C/_distributed_c10d.pyi
@@ -605,30 +605,3 @@
def _resolve_process_group(group_name: str) -> ProcessGroup: ...
def _unregister_all_process_groups() -> None: ...
def _unregister_process_group(group_name: str) -> None: ...
-
-class ProcessGroupCudaP2P(Backend):
- class Options:
- nccl_options: Optional[ProcessGroupNCCL.Options]
- buffer_size: Optional[int]
-
- def __init__(self) -> None: ...
-
- def __init__(
- self,
- store: Store,
- rank: int,
- size: int,
- options: ProcessGroupCudaP2P.Options,
- ) -> None: ...
- def is_p2p_available(self) -> bool: ...
- def get_buffer_size(self) -> int: ...
- def stream(self) -> torch.cuda.Stream: ...
- def intra_node_barrier(self) -> Work: ...
- def get_p2p_buffer(
- self,
- rank: int,
- sizes: torch.Size,
- dtype: torch.dtype,
- storage_offset: Optional[int] = 0,
- ) -> torch.Tensor: ...
- def _shutdown(self) -> None: ...
diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp
deleted file mode 100644
index 6280eb1..0000000
--- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp
+++ /dev/null
@@ -1,206 +0,0 @@
-#ifdef USE_C10D_NCCL
-#include <torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp>
-
-#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
-
-namespace c10d {
-
-using namespace c10d::intra_node_comm;
-
-ProcessGroupCudaP2P::ProcessGroupCudaP2P(
- const c10::intrusive_ptr<Store>& store,
- int rank,
- int size,
- c10::intrusive_ptr<Options> options)
- : Backend(rank, size), stream_(c10::cuda::getStreamFromPool()) {
- nccl_backend_ = c10::make_intrusive<ProcessGroupNCCL>(
- c10::make_intrusive<PrefixStore>("nccl", store),
- rank,
- size,
- options->nccl_options);
- nccl_backend_->setSequenceNumberForGroup();
-
- p2p_backend_ = c10::make_intrusive<IntraNodeComm>(
- c10::make_intrusive<PrefixStore>("p2p", store),
- rank,
- size,
- options->buffer_size);
- if (!p2p_backend_->rendezvous()) {
- p2p_backend_ = nullptr;
- }
-}
-
-bool ProcessGroupCudaP2P::is_p2p_available() {
- return p2p_backend_ != nullptr &&
- p2p_backend_->getTopology() == Topology::FULLY_CONNECTED;
-}
-
-size_t ProcessGroupCudaP2P::get_buffer_size() {
- if (p2p_backend_ == nullptr) {
- return 0;
- }
- return p2p_backend_->getBufferSize();
-}
-
-c10::Stream ProcessGroupCudaP2P::stream() {
- return stream_;
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::broadcast(
- std::vector<at::Tensor>& tensors,
- const BroadcastOptions& opts) {
- return nccl_backend_->broadcast(tensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allreduce(
- std::vector<at::Tensor>& tensors,
- const AllreduceOptions& opts) {
- return nccl_backend_->allreduce(tensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allreduce_sparse(
- std::vector<at::Tensor>& tensors,
- const AllreduceOptions& opts) {
- return nccl_backend_->allreduce_sparse(tensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allreduce_coalesced(
- std::vector<at::Tensor>& tensors,
- const AllreduceCoalescedOptions& opts) {
- return nccl_backend_->allreduce_coalesced(tensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::reduce(
- std::vector<at::Tensor>& tensors,
- const ReduceOptions& opts) {
- return nccl_backend_->reduce(tensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allgather(
- std::vector<std::vector<at::Tensor>>& outputTensors,
- std::vector<at::Tensor>& inputTensors,
- const AllgatherOptions& opts) {
- return nccl_backend_->allgather(outputTensors, inputTensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::_allgather_base(
- at::Tensor& outputBuffer,
- at::Tensor& inputBuffer,
- const AllgatherOptions& opts) {
- return nccl_backend_->_allgather_base(outputBuffer, inputBuffer, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allgather_coalesced(
- std::vector<std::vector<at::Tensor>>& outputTensorLists,
- std::vector<at::Tensor>& inputTensors,
- const AllgatherOptions& opts) {
- return nccl_backend_->allgather_coalesced(
- outputTensorLists, inputTensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allgather_into_tensor_coalesced(
- std::vector<at::Tensor>& outputs,
- std::vector<at::Tensor>& inputs,
- const AllgatherOptions& opts) {
- return nccl_backend_->allgather_into_tensor_coalesced(outputs, inputs, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::gather(
- std::vector<std::vector<at::Tensor>>& outputTensors,
- std::vector<at::Tensor>& inputTensors,
- const GatherOptions& opts) {
- return nccl_backend_->gather(outputTensors, inputTensors);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::scatter(
- std::vector<at::Tensor>& outputTensors,
- std::vector<std::vector<at::Tensor>>& inputTensors,
- const ScatterOptions& opts) {
- return nccl_backend_->scatter(outputTensors, inputTensors);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::reduce_scatter(
- std::vector<at::Tensor>& outputTensors,
- std::vector<std::vector<at::Tensor>>& inputTensors,
- const ReduceScatterOptions& opts) {
- return nccl_backend_->reduce_scatter(outputTensors, inputTensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::_reduce_scatter_base(
- at::Tensor& outputBuffer,
- at::Tensor& inputBuffer,
- const ReduceScatterOptions& opts) {
- return nccl_backend_->_reduce_scatter_base(outputBuffer, inputBuffer, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::reduce_scatter_tensor_coalesced(
- std::vector<at::Tensor>& outputs,
- std::vector<at::Tensor>& inputs,
- const ReduceScatterOptions& opts) {
- return nccl_backend_->reduce_scatter_tensor_coalesced(outputs, inputs, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::alltoall_base(
- at::Tensor& outputBuffer,
- at::Tensor& inputBuffer,
- std::vector<int64_t>& outputSplitSizes,
- std::vector<int64_t>& inputSplitSizes,
- const AllToAllOptions& opts) {
- return nccl_backend_->alltoall_base(
- outputBuffer, inputBuffer, outputSplitSizes, inputSplitSizes);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::alltoall(
- std::vector<at::Tensor>& outputTensors,
- std::vector<at::Tensor>& inputTensors,
- const AllToAllOptions& opts) {
- return nccl_backend_->alltoall(outputTensors, inputTensors, opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::send(
- std::vector<at::Tensor>& tensors,
- int dstRank,
- int tag) {
- return nccl_backend_->send(tensors, dstRank, tag);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::recv(
- std::vector<at::Tensor>& tensors,
- int srcRank,
- int tag) {
- return nccl_backend_->recv(tensors, srcRank, tag);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::recvAnysource(
- std::vector<at::Tensor>& tensors,
- int tag) {
- return nccl_backend_->recvAnysource(tensors, tag);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::barrier(
- const BarrierOptions& opts) {
- return nccl_backend_->barrier(opts);
-}
-
-c10::intrusive_ptr<Work> ProcessGroupCudaP2P::intra_node_barrier(
- c10::optional<std::vector<int64_t>> ranks) {
- TORCH_CHECK(p2p_backend_ != nullptr);
- p2p_backend_->barrier(ranks);
- return c10::make_intrusive<IntraNodeCommWork>();
-}
-
-at::Tensor ProcessGroupCudaP2P::get_p2p_buffer(
- size_t rank,
- const std::vector<int64_t>& sizes,
- c10::ScalarType dtype,
- int64_t storage_offset) {
- TORCH_CHECK(p2p_backend_ != nullptr);
- return p2p_backend_->getBuffer(rank, sizes, dtype, storage_offset);
-}
-
-void ProcessGroupCudaP2P::shutdown(c10::optional<std::string> reason) {
- nccl_backend_->shutdown(reason);
-}
-
-} // namespace c10d
-#endif // USE_C10D_NCCL
diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp
deleted file mode 100644
index cff4ad0..0000000
--- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp
+++ /dev/null
@@ -1,148 +0,0 @@
-#pragma once
-
-#ifdef USE_C10D_NCCL
-#include <torch/csrc/distributed/c10d/Backend.hpp>
-#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
-#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
-
-constexpr auto kProcessGroupCudaP2PDefaultTimeout =
- std::chrono::milliseconds(10 * 60 * 1000);
-
-namespace c10d {
-
-class TORCH_API ProcessGroupCudaP2P : public Backend {
- public:
- struct Options : Backend::Options {
- c10::intrusive_ptr<ProcessGroupNCCL::Options> nccl_options;
- c10::optional<size_t> buffer_size;
-
- explicit Options()
- : Backend::Options("cuda_p2p", kProcessGroupCudaP2PDefaultTimeout) {}
- };
-
- bool is_p2p_available();
- size_t get_buffer_size();
-
- c10::Stream stream();
-
- ProcessGroupCudaP2P(
- const c10::intrusive_ptr<Store>& store,
- int rank,
- int size,
- c10::intrusive_ptr<Options> options);
-
- c10::intrusive_ptr<Work> broadcast(
- std::vector<at::Tensor>& tensors,
- const BroadcastOptions& opts = BroadcastOptions()) override;
-
- c10::intrusive_ptr<Work> allreduce(
- std::vector<at::Tensor>& tensors,
- const AllreduceOptions& opts = AllreduceOptions()) override;
-
- c10::intrusive_ptr<Work> allreduce_sparse(
- std::vector<at::Tensor>& tensors,
- const AllreduceOptions& opts = AllreduceOptions()) override;
-
- c10::intrusive_ptr<Work> allreduce_coalesced(
- std::vector<at::Tensor>& tensors,
- const AllreduceCoalescedOptions& opts =
- AllreduceCoalescedOptions()) override;
-
- c10::intrusive_ptr<Work> reduce(
- std::vector<at::Tensor>& tensors,
- const ReduceOptions& opts = ReduceOptions()) override;
-
- c10::intrusive_ptr<Work> allgather(
- std::vector<std::vector<at::Tensor>>& outputTensors,
- std::vector<at::Tensor>& inputTensors,
- const AllgatherOptions& opts = AllgatherOptions()) override;
-
- c10::intrusive_ptr<Work> _allgather_base(
- at::Tensor& outputBuffer,
- at::Tensor& inputBuffer,
- const AllgatherOptions& opts = AllgatherOptions()) override;
-
- c10::intrusive_ptr<Work> allgather_coalesced(
- std::vector<std::vector<at::Tensor>>& outputTensorLists,
- std::vector<at::Tensor>& inputTensors,
- const AllgatherOptions& opts = AllgatherOptions()) override;
-
- c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
- std::vector<at::Tensor>& outputs,
- std::vector<at::Tensor>& inputs,
- const AllgatherOptions& opts = AllgatherOptions()) override;
-
- c10::intrusive_ptr<Work> gather(
- std::vector<std::vector<at::Tensor>>& outputTensors,
- std::vector<at::Tensor>& inputTensors,
- const GatherOptions& opts = GatherOptions()) override;
-
- c10::intrusive_ptr<Work> scatter(
- std::vector<at::Tensor>& outputTensors,
- std::vector<std::vector<at::Tensor>>& inputTensors,
- const ScatterOptions& opts = ScatterOptions()) override;
-
- c10::intrusive_ptr<Work> reduce_scatter(
- std::vector<at::Tensor>& outputTensors,
- std::vector<std::vector<at::Tensor>>& inputTensors,
- const ReduceScatterOptions& opts) override;
-
- c10::intrusive_ptr<Work> _reduce_scatter_base(
- at::Tensor& outputBuffer,
- at::Tensor& inputBuffer,
- const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
-
- c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
- std::vector<at::Tensor>& outputs,
- std::vector<at::Tensor>& inputs,
- const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
-
- c10::intrusive_ptr<Work> alltoall_base(
- at::Tensor& outputBuffer,
- at::Tensor& inputBuffer,
- std::vector<int64_t>& outputSplitSizes,
- std::vector<int64_t>& inputSplitSizes,
- const AllToAllOptions& opts = AllToAllOptions()) override;
-
- c10::intrusive_ptr<Work> alltoall(
- std::vector<at::Tensor>& outputTensors,
- std::vector<at::Tensor>& inputTensors,
- const AllToAllOptions& opts = AllToAllOptions()) override;
-
- c10::intrusive_ptr<Work> send(
- std::vector<at::Tensor>& tensors,
- int dstRank,
- int tag) override;
-
- c10::intrusive_ptr<Work> recv(
- std::vector<at::Tensor>& tensors,
- int srcRank,
- int tag) override;
-
- c10::intrusive_ptr<Work> recvAnysource(
- std::vector<at::Tensor>& tensors,
- int tag) override;
-
- /* P2P-only */
- c10::intrusive_ptr<Work> barrier(
- const BarrierOptions& opts = BarrierOptions()) override;
-
- c10::intrusive_ptr<Work> intra_node_barrier(
- c10::optional<std::vector<int64_t>> ranks = c10::nullopt);
-
- at::Tensor get_p2p_buffer(
- size_t rank,
- const std::vector<int64_t>& sizes,
- c10::ScalarType dtype,
- int64_t storage_offest = 0);
-
- void shutdown(c10::optional<std::string> reason = c10::nullopt);
-
- private:
- c10::intrusive_ptr<ProcessGroupNCCL> nccl_backend_;
- c10::intrusive_ptr<c10d::intra_node_comm::IntraNodeComm> p2p_backend_;
- c10::Stream stream_;
-};
-
-} // namespace c10d
-#endif // USE_C10D_NCCL
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 2aaf900..505b64e 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -24,7 +24,6 @@
#ifdef USE_C10D_NCCL
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
-#include <torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#endif
@@ -2645,7 +2644,14 @@
py::arg("rank"),
py::arg("world_size"),
py::arg("buffer_size") = c10::nullopt)
- .def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none());
+ .def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none())
+ .def("put", &IntraNodeComm::put, py::arg("input"), py::arg("offset") = 0)
+ .def(
+ "get",
+ &IntraNodeComm::get,
+ py::arg("rank"),
+ py::arg("tensor"),
+ py::arg("offset") = 0);
#ifdef NCCL_HAS_COMM_CTA_CGA
py::class_<ncclConfig_t>(
@@ -2721,54 +2727,6 @@
.def_readwrite(
"group_name", &::c10d::ProcessGroupNCCL::Options::group_name);
- auto processGroupCudaP2P =
- intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupCudaP2P>(
- module, "ProcessGroupCudaP2P", backend)
- .def(py::init<
- const c10::intrusive_ptr<::c10d::Store>&,
- int,
- int,
- c10::intrusive_ptr<::c10d::ProcessGroupCudaP2P::Options>>())
- .def(
- "is_p2p_available",
- &::c10d::ProcessGroupCudaP2P::is_p2p_available)
- .def("get_buffer_size", &::c10d::ProcessGroupCudaP2P::get_buffer_size)
- .def("stream", &::c10d::ProcessGroupCudaP2P::stream)
- .def(
- "intra_node_barrier",
- &::c10d::ProcessGroupCudaP2P::intra_node_barrier,
- py::arg("ranks") = py::none())
- .def(
- "get_p2p_buffer",
- [](c10::intrusive_ptr<::c10d::ProcessGroupCudaP2P> self,
- size_t rank,
- const std::vector<int64_t>& sizes,
- py::object data_type_obj,
- int64_t storage_offset) {
- auto scalar_type =
- reinterpret_cast<THPDtype*>(data_type_obj.ptr())
- ->scalar_type;
- return self->get_p2p_buffer(
- rank, sizes, scalar_type, storage_offset);
- },
- py::arg("rank"),
- py::arg("sizes"),
- py::arg("dtype"),
- py::arg("storage_offset") = 0)
- .def(
- "_shutdown",
- [](const c10::intrusive_ptr<::c10d::ProcessGroupCudaP2P>& self) {
- return self->shutdown();
- });
-
- intrusive_ptr_class_<::c10d::ProcessGroupCudaP2P::Options>(
- processGroupCudaP2P, "Options", processGroupOptions)
- .def(py::init<>())
- .def_readwrite(
- "nccl_options", &::c10d::ProcessGroupCudaP2P::Options::nccl_options)
- .def_readwrite(
- "buffer_size", &::c10d::ProcessGroupCudaP2P::Options::buffer_size);
-
#endif
#ifdef USE_C10D_MPI
diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp
index 85136a9..ceec7bb 100644
--- a/torch/csrc/distributed/c10d/intra_node_comm.cpp
+++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp
@@ -211,8 +211,9 @@
: store_(std::move(store)),
rank_(rank),
worldSize_(worldSize),
- bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize),
- barrierReady_(at::cuda::CUDAEvent()) {}
+ bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize) {
+ rendezvous();
+}
IntraNodeComm::~IntraNodeComm() {
if (!isInitialized_) {
@@ -288,7 +289,7 @@
return true;
}
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
- if (!isIntraNodeCommSupported() || worldSize_ < 2 ||
+ if (!isIntraNodeCommSupported() || !isEnabled() || worldSize_ < 2 ||
worldSize_ > kMaxDevices) {
return false;
}
diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu
index 51fc625..ce479cd 100644
--- a/torch/csrc/distributed/c10d/intra_node_comm.cu
+++ b/torch/csrc/distributed/c10d/intra_node_comm.cu
@@ -504,8 +504,7 @@
at::cuda::CUDAStream& stream) {
checkInput(input, rank_);
- const size_t numelPerWarp =
- kBytesPerThread / input.element_size() * kWarpSize;
+ const size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
const size_t N_aligned = alignUp(input.numel(), numelPerWarp);
const bool isAligned = (N_aligned == static_cast<size_t>(input.numel()));
TORCH_CHECK(N_aligned <= bufferSize_ / input.element_size());
@@ -734,7 +733,6 @@
}
void IntraNodeComm::barrier(std::optional<std::vector<int64_t>> ranks) {
- barrierReady_.block(at::cuda::getCurrentCUDAStream());
if (!ranks.has_value()) {
ranks = std::vector<int64_t>(worldSize_);
std::iota(ranks->begin(), ranks->end(), 0);
@@ -747,23 +745,44 @@
barrierKernel<<<1, kWarpSize, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<P2pState**>(p2pStatesDev_), mask, rank_, worldSize_);
C10_CUDA_KERNEL_LAUNCH_CHECK();
- barrierReady_.record();
}
-at::Tensor IntraNodeComm::getBuffer(
- size_t rank,
- const std::vector<int64_t>& sizes,
- c10::ScalarType dtype,
- int64_t storageOffset) {
- const auto numel = std::accumulate(sizes.begin(), sizes.end(), 0);
- const auto elementSize = c10::elementSize(dtype);
- TORCH_CHECK((numel + storageOffset) * elementSize <= bufferSize_);
- auto options = at::TensorOptions().dtype(dtype).device(
- at::kCUDA, at::cuda::current_device());
- return at::for_blob(buffers_[rank], sizes)
- .storage_offset(storageOffset)
- .options(options)
- .make_tensor();
+void IntraNodeComm::put(const at::Tensor& tensor, int64_t offset) {
+ TORCH_CHECK(
+ tensor.is_non_overlapping_and_dense(),
+ "IntraNodeComm::put(): tensor must be non-overlapping and dense");
+ size_t sz = tensor.numel() * tensor.element_size();
+ TORCH_CHECK(
+ offset + sz <= bufferSize_,
+ "IntraNodeComm::put(): offset + tensor size exceeded "
+ "p2p buffer size");
+ // This results in "Memcpy PtoP" which does not use SMs for copying
+ AT_CUDA_CHECK(cudaMemcpyAsync(
+ static_cast<char*>(buffers_[rank_]) + offset,
+ static_cast<char*>(tensor.data_ptr()),
+ sz,
+ cudaMemcpyDeviceToDevice,
+ at::cuda::getCurrentCUDAStream()));
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+void IntraNodeComm::get(size_t rank, at::Tensor tensor, int64_t offset) {
+ TORCH_CHECK(
+ tensor.is_non_overlapping_and_dense(),
+ "IntraNodeComm::get(): tensor must be non-overlapping and dense");
+ size_t sz = tensor.numel() * tensor.element_size();
+ TORCH_CHECK(
+ offset + sz <= bufferSize_,
+ "IntraNodeComm::get(): offset + tensor size exceeded "
+ "p2p buffer size");
+ // This results in "Memcpy PtoP" which does not use SMs for copying
+ AT_CUDA_CHECK(cudaMemcpyAsync(
+ static_cast<char*>(tensor.data_ptr()),
+ static_cast<char*>(buffers_[rank]) + offset,
+ sz,
+ cudaMemcpyDeviceToDevice,
+ at::cuda::getCurrentCUDAStream()));
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} // namespace intra_node_comm
diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp
index 5d7e2d4..fe59197 100644
--- a/torch/csrc/distributed/c10d/intra_node_comm.hpp
+++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp
@@ -46,10 +46,6 @@
*/
bool rendezvous();
- Topology getTopology() {
- return topology_;
- }
-
size_t getBufferSize() {
return bufferSize_;
}
@@ -67,11 +63,17 @@
*/
void barrier(std::optional<std::vector<int64_t>> ranks = c10::nullopt);
- at::Tensor getBuffer(
- size_t rank,
- const std::vector<int64_t>& sizes,
- c10::ScalarType dtype,
- int64_t storageOffset);
+ /**
+ * Puts the given tensor into the p2p buffer of the current rank at the
+ * specified offset.
+ */
+ void put(const at::Tensor& tensor, int64_t offset = 0);
+
+ /**
+ * Fills the given tensor with the data from the specified rank's p2p buffer
+ * at the specified offset.
+ */
+ void get(size_t rank, at::Tensor tensor, int64_t offset = 0);
private:
at::Tensor oneShotAllReduce(
@@ -90,7 +92,6 @@
size_t rank_;
size_t worldSize_;
size_t bufferSize_;
- at::cuda::CUDAEvent barrierReady_;
/**
* Members initialized after rendezvous
diff --git a/torch/distributed/_cuda_p2p/__init__.py b/torch/distributed/_cuda_p2p/__init__.py
deleted file mode 100644
index f91d1f2..0000000
--- a/torch/distributed/_cuda_p2p/__init__.py
+++ /dev/null
@@ -1,123 +0,0 @@
-from contextlib import contextmanager
-
-from functools import partial
-from typing import Callable, cast, List, Tuple, Union
-
-import torch
-import torch.distributed._functional_collectives as funcol
-
-import torch.distributed.distributed_c10d as c10d
-from torch._C._distributed_c10d import _DistributedBackendOptions, Backend
-
-
-"""
-This file contains the registration logic and Python APIs for
-``ProcessGroupCudaP2P`` (experimental).
-
-``ProcessGroupCudaP2P`` is a thin wrapper around ``ProcessGroupNCCL``. By
-default, it routes all collectives to the underlying ``ProcessGroupNCCL``. In
-addition, ``ProcessGroupCudaP2P`` initializes a P2P workspace that allows
-direct GPU memory access among the members. The workspace can be used in Python
-to optimize intra-node communication patterns or to create custom intra-node
-collectives in CUDA.
-
-``ProcessGroupCudaP2P`` aims to bridge the gap where certain important patterns
-can be better optimized via fine-grained P2P memory access than with
-collectives in the latest version of NCCL. It is meant to complement NCCL
-rather than replacing it.
-
-Usage:
-
- # Using ProcessGroupCudaP2P
- dist.init_process_group(backend="cuda_p2p", ...)
-
- # Using ProcessGroupCudaP2P while specifying ProcessGroupCudaP2P.Options
- pg_options = ProcessGroupCudaP2P.Options()
- dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
-
- # Using ProcessGroupCudaP2P while specifying ProcessGroupNCCL.Options
- pg_options = ProcessGroupNCCL.Options()
- dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
-
- # Using ProcessGroupCudaP2P while specifying both
- # ProcessGroupCudaP2P.Options and ProcessGroupNCCL.Options
- pg_options = ProcessGroupCudaP2P.Options()
- pg_options.nccl_options = ProcessGroupNCCL.Options()
- dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
-
- # Down-casting the backend to access p2p buffers for cuda_p2p specific
- # optimizations
- if is_cuda_p2p_group(group):
- backend = get_cuda_p2p_backend(group)
- if required_p2p_buffer_size > backend.get_buffer_size():
- # fallback
- p2p_buffer = backend.get_p2p_buffer(...)
- else:
- # fallback
-"""
-
-
-def _create_cuda_p2p_group(
- dist_backend_opts: "_DistributedBackendOptions",
- options: Union[
- "c10d.ProcessGroupCudaP2P.Options", "c10d.ProcessGroupNCCL.Options", None
- ],
-) -> "Backend":
- if not c10d.is_nccl_available():
- raise RuntimeError("The cuda_p2p backend is not available")
- if options is None:
- options = c10d.ProcessGroupCudaP2P.Options()
- options.nccl_options = c10d.ProcessGroupNCCL.Options()
- elif isinstance(options, c10d.ProcessGroupNCCL.Options):
- nccl_options = options
- options = c10d.ProcessGroupCudaP2P.Options()
- options.nccl_options = nccl_options
- elif isinstance(options, c10d.ProcessGroupCudaP2P.Options):
- if options.nccl_options is None:
- options.nccl_options = c10d.ProcessGroupNCCL.Options()
- else:
- raise TypeError(
- "options for cuda_p2p must be ProcessGroupCudaP2P.Options "
- f"or ProcessGroupNCCL.Options (got: {type(options)})"
- )
-
- return c10d.ProcessGroupCudaP2P(
- dist_backend_opts.store,
- dist_backend_opts.group_rank,
- dist_backend_opts.group_size,
- options,
- )
-
-
-def is_cuda_p2p_group(group: c10d.ProcessGroup) -> bool:
- if not c10d.is_nccl_available():
- return False
- try:
- backend = group._get_backend(torch.device("cuda"))
- except Exception:
- return False
- return isinstance(backend, c10d.ProcessGroupCudaP2P) and backend.is_p2p_available()
-
-
-def get_cuda_p2p_backend(group: c10d.ProcessGroup) -> "c10d.ProcessGroupCudaP2P":
- if not is_cuda_p2p_group(group):
- raise TypeError("group is not a cuda_p2p process group.")
- return cast(
- c10d.ProcessGroupCudaP2P,
- group._get_backend(torch.device("cuda")),
- )
-
-
-def get_p2p_buffer_size(group: c10d.ProcessGroup) -> int:
- if not is_cuda_p2p_group(group):
- return 0
- backend = get_cuda_p2p_backend(group)
- return backend.get_buffer_size()
-
-
-c10d.Backend.register_backend(
- "cuda_p2p",
- _create_cuda_p2p_group,
- extended_api=True,
- devices=["cuda"],
-)
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 3ea412f..c6fc227 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -110,10 +110,8 @@
try:
from torch._C._distributed_c10d import ProcessGroupNCCL
- from torch._C._distributed_c10d import ProcessGroupCudaP2P
ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
- ProcessGroupCudaP2P.__module__ = "torch.distributed.distributed_c10d"
- __all__ += ["ProcessGroupNCCL", "ProcessGroupCudaP2P"]
+ __all__ += ["ProcessGroupNCCL"]
except ImportError:
_NCCL_AVAILABLE = False
@@ -1446,7 +1444,7 @@
backend = pg._get_backend(torch.device("cuda"))
except RuntimeError:
pass
- if is_nccl_available() and isinstance(backend, (ProcessGroupNCCL, ProcessGroupCudaP2P)):
+ if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
# explictly call shutdown to ensure that NCCL resources are released
backend._shutdown()