[Torchelastic][Logging] Pluggable logsspecs using python entrypoints and option to specify one by name. (#120942)

Summary:
Expose an option to users to specify name of the LogsSpec implementation to use.
- Has to be defined in entrypoints under `torchrun.logs_specs` group.
- Must implement LogsSpec defined in prior PR/diff.

Test Plan: unit test+local tests

Reviewed By: ezyang

Differential Revision: D54180838

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120942
Approved by: https://github.com/ezyang
diff --git a/setup.py b/setup.py
index 7eec481..84cd3fd 100644
--- a/setup.py
+++ b/setup.py
@@ -1055,7 +1055,10 @@
             "convert-caffe2-to-onnx = caffe2.python.onnx.bin.conversion:caffe2_to_onnx",
             "convert-onnx-to-caffe2 = caffe2.python.onnx.bin.conversion:onnx_to_caffe2",
             "torchrun = torch.distributed.run:main",
-        ]
+        ],
+        "torchrun.logs_specs": [
+            "default = torch.distributed.elastic.multiprocessing:DefaultLogsSpecs",
+        ],
     }
 
     return extensions, cmdclass, packages, entry_points, extra_install_requires
diff --git a/test/distributed/launcher/run_test.py b/test/distributed/launcher/run_test.py
index be30e0f..f33d075 100644
--- a/test/distributed/launcher/run_test.py
+++ b/test/distributed/launcher/run_test.py
@@ -17,17 +17,18 @@
 import uuid
 from contextlib import closing
 from unittest import mock
-from unittest.mock import Mock, patch
+from unittest.mock import MagicMock, Mock, patch
 
 import torch.distributed.run as launch
 from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
+from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
 from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
 from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
 from torch.distributed.elastic.utils import get_socket_with_port
 from torch.distributed.elastic.utils.distributed import get_free_port
 from torch.testing._internal.common_utils import (
-    TEST_WITH_DEV_DBG_ASAN,
     skip_but_pass_in_sandcastle_if,
+    TEST_WITH_DEV_DBG_ASAN,
 )
 
 
@@ -504,6 +505,55 @@
             is_torchelastic_launched = fp.readline()
             self.assertEqual("True", is_torchelastic_launched)
 
+    @patch("torch.distributed.run.metadata")
+    @skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
+    def test_is_torchelastic_launched_with_logs_spec_defined(self, metadata_mock):
+        # mock the entrypoint API to avoid version issues.
+        entrypoints = MagicMock()
+        metadata_mock.entry_points.return_value = entrypoints
+
+        group = MagicMock()
+        entrypoints.select.return_value = group
+
+        ep = MagicMock()
+        ep.load.return_value = DefaultLogsSpecs
+
+        group.select.return_value = (ep)
+        group.__getitem__.return_value = ep
+
+        out_file = f"{os.path.join(self.test_dir, 'out')}"
+        if os.path.exists(out_file):
+            os.remove(out_file)
+        launch.main(
+            [
+                "--run-path",
+                "--nnodes=1",
+                "--nproc-per-node=1",
+                "--monitor-interval=1",
+                "--logs_specs=default",
+                path("bin/test_script_is_torchelastic_launched.py"),
+                f"--out-file={out_file}",
+            ]
+        )
+
+        with open(out_file) as fp:
+            is_torchelastic_launched = fp.readline()
+            self.assertEqual("True", is_torchelastic_launched)
+
+    @skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
+    def test_logs_logs_spec_entrypoint_must_be_defined(self):
+        with self.assertRaises(ValueError):
+            launch.main(
+                [
+                    "--run-path",
+                    "--nnodes=1",
+                    "--nproc-per-node=1",
+                    "--monitor-interval=1",
+                    "--logs_specs=DOESNOT_EXIST",
+                    path("bin/test_script_is_torchelastic_launched.py"),
+                ]
+            )
+
     def test_is_not_torchelastic_launched(self):
         # launch test script without torchelastic and validate that
         # torch.distributed.is_torchelastic_launched() returns False
diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py
index db60a66..d7e6a55 100644
--- a/torch/distributed/elastic/multiprocessing/__init__.py
+++ b/torch/distributed/elastic/multiprocessing/__init__.py
@@ -68,6 +68,7 @@
 from torch.distributed.elastic.multiprocessing.api import (  # noqa: F401
     _validate_full_rank,
     DefaultLogsSpecs,
+    LogsDest,
     LogsSpecs,
     MultiprocessContext,
     PContext,
@@ -88,6 +89,7 @@
     "RunProcsResult",
     "SignalException",
     "Std",
+    "LogsDest",
     "LogsSpecs",
     "DefaultLogsSpecs",
     "SubprocessContext",
diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py
index c53477e..b0bd8d1 100644
--- a/torch/distributed/elastic/multiprocessing/api.py
+++ b/torch/distributed/elastic/multiprocessing/api.py
@@ -184,10 +184,18 @@
 class LogsSpecs(ABC):
     """
     Defines logs processing and redirection for each worker process.
+
     Args:
-        log_dir: base directory where logs will be written
-        redirects: specifies which streams to redirect to files.
-        tee: specifies which streams to duplicate to stdout/stderr
+        log_dir:
+            Base directory where logs will be written.
+        redirects:
+            Streams to redirect to files. Pass a single ``Std``
+            enum to redirect for all workers, or a mapping keyed
+            by local_rank to selectively redirect.
+        tee:
+            Streams to duplicate to stdout/stderr.
+            Pass a single ``Std`` enum to duplicate streams for all workers,
+            or a mapping keyed by local_rank to selectively duplicate.
     """
 
     def __init__(
@@ -220,7 +228,8 @@
 class DefaultLogsSpecs(LogsSpecs):
     """
     Default LogsSpecs implementation:
-    - `log_dir` will be created if it doesn't exist and it is not set to os.devnull
+
+    - `log_dir` will be created if it doesn't exist and it is not set to `os.devnull`
     - Generates nested folders for each attempt and rank.
     """
     def __init__(
diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py
index 214d043..f2b4aca 100644
--- a/torch/distributed/launcher/api.py
+++ b/torch/distributed/launcher/api.py
@@ -54,12 +54,6 @@
                         as a period of monitoring workers.
         start_method: The method is used by the elastic agent to start the
                     workers (spawn, fork, forkserver).
-        log_dir: base log directory where log files are written. If not set,
-                one is created in a tmp dir but NOT removed on exit.
-        redirects: configuration to redirect stdout/stderr to log files.
-                Pass a single ``Std`` enum to redirect all workers,
-                or a mapping keyed by local_rank to selectively redirect.
-        tee: configuration to "tee" stdout/stderr to console + log file.
         metrics_cfg: configuration to initialize metrics.
         local_addr: address of the local node if any. If not set, a lookup on the local
                 machine's FQDN will be performed.
@@ -248,9 +242,9 @@
 
     agent = LocalElasticAgent(
         spec=spec,
+        logs_specs=config.logs_specs,  # type: ignore[arg-type]
         start_method=config.start_method,
         log_line_prefix_template=config.log_line_prefix_template,
-        logs_specs=config.logs_specs,  # type: ignore[arg-type]
     )
 
     shutdown_rdzv = True
diff --git a/torch/distributed/run.py b/torch/distributed/run.py
index b6f2fdd..4928f6c 100644
--- a/torch/distributed/run.py
+++ b/torch/distributed/run.py
@@ -375,12 +375,13 @@
 import os
 import sys
 import uuid
+import importlib.metadata as metadata
 from argparse import REMAINDER, ArgumentParser
-from typing import Callable, List, Tuple, Union, Optional, Set
+from typing import Callable, List, Tuple, Type, Union, Optional, Set
 
 import torch
 from torch.distributed.argparse_util import check_env, env
-from torch.distributed.elastic.multiprocessing import Std, DefaultLogsSpecs
+from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std
 from torch.distributed.elastic.multiprocessing.errors import record
 from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
 from torch.distributed.elastic.utils import macros
@@ -602,6 +603,15 @@
         "machine's FQDN.",
     )
 
+    parser.add_argument(
+        "--logs-specs",
+        "--logs_specs",
+        default=None,
+        type=str,
+        help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. "
+        "Can be used to override custom logging behavior.",
+    )
+
     #
     # Positional arguments.
     #
@@ -699,6 +709,36 @@
     return args.use_env
 
 
+def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
+    """
+    Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
+    Provides plugin mechanism to provide custom implementation of LogsSpecs.
+
+    Returns `DefaultLogsSpecs` when logs_spec_name is None.
+    Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints.
+    """
+    logs_specs_cls = None
+    if logs_specs_name is not None:
+        eps = metadata.entry_points()
+        if hasattr(eps, "select"):  # >= 3.10
+            group = eps.select(group="torchrun.logs_specs")
+            if group.select(name=logs_specs_name):
+                logs_specs_cls = group[logs_specs_name].load()
+
+        elif specs := eps.get("torchrun.logs_specs"):  # < 3.10
+            if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]:
+                logs_specs_cls = entrypoint_list[0].load()
+
+        if logs_specs_cls is None:
+            raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key")
+
+        logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls))
+    else:
+        logs_specs_cls = DefaultLogsSpecs
+
+    return logs_specs_cls
+
+
 def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]:
     # If ``args`` not passed, defaults to ``sys.argv[:1]``
     min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
@@ -745,7 +785,8 @@
                 "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
             ) from e
 
-    logs_specs = DefaultLogsSpecs(
+    logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
+    logs_specs = logs_specs_cls(
         log_dir=args.log_dir,
         redirects=Std.from_str(args.redirects),
         tee=Std.from_str(args.tee),