[Watchdog Timer] Clear timer for already terminated process (#122324)
Summary:
handling cases where worker process is terminated w/o releasing the timer request, this scenario causes reaping of process at expiry.
removing the non-existent process during clear timer.
Test Plan: unit tests
Differential Revision: D55099773
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122324
Approved by: https://github.com/d4l3k
diff --git a/test/distributed/elastic/timer/file_based_local_timer_test.py b/test/distributed/elastic/timer/file_based_local_timer_test.py
index 785cc97..198c57f 100644
--- a/test/distributed/elastic/timer/file_based_local_timer_test.py
+++ b/test/distributed/elastic/timer/file_based_local_timer_test.py
@@ -120,6 +120,44 @@
self.server.run_once() # Allows the server to process all requests
self.assertEqual(2 * num_clients * num_requests_per_client, self.server._request_count)
+ @mock.patch("torch.distributed.elastic.timer.FileTimerServer._reap_worker")
+ def test_exit_before_release(self, mock_reap):
+ def func1(file_path):
+ client = timer.FileTimerClient(file_path)
+ timer.configure(client)
+ expire = time.time() + 2
+ client.acquire("test_scope", expire)
+ time.sleep(1)
+
+ p = mp.Process(target=func1, args=(self.file_path,))
+ p.start()
+ p.join()
+
+ time.sleep(2)
+ self.server.run_once() # Allows the server to process all requests
+ mock_reap.assert_not_called()
+ self.assertEqual(0, len(self.server._timers))
+
+ @mock.patch("torch.distributed.elastic.timer.FileTimerServer._reap_worker")
+ @mock.patch("torch.distributed.elastic.timer.FileTimerServer.is_process_running")
+ def test_exit_before_release_reap(self, mock_pid_exists, mock_reap):
+ def func1(file_path):
+ client = timer.FileTimerClient(file_path)
+ timer.configure(client)
+ expire = time.time() + 2
+ client.acquire("test_scope", expire)
+ time.sleep(1)
+
+ mock_pid_exists.return_value = True
+ p = mp.Process(target=func1, args=(self.file_path,))
+ p.start()
+ p.join()
+
+ time.sleep(2)
+ self.server.run_once() # Allows the server to process all requests
+ mock_reap.assert_called()
+ self.assertEqual(0, len(self.server._timers))
+
@staticmethod
def _run(file_path, timeout, duration):
client = timer.FileTimerClient(file_path)
@@ -240,12 +278,14 @@
self.assertEqual(0, len(self.server._timers))
mock_os_kill.assert_not_called()
+ @mock.patch("torch.distributed.elastic.timer.FileTimerServer.is_process_running")
@mock.patch("os.kill")
- def test_valid_timers(self, mock_os_kill):
+ def test_valid_timers(self, mock_os_kill, mock_pid_exists):
"""
tests that valid timers are processed correctly and the process is left alone
"""
self.server.start()
+ mock_pid_exists.return_value = True
client = timer.FileTimerClient(self.file_path)
client._send_request(self._valid_timer(pid=-3, scope="test1"))
diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py
index 26ebce3..f2c5d44 100644
--- a/torch/distributed/elastic/timer/file_based_local_timer.py
+++ b/torch/distributed/elastic/timer/file_based_local_timer.py
@@ -6,7 +6,6 @@
import io
import json
-import logging
import os
import select
import signal
@@ -16,10 +15,11 @@
from typing import Callable, Dict, List, Optional, Set, Tuple
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
+from torch.distributed.elastic.utils.logging import get_logger
__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"]
-log = logging.getLogger(__name__)
+log = get_logger(__name__)
class FileTimerRequest(TimerRequest):
"""
@@ -212,6 +212,18 @@
if os.path.exists(self._file_path):
os.remove(self._file_path)
+ @staticmethod
+ def is_process_running(pid: int):
+ """
+ function to check process is running or not
+ """
+ try:
+ # Check if the process exists and we can send signals to it
+ os.kill(pid, 0)
+ return True
+ except OSError:
+ return False
+
def _watchdog_loop(self) -> None:
# Open the pipe in blocking mode blocks the server thread.
# This is fine for the following reasons:
@@ -309,7 +321,7 @@
def clear_timers(self, worker_pids: Set[int]) -> None:
for (pid, scope_id) in list(self._timers.keys()):
- if pid in worker_pids:
+ if pid in worker_pids or not FileTimerServer.is_process_running(pid):
del self._timers[(pid, scope_id)]
def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]: