Reapply "c10d: add Collectives abstraction (#125978)" (#126695)

This reverts commit d9c3485146913324ab4b3e211d2a4517e138f4af.

Reapplies #125978.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126695
Approved by: https://github.com/c-p-i-o
diff --git a/BUILD.bazel b/BUILD.bazel
index 3f7e632..831d64b 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -772,7 +772,7 @@
         [
             "torch/*.h",
             "torch/csrc/**/*.h",
-            "torch/csrc/distributed/c10d/*.hpp",
+            "torch/csrc/distributed/c10d/**/*.hpp",
             "torch/lib/libshm/*.h",
         ],
         exclude = [
diff --git a/build_variables.bzl b/build_variables.bzl
index 3f16f9b..152324a 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -487,6 +487,7 @@
 # These files are the only ones that are supported on Windows.
 libtorch_distributed_base_sources = [
     "torch/csrc/distributed/c10d/Backend.cpp",
+    "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
     "torch/csrc/distributed/c10d/FileStore.cpp",
     "torch/csrc/distributed/c10d/Functional.cpp",
     "torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
diff --git a/test/distributed/test_control_collectives.py b/test/distributed/test_control_collectives.py
new file mode 100644
index 0000000..fb0067f
--- /dev/null
+++ b/test/distributed/test_control_collectives.py
@@ -0,0 +1,189 @@
+# Owner(s): ["oncall: distributed"]
+
+from datetime import timedelta
+from multiprocessing.pool import ThreadPool
+
+import torch
+import torch.distributed as dist
+from torch.testing._internal.common_utils import run_tests, TestCase
+
+
+class TestCollectives(TestCase):
+    def test_barrier(self) -> None:
+        store = dist.HashStore()
+
+        world_size = 2
+
+        def f(rank: int) -> None:
+            collectives = dist._StoreCollectives(store, rank, world_size)
+            collectives.barrier("foo", timedelta(seconds=10), True)
+
+        with ThreadPool(world_size) as pool:
+            pool.map(f, range(world_size))
+
+    def test_broadcast(self) -> None:
+        store = dist.HashStore()
+
+        world_size = 4
+        timeout = timedelta(seconds=10)
+
+        def f(rank: int) -> None:
+            collectives = dist._StoreCollectives(store, rank, world_size)
+            if rank == 2:
+                collectives.broadcast_send("foo", b"data", timeout)
+            else:
+                out = collectives.broadcast_recv("foo", timeout)
+                self.assertEqual(out, b"data")
+
+        with ThreadPool(world_size) as pool:
+            pool.map(f, range(world_size))
+
+    def test_gather(self) -> None:
+        store = dist.HashStore()
+
+        world_size = 4
+        timeout = timedelta(seconds=10)
+
+        def f(rank: int) -> None:
+            collectives = dist._StoreCollectives(store, rank, world_size)
+            if rank == 2:
+                out = collectives.gather_recv("foo", str(rank), timeout)
+                self.assertEqual(out, [b"0", b"1", b"2", b"3"])
+            else:
+                collectives.gather_send("foo", str(rank), timeout)
+
+        with ThreadPool(world_size) as pool:
+            pool.map(f, range(world_size))
+
+    def test_scatter(self) -> None:
+        store = dist.HashStore()
+
+        world_size = 4
+        timeout = timedelta(seconds=10)
+
+        def f(rank: int) -> None:
+            collectives = dist._StoreCollectives(store, rank, world_size)
+            if rank == 2:
+                out = collectives.scatter_send(
+                    "foo", [str(i) for i in range(world_size)], timeout
+                )
+            else:
+                out = collectives.scatter_recv("foo", timeout)
+            self.assertEqual(out, str(rank).encode())
+
+        with ThreadPool(world_size) as pool:
+            pool.map(f, range(world_size))
+
+    def test_all_sum(self) -> None:
+        store = dist.HashStore()
+
+        world_size = 4
+        timeout = timedelta(seconds=10)
+
+        def f(rank: int) -> None:
+            collectives = dist._StoreCollectives(store, rank, world_size)
+            out = collectives.all_sum("foo", rank, timeout)
+            self.assertEqual(out, sum(range(world_size)))
+
+        with ThreadPool(world_size) as pool:
+            pool.map(f, range(world_size))
+
+    def test_broadcast_timeout(self) -> None:
+        store = dist.HashStore()
+
+        world_size = 4
+        timeout = timedelta(milliseconds=1)
+        collectives = dist._StoreCollectives(store, 1, world_size)
+        with self.assertRaisesRegex(Exception, "Wait timeout"):
+            collectives.broadcast_recv("foo", timeout)
+
+    def test_gather_timeout(self) -> None:
+        store = dist.HashStore()
+
+        world_size = 4
+        timeout = timedelta(milliseconds=1)
+        collectives = dist._StoreCollectives(store, 1, world_size)
+        with self.assertRaisesRegex(
+            Exception, "gather failed -- missing ranks: 0, 2, 3"
+        ):
+            collectives.gather_recv("foo", "data", timeout)
+
+    def test_scatter_timeout(self) -> None:
+        store = dist.HashStore()
+
+        world_size = 4
+        timeout = timedelta(milliseconds=1)
+        collectives = dist._StoreCollectives(store, 1, world_size)
+        with self.assertRaisesRegex(Exception, "Wait timeout"):
+            collectives.scatter_recv("foo", timeout)
+
+    def test_all_gather_timeout(self) -> None:
+        store = dist.HashStore()
+
+        world_size = 4
+        timeout = timedelta(milliseconds=1)
+        collectives = dist._StoreCollectives(store, 1, world_size)
+        with self.assertRaisesRegex(
+            Exception, "all_gather failed -- missing ranks: 0, 2, 3"
+        ):
+            collectives.all_gather("foo", "data", timeout)
+
+    def test_barrier_timeout(self) -> None:
+        store = dist.HashStore()
+
+        world_size = 4
+        timeout = timedelta(milliseconds=1)
+        collectives = dist._StoreCollectives(store, 1, world_size)
+        with self.assertRaisesRegex(
+            Exception, "barrier failed -- missing ranks: 0, 2, 3"
+        ):
+            collectives.barrier("foo", timeout, True)
+
+    def test_all_sum_timeout(self) -> None:
+        store = dist.HashStore()
+
+        world_size = 4
+        timeout = timedelta(milliseconds=1)
+        collectives = dist._StoreCollectives(store, 1, world_size)
+        with self.assertRaisesRegex(
+            Exception, "barrier failed -- missing ranks: 0, 2, 3"
+        ):
+            collectives.all_sum("foo", 1, timeout)
+
+    def test_unique(self) -> None:
+        store = dist.HashStore()
+
+        collectives = dist._StoreCollectives(store, 1, 1)
+        collectives.broadcast_send("foo", "bar")
+
+        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+            collectives.broadcast_send("foo", "bar")
+
+        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+            collectives.broadcast_recv("foo")
+
+        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+            collectives.gather_send("foo", "bar")
+
+        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+            collectives.gather_recv("foo", "asdf")
+
+        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+            collectives.scatter_send("foo", ["asdf"])
+
+        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+            collectives.scatter_recv("foo")
+
+        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+            collectives.all_gather("foo", "bar")
+
+        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+            collectives.all_sum("foo", 2)
+
+
+if __name__ == "__main__":
+    assert (
+        not torch.cuda._initialized
+    ), "test_distributed must not have initialized CUDA context on main process"
+
+    run_tests()
diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi
index 28d790e..74a73a3 100644
--- a/torch/_C/_distributed_c10d.pyi
+++ b/torch/_C/_distributed_c10d.pyi
@@ -210,6 +210,20 @@
     @property
     def underlying_store(self) -> Store: ...
 
+class _ControlCollectives:
+    def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ...
+    def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ...
+    def broadcast_recv(self, key: str, timeout: timedelta) -> str: ...
+    def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ...
+    def gather_recv(self, key: str, timeout: timedelta) -> str: ...
+    def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ...
+    def scatter_recv(self, key: str, timeout: timedelta) -> str: ...
+    def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ...
+    def all_sum(self, key: str, data: str, timeout: timedelta) -> int: ...
+
+class _StoreCollectives(_ControlCollectives):
+    def __init__(self, store: Store, rank: int, world_size: int) -> None: ...
+
 class _DistributedBackendOptions:
     def __init__(self): ...
     @property
diff --git a/torch/csrc/distributed/c10d/HashStore.hpp b/torch/csrc/distributed/c10d/HashStore.hpp
index 1453c0a..3697d62 100644
--- a/torch/csrc/distributed/c10d/HashStore.hpp
+++ b/torch/csrc/distributed/c10d/HashStore.hpp
@@ -22,7 +22,7 @@
   std::vector<uint8_t> get(const std::string& key) override;
 
   void wait(const std::vector<std::string>& keys) override {
-    wait(keys, Store::kDefaultTimeout);
+    wait(keys, timeout_);
   }
 
   void wait(
diff --git a/torch/csrc/distributed/c10d/Store.hpp b/torch/csrc/distributed/c10d/Store.hpp
index af715ba..993284f 100644
--- a/torch/csrc/distributed/c10d/Store.hpp
+++ b/torch/csrc/distributed/c10d/Store.hpp
@@ -97,4 +97,33 @@
   std::chrono::milliseconds timeout_;
 };
 
+/*
+StoreTimeoutGuard is a RAII guard that will set the store timeout and restore it
+when it returns.
+*/
+class StoreTimeoutGuard {
+ public:
+  explicit StoreTimeoutGuard(
+      Store& store,
+      const std::chrono::milliseconds& timeout)
+      : store_(store) {
+    oldTimeout_ = store.getTimeout();
+    store.setTimeout(timeout);
+  }
+
+  ~StoreTimeoutGuard() {
+    store_.setTimeout(oldTimeout_);
+  }
+
+  /* Disabling copy and move semantics */
+  StoreTimeoutGuard(const StoreTimeoutGuard&) = delete;
+  StoreTimeoutGuard& operator=(const StoreTimeoutGuard&) = delete;
+  StoreTimeoutGuard(StoreTimeoutGuard&&) = delete;
+  StoreTimeoutGuard& operator=(StoreTimeoutGuard&&) = delete;
+
+ private:
+  Store& store_;
+  std::chrono::milliseconds oldTimeout_;
+};
+
 } // namespace c10d
diff --git a/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp b/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp
new file mode 100644
index 0000000..b98f9a7
--- /dev/null
+++ b/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp
@@ -0,0 +1,59 @@
+#pragma once
+
+#include <ATen/core/ivalue.h>
+#include <chrono>
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include <c10/macros/Macros.h>
+#include <torch/custom_class.h>
+
+namespace c10d {
+
+using namespace std::chrono_literals;
+
+class TORCH_API ControlCollectives : public torch::CustomClassHolder {
+ public:
+  virtual void barrier(
+      const std::string& key,
+      std::chrono::milliseconds timeout = 5min,
+      bool block = true) = 0;
+
+  virtual void broadcastSend(
+      const std::string& key,
+      const std::vector<uint8_t>& data,
+      std::chrono::milliseconds timeout = 5min) = 0;
+  virtual std::vector<uint8_t> broadcastRecv(
+      const std::string& key,
+      std::chrono::milliseconds timeout = 5min) = 0;
+
+  virtual void gatherSend(
+      const std::string& key,
+      const std::vector<uint8_t>& data,
+      std::chrono::milliseconds timeout = 5min) = 0;
+  virtual std::vector<std::vector<uint8_t>> gatherRecv(
+      const std::string& key,
+      const std::vector<uint8_t>& data,
+      std::chrono::milliseconds timeout = 5min) = 0;
+
+  virtual std::vector<uint8_t> scatterSend(
+      const std::string& key,
+      const std::vector<std::vector<uint8_t>>& data,
+      std::chrono::milliseconds timeout = 5min) = 0;
+  virtual std::vector<uint8_t> scatterRecv(
+      const std::string& key,
+      std::chrono::milliseconds timeout = 5min) = 0;
+
+  virtual std::vector<std::vector<uint8_t>> allGather(
+      const std::string& key,
+      const std::vector<uint8_t>& data,
+      std::chrono::milliseconds timeout = 5min) = 0;
+
+  virtual int64_t allSum(
+      const std::string& key,
+      int64_t data,
+      std::chrono::milliseconds timeout = 5min) = 0;
+};
+
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp
new file mode 100644
index 0000000..9958994
--- /dev/null
+++ b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp
@@ -0,0 +1,222 @@
+#include <c10/util/Exception.h>
+#include <fmt/format.h>
+#include <torch/csrc/distributed/c10d/Store.hpp>
+#include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
+#include <chrono>
+#include <exception>
+#include <vector>
+
+namespace {
+std::string getRankKey(const std::string& key, int rank) {
+  return fmt::format("{}/{}", key, rank);
+}
+} // namespace
+
+namespace c10d {
+
+StoreCollectives::StoreCollectives(
+    c10::intrusive_ptr<::c10d::Store> store,
+    int rank,
+    int worldSize)
+    : store_(std::move(store)), rank_(rank), worldSize_(worldSize) {}
+
+void StoreCollectives::barrier(
+    const std::string& key,
+    std::chrono::milliseconds timeout,
+    bool blocking) {
+  enforceUnique(key);
+  StoreTimeoutGuard g{*store_, timeout};
+
+  auto num_members_key = fmt::format("{}/num_members", key);
+  auto last_members_key = fmt::format("{}/last_members", key);
+
+  auto idx = store_->add(num_members_key, 1);
+  store_->set(getRankKey(key, rank_), "joined");
+
+  if (idx == worldSize_) {
+    store_->set(last_members_key, "<val_ignored>");
+  } else if (blocking) {
+    try {
+      store_->wait({last_members_key});
+    } catch (const std::exception& e) {
+      std::string msg = "barrier failed -- missing ranks: ";
+      for (int i = 0; i < worldSize_; i++) {
+        if (i == rank_) {
+          continue;
+        }
+        auto rank_key = getRankKey(key, i);
+        if (!store_->check({rank_key})) {
+          msg += fmt::format("{}, ", i);
+        }
+      }
+      throw std::runtime_error(msg + e.what());
+    }
+  }
+}
+
+void StoreCollectives::broadcastSend(
+    const std::string& key,
+    const std::vector<uint8_t>& data,
+    std::chrono::milliseconds timeout) {
+  enforceUnique(key);
+  StoreTimeoutGuard g{*store_, timeout};
+
+  store_->set(key, data);
+}
+
+std::vector<uint8_t> StoreCollectives::broadcastRecv(
+    const std::string& key,
+    std::chrono::milliseconds timeout) {
+  enforceUnique(key);
+  StoreTimeoutGuard g{*store_, timeout};
+
+  return store_->get(key);
+}
+
+void StoreCollectives::gatherSend(
+    const std::string& key,
+    const std::vector<uint8_t>& data,
+    std::chrono::milliseconds timeout) {
+  enforceUnique(key);
+  StoreTimeoutGuard g{*store_, timeout};
+
+  auto rank_key = getRankKey(key, rank_);
+  store_->set(rank_key, data);
+}
+
+std::vector<std::vector<uint8_t>> StoreCollectives::gatherRecv(
+    const std::string& key,
+    const std::vector<uint8_t>& data,
+    std::chrono::milliseconds timeout) {
+  enforceUnique(key);
+  StoreTimeoutGuard g{*store_, timeout};
+
+  std::vector<std::string> keys;
+  keys.reserve(worldSize_);
+
+  for (int i = 0; i < worldSize_; i++) {
+    if (i == rank_) {
+      continue;
+    }
+    auto rank_key = getRankKey(key, i);
+    keys.emplace_back(rank_key);
+  }
+
+  std::vector<std::vector<uint8_t>> results;
+  results.reserve(worldSize_);
+
+  try {
+    results = store_->multiGet(keys);
+  } catch (const std::exception& e) {
+    std::string msg = "gather failed -- missing ranks: ";
+    for (int i = 0; i < worldSize_; i++) {
+      if (i == rank_) {
+        continue;
+      }
+      auto rank_key = getRankKey(key, i);
+      if (!store_->check({rank_key})) {
+        msg += fmt::format("{}, ", i);
+      }
+    }
+    throw std::runtime_error(msg + e.what());
+  }
+
+  // insert local data
+  results.insert(results.begin() + rank_, data);
+  return results;
+}
+
+std::vector<uint8_t> StoreCollectives::scatterSend(
+    const std::string& key,
+    const std::vector<std::vector<uint8_t>>& data,
+    std::chrono::milliseconds timeout) {
+  enforceUnique(key);
+  StoreTimeoutGuard g{*store_, timeout};
+
+  std::vector<std::string> keys;
+  keys.reserve(worldSize_);
+  for (int i = 0; i < worldSize_; i++) {
+    if (i == rank_) {
+      continue;
+    }
+    auto rank_key = getRankKey(key, i);
+    keys.emplace_back(rank_key);
+  }
+  auto local = data.at(rank_);
+
+  std::vector<std::vector<uint8_t>> toSend{data};
+
+  toSend.erase(toSend.begin() + rank_);
+
+  store_->multiSet(keys, toSend);
+
+  return local;
+}
+
+std::vector<uint8_t> StoreCollectives::scatterRecv(
+    const std::string& key,
+    std::chrono::milliseconds timeout) {
+  enforceUnique(key);
+  StoreTimeoutGuard g{*store_, timeout};
+
+  auto rank_key = getRankKey(key, rank_);
+  return store_->get(rank_key);
+}
+
+std::vector<std::vector<uint8_t>> StoreCollectives::allGather(
+    const std::string& key,
+    const std::vector<uint8_t>& data,
+    std::chrono::milliseconds timeout) {
+  enforceUnique(key);
+  StoreTimeoutGuard g{*store_, timeout};
+
+  auto localKey = getRankKey(key, rank_);
+  store_->set(localKey, data);
+
+  std::vector<std::string> keys;
+  keys.reserve(worldSize_);
+
+  for (int i = 0; i < worldSize_; i++) {
+    auto rank_key = getRankKey(key, i);
+    keys.emplace_back(rank_key);
+  }
+
+  try {
+    return store_->multiGet(keys);
+  } catch (const std::exception& e) {
+    std::string msg = "all_gather failed -- missing ranks: ";
+    for (int i = 0; i < worldSize_; i++) {
+      if (i == rank_) {
+        continue;
+      }
+      auto rank_key = getRankKey(key, i);
+      if (!store_->check({rank_key})) {
+        msg += fmt::format("{}, ", i);
+      }
+    }
+    throw std::runtime_error(msg + e.what());
+  }
+}
+
+int64_t StoreCollectives::allSum(
+    const std::string& key,
+    int64_t value,
+    std::chrono::milliseconds timeout) {
+  enforceUnique(key);
+  StoreTimeoutGuard g{*store_, timeout};
+
+  store_->add(key, value);
+
+  barrier(key + "/barrier", timeout);
+
+  return store_->add(key, 0);
+}
+
+void StoreCollectives::enforceUnique(const std::string& key) {
+  auto it = seenKeys_.find(key);
+  TORCH_INTERNAL_ASSERT(
+      it == seenKeys_.end(), "Key ", key, " has already been used.");
+  seenKeys_.emplace(key);
+}
+
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp
new file mode 100644
index 0000000..7d3eb50
--- /dev/null
+++ b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp
@@ -0,0 +1,68 @@
+#pragma once
+
+#include <c10/macros/Macros.h>
+#include <c10/util/FbcodeMaps.h>
+#include <torch/csrc/distributed/c10d/Store.hpp>
+#include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
+
+namespace c10d {
+
+class TORCH_API StoreCollectives : public ControlCollectives {
+ public:
+  explicit StoreCollectives(
+      c10::intrusive_ptr<Store> store,
+      int rank,
+      int worldSize);
+
+  void barrier(
+      const std::string& key,
+      std::chrono::milliseconds timeout = 5min,
+      bool block = true) override;
+
+  void broadcastSend(
+      const std::string& key,
+      const std::vector<uint8_t>& data,
+      std::chrono::milliseconds timeout = 5min) override;
+  std::vector<uint8_t> broadcastRecv(
+      const std::string& key,
+      std::chrono::milliseconds timeout = 5min) override;
+
+  void gatherSend(
+      const std::string& key,
+      const std::vector<uint8_t>& data,
+      std::chrono::milliseconds timeout = 5min) override;
+  std::vector<std::vector<uint8_t>> gatherRecv(
+      const std::string& key,
+      const std::vector<uint8_t>& data,
+      std::chrono::milliseconds timeout = 5min) override;
+
+  std::vector<uint8_t> scatterSend(
+      const std::string& key,
+      const std::vector<std::vector<uint8_t>>& data,
+      std::chrono::milliseconds timeout = 5min) override;
+  std::vector<uint8_t> scatterRecv(
+      const std::string& key,
+      std::chrono::milliseconds timeout = 5min) override;
+
+  std::vector<std::vector<uint8_t>> allGather(
+      const std::string& key,
+      const std::vector<uint8_t>& data,
+      std::chrono::milliseconds timeout = 5min) override;
+
+  int64_t allSum(
+      const std::string& key,
+      int64_t data,
+      std::chrono::milliseconds timeout = 5min) override;
+
+ private:
+  void enforceUnique(const std::string& key);
+
+ private:
+  c10::intrusive_ptr<Store> store_;
+  int rank_;
+  int worldSize_;
+
+  c10::FastSet<std::string> seenKeys_{};
+};
+
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 483becb..505b64e 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -6,6 +6,9 @@
 #include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
 #include <torch/csrc/distributed/c10d/TCPStore.hpp>
 #include <torch/csrc/distributed/c10d/Utils.hpp>
+#include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
+#include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
+#include <vector>
 #ifndef _WIN32
 #include <torch/csrc/distributed/c10d/HashStore.hpp>
 #include <torch/csrc/distributed/c10d/ProcessGroupRoundRobin.hpp>
@@ -136,6 +139,34 @@
 
 namespace {
 
+py::bytes toPyBytes(const std::vector<uint8_t>& data) {
+  return py::bytes(reinterpret_cast<const char*>(data.data()), data.size());
+}
+
+std::vector<py::bytes> toPyBytes(
+    const std::vector<std::vector<uint8_t>>& data) {
+  std::vector<py::bytes> out;
+  out.reserve(data.size());
+  for (const std::vector<uint8_t>& data_ : data) {
+    out.emplace_back(reinterpret_cast<const char*>(data_.data()), data_.size());
+  }
+  return out;
+}
+
+std::vector<uint8_t> toVec8(const std::string& data) {
+  std::vector<uint8_t> out{data.begin(), data.end()};
+  return out;
+}
+
+std::vector<std::vector<uint8_t>> toVec8(const std::vector<std::string>& data) {
+  std::vector<std::vector<uint8_t>> out;
+  out.reserve(data.size());
+  for (auto& data_ : data) {
+    out.emplace_back(toVec8(data_));
+  }
+  return out;
+}
+
 template <typename T>
 using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
 
@@ -166,8 +197,7 @@
         pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "set");
     TORCH_INTERNAL_ASSERT(fn, "Not implemented.");
     // Call function with a py::bytes object for the value.
-    fn(key,
-       py::bytes(reinterpret_cast<const char*>(value.data()), value.size()));
+    fn(key, toPyBytes(value));
   }
 
   // Note: this function manually calls the Python-side overload
@@ -184,7 +214,7 @@
     // std::vector<uint8_t>. There is no API for directly accessing
     // the contents of the py::bytes object.
     std::string str = pybind11::cast<py::bytes>(fn(key));
-    return std::vector<uint8_t>(str.begin(), str.end());
+    return toVec8(str);
   }
 
   // Note: this function manually calls the Python-side overload
@@ -204,14 +234,8 @@
     // std::vector<uint8_t>. There is no API for directly accessing
     // the contents of the py::bytes object.
     std::string str = pybind11::cast<py::bytes>(
-        fn(key,
-           py::bytes(
-               reinterpret_cast<const char*>(expectedValue.data()),
-               expectedValue.size()),
-           py::bytes(
-               reinterpret_cast<const char*>(desiredValue.data()),
-               desiredValue.size())));
-    return std::vector<uint8_t>(str.begin(), str.end());
+        fn(key, toPyBytes(expectedValue), toPyBytes(desiredValue)));
+    return toVec8(str);
   }
 
   int64_t add(const std::string& key, int64_t value) override {
@@ -253,8 +277,7 @@
       return Store::append(key, value);
     }
     // Call function with a py::bytes object for the value.
-    fn(key,
-       py::bytes(reinterpret_cast<const char*>(value.data()), value.size()));
+    fn(key, toPyBytes(value));
   }
 
   std::vector<std::vector<uint8_t>> multiGet(
@@ -287,14 +310,7 @@
       return Store::multiSet(keys, values);
     }
 
-    std::vector<py::bytes> bytes;
-    bytes.reserve(values.size());
-    for (auto& value : values) {
-      bytes.emplace_back(
-          reinterpret_cast<const char*>(value.data()), value.size());
-    }
-
-    fn(keys, bytes);
+    fn(keys, toPyBytes(values));
   }
 
   bool hasExtendedApi() const override {
@@ -973,10 +989,7 @@
               "set",
               [](::c10d::Store& store,
                  const std::string& key,
-                 const std::string& value) {
-                std::vector<uint8_t> value_(value.begin(), value.end());
-                store.set(key, value_);
-              },
+                 const std::string& value) { store.set(key, toVec8(value)); },
               py::call_guard<py::gil_scoped_release>(),
               R"(
 Inserts the key-value pair into the store based on the supplied ``key`` and
@@ -1001,14 +1014,9 @@
                  const std::string& key,
                  const std::string& expected_value,
                  const std::string& desired_value) -> py::bytes {
-                std::vector<uint8_t> expectedValue_(
-                    expected_value.begin(), expected_value.end());
-                std::vector<uint8_t> desiredValue_(
-                    desired_value.begin(), desired_value.end());
-                auto value =
-                    store.compareSet(key, expectedValue_, desiredValue_);
-                return py::bytes(
-                    reinterpret_cast<char*>(value.data()), value.size());
+                auto value = store.compareSet(
+                    key, toVec8(expected_value), toVec8(desired_value));
+                return toPyBytes(value);
               },
               py::call_guard<py::gil_scoped_release>(),
               R"(
