[torch] Various improvements to `torch.distributed.launch` and `torch.distributed.run` (#61294)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61294
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60925
* Make `torch.distributed.launch` restarts to 0
* Remove unnecessary `-use_env` warning, move `-use_env` warnings
* Move `-use_env` warnings to `torch.distributed.launch`
* Make default log level WARNING
* Add new doc section around transitioning to `torch.distributed.run`
* Make `torch.distributed.launch` not use error-propagation
* Set default events handler to `null` that does not print events to console
* Add reference from `torch.distributed.launch` to `torch.distributed.run`
* Set correct preexec function that sends SIGTERM to child processes when parent dies
Issues resolved:
https://github.com/pytorch/pytorch/issues/60716
https://github.com/pytorch/pytorch/issues/60754
Test Plan:
sandcastle
python -m torch.distributed.launch --nproc_per_node 2 main.py -> uses 0 restarts
python -m torch.distributed.run --nproc_per_node 2 main.py -> uses default for torchelastic, 0 restarts
python -m torch.distributed.launch --nproc_per_node=4 --use_env --no_python main.py -> produces error
python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py -> no warning
python -m torch.distributed.launch --nproc_per_node=4 --no_python main.py ->warning
Output of running torch.distributed.launch without --use_env:
$path/torch/distributed/launch.py:173: FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torch.distributed.run.
Note that --use_env is set by default in torch.distributed.run.
If your script expects `--local_rank` argument to be set, please
change it to read from `os.environ('LOCAL_RANK')` instead.
New section:
{F628923078}
{F628974089}
Reviewed By: cbalioglu
Differential Revision: D29559553
fbshipit-source-id: 03ed9ba638bf154354e1530ffc964688431edf6b
diff --git a/docs/source/elastic/errors.rst b/docs/source/elastic/errors.rst
index 1105d1b..7734bcd 100644
--- a/docs/source/elastic/errors.rst
+++ b/docs/source/elastic/errors.rst
@@ -1,3 +1,5 @@
+.. _elastic_errors-api:
+
Error Propagation
==================
diff --git a/docs/source/elastic/run.rst b/docs/source/elastic/run.rst
index 6e1eac8..fb870fa 100644
--- a/docs/source/elastic/run.rst
+++ b/docs/source/elastic/run.rst
@@ -1,9 +1,6 @@
.. _launcher-api:
-Elastic Launch
-============================
-
-torch.distributed.run
-----------------------
+torch.distributed.run (Elastic Launch)
+======================================
.. automodule:: torch.distributed.run
diff --git a/docs/source/elastic/train_script.rst b/docs/source/elastic/train_script.rst
index 4d9eea8..ab63159 100644
--- a/docs/source/elastic/train_script.rst
+++ b/docs/source/elastic/train_script.rst
@@ -1,3 +1,5 @@
+.. _elastic_train_script:
+
Train script
-------------
@@ -7,18 +9,20 @@
1. No need to manually pass ``RANK``, ``WORLD_SIZE``,
``MASTER_ADDR``, and ``MASTER_PORT``.
-2. ``rdzv_backend`` and ``rdzv_endpoint`` must be provided. For most users
- this will be set to ``c10d`` (see `rendezvous <rendezvous.html>`_).
+2. ``rdzv_backend`` and ``rdzv_endpoint`` can be provided. For most users
+ this will be set to ``c10d`` (see `rendezvous <rendezvous.html>`_). The default
+ ``rdzv_backend`` creates a non-elastic rendezvous where ``rdzv_endpoint`` holds
+ the master address.
3. Make sure you have a ``load_checkpoint(path)`` and
- ``save_checkpoint(path)`` logic in your script. When workers fail
- we restart all the workers with the same program arguments so you will
- lose progress up to the most recent checkpoint
+ ``save_checkpoint(path)`` logic in your script. When any number of
+ workers fail we restart all the workers with the same program
+ arguments so you will lose progress up to the most recent checkpoint
(see `elastic launch <distributed.html>`_).
4. ``use_env`` flag has been removed. If you were parsing local rank by parsing
the ``--local_rank`` option, you need to get the local rank from the
- environment variable ``LOCAL_RANK`` (e.g. ``os.environ["LOCAL_RANK"]``).
+ environment variable ``LOCAL_RANK`` (e.g. ``int(os.environ["LOCAL_RANK"])``).
Below is an expository example of a training script that checkpoints on each
epoch, hence the worst-case progress lost on failure is one full epoch worth
@@ -31,7 +35,7 @@
state = load_checkpoint(args.checkpoint_path)
initialize(state)
- # torch.distributed.run ensure that this will work
+ # torch.distributed.run ensures that this will work
# by exporting all the env vars needed to initialize the process group
torch.distributed.init_process_group(backend=args.backend)
diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py
index 8f068a5..44d9c5a 100644
--- a/torch/distributed/elastic/agent/server/local_elastic_agent.py
+++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py
@@ -205,7 +205,6 @@
result = self._pcontext.wait(0)
if result:
if result.is_failed():
- log.error(f"[{role}] Worker group failed")
# map local rank failure to global rank
worker_failures = {}
for local_rank, failure in result.failures.items():
diff --git a/torch/distributed/elastic/events/__init__.py b/torch/distributed/elastic/events/__init__.py
index 3e4dff1..4132ca7 100644
--- a/torch/distributed/elastic/events/__init__.py
+++ b/torch/distributed/elastic/events/__init__.py
@@ -19,6 +19,7 @@
"""
+import os
import logging
from torch.distributed.elastic.events.handlers import get_logging_handler
@@ -46,12 +47,12 @@
return _events_logger
logging_handler = get_logging_handler(destination)
_events_logger = logging.getLogger(f"torchelastic-events-{destination}")
- _events_logger.setLevel(logging.DEBUG)
+ _events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO"))
# Do not propagate message to the root logger
_events_logger.propagate = False
_events_logger.addHandler(logging_handler)
return _events_logger
-def record(event: Event, destination: str = "console") -> None:
+def record(event: Event, destination: str = "null") -> None:
_get_or_create_logger(destination).info(event.serialize())
diff --git a/torch/distributed/elastic/events/handlers.py b/torch/distributed/elastic/events/handlers.py
index bfdc1c1..63df63d 100644
--- a/torch/distributed/elastic/events/handlers.py
+++ b/torch/distributed/elastic/events/handlers.py
@@ -12,8 +12,9 @@
_log_handlers: Dict[str, logging.Handler] = {
"console": logging.StreamHandler(),
+ "null": logging.NullHandler(),
}
-def get_logging_handler(destination: str = "console") -> logging.Handler:
+def get_logging_handler(destination: str = "null") -> logging.Handler:
return _log_handlers[destination]
diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py
index 341a033..11ef96f 100644
--- a/torch/distributed/elastic/multiprocessing/api.py
+++ b/torch/distributed/elastic/multiprocessing/api.py
@@ -465,24 +465,32 @@
entrypoint: str,
args: Tuple,
env: Dict[str, str],
- preexec_fn: Callable,
+ preexec_fn: Optional[Callable],
stdout: str,
stderr: str,
):
self._stdout = open(stdout, "w") if stdout else None
self._stderr = open(stderr, "w") if stderr else None
- args_str = [str(e) for e in args]
-
# inherit parent environment vars
env_vars = os.environ.copy()
env_vars.update(env)
- self.proc: subprocess.Popen = subprocess.Popen(
+ args_str = (entrypoint, *[str(e) for e in args])
+ self.proc: subprocess.Popen = self._popen(args_str, env_vars, preexec_fn)
+
+ def _popen(
+ self, args: Tuple, env: Dict[str, str], preexec_fn: Optional[Callable]
+ ) -> subprocess.Popen:
+ if IS_WINDOWS:
+ # Reset preexec_fn on windows, since windows does not support it
+ preexec_fn = None
+
+ return subprocess.Popen(
# pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],
# _PathLike[str], bytes, str]], bytes, str]` for 1st param but got
# `Tuple[str, *Tuple[Any, ...]]`.
- args=(entrypoint, *args_str),
- env=env_vars,
+ args=args,
+ env=env,
preexec_fn=preexec_fn,
stdout=self._stdout,
stderr=self._stderr,
@@ -497,6 +505,17 @@
self._stderr.close()
+def _pr_set_pdeathsig() -> None:
+ """
+ Sets PR_SET_PDEATHSIG to ensure a child process is
+ terminated appropriately.
+
+ See http://stackoverflow.com/questions/1884941/ for more information.
+ For libc.so.6 read http://www.linux-m68k.org/faq/glibcinfo.html
+ """
+ mp._prctl_pr_set_pdeathsig(signal.SIGTERM) # type: ignore[attr-defined]
+
+
class SubprocessContext(PContext):
"""
``PContext`` holding worker processes invoked as a binary.
@@ -541,7 +560,7 @@
entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str
args=self.args[local_rank],
env=self.envs[local_rank],
- preexec_fn=mp._prctl_pr_set_pdeathsig(signal.SIGTERM), # type: ignore[attr-defined]
+ preexec_fn=_pr_set_pdeathsig,
stdout=self.stdouts[local_rank],
stderr=self.stderrs[local_rank],
)
diff --git a/torch/distributed/elastic/utils/logging.py b/torch/distributed/elastic/utils/logging.py
index 526836d..19c68c0 100644
--- a/torch/distributed/elastic/utils/logging.py
+++ b/torch/distributed/elastic/utils/logging.py
@@ -17,7 +17,7 @@
"""
Util function to set up a simple logger that writes
into stderr. The loglevel is fetched from the LOGLEVEL
- env. variable or INFO as default. The function will use the
+ env. variable or WARNING as default. The function will use the
module name of the caller if no name is provided.
Args:
@@ -32,7 +32,7 @@
def _setup_logger(name: Optional[str] = None):
log = logging.getLogger(name)
- log.setLevel(os.environ.get("LOGLEVEL", "INFO"))
+ log.setLevel(os.environ.get("LOGLEVEL", "WARNING"))
return log
diff --git a/torch/distributed/elastic/utils/store.py b/torch/distributed/elastic/utils/store.py
index 7879a35..81395bb 100644
--- a/torch/distributed/elastic/utils/store.py
+++ b/torch/distributed/elastic/utils/store.py
@@ -6,7 +6,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-import warnings
from datetime import timedelta
from typing import List
@@ -64,8 +63,5 @@
Note: Since the data is not removed from the store, the barrier can be used
once per unique ``key_prefix``.
"""
- warnings.warn(
- "This is an experimental API and will be changed in future.", FutureWarning
- )
data = f"{rank}".encode(encoding="UTF-8")
synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)
diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py
index c6aa0ec..5fcb3eb 100644
--- a/torch/distributed/launch.py
+++ b/torch/distributed/launch.py
@@ -1,8 +1,10 @@
r"""
-`torch.distributed.launch` is a module that spawns up multiple distributed
+``torch.distributed.launch`` is a module that spawns up multiple distributed
training processes on each of the training nodes.
-NOTE: This module is deprecated, use torch.distributed.run.
+.. warning::
+
+ This module is going to be deprecated in favor of :ref:`torch.distributed.run <launcher-api>`.
The utility can be used for single-node distributed training, in which one or
more processes per node will be spawned. The utility can be used for either
@@ -136,9 +138,12 @@
https://github.com/pytorch/pytorch/issues/12042 for an example of
how things can go wrong if you don't do this correctly.
+
+
"""
import logging
+import warnings
from torch.distributed.run import get_args_parser, run
@@ -159,14 +164,27 @@
return parser.parse_args(args)
+def launch(args):
+ if args.no_python and not args.use_env:
+ raise ValueError(
+ "When using the '--no_python' flag,"
+ " you must also set the '--use_env' flag."
+ )
+ run(args)
+
+
def main(args=None):
- logger.warning(
- "The module torch.distributed.launch is deprecated "
- "and going to be removed in future."
- "Migrate to torch.distributed.run"
+ warnings.warn(
+ "The module torch.distributed.launch is deprecated\n"
+ "and will be removed in future. Use torch.distributed.run.\n"
+ "Note that --use_env is set by default in torch.distributed.run.\n"
+ "If your script expects `--local_rank` argument to be set, please\n"
+ "change it to read from `os.environ['LOCAL_RANK']` instead. See \n"
+ "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n"
+ "further instructions\n", FutureWarning
)
args = parse_args(args)
- run(args)
+ launch(args)
if __name__ == "__main__":
diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py
index 336474b..1ee5abd 100644
--- a/torch/distributed/launcher/api.py
+++ b/torch/distributed/launcher/api.py
@@ -15,7 +15,7 @@
from torch.distributed.elastic.agent.server.api import WorkerSpec, WorkerState
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from torch.distributed.elastic.multiprocessing import Std
-from torch.distributed.elastic.multiprocessing.errors import ChildFailedError, record
+from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
from torch.distributed.elastic.utils.logging import get_logger
@@ -172,7 +172,6 @@
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# torch.distributed.elastic.multiprocessing.errors.record.
-@record
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
diff --git a/torch/distributed/run.py b/torch/distributed/run.py
index acdbf0e..0dc3b05 100644
--- a/torch/distributed/run.py
+++ b/torch/distributed/run.py
@@ -7,8 +7,8 @@
# LICENSE file in the root directory of this source tree.
"""
-This module provides similar functionality as ``torch.distributed.launch`` with the following
-additional functionalities:
+``torch.distributed.run`` provides a superset of the functionality as ``torch.distributed.launch``
+with the following additional functionalities:
1. Worker failures are handled gracefully by restarting all workers.
@@ -16,7 +16,60 @@
3. Number of nodes is allowed to change between minimum and maximum sizes (elasticity).
-**Usage:**
+
+
+Transitioning from torch.distributed.launch to torch.distributed.run
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+
+``torch.distributed.run`` supports the same arguments as ``torch.distributed.launch`` **except**
+for ``--use_env`` which is now deprecated. To migrate from ``torch.distributed.launch``
+to ``torch.distributed.run`` follow these steps:
+
+1. If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable.
+ Then you need simply omit the ``--use_env`` flag, e.g.:
+
+ +--------------------------------------------------------------------+------------------------------------------------------+
+ | ``torch.distributed.launch`` | ``torch.distributed.run`` |
+ +====================================================================+======================================================+
+ | | |
+ | .. code-block:: shell-session | .. code-block:: shell-session |
+ | | |
+ | $ python -m torch.distributed.launch --use_env train_script.py | $ python -m torch.distributed.run train_script.py |
+ | | |
+ +--------------------------------------------------------------------+------------------------------------------------------+
+
+2. If your training script reads local rank from a ``--local_rank`` cmd argument.
+ Change your training script to read from the ``LOCAL_RANK`` environment variable as
+ demonstrated by the following code snippet:
+
+ +-------------------------------------------------------+----------------------------------------------------+
+ | ``torch.distributed.launch`` | ``torch.distributed.run`` |
+ +=======================================================+====================================================+
+ | | |
+ | .. code-block:: python | .. code-block:: python |
+ | | |
+ | | |
+ | import argparse | import os |
+ | parser = argparse.ArgumentParser() | local_rank = int(os.environ["LOCAL_RANK"]) |
+ | parser.add_argument("--local_rank", type=int) | |
+ | args = parser.parse_args() | |
+ | | |
+ | local_rank = args.local_rank | |
+ | | |
+ +-------------------------------------------------------+----------------------------------------------------+
+
+The aformentioned changes suffice to migrate from ``torch.distributed.launch`` to ``torch.distributed.run``.
+To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torch.distributed.run``
+please refer to:
+
+* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torch.distributed.run`` compliant.
+* the rest of this page for more information on the features of ``torch.distributed.run``.
+
+
+
+Usage
+~~~~~~
1. Single-node multi-worker
@@ -188,8 +241,10 @@
**Important Notices:**
-1. All the items in the important notices section of ``torch.distributed.launch`` apply to this
- module as well.
+1. This utility and multi-process distributed (single-node or
+ multi-node) GPU training currently only achieves the best performance using
+ the NCCL distributed backend. Thus NCCL backend is the recommended backend to
+ use for GPU training.
2. The environment variables necessary to initialize a Torch process group are provided to you by
this module, no need for you to pass ``RANK`` manually. To initialize a process group in your
@@ -200,21 +255,41 @@
>>> import torch.distributed as dist
>>> dist.init_process_group(backend="gloo|nccl")
-3. On failures or membership changes ALL surviving workers are killed immediately. Make sure to
+3. In your training program, you can either use regular distributed functions
+ or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
+ training program uses GPUs for training and you would like to use
+ :func:`torch.nn.parallel.DistributedDataParallel` module,
+ here is how to configure it.
+
+::
+
+ local_rank = int(os.environ["LOCAL_RANK"])
+ model = torch.nn.parallel.DistributedDataParallel(model,
+ device_ids=[local_rank],
+ output_device=local_rank)
+
+Please ensure that ``device_ids`` argument is set to be the only GPU device id
+that your code will be operating on. This is generally the local rank of the
+process. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``,
+and ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this
+utility
+
+
+4. On failures or membership changes ALL surviving workers are killed immediately. Make sure to
checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance
for lost work.
-4. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all
+5. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all
nodes run the same number of local workers (per role).
-5. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assgined a
+6. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assgined a
different range of ranks than before. NEVER hard code any assumptions about the stable-ness of
ranks or some correlation between ``RANK`` and ``LOCAL_RANK``.
-6. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about
+7. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about
``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join.
-7. It is recommended for your script to have the following structure:
+8. It is recommended for your script to have the following structure:
::
@@ -244,7 +319,7 @@
from torch.distributed.elastic.utils import macros
from torch.distributed.elastic.utils.logging import get_logger
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
-
+from torch.distributed.elastic.multiprocessing.errors import record
log = get_logger()
@@ -322,7 +397,7 @@
"--max_restarts",
action=env,
type=int,
- default=3,
+ default=0,
help="Maximum number of worker group restarts before failing.",
)
parser.add_argument(
@@ -570,11 +645,6 @@
cmd_args.append("-m")
cmd_args.append(args.training_script)
else:
- if not use_env:
- raise ValueError(
- "When using the '--no_python' flag,"
- " you must also set the '--use_env' flag."
- )
if args.module:
raise ValueError(
"Don't use both the '--no_python' flag"
@@ -582,10 +652,6 @@
)
cmd = args.training_script
if not use_env:
- log.warning(
- "--use_env is deprecated and will be removed in future releases.\n"
- " Please read local_rank from `os.environ['LOCAL_RANK']` instead."
- )
cmd_args.append(f"--local_rank={macros.local_rank}")
cmd_args.extend(args.training_script_args)
@@ -625,14 +691,11 @@
)(*cmd_args)
+@record
def main(args=None):
args = parse_args(args)
run(args)
if __name__ == "__main__":
- logging.basicConfig(
- level=logging.INFO, format="[%(levelname)s] %(asctime)s %(module)s: %(message)s"
- )
- log.info(f"Running torch.distributed.run with args: {sys.argv}")
main()