Revert "Reland "[C10] PG observability hooks. (#108815)" (#110907)"

This reverts commit 7678cd22af46c9df4fb47a409d3e8ad71a6127ea.

Reverted https://github.com/pytorch/pytorch/pull/110907 on behalf of https://github.com/huydhn due to Sorry for reverting this, but macos job in trunk starts failing after this https://hud.pytorch.org/pytorch/pytorch/commit/7678cd22af46c9df4fb47a409d3e8ad71a6127ea ([comment](https://github.com/pytorch/pytorch/pull/110907#issuecomment-1756497387))
diff --git a/build_variables.bzl b/build_variables.bzl
index c6592e0..00baad9 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -521,7 +521,6 @@
     "torch/csrc/distributed/c10d/Backend.cpp",
     "torch/csrc/distributed/c10d/FileStore.cpp",
     "torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
-    "torch/csrc/distributed/c10d/Hooks.cpp",
     "torch/csrc/distributed/c10d/Ops.cpp",
     "torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
     "torch/csrc/distributed/c10d/PrefixStore.cpp",
diff --git a/test/distributed/test_hooks.py b/test/distributed/test_hooks.py
deleted file mode 100644
index aa3c489..0000000
--- a/test/distributed/test_hooks.py
+++ /dev/null
@@ -1,270 +0,0 @@
-# Owner(s): ["oncall: distributed"]
-
-import os
-import sys
-import tempfile
-import threading
-from functools import partial, wraps
-
-import torch
-import torch.distributed as dist
-import torch.distributed._hooks as dhooks
-
-if not dist.is_available():
-    print("torch.distributed not available, skipping tests", file=sys.stderr)
-    sys.exit(0)
-
-
-from torch.testing._internal.common_distributed import (
-    MultiProcessTestCase,
-    skip_if_lt_x_gpu,
-)
-
-from torch.testing._internal.common_utils import run_tests, TestCase
-
-
-class PgHooks(MultiProcessTestCase):
-    @property
-    def world_size(self) -> int:
-        return 4
-
-    def setUp(self) -> None:
-        super().setUp()
-        self._spawn_processes()
-
-    def tearDown(self):
-        super().tearDown()
-        try:
-            os.remove(self.file_name)
-        except OSError:
-            pass
-
-    def test_pg_hook(self):
-        pgs = []
-
-        def pg_hook(pg, pg_name):
-            pgs.append((pg, pg_name))
-
-        dhooks.register_process_group_hook(pg_hook)
-        dist.init_process_group(
-            backend="gloo",
-            rank=self.rank,
-            world_size=self.world_size,
-            store=dist.FileStore(self.file_name, self.world_size),
-        )
-        self.assertEqual(len(pgs), 1)
-        self.assertEqual(pgs[0][0], dist.group.WORLD)
-
-        # create two partial world PGs
-        pg0 = dist.new_group(ranks=[0, 1])
-        pg1 = dist.new_group(ranks=[2, 3])
-
-        # Each rank only observe two PGs being created: the default PG and one covering its ranks
-        # We don't emit events for PG creation if the current rank doesn't belong to it.
-        # For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact
-        # dictates you need to call new_group for both.
-        self.assertEqual(len(pgs), 2)
-        self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1)
-
-
-def with_comms(func=None):
-    if func is None:
-        return partial(
-            with_comms,
-        )
-
-    @wraps(func)
-    def wrapper(self, *args, **kwargs):
-        self.init_comms()
-        func(self, *args, **kwargs)
-        self.destroy_comms()
-
-    return wrapper
-
-
-class CollectiveHooks:
-    @property
-    def world_size(self) -> int:
-        return 4
-
-    def _collective_hooks(self):
-        # it's ok to access them directly since there's a single bg thread poking at them.
-        starts = []
-        ends = []
-        cv = threading.Condition()
-
-        def coll_start(status):
-            starts.append(status)
-            print(f"col_start {len(starts)} rank{self.rank}")
-
-        def coll_end(status):
-            ends.append(status)
-            print(f"col_end {len(ends)} rank{self.rank}")
-            if len(ends) == 2:
-                with cv:
-                    cv.notify()
-
-        dhooks.register_collective_start_hook(coll_start)
-        dhooks.register_collective_end_hook(coll_end)
-
-        tensor = torch.ones([2, 3]).to(self.device) * self.rank
-        tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
-
-        dist.all_gather(tensor_list, tensor)
-
-        tensor2 = torch.ones([2, 3]).to(self.device) * self.rank
-        dist.all_reduce(tensor2)
-
-        with cv:
-            cv.wait(1)
-
-        default_pg_name = dist.group.WORLD.group_name
-        self.assertEqual(2, len(starts))
-        self.assertEqual(2, len(ends))
-
-        def check_op(idx, coll_name):
-            self.assertEqual(default_pg_name, starts[idx].pg_name)
-            self.assertEqual(self.backend_name, starts[idx].backend)
-            self.assertGreaterEqual(starts[idx].sequence_number, 0)
-            self.assertGreaterEqual(starts[idx].timestamp, 0)
-            self.assertEqual(coll_name, starts[idx].operation)
-
-            self.assertEqual(default_pg_name, ends[idx].pg_name)
-            self.assertEqual(self.backend_name, ends[idx].backend)
-
-            self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number)
-            self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp)
-            self.assertEqual(coll_name, ends[idx].operation)
-
-        check_op(0, "ALLGATHER")
-        check_op(1, "ALLREDUCE")
-
-
-class GlooHooks(MultiProcessTestCase, CollectiveHooks):
-    def setUp(self) -> None:
-        super().setUp()
-        self._spawn_processes()
-
-    def tearDown(self):
-        super().tearDown()
-        try:
-            os.remove(self.file_name)
-        except OSError:
-            pass
-
-    def init_comms(self):
-        dist.init_process_group(
-            backend="gloo",
-            rank=self.rank,
-            world_size=self.world_size,
-            store=dist.FileStore(self.file_name, self.world_size),
-        )
-
-    def destroy_comms(self):
-        dist.destroy_process_group()
-
-    @property
-    def backend_name(self):
-        return "gloo"
-
-    @property
-    def device(self):
-        return "cpu"
-
-    @with_comms
-    def test_collective_hooks(self):
-        self._collective_hooks()
-
-
-class NcclHooks(MultiProcessTestCase, CollectiveHooks):
-    def setUp(self) -> None:
-        super().setUp()
-        self._spawn_processes()
-
-    def tearDown(self):
-        super().tearDown()
-        try:
-            os.remove(self.file_name)
-        except OSError:
-            pass
-
-    def init_comms(self):
-        dist.init_process_group(
-            backend="nccl",
-            rank=self.rank,
-            world_size=self.world_size,
-            store=dist.FileStore(self.file_name, self.world_size),
-        )
-
-    def destroy_comms(self):
-        dist.destroy_process_group()
-
-    @property
-    def backend_name(self):
-        return "nccl"
-
-    @property
-    def device(self):
-        return f"cuda:{self.rank}"
-
-    @skip_if_lt_x_gpu(4)
-    @with_comms
-    def test_collective_hooks(self):
-        self._collective_hooks()
-
-
-class SingleRankTests(TestCase):
-    def setUp(self) -> None:
-        super().setUp()
-        self.rank = 0
-        self.file_name = tempfile.NamedTemporaryFile(delete=False).name
-        dist.init_process_group(
-            backend="gloo",
-            rank=0,
-            world_size=1,
-            store=dist.FileStore(self.file_name, 1),
-        )
-
-    def tearDown(self) -> None:
-        dist.destroy_process_group()
-
-    def test_queue_overflow(self) -> None:
-        cv_done_colls = threading.Condition()
-        cv_done_cb = threading.Condition()
-        colls_done = False
-        starts = []
-        status_with_dropped = None
-
-        def coll_start(status: dhooks.CollectiveStatus):
-            starts.append(status)
-            with cv_done_colls:
-                while not colls_done:
-                    cv_done_colls.wait()
-            if status.drop_count > 0:
-                nonlocal status_with_dropped
-                status_with_dropped = status
-                with cv_done_cb:
-                    cv_done_cb.notify()
-
-        dhooks.register_collective_start_hook(coll_start)
-
-        # native limit is 512
-        for i in range(600):
-            dist.all_reduce(torch.ones([2, 3]))
-        colls_done = True
-        with cv_done_colls:
-            cv_done_colls.notify()
-
-        with cv_done_cb:
-            cv_done_cb.wait(10)
-
-        self.assertTrue(status_with_dropped is not None)
-        self.assertTrue(status_with_dropped.drop_count > 0)
-
-
-if __name__ == "__main__":
-    assert (
-        not torch.cuda._initialized
-    ), "test_distributed must not have initialized CUDA context on main process"
-
-    run_tests()
diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi
index 7672421..c0c5aed 100644
--- a/torch/_C/_distributed_c10d.pyi
+++ b/torch/_C/_distributed_c10d.pyi
@@ -11,10 +11,6 @@
 _DEFAULT_NO_TIMEOUT: timedelta
 _DEFAULT_PG_TIMEOUT: timedelta
 
