blob: 7d610f9a1040b890b0e390a539798fe312956436 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.abs
import logging
import unittest
from unittest.mock import patch
from torch.distributed.elastic.events import _get_or_create_logger, Event, EventSource
from torch.testing._internal.common_utils import run_tests
class EventLibTest(unittest.TestCase):
def assert_event(self, actual_event, expected_event):
self.assertEqual(actual_event.name, expected_event.name)
self.assertEqual(actual_event.source, expected_event.source)
self.assertEqual(actual_event.timestamp, expected_event.timestamp)
self.assertDictEqual(actual_event.metadata, expected_event.metadata)
@patch("torch.distributed.elastic.events.get_logging_handler")
def test_get_or_create_logger(self, logging_handler_mock):
logging_handler_mock.return_value = logging.NullHandler()
logger = _get_or_create_logger("test_destination")
self.assertIsNotNone(logger)
self.assertEqual(1, len(logger.handlers))
self.assertIsInstance(logger.handlers[0], logging.NullHandler)
def test_event_created(self):
event = Event(
name="test_event",
source=EventSource.AGENT,
metadata={"key1": "value1", "key2": 2},
)
self.assertEqual("test_event", event.name)
self.assertEqual(EventSource.AGENT, event.source)
self.assertDictEqual({"key1": "value1", "key2": 2}, event.metadata)
def test_event_deser(self):
event = Event(
name="test_event",
source=EventSource.AGENT,
metadata={"key1": "value1", "key2": 2, "key3": 1.0},
)
json_event = event.serialize()
deser_event = Event.deserialize(json_event)
self.assert_event(event, deser_event)
if __name__ == "__main__":
run_tests()