[Torchelastic][Logging] Pluggable logsspecs using python entrypoints and option to specify one by name. (#120942)
Summary:
Expose an option to users to specify name of the LogsSpec implementation to use.
- Has to be defined in entrypoints under `torchrun.logs_specs` group.
- Must implement LogsSpec defined in prior PR/diff.
Test Plan: unit test+local tests
Reviewed By: ezyang
Differential Revision: D54180838
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120942
Approved by: https://github.com/ezyang
diff --git a/setup.py b/setup.py
index 7eec481..84cd3fd 100644
--- a/setup.py
+++ b/setup.py
@@ -1055,7 +1055,10 @@
"convert-caffe2-to-onnx = caffe2.python.onnx.bin.conversion:caffe2_to_onnx",
"convert-onnx-to-caffe2 = caffe2.python.onnx.bin.conversion:onnx_to_caffe2",
"torchrun = torch.distributed.run:main",
- ]
+ ],
+ "torchrun.logs_specs": [
+ "default = torch.distributed.elastic.multiprocessing:DefaultLogsSpecs",
+ ],
}
return extensions, cmdclass, packages, entry_points, extra_install_requires
diff --git a/test/distributed/launcher/run_test.py b/test/distributed/launcher/run_test.py
index be30e0f..f33d075 100644
--- a/test/distributed/launcher/run_test.py
+++ b/test/distributed/launcher/run_test.py
@@ -17,17 +17,18 @@
import uuid
from contextlib import closing
from unittest import mock
-from unittest.mock import Mock, patch
+from unittest.mock import MagicMock, Mock, patch
import torch.distributed.run as launch
from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
+from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
from torch.distributed.elastic.utils import get_socket_with_port
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.testing._internal.common_utils import (
- TEST_WITH_DEV_DBG_ASAN,
skip_but_pass_in_sandcastle_if,
+ TEST_WITH_DEV_DBG_ASAN,
)
@@ -504,6 +505,55 @@
is_torchelastic_launched = fp.readline()
self.assertEqual("True", is_torchelastic_launched)
+ @patch("torch.distributed.run.metadata")
+ @skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
+ def test_is_torchelastic_launched_with_logs_spec_defined(self, metadata_mock):
+ # mock the entrypoint API to avoid version issues.
+ entrypoints = MagicMock()
+ metadata_mock.entry_points.return_value = entrypoints
+
+ group = MagicMock()
+ entrypoints.select.return_value = group
+
+ ep = MagicMock()
+ ep.load.return_value = DefaultLogsSpecs
+
+ group.select.return_value = (ep)
+ group.__getitem__.return_value = ep
+
+ out_file = f"{os.path.join(self.test_dir, 'out')}"
+ if os.path.exists(out_file):
+ os.remove(out_file)
+ launch.main(
+ [
+ "--run-path",
+ "--nnodes=1",
+ "--nproc-per-node=1",
+ "--monitor-interval=1",
+ "--logs_specs=default",
+ path("bin/test_script_is_torchelastic_launched.py"),
+ f"--out-file={out_file}",
+ ]
+ )
+
+ with open(out_file) as fp:
+ is_torchelastic_launched = fp.readline()
+ self.assertEqual("True", is_torchelastic_launched)
+
+ @skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
+ def test_logs_logs_spec_entrypoint_must_be_defined(self):
+ with self.assertRaises(ValueError):
+ launch.main(
+ [
+ "--run-path",
+ "--nnodes=1",
+ "--nproc-per-node=1",
+ "--monitor-interval=1",
+ "--logs_specs=DOESNOT_EXIST",
+ path("bin/test_script_is_torchelastic_launched.py"),
+ ]
+ )
+
def test_is_not_torchelastic_launched(self):
# launch test script without torchelastic and validate that
# torch.distributed.is_torchelastic_launched() returns False
diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py
index db60a66..d7e6a55 100644
--- a/torch/distributed/elastic/multiprocessing/__init__.py
+++ b/torch/distributed/elastic/multiprocessing/__init__.py
@@ -68,6 +68,7 @@
from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
_validate_full_rank,
DefaultLogsSpecs,
+ LogsDest,
LogsSpecs,
MultiprocessContext,
PContext,
@@ -88,6 +89,7 @@
"RunProcsResult",
"SignalException",
"Std",
+ "LogsDest",
"LogsSpecs",
"DefaultLogsSpecs",
"SubprocessContext",
diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py
index c53477e..b0bd8d1 100644
--- a/torch/distributed/elastic/multiprocessing/api.py
+++ b/torch/distributed/elastic/multiprocessing/api.py
@@ -184,10 +184,18 @@
class LogsSpecs(ABC):
"""
Defines logs processing and redirection for each worker process.
+
Args:
- log_dir: base directory where logs will be written
- redirects: specifies which streams to redirect to files.
- tee: specifies which streams to duplicate to stdout/stderr
+ log_dir:
+ Base directory where logs will be written.
+ redirects:
+ Streams to redirect to files. Pass a single ``Std``
+ enum to redirect for all workers, or a mapping keyed
+ by local_rank to selectively redirect.
+ tee:
+ Streams to duplicate to stdout/stderr.
+ Pass a single ``Std`` enum to duplicate streams for all workers,
+ or a mapping keyed by local_rank to selectively duplicate.
"""
def __init__(
@@ -220,7 +228,8 @@
class DefaultLogsSpecs(LogsSpecs):
"""
Default LogsSpecs implementation:
- - `log_dir` will be created if it doesn't exist and it is not set to os.devnull
+
+ - `log_dir` will be created if it doesn't exist and it is not set to `os.devnull`
- Generates nested folders for each attempt and rank.
"""
def __init__(
diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py
index 214d043..f2b4aca 100644
--- a/torch/distributed/launcher/api.py
+++ b/torch/distributed/launcher/api.py
@@ -54,12 +54,6 @@
as a period of monitoring workers.
start_method: The method is used by the elastic agent to start the
workers (spawn, fork, forkserver).
- log_dir: base log directory where log files are written. If not set,
- one is created in a tmp dir but NOT removed on exit.
- redirects: configuration to redirect stdout/stderr to log files.
- Pass a single ``Std`` enum to redirect all workers,
- or a mapping keyed by local_rank to selectively redirect.
- tee: configuration to "tee" stdout/stderr to console + log file.
metrics_cfg: configuration to initialize metrics.
local_addr: address of the local node if any. If not set, a lookup on the local
machine's FQDN will be performed.
@@ -248,9 +242,9 @@
agent = LocalElasticAgent(
spec=spec,
+ logs_specs=config.logs_specs, # type: ignore[arg-type]
start_method=config.start_method,
log_line_prefix_template=config.log_line_prefix_template,
- logs_specs=config.logs_specs, # type: ignore[arg-type]
)
shutdown_rdzv = True
diff --git a/torch/distributed/run.py b/torch/distributed/run.py
index b6f2fdd..4928f6c 100644
--- a/torch/distributed/run.py
+++ b/torch/distributed/run.py
@@ -375,12 +375,13 @@
import os
import sys
import uuid
+import importlib.metadata as metadata
from argparse import REMAINDER, ArgumentParser
-from typing import Callable, List, Tuple, Union, Optional, Set
+from typing import Callable, List, Tuple, Type, Union, Optional, Set
import torch
from torch.distributed.argparse_util import check_env, env
-from torch.distributed.elastic.multiprocessing import Std, DefaultLogsSpecs
+from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
from torch.distributed.elastic.utils import macros
@@ -602,6 +603,15 @@
"machine's FQDN.",
)
+ parser.add_argument(
+ "--logs-specs",
+ "--logs_specs",
+ default=None,
+ type=str,
+ help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. "
+ "Can be used to override custom logging behavior.",
+ )
+
#
# Positional arguments.
#
@@ -699,6 +709,36 @@
return args.use_env
+def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
+ """
+ Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
+ Provides plugin mechanism to provide custom implementation of LogsSpecs.
+
+ Returns `DefaultLogsSpecs` when logs_spec_name is None.
+ Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints.
+ """
+ logs_specs_cls = None
+ if logs_specs_name is not None:
+ eps = metadata.entry_points()
+ if hasattr(eps, "select"): # >= 3.10
+ group = eps.select(group="torchrun.logs_specs")
+ if group.select(name=logs_specs_name):
+ logs_specs_cls = group[logs_specs_name].load()
+
+ elif specs := eps.get("torchrun.logs_specs"): # < 3.10
+ if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]:
+ logs_specs_cls = entrypoint_list[0].load()
+
+ if logs_specs_cls is None:
+ raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key")
+
+ logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls))
+ else:
+ logs_specs_cls = DefaultLogsSpecs
+
+ return logs_specs_cls
+
+
def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]:
# If ``args`` not passed, defaults to ``sys.argv[:1]``
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
@@ -745,7 +785,8 @@
"--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
) from e
- logs_specs = DefaultLogsSpecs(
+ logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
+ logs_specs = logs_specs_cls(
log_dir=args.log_dir,
redirects=Std.from_str(args.redirects),
tee=Std.from_str(args.tee),