Migrate test to internal base class, fixes (#128367)
Summary:
## Remove etc deps
converted tests to non-etcd based rdzv handler so that tests don't have dependency on etcd server
## Adopt pytorch test convetions
- test starts with `test_TESTS.py`
- Test base class is torch.testing._internal.common_utils.TestCase
- include __main__ handler
## reduce test timing (used to take > 300 seconds):
3.05s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_init_method_env_with_torchelastic
2.59s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_init_method_tcp_with_torchelastic
2.33s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_elastic_worker_raise_exception
2.33s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_run_path
2.30s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_nproc_launch_auto_configurations
2.24s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_is_torchelastic_launched_with_logs_spec_defined
2.24s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_is_torchelastic_launched
2.17s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_elastic_multiple_agents
2.12s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_elastic
2.08s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_nproc_gpu_launch_configurations
1.32s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_standalone
1.05s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_nproc_launch_number_configurations
1.05s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_with_env_vars
1.05s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_user_script_python
1.05s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_user_script_python_caffe2_bc
1.04s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_user_script_bash
1.03s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_user_script_default_nproc
0.04s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_logs_logs_spec_entrypoint_must_be_defined
0.01s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_elastic_agent_raise_exception
0.01s call test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_shutdown
Test Plan: pytest --durations=0 test/distributed/launcher/run_test.py
Differential Revision: D58388182
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128367
Approved by: https://github.com/d4l3k
diff --git a/test/distributed/launcher/run_test.py b/test/distributed/launcher/test_run.py
similarity index 89%
rename from test/distributed/launcher/run_test.py
rename to test/distributed/launcher/test_run.py
index c816042..ba58aec 100644
--- a/test/distributed/launcher/run_test.py
+++ b/test/distributed/launcher/test_run.py
@@ -13,7 +13,6 @@
import subprocess
import sys
import tempfile
-import unittest
import uuid
from contextlib import closing
from unittest import mock
@@ -23,12 +22,13 @@
from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
-from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
from torch.distributed.elastic.utils import get_socket_with_port
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.testing._internal.common_utils import (
+ run_tests,
skip_but_pass_in_sandcastle_if,
TEST_WITH_DEV_DBG_ASAN,
+ TestCase,
)
@@ -63,19 +63,7 @@
pass
-class ElasticLaunchTest(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- # start a standalone, single process etcd server to use for all tests
- cls._etcd_server = EtcdServer()
- cls._etcd_server.start()
- cls._etcd_endpoint = cls._etcd_server.get_endpoint()
-
- @classmethod
- def tearDownClass(cls):
- # stop the standalone etcd server
- cls._etcd_server.stop()
-
+class ElasticLaunchTest(TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
@@ -103,8 +91,6 @@
args = [
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_per_node}",
- "--rdzv-backend=etcd",
- f"--rdzv-endpoint={self._etcd_endpoint}",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
@@ -156,8 +142,6 @@
args = [
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_per_node}",
- "--rdzv-backend=etcd",
- f"--rdzv-endpoint={self._etcd_endpoint}",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
@@ -187,8 +171,6 @@
world_size = 1
args = [
f"--nnodes={nnodes}",
- "--rdzv-backend=etcd",
- f"--rdzv-endpoint={self._etcd_endpoint}",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
@@ -220,8 +202,6 @@
os.environ["PET_NNODES"] = str(nnodes)
os.environ["PET_NPROC_PER_NODE"] = str(nproc_per_node)
- os.environ["PET_RDZV_BACKEND"] = "etcd"
- os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint
os.environ["PET_RDZV_ID"] = run_id
os.environ["PET_MONITOR_INTERVAL"] = "1"
os.environ["PET_START_METHOD"] = "spawn"
@@ -250,8 +230,6 @@
args = [
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_type}",
- "--rdzv-backend=etcd",
- f"--rdzv-endpoint={self._etcd_endpoint}",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
@@ -272,7 +250,8 @@
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
- def test_nproc_launch_auto_configurations(self):
+ @patch("torch.cuda.is_available", return_value=False)
+ def test_nproc_launch_auto_configurations(self, _mock1):
self._test_nproc_launch_configuration("auto", os.cpu_count())
@skip_but_pass_in_sandcastle_if(
@@ -310,8 +289,9 @@
args = [
f"--nnodes={min_nodes}:{max_nodes}",
f"--nproc-per-node={nproc_per_node}",
- "--rdzv-backend=etcd",
- f"--rdzv-endpoint={self._etcd_endpoint}",
+ "--rdzv-backend=c10d",
+ f"--rdzv-endpoint=localhost:{get_free_port()}",
+ "--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
@@ -343,8 +323,9 @@
args = [
f"--nnodes={min_nodes}:{max_nodes}",
f"--nproc-per-node={nproc_per_node}",
- "--rdzv-backend=etcd",
- f"--rdzv-endpoint={self._etcd_endpoint}",
+ "--rdzv-backend=c10d",
+ f"--rdzv-endpoint=localhost:{get_free_port()}",
+ "--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--max-restarts=0",
@@ -376,8 +357,9 @@
args = [
f"--nnodes={min_nodes}:{max_nodes}",
f"--nproc-per-node={nproc_per_node}",
- "--rdzv-backend=etcd",
- f"--rdzv-endpoint={self._etcd_endpoint}",
+ "--rdzv-backend=c10d",
+ f"--rdzv-endpoint=localhost:{get_free_port()}",
+ "--rdzv_conf=timeout=5",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--max-restarts=0",
@@ -452,8 +434,9 @@
args = [
f"--nnodes={min_nodes}:{max_nodes}",
f"--nproc-per-node={nproc_per_node}",
- "--rdzv-backend=etcd",
- f"--rdzv-endpoint={self._etcd_endpoint}",
+ "--rdzv-backend=c10d",
+ f"--rdzv-endpoint=localhost:{get_free_port()}",
+ "--rdzv_conf=timeout=5",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
@@ -608,21 +591,6 @@
is_torchelastic_launched = fp.readline()
self.assertEqual("False", is_torchelastic_launched)
- def test_init_method_tcp(self):
- port = get_free_port()
- with patch.object(
- sys,
- "argv",
- [
- path("bin/test_script_init_method.py"),
- f"--init-method=tcp://localhost:{port}",
- "--rank=0",
- "--world-size=1",
- ],
- ):
- runpy.run_path(sys.argv[0], run_name="__main__")
- # nothing to validate, just make sure it runs
-
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
@@ -642,27 +610,6 @@
)
# nothing to validate, just make sure it runs
- def test_init_method_env(self):
- port = get_free_port()
- with patch.dict(
- os.environ,
- {
- "RANK": "0",
- "WORLD_SIZE": "1",
- "MASTER_ADDR": "localhost",
- "MASTER_PORT": str(port),
- },
- ), patch.object(
- sys,
- "argv",
- [
- path("bin/test_script_init_method.py"),
- "--init-method=env://",
- ],
- ):
- runpy.run_path(sys.argv[0], run_name="__main__")
- # nothing to validate, just make sure it runs
-
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
@@ -681,3 +628,7 @@
]
)
# nothing to validate, just make sure it runs
+
+
+if __name__ == "__main__":
+ run_tests()