[DTensor] remove redundant device mesh test code (#92069)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92069
Approved by: https://github.com/wanchaol
diff --git a/test/distributed/_tensor/test_device_mesh.py b/test/distributed/_tensor/test_device_mesh.py
index 947b2bd..27c7582 100644
--- a/test/distributed/_tensor/test_device_mesh.py
+++ b/test/distributed/_tensor/test_device_mesh.py
@@ -70,17 +70,13 @@
mesh = DeviceMesh(device_type, mesh_tensor)
def test_init_process_group(self):
- device_type = "cuda" if torch.cuda.is_available() else "cpu"
- backend = "nccl" if device_type == "cuda" else "gloo"
+ device_type, backend = _get_device_type_and_backend()
# skip the test if not enough GPUs
if backend == "nccl" and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
mesh_tensor = torch.arange(4).reshape(2, 2)
self.assertTrue(not is_initialized())
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = "25364"
- os.environ["WORLD_SIZE"] = f"{self.world_size}"
- os.environ["RANK"] = f"{self.rank}"
+ _set_env_var(world_size=self.world_size, rank=self.rank)
mesh = DeviceMesh(device_type, mesh_tensor)
self.assertTrue(is_initialized())
self.destroy_pg()
@@ -187,16 +183,12 @@
return 8
def test_mesh_size_requirement_error(self):
- device_type = "cuda" if torch.cuda.is_available() else "cpu"
- backend = "nccl" if device_type == "cuda" else "gloo"
+ device_type, backend = _get_device_type_and_backend()
# skip the test if not enough GPUs
if backend == "nccl" and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
mesh_tensor = torch.arange(4).reshape(2, 2)
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = "25364"
- os.environ["WORLD_SIZE"] = f"{self.world_size}"
- os.environ["RANK"] = f"{self.rank}"
+ _set_env_var(world_size=self.world_size, rank=self.rank)
with self.assertRaisesRegex(RuntimeError, "DeviceMesh must include every process in WORLD"):
mesh = DeviceMesh(device_type, mesh_tensor)
self.assertTrue(not is_initialized())