[re-land] Introduce 3 low-latency, intra-node allreduce algorithms for small messages to PyTorch (#114001) (#116125)

This is an attempt to re-land https://github.com/pytorch/pytorch/pull/114001. The previous attempt used `std::array` in cuda kernels which wasn't compatible with Meta's internal build.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116125
Approved by: https://github.com/yf225
diff --git a/BUILD.bazel b/BUILD.bazel
index f5739a9..59d2ea8 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -1452,7 +1452,10 @@
     # https://github.com/pytorch/pytorch/issues/79236
     # To solve it we add it into the `caffe2_cuda`,
     # this is also aligned with the CMake build.
-    srcs = [":caffe2_cu_srcs"] + ["torch/csrc/distributed/c10d/quantization/quantization_gpu.cu"],
+    srcs = [":caffe2_cu_srcs"] + [
+        "torch/csrc/distributed/c10d/intra_node_comm.cu",
+        "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
+    ],
     copts = CAFFE2_COPTS + torch_cuda_half_options,
     visibility = ["//visibility:public"],
     deps = [
@@ -1619,6 +1622,7 @@
         exclude = [
             "torch/csrc/cuda/python_nccl.cpp",
             "torch/csrc/cuda/nccl.cpp",
+            "torch/csrc/distributed/c10d/intra_node_comm.cu",
             "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
         ],
     )) + torch_sources,
diff --git a/build_variables.bzl b/build_variables.bzl
index 9d61861..b028a9b 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -674,6 +674,8 @@
     "torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
     "torch/csrc/distributed/c10d/UCCTracing.cpp",
     "torch/csrc/distributed/c10d/UCCUtils.cpp",
+    "torch/csrc/distributed/c10d/intra_node_comm.cpp",
+    "torch/csrc/distributed/c10d/intra_node_comm.cu",
     "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
     "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
 ]
diff --git a/c10/cuda/driver_api.cpp b/c10/cuda/driver_api.cpp
index a90081f..56243e6 100644
--- a/c10/cuda/driver_api.cpp
+++ b/c10/cuda/driver_api.cpp
@@ -37,7 +37,7 @@
   return nvml_hanle;
 }
 
-DriverAPI* DriverAPI::get() {
+C10_EXPORT DriverAPI* DriverAPI::get() {
   static DriverAPI singleton = create_driver_api();
   return &singleton;
 }
diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h
index 6f5e46f..f4054c2 100644
--- a/c10/cuda/driver_api.h
+++ b/c10/cuda/driver_api.h
@@ -28,9 +28,11 @@
   _(cuMemCreate)                  \
   _(cuGetErrorString)
 
-#define C10_NVML_DRIVER_API(_)        \
-  _(nvmlInit_v2)                      \
-  _(nvmlDeviceGetHandleByPciBusId_v2) \
+#define C10_NVML_DRIVER_API(_)           \
+  _(nvmlInit_v2)                         \
+  _(nvmlDeviceGetHandleByPciBusId_v2)    \
+  _(nvmlDeviceGetNvLinkRemoteDeviceType) \
+  _(nvmlDeviceGetNvLinkRemotePciInfo_v2) \
   _(nvmlDeviceGetComputeRunningProcesses)
 
 namespace c10 {
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 7483637..f2acc61 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -641,6 +641,10 @@
     append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS)
     if(NOT WIN32)
       append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS)
