| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| from typing import Any, Dict, SupportsInt, Tuple, cast |
| from unittest import TestCase |
| |
| from torch.distributed import Store |
| from torch.distributed.elastic.rendezvous import ( |
| RendezvousHandler, |
| RendezvousHandlerRegistry, |
| RendezvousParameters, |
| ) |
| |
| |
| class RendezvousParametersTest(TestCase): |
| def setUp(self) -> None: |
| self._backend = "dummy_backend" |
| self._endpoint = "dummy_endpoint" |
| self._run_id = "dummy_run_id" |
| self._min_nodes = 3 |
| self._max_nodes = 6 |
| self._kwargs: Dict[str, Any] = {} |
| |
| def _create_params(self) -> RendezvousParameters: |
| return RendezvousParameters( |
| backend=self._backend, |
| endpoint=self._endpoint, |
| run_id=self._run_id, |
| min_nodes=self._min_nodes, |
| max_nodes=self._max_nodes, |
| **self._kwargs, |
| ) |
| |
| def test_init_initializes_params(self) -> None: |
| self._kwargs["dummy_param"] = "x" |
| |
| params = self._create_params() |
| |
| self.assertEqual(params.backend, self._backend) |
| self.assertEqual(params.endpoint, self._endpoint) |
| self.assertEqual(params.run_id, self._run_id) |
| self.assertEqual(params.min_nodes, self._min_nodes) |
| self.assertEqual(params.max_nodes, self._max_nodes) |
| |
| self.assertEqual(params.get("dummy_param"), "x") |
| |
| def test_init_initializes_params_if_min_nodes_equals_to_1(self) -> None: |
| self._min_nodes = 1 |
| |
| params = self._create_params() |
| |
| self.assertEqual(params.min_nodes, self._min_nodes) |
| self.assertEqual(params.max_nodes, self._max_nodes) |
| |
| def test_init_initializes_params_if_min_and_max_nodes_are_equal(self) -> None: |
| self._max_nodes = 3 |
| |
| params = self._create_params() |
| |
| self.assertEqual(params.min_nodes, self._min_nodes) |
| self.assertEqual(params.max_nodes, self._max_nodes) |
| |
| def test_init_raises_error_if_backend_is_none_or_empty(self) -> None: |
| for backend in [None, ""]: |
| with self.subTest(backend=backend): |
| self._backend = backend # type: ignore[assignment] |
| |
| with self.assertRaisesRegex( |
| ValueError, |
| r"^The rendezvous backend name must be a non-empty string.$", |
| ): |
| self._create_params() |
| |
| def test_init_raises_error_if_min_nodes_is_less_than_1(self) -> None: |
| for min_nodes in [0, -1, -5]: |
| with self.subTest(min_nodes=min_nodes): |
| self._min_nodes = min_nodes |
| |
| with self.assertRaisesRegex( |
| ValueError, |
| rf"^The minimum number of rendezvous nodes \({min_nodes}\) must be greater " |
| rf"than zero.$", |
| ): |
| self._create_params() |
| |
| def test_init_raises_error_if_max_nodes_is_less_than_min_nodes(self) -> None: |
| for max_nodes in [2, 1, -2]: |
| with self.subTest(max_nodes=max_nodes): |
| self._max_nodes = max_nodes |
| |
| with self.assertRaisesRegex( |
| ValueError, |
| rf"^The maximum number of rendezvous nodes \({max_nodes}\) must be greater " |
| "than or equal to the minimum number of rendezvous nodes " |
| rf"\({self._min_nodes}\).$", |
| ): |
| self._create_params() |
| |
| def test_get_returns_none_if_key_does_not_exist(self) -> None: |
| params = self._create_params() |
| |
| self.assertIsNone(params.get("dummy_param")) |
| |
| def test_get_returns_default_if_key_does_not_exist(self) -> None: |
| params = self._create_params() |
| |
| self.assertEqual(params.get("dummy_param", default="x"), "x") |
| |
| def test_get_as_bool_returns_none_if_key_does_not_exist(self) -> None: |
| params = self._create_params() |
| |
| self.assertIsNone(params.get_as_bool("dummy_param")) |
| |
| def test_get_as_bool_returns_default_if_key_does_not_exist(self) -> None: |
| params = self._create_params() |
| |
| self.assertTrue(params.get_as_bool("dummy_param", default=True)) |
| |
| def test_get_as_bool_returns_true_if_value_represents_true(self) -> None: |
| for value in ["1", "True", "tRue", "T", "t", "yEs", "Y", 1, True]: |
| with self.subTest(value=value): |
| self._kwargs["dummy_param"] = value |
| |
| params = self._create_params() |
| |
| self.assertTrue(params.get_as_bool("dummy_param")) |
| |
| def test_get_as_bool_returns_false_if_value_represents_false(self) -> None: |
| for value in ["0", "False", "faLse", "F", "f", "nO", "N", 0, False]: |
| with self.subTest(value=value): |
| self._kwargs["dummy_param"] = value |
| |
| params = self._create_params() |
| |
| self.assertFalse(params.get_as_bool("dummy_param")) |
| |
| def test_get_as_bool_raises_error_if_value_is_invalid(self) -> None: |
| for value in ["01", "Flse", "Ture", "g", "4", "_", "truefalse", 2, -1]: |
| with self.subTest(value=value): |
| self._kwargs["dummy_param"] = value |
| |
| params = self._create_params() |
| |
| with self.assertRaisesRegex( |
| ValueError, |
| r"^The rendezvous configuration option 'dummy_param' does not represent a " |
| r"valid boolean value.$", |
| ): |
| params.get_as_bool("dummy_param") |
| |
| def test_get_as_int_returns_none_if_key_does_not_exist(self) -> None: |
| params = self._create_params() |
| |
| self.assertIsNone(params.get_as_int("dummy_param")) |
| |
| def test_get_as_int_returns_default_if_key_does_not_exist(self) -> None: |
| params = self._create_params() |
| |
| self.assertEqual(params.get_as_int("dummy_param", default=5), 5) |
| |
| def test_get_as_int_returns_integer_if_value_represents_integer(self) -> None: |
| for value in ["0", "-10", "5", " 4", "4 ", " 4 ", 0, -4, 3]: |
| with self.subTest(value=value): |
| self._kwargs["dummy_param"] = value |
| |
| params = self._create_params() |
| |
| self.assertEqual(params.get_as_int("dummy_param"), int(cast(SupportsInt, value))) |
| |
| def test_get_as_int_raises_error_if_value_is_invalid(self) -> None: |
| for value in ["a", "0a", "3b", "abc"]: |
| with self.subTest(value=value): |
| self._kwargs["dummy_param"] = value |
| |
| params = self._create_params() |
| |
| with self.assertRaisesRegex( |
| ValueError, |
| r"^The rendezvous configuration option 'dummy_param' does not represent a " |
| r"valid integer value.$", |
| ): |
| params.get_as_int("dummy_param") |
| |
| |
| class _DummyRendezvousHandler(RendezvousHandler): |
| def __init__(self, params: RendezvousParameters) -> None: |
| self.params = params |
| |
| def get_backend(self) -> str: |
| return "dummy_backend" |
| |
| def next_rendezvous(self) -> Tuple[Store, int, int]: |
| raise NotImplementedError() |
| |
| def is_closed(self) -> bool: |
| return False |
| |
| def set_closed(self) -> None: |
| pass |
| |
| def num_nodes_waiting(self) -> int: |
| return 0 |
| |
| def get_run_id(self) -> str: |
| return "" |
| |
| def shutdown(self) -> bool: |
| return False |
| |
| |
| class RendezvousHandlerRegistryTest(TestCase): |
| def setUp(self) -> None: |
| self._params = RendezvousParameters( |
| backend="dummy_backend", |
| endpoint="dummy_endpoint", |
| run_id="dummy_run_id", |
| min_nodes=1, |
| max_nodes=1, |
| ) |
| |
| self._registry = RendezvousHandlerRegistry() |
| |
| @staticmethod |
| def _create_handler(params: RendezvousParameters) -> RendezvousHandler: |
| return _DummyRendezvousHandler(params) |
| |
| def test_register_registers_once_if_called_twice_with_same_creator(self) -> None: |
| self._registry.register("dummy_backend", self._create_handler) |
| self._registry.register("dummy_backend", self._create_handler) |
| |
| def test_register_raises_error_if_called_twice_with_different_creators(self) -> None: |
| self._registry.register("dummy_backend", self._create_handler) |
| |
| other_create_handler = lambda p: _DummyRendezvousHandler(p) # noqa: E731 |
| |
| with self.assertRaisesRegex( |
| ValueError, |
| r"^The rendezvous backend 'dummy_backend' cannot be registered with " |
| rf"'{other_create_handler}' as it is already registered with '{self._create_handler}'.$", |
| ): |
| self._registry.register("dummy_backend", other_create_handler) |
| |
| def test_create_handler_returns_handler(self) -> None: |
| self._registry.register("dummy_backend", self._create_handler) |
| |
| handler = self._registry.create_handler(self._params) |
| |
| self.assertIsInstance(handler, _DummyRendezvousHandler) |
| |
| self.assertIs(handler.params, self._params) |
| |
| def test_create_handler_raises_error_if_backend_is_not_registered(self) -> None: |
| with self.assertRaisesRegex( |
| ValueError, |
| r"^The rendezvous backend 'dummy_backend' is not registered. Did you forget to call " |
| r"`register`\?$", |
| ): |
| self._registry.create_handler(self._params) |
| |
| def test_create_handler_raises_error_if_backend_names_do_not_match(self) -> None: |
| self._registry.register("dummy_backend_2", self._create_handler) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"^The rendezvous backend 'dummy_backend' does not match the requested backend " |
| r"'dummy_backend_2'.$", |
| ): |
| self._params.backend = "dummy_backend_2" |
| |
| self._registry.create_handler(self._params) |