Revert D27442325: [torch/elastic] Revise the rendezvous handler registry logic.

Test Plan: revert-hammer

Differential Revision:
D27442325 (https://github.com/pytorch/pytorch/commit/df299dbd7d8241669aaac7ce07ed0c034f219b4f)

Original commit changeset: 8519a2caacbe

fbshipit-source-id: f10452567f592c23ae79ca31556a2a77546726b1
diff --git a/test/distributed/elastic/rendezvous/api_test.py b/test/distributed/elastic/rendezvous/api_test.py
index 71be33d..be20f55 100644
--- a/test/distributed/elastic/rendezvous/api_test.py
+++ b/test/distributed/elastic/rendezvous/api_test.py
@@ -7,14 +7,67 @@
 from typing import Any, Dict, SupportsInt, Tuple, cast
 from unittest import TestCase
 
-from torch.distributed import Store
 from torch.distributed.elastic.rendezvous import (
     RendezvousHandler,
-    RendezvousHandlerRegistry,
-    RendezvousParameters
+    RendezvousHandlerFactory,
+    RendezvousParameters,
 )
 
 
+def create_mock_rdzv_handler(ignored: RendezvousParameters) -> RendezvousHandler:
+    return MockRendezvousHandler()
+
+
+class MockRendezvousHandler(RendezvousHandler):
+    def next_rendezvous(
+        self,
+        # pyre-ignore[11]: Annotation `Store` is not defined as a type.
+    ) -> Tuple["torch.distributed.Store", int, int]:  # noqa F821
+        raise NotImplementedError()
+
+    def get_backend(self) -> str:
+        return "mock"
+
+    def is_closed(self) -> bool:
+        return False
+
+    def set_closed(self):
+        pass
+
+    def num_nodes_waiting(self) -> int:
+        return -1
+
+    def get_run_id(self) -> str:
+        return ""
+
+
+class RendezvousHandlerFactoryTest(TestCase):
+    def test_double_registration(self):
+        factory = RendezvousHandlerFactory()
+        factory.register("mock", create_mock_rdzv_handler)
+        with self.assertRaises(ValueError):
+            factory.register("mock", create_mock_rdzv_handler)
+
+    def test_no_factory_method_found(self):
+        factory = RendezvousHandlerFactory()
+        rdzv_params = RendezvousParameters(
+            backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
+        )
+
+        with self.assertRaises(ValueError):
+            factory.create_handler(rdzv_params)
+
+    def test_create_handler(self):
+        rdzv_params = RendezvousParameters(
+            backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
+        )
+
+        factory = RendezvousHandlerFactory()
+        factory.register("mock", create_mock_rdzv_handler)
+        mock_rdzv_handler = factory.create_handler(rdzv_params)
+        self.assertTrue(isinstance(mock_rdzv_handler, MockRendezvousHandler))
+
+
 class RendezvousParametersTest(TestCase):
     def setUp(self) -> None:
         self._backend = "dummy_backend"
@@ -183,91 +236,3 @@
                     r"valid integer value.$",
                 ):
                     params.get_as_int("dummy_param")
