[Torch][Timer] Adding debug info logging interface for expired timers (#123883)

Summary:
Adding function to log additional debug information before killing the expired watchdog timers.

Additional information like stack trace can be added in the debug function using worker process IDs from expired timers.

Test Plan: buck test mode/opt caffe2/test/distributed/elastic/timer:file_based_timer_test

Differential Revision: D56044153

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123883
Approved by: https://github.com/kurman
diff --git a/docs/source/elastic/timer.rst b/docs/source/elastic/timer.rst
index f64597c..3f124a0 100644
--- a/docs/source/elastic/timer.rst
+++ b/docs/source/elastic/timer.rst
@@ -50,3 +50,11 @@
 
 .. autoclass:: TimerClient
    :members:
+
+
+Debug info logging
+-------------------
+
+.. automodule:: torch.distributed.elastic.timer.debug_info_logging
+
+.. autofunction:: torch.distributed.elastic.timer.debug_info_logging.log_debug_info_for_expired_timers
diff --git a/test/distributed/elastic/timer/file_based_local_timer_test.py b/test/distributed/elastic/timer/file_based_local_timer_test.py
index 6c7a92c..4616ae0 100644
--- a/test/distributed/elastic/timer/file_based_local_timer_test.py
+++ b/test/distributed/elastic/timer/file_based_local_timer_test.py
@@ -38,7 +38,9 @@
             super().setUp()
             self.max_interval = 0.01
             self.file_path = "/tmp/test_file_path_" + str(uuid.uuid4())
-            self.server = timer.FileTimerServer(self.file_path, self.max_interval)
+            self.server = timer.FileTimerServer(
+                self.file_path, "test", self.max_interval
+            )
             self.server.start()
 
         def tearDown(self):
@@ -204,7 +206,9 @@
             super().setUp()
             self.file_path = "/tmp/test_file_path_" + str(uuid.uuid4())
             self.max_interval = 0.01
-            self.server = timer.FileTimerServer(self.file_path, self.max_interval)
+            self.server = timer.FileTimerServer(
+                self.file_path, "test", self.max_interval
+            )
 
         def tearDown(self):
             super().tearDown()
@@ -260,7 +264,8 @@
             )
 
         @mock.patch("os.kill")
-        def test_expired_timers(self, mock_os_kill):
+        @mock.patch("torch.distributed.elastic.timer.log_debug_info_for_expired_timers")
+        def test_expired_timers(self, mock_debug_info, mock_os_kill):
             """
             tests that a single expired timer on a process should terminate
             the process and clean up all pending timers that was owned by the process
@@ -275,6 +280,7 @@
             self.server.run_once()  # Allows the server to process all requests
             self.assertEqual(0, len(self.server._timers))
             mock_os_kill.assert_called_once_with(test_pid, signal.SIGKILL)
+            mock_debug_info.assert_called()
 
         @mock.patch("os.kill")
         def test_send_request_release(self, mock_os_kill):
diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py
index 60469c0..e308d53 100644
--- a/torch/distributed/elastic/agent/server/local_elastic_agent.py
+++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py
@@ -165,8 +165,14 @@
             if watchdog_file_path is None:
                 watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4())
             logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path)
+            if not envs:
+                logger.warning("Empty envs variables, using empty run_id for FileTimerServer")
+                run_id = ''
+            else:
+                run_id = envs[0]["TORCHELASTIC_RUN_ID"]
             self._worker_watchdog = timer.FileTimerServer(
                 file_path=watchdog_file_path,
+                run_id=run_id,
                 max_interval=0.1,
                 daemon=True,
                 log_event=self._log_watchdog_event)
diff --git a/torch/distributed/elastic/timer/debug_info_logging.py b/torch/distributed/elastic/timer/debug_info_logging.py
new file mode 100644
index 0000000..8c8645d
--- /dev/null
+++ b/torch/distributed/elastic/timer/debug_info_logging.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Dict, List
+
+from torch.distributed.elastic.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+def log_debug_info_for_expired_timers(
+    run_id: str,
+    expired_timers: Dict[int, List[str]],
+):
+    logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers)
diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py
index 2842c72..f2ded8b 100644
--- a/torch/distributed/elastic/timer/file_based_local_timer.py
+++ b/torch/distributed/elastic/timer/file_based_local_timer.py
@@ -15,6 +15,7 @@
 from typing import Callable, Dict, List, Optional, Set, Tuple
 
 from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
+from torch.distributed.elastic.timer.debug_info_logging import log_debug_info_for_expired_timers
 from torch.distributed.elastic.utils.logging import get_logger
 
 __all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"]
@@ -156,11 +157,13 @@
     def __init__(
         self,
         file_path: str,
+        run_id: str,
         max_interval: float = 10,
         daemon: bool = True,
         log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None
     ) -> None:
         self._file_path = file_path
+        self._run_id = run_id
         self._max_interval = max_interval
         self._daemon = daemon
         self._timers: Dict[Tuple[int, str], FileTimerRequest] = {}
@@ -247,7 +250,14 @@
         self.register_timers(timer_requests)
         now = time.time()
         reaped_worker_pids = set()
-        for worker_pid, expired_timers in self.get_expired_timers(now).items():
+
+        all_expired_timers = self.get_expired_timers(now)
+        log_debug_info_for_expired_timers(
+            self._run_id,
+            {pid: self._get_scopes(expired_timers) for pid, expired_timers in all_expired_timers.items()},
+        )
+
+        for worker_pid, expired_timers in all_expired_timers.items():
             logger.info("Reaping worker_pid=[%s]. Expired timers: %s", worker_pid, self._get_scopes(expired_timers))
             reaped_worker_pids.add(worker_pid)
             # In case we have multiple expired timers, we find the first timer