| #!/usr/bin/env python3 |
| # Owner(s): ["oncall: r2p"] |
| |
| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| import multiprocessing as mp |
| import os |
| import shutil |
| import signal |
| import sys |
| import tempfile |
| import time |
| import unittest |
| import uuid |
| from contextlib import closing |
| from typing import Any, Dict, Optional |
| from unittest import mock |
| from unittest.mock import MagicMock, Mock, patch |
| |
| import torch |
| import torch.distributed as dist |
| from torch.distributed.elastic.agent.server.api import RunResult, WorkerState |
| from torch.distributed.elastic.multiprocessing.api import SignalException |
| from torch.distributed.elastic.multiprocessing.errors import ChildFailedError |
| from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer |
| from torch.distributed.elastic.utils import get_socket_with_port |
| from torch.distributed.launcher.api import ( |
| LaunchConfig, |
| _get_entrypoint_name, |
| elastic_launch, |
| launch_agent, |
| ) |
| from torch.testing._internal.common_utils import ( |
| TEST_WITH_DEV_DBG_ASAN, |
| sandcastle_skip_if, |
| ) |
| |
| |
| def path(script): |
| return os.path.join(os.path.dirname(__file__), script) |
| |
| |
| def simple_rank_scale(): |
| rank = int(os.environ["RANK"]) |
| return 10 + rank |
| |
| |
| def function_with_bug(): |
| raise RuntimeError("test error") |
| |
| |
| def get_test_launch_config( |
| rdzv_endpoint: str, |
| min_nodes: int, |
| max_nodes: int, |
| nproc_per_node: int, |
| run_id: str = "", |
| rdzv_backend: str = "etcd", |
| config: Optional[Dict[str, Any]] = None, |
| ) -> LaunchConfig: |
| rdzv_configs = {} |
| if config: |
| rdzv_configs.update(config) |
| return LaunchConfig( |
| min_nodes=min_nodes, |
| max_nodes=max_nodes, |
| nproc_per_node=nproc_per_node, |
| run_id=run_id, |
| rdzv_endpoint=rdzv_endpoint, |
| monitor_interval=1, |
| rdzv_backend=rdzv_backend, |
| start_method="spawn", |
| max_restarts=0, |
| rdzv_configs=rdzv_configs, |
| ) |
| |
| |
| def elastic_launch_wrapper( |
| test_dir: str, |
| rdzv_endpoint: str, |
| min_nodes: int, |
| max_nodes: int, |
| nproc_per_node: int, |
| run_id: str, |
| ): |
| """A wrapper function for class `elastic_launch.` in order to make multiprocess returns correct exit code.""" |
| elastic_launch( |
| get_test_launch_config( |
| rdzv_endpoint, min_nodes, max_nodes, nproc_per_node, run_id |
| ), |
| sys.executable, |
| )("-u", path("bin/test_script.py"), f"--touch_file_dir={test_dir}") |
| |
| |
| def _dist_sum(wait=0): |
| rank = int(os.environ["RANK"]) |
| dist.init_process_group(backend="gloo") |
| t = torch.tensor(rank) |
| |
| time.sleep(wait) |
| dist.all_reduce(t, op=dist.reduce_op.SUM) |
| return t.item() |
| |
| |
| ELASTIC_AGENT_RUN = "torch.distributed.launcher.api.LocalElasticAgent.run" |
| EVENTS_RECORD = "torch.distributed.launcher.api.events.record" |
| GET_RDZV_HANDLER = ( |
| "torch.distributed.elastic.rendezvous.registry.get_rendezvous_handler" |
| ) |
| |
| |
| class MockException(Exception): |
| pass |
| |
| |
| def short_hash(): |
| return str(uuid.uuid4()).split("-")[0] |
| |
| |
| class ElasticLaunchTest(unittest.TestCase): |
| @classmethod |
| def setUpClass(cls): |
| # start a standalone, single process etcd server to use for all tests. |
| cls._etcd_server = EtcdServer() |
| cls._etcd_server.start() |
| cls._etcd_endpoint = cls._etcd_server.get_endpoint() |
| |
| @classmethod |
| def tearDownClass(cls): |
| # stop the standalone etcd server. |
| cls._etcd_server.stop() |
| |
| def setUp(self): |
| self.test_dir = tempfile.mkdtemp() |
| |
| # remove any lingering environment variables. |
| for env in os.environ.keys(): |
| if env.startswith("PET_"): |
| del os.environ[env] |
| |
| # set a sentinel env var on the parent proc. |
| # this should be present on the child and gets |
| # asserted in ``bin/test_script.py``. |
| os.environ["TEST_SENTINEL_PARENT"] = "FOOBAR" |
| os.environ["OMP_NUM_THREADS"] = str(1) |
| |
| def tearDown(self): |
| shutil.rmtree(self.test_dir) |
| |
| def check_works_ran(self, world_size: int): |
| self.assertSetEqual( |
| {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) |
| ) |
| |
| @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") |
| def test_launch_script_python(self): |
| nnodes = 1 |
| nproc_per_node = 4 |
| |
| elastic_launch( |
| get_test_launch_config(self._etcd_endpoint, nnodes, nnodes, nproc_per_node), |
| sys.executable, |
| )("-u", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}") |
| |
| # make sure all the workers ran. |
| # each worker touches a file with its global rank as the name. |
| world_size = nnodes * nproc_per_node |
| self.check_works_ran(world_size) |
| |
| @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") |
| def test_launch_script_python_local_rank_transfer(self): |
| nnodes = 1 |
| nproc_per_node = 4 |
| |
| elastic_launch( |
| get_test_launch_config(self._etcd_endpoint, nnodes, nnodes, nproc_per_node), |
| sys.executable, |
| )("-u", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}") |
| |
| # make sure all the workers ran. |
| # each worker touches a file with its global rank as the name. |
| world_size = nnodes * nproc_per_node |
| self.check_works_ran(world_size) |
| |
| @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") |
| def test_launch_script_bash(self): |
| nnodes = 1 |
| nproc_per_node = 4 |
| |
| elastic_launch( |
| get_test_launch_config(self._etcd_endpoint, nnodes, nnodes, nproc_per_node), |
| path("bin/test_script.sh"), |
| )(f"{self.test_dir}") |
| |
| world_size = nnodes * nproc_per_node |
| self.check_works_ran(world_size) |
| |
| @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") |
| def test_launch_function(self): |
| nnodes = 1 |
| nproc_per_node = 4 |
| |
| res = elastic_launch( |
| get_test_launch_config(self._etcd_endpoint, nnodes, nnodes, nproc_per_node), |
| simple_rank_scale, |
| )() |
| |
| expected_res = [10, 11, 12, 13] |
| actual_res = sorted(value for value in res.values()) |
| self.assertEqual(expected_res, actual_res) |
| |
| @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") |
| def test_launch_dist_sum_with_static_rdzv(self): |
| nnodes = 1 |
| nproc_per_node = 4 |
| sock = get_socket_with_port() |
| with closing(sock): |
| master_port = sock.getsockname()[1] |
| rdzv_endpoint = f"127.0.0.1:{master_port}" |
| rank = 0 |
| rdzv_config = { |
| "rank": rank, |
| } |
| |
| res = elastic_launch( |
| get_test_launch_config( |
| rdzv_endpoint, |
| nnodes, |
| nnodes, |
| nproc_per_node, |
| rdzv_backend="static", |
| config=rdzv_config, |
| ), |
| _dist_sum, |
| )() |
| |
| expected_res = [sum(range(nproc_per_node))] * nproc_per_node |
| actual_res = sorted(value for value in res.values()) |
| self.assertEqual(expected_res, actual_res) |
| |
| @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") |
| def test_launch_elastic(self): |
| nproc_per_node = 4 |
| |
| elastic_launch( |
| get_test_launch_config(self._etcd_endpoint, 1, 2, nproc_per_node), |
| sys.executable, |
| )("-u", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}") |
| |
| world_size = nproc_per_node |
| self.check_works_ran(world_size) |
| |
| @mock.patch("torch.distributed.elastic.events.record") |
| def test_launch_elastic_worker_raise_exception(self, record_mock): |
| """ |
| Asserts that when the worker program fails and lancher raieses exception |
| to indicate that worker process failed. |
| """ |
| nproc_per_node = 4 |
| |
| with self.assertRaises(ChildFailedError): |
| elastic_launch( |
| get_test_launch_config(self._etcd_endpoint, 1, 2, nproc_per_node), |
| sys.executable, |
| )("-u", path("bin/test_script.py"), "--fail") |
| |
| record_mock.assert_called_once() |
| |
| @mock.patch("torch.distributed.elastic.events.record") |
| @mock.patch( |
| "torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent.run" |
| ) |
| def test_launch_elastic_agent_raise_exception(self, record_mock, mock_agent_run): |
| """ |
| Asserts that when the agent raises an exception |
| the launcher re-raises the original exception. |
| """ |
| mock_agent_run.side_effect = MockException |
| with self.assertRaises(MockException): |
| elastic_launch( |
| get_test_launch_config(self._etcd_endpoint, 1, 2, 4), |
| sys.executable, |
| )("-u", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}") |
| record_mock.assert_called_once() |
| |
| @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") |
| def test_launch_elastic_multiple_agents(self): |
| min_nodes = 1 |
| max_nodes = 2 |
| nproc_per_node = 4 |
| nnodes = 2 |
| run_id = str(uuid.uuid4().int) |
| |
| procs = [] |
| ctx = mp.get_context("spawn") |
| for _ in range(nnodes - 1): |
| p = ctx.Process( |
| target=elastic_launch_wrapper, |
| args=( |
| self.test_dir, |
| self._etcd_endpoint, |
| min_nodes, |
| max_nodes, |
| nproc_per_node, |
| run_id, |
| ), |
| ) |
| procs.append(p) |
| p.start() |
| |
| elastic_launch_wrapper( |
| self.test_dir, |
| self._etcd_endpoint, |
| min_nodes, |
| max_nodes, |
| nproc_per_node, |
| run_id, |
| ) |
| |
| for i in range(nnodes - 1): |
| p = procs[i] |
| p.join() |
| self.assertEqual(0, p.exitcode) |
| |
| # make sure all the workers ran |
| # each worker touches a file with its global rank as the name |
| world_size = nnodes * nproc_per_node |
| self.assertSetEqual( |
| {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) |
| ) |
| |
| @patch("torch.distributed.launcher.api.LocalElasticAgent") |
| def test_launch_shutdown(self, agent_mock_cls): |
| agent_mock = Mock() |
| agent_mock.run.return_value = RunResult(WorkerState.SUCCEEDED) |
| agent_mock_cls.return_value = agent_mock |
| rdzv_handler_mock = Mock() |
| with patch( |
| "torch.distributed.elastic.rendezvous.registry.get_rendezvous_handler" |
| ) as param_mock: |
| param_mock.return_value = rdzv_handler_mock |
| elastic_launch( |
| get_test_launch_config(self._etcd_endpoint, 1, 1, 4), |
| sys.executable, |
| )("-u", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}") |
| |
| rdzv_handler_mock.shutdown.assert_called_once() |
| |
| def test_get_entrypoint_name(self): |
| self.assertEqual( |
| "simple_rank_scale", _get_entrypoint_name(simple_rank_scale, []) |
| ) |
| self.assertEqual("", _get_entrypoint_name(sys.executable, [])) |
| self.assertEqual("", _get_entrypoint_name(sys.executable, ["-u"])) |
| self.assertEqual( |
| "test_script.py", |
| _get_entrypoint_name(sys.executable, ["-u", "test_script.py"]), |
| ) |
| self.assertEqual("", _get_entrypoint_name(None, [])) |
| |
| @patch(ELASTIC_AGENT_RUN) |
| @patch(GET_RDZV_HANDLER) |
| def test_rdzv_handler_shutdown_on_agent_signal(self, mock_get_rdzv, mock_agent_run): |
| config = get_test_launch_config( |
| self._etcd_endpoint, min_nodes=1, max_nodes=1, nproc_per_node=1 |
| ) |
| |
| for sigval in [signal.SIGTERM, signal.SIGINT]: |
| with patch(EVENTS_RECORD) as record_event_mock: |
| rdzv_handler_mock = MagicMock() |
| rdzv_handler_mock.get_run_id.return_value = short_hash() |
| mock_get_rdzv.return_value = rdzv_handler_mock |
| |
| mock_agent_run.side_effect = SignalException("test", sigval) |
| with self.assertRaises(SignalException): |
| launch_agent(config, simple_rank_scale, []) |
| rdzv_handler_mock.shutdown.assert_not_called() |
| record_event_mock.assert_called_once() |
| |
| @patch(ELASTIC_AGENT_RUN) |
| @patch(GET_RDZV_HANDLER) |
| def test_rdzv_handler_shutdown_on_agent_error(self, mock_get_rdzv, mock_agent_run): |
| config = get_test_launch_config( |
| self._etcd_endpoint, min_nodes=1, max_nodes=1, nproc_per_node=1 |
| ) |
| |
| with patch(EVENTS_RECORD) as record_event_mock: |
| rdzv_handler_mock = MagicMock() |
| rdzv_handler_mock.get_run_id.return_value = short_hash() |
| mock_get_rdzv.return_value = rdzv_handler_mock |
| |
| mock_agent_run.side_effect = RuntimeError("any other exception") |
| with self.assertRaises(RuntimeError): |
| launch_agent(config, simple_rank_scale, []) |
| rdzv_handler_mock.shutdown.assert_called_once() |
| record_event_mock.assert_called_once() |