Reapply "c10d: add Collectives abstraction (#125978)" (#126695)
This reverts commit d9c3485146913324ab4b3e211d2a4517e138f4af.
Reapplies #125978.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126695
Approved by: https://github.com/c-p-i-o
diff --git a/BUILD.bazel b/BUILD.bazel
index 3f7e632..831d64b 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -772,7 +772,7 @@
[
"torch/*.h",
"torch/csrc/**/*.h",
- "torch/csrc/distributed/c10d/*.hpp",
+ "torch/csrc/distributed/c10d/**/*.hpp",
"torch/lib/libshm/*.h",
],
exclude = [
diff --git a/build_variables.bzl b/build_variables.bzl
index 3f16f9b..152324a 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -487,6 +487,7 @@
# These files are the only ones that are supported on Windows.
libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/Backend.cpp",
+ "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
"torch/csrc/distributed/c10d/FileStore.cpp",
"torch/csrc/distributed/c10d/Functional.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
diff --git a/test/distributed/test_control_collectives.py b/test/distributed/test_control_collectives.py
new file mode 100644
index 0000000..fb0067f
--- /dev/null
+++ b/test/distributed/test_control_collectives.py
@@ -0,0 +1,189 @@
+# Owner(s): ["oncall: distributed"]
+
+from datetime import timedelta
+from multiprocessing.pool import ThreadPool
+
+import torch
+import torch.distributed as dist
+from torch.testing._internal.common_utils import run_tests, TestCase
+
+
+class TestCollectives(TestCase):
+ def test_barrier(self) -> None:
+ store = dist.HashStore()
+
+ world_size = 2
+
+ def f(rank: int) -> None:
+ collectives = dist._StoreCollectives(store, rank, world_size)
+ collectives.barrier("foo", timedelta(seconds=10), True)
+
+ with ThreadPool(world_size) as pool:
+ pool.map(f, range(world_size))
+
+ def test_broadcast(self) -> None:
+ store = dist.HashStore()
+
+ world_size = 4
+ timeout = timedelta(seconds=10)
+
+ def f(rank: int) -> None:
+ collectives = dist._StoreCollectives(store, rank, world_size)
+ if rank == 2:
+ collectives.broadcast_send("foo", b"data", timeout)
+ else:
+ out = collectives.broadcast_recv("foo", timeout)
+ self.assertEqual(out, b"data")
+
+ with ThreadPool(world_size) as pool:
+ pool.map(f, range(world_size))
+
+ def test_gather(self) -> None:
+ store = dist.HashStore()
+
+ world_size = 4
+ timeout = timedelta(seconds=10)
+
+ def f(rank: int) -> None:
+ collectives = dist._StoreCollectives(store, rank, world_size)
+ if rank == 2:
+ out = collectives.gather_recv("foo", str(rank), timeout)
+ self.assertEqual(out, [b"0", b"1", b"2", b"3"])
+ else:
+ collectives.gather_send("foo", str(rank), timeout)
+
+ with ThreadPool(world_size) as pool:
+ pool.map(f, range(world_size))
+
+ def test_scatter(self) -> None:
+ store = dist.HashStore()
+
+ world_size = 4
+ timeout = timedelta(seconds=10)
+
+ def f(rank: int) -> None:
+ collectives = dist._StoreCollectives(store, rank, world_size)
+ if rank == 2:
+ out = collectives.scatter_send(
+ "foo", [str(i) for i in range(world_size)], timeout
+ )
+ else:
+ out = collectives.scatter_recv("foo", timeout)
+ self.assertEqual(out, str(rank).encode())
+
+ with ThreadPool(world_size) as pool:
+ pool.map(f, range(world_size))
+
+ def test_all_sum(self) -> None:
+ store = dist.HashStore()
+
+ world_size = 4
+ timeout = timedelta(seconds=10)
+
+ def f(rank: int) -> None:
+ collectives = dist._StoreCollectives(store, rank, world_size)
+ out = collectives.all_sum("foo", rank, timeout)
+ self.assertEqual(out, sum(range(world_size)))
+
+ with ThreadPool(world_size) as pool:
+ pool.map(f, range(world_size))
+
+ def test_broadcast_timeout(self) -> None:
+ store = dist.HashStore()
+
+ world_size = 4
+ timeout = timedelta(milliseconds=1)
+ collectives = dist._StoreCollectives(store, 1, world_size)
+ with self.assertRaisesRegex(Exception, "Wait timeout"):
+ collectives.broadcast_recv("foo", timeout)
+
+ def test_gather_timeout(self) -> None:
+ store = dist.HashStore()
+
+ world_size = 4
+ timeout = timedelta(milliseconds=1)
+ collectives = dist._StoreCollectives(store, 1, world_size)
+ with self.assertRaisesRegex(
+ Exception, "gather failed -- missing ranks: 0, 2, 3"
+ ):
+ collectives.gather_recv("foo", "data", timeout)
+
+ def test_scatter_timeout(self) -> None:
+ store = dist.HashStore()
+
+ world_size = 4
+ timeout = timedelta(milliseconds=1)
+ collectives = dist._StoreCollectives(store, 1, world_size)
+ with self.assertRaisesRegex(Exception, "Wait timeout"):
+ collectives.scatter_recv("foo", timeout)
+
+ def test_all_gather_timeout(self) -> None:
+ store = dist.HashStore()
+
+ world_size = 4
+ timeout = timedelta(milliseconds=1)
+ collectives = dist._StoreCollectives(store, 1, world_size)
+ with self.assertRaisesRegex(
+ Exception, "all_gather failed -- missing ranks: 0, 2, 3"
+ ):
+ collectives.all_gather("foo", "data", timeout)
+
+ def test_barrier_timeout(self) -> None:
+ store = dist.HashStore()
+
+ world_size = 4
+ timeout = timedelta(milliseconds=1)
+ collectives = dist._StoreCollectives(store, 1, world_size)
+ with self.assertRaisesRegex(
+ Exception, "barrier failed -- missing ranks: 0, 2, 3"
+ ):
+ collectives.barrier("foo", timeout, True)
+
+ def test_all_sum_timeout(self) -> None:
+ store = dist.HashStore()
+
+ world_size = 4
+ timeout = timedelta(milliseconds=1)
+ collectives = dist._StoreCollectives(store, 1, world_size)
+ with self.assertRaisesRegex(
+ Exception, "barrier failed -- missing ranks: 0, 2, 3"
+ ):
+ collectives.all_sum("foo", 1, timeout)
+
+ def test_unique(self) -> None:
+ store = dist.HashStore()
+
+ collectives = dist._StoreCollectives(store, 1, 1)
+ collectives.broadcast_send("foo", "bar")
+
+ with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+ collectives.broadcast_send("foo", "bar")
+
+ with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+ collectives.broadcast_recv("foo")
+
+ with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+ collectives.gather_send("foo", "bar")
+
+ with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+ collectives.gather_recv("foo", "asdf")
+
+ with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+ collectives.scatter_send("foo", ["asdf"])
+
+ with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+ collectives.scatter_recv("foo")
+
+ with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+ collectives.all_gather("foo", "bar")
+
+ with self.assertRaisesRegex(Exception, "Key foo has already been used"):
+ collectives.all_sum("foo", 2)
+
+
+if __name__ == "__main__":
+ assert (
+ not torch.cuda._initialized
+ ), "test_distributed must not have initialized CUDA context on main process"
+
+ run_tests()
diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi
index 28d790e..74a73a3 100644
--- a/torch/_C/_distributed_c10d.pyi
+++ b/torch/_C/_distributed_c10d.pyi
@@ -210,6 +210,20 @@
@property
def underlying_store(self) -> Store: ...
+class _ControlCollectives:
+ def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ...
+ def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ...
+ def broadcast_recv(self, key: str, timeout: timedelta) -> str: ...
+ def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ...
+ def gather_recv(self, key: str, timeout: timedelta) -> str: ...
+ def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ...
+ def scatter_recv(self, key: str, timeout: timedelta) -> str: ...
+ def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ...
+ def all_sum(self, key: str, data: str, timeout: timedelta) -> int: ...
+
+class _StoreCollectives(_ControlCollectives):
+ def __init__(self, store: Store, rank: int, world_size: int) -> None: ...
+
class _DistributedBackendOptions:
def __init__(self): ...
@property
diff --git a/torch/csrc/distributed/c10d/HashStore.hpp b/torch/csrc/distributed/c10d/HashStore.hpp
index 1453c0a..3697d62 100644
--- a/torch/csrc/distributed/c10d/HashStore.hpp
+++ b/torch/csrc/distributed/c10d/HashStore.hpp
@@ -22,7 +22,7 @@
std::vector<uint8_t> get(const std::string& key) override;
void wait(const std::vector<std::string>& keys) override {
- wait(keys, Store::kDefaultTimeout);
+ wait(keys, timeout_);
}
void wait(
diff --git a/torch/csrc/distributed/c10d/Store.hpp b/torch/csrc/distributed/c10d/Store.hpp
index af715ba..993284f 100644
--- a/torch/csrc/distributed/c10d/Store.hpp
+++ b/torch/csrc/distributed/c10d/Store.hpp
@@ -97,4 +97,33 @@
std::chrono::milliseconds timeout_;
};
+/*
+StoreTimeoutGuard is a RAII guard that will set the store timeout and restore it
+when it returns.
+*/
+class StoreTimeoutGuard {
+ public:
+ explicit StoreTimeoutGuard(
+ Store& store,
+ const std::chrono::milliseconds& timeout)
+ : store_(store) {
+ oldTimeout_ = store.getTimeout();
+ store.setTimeout(timeout);
+ }
+
+ ~StoreTimeoutGuard() {
+ store_.setTimeout(oldTimeout_);
+ }
+
+ /* Disabling copy and move semantics */
+ StoreTimeoutGuard(const StoreTimeoutGuard&) = delete;
+ StoreTimeoutGuard& operator=(const StoreTimeoutGuard&) = delete;
+ StoreTimeoutGuard(StoreTimeoutGuard&&) = delete;
+ StoreTimeoutGuard& operator=(StoreTimeoutGuard&&) = delete;
+
+ private:
+ Store& store_;
+ std::chrono::milliseconds oldTimeout_;
+};
+
} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp b/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp
new file mode 100644
index 0000000..b98f9a7
--- /dev/null
+++ b/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp
@@ -0,0 +1,59 @@
+#pragma once
+
+#include <ATen/core/ivalue.h>
+#include <chrono>
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include <c10/macros/Macros.h>
+#include <torch/custom_class.h>
+
+namespace c10d {
+
+using namespace std::chrono_literals;
+
+class TORCH_API ControlCollectives : public torch::CustomClassHolder {
+ public:
+ virtual void barrier(
+ const std::string& key,
+ std::chrono::milliseconds timeout = 5min,
+ bool block = true) = 0;
+
+ virtual void broadcastSend(
+ const std::string& key,
+ const std::vector<uint8_t>& data,
+ std::chrono::milliseconds timeout = 5min) = 0;
+ virtual std::vector<uint8_t> broadcastRecv(
+ const std::string& key,
+ std::chrono::milliseconds timeout = 5min) = 0;
+
+ virtual void gatherSend(
+ const std::string& key,
+ const std::vector<uint8_t>& data,
+ std::chrono::milliseconds timeout = 5min) = 0;
+ virtual std::vector<std::vector<uint8_t>> gatherRecv(
+ const std::string& key,
+ const std::vector<uint8_t>& data,
+ std::chrono::milliseconds timeout = 5min) = 0;
+
+ virtual std::vector<uint8_t> scatterSend(
+ const std::string& key,
+ const std::vector<std::vector<uint8_t>>& data,
+ std::chrono::milliseconds timeout = 5min) = 0;
+ virtual std::vector<uint8_t> scatterRecv(
+ const std::string& key,
+ std::chrono::milliseconds timeout = 5min) = 0;
+
+ virtual std::vector<std::vector<uint8_t>> allGather(
+ const std::string& key,
+ const std::vector<uint8_t>& data,
+ std::chrono::milliseconds timeout = 5min) = 0;
+
+ virtual int64_t allSum(
+ const std::string& key,
+ int64_t data,
+ std::chrono::milliseconds timeout = 5min) = 0;
+};
+
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp
new file mode 100644
index 0000000..9958994
--- /dev/null
+++ b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp
@@ -0,0 +1,222 @@
+#include <c10/util/Exception.h>
+#include <fmt/format.h>
+#include <torch/csrc/distributed/c10d/Store.hpp>
+#include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
+#include <chrono>
+#include <exception>
+#include <vector>
+
+namespace {
+std::string getRankKey(const std::string& key, int rank) {
+ return fmt::format("{}/{}", key, rank);
+}
+} // namespace
+
+namespace c10d {
+
+StoreCollectives::StoreCollectives(
+ c10::intrusive_ptr<::c10d::Store> store,
+ int rank,
+ int worldSize)
+ : store_(std::move(store)), rank_(rank), worldSize_(worldSize) {}
+
+void StoreCollectives::barrier(
+ const std::string& key,
+ std::chrono::milliseconds timeout,
+ bool blocking) {
+ enforceUnique(key);
+ StoreTimeoutGuard g{*store_, timeout};
+
+ auto num_members_key = fmt::format("{}/num_members", key);
+ auto last_members_key = fmt::format("{}/last_members", key);
+
+ auto idx = store_->add(num_members_key, 1);
+ store_->set(getRankKey(key, rank_), "joined");
+
+ if (idx == worldSize_) {
+ store_->set(last_members_key, "<val_ignored>");
+ } else if (blocking) {
+ try {
+ store_->wait({last_members_key});
+ } catch (const std::exception& e) {
+ std::string msg = "barrier failed -- missing ranks: ";
+ for (int i = 0; i < worldSize_; i++) {
+ if (i == rank_) {
+ continue;
+ }
+ auto rank_key = getRankKey(key, i);
+ if (!store_->check({rank_key})) {
+ msg += fmt::format("{}, ", i);
+ }
+ }
+ throw std::runtime_error(msg + e.what());
+ }
+ }
+}
+
+void StoreCollectives::broadcastSend(
+ const std::string& key,
+ const std::vector<uint8_t>& data,
+ std::chrono::milliseconds timeout) {
+ enforceUnique(key);
+ StoreTimeoutGuard g{*store_, timeout};
+
+ store_->set(key, data);
+}
+
+std::vector<uint8_t> StoreCollectives::broadcastRecv(
+ const std::string& key,
+ std::chrono::milliseconds timeout) {
+ enforceUnique(key);
+ StoreTimeoutGuard g{*store_, timeout};
+
+ return store_->get(key);
+}
+
+void StoreCollectives::gatherSend(
+ const std::string& key,
+ const std::vector<uint8_t>& data,
+ std::chrono::milliseconds timeout) {
+ enforceUnique(key);
+ StoreTimeoutGuard g{*store_, timeout};
+
+ auto rank_key = getRankKey(key, rank_);
+ store_->set(rank_key, data);
+}
+
+std::vector<std::vector<uint8_t>> StoreCollectives::gatherRecv(
+ const std::string& key,
+ const std::vector<uint8_t>& data,
+ std::chrono::milliseconds timeout) {
+ enforceUnique(key);
+ StoreTimeoutGuard g{*store_, timeout};
+
+ std::vector<std::string> keys;
+ keys.reserve(worldSize_);
+
+ for (int i = 0; i < worldSize_; i++) {
+ if (i == rank_) {
+ continue;
+ }
+ auto rank_key = getRankKey(key, i);
+ keys.emplace_back(rank_key);
+ }
+
+ std::vector<std::vector<uint8_t>> results;
+ results.reserve(worldSize_);
+
+ try {
+ results = store_->multiGet(keys);
+ } catch (const std::exception& e) {
+ std::string msg = "gather failed -- missing ranks: ";
+ for (int i = 0; i < worldSize_; i++) {
+ if (i == rank_) {
+ continue;
+ }
+ auto rank_key = getRankKey(key, i);
+ if (!store_->check({rank_key})) {
+ msg += fmt::format("{}, ", i);
+ }
+ }
+ throw std::runtime_error(msg + e.what());
+ }
+
+ // insert local data
+ results.insert(results.begin() + rank_, data);
+ return results;
+}
+
+std::vector<uint8_t> StoreCollectives::scatterSend(
+ const std::string& key,
+ const std::vector<std::vector<uint8_t>>& data,
+ std::chrono::milliseconds timeout) {
+ enforceUnique(key);
+ StoreTimeoutGuard g{*store_, timeout};
+
+ std::vector<std::string> keys;
+ keys.reserve(worldSize_);
+ for (int i = 0; i < worldSize_; i++) {
+ if (i == rank_) {
+ continue;
+ }
+ auto rank_key = getRankKey(key, i);
+ keys.emplace_back(rank_key);
+ }
+ auto local = data.at(rank_);
+
+ std::vector<std::vector<uint8_t>> toSend{data};
+
+ toSend.erase(toSend.begin() + rank_);
+
+ store_->multiSet(keys, toSend);
+
+ return local;
+}
+
+std::vector<uint8_t> StoreCollectives::scatterRecv(
+ const std::string& key,
+ std::chrono::milliseconds timeout) {
+ enforceUnique(key);
+ StoreTimeoutGuard g{*store_, timeout};
+
+ auto rank_key = getRankKey(key, rank_);
+ return store_->get(rank_key);
+}
+
+std::vector<std::vector<uint8_t>> StoreCollectives::allGather(
+ const std::string& key,
+ const std::vector<uint8_t>& data,
+ std::chrono::milliseconds timeout) {
+ enforceUnique(key);
+ StoreTimeoutGuard g{*store_, timeout};
+
+ auto localKey = getRankKey(key, rank_);
+ store_->set(localKey, data);
+
+ std::vector<std::string> keys;
+ keys.reserve(worldSize_);
+
+ for (int i = 0; i < worldSize_; i++) {
+ auto rank_key = getRankKey(key, i);
+ keys.emplace_back(rank_key);
+ }
+
+ try {
+ return store_->multiGet(keys);
+ } catch (const std::exception& e) {
+ std::string msg = "all_gather failed -- missing ranks: ";
+ for (int i = 0; i < worldSize_; i++) {
+ if (i == rank_) {
+ continue;
+ }
+ auto rank_key = getRankKey(key, i);
+ if (!store_->check({rank_key})) {
+ msg += fmt::format("{}, ", i);
+ }
+ }
+ throw std::runtime_error(msg + e.what());
+ }
+}
+
+int64_t StoreCollectives::allSum(
+ const std::string& key,
+ int64_t value,
+ std::chrono::milliseconds timeout) {
+ enforceUnique(key);
+ StoreTimeoutGuard g{*store_, timeout};
+
+ store_->add(key, value);
+
+ barrier(key + "/barrier", timeout);
+
+ return store_->add(key, 0);
+}
+
+void StoreCollectives::enforceUnique(const std::string& key) {
+ auto it = seenKeys_.find(key);
+ TORCH_INTERNAL_ASSERT(
+ it == seenKeys_.end(), "Key ", key, " has already been used.");
+ seenKeys_.emplace(key);
+}
+
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp
new file mode 100644
index 0000000..7d3eb50
--- /dev/null
+++ b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp
@@ -0,0 +1,68 @@
+#pragma once
+
+#include <c10/macros/Macros.h>
+#include <c10/util/FbcodeMaps.h>
+#include <torch/csrc/distributed/c10d/Store.hpp>
+#include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
+
+namespace c10d {
+
+class TORCH_API StoreCollectives : public ControlCollectives {
+ public:
+ explicit StoreCollectives(
+ c10::intrusive_ptr<Store> store,
+ int rank,
+ int worldSize);
+
+ void barrier(
+ const std::string& key,
+ std::chrono::milliseconds timeout = 5min,
+ bool block = true) override;
+
+ void broadcastSend(
+ const std::string& key,
+ const std::vector<uint8_t>& data,
+ std::chrono::milliseconds timeout = 5min) override;
+ std::vector<uint8_t> broadcastRecv(
+ const std::string& key,
+ std::chrono::milliseconds timeout = 5min) override;
+
+ void gatherSend(
+ const std::string& key,
+ const std::vector<uint8_t>& data,
+ std::chrono::milliseconds timeout = 5min) override;
+ std::vector<std::vector<uint8_t>> gatherRecv(
+ const std::string& key,
+ const std::vector<uint8_t>& data,
+ std::chrono::milliseconds timeout = 5min) override;
+
+ std::vector<uint8_t> scatterSend(
+ const std::string& key,
+ const std::vector<std::vector<uint8_t>>& data,
+ std::chrono::milliseconds timeout = 5min) override;
+ std::vector<uint8_t> scatterRecv(
+ const std::string& key,
+ std::chrono::milliseconds timeout = 5min) override;
+
+ std::vector<std::vector<uint8_t>> allGather(
+ const std::string& key,
+ const std::vector<uint8_t>& data,
+ std::chrono::milliseconds timeout = 5min) override;
+
+ int64_t allSum(
+ const std::string& key,
+ int64_t data,
+ std::chrono::milliseconds timeout = 5min) override;
+
+ private:
+ void enforceUnique(const std::string& key);
+
+ private:
+ c10::intrusive_ptr<Store> store_;
+ int rank_;
+ int worldSize_;
+
+ c10::FastSet<std::string> seenKeys_{};
+};
+
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 483becb..505b64e 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -6,6 +6,9 @@
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>
+#include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
+#include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
+#include <vector>
#ifndef _WIN32
#include <torch/csrc/distributed/c10d/HashStore.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupRoundRobin.hpp>
@@ -136,6 +139,34 @@
namespace {
+py::bytes toPyBytes(const std::vector<uint8_t>& data) {
+ return py::bytes(reinterpret_cast<const char*>(data.data()), data.size());
+}
+
+std::vector<py::bytes> toPyBytes(
+ const std::vector<std::vector<uint8_t>>& data) {
+ std::vector<py::bytes> out;
+ out.reserve(data.size());
+ for (const std::vector<uint8_t>& data_ : data) {
+ out.emplace_back(reinterpret_cast<const char*>(data_.data()), data_.size());
+ }
+ return out;
+}
+
+std::vector<uint8_t> toVec8(const std::string& data) {
+ std::vector<uint8_t> out{data.begin(), data.end()};
+ return out;
+}
+
+std::vector<std::vector<uint8_t>> toVec8(const std::vector<std::string>& data) {
+ std::vector<std::vector<uint8_t>> out;
+ out.reserve(data.size());
+ for (auto& data_ : data) {
+ out.emplace_back(toVec8(data_));
+ }
+ return out;
+}
+
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
@@ -166,8 +197,7 @@
pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "set");
TORCH_INTERNAL_ASSERT(fn, "Not implemented.");
// Call function with a py::bytes object for the value.
- fn(key,
- py::bytes(reinterpret_cast<const char*>(value.data()), value.size()));
+ fn(key, toPyBytes(value));
}
// Note: this function manually calls the Python-side overload
@@ -184,7 +214,7 @@
// std::vector<uint8_t>. There is no API for directly accessing
// the contents of the py::bytes object.
std::string str = pybind11::cast<py::bytes>(fn(key));
- return std::vector<uint8_t>(str.begin(), str.end());
+ return toVec8(str);
}
// Note: this function manually calls the Python-side overload
@@ -204,14 +234,8 @@
// std::vector<uint8_t>. There is no API for directly accessing
// the contents of the py::bytes object.
std::string str = pybind11::cast<py::bytes>(
- fn(key,
- py::bytes(
- reinterpret_cast<const char*>(expectedValue.data()),
- expectedValue.size()),
- py::bytes(
- reinterpret_cast<const char*>(desiredValue.data()),
- desiredValue.size())));
- return std::vector<uint8_t>(str.begin(), str.end());
+ fn(key, toPyBytes(expectedValue), toPyBytes(desiredValue)));
+ return toVec8(str);
}
int64_t add(const std::string& key, int64_t value) override {
@@ -253,8 +277,7 @@
return Store::append(key, value);
}
// Call function with a py::bytes object for the value.
- fn(key,
- py::bytes(reinterpret_cast<const char*>(value.data()), value.size()));
+ fn(key, toPyBytes(value));
}
std::vector<std::vector<uint8_t>> multiGet(
@@ -287,14 +310,7 @@
return Store::multiSet(keys, values);
}
- std::vector<py::bytes> bytes;
- bytes.reserve(values.size());
- for (auto& value : values) {
- bytes.emplace_back(
- reinterpret_cast<const char*>(value.data()), value.size());
- }
-
- fn(keys, bytes);
+ fn(keys, toPyBytes(values));
}
bool hasExtendedApi() const override {
@@ -973,10 +989,7 @@
"set",
[](::c10d::Store& store,
const std::string& key,
- const std::string& value) {
- std::vector<uint8_t> value_(value.begin(), value.end());
- store.set(key, value_);
- },
+ const std::string& value) { store.set(key, toVec8(value)); },
py::call_guard<py::gil_scoped_release>(),
R"(
Inserts the key-value pair into the store based on the supplied ``key`` and
@@ -1001,14 +1014,9 @@
const std::string& key,
const std::string& expected_value,
const std::string& desired_value) -> py::bytes {
- std::vector<uint8_t> expectedValue_(
- expected_value.begin(), expected_value.end());
- std::vector<uint8_t> desiredValue_(
- desired_value.begin(), desired_value.end());
- auto value =
- store.compareSet(key, expectedValue_, desiredValue_);
- return py::bytes(
- reinterpret_cast<char*>(value.data()), value.size());
+ auto value = store.compareSet(
+ key, toVec8(expected_value), toVec8(desired_value));
+ return toPyBytes(value);
},
py::call_guard<py::gil_scoped_release>(),
R"(
@@ -1040,8 +1048,7 @@
py::gil_scoped_release guard;
return store.get(key);
}();
- return py::bytes(
- reinterpret_cast<char*>(value.data()), value.size());
+ return toPyBytes(value);
},
R"(
Retrieves the value associated with the given ``key`` in the store. If ``key`` is not
@@ -1240,8 +1247,7 @@
[](::c10d::Store& store,
const std::string& key,
const std::string& value) {
- std::vector<uint8_t> value_(value.begin(), value.end());
- store.append(key, value_);
+ store.append(key, toVec8(value));
},
py::call_guard<py::gil_scoped_release>(),
R"(
@@ -1268,14 +1274,7 @@
py::gil_scoped_release guard;
return store.multiGet(keys);
}();
- std::vector<py::bytes> res;
- for (auto& value : values) {
- auto bytes = py::bytes(
- reinterpret_cast<const char*>(value.data()),
- value.size());
- res.push_back(bytes);
- }
- return res;
+ return toPyBytes(values);
},
R"(
Retrieve all values in ``keys``. If any key in ``keys`` is not
@@ -1298,12 +1297,7 @@
[](::c10d::Store& store,
const std::vector<std::string>& keys,
const std::vector<std::string>& values) {
- std::vector<std::vector<uint8_t>> vals;
- vals.reserve(values.size());
- for (auto& value : values) {
- vals.emplace_back(value.begin(), value.end());
- }
- store.multiSet(keys, vals);
+ store.multiSet(keys, toVec8(values));
},
py::call_guard<py::gil_scoped_release>(),
R"(
@@ -1487,6 +1481,212 @@
&::c10d::PrefixStore::getUnderlyingNonPrefixStore,
R"(Recursively to get the store before layers of wrapping with PrefixStore.)");
+ using namespace std::chrono_literals;
+
+ auto collectives =
+ py::class_<
+ ::c10d::ControlCollectives,
+ c10::intrusive_ptr<::c10d::ControlCollectives>>(
+ module,
+ "_ControlCollectives",
+ R"(
+Base class for all ControlCollectives implementations.
+)")
+ .def(
+ "barrier",
+ &::c10d::ControlCollectives::barrier,
+ py::arg("key"),
+ py::arg("timeout") = 5min,
+ py::arg("block") = true,
+ py::call_guard<py::gil_scoped_release>(),
+ R"(
+Blocks until all workers have entered this function.
+
+Arguments:
+ key (str): The unique key used to identify this operation.
+ timeout (duration): The timeout for this operation.
+ block (bool): whether to block this working waiting on the results of the barrier.
+)")
+ .def(
+ "all_sum",
+ &::c10d::ControlCollectives::allSum,
+ py::arg("key"),
+ py::arg("data"),
+ py::arg("timeout") = 5min,
+ py::call_guard<py::gil_scoped_release>(),
+ R"(
+Computes a sum across all workers and returns the final value.
+
+Arguments:
+ key (str): The unique key used to identify this operation.
+ data (int): The data to sum.
+ timeout (duration): The timeout for this operation.
+)")
+ .def(
+ "broadcast_send",
+ [](::c10d::ControlCollectives& collectives,
+ const std::string& key,
+ const std::string& data,
+ std::chrono::milliseconds timeout = 5min) {
+ collectives.broadcastSend(key, toVec8(data), timeout);
+ },
+ py::arg("key"),
+ py::arg("data"),
+ py::arg("timeout") = 5min,
+ py::call_guard<py::gil_scoped_release>(),
+ R"(
+Sends data to all other workers. Must be only called from one worker.
+
+Arguments:
+ key (str): The unique key used to identify this operation.
+ data (str): The data to send.
+ timeout (duration): The timeout for this operation.
+)")
+ .def(
+ "broadcast_recv",
+ [](::c10d::ControlCollectives& collectives,
+ const std::string& key,
+ std::chrono::milliseconds timeout = 5min) {
+ auto out = [&]() {
+ py::gil_scoped_release guard;
+ return collectives.broadcastRecv(key, timeout);
+ }();
+ return toPyBytes(out);
+ },
+ py::arg("key"),
+ py::arg("timeout") = 5min,
+ R"(
+Receives data broadcasted from 1 worker.
+
+Arguments:
+ key (str): The unique key used to identify this operation.
+ timeout (duration): The timeout for this operation.
+)")
+ .def(
+ "gather_send",
+ [](::c10d::ControlCollectives& collectives,
+ const std::string& key,
+ const std::string& data,
+ std::chrono::milliseconds timeout = 5min) {
+ collectives.gatherSend(key, toVec8(data), timeout);
+ },
+ py::arg("key"),
+ py::arg("data"),
+ py::arg("timeout") = 5min,
+ py::call_guard<py::gil_scoped_release>(),
+ R"(
+Sends data to one other worker.
+
+Arguments:
+ key (str): The unique key used to identify this operation.
+ data (str): The data to send.
+ timeout (duration): The timeout for this operation.
+)")
+ .def(
+ "gather_recv",
+ [](::c10d::ControlCollectives& collectives,
+ const std::string& key,
+ const std::string& data,
+ std::chrono::milliseconds timeout = 5min) {
+ auto out = [&]() {
+ py::gil_scoped_release guard;
+ return collectives.gatherRecv(key, toVec8(data), timeout);
+ }();
+ return toPyBytes(out);
+ },
+ py::arg("key"),
+ py::arg("data"),
+ py::arg("timeout") = 5min,
+ R"(
+Receives data broadcasted from all workers. Must only be called by one worker.
+
+Arguments:
+ key (str): The unique key used to identify this operation.
+ timeout (duration): The timeout for this operation.
+)")
+
+ .def(
+ "scatter_send",
+ [](::c10d::ControlCollectives& collectives,
+ const std::string& key,
+ const std::vector<std::string>& data,
+ std::chrono::milliseconds timeout = 5min) {
+ auto out = [&]() {
+ py::gil_scoped_release guard;
+ return collectives.scatterSend(key, toVec8(data), timeout);
+ }();
+ return toPyBytes(out);
+ },
+ py::arg("key"),
+ py::arg("data"),
+ py::arg("timeout") = 5min,
+ R"(
+Sends rank specific data to all other workers.
+
+Arguments:
+ key (str): The unique key used to identify this operation.
+ data (str): The data to send.
+ timeout (duration): The timeout for this operation.
+)")
+ .def(
+ "scatter_recv",
+ [](::c10d::ControlCollectives& collectives,
+ const std::string& key,
+ std::chrono::milliseconds timeout = 5min) {
+ auto out = [&]() {
+ py::gil_scoped_release guard;
+ return collectives.scatterRecv(key, timeout);
+ }();
+ return toPyBytes(out);
+ },
+ py::arg("key"),
+ py::arg("timeout") = 5min,
+ R"(
+Receives rank specific data from one worker.
+
+Arguments:
+ key (str): The unique key used to identify this operation.
+ timeout (duration): The timeout for this operation.
+)")
+
+ .def(
+ "all_gather",
+ [](::c10d::ControlCollectives& collectives,
+ const std::string& key,
+ const std::string& data,
+ std::chrono::milliseconds timeout = 5min) {
+ auto out = [&]() {
+ py::gil_scoped_release guard;
+ return collectives.allGather(key, toVec8(data), timeout);
+ }();
+ return toPyBytes(out);
+ },
+ py::arg("key"),
+ py::arg("data"),
+ py::arg("timeout") = 5min,
+ R"(
+Sends data to all workers and receives data from all other workers.
+
+Arguments:
+ key (str): The unique key used to identify this operation.
+ data (str): The data to send.
+ timeout (duration): The timeout for this operation.
+)");
+
+ intrusive_ptr_class_<::c10d::StoreCollectives>(
+ module,
+ "_StoreCollectives",
+ collectives,
+ R"(
+An implementation of ControlCollectives that uses the provided store as the underlying
+communication mechanism.
+ )")
+ .def(
+ py::init<c10::intrusive_ptr<::c10d::Store>, int, int>(),
+ py::arg("store"),
+ py::arg("rank"),
+ py::arg("world_size"));
+
auto processGroup =
py::class_<
::c10d::ProcessGroup,
diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py
index eb7a690..3e7dce9 100644
--- a/torch/distributed/__init__.py
+++ b/torch/distributed/__init__.py
@@ -54,6 +54,8 @@
set_debug_level,
set_debug_level_from_env,
_make_nccl_premul_sum,
+ _ControlCollectives,
+ _StoreCollectives,
)
class _DistributedPdb(pdb.Pdb):