-class EventKind(Enum):
-    START = ...
-    END = ...
-
 class BuiltinCommHookType(Enum):
     ALLREDUCE = ...
     FP16_COMPRESS = ...
@@ -24,8 +20,6 @@
     reducer: Reducer,
     comm_hook_type: BuiltinCommHookType,
 ): ...
-def _dequeue_c10d_event() -> Dict[str, object]: ...
-def _enable_event_collection(pipe_fs: int) -> None: ...
 
 class GradBucket:
     def index(self) -> int: ...
diff --git a/torch/csrc/distributed/c10d/Backend.cpp b/torch/csrc/distributed/c10d/Backend.cpp
index aa6db14..9382c9c 100644
--- a/torch/csrc/distributed/c10d/Backend.cpp
+++ b/torch/csrc/distributed/c10d/Backend.cpp
@@ -1,26 +1,9 @@
 #include <c10/util/Logging.h>
 #include <fmt/format.h>
 #include <torch/csrc/distributed/c10d/Backend.hpp>
-#include <torch/csrc/distributed/c10d/Hooks.hpp>
-#include <torch/csrc/distributed/c10d/logging.h>
 
 namespace c10d {
 
-namespace {
-void commonEventinit(
-    details::EventInfo& evt,
-    const Backend& backend,
-    const Work& work) {
-  evt.timestamp =
-      std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
-  evt.pg_name = backend.getGroupName();
-  evt.backend = backend.getBackendName();
-  evt.sequence_number = work.getSequencenumber();
-  evt.operation = c10d::opTypeToString(work.retrieveOpType());
-  evt.drop_count = 0;
-}
-} // namespace
-
 Backend::Backend(int rank, int size)
     : rank_(rank), size_(size), dist_debug_level_(debug_level()) {
   C10_LOG_API_USAGE_ONCE("c10d.backend");
@@ -32,21 +15,4 @@
   C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName()));
 }
 
-void Backend::emitCollectiveStart(const Work& work) {
-  details::EventInfo evt;
-  commonEventinit(evt, *this, work);
-
-  evt.event_kind = ::c10d::EventKind::CollectiveStart;
-  details::enqueue_c10d_event(std::move(evt));
-}
-
-void Backend::emitCollectiveEnd(const Work& work) {
-  details::EventInfo evt;
-  commonEventinit(evt, *this, work);
-
-  evt.event_kind = ::c10d::EventKind::CollectiveEnd;
-  evt.duration_ms = work.getDuration();
-  details::enqueue_c10d_event(std::move(evt));
-}
-
 } // namespace c10d
diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp
index 60415db..58d06ba 100644
--- a/torch/csrc/distributed/c10d/Backend.hpp
+++ b/torch/csrc/distributed/c10d/Backend.hpp
@@ -366,8 +366,6 @@
   // Implementations of this interface need to call this to setup
   // appropriate logging etc.
   void init();
-  void emitCollectiveStart(const Work& work);
-  void emitCollectiveEnd(const Work& work);
 
   const int rank_;
   const int size_;
diff --git a/torch/csrc/distributed/c10d/Hooks.cpp b/torch/csrc/distributed/c10d/Hooks.cpp
deleted file mode 100644
index 485b896..0000000
--- a/torch/csrc/distributed/c10d/Hooks.cpp
+++ /dev/null
@@ -1,60 +0,0 @@
-#include <atomic>
-
-#include <deque>
-#include <memory>
-#include <mutex>
-
-#ifndef _WIN32
-#include <unistd.h>
-#else
-#include <io.h>
-#endif
-
-#include <torch/csrc/distributed/c10d/Hooks.hpp>
-namespace c10d {
-
-namespace {
-
-std::atomic<bool> event_queue_enabled = false;
-int sync_pipe;
-std::mutex event_queue_lock;
-std::deque<details::EventInfo> event_queue;
-
-} // namespace
-
-void enable_event_collection(int pipe) {
-  sync_pipe = pipe;
-  event_queue_enabled.store(true);
-}
-
-namespace details {
-
-// we start dropping events after this
-const size_t MAX_QUEUE_SIZE = 512;
-
-bool dequeue_c10d_event(EventInfo& evt) {
-  std::unique_lock<std::mutex> lock(event_queue_lock);
-  if (event_queue.size() == 0) {
-    return false;
-  }
-  evt = event_queue.front();
-  event_queue.pop_front();
-  return true;
-}
-
-void enqueue_c10d_event(EventInfo&& evt) {
-  if (!event_queue_enabled.load())
-    return;
-
-  std::unique_lock<std::mutex> lock(event_queue_lock);
-  if (event_queue.size() >= MAX_QUEUE_SIZE) {
-    event_queue.back().drop_count++;
-  } else {
-    event_queue.push_back(std::move(evt));
-    char m = 'x';
-    write(sync_pipe, &m, 1);
-  }
-}
-
-} // namespace details
-} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/Hooks.hpp b/torch/csrc/distributed/c10d/Hooks.hpp
deleted file mode 100644
index ad15940..0000000
--- a/torch/csrc/distributed/c10d/Hooks.hpp
+++ /dev/null
@@ -1,30 +0,0 @@
-#pragma once
-
-#include <c10/util/Optional.h>
-#include <string>
-
-namespace c10d {
-
-enum class EventKind { CollectiveStart, CollectiveEnd };
-
-TORCH_API void enable_event_collection(int sync_pipe);
-
-namespace details {
-
-struct TORCH_API EventInfo {
-  EventKind event_kind;
-  std::string pg_name;
-  std::string backend;
-  int64_t sequence_number;
-  std::string operation;
-  int64_t timestamp;
-  c10::optional<float> duration_ms;
-  int64_t drop_count;
-};
-
-// TODO do we want to expose something else here?
-TORCH_API bool dequeue_c10d_event(EventInfo& evt);
-TORCH_API void enqueue_c10d_event(EventInfo&& evt);
-
-} // namespace details
-} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp
index 4256110..ffec42a 100644
--- a/torch/csrc/distributed/c10d/ProcessGroup.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp
@@ -82,14 +82,10 @@
       return "RECVANYSOURCE";
     case OpType::BARRIER:
       return "BARRIER";
