[Torch][Timer] Adding debug info logging interface for expired timers (#123883)
Summary:
Adding function to log additional debug information before killing the expired watchdog timers.
Additional information like stack trace can be added in the debug function using worker process IDs from expired timers.
Test Plan: buck test mode/opt caffe2/test/distributed/elastic/timer:file_based_timer_test
Differential Revision: D56044153
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123883
Approved by: https://github.com/kurman
diff --git a/docs/source/elastic/timer.rst b/docs/source/elastic/timer.rst
index f64597c..3f124a0 100644
--- a/docs/source/elastic/timer.rst
+++ b/docs/source/elastic/timer.rst
@@ -50,3 +50,11 @@
.. autoclass:: TimerClient
:members:
+
+
+Debug info logging
+-------------------
+
+.. automodule:: torch.distributed.elastic.timer.debug_info_logging
+
+.. autofunction:: torch.distributed.elastic.timer.debug_info_logging.log_debug_info_for_expired_timers
diff --git a/test/distributed/elastic/timer/file_based_local_timer_test.py b/test/distributed/elastic/timer/file_based_local_timer_test.py
index 6c7a92c..4616ae0 100644
--- a/test/distributed/elastic/timer/file_based_local_timer_test.py
+++ b/test/distributed/elastic/timer/file_based_local_timer_test.py
@@ -38,7 +38,9 @@
super().setUp()
self.max_interval = 0.01
self.file_path = "/tmp/test_file_path_" + str(uuid.uuid4())
- self.server = timer.FileTimerServer(self.file_path, self.max_interval)
+ self.server = timer.FileTimerServer(
+ self.file_path, "test", self.max_interval
+ )
self.server.start()
def tearDown(self):
@@ -204,7 +206,9 @@
super().setUp()
self.file_path = "/tmp/test_file_path_" + str(uuid.uuid4())
self.max_interval = 0.01
- self.server = timer.FileTimerServer(self.file_path, self.max_interval)
+ self.server = timer.FileTimerServer(
+ self.file_path, "test", self.max_interval
+ )
def tearDown(self):
super().tearDown()
@@ -260,7 +264,8 @@
)
@mock.patch("os.kill")
- def test_expired_timers(self, mock_os_kill):
+ @mock.patch("torch.distributed.elastic.timer.log_debug_info_for_expired_timers")
+ def test_expired_timers(self, mock_debug_info, mock_os_kill):
"""
tests that a single expired timer on a process should terminate
the process and clean up all pending timers that was owned by the process
@@ -275,6 +280,7 @@
self.server.run_once() # Allows the server to process all requests
self.assertEqual(0, len(self.server._timers))
mock_os_kill.assert_called_once_with(test_pid, signal.SIGKILL)
+ mock_debug_info.assert_called()
@mock.patch("os.kill")
def test_send_request_release(self, mock_os_kill):
diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py
index 60469c0..e308d53 100644
--- a/torch/distributed/elastic/agent/server/local_elastic_agent.py
+++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py
@@ -165,8 +165,14 @@
if watchdog_file_path is None:
watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4())
logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path)
+ if not envs:
+ logger.warning("Empty envs variables, using empty run_id for FileTimerServer")
+ run_id = ''
+ else:
+ run_id = envs[0]["TORCHELASTIC_RUN_ID"]
self._worker_watchdog = timer.FileTimerServer(
file_path=watchdog_file_path,
+ run_id=run_id,
max_interval=0.1,
daemon=True,
log_event=self._log_watchdog_event)
diff --git a/torch/distributed/elastic/timer/debug_info_logging.py b/torch/distributed/elastic/timer/debug_info_logging.py
new file mode 100644
index 0000000..8c8645d
--- /dev/null
+++ b/torch/distributed/elastic/timer/debug_info_logging.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Dict, List
+
+from torch.distributed.elastic.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+def log_debug_info_for_expired_timers(
+ run_id: str,
+ expired_timers: Dict[int, List[str]],
+):
+ logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers)
diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py
index 2842c72..f2ded8b 100644
--- a/torch/distributed/elastic/timer/file_based_local_timer.py
+++ b/torch/distributed/elastic/timer/file_based_local_timer.py
@@ -15,6 +15,7 @@
from typing import Callable, Dict, List, Optional, Set, Tuple
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
+from torch.distributed.elastic.timer.debug_info_logging import log_debug_info_for_expired_timers
from torch.distributed.elastic.utils.logging import get_logger
__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"]
@@ -156,11 +157,13 @@
def __init__(
self,
file_path: str,
+ run_id: str,
max_interval: float = 10,
daemon: bool = True,
log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None
) -> None:
self._file_path = file_path
+ self._run_id = run_id
self._max_interval = max_interval
self._daemon = daemon
self._timers: Dict[Tuple[int, str], FileTimerRequest] = {}
@@ -247,7 +250,14 @@
self.register_timers(timer_requests)
now = time.time()
reaped_worker_pids = set()
- for worker_pid, expired_timers in self.get_expired_timers(now).items():
+
+ all_expired_timers = self.get_expired_timers(now)
+ log_debug_info_for_expired_timers(
+ self._run_id,
+ {pid: self._get_scopes(expired_timers) for pid, expired_timers in all_expired_timers.items()},
+ )
+
+ for worker_pid, expired_timers in all_expired_timers.items():
logger.info("Reaping worker_pid=[%s]. Expired timers: %s", worker_pid, self._get_scopes(expired_timers))
reaped_worker_pids.add(worker_pid)
# In case we have multiple expired timers, we find the first timer