Revert D27442325: [torch/elastic] Revise the rendezvous handler registry logic.
Test Plan: revert-hammer
Differential Revision:
D27442325 (https://github.com/pytorch/pytorch/commit/df299dbd7d8241669aaac7ce07ed0c034f219b4f)
Original commit changeset: 8519a2caacbe
fbshipit-source-id: f10452567f592c23ae79ca31556a2a77546726b1
diff --git a/test/distributed/elastic/rendezvous/api_test.py b/test/distributed/elastic/rendezvous/api_test.py
index 71be33d..be20f55 100644
--- a/test/distributed/elastic/rendezvous/api_test.py
+++ b/test/distributed/elastic/rendezvous/api_test.py
@@ -7,14 +7,67 @@
from typing import Any, Dict, SupportsInt, Tuple, cast
from unittest import TestCase
-from torch.distributed import Store
from torch.distributed.elastic.rendezvous import (
RendezvousHandler,
- RendezvousHandlerRegistry,
- RendezvousParameters
+ RendezvousHandlerFactory,
+ RendezvousParameters,
)
+def create_mock_rdzv_handler(ignored: RendezvousParameters) -> RendezvousHandler:
+ return MockRendezvousHandler()
+
+
+class MockRendezvousHandler(RendezvousHandler):
+ def next_rendezvous(
+ self,
+ # pyre-ignore[11]: Annotation `Store` is not defined as a type.
+ ) -> Tuple["torch.distributed.Store", int, int]: # noqa F821
+ raise NotImplementedError()
+
+ def get_backend(self) -> str:
+ return "mock"
+
+ def is_closed(self) -> bool:
+ return False
+
+ def set_closed(self):
+ pass
+
+ def num_nodes_waiting(self) -> int:
+ return -1
+
+ def get_run_id(self) -> str:
+ return ""
+
+
+class RendezvousHandlerFactoryTest(TestCase):
+ def test_double_registration(self):
+ factory = RendezvousHandlerFactory()
+ factory.register("mock", create_mock_rdzv_handler)
+ with self.assertRaises(ValueError):
+ factory.register("mock", create_mock_rdzv_handler)
+
+ def test_no_factory_method_found(self):
+ factory = RendezvousHandlerFactory()
+ rdzv_params = RendezvousParameters(
+ backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
+ )
+
+ with self.assertRaises(ValueError):
+ factory.create_handler(rdzv_params)
+
+ def test_create_handler(self):
+ rdzv_params = RendezvousParameters(
+ backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
+ )
+
+ factory = RendezvousHandlerFactory()
+ factory.register("mock", create_mock_rdzv_handler)
+ mock_rdzv_handler = factory.create_handler(rdzv_params)
+ self.assertTrue(isinstance(mock_rdzv_handler, MockRendezvousHandler))
+
+
class RendezvousParametersTest(TestCase):
def setUp(self) -> None:
self._backend = "dummy_backend"
@@ -183,91 +236,3 @@
r"valid integer value.$",
):
params.get_as_int("dummy_param")
-
-
-class _DummyRendezvousHandler(RendezvousHandler):
- def __init__(self, params: RendezvousParameters) -> None:
- self.params = params
-
- def get_backend(self) -> str:
- return "dummy_backend"
-
- def next_rendezvous(self) -> Tuple[Store, int, int]:
- raise NotImplementedError()
-
- def is_closed(self) -> bool:
- return False
-
- def set_closed(self) -> None:
- pass
-
- def num_nodes_waiting(self) -> int:
- return -1
-
- def get_run_id(self) -> str:
- return ""
-
- def shutdown(self) -> bool:
- return False
-
-
-class RendezvousHandlerRegistryTest(TestCase):
- def setUp(self) -> None:
- self._params = RendezvousParameters(
- backend="dummy_backend",
- endpoint="dummy_endpoint",
- run_id="dummy_run_id",
- min_nodes=1,
- max_nodes=1,
- )
-
- self._registry = RendezvousHandlerRegistry()
-
- @staticmethod
- def _create_handler(params: RendezvousParameters) -> RendezvousHandler:
- return _DummyRendezvousHandler(params)
-
- def test_register_registers_once_if_called_twice_with_same_creator(self) -> None:
- self._registry.register("dummy_backend", self._create_handler)
- self._registry.register("dummy_backend", self._create_handler)
-
- def test_register_raises_error_if_called_twice_with_different_creators(self) -> None:
- self._registry.register("dummy_backend", self._create_handler)
-
- other_create_handler = lambda p: _DummyRendezvousHandler(p) # noqa: E731
-
- with self.assertRaisesRegex(
- ValueError,
- r"^The rendezvous backend 'dummy_backend' cannot be registered with "
- rf"'{other_create_handler}' as it is already registered with '{self._create_handler}'.$",
- ):
- self._registry.register("dummy_backend", other_create_handler)
-
- def test_create_handler_returns_handler(self) -> None:
- self._registry.register("dummy_backend", self._create_handler)
-
- handler = self._registry.create_handler(self._params)
-
- self.assertIsInstance(handler, _DummyRendezvousHandler)
-
- self.assertIs(handler.params, self._params)
-
- def test_create_handler_raises_error_if_backend_is_not_registered(self) -> None:
- with self.assertRaisesRegex(
- ValueError,
- r"^The rendezvous backend 'dummy_backend' is not registered. Did you forget to call "
- r"`register`\?$",
- ):
- self._registry.create_handler(self._params)
-
- def test_create_handler_raises_error_if_backend_names_do_not_match(self) -> None:
- self._registry.register("dummy_backend_2", self._create_handler)
-
- with self.assertRaisesRegex(
- RuntimeError,
- r"^The rendezvous backend 'dummy_backend' does not match the requested backend "
- r"'dummy_backend_2'.$",
- ):
- self._params.backend = "dummy_backend_2"
-
- self._registry.create_handler(self._params)
diff --git a/torch/distributed/elastic/rendezvous/__init__.py b/torch/distributed/elastic/rendezvous/__init__.py
index d3ec2b4..6c54935 100644
--- a/torch/distributed/elastic/rendezvous/__init__.py
+++ b/torch/distributed/elastic/rendezvous/__init__.py
@@ -1,3 +1,5 @@
+#!/usr/bin/env/python3
+
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
@@ -99,22 +101,13 @@
to participate in next rendezvous.
"""
-from .api import *
-from .registry import _register_default_handlers
-
-
-_register_default_handlers()
-
-
-__all__ = [
- "RendezvousClosedError",
- "RendezvousConnectionError",
- "RendezvousError",
- "RendezvousHandler",
- "RendezvousHandlerCreator",
- "RendezvousHandlerRegistry",
- "RendezvousParameters",
- "RendezvousStateError",
- "RendezvousTimeoutError",
- "rendezvous_handler_registry",
-]
+from .api import ( # noqa: F401
+ RendezvousClosedError,
+ RendezvousConnectionError,
+ RendezvousError,
+ RendezvousHandler,
+ RendezvousHandlerFactory,
+ RendezvousParameters,
+ RendezvousStateError,
+ RendezvousTimeoutError,
+)
diff --git a/torch/distributed/elastic/rendezvous/api.py b/torch/distributed/elastic/rendezvous/api.py
index c1a3bd3..363fe1b 100644
--- a/torch/distributed/elastic/rendezvous/api.py
+++ b/torch/distributed/elastic/rendezvous/api.py
@@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
import abc
-from typing import Any, Callable, Dict, Optional, Tuple, final
+from typing import Any, Callable, Dict, Optional, Tuple
from torch.distributed import Store
@@ -219,49 +219,42 @@
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
-@final
-class RendezvousHandlerRegistry:
- """Represents a registry of `RendezvousHandler` backends."""
+class RendezvousHandlerFactory:
+ """
+ Creates ``RendezvousHandler`` instances for supported rendezvous backends.
+ """
- _registry: Dict[str, RendezvousHandlerCreator]
+ def __init__(self):
+ self._registry: Dict[str, RendezvousHandlerCreator] = {}
- def __init__(self) -> None:
- self._registry = {}
-
- def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
- """Registers a new rendezvous backend.
-
- Args:
- backend:
- The name of the backend.
- creater:
- The callback to invoke to construct the `RendezvousHandler`.
+ def register(self, backend: str, creator: RendezvousHandlerCreator):
"""
- if not backend:
- raise ValueError("The rendezvous backend name must be a non-empty string.")
-
- current_creator: Optional[RendezvousHandlerCreator]
+ Registers a new rendezvous backend.
+ """
try:
current_creator = self._registry[backend]
except KeyError:
- current_creator = None
+ current_creator = None # type: ignore[assignment]
- if current_creator is not None and current_creator != creator:
+ if current_creator is not None:
raise ValueError(
- f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
- f"is already registered with '{current_creator}'."
+ f"The rendezvous backend '{backend}' cannot be registered with"
+ f" '{creator.__module__}.{creator.__name__}' as it is already"
+ f" registered with '{current_creator.__module__}.{current_creator.__name__}'."
)
self._registry[backend] = creator
def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
- """Creates a new `RendezvousHandler`."""
+ """
+ Creates a new ``RendezvousHandler`` instance for the specified backend.
+ """
try:
creator = self._registry[params.backend]
except KeyError:
raise ValueError(
f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
- f"to call `{self.register.__name__}`?"
+ f"to call {self.register.__name__}?"
)
handler = creator(params)
@@ -269,13 +262,8 @@
# Do some sanity check.
if handler.get_backend() != params.backend:
raise RuntimeError(
- f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
- f"backend '{params.backend}'."
+ f"The rendezvous handler backend '{handler.get_backend()}' does not match the "
+ f"requested backend '{params.backend}'."
)
return handler
-
-
-# The default global registry instance used by launcher scripts to instantiate
-# rendezvous handlers.
-rendezvous_handler_registry = RendezvousHandlerRegistry()
diff --git a/torch/distributed/elastic/rendezvous/registry.py b/torch/distributed/elastic/rendezvous/registry.py
index 3a8f931..0b81159 100644
--- a/torch/distributed/elastic/rendezvous/registry.py
+++ b/torch/distributed/elastic/rendezvous/registry.py
@@ -4,20 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-from .api import RendezvousHandler, RendezvousParameters
-from .api import rendezvous_handler_registry as handler_registry
+from . import etcd_rendezvous
+from .api import (
+ RendezvousHandler,
+ RendezvousHandlerFactory,
+ RendezvousParameters,
+)
+
+_factory = RendezvousHandlerFactory()
+_factory.register("etcd", etcd_rendezvous.create_rdzv_handler)
-def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
- from . import etcd_rendezvous
-
- return etcd_rendezvous.create_rdzv_handler(params)
-
-
-def _register_default_handlers() -> None:
- handler_registry.register("etcd", _create_etcd_handler)
-
-
-# The legacy function kept for backwards compatibility.
def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
- return handler_registry.create_handler(params)
+ return _factory.create_handler(params)