-    case OpType::_REDUCE_SCATTER_BASE:
-      return "_REDUCE_SCATTER_BASE";
-    case OpType::COALESCED:
-      return "COALESCED";
-    case OpType::_ALLREDUCE_SPARSE:
-      return "_ALLREDUCE_SPARSE";
     case OpType::UNKNOWN:
       return "UNKNOWN";
+    case OpType::_REDUCE_SCATTER_BASE:
+      return "_REDUCE_SCATTER_BASE";
     default:
       TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
   }
diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp
index 60ea6e3..13b76a6 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp
@@ -855,10 +855,6 @@
 }
 
 void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
-  emitCollectiveStart(*work.get());
-  work->getFuture()->addCallback(
-      [=](auto& f) { this->emitCollectiveEnd(*work.get()); });
-
   std::unique_lock<std::mutex> lock(workMutex_);
   workQueue_.push_back(std::move(work));
   lock.unlock();
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 54681a4..ca8dd08 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -942,34 +942,27 @@
   }
 }
 
-void ProcessGroupNCCL::logWorkStart(WorkNCCL& work, bool emitDesyncInfo) {
-  if (terminateProcessGroup_.load() || work.startTraceUpdated_)
+void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) {
+  if (work.startTraceUpdated_)
     return;
+
+  if (terminateProcessGroup_.load() || storeError_)
+    return;
+
   work.startTraceUpdated_ = true;
-
-  emitCollectiveStart(work);
-
-  if (!emitDesyncInfo || storeError_)
-    return;
-
   storeError_ = !c10d::traceUpdate(
       store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_));
 }
 
