[Watchdog Timer] Clear timer for already terminated process (#122324)

Summary:
handling cases where worker process is terminated w/o releasing the timer request, this scenario causes reaping of process at expiry.

removing the non-existent process during clear timer.

Test Plan: unit tests

Differential Revision: D55099773

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122324
Approved by: https://github.com/d4l3k
diff --git a/test/distributed/elastic/timer/file_based_local_timer_test.py b/test/distributed/elastic/timer/file_based_local_timer_test.py
index 785cc97..198c57f 100644
--- a/test/distributed/elastic/timer/file_based_local_timer_test.py
+++ b/test/distributed/elastic/timer/file_based_local_timer_test.py
@@ -120,6 +120,44 @@
             self.server.run_once()  # Allows the server to process all requests
             self.assertEqual(2 * num_clients * num_requests_per_client, self.server._request_count)
 
+        @mock.patch("torch.distributed.elastic.timer.FileTimerServer._reap_worker")
+        def test_exit_before_release(self, mock_reap):
+            def func1(file_path):
+                client = timer.FileTimerClient(file_path)
+                timer.configure(client)
+                expire = time.time() + 2
+                client.acquire("test_scope", expire)
+                time.sleep(1)
+
+            p = mp.Process(target=func1, args=(self.file_path,))
+            p.start()
+            p.join()
+
+            time.sleep(2)
+            self.server.run_once()  # Allows the server to process all requests
+            mock_reap.assert_not_called()
+            self.assertEqual(0, len(self.server._timers))
+
+        @mock.patch("torch.distributed.elastic.timer.FileTimerServer._reap_worker")
+        @mock.patch("torch.distributed.elastic.timer.FileTimerServer.is_process_running")
+        def test_exit_before_release_reap(self, mock_pid_exists, mock_reap):
+            def func1(file_path):
+                client = timer.FileTimerClient(file_path)
+                timer.configure(client)
+                expire = time.time() + 2
+                client.acquire("test_scope", expire)
+                time.sleep(1)
+
+            mock_pid_exists.return_value = True
+            p = mp.Process(target=func1, args=(self.file_path,))
+            p.start()
+            p.join()
+
+            time.sleep(2)
+            self.server.run_once()  # Allows the server to process all requests
+            mock_reap.assert_called()
+            self.assertEqual(0, len(self.server._timers))
+
         @staticmethod
         def _run(file_path, timeout, duration):
             client = timer.FileTimerClient(file_path)
@@ -240,12 +278,14 @@
             self.assertEqual(0, len(self.server._timers))
             mock_os_kill.assert_not_called()
 
+        @mock.patch("torch.distributed.elastic.timer.FileTimerServer.is_process_running")
         @mock.patch("os.kill")
-        def test_valid_timers(self, mock_os_kill):
+        def test_valid_timers(self, mock_os_kill, mock_pid_exists):
             """
             tests that valid timers are processed correctly and the process is left alone
             """
             self.server.start()
+            mock_pid_exists.return_value = True
 
             client = timer.FileTimerClient(self.file_path)
             client._send_request(self._valid_timer(pid=-3, scope="test1"))
diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py
index 26ebce3..f2c5d44 100644
--- a/torch/distributed/elastic/timer/file_based_local_timer.py
+++ b/torch/distributed/elastic/timer/file_based_local_timer.py
@@ -6,7 +6,6 @@
 
 import io
 import json
-import logging
 import os
 import select
 import signal
@@ -16,10 +15,11 @@
 from typing import Callable, Dict, List, Optional, Set, Tuple
 
 from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
+from torch.distributed.elastic.utils.logging import get_logger
 
 __all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"]
 
-log = logging.getLogger(__name__)
+log = get_logger(__name__)
 
 class FileTimerRequest(TimerRequest):
     """
@@ -212,6 +212,18 @@
         if os.path.exists(self._file_path):
             os.remove(self._file_path)
 
+    @staticmethod
+    def is_process_running(pid: int):
+        """
+        function to check process is running or not
+        """
+        try:
+            # Check if the process exists and we can send signals to it
+            os.kill(pid, 0)
+            return True
+        except OSError:
+            return False
+
     def _watchdog_loop(self) -> None:
         # Open the pipe in blocking mode blocks the server thread.
         # This is fine for the following reasons:
@@ -309,7 +321,7 @@
 
     def clear_timers(self, worker_pids: Set[int]) -> None:
         for (pid, scope_id) in list(self._timers.keys()):
-            if pid in worker_pids:
+            if pid in worker_pids or not FileTimerServer.is_process_running(pid):
                 del self._timers[(pid, scope_id)]
 
     def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]: