| #!/usr/bin/env python3 |
| |
| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree.abs |
| |
| import json |
| import logging |
| import unittest |
| from dataclasses import asdict |
| from unittest.mock import patch |
| |
| from torch.distributed.elastic.events import ( |
| Event, |
| EventSource, |
| NodeState, |
| RdzvEvent, |
| _get_or_create_logger, |
| construct_and_record_rdzv_event, |
| ) |
| from torch.testing._internal.common_utils import run_tests |
| |
| |
| class EventLibTest(unittest.TestCase): |
| def assert_event(self, actual_event, expected_event): |
| self.assertEqual(actual_event.name, expected_event.name) |
| self.assertEqual(actual_event.source, expected_event.source) |
| self.assertEqual(actual_event.timestamp, expected_event.timestamp) |
| self.assertDictEqual(actual_event.metadata, expected_event.metadata) |
| |
| @patch("torch.distributed.elastic.events.get_logging_handler") |
| def test_get_or_create_logger(self, logging_handler_mock): |
| logging_handler_mock.return_value = logging.NullHandler() |
| logger = _get_or_create_logger("test_destination") |
| self.assertIsNotNone(logger) |
| self.assertEqual(1, len(logger.handlers)) |
| self.assertIsInstance(logger.handlers[0], logging.NullHandler) |
| |
| def test_event_created(self): |
| event = Event( |
| name="test_event", |
| source=EventSource.AGENT, |
| metadata={"key1": "value1", "key2": 2}, |
| ) |
| self.assertEqual("test_event", event.name) |
| self.assertEqual(EventSource.AGENT, event.source) |
| self.assertDictEqual({"key1": "value1", "key2": 2}, event.metadata) |
| |
| def test_event_deser(self): |
| event = Event( |
| name="test_event", |
| source=EventSource.AGENT, |
| metadata={"key1": "value1", "key2": 2, "key3": 1.0}, |
| ) |
| json_event = event.serialize() |
| deser_event = Event.deserialize(json_event) |
| self.assert_event(event, deser_event) |
| |
| class RdzvEventLibTest(unittest.TestCase): |
| @patch("torch.distributed.elastic.events.record_rdzv_event") |
| @patch("torch.distributed.elastic.events.get_logging_handler") |
| def test_construct_and_record_rdzv_event(self, get_mock, record_mock): |
| get_mock.return_value = logging.StreamHandler() |
| construct_and_record_rdzv_event( |
| run_id="test_run_id", |
| message="test_message", |
| node_state=NodeState.RUNNING, |
| ) |
| record_mock.assert_called_once() |
| |
| @patch("torch.distributed.elastic.events.record_rdzv_event") |
| @patch("torch.distributed.elastic.events.get_logging_handler") |
| def test_construct_and_record_rdzv_event_does_not_run_if_invalid_dest(self, get_mock, record_mock): |
| get_mock.return_value = logging.NullHandler() |
| construct_and_record_rdzv_event( |
| run_id="test_run_id", |
| message="test_message", |
| node_state=NodeState.RUNNING, |
| ) |
| record_mock.assert_not_called() |
| |
| def assert_rdzv_event(self, actual_event: RdzvEvent, expected_event: RdzvEvent): |
| self.assertEqual(actual_event.name, expected_event.name) |
| self.assertEqual(actual_event.run_id, expected_event.run_id) |
| self.assertEqual(actual_event.message, expected_event.message) |
| self.assertEqual(actual_event.hostname, expected_event.hostname) |
| self.assertEqual(actual_event.pid, expected_event.pid) |
| self.assertEqual(actual_event.node_state, expected_event.node_state) |
| self.assertEqual(actual_event.master_endpoint, expected_event.master_endpoint) |
| self.assertEqual(actual_event.rank, expected_event.rank) |
| self.assertEqual(actual_event.local_id, expected_event.local_id) |
| self.assertEqual(actual_event.error_trace, expected_event.error_trace) |
| |
| def get_test_rdzv_event(self) -> RdzvEvent: |
| return RdzvEvent( |
| name="test_name", |
| run_id="test_run_id", |
| message="test_message", |
| hostname="test_hostname", |
| pid=1, |
| node_state=NodeState.RUNNING, |
| master_endpoint="test_master_endpoint", |
| rank=3, |
| local_id=4, |
| error_trace="test_error_trace", |
| ) |
| |
| def test_rdzv_event_created(self): |
| event = self.get_test_rdzv_event() |
| self.assertEqual(event.name, "test_name") |
| self.assertEqual(event.run_id, "test_run_id") |
| self.assertEqual(event.message, "test_message") |
| self.assertEqual(event.hostname, "test_hostname") |
| self.assertEqual(event.pid, 1) |
| self.assertEqual(event.node_state, NodeState.RUNNING) |
| self.assertEqual(event.master_endpoint, "test_master_endpoint") |
| self.assertEqual(event.rank, 3) |
| self.assertEqual(event.local_id, 4) |
| self.assertEqual(event.error_trace, "test_error_trace") |
| |
| |
| def test_rdzv_event_deserialize(self): |
| event = self.get_test_rdzv_event() |
| json_event = event.serialize() |
| deserialized_event = RdzvEvent.deserialize(json_event) |
| self.assert_rdzv_event(event, deserialized_event) |
| self.assert_rdzv_event(event, RdzvEvent.deserialize(event)) |
| |
| def test_rdzv_event_str(self): |
| event = self.get_test_rdzv_event() |
| self.assertEqual(str(event), json.dumps(asdict(event))) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |