[re-land] Introduce 3 low-latency, intra-node allreduce algorithms for small messages to PyTorch (#114001) (#116125)
This is an attempt to re-land https://github.com/pytorch/pytorch/pull/114001. The previous attempt used `std::array` in cuda kernels which wasn't compatible with Meta's internal build.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116125
Approved by: https://github.com/yf225
diff --git a/BUILD.bazel b/BUILD.bazel
index f5739a9..59d2ea8 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -1452,7 +1452,10 @@
# https://github.com/pytorch/pytorch/issues/79236
# To solve it we add it into the `caffe2_cuda`,
# this is also aligned with the CMake build.
- srcs = [":caffe2_cu_srcs"] + ["torch/csrc/distributed/c10d/quantization/quantization_gpu.cu"],
+ srcs = [":caffe2_cu_srcs"] + [
+ "torch/csrc/distributed/c10d/intra_node_comm.cu",
+ "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
+ ],
copts = CAFFE2_COPTS + torch_cuda_half_options,
visibility = ["//visibility:public"],
deps = [
@@ -1619,6 +1622,7 @@
exclude = [
"torch/csrc/cuda/python_nccl.cpp",
"torch/csrc/cuda/nccl.cpp",
+ "torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],
)) + torch_sources,
diff --git a/build_variables.bzl b/build_variables.bzl
index 9d61861..b028a9b 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -674,6 +674,8 @@
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
"torch/csrc/distributed/c10d/UCCTracing.cpp",
"torch/csrc/distributed/c10d/UCCUtils.cpp",
+ "torch/csrc/distributed/c10d/intra_node_comm.cpp",
+ "torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
]
diff --git a/c10/cuda/driver_api.cpp b/c10/cuda/driver_api.cpp
index a90081f..56243e6 100644
--- a/c10/cuda/driver_api.cpp
+++ b/c10/cuda/driver_api.cpp
@@ -37,7 +37,7 @@
return nvml_hanle;
}
-DriverAPI* DriverAPI::get() {
+C10_EXPORT DriverAPI* DriverAPI::get() {
static DriverAPI singleton = create_driver_api();
return &singleton;
}
diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h
index 6f5e46f..f4054c2 100644
--- a/c10/cuda/driver_api.h
+++ b/c10/cuda/driver_api.h
@@ -28,9 +28,11 @@
_(cuMemCreate) \
_(cuGetErrorString)
-#define C10_NVML_DRIVER_API(_) \
- _(nvmlInit_v2) \
- _(nvmlDeviceGetHandleByPciBusId_v2) \
+#define C10_NVML_DRIVER_API(_) \
+ _(nvmlInit_v2) \
+ _(nvmlDeviceGetHandleByPciBusId_v2) \
+ _(nvmlDeviceGetNvLinkRemoteDeviceType) \
+ _(nvmlDeviceGetNvLinkRemotePciInfo_v2) \
_(nvmlDeviceGetComputeRunningProcesses)
namespace c10 {
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 7483637..f2acc61 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -641,6 +641,10 @@
append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS)
if(NOT WIN32)
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS)
+ set_source_files_properties(
+ ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
+ PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
+ )
endif()
endif()
set_source_files_properties(
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index 2992a26..15fc8e3 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -15,7 +15,7 @@
from contextlib import contextmanager
from datetime import datetime, timedelta
from itertools import chain, product
-from unittest import mock
+from unittest import SkipTest, mock
import torch
import torch.distributed as c10d
@@ -3115,6 +3115,65 @@
@requires_nccl()
@skip_if_lt_x_gpu(2)
+ @skip_if_rocm
+ def test_intra_node_comm_all_reduce(self):
+ from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter
+ from torch.testing._internal.common_cuda import SM80OrLater
+ for peer in range(self.world_size):
+ if peer == self.rank:
+ continue
+ if not torch._C._cuda_canDeviceAccessPeer(self.rank, peer):
+ raise SkipTest("Test requires p2p access")
+
+ if not SM80OrLater:
+ raise SkipTest("Test requires sm>=80")
+
+ store = c10d.FileStore(self.file_name, self.world_size)
+ os.environ["ENABLE_INTRA_NODE_COMM"] = "1"
+ os.environ["TEST_INTRA_NODE_COMM"] = "1"
+ torch.cuda.set_device(self.rank)
+ c10d.init_process_group(
+ backend="nccl", rank=self.rank, world_size=self.world_size, store=store
+ )
+ expect = self.world_size * (self.world_size - 1) // 2
+
+ # IntraNodeComm currently only supports sum and bf16.
+ # Verify that it is not used in the next two configurations.
+ t = torch.full((4 * 1024 // 2,), self.rank).cuda()
+ c10d.all_reduce(t, c10d.ReduceOp.SUM)
+ self.assertTrue(t.eq(expect).all())
+ self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
+
+ t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
+ c10d.all_reduce(t, c10d.ReduceOp.AVG)
+ self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
+
+ # Verify that IntraNodeComm is used up to 10MB
+ t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
+ c10d.all_reduce(t, c10d.ReduceOp.SUM)
+ self.assertTrue(t.eq(expect).all())
+ self.assertEqual(_get_intra_node_comm_usage_counter(), 1)
+
+ t = torch.full((512 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
+ c10d.all_reduce(t, c10d.ReduceOp.SUM)
+ self.assertTrue(t.eq(expect).all())
+ self.assertEqual(_get_intra_node_comm_usage_counter(), 2)
+
+ t = torch.full((10 * 1024 ** 2 // 2,), self.rank, dtype=torch.bfloat16).cuda()
+ c10d.all_reduce(t, c10d.ReduceOp.SUM)
+ self.assertTrue(t.eq(expect).all())
+ self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
+
+ # Verify that IntraNodeComm is not used beyond 10MB
+ t = torch.full((10 * 1024 ** 2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda()
+ c10d.all_reduce(t, c10d.ReduceOp.SUM)
+ self.assertTrue(t.eq(expect).all())
+ self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
+
+ c10d.destroy_process_group()
+
+ @requires_nccl()
+ @skip_if_lt_x_gpu(2)
def test_sequence_num_set_default_pg_nccl(self):
torch.cuda.set_device(self.rank)
self._test_sequence_num_set_default_pg(backend="nccl")
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 026576a..0580ea3 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -712,7 +712,8 @@
terminateProcessGroup_(false),
terminateHeartbeatMonitorThread_(false),
collectiveDebugInfoMode_(false),
- uid_(process_group_id++) {
+ uid_(process_group_id++),
+ intraNodeComm_(initIntraNodeComm()) {
TORCH_CHECK_WITH(
ValueError,
at::cuda::getNumGPUs() != 0,
@@ -896,6 +897,12 @@
#endif
}
+c10::intrusive_ptr<intra_node_comm::IntraNodeComm> ProcessGroupNCCL::
+ initIntraNodeComm() {
+ return intra_node_comm::IntraNodeComm::rendezvous(
+ store_, std::to_string(uid_), rank_, size_);
+}
+
void ProcessGroupNCCL::runHealthCheck() {
// Run health check in a separate thread and wait on CV to handle timeouts,
// since majority of getNCCLComm failures are hangs.
@@ -2842,6 +2849,16 @@
c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
+ if (intraNodeComm_ != nullptr && tensors.size() == 1 &&
+ opts.reduceOp == ReduceOp::SUM) {
+ using namespace intra_node_comm;
+ auto algo = intraNodeComm_->selectAllReduceAlgo(tensors[0]);
+ if (algo != intra_node_comm::AllReduceAlgo::NONE) {
+ intraNodeComm_->allReduce(tensors[0], algo);
+ return c10::make_intrusive<IntraNodeCommWork>();
+ }
+ }
+
check_gpu_tensors_different_devices(tensors);
// @lint-ignore CLANGTIDY
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index 00022b1..4c07053 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -13,6 +13,7 @@
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
+#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#include <ATen/DynamicLibrary.h>
#include <ATen/cuda/CUDAContext.h>
@@ -552,6 +553,8 @@
ncclCommsMap,
c10::optional<std::string> abortReason);
+ c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
+
// Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
// instead of relying on ProcessGroupNCCL destructor.
void abort(c10::optional<std::string> abortReason = c10::nullopt);
@@ -950,6 +953,8 @@
std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr;
size_t uid_;
+
+ c10::intrusive_ptr<intra_node_comm::IntraNodeComm> intraNodeComm_;
};
TORCH_API std::string dump_nccl_trace();
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 252cf4a..22b02f0 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -21,6 +21,7 @@
#ifdef USE_C10D_NCCL
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
+#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#endif
#ifdef USE_C10D_MPI
@@ -2328,6 +2329,10 @@
"perform_nocolor_split",
&::c10d::ProcessGroupNCCL::performNocolorSplit);
+ module.def(
+ "_get_intra_node_comm_usage_counter",
+ &::c10d::intra_node_comm::getIntraNodeCommUsageCounter);
+
#ifdef NCCL_HAS_COMM_CTA_CGA
py::class_<ncclConfig_t>(
processGroupNCCL,
diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp
new file mode 100644
index 0000000..50b0147
--- /dev/null
+++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp
@@ -0,0 +1,485 @@
+#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
+
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+#include <c10/util/Logging.h>
+#include <torch/csrc/distributed/c10d/Utils.hpp>
+
+#include <iostream>
+#include <random>
+
+#include <fcntl.h>
+#include <pthread.h>
+#include <semaphore.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
+#include <c10/cuda/driver_api.h>
+#include <nvml.h>
+#endif
+
+#include <cuda_runtime.h>
+
+namespace c10d {
+namespace intra_node_comm {
+
+static std::vector<std::string> ENABLE_INTRA_NODE_COMM = {
+ "ENABLE_INTRA_NODE_COMM"};
+// Forces detectedTopology() to return Topology::FULLY_CONNECTED, so
+// IntraNodeComm can be used even without NVLink connection. This is only used
+// for testing purposes.
+static std::vector<std::string> TEST_INTRA_NODE_COMM = {"TEST_INTRA_NODE_COMM"};
+
+////////////////////////////////////////////////////////////////////////////////
+// CUDA Functions
+////////////////////////////////////////////////////////////////////////////////
+
+bool isIntraNodeCommSupported();
+
+std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh);
+
+void* initP2pState();
+
+void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank);
+
+AllReduceAlgo selectAllReduceAlgo(
+ const at::Tensor& input,
+ Topology topology,
+ size_t worldSize);
+
+at::Tensor allReduce(
+ const at::Tensor& input,
+ std::array<void*, kMaxDevices> p2pStates,
+ std::array<void*, kMaxDevices> buffers,
+ void* p2pStatesDev,
+ void* buffersDev,
+ void* topoInfo,
+ size_t rank,
+ size_t worldSize,
+ AllReduceAlgo algo,
+ at::cuda::CUDAStream& stream);
+
+////////////////////////////////////////////////////////////////////////////////
+// Topology Detection
+////////////////////////////////////////////////////////////////////////////////
+
+// TODO: find a better way to determine this
+static constexpr size_t kMaxNvLinks = 20;
+
+static std::ostream& operator<<(std::ostream& os, const NvlMesh& nvlMesh) {
+ std::ostringstream oss;
+ for (size_t i = 0; i < kMaxDevices; ++i) {
+ for (size_t j = 0; j < kMaxDevices; ++j) {
+ oss << nvlMesh[i][j] << " ";
+ }
+ oss << std::endl;
+ }
+ os << oss.str();
+ return os;
+}
+
+static bool isSame(NvlMesh lhs, NvlMesh rhs) {
+ for (size_t i = 0; i < kMaxDevices; ++i) {
+ for (size_t j = 0; j < kMaxDevices; ++j) {
+ if (lhs[i][j] != rhs[i][j]) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+/**
+ * Query the nvlink connection among devices.
+ */
+static NvlMesh getNvlMesh(std::vector<std::string> rankToBusId) {
+#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
+ using namespace c10::cuda;
+
+ NvlMesh nvlMesh = {};
+ auto driverApi = DriverAPI::get();
+ if (driverApi == nullptr) {
+ return nvlMesh;
+ }
+
+ const auto worldSize = rankToBusId.size();
+ std::vector<nvmlDevice_t> devices(worldSize, 0);
+ std::unordered_map<std::string, size_t> busIdToRank;
+ std::vector<size_t> switchLinkCount(worldSize, 0);
+
+ for (size_t r = 0; r < worldSize; ++r) {
+ busIdToRank.emplace(std::make_pair(rankToBusId[r], r));
+ TORCH_CHECK(
+ driverApi->nvmlDeviceGetHandleByPciBusId_v2_(
+ rankToBusId[r].c_str(), &devices[r]) == NVML_SUCCESS);
+ }
+
+ // For each device, loop over devices connected to it via NVLink
+ for (size_t idx = 0; idx < worldSize; ++idx) {
+ for (size_t link = 0; link < kMaxNvLinks; ++link) {
+ nvmlReturn_t ret;
+ nvmlIntNvLinkDeviceType_t deviceType;
+ ret = driverApi->nvmlDeviceGetNvLinkRemoteDeviceType_(
+ devices[idx], link, &deviceType);
+ if (ret != NVML_SUCCESS) {
+ // We've exhausted the NVLinks connected to this device.
+ // This error is benign. There doesn't seem to be a reliable
+ // way to obtain the maximum link value that can be passed to
+ // the API, so we simply increment the link value until the
+ // API fails or we hit a predefined maximum value.
+ break;
+ }
+ // Remote device is GPU
+ if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) {
+ nvmlPciInfo_t pciInfo;
+ ret = driverApi->nvmlDeviceGetNvLinkRemotePciInfo_v2_(
+ devices[idx], link, &pciInfo);
+ if (ret != NVML_SUCCESS) {
+ // Unexpected error. Return an empty NvlMesh
+ return {};
+ }
+ auto it = busIdToRank.find(pciInfo.busId);
+ if (it != busIdToRank.end()) {
+ if (idx != it->second) {
+ nvlMesh[idx][it->second] += 1;
+ }
+ }
+ // Remote device is NVSwitch
+ } else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) {
+ switchLinkCount[idx] += 1;
+ }
+ }
+ }
+ // Process NVSwitch connections. For simplicity, we assume
+ // all NVSwitches are interconnected.
+ for (size_t i = 0; i < worldSize; ++i) {
+ for (size_t j = 0; j < worldSize; ++j) {
+ if (i == j) {
+ continue;
+ }
+ nvlMesh[i][j] += std::min(switchLinkCount[i], switchLinkCount[j]);
+ }
+ }
+ return nvlMesh;
+#else
+ return {};
+#endif
+}
+
+/**
+ * Determine if the devices form a hybrid cube mesh
+ * topology given a NvlMesh.
+ */
+static bool isHybridCubeMesh(const NvlMesh nvlMesh) {
+ std::array<size_t, kMaxDevices> numNeighbors = {};
+ for (size_t i = 0; i < kMaxDevices; ++i) {
+ for (size_t j = 0; j < kMaxDevices; ++j) {
+ if (nvlMesh[i][j] > 0) {
+ numNeighbors[i] += 1;
+ }
+ }
+ }
+ for (size_t i = 0; i < kMaxDevices; ++i) {
+ // TODO: this is insufficent and needs revisit
+ if (numNeighbors[i] != 4) {
+ return false;
+ }
+ }
+ return true;
+}
+
+/**
+ * Detech topology given a NvlMesh.
+ */
+static Topology detectTopology(const NvlMesh nvlMesh, size_t worldSize) {
+ if (getCvarBool(TEST_INTRA_NODE_COMM, false)) {
+ return Topology::FULLY_CONNECTED;
+ }
+ bool fullyConnected = true;
+ for (size_t i = 0; i < worldSize - 1; ++i) {
+ for (size_t j = i + 1; j < worldSize; ++j) {
+ if (nvlMesh[i][j] == 0 || nvlMesh[j][i] == 0) {
+ fullyConnected = false;
+ }
+ }
+ }
+ if (fullyConnected) {
+ LOG(INFO) << "IntraNodeComm: Topology::FULLY_CONNECTED";
+ return Topology::FULLY_CONNECTED;
+ }
+ if (worldSize == kMaxDevices && getHybridCubeMesh(nvlMesh) != std::nullopt) {
+ LOG(INFO) << "IntraNodeComm: Topology::HYBRID_CUBE_MESH";
+ return Topology::HYBRID_CUBE_MESH;
+ }
+ LOG(INFO) << "IntraNodeComm: Topology::UNKNOWN";
+ return Topology::UNKNOWN;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+// Rendezvous and Initialization
+////////////////////////////////////////////////////////////////////////////////
+
+IntraNodeComm::IntraNodeComm(
+ Topology topology,
+ std::array<void*, kMaxDevices> p2pStates,
+ std::array<void*, kMaxDevices> buffers,
+ void* p2pStatesDev,
+ void* buffersDev,
+ void* topoInfo,
+ size_t rank,
+ size_t worldSize)
+ : topology_(topology),
+ p2pStates_(p2pStates),
+ buffers_(buffers),
+ p2pStatesDev_(p2pStatesDev),
+ buffersDev_(buffersDev),
+ topoInfo_(topoInfo),
+ rank_(rank),
+ worldSize_(worldSize) {}
+
+IntraNodeComm::~IntraNodeComm() {
+ // Intentionally releasing resources without synchronizing devices. The
+ // teardown logic is safe for propoerly sync'd user program. We don't want
+ // improperly sync'd user program to hang here.
+ for (size_t r = 0; r < worldSize_; ++r) {
+ if (r == rank_) {
+ continue;
+ }
+ AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r]));
+ AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r]));
+ }
+ AT_CUDA_CHECK(cudaFree(p2pStates_[rank_]));
+ AT_CUDA_CHECK(cudaFree(buffers_[rank_]));
+ if (topoInfo_ != nullptr) {
+ AT_CUDA_CHECK(cudaFree(topoInfo_));
+ }
+ AT_CUDA_CHECK(cudaFree(p2pStatesDev_));
+ AT_CUDA_CHECK(cudaFree(buffersDev_));
+}
+
+/**
+ * Use c10d::Store to perform allgather on a trivially copyable type.
+ */
+template <typename T>
+std::vector<T> storeAllGather(
+ c10::intrusive_ptr<c10d::Store> store,
+ const std::string& prefix,
+ size_t rank,
+ size_t worldSize,
+ T val) {
+ static_assert(std::is_trivially_copyable<T>::value);
+
+ std::vector<std::string> peerKeys;
+ for (size_t r = 0; r < worldSize; ++r) {
+ std::ostringstream oss;
+ oss << prefix << "-" << r;
+ peerKeys.push_back(oss.str());
+ }
+
+ {
+ std::vector<uint8_t> payload(
+ reinterpret_cast<uint8_t*>(&val),
+ reinterpret_cast<uint8_t*>(&val) + sizeof(T));
+ store->set(peerKeys[rank], payload);
+ }
+
+ std::vector<T> peerVals;
+ for (size_t r = 0; r < worldSize; ++r) {
+ if (r == rank) {
+ peerVals.push_back(val);
+ continue;
+ }
+ store->wait({peerKeys[r]});
+ auto payload = store->get(peerKeys[r]);
+ TORCH_CHECK(payload.size() == sizeof(T));
+ T peerVal;
+ std::memcpy(&peerVal, payload.data(), sizeof(T));
+ peerVals.push_back(peerVal);
+ }
+ return peerVals;
+}
+
+c10::intrusive_ptr<IntraNodeComm> IntraNodeComm::rendezvous(
+ c10::intrusive_ptr<c10d::Store> store,
+ const std::string& prefix,
+ size_t rank,
+ size_t worldSize) {
+#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
+ if (!isIntraNodeCommSupported() ||
+ !getCvarBool(ENABLE_INTRA_NODE_COMM, false) || worldSize < 2 ||
+ worldSize > kMaxDevices) {
+ return nullptr;
+ }
+
+ int deviceIdx = at::cuda::current_device();
+ c10::cuda::CUDAGuard guard(deviceIdx);
+
+ // First hand shake: exchange hostname and device bus ID
+ struct DevInfo {
+ char hostname[HOST_NAME_MAX + 1];
+ char busId[80];
+ };
+
+ DevInfo devInfo{};
+ gethostname(devInfo.hostname, sizeof(devInfo.hostname));
+ cudaDeviceProp prop{};
+ AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx));
+ snprintf(
+ devInfo.busId,
+ sizeof(devInfo.busId),
+ NVML_DEVICE_PCI_BUS_ID_FMT,
+ prop.pciDomainID,
+ prop.pciBusID,
+ prop.pciDeviceID);
+
+ auto peerDevInfos = storeAllGather(
+ store, prefix + "-IntraNodeCommHandShake-0", rank, worldSize, devInfo);
+
+ std::vector<std::string> rankToBusId;
+ for (const auto& info : peerDevInfos) {
+ if (strcmp(info.hostname, peerDevInfos.front().hostname) != 0) {
+ LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
+ "participants are not on the same host ("
+ << info.hostname << ", " << devInfo.hostname << ")";
+ return nullptr;
+ }
+ rankToBusId.emplace_back(info.busId);
+ }
+
+ // Verify unique devices
+ {
+ std::unordered_set uniqueBusIds(rankToBusId.begin(), rankToBusId.end());
+ TORCH_CHECK(
+ uniqueBusIds.size() == worldSize,
+ "IntraNodeComm::rendezvous: detected overlapping devices across ranks. "
+ "Please properly set device via torch.cuda.set_device() before "
+ "initiating rendezvous.");
+ }
+
+ // Query nvlink connection
+ auto nvlMesh = getNvlMesh(rankToBusId);
+
+ // Detect topology
+ Topology topology = detectTopology(nvlMesh, worldSize);
+
+ // Initialize p2p state
+ auto p2pState = initP2pState();
+
+ // Allocate buffer
+ void* buffer = nullptr;
+ AT_CUDA_CHECK(cudaMalloc(&buffer, kMaxIntraNodeSize * 2));
+
+ // Second handshake: exchange topology and CUDA IPC handles
+ struct IpcInfo {
+ NvlMesh nvlMesh;
+ Topology topology;
+ cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
+ };
+
+ // Make p2p state and buffer available for IPC
+ cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
+ AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState));
+ AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer));
+
+ IpcInfo ipcInfo{
+ .nvlMesh = nvlMesh,
+ .topology = topology,
+ .p2pStateHandle = p2pStateHandle,
+ .bufferHandle = bufferHandle};
+
+ auto peerIpcInfos = storeAllGather(
+ store, prefix + "-IntraNodeCommHandShake-2", rank, worldSize, ipcInfo);
+
+ for (const auto& info : peerIpcInfos) {
+ if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) ||
+ info.topology != peerIpcInfos.front().topology) {
+ LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
+ "participants are observing different topologies ("
+ << int(info.topology) << " and " << int(topology) << ")";
+ AT_CUDA_CHECK(cudaFree(p2pState));
+ AT_CUDA_CHECK(cudaFree(buffer));
+ return nullptr;
+ }
+ }
+
+ std::array<void*, kMaxDevices> p2pStates = {}, buffers = {};
+ for (size_t r = 0; r < peerIpcInfos.size(); ++r) {
+ if (r == rank) {
+ p2pStates[r] = p2pState;
+ buffers[r] = buffer;
+ } else {
+ AT_CUDA_CHECK(cudaIpcOpenMemHandle(
+ &p2pStates[r],
+ peerIpcInfos[r].p2pStateHandle,
+ cudaIpcMemLazyEnablePeerAccess));
+ AT_CUDA_CHECK(cudaIpcOpenMemHandle(
+ &buffers[r],
+ peerIpcInfos[r].bufferHandle,
+ cudaIpcMemLazyEnablePeerAccess));
+ }
+ }
+ void* p2pStatesDev = nullptr;
+ AT_CUDA_CHECK(cudaMalloc(&p2pStatesDev, sizeof(p2pStates)));
+ AT_CUDA_CHECK(cudaMemcpy(
+ p2pStatesDev,
+ p2pStates.data(),
+ sizeof(p2pStates),
+ cudaMemcpyHostToDevice));
+
+ void* buffersDev = nullptr;
+ AT_CUDA_CHECK(cudaMalloc(&buffersDev, sizeof(buffers)));
+ AT_CUDA_CHECK(cudaMemcpy(
+ buffersDev, buffers.data(), sizeof(buffers), cudaMemcpyHostToDevice));
+
+ void* topoInfo = initTopoInfo(topology, nvlMesh, rank);
+ return c10::make_intrusive<IntraNodeComm>(
+ topology,
+ p2pStates,
+ buffers,
+ p2pStatesDev,
+ buffersDev,
+ topoInfo,
+ rank,
+ worldSize);
+#else
+ return nullptr;
+#endif
+}
+
+AllReduceAlgo IntraNodeComm::selectAllReduceAlgo(const at::Tensor& input) {
+ return c10d::intra_node_comm::selectAllReduceAlgo(
+ input, topology_, worldSize_);
+}
+
+static int64_t usageCounter = 0;
+
+at::Tensor IntraNodeComm::allReduce(
+ const at::Tensor& input,
+ AllReduceAlgo algo) {
+ // Report usage for testing purposes.
+ // We don't care about overflowing.
+ ++usageCounter;
+ auto stream = at::cuda::getCurrentCUDAStream();
+ c10::cuda::CUDACachingAllocator::recordStream(
+ input.storage().data_ptr(), stream);
+ return c10d::intra_node_comm::allReduce(
+ input,
+ p2pStates_,
+ buffers_,
+ p2pStatesDev_,
+ buffersDev_,
+ topoInfo_,
+ rank_,
+ worldSize_,
+ algo,
+ stream);
+}
+
+int64_t getIntraNodeCommUsageCounter() {
+ return usageCounter;
+}
+
+} // namespace intra_node_comm
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu
new file mode 100644
index 0000000..7723140
--- /dev/null
+++ b/torch/csrc/distributed/c10d/intra_node_comm.cu
@@ -0,0 +1,729 @@
+#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
+
+#include <ATen/Dispatch.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+
+namespace c10d {
+namespace intra_node_comm {
+
+static constexpr size_t kBytesPerThread = 16;
+static constexpr size_t kMaxAllReduceBlocks = 24;
+static constexpr size_t kThreadsPerBlock = 1024;
+static constexpr size_t kWarpSize = 32;
+
+static constexpr size_t kHcmThreshBytes = 256 * 1024;
+static constexpr size_t kOneShotThreshBytes = 256 * 1024;
+static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;
+
+#if defined(USE_ROCM)
+using __nv_bfloat162 = uint32_t;
+#endif
+
+struct __align__(16) bf16x8 {
+ __nv_bfloat162 vals[4];
+};
+
+#define DEVICE_INLINE __device__ inline __attribute__((always_inline))
+
+DEVICE_INLINE __nv_bfloat162
+bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
+#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
+ CUDA_KERNEL_ASSERT(false);
+#else
+ return __hadd2(x, y);
+#endif
+}
+
+DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
+ bf16x8 c;
+ c.vals[0] = bf16hadd2(a.vals[0], b.vals[0]);
+ c.vals[1] = bf16hadd2(a.vals[1], b.vals[1]);
+ c.vals[2] = bf16hadd2(a.vals[2], b.vals[2]);
+ c.vals[3] = bf16hadd2(a.vals[3], b.vals[3]);
+ return c;
+}
+
+/**
+ * NOTE [cross device memory synchronization]
+ *
+ * The multi-stage algorithms (e.g. two-shot, hcm allreduce) require the writes
+ * of a thread to be visible by threads with the same block/thread ID on other
+ * devices. To satisfy CUDA's memory consistency model, every thread has to
+ * release its writes at the system scope, and the consuming thread has to
+ * acquire the writes at the system scope. This incurs high overhead and
+ * attempts in optmizing this process can be prone to race condition.
+ *
+ * Instead, we go around caching by having each thread:
+ *
+ * - Directly write to global memory via st.cs (cache-streaming).
+ * - Synchronize with threads within the block.
+ * - Perform cross device synchronization at block level (via system scope
+ * atomic ops).
+ * - Synchronize with threads within the block.
+ * - Directly read from global memory via ld.nc (non-coherent/non-cached).
+ */
+template <typename T>
+DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
+#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
+ CUDA_KERNEL_ASSERT(false);
+#else
+ unsigned long long int low, high;
+ asm("ld.global.nc.v2.u64 {%0, %1}, [%2];"
+ : "=l"(low), "=l"(high)
+ : "l"(addr));
+ reinterpret_cast<unsigned long long int*>(&val)[0] = low;
+ reinterpret_cast<unsigned long long int*>(&val)[1] = high;
+#endif
+}
+
+__device__ inline void streamStore128(at::BFloat16* addr, const bf16x8& val) {
+#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
+ CUDA_KERNEL_ASSERT(false);
+#else
+ unsigned long long int low, high;
+ low = reinterpret_cast<const unsigned long long int*>(&val)[0];
+ high = reinterpret_cast<const unsigned long long int*>(&val)[1];
+ asm("st.global.cs.v2.u64 [%0], {%1, %2};" : : "l"(addr), "l"(low), "l"(high));
+#endif
+}
+
+template <typename T>
+DEVICE_INLINE void load128(bf16x8& val, const T* addr) {
+ *reinterpret_cast<uint4*>(&val) = reinterpret_cast<const uint4*>(addr)[0];
+}
+
+template <typename T>
+DEVICE_INLINE void store128(T* addr, const bf16x8& val) {
+ *reinterpret_cast<uint4*>(addr) = reinterpret_cast<const uint4*>(&val)[0];
+}
+
+DEVICE_INLINE void releaseSignal(uint32_t* addr) {
+#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
+ CUDA_KERNEL_ASSERT(false);
+#else
+ atomicAdd_system(addr, 1);
+#endif
+}
+
+DEVICE_INLINE void acquireSignal(uint32_t* addr) {
+#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
+ CUDA_KERNEL_ASSERT(false);
+#else
+ volatile uint32_t* signal = addr;
+ uint32_t val;
+ do {
+ val = *signal;
+ } while (val == 0 || atomicCAS_system(addr, val, val - 1) != val);
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Fully Connected Algos
+////////////////////////////////////////////////////////////////////////////////
+
+struct P2pState {
+ uint32_t signals0[kMaxAllReduceBlocks][kMaxDevices];
+ uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices];
+};
+
+template <uint32_t kWorldSize, bool kAligned>
+static __global__ void oneShotAllReduceKernel(
+ at::BFloat16* input,
+ size_t N,
+ size_t N_aligned,
+ P2pState** p2pStates,
+ at::BFloat16** buffers,
+ size_t rank) {
+ const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
+ const size_t offset =
+ (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
+ const size_t stride = blockDim.x * gridDim.x * numelPerThread;
+
+ // Wait for all other ranks to enter the kernel
+ if (threadIdx.x < kWorldSize) {
+ auto targetRank = threadIdx.x;
+ releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
+ acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
+ }
+ __syncthreads();
+
+ // The source pointers. Distributed round-robin for the different warps
+ const at::BFloat16* srcs[kWorldSize];
+#pragma unroll kWorldSize
+ for (int ii = 0; ii < kWorldSize; ++ii) {
+ int srcRank = (rank + ii) % kWorldSize;
+ srcs[ii] = buffers[srcRank];
+ }
+
+ for (size_t i = offset; i < N_aligned; i += stride) {
+ bf16x8 vals[kWorldSize];
+#pragma unroll kWorldSize
+ for (size_t ii = 0; ii < kWorldSize; ++ii) {
+ streamLoad128(vals[ii], &srcs[ii][i]);
+ }
+
+ bf16x8 sums;
+ memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
+
+#pragma unroll kWorldSize
+ for (size_t ii = 0; ii < kWorldSize; ++ii) {
+ sums = add_bf16x8(sums, vals[ii]);
+ }
+ if constexpr (kAligned) {
+ streamStore128(&input[i], sums);
+ } else {
+ for (size_t ii = 0; ii < numelPerThread; ++ii) {
+ if (i + ii < N) {
+ input[i + ii] = reinterpret_cast<at::BFloat16*>(&sums)[ii];
+ }
+ }
+ }
+ }
+}
+
+template <uint32_t kWorldSize>
+static __launch_bounds__(1024) __global__ void twoShotAllReduceKernel(
+ at::BFloat16* input,
+ size_t N_aligned,
+ P2pState** p2pStates,
+ at::BFloat16** buffers,
+ size_t rank) {
+ const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
+ const size_t offset =
+ (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
+ const size_t stride = blockDim.x * gridDim.x * numelPerThread;
+ const size_t N_per_rank = N_aligned / kWorldSize;
+ const size_t N_start = N_per_rank * rank;
+
+ // Wait for all other ranks to enter the kernel
+ if (threadIdx.x < kWorldSize) {
+ auto targetRank = threadIdx.x;
+ releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
+ acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
+ }
+ __syncthreads();
+
+ // The source pointers. Distributed round-robin for the different warps
+ at::BFloat16* srcs[kWorldSize];
+ size_t srcRanks[kWorldSize];
+#pragma unroll kWorldSize
+ for (int ii = 0; ii < kWorldSize; ++ii) {
+ int srcRank = (rank + ii) % kWorldSize;
+ srcs[ii] = buffers[srcRank];
+ srcRanks[ii] = srcRank;
+ }
+
+ for (size_t i = offset; i < N_per_rank; i += stride) {
+ bf16x8 vals[kWorldSize];
+#pragma unroll kWorldSize
+ for (size_t ii = 0; ii < kWorldSize; ++ii) {
+ streamLoad128(vals[ii], &srcs[ii][N_start + i]);
+ }
+
+ bf16x8 sums;
+ memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
+
+#pragma unroll kWorldSize
+ for (size_t ii = 0; ii < kWorldSize; ++ii) {
+ sums = add_bf16x8(sums, vals[ii]);
+ }
+ streamStore128(&srcs[0][N_start + i], sums);
+ // Store local sums into input now so we can avoid
+ // a global memory access later for it.
+ streamStore128(&input[N_start + i], sums);
+ }
+ __syncthreads();
+
+ if (threadIdx.x < kWorldSize) {
+ auto targetRank = threadIdx.x;
+ releaseSignal(&p2pStates[targetRank]->signals1[blockIdx.x][rank]);
+ acquireSignal(&p2pStates[rank]->signals1[blockIdx.x][targetRank]);
+ }
+ __syncthreads();
+
+ for (size_t i = offset; i < N_per_rank; i += stride) {
+#pragma unroll kWorldSize - 1
+ for (size_t ii = 1; ii < kWorldSize; ++ii) {
+ size_t k = N_start + i + (srcRanks[ii] - rank) * N_per_rank;
+ bf16x8 val;
+ streamLoad128(val, &srcs[ii][k]);
+ streamStore128(&input[k], val);
+ }
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Hybrid Cube Mesh Algos
+////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * NOTE [hybrid cube mesh]
+ *
+ * In a hybrid cube mesh topology, every device has exactly 4 neighbors
+ * (directly connected via NVLink). For every device X, it has exactly 1
+ * neighbor Y that is a neighbor of the 3 non-neighbor of X. We call Y the
+ * relay neighbor of X. This property is symmetrical: X is also guaranteed to
+ * be the relay neighbor of Y.
+ *
+ * With this property, we can perform a variant of one-shot allreduce algo that
+ * only moves data across NVLinks:
+ *
+ * - Each device one-shot allreduce among itself and 3 non-relay neighbors.
+ * - Each device exchange data with its relay neighbor.
+ *
+ * HybridCubeMesh is a data structure for describing the topology:
+ *
+ * - hcm[X][0:3] are the 3 neighbors of X.
+ * - hcm[X][3] is the relay neighbor of X.
+ * - For load balancing purpose, we also ensure that if hcm[X][k] = Y,
+ * hcm[Y][k] = X.
+ */
+std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh) {
+ std::array<std::unordered_set<size_t>, kMaxDevices> neighbors = {};
+ std::array<size_t, kMaxDevices> neighborMasks = {};
+ for (size_t i = 0; i < kMaxDevices; ++i) {
+ for (size_t j = 0; j < kMaxDevices; ++j) {
+ if (nvlMesh[i][j] > 0) {
+ neighbors[i].insert(j);
+ neighborMasks[i] |= (1ul << j);
+ }
+ }
+ }
+ HybridCubeMesh hcm = {};
+ for (auto& row : hcm) {
+ row.fill(-1);
+ }
+ // A topology is an HCM if:
+ // - Every device has exactly 4 neighbors.
+ // - For every device, it has exactly 1 relay neighbor that is
+ // a neighbor of the 3 non-neighbor of the device.
+ for (size_t i = 0; i < kMaxDevices; ++i) {
+ if (neighbors[i].size() != 4) {
+ return std::nullopt;
+ }
+ // Condition 1: check the number of neighbors
+ std::vector<size_t> relayNeighbors;
+ for (size_t j = 0; j < kMaxDevices; ++j) {
+ if ((neighborMasks[i] & neighborMasks[j]) == 0) {
+ relayNeighbors.push_back(j);
+ }
+ }
+ // Condition 2: check the number of relay neighbors
+ if (relayNeighbors.size() != 1) {
+ return std::nullopt;
+ }
+ neighbors[i].erase(relayNeighbors[0]);
+ hcm[i][3] = relayNeighbors[0];
+ }
+
+ for (size_t i = 0; i < kMaxDevices; ++i) {
+ for (size_t k = 0; k < 3; ++k) {
+ // We can only fill hcm[i][k] with j if hcm[j][k] is not filled
+ for (size_t j : neighbors[i]) {
+ if (hcm[j][k] == -1) {
+ hcm[i][k] = j;
+ hcm[j][k] = i;
+ break;
+ }
+ }
+ TORCH_CHECK(hcm[i][k] != -1);
+ neighbors[i].erase(hcm[i][k]);
+ }
+ }
+ return hcm;
+}
+
+template <bool kAligned>
+static __global__ void hybridCubeMeshAllReduceKernel(
+ at::BFloat16* input,
+ size_t N,
+ size_t N_aligned,
+ P2pState** p2pStates,
+ at::BFloat16** buffers,
+ int hcmInfo[4],
+ size_t rank) {
+ const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
+ const size_t offset =
+ (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
+ const size_t stride = blockDim.x * gridDim.x * numelPerThread;
+ const int relayRank = hcmInfo[3];
+
+ // Wait for HCM neigbors to enter the kernel
+ if (threadIdx.x < 3) {
+ auto targetRank = hcmInfo[threadIdx.x];
+ releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
+ acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
+ }
+ __syncthreads();
+
+ const at::BFloat16* srcs[4] = {
+ buffers[rank],
+ buffers[hcmInfo[0]],
+ buffers[hcmInfo[1]],
+ buffers[hcmInfo[2]],
+ };
+ at::BFloat16* localRelay = buffers[rank] + kMaxIntraNodeSize / 2;
+ at::BFloat16* remoteRelay = buffers[relayRank] + kMaxIntraNodeSize / 2;
+
+ for (size_t i = offset; i < N_aligned; i += stride) {
+ bf16x8 vals[4];
+
+#pragma unroll 4
+ for (size_t ii = 0; ii < 4; ++ii) {
+ streamLoad128(vals[ii], &srcs[ii][i]);
+ }
+
+ bf16x8 sums;
+ memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
+
+#pragma unroll 4
+ for (size_t ii = 0; ii < 4; ++ii) {
+ sums = add_bf16x8(sums, vals[ii]);
+ }
+ // Cached store for local sums
+ store128(&localRelay[i], sums);
+ }
+ __syncthreads();
+
+ if (threadIdx.x == 0) {
+ releaseSignal(&p2pStates[relayRank]->signals0[blockIdx.x][rank]);
+ acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][relayRank]);
+ }
+ __syncthreads();
+
+ for (size_t i = offset; i < N_aligned; i += stride) {
+ bf16x8 localSum, remoteSum;
+ // Cached load for local sums
+ load128(localSum, &localRelay[i]);
+ streamLoad128(remoteSum, &remoteRelay[i]);
+ localSum = add_bf16x8(localSum, remoteSum);
+ if constexpr (kAligned) {
+ streamStore128(&input[i], localSum);
+ } else {
+ for (size_t ii = 0; ii < numelPerThread; ++ii) {
+ if (i + ii < N) {
+ input[i + ii] = reinterpret_cast<at::BFloat16*>(&localSum)[ii];
+ }
+ }
+ }
+ }
+}
+
+static inline size_t divUp(uint32_t a, uint32_t b) {
+ return (a + b - 1) / b;
+}
+
+static inline size_t alignUp(uint32_t a, uint32_t b) {
+ return divUp(a, b) * b;
+}
+
+static void checkInput(const at::Tensor& input, size_t rank) {
+ TORCH_CHECK(
+ input.dtype() == at::kBFloat16,
+ "oneShotAllReduce only supports bf16 for now");
+ TORCH_CHECK(input.is_non_overlapping_and_dense());
+ TORCH_CHECK(input.device().is_cuda());
+ TORCH_CHECK(static_cast<size_t>(input.get_device()) == rank);
+}
+
+static void getLaunchConfig(
+ size_t N_aligned,
+ size_t elemSize,
+ dim3& blocks,
+ dim3& threads) {
+ blocks = dim3(0, 1, 1);
+ threads = dim3(0, 1, 1);
+
+ const auto numelPerThread = kBytesPerThread / elemSize;
+ const auto numelPerWarp = numelPerThread * kWarpSize;
+ TORCH_CHECK(N_aligned % numelPerThread == 0);
+ TORCH_CHECK(N_aligned % numelPerWarp == 0);
+ if (N_aligned < numelPerThread * kThreadsPerBlock) {
+ threads.x = N_aligned / numelPerWarp * kWarpSize;
+ blocks.x = 1;
+ } else {
+ auto warpsRequired = N_aligned / numelPerWarp;
+ auto threadsRequired = N_aligned / numelPerThread;
+ blocks.x =
+ std::min(divUp(threadsRequired, kThreadsPerBlock), kMaxAllReduceBlocks);
+ auto warpsPerBlock = divUp(warpsRequired, blocks.x);
+ threads.x = std::min(kThreadsPerBlock, warpsPerBlock * kWarpSize);
+ }
+}
+
+bool isIntraNodeCommSupported() {
+#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
+ return false;
+#else
+ return true;
+#endif
+}
+
+void* initP2pState() {
+ void* state = nullptr;
+ AT_CUDA_CHECK(cudaMalloc(&state, sizeof(P2pState)));
+ AT_CUDA_CHECK(cudaMemset(state, 0, sizeof(P2pState)));
+ return state;
+}
+
+void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank) {
+ void* topoInfo = nullptr;
+ if (topology != Topology::HYBRID_CUBE_MESH) {
+ return topoInfo;
+ }
+ auto hcm = getHybridCubeMesh(nvlMesh);
+ int hcmInfo[4];
+ std::copy((*hcm)[rank].begin(), (*hcm)[rank].begin() + 4, hcmInfo);
+ AT_CUDA_CHECK(cudaMalloc(&topoInfo, sizeof(hcmInfo)));
+ AT_CUDA_CHECK(
+ cudaMemcpy(topoInfo, hcmInfo, sizeof(hcmInfo), cudaMemcpyHostToDevice));
+ return topoInfo;
+}
+
+at::Tensor oneShotAllReduce(
+ const at::Tensor& input,
+ std::array<void*, kMaxDevices> p2pStates,
+ std::array<void*, kMaxDevices> buffers,
+ void* p2pStatesDev,
+ void* buffersDev,
+ size_t rank,
+ size_t worldSize,
+ at::cuda::CUDAStream& stream) {
+ checkInput(input, rank);
+
+ size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
+ size_t N_aligned = alignUp(input.numel(), numelPerWarp);
+ TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
+
+ dim3 blocks, threads;
+ getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
+
+ at::cuda::OptionalCUDAGuard guard(input.get_device());
+ AT_CUDA_CHECK(cudaMemcpyAsync(
+ buffers[rank],
+ input.data_ptr(),
+ input.numel() * input.element_size(),
+ cudaMemcpyDeviceToDevice,
+ stream));
+
+#define X(kWorldSize, kAligned) \
+ if (worldSize == kWorldSize) { \
+ oneShotAllReduceKernel<kWorldSize, kAligned> \
+ <<<blocks, threads, 0, stream>>>( \
+ input.data_ptr<at::BFloat16>(), \
+ input.numel(), \
+ N_aligned, \
+ reinterpret_cast<P2pState**>(p2pStatesDev), \
+ reinterpret_cast<at::BFloat16**>(buffersDev), \
+ rank); \
+ C10_CUDA_KERNEL_LAUNCH_CHECK(); \
+ }
+
+#define DISPATCH_ALL_WORLD_SIZES(kAligned) \
+ X(2, kAligned); \
+ X(3, kAligned); \
+ X(4, kAligned); \
+ X(5, kAligned); \
+ X(6, kAligned); \
+ X(7, kAligned); \
+ X(8, kAligned);
+
+ if (N_aligned == static_cast<size_t>(input.numel())) {
+ DISPATCH_ALL_WORLD_SIZES(true);
+ } else {
+ DISPATCH_ALL_WORLD_SIZES(false);
+ }
+
+#undef DISPATCH_ALL_WORLD_SIZES
+#undef X
+ return input;
+}
+
+at::Tensor twoShotAllReduce(
+ const at::Tensor& input,
+ std::array<void*, kMaxDevices> p2pStates,
+ std::array<void*, kMaxDevices> buffers,
+ void* p2pStatesDev,
+ void* buffersDev,
+ size_t rank,
+ size_t worldSize,
+ at::cuda::CUDAStream& stream) {
+ checkInput(input, rank);
+
+ size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
+ size_t N_aligned = alignUp(input.numel(), worldSize * numelPerWarp);
+ size_t N_per_rank = N_aligned / worldSize;
+ TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
+
+ dim3 blocks, threads;
+ getLaunchConfig(N_per_rank, input.element_size(), blocks, threads);
+
+ auto output = N_aligned == static_cast<size_t>(input.numel())
+ ? input
+ : input.new_empty(N_aligned);
+
+ at::cuda::OptionalCUDAGuard guard(input.get_device());
+ AT_CUDA_CHECK(cudaMemcpyAsync(
+ buffers[rank],
+ input.data_ptr(),
+ input.numel() * input.element_size(),
+ cudaMemcpyDeviceToDevice,
+ stream));
+
+#define X(kWorldSize) \
+ if (worldSize == kWorldSize) { \
+ twoShotAllReduceKernel<kWorldSize><<<blocks, threads, 0, stream>>>( \
+ output.data_ptr<at::BFloat16>(), \
+ N_aligned, \
+ reinterpret_cast<P2pState**>(p2pStatesDev), \
+ reinterpret_cast<at::BFloat16**>(buffersDev), \
+ rank); \
+ C10_CUDA_KERNEL_LAUNCH_CHECK(); \
+ }
+ X(2);
+ X(3);
+ X(4);
+ X(5);
+ X(6);
+ X(7);
+ X(8);
+#undef X
+
+ if (output.data_ptr() != input.data_ptr()) {
+ AT_CUDA_CHECK(cudaMemcpyAsync(
+ input.data_ptr(),
+ output.data_ptr(),
+ input.numel() * input.element_size(),
+ cudaMemcpyDeviceToDevice,
+ stream));
+ }
+ return input;
+}
+
+at::Tensor hybridCubeMeshAllReduce(
+ const at::Tensor& input,
+ std::array<void*, kMaxDevices> p2pStates,
+ std::array<void*, kMaxDevices> buffers,
+ void* p2pStatesDev,
+ void* buffersDev,
+ int hcmInfo[4],
+ size_t rank,
+ size_t worldSize,
+ at::cuda::CUDAStream& stream) {
+ checkInput(input, rank);
+
+ size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
+ size_t N_aligned = alignUp(input.numel(), numelPerWarp);
+ TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
+
+ dim3 blocks, threads;
+ getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
+
+ at::cuda::OptionalCUDAGuard guard(input.get_device());
+ AT_CUDA_CHECK(cudaMemcpyAsync(
+ buffers[rank],
+ input.data_ptr(),
+ input.numel() * input.element_size(),
+ cudaMemcpyDeviceToDevice,
+ stream));
+
+#define X(kAligned) \
+ hybridCubeMeshAllReduceKernel<kAligned><<<blocks, threads, 0, stream>>>( \
+ input.data_ptr<at::BFloat16>(), \
+ input.numel(), \
+ N_aligned, \
+ reinterpret_cast<P2pState**>(p2pStatesDev), \
+ reinterpret_cast<at::BFloat16**>(buffersDev), \
+ hcmInfo, \
+ rank); \
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+ if (N_aligned == static_cast<size_t>(input.numel())) {
+ X(true);
+ } else {
+ X(false);
+ }
+#undef X
+ return input;
+}
+
+AllReduceAlgo selectAllReduceAlgo(
+ const at::Tensor& input,
+ Topology topology,
+ size_t worldSize) {
+ // Only support bf16 for now
+ if (input.dtype() != at::kBFloat16 ||
+ input.numel() * input.element_size() > kMaxIntraNodeSize) {
+ return AllReduceAlgo::NONE;
+ }
+ const auto numel = input.numel();
+ const auto numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
+ if (topology == Topology::HYBRID_CUBE_MESH) {
+ TORCH_CHECK(
+ worldSize == 8, "hyperCubeAllReduce only supports exactly 8 GPUs");
+ if (alignUp(numel, numelPerWarp) <= kHcmThreshBytes) {
+ return AllReduceAlgo::HCM;
+ }
+ }
+ if (topology == Topology::FULLY_CONNECTED) {
+ if (alignUp(numel, numelPerWarp) <= kOneShotThreshBytes) {
+ return AllReduceAlgo::ONE_SHOT;
+ }
+ if (alignUp(numel, numelPerWarp * worldSize) <= kTwoShotThreshBytes) {
+ return AllReduceAlgo::TWO_SHOT;
+ }
+ }
+ return AllReduceAlgo::NONE;
+}
+
+at::Tensor allReduce(
+ const at::Tensor& input,
+ std::array<void*, kMaxDevices> p2pStates,
+ std::array<void*, kMaxDevices> buffers,
+ void* p2pStatesDev,
+ void* buffersDev,
+ void* topoInfo,
+ size_t rank,
+ size_t worldSize,
+ AllReduceAlgo algo,
+ at::cuda::CUDAStream& stream) {
+ switch (algo) {
+ case AllReduceAlgo::ONE_SHOT:
+ return oneShotAllReduce(
+ input,
+ p2pStates,
+ buffers,
+ p2pStatesDev,
+ buffersDev,
+ rank,
+ worldSize,
+ stream);
+ case AllReduceAlgo::TWO_SHOT:
+ return twoShotAllReduce(
+ input,
+ p2pStates,
+ buffers,
+ p2pStatesDev,
+ buffersDev,
+ rank,
+ worldSize,
+ stream);
+ case AllReduceAlgo::HCM:
+ return hybridCubeMeshAllReduce(
+ input,
+ p2pStates,
+ buffers,
+ p2pStatesDev,
+ buffersDev,
+ (int*)topoInfo,
+ rank,
+ worldSize,
+ stream);
+ default:
+ C10_THROW_ERROR(ValueError, "IntraNodeComm: invalid algo");
+ }
+}
+
+} // namespace intra_node_comm
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp
new file mode 100644
index 0000000..b494990
--- /dev/null
+++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp
@@ -0,0 +1,106 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAEvent.h>
+#include <c10/cuda/CUDAStream.h>
+#include <torch/csrc/distributed/c10d/Store.hpp>
+#include <torch/csrc/distributed/c10d/Work.hpp>
+
+namespace c10d {
+namespace intra_node_comm {
+
+constexpr size_t kMaxDevices = 8;
+constexpr size_t kMaxIntraNodeSize = 10 * 1024 * 1024;
+
+using NvlMesh = std::array<std::array<size_t, kMaxDevices>, kMaxDevices>;
+using HybridCubeMesh = std::array<std::array<int, 4>, kMaxDevices>;
+
+enum class Topology { UNKNOWN = 0, FULLY_CONNECTED = 1, HYBRID_CUBE_MESH = 2 };
+
+enum class AllReduceAlgo { NONE = 0, ONE_SHOT = 1, TWO_SHOT = 2, HCM = 3 };
+
+class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
+ public:
+ IntraNodeComm(
+ Topology topology,
+ std::array<void*, kMaxDevices> p2pStates,
+ std::array<void*, kMaxDevices> buffers,
+ void* p2pStatesDev,
+ void* buffersDev,
+ void* topoInfo,
+ size_t rank,
+ size_t worldSize);
+
+ ~IntraNodeComm();
+
+ /**
+ * Rendezvous via a c10d::Store.
+ * This function may return nullptr if intra-node comm is not applicable.
+ * It guarantees all participants either succeeds or abort.
+ */
+ static c10::intrusive_ptr<IntraNodeComm> rendezvous(
+ c10::intrusive_ptr<c10d::Store> store,
+ const std::string& prefix,
+ size_t rank,
+ size_t worldSize);
+
+ /**
+ * Selects a AllReduceAlgo that we think will outperform nccl.
+ * Returns AllReduceAlgo::NONE if we don't think we can outperform nccl.
+ */
+ AllReduceAlgo selectAllReduceAlgo(const at::Tensor& input);
+
+ at::Tensor allReduce(const at::Tensor& input, AllReduceAlgo algo);
+
+ private:
+ Topology topology_;
+ std::array<void*, kMaxDevices> p2pStates_;
+ std::array<void*, kMaxDevices> buffers_;
+ void* p2pStatesDev_;
+ void* buffersDev_;
+ void* topoInfo_;
+ size_t rank_;
+ size_t worldSize_;
+};
+
+/**
+ * NOTE [IntraNodeComm Stream Semantics]
+ *
+ * ProcessGroupNCCL launches kernels differently from the conventional PyTorch
+ * CUDA semantics: it always launches collective kernels onto a dedicated
+ * communication stream. Therefore, it needs to:
+ *
+ * - Synchronize the calling stream and the comm stream.
+ * - Ensure the memory safety of the operands (via record_stream or stashing).
+ * - Synchronize the waiting stream with the comm stream.
+ *
+ * Unconditionally performing these tasks makes sense when we expect most of the
+ * communication to benefit from compute/comm overlap. However, IntraNodeComm
+ * primarily aims to optimize small, latency-sensitive, blocking communication,
+ * in which the overhead incurred by the above steps can be quite pronounced.
+ *
+ * Thus, IntraNodeComm follows the conventional PyTorch CUDA semantics and
+ * launches kernels onto the stream specified by the user. Although the user
+ * can perform neccessary synchronization via wait_stream, to provide a UX
+ * consistent to that of ProcessGroupNCCL, the neccessary stream
+ * synchronization can also be performed via IntraNodeWork::wait().
+ */
+class IntraNodeCommWork : public c10d::Work {
+ public:
+ IntraNodeCommWork() : c10d::Work() {
+ event_.record();
+ }
+
+ bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
+ event_.block(at::cuda::getCurrentCUDAStream());
+ return true;
+ }
+
+ private:
+ at::cuda::CUDAEvent event_;
+};
+
+TORCH_API int64_t getIntraNodeCommUsageCounter();
+
+} // namespace intra_node_comm
+} // namespace c10d