| # Owner(s): ["oncall: r2p"] |
| |
| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| import unittest |
| import unittest.mock as mock |
| |
| from torch.distributed.elastic.timer import TimerServer |
| from torch.distributed.elastic.timer.api import RequestQueue, TimerRequest |
| |
| |
| class MockRequestQueue(RequestQueue): |
| def size(self): |
| return 2 |
| |
| def get(self, size, timeout): |
| return [TimerRequest(1, "test_1", 0), TimerRequest(2, "test_2", 0)] |
| |
| |
| class MockTimerServer(TimerServer): |
| """ |
| Mock implementation of TimerServer for testing purposes. |
| This mock has the following behavior: |
| |
| 1. reaping worker 1 throws |
| 2. reaping worker 2 succeeds |
| 3. reaping worker 3 fails (caught exception) |
| |
| For each workers 1 - 3 returns 2 expired timers |
| """ |
| |
| def __init__(self, request_queue, max_interval): |
| super().__init__(request_queue, max_interval) |
| |
| def register_timers(self, timer_requests): |
| pass |
| |
| def clear_timers(self, worker_ids): |
| pass |
| |
| def get_expired_timers(self, deadline): |
| return { |
| i: [TimerRequest(i, f"test_{i}_0", 0), TimerRequest(i, f"test_{i}_1", 0)] |
| for i in range(1, 4) |
| } |
| |
| def _reap_worker(self, worker_id): |
| if worker_id == 1: |
| raise RuntimeError("test error") |
| elif worker_id == 2: |
| return True |
| elif worker_id == 3: |
| return False |
| |
| |
| class TimerApiTest(unittest.TestCase): |
| @mock.patch.object(MockTimerServer, "register_timers") |
| @mock.patch.object(MockTimerServer, "clear_timers") |
| def test_run_watchdog(self, mock_clear_timers, mock_register_timers): |
| """ |
| tests that when a ``_reap_worker()`` method throws an exception |
| for a particular worker_id, the timers for successfully reaped workers |
| are cleared properly |
| """ |
| max_interval = 1 |
| request_queue = mock.Mock(wraps=MockRequestQueue()) |
| timer_server = MockTimerServer(request_queue, max_interval) |
| timer_server._run_watchdog() |
| |
| request_queue.size.assert_called_once() |
| request_queue.get.assert_called_with(request_queue.size(), max_interval) |
| mock_register_timers.assert_called_with(request_queue.get(2, 1)) |
| mock_clear_timers.assert_called_with({1, 2}) |