[19/N] Add monitored_barrier custom op with CPU implementation (#89318)
Differential Revision: [D41415324](https://our.internmc.facebook.com/intern/diff/D41415324)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89318
Approved by: https://github.com/kwen2501
diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py
index 545f125..bee76e7 100644
--- a/test/distributed/test_c10d_gloo.py
+++ b/test/distributed/test_c10d_gloo.py
@@ -2431,6 +2431,17 @@
dist.all_gather_coalesced([output_tensor_list], [input_tensor])
self.assertEqual(output_tensor_list, [input_tensor])
+ @requires_gloo()
+ def test_monitored_barrier(self):
+ store = dist.FileStore(self.file_name, self.world_size)
+ dist.init_process_group(
+ "gloo",
+ world_size=self.world_size,
+ rank=self.rank,
+ store=store,
+ )
+ dist.monitored_barrier()
+
class CompilerTest(test_c10d_common.CompilerTest):
@property
diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp
index 4edb70c..6b4717a 100644
--- a/torch/csrc/distributed/c10d/Ops.cpp
+++ b/torch/csrc/distributed/c10d/Ops.cpp
@@ -181,6 +181,17 @@
BarrierOptions{device_ids, std::chrono::milliseconds(timeout)});
}
+void monitored_barrier_(
+ at::Tensor /* unused */,
+ const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
+ const std::vector<int64_t>& device_ids,
+ int64_t timeout,
+ bool wait_all_ranks) {
+ process_group->monitoredBarrier(
+ BarrierOptions{device_ids, std::chrono::milliseconds(timeout)},
+ wait_all_ranks);
+}
+
c10::intrusive_ptr<Work> send(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
@@ -255,6 +266,10 @@
m.def(
"barrier",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, barrier));
+ m.def(
+ "monitored_barrier_",
+ dispatch(
+ c10::DispatchKey::CompositeExplicitAutograd, monitored_barrier_));
m.def("send", dispatch(c10::DispatchKey::CompositeExplicitAutograd, send));
m.def("recv_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, recv_));
}
@@ -497,6 +512,28 @@
output_tensors, input_tensors, process_group, opts.timeout.count());
}
+void monitored_barrier(
+ const c10::intrusive_ptr<ProcessGroup>& process_group,
+ const BarrierOptions& opts,
+ bool wait_all_ranks) {
+ static auto op = c10::Dispatcher::singleton()
+ .findSchemaOrThrow("c10d::monitored_barrier_", "")
+ .typed<void(
+ at::Tensor,
+ const c10::intrusive_ptr<::c10d::ProcessGroup>&,
+ const std::vector<int64_t>&,
+ int64_t,
+ bool)>();
+ // Default to using cpu implementation, monitored barrier is only for GLOO
+ at::Tensor tensor = at::empty({0}, at::TensorOptions().device(at::kCPU));
+ op.call(
+ tensor,
+ process_group,
+ opts.device_ids,
+ opts.timeout.count(),
+ wait_all_ranks);
+}
+
c10::intrusive_ptr<Work> barrier(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const BarrierOptions& opts) {
diff --git a/torch/csrc/distributed/c10d/Ops.hpp b/torch/csrc/distributed/c10d/Ops.hpp
index ad6e2d3..b542603 100644
--- a/torch/csrc/distributed/c10d/Ops.hpp
+++ b/torch/csrc/distributed/c10d/Ops.hpp
@@ -83,6 +83,11 @@
const c10::intrusive_ptr<ProcessGroup>& process_group,
const BarrierOptions& opts = {});
+TORCH_API void monitored_barrier(
+ const c10::intrusive_ptr<ProcessGroup>& process_group,
+ const BarrierOptions& opts,
+ bool waitAllRanks);
+
TORCH_API c10::intrusive_ptr<Work> send(
const c10::intrusive_ptr<ProcessGroup>& process_group,
at::TensorList tensors,
diff --git a/torch/csrc/distributed/c10d/OpsImpl.cpp b/torch/csrc/distributed/c10d/OpsImpl.cpp
index 66269db..3138669 100644
--- a/torch/csrc/distributed/c10d/OpsImpl.cpp
+++ b/torch/csrc/distributed/c10d/OpsImpl.cpp
@@ -399,6 +399,17 @@
BarrierOptions{device_ids, std::chrono::milliseconds(timeout)});
}
+void monitored_barrier_cpu_(
+ at::Tensor /* unused */,
+ const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
+ const std::vector<int64_t>& device_ids,
+ int64_t timeout,
+ bool wait_all_ranks) {
+ process_group->monitoredBarrier(
+ BarrierOptions{device_ids, std::chrono::milliseconds(timeout)},
+ wait_all_ranks);
+}
+
// register functions to dispatcher
namespace {
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
@@ -531,6 +542,10 @@
m.impl("barrier", barrier_cuda);
}
+TORCH_LIBRARY_IMPL(c10d, CPU, m) {
+ m.impl("monitored_barrier_", monitored_barrier_cpu_);
+}
+
} // namespace
} // namespace ops
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index f65354d..9a9699c 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -1539,7 +1539,7 @@
bool waitAllRanks) {
::c10d::BarrierOptions opts;
opts.timeout = timeout;
- return self->monitoredBarrier(opts, waitAllRanks);
+ return ::c10d::ops::monitored_barrier(self, opts, waitAllRanks);
},
py::arg("timeout") = ::c10d::kUnsetTimeout,
py::arg("wait_all_ranks") = false,