+      set_source_files_properties(
+        ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
+        PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
+      )
     endif()
   endif()
   set_source_files_properties(
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index 2992a26..15fc8e3 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -15,7 +15,7 @@
 from contextlib import contextmanager
 from datetime import datetime, timedelta
 from itertools import chain, product
-from unittest import mock
+from unittest import SkipTest, mock
 
 import torch
 import torch.distributed as c10d
@@ -3115,6 +3115,65 @@
 
     @requires_nccl()
     @skip_if_lt_x_gpu(2)
+    @skip_if_rocm
+    def test_intra_node_comm_all_reduce(self):
+        from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter
+        from torch.testing._internal.common_cuda import SM80OrLater
+        for peer in range(self.world_size):
+            if peer == self.rank:
+                continue
+            if not torch._C._cuda_canDeviceAccessPeer(self.rank, peer):
+                raise SkipTest("Test requires p2p access")
+
+        if not SM80OrLater:
+            raise SkipTest("Test requires sm>=80")
+
+        store = c10d.FileStore(self.file_name, self.world_size)
+        os.environ["ENABLE_INTRA_NODE_COMM"] = "1"
+        os.environ["TEST_INTRA_NODE_COMM"] = "1"
+        torch.cuda.set_device(self.rank)
+        c10d.init_process_group(
+            backend="nccl", rank=self.rank, world_size=self.world_size, store=store
+        )
+        expect = self.world_size * (self.world_size - 1) // 2
+
+        # IntraNodeComm currently only supports sum and bf16.
+        # Verify that it is not used in the next two configurations.
+        t = torch.full((4 * 1024 // 2,), self.rank).cuda()
+        c10d.all_reduce(t, c10d.ReduceOp.SUM)
+        self.assertTrue(t.eq(expect).all())
+        self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
+
+        t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
+        c10d.all_reduce(t, c10d.ReduceOp.AVG)
+        self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
+
+        # Verify that IntraNodeComm is used up to 10MB
+        t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
+        c10d.all_reduce(t, c10d.ReduceOp.SUM)
+        self.assertTrue(t.eq(expect).all())
+        self.assertEqual(_get_intra_node_comm_usage_counter(), 1)
+
+        t = torch.full((512 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
+        c10d.all_reduce(t, c10d.ReduceOp.SUM)
+        self.assertTrue(t.eq(expect).all())
+        self.assertEqual(_get_intra_node_comm_usage_counter(), 2)
+
+        t = torch.full((10 * 1024 ** 2 // 2,), self.rank, dtype=torch.bfloat16).cuda()
+        c10d.all_reduce(t, c10d.ReduceOp.SUM)
+        self.assertTrue(t.eq(expect).all())
+        self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
+
+        # Verify that IntraNodeComm is not used beyond 10MB
+        t = torch.full((10 * 1024 ** 2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda()
+        c10d.all_reduce(t, c10d.ReduceOp.SUM)
+        self.assertTrue(t.eq(expect).all())
+        self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
+
+        c10d.destroy_process_group()
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(2)
     def test_sequence_num_set_default_pg_nccl(self):
         torch.cuda.set_device(self.rank)
         self._test_sequence_num_set_default_pg(backend="nccl")
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 026576a..0580ea3 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -712,7 +712,8 @@
       terminateProcessGroup_(false),
       terminateHeartbeatMonitorThread_(false),
       collectiveDebugInfoMode_(false),
-      uid_(process_group_id++) {
+      uid_(process_group_id++),
+      intraNodeComm_(initIntraNodeComm()) {
   TORCH_CHECK_WITH(
       ValueError,
       at::cuda::getNumGPUs() != 0,
@@ -896,6 +897,12 @@
 #endif
 }
 
+c10::intrusive_ptr<intra_node_comm::IntraNodeComm> ProcessGroupNCCL::
+    initIntraNodeComm() {
+  return intra_node_comm::IntraNodeComm::rendezvous(
+      store_, std::to_string(uid_), rank_, size_);
+}
+
 void ProcessGroupNCCL::runHealthCheck() {
   // Run health check in a separate thread and wait on CV to handle timeouts,
   // since majority of getNCCLComm failures are hangs.
@@ -2842,6 +2849,16 @@
 c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
     std::vector<at::Tensor>& tensors,
     const AllreduceOptions& opts) {
+  if (intraNodeComm_ != nullptr && tensors.size() == 1 &&
+      opts.reduceOp == ReduceOp::SUM) {
+    using namespace intra_node_comm;
+    auto algo = intraNodeComm_->selectAllReduceAlgo(tensors[0]);
+    if (algo != intra_node_comm::AllReduceAlgo::NONE) {
+      intraNodeComm_->allReduce(tensors[0], algo);
+      return c10::make_intrusive<IntraNodeCommWork>();
+    }
+  }
+
   check_gpu_tensors_different_devices(tensors);
 
   // @lint-ignore CLANGTIDY
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index 00022b1..4c07053 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -13,6 +13,7 @@
 #include <torch/csrc/distributed/c10d/Backend.hpp>
 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
 #include <torch/csrc/distributed/c10d/Store.hpp>
+#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
 
 #include <ATen/DynamicLibrary.h>
 #include <ATen/cuda/CUDAContext.h>
@@ -552,6 +553,8 @@
           ncclCommsMap,
       c10::optional<std::string> abortReason);
 
+  c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
+
   // Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
   // instead of relying on ProcessGroupNCCL destructor.
   void abort(c10::optional<std::string> abortReason = c10::nullopt);
@@ -950,6 +953,8 @@
   std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr;
 
   size_t uid_;
+
+  c10::intrusive_ptr<intra_node_comm::IntraNodeComm> intraNodeComm_;
 };
 
 TORCH_API std::string dump_nccl_trace();
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 252cf4a..22b02f0 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -21,6 +21,7 @@
 #ifdef USE_C10D_NCCL
 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
 #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
+#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
 #endif
 
 #ifdef USE_C10D_MPI
@@ -2328,6 +2329,10 @@
               "perform_nocolor_split",
               &::c10d::ProcessGroupNCCL::performNocolorSplit);
 
+  module.def(
+      "_get_intra_node_comm_usage_counter",
+      &::c10d::intra_node_comm::getIntraNodeCommUsageCounter);
+
 #ifdef NCCL_HAS_COMM_CTA_CGA
   py::class_<ncclConfig_t>(
       processGroupNCCL,
diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp
new file mode 100644
index 0000000..50b0147
--- /dev/null
+++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp
@@ -0,0 +1,485 @@
+#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
+
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+#include <c10/util/Logging.h>
+#include <torch/csrc/distributed/c10d/Utils.hpp>
+
+#include <iostream>
+#include <random>
+
+#include <fcntl.h>
+#include <pthread.h>
+#include <semaphore.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
+#include <c10/cuda/driver_api.h>
+#include <nvml.h>
+#endif
+
+#include <cuda_runtime.h>
+
+namespace c10d {
+namespace intra_node_comm {
+
+static std::vector<std::string> ENABLE_INTRA_NODE_COMM = {
+    "ENABLE_INTRA_NODE_COMM"};
+// Forces detectedTopology() to return Topology::FULLY_CONNECTED, so
+// IntraNodeComm can be used even without NVLink connection. This is only used
+// for testing purposes.
+static std::vector<std::string> TEST_INTRA_NODE_COMM = {"TEST_INTRA_NODE_COMM"};
+
+////////////////////////////////////////////////////////////////////////////////
+// CUDA Functions
+////////////////////////////////////////////////////////////////////////////////
+
+bool isIntraNodeCommSupported();
+
+std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh);
+
+void* initP2pState();
+
+void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank);
+
+AllReduceAlgo selectAllReduceAlgo(
+    const at::Tensor& input,
+    Topology topology,
+    size_t worldSize);
+
+at::Tensor allReduce(
+    const at::Tensor& input,
+    std::array<void*, kMaxDevices> p2pStates,
+    std::array<void*, kMaxDevices> buffers,
+    void* p2pStatesDev,
+    void* buffersDev,
+    void* topoInfo,
+    size_t rank,
+    size_t worldSize,
+    AllReduceAlgo algo,
+    at::cuda::CUDAStream& stream);
+
+////////////////////////////////////////////////////////////////////////////////
+// Topology Detection
+////////////////////////////////////////////////////////////////////////////////
+
+// TODO: find a better way to determine this
+static constexpr size_t kMaxNvLinks = 20;
+
+static std::ostream& operator<<(std::ostream& os, const NvlMesh& nvlMesh) {
+  std::ostringstream oss;
+  for (size_t i = 0; i < kMaxDevices; ++i) {
+    for (size_t j = 0; j < kMaxDevices; ++j) {
+      oss << nvlMesh[i][j] << " ";
+    }
+    oss << std::endl;
+  }
+  os << oss.str();
+  return os;
+}
+
+static bool isSame(NvlMesh lhs, NvlMesh rhs) {
+  for (size_t i = 0; i < kMaxDevices; ++i) {
+    for (size_t j = 0; j < kMaxDevices; ++j) {
+      if (lhs[i][j] != rhs[i][j]) {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
+/**
+ * Query the nvlink connection among devices.
+ */
+static NvlMesh getNvlMesh(std::vector<std::string> rankToBusId) {
+#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
+  using namespace c10::cuda;
+
+  NvlMesh nvlMesh = {};
+  auto driverApi = DriverAPI::get();
+  if (driverApi == nullptr) {
+    return nvlMesh;
+  }
+
+  const auto worldSize = rankToBusId.size();
+  std::vector<nvmlDevice_t> devices(worldSize, 0);
+  std::unordered_map<std::string, size_t> busIdToRank;
+  std::vector<size_t> switchLinkCount(worldSize, 0);
+
+  for (size_t r = 0; r < worldSize; ++r) {
+    busIdToRank.emplace(std::make_pair(rankToBusId[r], r));
+    TORCH_CHECK(
+        driverApi->nvmlDeviceGetHandleByPciBusId_v2_(
+            rankToBusId[r].c_str(), &devices[r]) == NVML_SUCCESS);
+  }
+
+  // For each device, loop over devices connected to it via NVLink
+  for (size_t idx = 0; idx < worldSize; ++idx) {
+    for (size_t link = 0; link < kMaxNvLinks; ++link) {
+      nvmlReturn_t ret;
+      nvmlIntNvLinkDeviceType_t deviceType;
+      ret = driverApi->nvmlDeviceGetNvLinkRemoteDeviceType_(
+          devices[idx], link, &deviceType);
+      if (ret != NVML_SUCCESS) {
+        // We've exhausted the NVLinks connected to this device.
+        // This error is benign. There doesn't seem to be a reliable
+        // way to obtain the maximum link value that can be passed to
+        // the API, so we simply increment the link value until the
+        // API fails or we hit a predefined maximum value.
+        break;
+      }
+      // Remote device is GPU
+      if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) {
+        nvmlPciInfo_t pciInfo;
+        ret = driverApi->nvmlDeviceGetNvLinkRemotePciInfo_v2_(
+            devices[idx], link, &pciInfo);
+        if (ret != NVML_SUCCESS) {
+          // Unexpected error. Return an empty NvlMesh
+          return {};
+        }
+        auto it = busIdToRank.find(pciInfo.busId);
+        if (it != busIdToRank.end()) {
+          if (idx != it->second) {
+            nvlMesh[idx][it->second] += 1;
+          }
+        }
+        // Remote device is NVSwitch
+      } else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) {
+        switchLinkCount[idx] += 1;
+      }
+    }
+  }
+  // Process NVSwitch connections. For simplicity, we assume
+  // all NVSwitches are interconnected.
+  for (size_t i = 0; i < worldSize; ++i) {
+    for (size_t j = 0; j < worldSize; ++j) {
+      if (i == j) {
+        continue;
+      }
+      nvlMesh[i][j] += std::min(switchLinkCount[i], switchLinkCount[j]);
+    }
+  }
+  return nvlMesh;
+#else
+  return {};
+#endif
+}
+
+/**
+ * Determine if the devices form a hybrid cube mesh
+ * topology given a NvlMesh.
+ */
+static bool isHybridCubeMesh(const NvlMesh nvlMesh) {
+  std::array<size_t, kMaxDevices> numNeighbors = {};
+  for (size_t i = 0; i < kMaxDevices; ++i) {
+    for (size_t j = 0; j < kMaxDevices; ++j) {
+      if (nvlMesh[i][j] > 0) {
+        numNeighbors[i] += 1;
+      }
+    }
+  }
+  for (size_t i = 0; i < kMaxDevices; ++i) {
+    // TODO: this is insufficent and needs revisit
+    if (numNeighbors[i] != 4) {
+      return false;
+    }
+  }
+  return true;
+}
+
+/**
+ * Detech topology given a NvlMesh.
+ */
+static Topology detectTopology(const NvlMesh nvlMesh, size_t worldSize) {
+  if (getCvarBool(TEST_INTRA_NODE_COMM, false)) {
+    return Topology::FULLY_CONNECTED;
+  }
+  bool fullyConnected = true;
+  for (size_t i = 0; i < worldSize - 1; ++i) {
+    for (size_t j = i + 1; j < worldSize; ++j) {
+      if (nvlMesh[i][j] == 0 || nvlMesh[j][i] == 0) {
+        fullyConnected = false;
+      }
+    }
+  }
+  if (fullyConnected) {
+    LOG(INFO) << "IntraNodeComm: Topology::FULLY_CONNECTED";
+    return Topology::FULLY_CONNECTED;
+  }
+  if (worldSize == kMaxDevices && getHybridCubeMesh(nvlMesh) != std::nullopt) {
+    LOG(INFO) << "IntraNodeComm: Topology::HYBRID_CUBE_MESH";
+    return Topology::HYBRID_CUBE_MESH;
+  }
+  LOG(INFO) << "IntraNodeComm: Topology::UNKNOWN";
+  return Topology::UNKNOWN;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+// Rendezvous and Initialization
+////////////////////////////////////////////////////////////////////////////////
+
+IntraNodeComm::IntraNodeComm(
+    Topology topology,
+    std::array<void*, kMaxDevices> p2pStates,
+    std::array<void*, kMaxDevices> buffers,
+    void* p2pStatesDev,
+    void* buffersDev,
+    void* topoInfo,
+    size_t rank,
+    size_t worldSize)
+    : topology_(topology),
+      p2pStates_(p2pStates),
+      buffers_(buffers),
+      p2pStatesDev_(p2pStatesDev),
+      buffersDev_(buffersDev),
+      topoInfo_(topoInfo),
+      rank_(rank),
+      worldSize_(worldSize) {}
+
+IntraNodeComm::~IntraNodeComm() {
+  // Intentionally releasing resources without synchronizing devices. The
+  // teardown logic is safe for propoerly sync'd user program. We don't want
+  // improperly sync'd user program to hang here.
+  for (size_t r = 0; r < worldSize_; ++r) {
+    if (r == rank_) {
+      continue;
+    }
+    AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r]));
+    AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r]));
+  }
+  AT_CUDA_CHECK(cudaFree(p2pStates_[rank_]));
+  AT_CUDA_CHECK(cudaFree(buffers_[rank_]));
+  if (topoInfo_ != nullptr) {
+    AT_CUDA_CHECK(cudaFree(topoInfo_));
+  }
+  AT_CUDA_CHECK(cudaFree(p2pStatesDev_));
+  AT_CUDA_CHECK(cudaFree(buffersDev_));
+}
+
+/**
+ * Use c10d::Store to perform allgather on a trivially copyable type.
+ */
+template <typename T>
+std::vector<T> storeAllGather(
+    c10::intrusive_ptr<c10d::Store> store,
+    const std::string& prefix,
+    size_t rank,
+    size_t worldSize,
+    T val) {
+  static_assert(std::is_trivially_copyable<T>::value);
+
+  std::vector<std::string> peerKeys;
+  for (size_t r = 0; r < worldSize; ++r) {
+    std::ostringstream oss;
+    oss << prefix << "-" << r;
+    peerKeys.push_back(oss.str());
+  }
+
+  {
+    std::vector<uint8_t> payload(
+        reinterpret_cast<uint8_t*>(&val),
+        reinterpret_cast<uint8_t*>(&val) + sizeof(T));
+    store->set(peerKeys[rank], payload);
+  }
+
+  std::vector<T> peerVals;
+  for (size_t r = 0; r < worldSize; ++r) {
+    if (r == rank) {
+      peerVals.push_back(val);
+      continue;
+    }
+    store->wait({peerKeys[r]});
+    auto payload = store->get(peerKeys[r]);
+    TORCH_CHECK(payload.size() == sizeof(T));
+    T peerVal;
+    std::memcpy(&peerVal, payload.data(), sizeof(T));
+    peerVals.push_back(peerVal);
+  }
+  return peerVals;
+}
+
+c10::intrusive_ptr<IntraNodeComm> IntraNodeComm::rendezvous(
+    c10::intrusive_ptr<c10d::Store> store,
+    const std::string& prefix,
+    size_t rank,
+    size_t worldSize) {
+#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
+  if (!isIntraNodeCommSupported() ||
+      !getCvarBool(ENABLE_INTRA_NODE_COMM, false) || worldSize < 2 ||
+      worldSize > kMaxDevices) {
+    return nullptr;
+  }
+
+  int deviceIdx = at::cuda::current_device();
+  c10::cuda::CUDAGuard guard(deviceIdx);
+
+  // First hand shake: exchange hostname and device bus ID
+  struct DevInfo {
+    char hostname[HOST_NAME_MAX + 1];
+    char busId[80];
+  };
+
+  DevInfo devInfo{};
+  gethostname(devInfo.hostname, sizeof(devInfo.hostname));
+  cudaDeviceProp prop{};
+  AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx));
+  snprintf(
+      devInfo.busId,
+      sizeof(devInfo.busId),
+      NVML_DEVICE_PCI_BUS_ID_FMT,
+      prop.pciDomainID,
+      prop.pciBusID,
+      prop.pciDeviceID);
+
+  auto peerDevInfos = storeAllGather(
+      store, prefix + "-IntraNodeCommHandShake-0", rank, worldSize, devInfo);
+
+  std::vector<std::string> rankToBusId;
+  for (const auto& info : peerDevInfos) {
+    if (strcmp(info.hostname, peerDevInfos.front().hostname) != 0) {
+      LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
+                      "participants are not on the same host ("
+                   << info.hostname << ", " << devInfo.hostname << ")";
+      return nullptr;
+    }
+    rankToBusId.emplace_back(info.busId);
+  }
+
+  // Verify unique devices
+  {
+    std::unordered_set uniqueBusIds(rankToBusId.begin(), rankToBusId.end());
+    TORCH_CHECK(
+        uniqueBusIds.size() == worldSize,
+        "IntraNodeComm::rendezvous: detected overlapping devices across ranks. "
+        "Please properly set device via torch.cuda.set_device() before "
+        "initiating rendezvous.");
+  }
+
+  // Query nvlink connection
+  auto nvlMesh = getNvlMesh(rankToBusId);
+
+  // Detect topology
+  Topology topology = detectTopology(nvlMesh, worldSize);
+
+  // Initialize p2p state
+  auto p2pState = initP2pState();
+
+  // Allocate buffer
+  void* buffer = nullptr;
+  AT_CUDA_CHECK(cudaMalloc(&buffer, kMaxIntraNodeSize * 2));
+
+  // Second handshake: exchange topology and CUDA IPC handles
+  struct IpcInfo {
+    NvlMesh nvlMesh;
+    Topology topology;
+    cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
+  };
+
+  // Make p2p state and buffer available for IPC
+  cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
+  AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState));
+  AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer));
+
+  IpcInfo ipcInfo{
+      .nvlMesh = nvlMesh,
+      .topology = topology,
+      .p2pStateHandle = p2pStateHandle,
+      .bufferHandle = bufferHandle};
+
+  auto peerIpcInfos = storeAllGather(
+      store, prefix + "-IntraNodeCommHandShake-2", rank, worldSize, ipcInfo);
+
+  for (const auto& info : peerIpcInfos) {
+    if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) ||
+        info.topology != peerIpcInfos.front().topology) {
+      LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
+                      "participants are observing different topologies ("
+                   << int(info.topology) << " and " << int(topology) << ")";
+      AT_CUDA_CHECK(cudaFree(p2pState));
+      AT_CUDA_CHECK(cudaFree(buffer));
+      return nullptr;
+    }
+  }
+
+  std::array<void*, kMaxDevices> p2pStates = {}, buffers = {};
+  for (size_t r = 0; r < peerIpcInfos.size(); ++r) {
+    if (r == rank) {
+      p2pStates[r] = p2pState;
+      buffers[r] = buffer;
+    } else {
+      AT_CUDA_CHECK(cudaIpcOpenMemHandle(
+          &p2pStates[r],
+          peerIpcInfos[r].p2pStateHandle,
+          cudaIpcMemLazyEnablePeerAccess));
+      AT_CUDA_CHECK(cudaIpcOpenMemHandle(
+          &buffers[r],
+          peerIpcInfos[r].bufferHandle,
+          cudaIpcMemLazyEnablePeerAccess));
+    }
+  }
+  void* p2pStatesDev = nullptr;
+  AT_CUDA_CHECK(cudaMalloc(&p2pStatesDev, sizeof(p2pStates)));
+  AT_CUDA_CHECK(cudaMemcpy(
+      p2pStatesDev,
+      p2pStates.data(),
+      sizeof(p2pStates),
+      cudaMemcpyHostToDevice));
+
+  void* buffersDev = nullptr;
+  AT_CUDA_CHECK(cudaMalloc(&buffersDev, sizeof(buffers)));
+  AT_CUDA_CHECK(cudaMemcpy(
+      buffersDev, buffers.data(), sizeof(buffers), cudaMemcpyHostToDevice));
+
+  void* topoInfo = initTopoInfo(topology, nvlMesh, rank);
+  return c10::make_intrusive<IntraNodeComm>(
+      topology,
+      p2pStates,
+      buffers,
+      p2pStatesDev,
+      buffersDev,
+      topoInfo,
+      rank,
+      worldSize);
+#else
+  return nullptr;
+#endif
+}
+
+AllReduceAlgo IntraNodeComm::selectAllReduceAlgo(const at::Tensor& input) {
+  return c10d::intra_node_comm::selectAllReduceAlgo(
+      input, topology_, worldSize_);
+}
+
+static int64_t usageCounter = 0;
+
+at::Tensor IntraNodeComm::allReduce(
+    const at::Tensor& input,
+    AllReduceAlgo algo) {
+  // Report usage for testing purposes.
+  // We don't care about overflowing.
+  ++usageCounter;
+  auto stream = at::cuda::getCurrentCUDAStream();
+  c10::cuda::CUDACachingAllocator::recordStream(
+      input.storage().data_ptr(), stream);
+  return c10d::intra_node_comm::allReduce(
+      input,
+      p2pStates_,
+      buffers_,
+      p2pStatesDev_,
+      buffersDev_,
+      topoInfo_,
+      rank_,
+      worldSize_,
+      algo,
+      stream);
+}
+
+int64_t getIntraNodeCommUsageCounter() {
+  return usageCounter;
+}
+
+} // namespace intra_node_comm
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu
new file mode 100644
index 0000000..7723140
--- /dev/null
+++ b/torch/csrc/distributed/c10d/intra_node_comm.cu
@@ -0,0 +1,729 @@
+#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
+
+#include <ATen/Dispatch.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+
+namespace c10d {
+namespace intra_node_comm {
+
+static constexpr size_t kBytesPerThread = 16;
+static constexpr size_t kMaxAllReduceBlocks = 24;
+static constexpr size_t kThreadsPerBlock = 1024;
+static constexpr size_t kWarpSize = 32;
+
+static constexpr size_t kHcmThreshBytes = 256 * 1024;
+static constexpr size_t kOneShotThreshBytes = 256 * 1024;
+static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;
+
+#if defined(USE_ROCM)
+using __nv_bfloat162 = uint32_t;
+#endif
+
+struct __align__(16) bf16x8 {
+  __nv_bfloat162 vals[4];
+};
+
+#define DEVICE_INLINE __device__ inline __attribute__((always_inline))
+
+DEVICE_INLINE __nv_bfloat162
+bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
+#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
+  CUDA_KERNEL_ASSERT(false);
+#else
+  return __hadd2(x, y);
+#endif
+}
+
+DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
+  bf16x8 c;
+  c.vals[0] = bf16hadd2(a.vals[0], b.vals[0]);
+  c.vals[1] = bf16hadd2(a.vals[1], b.vals[1]);
+  c.vals[2] = bf16hadd2(a.vals[2], b.vals[2]);
+  c.vals[3] = bf16hadd2(a.vals[3], b.vals[3]);
+  return c;
+}
+
+/**
+ * NOTE [cross device memory synchronization]
+ *
+ * The multi-stage algorithms (e.g. two-shot, hcm allreduce) require the writes
+ * of a thread to be visible by threads with the same block/thread ID on other
+ * devices. To satisfy CUDA's memory consistency model, every thread has to
+ * release its writes at the system scope, and the consuming thread has to
+ * acquire the writes at the system scope. This incurs high overhead and
+ * attempts in optmizing this process can be prone to race condition.
+ *
+ * Instead, we go around caching by having each thread:
+ *
+ * - Directly write to global memory via st.cs (cache-streaming).
+ * - Synchronize with threads within the block.
+ * - Perform cross device synchronization at block level (via system scope
+ *   atomic ops).
+ * - Synchronize with threads within the block.
+ * - Directly read from global memory via ld.nc (non-coherent/non-cached).
+ */
+template <typename T>
+DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
+#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
+  CUDA_KERNEL_ASSERT(false);
+#else
+  unsigned long long int low, high;
+  asm("ld.global.nc.v2.u64 {%0, %1}, [%2];"
+      : "=l"(low), "=l"(high)
+      : "l"(addr));
+  reinterpret_cast<unsigned long long int*>(&val)[0] = low;
+  reinterpret_cast<unsigned long long int*>(&val)[1] = high;
+#endif
+}
+
+__device__ inline void streamStore128(at::BFloat16* addr, const bf16x8& val) {
+#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
+  CUDA_KERNEL_ASSERT(false);
+#else
+  unsigned long long int low, high;
+  low = reinterpret_cast<const unsigned long long int*>(&val)[0];
+  high = reinterpret_cast<const unsigned long long int*>(&val)[1];
+  asm("st.global.cs.v2.u64 [%0], {%1, %2};" : : "l"(addr), "l"(low), "l"(high));
+#endif
+}
+
+template <typename T>
+DEVICE_INLINE void load128(bf16x8& val, const T* addr) {
+  *reinterpret_cast<uint4*>(&val) = reinterpret_cast<const uint4*>(addr)[0];
+}
+
+template <typename T>
+DEVICE_INLINE void store128(T* addr, const bf16x8& val) {
+  *reinterpret_cast<uint4*>(addr) = reinterpret_cast<const uint4*>(&val)[0];
+}
+
+DEVICE_INLINE void releaseSignal(uint32_t* addr) {
+#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
+  CUDA_KERNEL_ASSERT(false);
+#else
+  atomicAdd_system(addr, 1);
+#endif
+}
+
+DEVICE_INLINE void acquireSignal(uint32_t* addr) {
+#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
+  CUDA_KERNEL_ASSERT(false);
+#else
+  volatile uint32_t* signal = addr;
+  uint32_t val;
+  do {
+    val = *signal;
+  } while (val == 0 || atomicCAS_system(addr, val, val - 1) != val);
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Fully Connected Algos
+////////////////////////////////////////////////////////////////////////////////
+
+struct P2pState {
+  uint32_t signals0[kMaxAllReduceBlocks][kMaxDevices];
+  uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices];
+};
+
+template <uint32_t kWorldSize, bool kAligned>
+static __global__ void oneShotAllReduceKernel(
+    at::BFloat16* input,
+    size_t N,
+    size_t N_aligned,
+    P2pState** p2pStates,
+    at::BFloat16** buffers,
+    size_t rank) {
+  const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
+  const size_t offset =
+      (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
+  const size_t stride = blockDim.x * gridDim.x * numelPerThread;
+
+  // Wait for all other ranks to enter the kernel
+  if (threadIdx.x < kWorldSize) {
+    auto targetRank = threadIdx.x;
+    releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
+    acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
+  }
+  __syncthreads();
+
+  // The source pointers. Distributed round-robin for the different warps
+  const at::BFloat16* srcs[kWorldSize];
+#pragma unroll kWorldSize
+  for (int ii = 0; ii < kWorldSize; ++ii) {
+    int srcRank = (rank + ii) % kWorldSize;
+    srcs[ii] = buffers[srcRank];
+  }
+
+  for (size_t i = offset; i < N_aligned; i += stride) {
+    bf16x8 vals[kWorldSize];
+#pragma unroll kWorldSize
+    for (size_t ii = 0; ii < kWorldSize; ++ii) {
+      streamLoad128(vals[ii], &srcs[ii][i]);
+    }
+
+    bf16x8 sums;
+    memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
+
+#pragma unroll kWorldSize
+    for (size_t ii = 0; ii < kWorldSize; ++ii) {
+      sums = add_bf16x8(sums, vals[ii]);
+    }
+    if constexpr (kAligned) {
+      streamStore128(&input[i], sums);
+    } else {
+      for (size_t ii = 0; ii < numelPerThread; ++ii) {
+        if (i + ii < N) {
+          input[i + ii] = reinterpret_cast<at::BFloat16*>(&sums)[ii];
+        }
+      }
+    }
+  }
+}
+
+template <uint32_t kWorldSize>
+static __launch_bounds__(1024) __global__ void twoShotAllReduceKernel(
+    at::BFloat16* input,
+    size_t N_aligned,
+    P2pState** p2pStates,
+    at::BFloat16** buffers,
+    size_t rank) {
+  const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
+  const size_t offset =
+      (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
+  const size_t stride = blockDim.x * gridDim.x * numelPerThread;
+  const size_t N_per_rank = N_aligned / kWorldSize;
+  const size_t N_start = N_per_rank * rank;
+
+  // Wait for all other ranks to enter the kernel
+  if (threadIdx.x < kWorldSize) {
+    auto targetRank = threadIdx.x;
+    releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
+    acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
+  }
+  __syncthreads();
+
+  // The source pointers. Distributed round-robin for the different warps
+  at::BFloat16* srcs[kWorldSize];
+  size_t srcRanks[kWorldSize];
+#pragma unroll kWorldSize
+  for (int ii = 0; ii < kWorldSize; ++ii) {
+    int srcRank = (rank + ii) % kWorldSize;
+    srcs[ii] = buffers[srcRank];
+    srcRanks[ii] = srcRank;
+  }
+
+  for (size_t i = offset; i < N_per_rank; i += stride) {
+    bf16x8 vals[kWorldSize];
+#pragma unroll kWorldSize
+    for (size_t ii = 0; ii < kWorldSize; ++ii) {
+      streamLoad128(vals[ii], &srcs[ii][N_start + i]);
+    }
+
+    bf16x8 sums;
+    memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
+
+#pragma unroll kWorldSize
+    for (size_t ii = 0; ii < kWorldSize; ++ii) {
+      sums = add_bf16x8(sums, vals[ii]);
+    }
+    streamStore128(&srcs[0][N_start + i], sums);
+    // Store local sums into input now so we can avoid
+    // a global memory access later for it.
+    streamStore128(&input[N_start + i], sums);
+  }
+  __syncthreads();
+
+  if (threadIdx.x < kWorldSize) {
+    auto targetRank = threadIdx.x;
+    releaseSignal(&p2pStates[targetRank]->signals1[blockIdx.x][rank]);
+    acquireSignal(&p2pStates[rank]->signals1[blockIdx.x][targetRank]);
+  }
+  __syncthreads();
+
+  for (size_t i = offset; i < N_per_rank; i += stride) {
+#pragma unroll kWorldSize - 1
+    for (size_t ii = 1; ii < kWorldSize; ++ii) {
+      size_t k = N_start + i + (srcRanks[ii] - rank) * N_per_rank;
+      bf16x8 val;
+      streamLoad128(val, &srcs[ii][k]);
+      streamStore128(&input[k], val);
+    }
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Hybrid Cube Mesh Algos
+////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * NOTE [hybrid cube mesh]
+ *
+ * In a hybrid cube mesh topology, every device has exactly 4 neighbors
+ * (directly connected via NVLink). For every device X, it has exactly 1
+ * neighbor Y that is a neighbor of the 3 non-neighbor of X. We call Y the
+ * relay neighbor of X. This property is symmetrical: X is also guaranteed to
+ * be the relay neighbor of Y.
+ *
+ * With this property, we can perform a variant of one-shot allreduce algo that
+ * only moves data across NVLinks:
+ *
+ * - Each device one-shot allreduce among itself and 3 non-relay neighbors.
+ * - Each device exchange data with its relay neighbor.
+ *
+ * HybridCubeMesh is a data structure for describing the topology:
+ *
+ * - hcm[X][0:3] are the 3 neighbors of X.
+ * - hcm[X][3] is the relay neighbor of X.
+ * - For load balancing purpose, we also ensure that if hcm[X][k] = Y,
+ *   hcm[Y][k] = X.
+ */
+std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh) {
+  std::array<std::unordered_set<size_t>, kMaxDevices> neighbors = {};
+  std::array<size_t, kMaxDevices> neighborMasks = {};
+  for (size_t i = 0; i < kMaxDevices; ++i) {
+    for (size_t j = 0; j < kMaxDevices; ++j) {
+      if (nvlMesh[i][j] > 0) {
+        neighbors[i].insert(j);
+        neighborMasks[i] |= (1ul << j);
+      }
+    }
+  }
+  HybridCubeMesh hcm = {};
+  for (auto& row : hcm) {
+    row.fill(-1);
+  }
+  // A topology is an HCM if:
+  // - Every device has exactly 4 neighbors.
+  // - For every device, it has exactly 1 relay neighbor that is
+  //   a neighbor of the 3 non-neighbor of the device.
+  for (size_t i = 0; i < kMaxDevices; ++i) {
+    if (neighbors[i].size() != 4) {
+      return std::nullopt;
+    }
+    // Condition 1: check the number of neighbors
+    std::vector<size_t> relayNeighbors;
+    for (size_t j = 0; j < kMaxDevices; ++j) {
+      if ((neighborMasks[i] & neighborMasks[j]) == 0) {
+        relayNeighbors.push_back(j);
+      }
+    }
+    // Condition 2: check the number of relay neighbors
+    if (relayNeighbors.size() != 1) {
+      return std::nullopt;
+    }
+    neighbors[i].erase(relayNeighbors[0]);
+    hcm[i][3] = relayNeighbors[0];
+  }
+
+  for (size_t i = 0; i < kMaxDevices; ++i) {
+    for (size_t k = 0; k < 3; ++k) {
+      // We can only fill hcm[i][k] with j if hcm[j][k] is not filled
+      for (size_t j : neighbors[i]) {
+        if (hcm[j][k] == -1) {
+          hcm[i][k] = j;
+          hcm[j][k] = i;
+          break;
+        }
+      }
+      TORCH_CHECK(hcm[i][k] != -1);
+      neighbors[i].erase(hcm[i][k]);
+    }
+  }
+  return hcm;
+}
+
+template <bool kAligned>
+static __global__ void hybridCubeMeshAllReduceKernel(
+    at::BFloat16* input,
+    size_t N,
+    size_t N_aligned,
+    P2pState** p2pStates,
+    at::BFloat16** buffers,
+    int hcmInfo[4],
+    size_t rank) {
+  const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
+  const size_t offset =
+      (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
+  const size_t stride = blockDim.x * gridDim.x * numelPerThread;
+  const int relayRank = hcmInfo[3];
+
+  // Wait for HCM neigbors to enter the kernel
+  if (threadIdx.x < 3) {
+    auto targetRank = hcmInfo[threadIdx.x];
+    releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
+    acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
+  }
+  __syncthreads();
+
+  const at::BFloat16* srcs[4] = {
+      buffers[rank],
+      buffers[hcmInfo[0]],
+      buffers[hcmInfo[1]],
+      buffers[hcmInfo[2]],
+  };
+  at::BFloat16* localRelay = buffers[rank] + kMaxIntraNodeSize / 2;
+  at::BFloat16* remoteRelay = buffers[relayRank] + kMaxIntraNodeSize / 2;
+
+  for (size_t i = offset; i < N_aligned; i += stride) {
+    bf16x8 vals[4];
+
+#pragma unroll 4
+    for (size_t ii = 0; ii < 4; ++ii) {
+      streamLoad128(vals[ii], &srcs[ii][i]);
+    }
+
+    bf16x8 sums;
+    memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
+
+#pragma unroll 4
+    for (size_t ii = 0; ii < 4; ++ii) {
+      sums = add_bf16x8(sums, vals[ii]);
+    }
+    // Cached store for local sums
+    store128(&localRelay[i], sums);
+  }
+  __syncthreads();
+
+  if (threadIdx.x == 0) {
+    releaseSignal(&p2pStates[relayRank]->signals0[blockIdx.x][rank]);
+    acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][relayRank]);
+  }
+  __syncthreads();
+
+  for (size_t i = offset; i < N_aligned; i += stride) {
+    bf16x8 localSum, remoteSum;
+    // Cached load for local sums
+    load128(localSum, &localRelay[i]);
+    streamLoad128(remoteSum, &remoteRelay[i]);
+    localSum = add_bf16x8(localSum, remoteSum);
+    if constexpr (kAligned) {
+      streamStore128(&input[i], localSum);
+    } else {
+      for (size_t ii = 0; ii < numelPerThread; ++ii) {
+        if (i + ii < N) {
+          input[i + ii] = reinterpret_cast<at::BFloat16*>(&localSum)[ii];
+        }
+      }
+    }
+  }
+}
+
+static inline size_t divUp(uint32_t a, uint32_t b) {
+  return (a + b - 1) / b;
+}
+
+static inline size_t alignUp(uint32_t a, uint32_t b) {
+  return divUp(a, b) * b;
+}
+
+static void checkInput(const at::Tensor& input, size_t rank) {
+  TORCH_CHECK(
+      input.dtype() == at::kBFloat16,
+      "oneShotAllReduce only supports bf16 for now");
+  TORCH_CHECK(input.is_non_overlapping_and_dense());
+  TORCH_CHECK(input.device().is_cuda());
+  TORCH_CHECK(static_cast<size_t>(input.get_device()) == rank);
+}
+
+static void getLaunchConfig(
+    size_t N_aligned,
+    size_t elemSize,
+    dim3& blocks,
+    dim3& threads) {
+  blocks = dim3(0, 1, 1);
+  threads = dim3(0, 1, 1);
+
+  const auto numelPerThread = kBytesPerThread / elemSize;
+  const auto numelPerWarp = numelPerThread * kWarpSize;
+  TORCH_CHECK(N_aligned % numelPerThread == 0);
+  TORCH_CHECK(N_aligned % numelPerWarp == 0);
+  if (N_aligned < numelPerThread * kThreadsPerBlock) {
+    threads.x = N_aligned / numelPerWarp * kWarpSize;
+    blocks.x = 1;
+  } else {
+    auto warpsRequired = N_aligned / numelPerWarp;
+    auto threadsRequired = N_aligned / numelPerThread;
+    blocks.x =
+        std::min(divUp(threadsRequired, kThreadsPerBlock), kMaxAllReduceBlocks);
+    auto warpsPerBlock = divUp(warpsRequired, blocks.x);
+    threads.x = std::min(kThreadsPerBlock, warpsPerBlock * kWarpSize);
+  }
+}
+
+bool isIntraNodeCommSupported() {
+#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
+  return false;
+#else
+  return true;
+#endif
+}
+
+void* initP2pState() {
+  void* state = nullptr;
+  AT_CUDA_CHECK(cudaMalloc(&state, sizeof(P2pState)));
+  AT_CUDA_CHECK(cudaMemset(state, 0, sizeof(P2pState)));
+  return state;
+}
+
+void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank) {
+  void* topoInfo = nullptr;
+  if (topology != Topology::HYBRID_CUBE_MESH) {
+    return topoInfo;
+  }
+  auto hcm = getHybridCubeMesh(nvlMesh);
+  int hcmInfo[4];
+  std::copy((*hcm)[rank].begin(), (*hcm)[rank].begin() + 4, hcmInfo);
+  AT_CUDA_CHECK(cudaMalloc(&topoInfo, sizeof(hcmInfo)));
+  AT_CUDA_CHECK(
+      cudaMemcpy(topoInfo, hcmInfo, sizeof(hcmInfo), cudaMemcpyHostToDevice));
+  return topoInfo;
+}
+
+at::Tensor oneShotAllReduce(
+    const at::Tensor& input,
+    std::array<void*, kMaxDevices> p2pStates,
+    std::array<void*, kMaxDevices> buffers,
+    void* p2pStatesDev,
+    void* buffersDev,
+    size_t rank,
+    size_t worldSize,
+    at::cuda::CUDAStream& stream) {
+  checkInput(input, rank);
+
+  size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
+  size_t N_aligned = alignUp(input.numel(), numelPerWarp);
+  TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
+
+  dim3 blocks, threads;
+  getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
+
+  at::cuda::OptionalCUDAGuard guard(input.get_device());
+  AT_CUDA_CHECK(cudaMemcpyAsync(
+      buffers[rank],
+      input.data_ptr(),
+      input.numel() * input.element_size(),
+      cudaMemcpyDeviceToDevice,
+      stream));
+
+#define X(kWorldSize, kAligned)                           \
+  if (worldSize == kWorldSize) {                          \
+    oneShotAllReduceKernel<kWorldSize, kAligned>          \
+        <<<blocks, threads, 0, stream>>>(                 \
+            input.data_ptr<at::BFloat16>(),               \
+            input.numel(),                                \
+            N_aligned,                                    \
+            reinterpret_cast<P2pState**>(p2pStatesDev),   \
+            reinterpret_cast<at::BFloat16**>(buffersDev), \
+            rank);                                        \
+    C10_CUDA_KERNEL_LAUNCH_CHECK();                       \
+  }
+
+#define DISPATCH_ALL_WORLD_SIZES(kAligned) \
+  X(2, kAligned);                          \
+  X(3, kAligned);                          \
+  X(4, kAligned);                          \
+  X(5, kAligned);                          \
+  X(6, kAligned);                          \
+  X(7, kAligned);                          \
+  X(8, kAligned);
+
+  if (N_aligned == static_cast<size_t>(input.numel())) {
+    DISPATCH_ALL_WORLD_SIZES(true);
+  } else {
+    DISPATCH_ALL_WORLD_SIZES(false);
+  }
+
+#undef DISPATCH_ALL_WORLD_SIZES
+#undef X
+  return input;
+}
+
+at::Tensor twoShotAllReduce(
+    const at::Tensor& input,
+    std::array<void*, kMaxDevices> p2pStates,
+    std::array<void*, kMaxDevices> buffers,
+    void* p2pStatesDev,
+    void* buffersDev,
+    size_t rank,
+    size_t worldSize,
+    at::cuda::CUDAStream& stream) {
+  checkInput(input, rank);
+
+  size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
+  size_t N_aligned = alignUp(input.numel(), worldSize * numelPerWarp);
+  size_t N_per_rank = N_aligned / worldSize;
+  TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
+
+  dim3 blocks, threads;
+  getLaunchConfig(N_per_rank, input.element_size(), blocks, threads);
+
+  auto output = N_aligned == static_cast<size_t>(input.numel())
+      ? input
+      : input.new_empty(N_aligned);
+
+  at::cuda::OptionalCUDAGuard guard(input.get_device());
+  AT_CUDA_CHECK(cudaMemcpyAsync(
+      buffers[rank],
+      input.data_ptr(),
+      input.numel() * input.element_size(),
+      cudaMemcpyDeviceToDevice,
+      stream));
+
+#define X(kWorldSize)                                                   \
+  if (worldSize == kWorldSize) {                                        \
+    twoShotAllReduceKernel<kWorldSize><<<blocks, threads, 0, stream>>>( \
+        output.data_ptr<at::BFloat16>(),                                \
+        N_aligned,                                                      \
+        reinterpret_cast<P2pState**>(p2pStatesDev),                     \
+        reinterpret_cast<at::BFloat16**>(buffersDev),                   \
+        rank);                                                          \
+    C10_CUDA_KERNEL_LAUNCH_CHECK();                                     \
+  }
+  X(2);
+  X(3);
+  X(4);
+  X(5);
+  X(6);
+  X(7);
+  X(8);
+#undef X
+
+  if (output.data_ptr() != input.data_ptr()) {
+    AT_CUDA_CHECK(cudaMemcpyAsync(
+        input.data_ptr(),
+        output.data_ptr(),
+        input.numel() * input.element_size(),
+        cudaMemcpyDeviceToDevice,
+        stream));
+  }
+  return input;
+}
+
+at::Tensor hybridCubeMeshAllReduce(
+    const at::Tensor& input,
+    std::array<void*, kMaxDevices> p2pStates,
+    std::array<void*, kMaxDevices> buffers,
+    void* p2pStatesDev,
+    void* buffersDev,
+    int hcmInfo[4],
+    size_t rank,
+    size_t worldSize,
+    at::cuda::CUDAStream& stream) {
+  checkInput(input, rank);
+
+  size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
+  size_t N_aligned = alignUp(input.numel(), numelPerWarp);
+  TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
+
+  dim3 blocks, threads;
+  getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
+
+  at::cuda::OptionalCUDAGuard guard(input.get_device());
+  AT_CUDA_CHECK(cudaMemcpyAsync(
+      buffers[rank],
+      input.data_ptr(),
+      input.numel() * input.element_size(),
+      cudaMemcpyDeviceToDevice,
+      stream));
+
+#define X(kAligned)                                                        \
+  hybridCubeMeshAllReduceKernel<kAligned><<<blocks, threads, 0, stream>>>( \
+      input.data_ptr<at::BFloat16>(),                                      \
+      input.numel(),                                                       \
+      N_aligned,                                                           \
+      reinterpret_cast<P2pState**>(p2pStatesDev),                          \
+      reinterpret_cast<at::BFloat16**>(buffersDev),                        \
+      hcmInfo,                                                             \
+      rank);                                                               \
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+  if (N_aligned == static_cast<size_t>(input.numel())) {
+    X(true);
+  } else {
+    X(false);
+  }
+#undef X
+  return input;
+}
+
+AllReduceAlgo selectAllReduceAlgo(
+    const at::Tensor& input,
+    Topology topology,
+    size_t worldSize) {
+  // Only support bf16 for now
+  if (input.dtype() != at::kBFloat16 ||
+      input.numel() * input.element_size() > kMaxIntraNodeSize) {
+    return AllReduceAlgo::NONE;
+  }
+  const auto numel = input.numel();
+  const auto numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
+  if (topology == Topology::HYBRID_CUBE_MESH) {
+    TORCH_CHECK(
+        worldSize == 8, "hyperCubeAllReduce only supports exactly 8 GPUs");
+    if (alignUp(numel, numelPerWarp) <= kHcmThreshBytes) {
+      return AllReduceAlgo::HCM;
+    }
+  }
+  if (topology == Topology::FULLY_CONNECTED) {
+    if (alignUp(numel, numelPerWarp) <= kOneShotThreshBytes) {
+      return AllReduceAlgo::ONE_SHOT;
+    }
+    if (alignUp(numel, numelPerWarp * worldSize) <= kTwoShotThreshBytes) {
+      return AllReduceAlgo::TWO_SHOT;
+    }
+  }
+  return AllReduceAlgo::NONE;
+}
+
+at::Tensor allReduce(
+    const at::Tensor& input,
+    std::array<void*, kMaxDevices> p2pStates,
+    std::array<void*, kMaxDevices> buffers,
+    void* p2pStatesDev,
+    void* buffersDev,
+    void* topoInfo,
+    size_t rank,
+    size_t worldSize,
+    AllReduceAlgo algo,
+    at::cuda::CUDAStream& stream) {
+  switch (algo) {
+    case AllReduceAlgo::ONE_SHOT:
+      return oneShotAllReduce(
+          input,
+          p2pStates,
+          buffers,
+          p2pStatesDev,
+          buffersDev,
+          rank,
+          worldSize,
+          stream);
+    case AllReduceAlgo::TWO_SHOT:
+      return twoShotAllReduce(
+          input,
+          p2pStates,
+          buffers,
+          p2pStatesDev,
+          buffersDev,
+          rank,
+          worldSize,
+          stream);
+    case AllReduceAlgo::HCM:
+      return hybridCubeMeshAllReduce(
+          input,
+          p2pStates,
+          buffers,
+          p2pStatesDev,
+          buffersDev,
+          (int*)topoInfo,
+          rank,
+          worldSize,
+          stream);
+    default:
+      C10_THROW_ERROR(ValueError, "IntraNodeComm: invalid algo");
+  }
+}
+
+} // namespace intra_node_comm
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp
new file mode 100644
index 0000000..b494990
--- /dev/null
+++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp
@@ -0,0 +1,106 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAEvent.h>
+#include <c10/cuda/CUDAStream.h>
+#include <torch/csrc/distributed/c10d/Store.hpp>
+#include <torch/csrc/distributed/c10d/Work.hpp>
+
+namespace c10d {
+namespace intra_node_comm {
+
+constexpr size_t kMaxDevices = 8;
+constexpr size_t kMaxIntraNodeSize = 10 * 1024 * 1024;
+
+using NvlMesh = std::array<std::array<size_t, kMaxDevices>, kMaxDevices>;
+using HybridCubeMesh = std::array<std::array<int, 4>, kMaxDevices>;
+
+enum class Topology { UNKNOWN = 0, FULLY_CONNECTED = 1, HYBRID_CUBE_MESH = 2 };
+
+enum class AllReduceAlgo { NONE = 0, ONE_SHOT = 1, TWO_SHOT = 2, HCM = 3 };
+
+class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
+ public:
+  IntraNodeComm(
+      Topology topology,
+      std::array<void*, kMaxDevices> p2pStates,
+      std::array<void*, kMaxDevices> buffers,
+      void* p2pStatesDev,
+      void* buffersDev,
+      void* topoInfo,
+      size_t rank,
+      size_t worldSize);
+
+  ~IntraNodeComm();
+
+  /**
+   * Rendezvous via a c10d::Store.
+   * This function may return nullptr if intra-node comm is not applicable.
+   * It guarantees all participants either succeeds or abort.
+   */
+  static c10::intrusive_ptr<IntraNodeComm> rendezvous(
+      c10::intrusive_ptr<c10d::Store> store,
+      const std::string& prefix,
+      size_t rank,
+      size_t worldSize);
+
+  /**
+   * Selects a AllReduceAlgo that we think will outperform nccl.
+   * Returns AllReduceAlgo::NONE if we don't think we can outperform nccl.
+   */
+  AllReduceAlgo selectAllReduceAlgo(const at::Tensor& input);
+
+  at::Tensor allReduce(const at::Tensor& input, AllReduceAlgo algo);
+
+ private:
+  Topology topology_;
+  std::array<void*, kMaxDevices> p2pStates_;
+  std::array<void*, kMaxDevices> buffers_;
+  void* p2pStatesDev_;
+  void* buffersDev_;
+  void* topoInfo_;
+  size_t rank_;
+  size_t worldSize_;
+};
+
+/**
+ * NOTE [IntraNodeComm Stream Semantics]
+ *
+ * ProcessGroupNCCL launches kernels differently from the conventional PyTorch
+ * CUDA semantics: it always launches collective kernels onto a dedicated
+ * communication stream. Therefore, it needs to:
+ *
+ * - Synchronize the calling stream and the comm stream.
+ * - Ensure the memory safety of the operands (via record_stream or stashing).
+ * - Synchronize the waiting stream with the comm stream.
+ *
+ * Unconditionally performing these tasks makes sense when we expect most of the
+ * communication to benefit from compute/comm overlap. However, IntraNodeComm
+ * primarily aims to optimize small, latency-sensitive, blocking communication,
+ * in which the overhead incurred by the above steps can be quite pronounced.
+ *
+ * Thus, IntraNodeComm follows the conventional PyTorch CUDA semantics and
+ * launches kernels onto the stream specified by the user. Although the user
+ * can perform neccessary synchronization via wait_stream, to provide a UX
+ * consistent to that of ProcessGroupNCCL, the neccessary stream
+ * synchronization can also be performed via IntraNodeWork::wait().
+ */
+class IntraNodeCommWork : public c10d::Work {
+ public:
+  IntraNodeCommWork() : c10d::Work() {
+    event_.record();
+  }
+
+  bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
+    event_.block(at::cuda::getCurrentCUDAStream());
+    return true;
+  }
+
+ private:
+  at::cuda::CUDAEvent event_;
+};
+
+TORCH_API int64_t getIntraNodeCommUsageCounter();
+
+} // namespace intra_node_comm
+} // namespace c10d