blob: 60db327d5b95c6097a10adba2026cd6044ce7500 [file] [log] [blame]
#!/usr/bin/env python3
# Owner(s): ["oncall: r2p"]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from unittest import mock
import torch.distributed.elastic.utils.store as store_util
from torch.distributed.elastic.utils.logging import get_logger
from torch.testing._internal.common_utils import run_tests, TestCase
class StoreUtilTest(TestCase):
def test_get_all_rank_0(self):
store = mock.MagicMock()
world_size = 3
store_util.get_all(store, 0, "test/store", world_size)
# omit empty kwargs, get only key
actual_set_call_args = [
call_args[0][0] for call_args in store.set.call_args_list
]
self.assertListEqual(["test/store0.FIN"], actual_set_call_args)
actual_get_call_args = [call_args[0] for call_args in store.get.call_args_list]
expected_get_call_args = [
("test/store0",),
("test/store1",),
("test/store2",),
("test/store0.FIN",),
("test/store1.FIN",),
("test/store2.FIN",),
]
self.assertListEqual(expected_get_call_args, actual_get_call_args)
def test_get_all_rank_n(self):
store = mock.MagicMock()
world_size = 3
store_util.get_all(store, 1, "test/store", world_size)
# omit empty kwargs, get only key
actual_set_call_args = [
call_args[0][0] for call_args in store.set.call_args_list
]
self.assertListEqual(["test/store1.FIN"], actual_set_call_args)
actual_get_call_args = [call_args[0] for call_args in store.get.call_args_list]
expected_get_call_args = [
("test/store0",),
("test/store1",),
("test/store2",),
]
self.assertListEqual(expected_get_call_args, actual_get_call_args)
def test_synchronize(self):
store_mock = mock.MagicMock()
data = b"data0"
store_util.synchronize(store_mock, data, 0, 3, key_prefix="torchelastic/test")
actual_set_call_args = store_mock.set.call_args_list
# omit empty kwargs
actual_set_call_args = [call_args[0] for call_args in actual_set_call_args]
expected_set_call_args = [
("torchelastic/test0", b"data0"),
("torchelastic/test0.FIN", b"FIN"),
]
self.assertListEqual(expected_set_call_args, actual_set_call_args)
expected_get_call_args = [
("torchelastic/test0",),
("torchelastic/test1",),
("torchelastic/test2",),
("torchelastic/test0.FIN",),
("torchelastic/test1.FIN",),
("torchelastic/test2.FIN",),
]
actual_get_call_args = store_mock.get.call_args_list
actual_get_call_args = [call_args[0] for call_args in actual_get_call_args]
self.assertListEqual(expected_get_call_args, actual_get_call_args)
class UtilTest(TestCase):
def test_get_logger_different(self):
logger1 = get_logger("name1")
logger2 = get_logger("name2")
self.assertNotEqual(logger1.name, logger2.name)
def test_get_logger(self):
logger1 = get_logger()
self.assertEqual(__name__, logger1.name)
def test_get_logger_none(self):
logger1 = get_logger(None)
self.assertEqual(__name__, logger1.name)
def test_get_logger_custom_name(self):
logger1 = get_logger("test.module")
self.assertEqual("test.module", logger1.name)
if __name__ == "__main__":
run_tests()