-void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work, bool emitDesyncInfo) {
-  if (terminateProcessGroup_.load())
+void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) {
+  if (terminateProcessGroup_.load() || storeError_)
     return;
 
   // In case the start of the work hasn't been logged
   if (!work.startTraceUpdated_) {
-    logWorkStart(work, emitDesyncInfo);
+    logWorkStart(work);
   }
 
-  emitCollectiveEnd(work);
-
-  if (!emitDesyncInfo || storeError_)
-    return;
-
   storeError_ = !c10d::traceUpdate(
       store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_));
 }
@@ -1021,11 +1014,13 @@
       }
 
       // Work status logging for desync debug
-      if (work.isStarted()) {
-        logWorkStart(work, desyncDebug_);
-      }
-      if (work.isCompleted()) {
-        logWorkEnd(work, desyncDebug_);
+      if (desyncDebug_) {
+        if (work.isStarted()) {
+          logWorkStart(work);
+        }
+        if (work.isCompleted()) {
+          logWorkEnd(work);
+        }
       }
 
       // Clean up completed work
@@ -1082,7 +1077,7 @@
             timeStarted, // timeStarted
             std::chrono::system_clock::now(), // timeFinished
             std::chrono::duration<float, std::milli>(
-                work.getDuration().value()) // activeDuration
+                work.getDuration()) // activeDuration
             ));
 
         lock.lock();
@@ -1585,19 +1580,19 @@
   return future_;
 }
 
-c10::optional<float> ProcessGroupNCCL::WorkNCCL::getDuration() const {
-  if (!timingEnabled_ || !((*ncclEndEvents_)[0].query())) {
-    return c10::optional<float>();
-  }
+float ProcessGroupNCCL::WorkNCCL::getDuration() const {
+  TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled")
   TORCH_CHECK(
       ncclStartEvents_->size() == 1,
       "getDuration only works for single device per ProcessGroup.");
   TORCH_CHECK(
       ncclEndEvents_->size() == 1,
       "getDuration only works for single device per ProcessGroup.");
+  TORCH_CHECK(
+      (*ncclEndEvents_)[0].query(),
+      "getDuration can only be called after work is succeeded.")
   return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]);
 }
-
 uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
   return seq_;
 }
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index d912d20..782b55f 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -167,7 +167,7 @@
     // Get a Future object that will be marked as completed internally.
     c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
 
-    c10::optional<float> getDuration() const override;
+    float getDuration() const override;
 
     uint64_t getSequencenumber() const override;
 
@@ -615,10 +615,10 @@
   void runHookLoop();
 
   // Desync debug helper
-  void logWorkStart(WorkNCCL& work, bool emitDesyncInfo);
+  void logWorkStart(WorkNCCL& work);
 
   // Desync debug helper
-  void logWorkEnd(WorkNCCL& work, bool emitDesyncInfo);
+  void logWorkEnd(WorkNCCL& work);
 
  protected:
   static const int64_t kWatchdogThreadSleepMillis;
diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp
index fb5b0f0..66c35b1 100644
--- a/torch/csrc/distributed/c10d/Work.cpp
+++ b/torch/csrc/distributed/c10d/Work.cpp
@@ -127,8 +127,8 @@
   }
 }
 