@@ -1040,8 +1048,7 @@
                   py::gil_scoped_release guard;
                   return store.get(key);
                 }();
-                return py::bytes(
-                    reinterpret_cast<char*>(value.data()), value.size());
+                return toPyBytes(value);
               },
               R"(
 Retrieves the value associated with the given ``key`` in the store. If ``key`` is not
@@ -1240,8 +1247,7 @@
               [](::c10d::Store& store,
                  const std::string& key,
                  const std::string& value) {
-                std::vector<uint8_t> value_(value.begin(), value.end());
-                store.append(key, value_);
+                store.append(key, toVec8(value));
               },
               py::call_guard<py::gil_scoped_release>(),
               R"(
@@ -1268,14 +1274,7 @@
                   py::gil_scoped_release guard;
                   return store.multiGet(keys);
                 }();
-                std::vector<py::bytes> res;
-                for (auto& value : values) {
-                  auto bytes = py::bytes(
-                      reinterpret_cast<const char*>(value.data()),
-                      value.size());
-                  res.push_back(bytes);
-                }
-                return res;
+                return toPyBytes(values);
               },
               R"(
 Retrieve all values in ``keys``. If any key in ``keys`` is not
@@ -1298,12 +1297,7 @@
               [](::c10d::Store& store,
                  const std::vector<std::string>& keys,
                  const std::vector<std::string>& values) {
-                std::vector<std::vector<uint8_t>> vals;
-                vals.reserve(values.size());
-                for (auto& value : values) {
-                  vals.emplace_back(value.begin(), value.end());
-                }
-                store.multiSet(keys, vals);
+                store.multiSet(keys, toVec8(values));
               },
               py::call_guard<py::gil_scoped_release>(),
               R"(