-
-
-class _DummyRendezvousHandler(RendezvousHandler):
-    def __init__(self, params: RendezvousParameters) -> None:
-        self.params = params
-
-    def get_backend(self) -> str:
-        return "dummy_backend"
-
-    def next_rendezvous(self) -> Tuple[Store, int, int]:
-        raise NotImplementedError()
-
-    def is_closed(self) -> bool:
-        return False
-
-    def set_closed(self) -> None:
-        pass
-
-    def num_nodes_waiting(self) -> int:
-        return -1
-
-    def get_run_id(self) -> str:
-        return ""
-
-    def shutdown(self) -> bool:
-        return False
-
-
-class RendezvousHandlerRegistryTest(TestCase):
-    def setUp(self) -> None:
-        self._params = RendezvousParameters(
-            backend="dummy_backend",
-            endpoint="dummy_endpoint",
-            run_id="dummy_run_id",
-            min_nodes=1,
-            max_nodes=1,
-        )
-
-        self._registry = RendezvousHandlerRegistry()
-
-    @staticmethod
-    def _create_handler(params: RendezvousParameters) -> RendezvousHandler:
-        return _DummyRendezvousHandler(params)
-
-    def test_register_registers_once_if_called_twice_with_same_creator(self) -> None:
-        self._registry.register("dummy_backend", self._create_handler)
-        self._registry.register("dummy_backend", self._create_handler)
-
-    def test_register_raises_error_if_called_twice_with_different_creators(self) -> None:
-        self._registry.register("dummy_backend", self._create_handler)
-
-        other_create_handler = lambda p: _DummyRendezvousHandler(p)  # noqa: E731
-
-        with self.assertRaisesRegex(
-            ValueError,
-            r"^The rendezvous backend 'dummy_backend' cannot be registered with "
-            rf"'{other_create_handler}' as it is already registered with '{self._create_handler}'.$",
-        ):
-            self._registry.register("dummy_backend", other_create_handler)
-
-    def test_create_handler_returns_handler(self) -> None:
-        self._registry.register("dummy_backend", self._create_handler)
-
-        handler = self._registry.create_handler(self._params)
-
-        self.assertIsInstance(handler, _DummyRendezvousHandler)
-
-        self.assertIs(handler.params, self._params)
-
-    def test_create_handler_raises_error_if_backend_is_not_registered(self) -> None:
-        with self.assertRaisesRegex(
-            ValueError,
-            r"^The rendezvous backend 'dummy_backend' is not registered. Did you forget to call "
-            r"`register`\?$",
-        ):
-            self._registry.create_handler(self._params)
-
-    def test_create_handler_raises_error_if_backend_names_do_not_match(self) -> None:
-        self._registry.register("dummy_backend_2", self._create_handler)
-
-        with self.assertRaisesRegex(
-            RuntimeError,
-            r"^The rendezvous backend 'dummy_backend' does not match the requested backend "
-            r"'dummy_backend_2'.$",
-        ):
-            self._params.backend = "dummy_backend_2"
-
-            self._registry.create_handler(self._params)
diff --git a/torch/distributed/elastic/rendezvous/__init__.py b/torch/distributed/elastic/rendezvous/__init__.py
index d3ec2b4..6c54935 100644
--- a/torch/distributed/elastic/rendezvous/__init__.py
+++ b/torch/distributed/elastic/rendezvous/__init__.py
@@ -1,3 +1,5 @@
+#!/usr/bin/env/python3
+
 # Copyright (c) Facebook, Inc. and its affiliates.
 # All rights reserved.
 #