-c10::optional<float> Work::getDuration() const {
-  return c10::optional<float>();
+float Work::getDuration() const {
+  TORCH_CHECK(false, "This Backend doesn't support getDuration.");
 }
 
 uint64_t Work::getSequencenumber() const {
diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp
index f1c695d..50c6ae0 100644
--- a/torch/csrc/distributed/c10d/Work.hpp
+++ b/torch/csrc/distributed/c10d/Work.hpp
@@ -107,7 +107,7 @@
   // work. Only NCCL backend is currently supported.
   virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
 
-  virtual c10::optional<float> getDuration() const;
+  virtual float getDuration() const;
 
   virtual uint64_t getSequencenumber() const;
 
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 680ccb4..f1ae7de 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -32,7 +32,6 @@
 
 #include <fmt/format.h>
 #include <pybind11/chrono.h>
-#include <torch/csrc/distributed/c10d/Hooks.hpp>
 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
 
 #include <torch/csrc/distributed/c10d/comm.hpp>
@@ -292,26 +291,6 @@
   reducer.register_builtin_comm_hook(comm_hook_type);
 }
 
-py::object c10d_dequeue_python_event() {
-  ::c10d::details::EventInfo evt;
-  if (!::c10d::details::dequeue_c10d_event(evt)) {
-    return py::none();
-  }
-
-  py::dict data;
-  data["event_kind"] = (int)evt.event_kind;
-
-  data["pg_name"] = evt.pg_name;
-  data["backend"] = evt.backend;
-  data["sequence_number"] = evt.sequence_number;
-  data["operation"] = evt.operation;
-  data["timestamp"] = evt.timestamp;
-  data["duration"] = evt.duration_ms.value_or(-1);
-  data["drop_count"] = evt.drop_count;
-
-  return std::move(data);
-}
-
 // Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
 // https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
 // struct from enum, sacrificing some of the Python built-in function supports
@@ -661,11 +640,6 @@
           &::c10d::Logger::set_static_graph,
           py::call_guard<py::gil_scoped_release>());
 
-  py::enum_<::c10d::EventKind>(module, "EventKind", R"(
-An enum for collective hooks event types.)")
-      .value("START", ::c10d::EventKind::CollectiveStart)
-      .value("END", ::c10d::EventKind::CollectiveEnd);
-
   py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"(
       An enum whose values correspond to different debug levels of the
       torch.distributed package. Currently supporting OFF, INFO, and DETAIL,
@@ -689,16 +663,7 @@
           "set_debug_level_from_env",
           ::c10d::setDebugLevelFromEnvironment,
           R"(Sets the debug level of the torch.distributed package from the
-          ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)")
-      .def(
-          "_enable_event_collection",
-          &::c10d::enable_event_collection,
-          "(Enables events collection).",
-          py::call_guard<py::gil_scoped_release>())
-      .def(
-          "_dequeue_c10d_event",
-          &c10d_dequeue_python_event,
-          "(Blocks until a c10d event is available and return it as a python dictionary).");
+          ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)");
 
   // TODO(crcrpar): Hardening `ReduceOp`.
   //    While keeping most op types as enum value,
diff --git a/torch/distributed/_hooks.py b/torch/distributed/_hooks.py
deleted file mode 100644
index 76cede5..0000000
--- a/torch/distributed/_hooks.py
+++ /dev/null
@@ -1,171 +0,0 @@
-import logging
-import os
-import threading
-import time
-from dataclasses import dataclass
-from typing import Callable, Dict, List, Optional
-
-import torch.distributed as dist
-import torch.distributed.distributed_c10d as c10d
-
-from torch._C._distributed_c10d import (
-    _dequeue_c10d_event,
-    _enable_event_collection,
-    EventKind,
-)
-
-__all__ = [
-    "CollectiveStatus",
-    "COLLECTIVE_HOOK_TYPE",
-    "PG_HOOK_TYPE",
-    "register_collective_start_hook",
-    "register_collective_end_hook",
-    "register_process_group_hook",
-]
-
-
-@dataclass
-class CollectiveStatus:
-    r"""Status of a collective operation.
-
-    Drop count indicates events have been dropped at the producer side, which means they
-    were not consumed fast enough by hooks.
-    """
-    pg_name: str = "unknown"  # This name matches the one informed in the pgreg cb
-    backend: str = "unknown"  # Name of the backend used
-    sequence_number: int = -1  # This name matches the one informed in the pgreg cb
-    operation: str = "unknown"  # collective name
-    timestamp: int = 0  # timestamp to the earliest time we noticed this event
-    duration: Optional[float] = None  # value in milliseconds it took executing
-    drop_count: int = 0  # number of events dropped following this one
-
-
-COLLECTIVE_HOOK_TYPE = Callable[[CollectiveStatus], None]
-PG_HOOK_TYPE = Callable[[dist.ProcessGroup, str], None]
-
-# This controls the number of internal failures we'll tolerate before giving up
-_MAX_INTERNAL_FAILURES = 10
-
-logger = logging.getLogger(__name__)
-_cb_thread: Optional[threading.Thread] = None
-_start_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
-_end_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
-_pp_r = -1
-_pp_w = -1
-
-
-def _c10d_pg_hooks_loops():
-    internal_failures = 0
-    while True:
-        # we don't care about the result, this is how we implement notification
-        _ = os.read(_pp_r, 1)
-        evt: Dict[str, object] = _dequeue_c10d_event()
-        try:
-            event_kind = evt.pop("event_kind", None)
-            if event_kind is None:
-                logger.warning(
-                    "c10d returned event dictionary %s without 'event_kind' key, cannot dispatch",
-                    evt,
-                )
-                internal_failures += 1
-                if internal_failures >= _MAX_INTERNAL_FAILURES:
-                    logger.warning(
-                        "too many internal c10d failures processing callback loop. stopping"
-                    )
-                    return
-                time.sleep(1)
-                continue
-
-            if event_kind == int(EventKind.START):  # type: ignore[call-overload]
-                cb_list = _start_callbacks
-            elif event_kind == int(EventKind.END):  # type: ignore[call-overload]
-                cb_list = _end_callbacks
-            else:
-                logger.warning(
-                    "c10d event %s with invalid 'event_kind' with value %d",
-                    evt,
-                    event_kind,
-                )
-                internal_failures += 1
-                if internal_failures >= _MAX_INTERNAL_FAILURES:
-                    logger.warning(
-                        "too many internal c10d failures processing callback loop. stopping"
-                    )
-                    return
-                time.sleep(1)
-                continue
-
-            status = CollectiveStatus(**evt)  # type: ignore[arg-type]
-            for cb in cb_list:
-                try:
-                    cb(status)
-                except Exception as e:
-                    logger.info(
-                        "c10d event callback %s with event %s threw exception %s",
-                        cb,
-                        status,
-                        e,
-                    )
-        except Exception as e:
-            # We have to keep processing otherwise the queue will grown infinitely large
-            logger.warning(
-                "c10d callback thread when processing event %s raised exception %s.",
-                evt,
-                e,
-            )
-            internal_failures += 1
-            if internal_failures >= _MAX_INTERNAL_FAILURES:
-                logger.warning(
-                    "too many internal c10d failures processing callback loop. stopping"
-                )
-                return
-
-            # Sleep for a second to avoid hogging the GIL in case of a persistent failure
-            time.sleep(1)
-
-
-def _lazy_init():
-    global _cb_thread
-    if _cb_thread is not None:
-        return
-    global _pp_r
-    global _pp_w
-    _pp_r, _pp_w = os.pipe()
-    _enable_event_collection(_pp_w)
-    c10d._enable_collectives_timing()
-    _cb_thread = threading.Thread(target=_c10d_pg_hooks_loops, daemon=True)
-    _cb_thread.start()
-    logger.info("c10d::hooks thread enabled")
-
-
-def register_collective_start_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
-    r"""Register a hook that is called every time a collective starts.
-
-    The hook is invoked on a background thread.
-    Exceptions raised by the callback are ignored and non-fatal.
-    """
-    _start_callbacks.append(hook)
-    _lazy_init()
-
-
-def register_collective_end_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
-    r"""Register a hook that is called every time a collective finishes.
-
-
-    The hook is invoked on a background thread.
-    Exceptions raised by the callback are ignored and non-fatal.
-    """
-    _end_callbacks.append(hook)
-    _lazy_init()
-
-
-def register_process_group_hook(hook: PG_HOOK_TYPE) -> None:
-    r"""Register a hook that is called every time a process group is created on this rank.
-
-    This hook is only invoked if the current rank is part of the PG being created.
-    The pg_name is unique to the whole cluster and should be treated as an opaque identified subject to change.
-
-    The hook is invoked on a background thread.
-    Exceptions raised by the callback are ignored and non-fatal.
-    """
-    c10d._register_creation_hook(hook)
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 27e960a..232fb42 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -421,19 +421,6 @@
 _tags_to_pg: Dict[str, List[ProcessGroup]] = {}
 _pg_to_tag: Dict[ProcessGroup, str] = {}
 
-class _HookState:
-    def __init__(self):
-        self.creation_hooks = []
-
-    def register_creation_hook(self, hook) -> None:
-        self.creation_hooks.append(hook)
-
-    def fire_creation_hook(self, pg, name) -> None:
-        for hook in self.creation_hooks:
-            try:
-                hook(pg, name)
-            except Exception as e:
-                logger.info("hook %s failed with %s", hook, e)
 
 class _World:
     """
@@ -447,8 +434,6 @@
         self._default_pg = None
         self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
         self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
-        self._hook_state = _HookState()
-        self.enable_collectives_timing = False
 
     @property
     def default_pg(self):
@@ -558,9 +543,6 @@
             )
         return config_info
 
-    @property
-    def pg_hook_state(self) -> _HookState:
-        return self._hook_state
 
 _world = _World()
 """Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
@@ -1382,9 +1364,6 @@
     pg._set_group_name(group_name)
 
     _world.pg_backend_config[pg] = str(backend_config)
-    if _world.enable_collectives_timing:
-        pg._enable_collectives_timing()
-
     # "" is the default tag for user PGs
     if pg_tag in [None, ""]:
         pg_tag = f"ptd:{group_name}"
@@ -1394,8 +1373,6 @@
 
     _world.tags_to_pg.setdefault(pg_tag, []).append(pg)
     _world.pg_to_tag[pg] = pg_tag
-    _world.pg_hook_state.fire_creation_hook(pg, group_name)
-
     return pg, prefix_store
 
 def destroy_process_group(group: Optional[ProcessGroup] = None):
@@ -4338,11 +4315,3 @@
     reduce_scatter_tensor,
     send,
 ]
-
-def _register_creation_hook(hook):
-    _world.pg_hook_state.register_creation_hook(hook)
-
-def _enable_collectives_timing():
-    _world.enable_collectives_timing = True
-    for pg in _world.pg_map:
-        pg._enable_collectives_timing()
diff --git a/torch/testing/_internal/distributed/multi_threaded_pg.py b/torch/testing/_internal/distributed/multi_threaded_pg.py
index 11b0993..a1044e6 100644
--- a/torch/testing/_internal/distributed/multi_threaded_pg.py
+++ b/torch/testing/_internal/distributed/multi_threaded_pg.py
@@ -401,11 +401,6 @@
 class ThreadLocalWorld:
     _world = threading.local()
 
-    def __init__(self):
-        self.enable_collectives_timing = False
-        self._hook_state = dist.distributed_c10d._HookState()
-
-
     def _get_world(self) -> WorldData:
         if not hasattr(ThreadLocalWorld._world, "world"):
             ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {})
@@ -459,9 +454,6 @@
     def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]:
         return self._get_world().pg_default_device
 
-    @property
-    def pg_hook_state(self) -> dist.distributed_c10d._HookState:
-        return self._hook_state
 
 _old_pg_world = None
 _ctx_manager = None