Revert "Reland "[C10] PG observability hooks. (#108815)" (#110907)"
This reverts commit 7678cd22af46c9df4fb47a409d3e8ad71a6127ea.
Reverted https://github.com/pytorch/pytorch/pull/110907 on behalf of https://github.com/huydhn due to Sorry for reverting this, but macos job in trunk starts failing after this https://hud.pytorch.org/pytorch/pytorch/commit/7678cd22af46c9df4fb47a409d3e8ad71a6127ea ([comment](https://github.com/pytorch/pytorch/pull/110907#issuecomment-1756497387))
diff --git a/build_variables.bzl b/build_variables.bzl
index c6592e0..00baad9 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -521,7 +521,6 @@
"torch/csrc/distributed/c10d/Backend.cpp",
"torch/csrc/distributed/c10d/FileStore.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
- "torch/csrc/distributed/c10d/Hooks.cpp",
"torch/csrc/distributed/c10d/Ops.cpp",
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
"torch/csrc/distributed/c10d/PrefixStore.cpp",
diff --git a/test/distributed/test_hooks.py b/test/distributed/test_hooks.py
deleted file mode 100644
index aa3c489..0000000
--- a/test/distributed/test_hooks.py
+++ /dev/null
@@ -1,270 +0,0 @@
-# Owner(s): ["oncall: distributed"]
-
-import os
-import sys
-import tempfile
-import threading
-from functools import partial, wraps
-
-import torch
-import torch.distributed as dist
-import torch.distributed._hooks as dhooks
-
-if not dist.is_available():
- print("torch.distributed not available, skipping tests", file=sys.stderr)
- sys.exit(0)
-
-
-from torch.testing._internal.common_distributed import (
- MultiProcessTestCase,
- skip_if_lt_x_gpu,
-)
-
-from torch.testing._internal.common_utils import run_tests, TestCase
-
-
-class PgHooks(MultiProcessTestCase):
- @property
- def world_size(self) -> int:
- return 4
-
- def setUp(self) -> None:
- super().setUp()
- self._spawn_processes()
-
- def tearDown(self):
- super().tearDown()
- try:
- os.remove(self.file_name)
- except OSError:
- pass
-
- def test_pg_hook(self):
- pgs = []
-
- def pg_hook(pg, pg_name):
- pgs.append((pg, pg_name))
-
- dhooks.register_process_group_hook(pg_hook)
- dist.init_process_group(
- backend="gloo",
- rank=self.rank,
- world_size=self.world_size,
- store=dist.FileStore(self.file_name, self.world_size),
- )
- self.assertEqual(len(pgs), 1)
- self.assertEqual(pgs[0][0], dist.group.WORLD)
-
- # create two partial world PGs
- pg0 = dist.new_group(ranks=[0, 1])
- pg1 = dist.new_group(ranks=[2, 3])
-
- # Each rank only observe two PGs being created: the default PG and one covering its ranks
- # We don't emit events for PG creation if the current rank doesn't belong to it.
- # For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact
- # dictates you need to call new_group for both.
- self.assertEqual(len(pgs), 2)
- self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1)
-
-
-def with_comms(func=None):
- if func is None:
- return partial(
- with_comms,
- )
-
- @wraps(func)
- def wrapper(self, *args, **kwargs):
- self.init_comms()
- func(self, *args, **kwargs)
- self.destroy_comms()
-
- return wrapper
-
-
-class CollectiveHooks:
- @property
- def world_size(self) -> int:
- return 4
-
- def _collective_hooks(self):
- # it's ok to access them directly since there's a single bg thread poking at them.
- starts = []
- ends = []
- cv = threading.Condition()
-
- def coll_start(status):
- starts.append(status)
- print(f"col_start {len(starts)} rank{self.rank}")
-
- def coll_end(status):
- ends.append(status)
- print(f"col_end {len(ends)} rank{self.rank}")
- if len(ends) == 2:
- with cv:
- cv.notify()
-
- dhooks.register_collective_start_hook(coll_start)
- dhooks.register_collective_end_hook(coll_end)
-
- tensor = torch.ones([2, 3]).to(self.device) * self.rank
- tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
-
- dist.all_gather(tensor_list, tensor)
-
- tensor2 = torch.ones([2, 3]).to(self.device) * self.rank
- dist.all_reduce(tensor2)
-
- with cv:
- cv.wait(1)
-
- default_pg_name = dist.group.WORLD.group_name
- self.assertEqual(2, len(starts))
- self.assertEqual(2, len(ends))
-
- def check_op(idx, coll_name):
- self.assertEqual(default_pg_name, starts[idx].pg_name)
- self.assertEqual(self.backend_name, starts[idx].backend)
- self.assertGreaterEqual(starts[idx].sequence_number, 0)
- self.assertGreaterEqual(starts[idx].timestamp, 0)
- self.assertEqual(coll_name, starts[idx].operation)
-
- self.assertEqual(default_pg_name, ends[idx].pg_name)
- self.assertEqual(self.backend_name, ends[idx].backend)
-
- self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number)
- self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp)
- self.assertEqual(coll_name, ends[idx].operation)
-
- check_op(0, "ALLGATHER")
- check_op(1, "ALLREDUCE")
-
-
-class GlooHooks(MultiProcessTestCase, CollectiveHooks):
- def setUp(self) -> None:
- super().setUp()
- self._spawn_processes()
-
- def tearDown(self):
- super().tearDown()
- try:
- os.remove(self.file_name)
- except OSError:
- pass
-
- def init_comms(self):
- dist.init_process_group(
- backend="gloo",
- rank=self.rank,
- world_size=self.world_size,
- store=dist.FileStore(self.file_name, self.world_size),
- )
-
- def destroy_comms(self):
- dist.destroy_process_group()
-
- @property
- def backend_name(self):
- return "gloo"
-
- @property
- def device(self):
- return "cpu"
-
- @with_comms
- def test_collective_hooks(self):
- self._collective_hooks()
-
-
-class NcclHooks(MultiProcessTestCase, CollectiveHooks):
- def setUp(self) -> None:
- super().setUp()
- self._spawn_processes()
-
- def tearDown(self):
- super().tearDown()
- try:
- os.remove(self.file_name)
- except OSError:
- pass
-
- def init_comms(self):
- dist.init_process_group(
- backend="nccl",
- rank=self.rank,
- world_size=self.world_size,
- store=dist.FileStore(self.file_name, self.world_size),
- )
-
- def destroy_comms(self):
- dist.destroy_process_group()
-
- @property
- def backend_name(self):
- return "nccl"
-
- @property
- def device(self):
- return f"cuda:{self.rank}"
-
- @skip_if_lt_x_gpu(4)
- @with_comms
- def test_collective_hooks(self):
- self._collective_hooks()
-
-
-class SingleRankTests(TestCase):
- def setUp(self) -> None:
- super().setUp()
- self.rank = 0
- self.file_name = tempfile.NamedTemporaryFile(delete=False).name
- dist.init_process_group(
- backend="gloo",
- rank=0,
- world_size=1,
- store=dist.FileStore(self.file_name, 1),
- )
-
- def tearDown(self) -> None:
- dist.destroy_process_group()
-
- def test_queue_overflow(self) -> None:
- cv_done_colls = threading.Condition()
- cv_done_cb = threading.Condition()
- colls_done = False
- starts = []
- status_with_dropped = None
-
- def coll_start(status: dhooks.CollectiveStatus):
- starts.append(status)
- with cv_done_colls:
- while not colls_done:
- cv_done_colls.wait()
- if status.drop_count > 0:
- nonlocal status_with_dropped
- status_with_dropped = status
- with cv_done_cb:
- cv_done_cb.notify()
-
- dhooks.register_collective_start_hook(coll_start)
-
- # native limit is 512
- for i in range(600):
- dist.all_reduce(torch.ones([2, 3]))
- colls_done = True
- with cv_done_colls:
- cv_done_colls.notify()
-
- with cv_done_cb:
- cv_done_cb.wait(10)
-
- self.assertTrue(status_with_dropped is not None)
- self.assertTrue(status_with_dropped.drop_count > 0)
-
-
-if __name__ == "__main__":
- assert (
- not torch.cuda._initialized
- ), "test_distributed must not have initialized CUDA context on main process"
-
- run_tests()
diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi
index 7672421..c0c5aed 100644
--- a/torch/_C/_distributed_c10d.pyi
+++ b/torch/_C/_distributed_c10d.pyi
@@ -11,10 +11,6 @@
_DEFAULT_NO_TIMEOUT: timedelta
_DEFAULT_PG_TIMEOUT: timedelta
-class EventKind(Enum):
- START = ...
- END = ...
-
class BuiltinCommHookType(Enum):
ALLREDUCE = ...
FP16_COMPRESS = ...
@@ -24,8 +20,6 @@
reducer: Reducer,
comm_hook_type: BuiltinCommHookType,
): ...
-def _dequeue_c10d_event() -> Dict[str, object]: ...
-def _enable_event_collection(pipe_fs: int) -> None: ...
class GradBucket:
def index(self) -> int: ...
diff --git a/torch/csrc/distributed/c10d/Backend.cpp b/torch/csrc/distributed/c10d/Backend.cpp
index aa6db14..9382c9c 100644
--- a/torch/csrc/distributed/c10d/Backend.cpp
+++ b/torch/csrc/distributed/c10d/Backend.cpp
@@ -1,26 +1,9 @@
#include <c10/util/Logging.h>
#include <fmt/format.h>
#include <torch/csrc/distributed/c10d/Backend.hpp>
-#include <torch/csrc/distributed/c10d/Hooks.hpp>
-#include <torch/csrc/distributed/c10d/logging.h>
namespace c10d {
-namespace {
-void commonEventinit(
- details::EventInfo& evt,
- const Backend& backend,
- const Work& work) {
- evt.timestamp =
- std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
- evt.pg_name = backend.getGroupName();
- evt.backend = backend.getBackendName();
- evt.sequence_number = work.getSequencenumber();
- evt.operation = c10d::opTypeToString(work.retrieveOpType());
- evt.drop_count = 0;
-}
-} // namespace
-
Backend::Backend(int rank, int size)
: rank_(rank), size_(size), dist_debug_level_(debug_level()) {
C10_LOG_API_USAGE_ONCE("c10d.backend");
@@ -32,21 +15,4 @@
C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName()));
}
-void Backend::emitCollectiveStart(const Work& work) {
- details::EventInfo evt;
- commonEventinit(evt, *this, work);
-
- evt.event_kind = ::c10d::EventKind::CollectiveStart;
- details::enqueue_c10d_event(std::move(evt));
-}
-
-void Backend::emitCollectiveEnd(const Work& work) {
- details::EventInfo evt;
- commonEventinit(evt, *this, work);
-
- evt.event_kind = ::c10d::EventKind::CollectiveEnd;
- evt.duration_ms = work.getDuration();
- details::enqueue_c10d_event(std::move(evt));
-}
-
} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp
index 60415db..58d06ba 100644
--- a/torch/csrc/distributed/c10d/Backend.hpp
+++ b/torch/csrc/distributed/c10d/Backend.hpp
@@ -366,8 +366,6 @@
// Implementations of this interface need to call this to setup
// appropriate logging etc.
void init();
- void emitCollectiveStart(const Work& work);
- void emitCollectiveEnd(const Work& work);
const int rank_;
const int size_;
diff --git a/torch/csrc/distributed/c10d/Hooks.cpp b/torch/csrc/distributed/c10d/Hooks.cpp
deleted file mode 100644
index 485b896..0000000
--- a/torch/csrc/distributed/c10d/Hooks.cpp
+++ /dev/null
@@ -1,60 +0,0 @@
-#include <atomic>
-
-#include <deque>
-#include <memory>
-#include <mutex>
-
-#ifndef _WIN32
-#include <unistd.h>
-#else
-#include <io.h>
-#endif
-
-#include <torch/csrc/distributed/c10d/Hooks.hpp>
-namespace c10d {
-
-namespace {
-
-std::atomic<bool> event_queue_enabled = false;
-int sync_pipe;
-std::mutex event_queue_lock;
-std::deque<details::EventInfo> event_queue;
-
-} // namespace
-
-void enable_event_collection(int pipe) {
- sync_pipe = pipe;
- event_queue_enabled.store(true);
-}
-
-namespace details {
-
-// we start dropping events after this
-const size_t MAX_QUEUE_SIZE = 512;
-
-bool dequeue_c10d_event(EventInfo& evt) {
- std::unique_lock<std::mutex> lock(event_queue_lock);
- if (event_queue.size() == 0) {
- return false;
- }
- evt = event_queue.front();
- event_queue.pop_front();
- return true;
-}
-
-void enqueue_c10d_event(EventInfo&& evt) {
- if (!event_queue_enabled.load())
- return;
-
- std::unique_lock<std::mutex> lock(event_queue_lock);
- if (event_queue.size() >= MAX_QUEUE_SIZE) {
- event_queue.back().drop_count++;
- } else {
- event_queue.push_back(std::move(evt));
- char m = 'x';
- write(sync_pipe, &m, 1);
- }
-}
-
-} // namespace details
-} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/Hooks.hpp b/torch/csrc/distributed/c10d/Hooks.hpp
deleted file mode 100644
index ad15940..0000000
--- a/torch/csrc/distributed/c10d/Hooks.hpp
+++ /dev/null
@@ -1,30 +0,0 @@
-#pragma once
-
-#include <c10/util/Optional.h>
-#include <string>
-
-namespace c10d {
-
-enum class EventKind { CollectiveStart, CollectiveEnd };
-
-TORCH_API void enable_event_collection(int sync_pipe);
-
-namespace details {
-
-struct TORCH_API EventInfo {
- EventKind event_kind;
- std::string pg_name;
- std::string backend;
- int64_t sequence_number;
- std::string operation;
- int64_t timestamp;
- c10::optional<float> duration_ms;
- int64_t drop_count;
-};
-
-// TODO do we want to expose something else here?
-TORCH_API bool dequeue_c10d_event(EventInfo& evt);
-TORCH_API void enqueue_c10d_event(EventInfo&& evt);
-
-} // namespace details
-} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp
index 4256110..ffec42a 100644
--- a/torch/csrc/distributed/c10d/ProcessGroup.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp
@@ -82,14 +82,10 @@
return "RECVANYSOURCE";
case OpType::BARRIER:
return "BARRIER";
- case OpType::_REDUCE_SCATTER_BASE:
- return "_REDUCE_SCATTER_BASE";
- case OpType::COALESCED:
- return "COALESCED";
- case OpType::_ALLREDUCE_SPARSE:
- return "_ALLREDUCE_SPARSE";
case OpType::UNKNOWN:
return "UNKNOWN";
+ case OpType::_REDUCE_SCATTER_BASE:
+ return "_REDUCE_SCATTER_BASE";
default:
TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
}
diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp
index 60ea6e3..13b76a6 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp
@@ -855,10 +855,6 @@
}
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
- emitCollectiveStart(*work.get());
- work->getFuture()->addCallback(
- [=](auto& f) { this->emitCollectiveEnd(*work.get()); });
-
std::unique_lock<std::mutex> lock(workMutex_);
workQueue_.push_back(std::move(work));
lock.unlock();
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 54681a4..ca8dd08 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -942,34 +942,27 @@
}
}
-void ProcessGroupNCCL::logWorkStart(WorkNCCL& work, bool emitDesyncInfo) {
- if (terminateProcessGroup_.load() || work.startTraceUpdated_)
+void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) {
+ if (work.startTraceUpdated_)
return;
+
+ if (terminateProcessGroup_.load() || storeError_)
+ return;
+
work.startTraceUpdated_ = true;
-
- emitCollectiveStart(work);
-
- if (!emitDesyncInfo || storeError_)
- return;
-
storeError_ = !c10d::traceUpdate(
store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_));
}
-void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work, bool emitDesyncInfo) {
- if (terminateProcessGroup_.load())
+void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) {
+ if (terminateProcessGroup_.load() || storeError_)
return;
// In case the start of the work hasn't been logged
if (!work.startTraceUpdated_) {
- logWorkStart(work, emitDesyncInfo);
+ logWorkStart(work);
}
- emitCollectiveEnd(work);
-
- if (!emitDesyncInfo || storeError_)
- return;
-
storeError_ = !c10d::traceUpdate(
store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_));
}
@@ -1021,11 +1014,13 @@
}
// Work status logging for desync debug
- if (work.isStarted()) {
- logWorkStart(work, desyncDebug_);
- }
- if (work.isCompleted()) {
- logWorkEnd(work, desyncDebug_);
+ if (desyncDebug_) {
+ if (work.isStarted()) {
+ logWorkStart(work);
+ }
+ if (work.isCompleted()) {
+ logWorkEnd(work);
+ }
}
// Clean up completed work
@@ -1082,7 +1077,7 @@
timeStarted, // timeStarted
std::chrono::system_clock::now(), // timeFinished
std::chrono::duration<float, std::milli>(
- work.getDuration().value()) // activeDuration
+ work.getDuration()) // activeDuration
));
lock.lock();
@@ -1585,19 +1580,19 @@
return future_;
}
-c10::optional<float> ProcessGroupNCCL::WorkNCCL::getDuration() const {
- if (!timingEnabled_ || !((*ncclEndEvents_)[0].query())) {
- return c10::optional<float>();
- }
+float ProcessGroupNCCL::WorkNCCL::getDuration() const {
+ TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled")
TORCH_CHECK(
ncclStartEvents_->size() == 1,
"getDuration only works for single device per ProcessGroup.");
TORCH_CHECK(
ncclEndEvents_->size() == 1,
"getDuration only works for single device per ProcessGroup.");
+ TORCH_CHECK(
+ (*ncclEndEvents_)[0].query(),
+ "getDuration can only be called after work is succeeded.")
return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]);
}
-
uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
return seq_;
}
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index d912d20..782b55f 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -167,7 +167,7 @@
// Get a Future object that will be marked as completed internally.
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
- c10::optional<float> getDuration() const override;
+ float getDuration() const override;
uint64_t getSequencenumber() const override;
@@ -615,10 +615,10 @@
void runHookLoop();
// Desync debug helper
- void logWorkStart(WorkNCCL& work, bool emitDesyncInfo);
+ void logWorkStart(WorkNCCL& work);
// Desync debug helper
- void logWorkEnd(WorkNCCL& work, bool emitDesyncInfo);
+ void logWorkEnd(WorkNCCL& work);
protected:
static const int64_t kWatchdogThreadSleepMillis;
diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp
index fb5b0f0..66c35b1 100644
--- a/torch/csrc/distributed/c10d/Work.cpp
+++ b/torch/csrc/distributed/c10d/Work.cpp
@@ -127,8 +127,8 @@
}
}
-c10::optional<float> Work::getDuration() const {
- return c10::optional<float>();
+float Work::getDuration() const {
+ TORCH_CHECK(false, "This Backend doesn't support getDuration.");
}
uint64_t Work::getSequencenumber() const {
diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp
index f1c695d..50c6ae0 100644
--- a/torch/csrc/distributed/c10d/Work.hpp
+++ b/torch/csrc/distributed/c10d/Work.hpp
@@ -107,7 +107,7 @@
// work. Only NCCL backend is currently supported.
virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
- virtual c10::optional<float> getDuration() const;
+ virtual float getDuration() const;
virtual uint64_t getSequencenumber() const;
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 680ccb4..f1ae7de 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -32,7 +32,6 @@
#include <fmt/format.h>
#include <pybind11/chrono.h>
-#include <torch/csrc/distributed/c10d/Hooks.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/comm.hpp>
@@ -292,26 +291,6 @@
reducer.register_builtin_comm_hook(comm_hook_type);
}
-py::object c10d_dequeue_python_event() {
- ::c10d::details::EventInfo evt;
- if (!::c10d::details::dequeue_c10d_event(evt)) {
- return py::none();
- }
-
- py::dict data;
- data["event_kind"] = (int)evt.event_kind;
-
- data["pg_name"] = evt.pg_name;
- data["backend"] = evt.backend;
- data["sequence_number"] = evt.sequence_number;
- data["operation"] = evt.operation;
- data["timestamp"] = evt.timestamp;
- data["duration"] = evt.duration_ms.value_or(-1);
- data["drop_count"] = evt.drop_count;
-
- return std::move(data);
-}
-
// Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
// https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
// struct from enum, sacrificing some of the Python built-in function supports
@@ -661,11 +640,6 @@
&::c10d::Logger::set_static_graph,
py::call_guard<py::gil_scoped_release>());
- py::enum_<::c10d::EventKind>(module, "EventKind", R"(
-An enum for collective hooks event types.)")
- .value("START", ::c10d::EventKind::CollectiveStart)
- .value("END", ::c10d::EventKind::CollectiveEnd);
-
py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"(
An enum whose values correspond to different debug levels of the
torch.distributed package. Currently supporting OFF, INFO, and DETAIL,
@@ -689,16 +663,7 @@
"set_debug_level_from_env",
::c10d::setDebugLevelFromEnvironment,
R"(Sets the debug level of the torch.distributed package from the
- ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)")
- .def(
- "_enable_event_collection",
- &::c10d::enable_event_collection,
- "(Enables events collection).",
- py::call_guard<py::gil_scoped_release>())
- .def(
- "_dequeue_c10d_event",
- &c10d_dequeue_python_event,
- "(Blocks until a c10d event is available and return it as a python dictionary).");
+ ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)");
// TODO(crcrpar): Hardening `ReduceOp`.
// While keeping most op types as enum value,
diff --git a/torch/distributed/_hooks.py b/torch/distributed/_hooks.py
deleted file mode 100644
index 76cede5..0000000
--- a/torch/distributed/_hooks.py
+++ /dev/null
@@ -1,171 +0,0 @@
-import logging
-import os
-import threading
-import time
-from dataclasses import dataclass
-from typing import Callable, Dict, List, Optional
-
-import torch.distributed as dist
-import torch.distributed.distributed_c10d as c10d
-
-from torch._C._distributed_c10d import (
- _dequeue_c10d_event,
- _enable_event_collection,
- EventKind,
-)
-
-__all__ = [
- "CollectiveStatus",
- "COLLECTIVE_HOOK_TYPE",
- "PG_HOOK_TYPE",
- "register_collective_start_hook",
- "register_collective_end_hook",
- "register_process_group_hook",
-]
-
-
-@dataclass
-class CollectiveStatus:
- r"""Status of a collective operation.
-
- Drop count indicates events have been dropped at the producer side, which means they
- were not consumed fast enough by hooks.
- """
- pg_name: str = "unknown" # This name matches the one informed in the pgreg cb
- backend: str = "unknown" # Name of the backend used
- sequence_number: int = -1 # This name matches the one informed in the pgreg cb
- operation: str = "unknown" # collective name
- timestamp: int = 0 # timestamp to the earliest time we noticed this event
- duration: Optional[float] = None # value in milliseconds it took executing
- drop_count: int = 0 # number of events dropped following this one
-
-
-COLLECTIVE_HOOK_TYPE = Callable[[CollectiveStatus], None]
-PG_HOOK_TYPE = Callable[[dist.ProcessGroup, str], None]
-
-# This controls the number of internal failures we'll tolerate before giving up
-_MAX_INTERNAL_FAILURES = 10
-
-logger = logging.getLogger(__name__)
-_cb_thread: Optional[threading.Thread] = None
-_start_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
-_end_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
-_pp_r = -1
-_pp_w = -1
-
-
-def _c10d_pg_hooks_loops():
- internal_failures = 0
- while True:
- # we don't care about the result, this is how we implement notification
- _ = os.read(_pp_r, 1)
- evt: Dict[str, object] = _dequeue_c10d_event()
- try:
- event_kind = evt.pop("event_kind", None)
- if event_kind is None:
- logger.warning(
- "c10d returned event dictionary %s without 'event_kind' key, cannot dispatch",
- evt,
- )
- internal_failures += 1
- if internal_failures >= _MAX_INTERNAL_FAILURES:
- logger.warning(
- "too many internal c10d failures processing callback loop. stopping"
- )
- return
- time.sleep(1)
- continue
-
- if event_kind == int(EventKind.START): # type: ignore[call-overload]
- cb_list = _start_callbacks
- elif event_kind == int(EventKind.END): # type: ignore[call-overload]
- cb_list = _end_callbacks
- else:
- logger.warning(
- "c10d event %s with invalid 'event_kind' with value %d",
- evt,
- event_kind,
- )
- internal_failures += 1
- if internal_failures >= _MAX_INTERNAL_FAILURES:
- logger.warning(
- "too many internal c10d failures processing callback loop. stopping"
- )
- return
- time.sleep(1)
- continue
-
- status = CollectiveStatus(**evt) # type: ignore[arg-type]
- for cb in cb_list:
- try:
- cb(status)
- except Exception as e:
- logger.info(
- "c10d event callback %s with event %s threw exception %s",
- cb,
- status,
- e,
- )
- except Exception as e:
- # We have to keep processing otherwise the queue will grown infinitely large
- logger.warning(
- "c10d callback thread when processing event %s raised exception %s.",
- evt,
- e,
- )
- internal_failures += 1
- if internal_failures >= _MAX_INTERNAL_FAILURES:
- logger.warning(
- "too many internal c10d failures processing callback loop. stopping"
- )
- return
-
- # Sleep for a second to avoid hogging the GIL in case of a persistent failure
- time.sleep(1)
-
-
-def _lazy_init():
- global _cb_thread
- if _cb_thread is not None:
- return
- global _pp_r
- global _pp_w
- _pp_r, _pp_w = os.pipe()
- _enable_event_collection(_pp_w)
- c10d._enable_collectives_timing()
- _cb_thread = threading.Thread(target=_c10d_pg_hooks_loops, daemon=True)
- _cb_thread.start()
- logger.info("c10d::hooks thread enabled")
-
-
-def register_collective_start_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
- r"""Register a hook that is called every time a collective starts.
-
- The hook is invoked on a background thread.
- Exceptions raised by the callback are ignored and non-fatal.
- """
- _start_callbacks.append(hook)
- _lazy_init()
-
-
-def register_collective_end_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
- r"""Register a hook that is called every time a collective finishes.
-
-
- The hook is invoked on a background thread.
- Exceptions raised by the callback are ignored and non-fatal.
- """
- _end_callbacks.append(hook)
- _lazy_init()
-
-
-def register_process_group_hook(hook: PG_HOOK_TYPE) -> None:
- r"""Register a hook that is called every time a process group is created on this rank.
-
- This hook is only invoked if the current rank is part of the PG being created.
- The pg_name is unique to the whole cluster and should be treated as an opaque identified subject to change.
-
- The hook is invoked on a background thread.
- Exceptions raised by the callback are ignored and non-fatal.
- """
- c10d._register_creation_hook(hook)
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 27e960a..232fb42 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -421,19 +421,6 @@
_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
_pg_to_tag: Dict[ProcessGroup, str] = {}
-class _HookState:
- def __init__(self):
- self.creation_hooks = []
-
- def register_creation_hook(self, hook) -> None:
- self.creation_hooks.append(hook)
-
- def fire_creation_hook(self, pg, name) -> None:
- for hook in self.creation_hooks:
- try:
- hook(pg, name)
- except Exception as e:
- logger.info("hook %s failed with %s", hook, e)
class _World:
"""
@@ -447,8 +434,6 @@
self._default_pg = None
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
- self._hook_state = _HookState()
- self.enable_collectives_timing = False
@property
def default_pg(self):
@@ -558,9 +543,6 @@
)
return config_info
- @property
- def pg_hook_state(self) -> _HookState:
- return self._hook_state
_world = _World()
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
@@ -1382,9 +1364,6 @@
pg._set_group_name(group_name)
_world.pg_backend_config[pg] = str(backend_config)
- if _world.enable_collectives_timing:
- pg._enable_collectives_timing()
-
# "" is the default tag for user PGs
if pg_tag in [None, ""]:
pg_tag = f"ptd:{group_name}"
@@ -1394,8 +1373,6 @@
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
_world.pg_to_tag[pg] = pg_tag
- _world.pg_hook_state.fire_creation_hook(pg, group_name)
-
return pg, prefix_store
def destroy_process_group(group: Optional[ProcessGroup] = None):
@@ -4338,11 +4315,3 @@
reduce_scatter_tensor,
send,
]
-
-def _register_creation_hook(hook):
- _world.pg_hook_state.register_creation_hook(hook)
-
-def _enable_collectives_timing():
- _world.enable_collectives_timing = True
- for pg in _world.pg_map:
- pg._enable_collectives_timing()
diff --git a/torch/testing/_internal/distributed/multi_threaded_pg.py b/torch/testing/_internal/distributed/multi_threaded_pg.py
index 11b0993..a1044e6 100644
--- a/torch/testing/_internal/distributed/multi_threaded_pg.py
+++ b/torch/testing/_internal/distributed/multi_threaded_pg.py
@@ -401,11 +401,6 @@
class ThreadLocalWorld:
_world = threading.local()
- def __init__(self):
- self.enable_collectives_timing = False
- self._hook_state = dist.distributed_c10d._HookState()
-
-
def _get_world(self) -> WorldData:
if not hasattr(ThreadLocalWorld._world, "world"):
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {})
@@ -459,9 +454,6 @@
def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]:
return self._get_world().pg_default_device
- @property
- def pg_hook_state(self) -> dist.distributed_c10d._HookState:
- return self._hook_state
_old_pg_world = None
_ctx_manager = None