@@ -99,22 +101,13 @@
    to participate in next rendezvous.
 """
 
-from .api import *
-from .registry import _register_default_handlers
-
-
-_register_default_handlers()
-
-
-__all__ = [
-    "RendezvousClosedError",
-    "RendezvousConnectionError",
-    "RendezvousError",
-    "RendezvousHandler",
-    "RendezvousHandlerCreator",
-    "RendezvousHandlerRegistry",
-    "RendezvousParameters",
-    "RendezvousStateError",
-    "RendezvousTimeoutError",
-    "rendezvous_handler_registry",
-]
+from .api import (  # noqa: F401
+    RendezvousClosedError,
+    RendezvousConnectionError,
+    RendezvousError,
+    RendezvousHandler,
+    RendezvousHandlerFactory,
+    RendezvousParameters,
+    RendezvousStateError,
+    RendezvousTimeoutError,
+)
diff --git a/torch/distributed/elastic/rendezvous/api.py b/torch/distributed/elastic/rendezvous/api.py
index c1a3bd3..363fe1b 100644
--- a/torch/distributed/elastic/rendezvous/api.py
+++ b/torch/distributed/elastic/rendezvous/api.py
@@ -5,7 +5,7 @@
 # LICENSE file in the root directory of this source tree.
 
 import abc
-from typing import Any, Callable, Dict, Optional, Tuple, final
+from typing import Any, Callable, Dict, Optional, Tuple
 
 from torch.distributed import Store
 
@@ -219,49 +219,42 @@
 RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
 
 
-@final
-class RendezvousHandlerRegistry:
-    """Represents a registry of `RendezvousHandler` backends."""
+class RendezvousHandlerFactory:
+    """
+    Creates ``RendezvousHandler`` instances for supported rendezvous backends.
+    """
 
-    _registry: Dict[str, RendezvousHandlerCreator]
+    def __init__(self):
+        self._registry: Dict[str, RendezvousHandlerCreator] = {}
 
-    def __init__(self) -> None:
-        self._registry = {}
-
-    def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
-        """Registers a new rendezvous backend.
-
-        Args:
-            backend:
-                The name of the backend.
-            creater:
-                The callback to invoke to construct the `RendezvousHandler`.
+    def register(self, backend: str, creator: RendezvousHandlerCreator):
         """
-        if not backend:
-            raise ValueError("The rendezvous backend name must be a non-empty string.")
-
-        current_creator: Optional[RendezvousHandlerCreator]
+        Registers a new rendezvous backend.
+        """
         try:
             current_creator = self._registry[backend]
         except KeyError:
-            current_creator = None
+            current_creator = None  # type: ignore[assignment]
 
-        if current_creator is not None and current_creator != creator:
+        if current_creator is not None:
             raise ValueError(
-                f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
-                f"is already registered with '{current_creator}'."
+                f"The rendezvous backend '{backend}' cannot be registered with"
+                f" '{creator.__module__}.{creator.__name__}' as it is already"
+                f" registered with '{current_creator.__module__}.{current_creator.__name__}'."
             )
 
         self._registry[backend] = creator
 
     def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
-        """Creates a new `RendezvousHandler`."""
+        """
+        Creates a new ``RendezvousHandler`` instance for the specified backend.
+        """
         try:
             creator = self._registry[params.backend]
         except KeyError:
             raise ValueError(
                 f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
-                f"to call `{self.register.__name__}`?"
+                f"to call {self.register.__name__}?"
             )
 
         handler = creator(params)
@@ -269,13 +262,8 @@
         # Do some sanity check.
         if handler.get_backend() != params.backend:
             raise RuntimeError(
-                f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
-                f"backend '{params.backend}'."
+                f"The rendezvous handler backend '{handler.get_backend()}' does not match the "
+                f"requested backend '{params.backend}'."
             )
 
         return handler
-
-
-# The default global registry instance used by launcher scripts to instantiate
-# rendezvous handlers.
-rendezvous_handler_registry = RendezvousHandlerRegistry()
diff --git a/torch/distributed/elastic/rendezvous/registry.py b/torch/distributed/elastic/rendezvous/registry.py
index 3a8f931..0b81159 100644
--- a/torch/distributed/elastic/rendezvous/registry.py
+++ b/torch/distributed/elastic/rendezvous/registry.py
@@ -4,20 +4,16 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
-from .api import RendezvousHandler, RendezvousParameters
-from .api import rendezvous_handler_registry as handler_registry
+from . import etcd_rendezvous
+from .api import (
+    RendezvousHandler,
+    RendezvousHandlerFactory,
+    RendezvousParameters,
+)
+
+_factory = RendezvousHandlerFactory()
+_factory.register("etcd", etcd_rendezvous.create_rdzv_handler)
 
 
-def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
-    from . import etcd_rendezvous
-
-    return etcd_rendezvous.create_rdzv_handler(params)
-
-
-def _register_default_handlers() -> None:
-    handler_registry.register("etcd", _create_etcd_handler)
-
-
-# The legacy function kept for backwards compatibility.
 def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
-    return handler_registry.create_handler(params)
+    return _factory.create_handler(params)