@@ -1487,6 +1481,212 @@
           &::c10d::PrefixStore::getUnderlyingNonPrefixStore,
           R"(Recursively to get the store before layers of wrapping with PrefixStore.)");
 
+  using namespace std::chrono_literals;
+
+  auto collectives =
+      py::class_<
+          ::c10d::ControlCollectives,
+          c10::intrusive_ptr<::c10d::ControlCollectives>>(
+          module,
+          "_ControlCollectives",
+          R"(
+Base class for all ControlCollectives implementations.
+)")
+          .def(
+              "barrier",
+              &::c10d::ControlCollectives::barrier,
+              py::arg("key"),
+              py::arg("timeout") = 5min,
+              py::arg("block") = true,
+              py::call_guard<py::gil_scoped_release>(),
+              R"(
+Blocks until all workers have entered this function.
+
+Arguments:
+    key (str): The unique key used to identify this operation.
+    timeout (duration): The timeout for this operation.
+    block (bool): whether to block this working waiting on the results of the barrier.
+)")
+          .def(
+              "all_sum",
+              &::c10d::ControlCollectives::allSum,
+              py::arg("key"),
+              py::arg("data"),
+              py::arg("timeout") = 5min,
+              py::call_guard<py::gil_scoped_release>(),
+              R"(
+Computes a sum across all workers and returns the final value.
+
+Arguments:
+    key (str): The unique key used to identify this operation.
+    data (int): The data to sum.
+    timeout (duration): The timeout for this operation.
+)")
+          .def(
+              "broadcast_send",
+              [](::c10d::ControlCollectives& collectives,
+                 const std::string& key,
+                 const std::string& data,
+                 std::chrono::milliseconds timeout = 5min) {
+                collectives.broadcastSend(key, toVec8(data), timeout);
+              },
+              py::arg("key"),
+              py::arg("data"),
+              py::arg("timeout") = 5min,
+              py::call_guard<py::gil_scoped_release>(),
+              R"(
+Sends data to all other workers. Must be only called from one worker.
+
+Arguments:
+    key (str): The unique key used to identify this operation.
+    data (str): The data to send.
+    timeout (duration): The timeout for this operation.
+)")
+          .def(
+              "broadcast_recv",
+              [](::c10d::ControlCollectives& collectives,
+                 const std::string& key,
+                 std::chrono::milliseconds timeout = 5min) {
+                auto out = [&]() {
+                  py::gil_scoped_release guard;
+                  return collectives.broadcastRecv(key, timeout);
+                }();
+                return toPyBytes(out);
+              },
+              py::arg("key"),
+              py::arg("timeout") = 5min,
+              R"(
+Receives data broadcasted from 1 worker.
+
+Arguments:
+    key (str): The unique key used to identify this operation.
+    timeout (duration): The timeout for this operation.
+)")
+          .def(
+              "gather_send",
+              [](::c10d::ControlCollectives& collectives,
+                 const std::string& key,
+                 const std::string& data,
+                 std::chrono::milliseconds timeout = 5min) {
+                collectives.gatherSend(key, toVec8(data), timeout);
+              },
+              py::arg("key"),
+              py::arg("data"),
+              py::arg("timeout") = 5min,
+              py::call_guard<py::gil_scoped_release>(),
+              R"(
+Sends data to one other worker.
+
+Arguments:
+    key (str): The unique key used to identify this operation.
+    data (str): The data to send.
+    timeout (duration): The timeout for this operation.
+)")
+          .def(
+              "gather_recv",
+              [](::c10d::ControlCollectives& collectives,
+                 const std::string& key,
+                 const std::string& data,
+                 std::chrono::milliseconds timeout = 5min) {
+                auto out = [&]() {
+                  py::gil_scoped_release guard;
+                  return collectives.gatherRecv(key, toVec8(data), timeout);
+                }();
+                return toPyBytes(out);
+              },
+              py::arg("key"),
+              py::arg("data"),
+              py::arg("timeout") = 5min,
+              R"(
+Receives data broadcasted from all workers. Must only be called by one worker.
+
+Arguments:
+    key (str): The unique key used to identify this operation.
+    timeout (duration): The timeout for this operation.
+)")
+
+          .def(
+              "scatter_send",
+              [](::c10d::ControlCollectives& collectives,
+                 const std::string& key,
+                 const std::vector<std::string>& data,
+                 std::chrono::milliseconds timeout = 5min) {
+                auto out = [&]() {
+                  py::gil_scoped_release guard;
+                  return collectives.scatterSend(key, toVec8(data), timeout);
+                }();
+                return toPyBytes(out);
+              },
+              py::arg("key"),
+              py::arg("data"),
+              py::arg("timeout") = 5min,
+              R"(
+Sends rank specific data to all other workers.
+
+Arguments:
+    key (str): The unique key used to identify this operation.
+    data (str): The data to send.
+    timeout (duration): The timeout for this operation.
+)")
+          .def(
+              "scatter_recv",
+              [](::c10d::ControlCollectives& collectives,
+                 const std::string& key,
+                 std::chrono::milliseconds timeout = 5min) {
+                auto out = [&]() {
+                  py::gil_scoped_release guard;
+                  return collectives.scatterRecv(key, timeout);
+                }();
+                return toPyBytes(out);
+              },
+              py::arg("key"),
+              py::arg("timeout") = 5min,
+              R"(
+Receives rank specific data from one worker.
+
+Arguments:
+    key (str): The unique key used to identify this operation.
+    timeout (duration): The timeout for this operation.
+)")
+
+          .def(
+              "all_gather",
+              [](::c10d::ControlCollectives& collectives,
+                 const std::string& key,
+                 const std::string& data,
+                 std::chrono::milliseconds timeout = 5min) {
+                auto out = [&]() {
+                  py::gil_scoped_release guard;
+                  return collectives.allGather(key, toVec8(data), timeout);
+                }();
+                return toPyBytes(out);
+              },
+              py::arg("key"),
+              py::arg("data"),
+              py::arg("timeout") = 5min,
+              R"(
+Sends data to all workers and receives data from all other workers.
+
+Arguments:
+    key (str): The unique key used to identify this operation.
+    data (str): The data to send.
+    timeout (duration): The timeout for this operation.
+)");
+
+  intrusive_ptr_class_<::c10d::StoreCollectives>(
+      module,
+      "_StoreCollectives",
+      collectives,
+      R"(
+An implementation of ControlCollectives that uses the provided store as the underlying
+communication mechanism.
+      )")
+      .def(
+          py::init<c10::intrusive_ptr<::c10d::Store>, int, int>(),
+          py::arg("store"),
+          py::arg("rank"),
+          py::arg("world_size"));
+
   auto processGroup =
       py::class_<
           ::c10d::ProcessGroup,
diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py
index eb7a690..3e7dce9 100644
--- a/torch/distributed/__init__.py
+++ b/torch/distributed/__init__.py
@@ -54,6 +54,8 @@
         set_debug_level,
         set_debug_level_from_env,
         _make_nccl_premul_sum,
+        _ControlCollectives,
+        _StoreCollectives,
     )
 
     class _DistributedPdb(pdb.Pdb):