[DTensor] remove redundant device mesh test code (#92069)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92069
Approved by: https://github.com/wanchaol
diff --git a/test/distributed/_tensor/test_device_mesh.py b/test/distributed/_tensor/test_device_mesh.py
index 947b2bd..27c7582 100644
--- a/test/distributed/_tensor/test_device_mesh.py
+++ b/test/distributed/_tensor/test_device_mesh.py
@@ -70,17 +70,13 @@
             mesh = DeviceMesh(device_type, mesh_tensor)
 
     def test_init_process_group(self):
-        device_type = "cuda" if torch.cuda.is_available() else "cpu"
-        backend = "nccl" if device_type == "cuda" else "gloo"
+        device_type, backend = _get_device_type_and_backend()
         # skip the test if not enough GPUs
         if backend == "nccl" and torch.cuda.device_count() < self.world_size:
             sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
         mesh_tensor = torch.arange(4).reshape(2, 2)
         self.assertTrue(not is_initialized())
-        os.environ["MASTER_ADDR"] = "localhost"
-        os.environ["MASTER_PORT"] = "25364"
-        os.environ["WORLD_SIZE"] = f"{self.world_size}"
-        os.environ["RANK"] = f"{self.rank}"
+        _set_env_var(world_size=self.world_size, rank=self.rank)
         mesh = DeviceMesh(device_type, mesh_tensor)
         self.assertTrue(is_initialized())
         self.destroy_pg()
@@ -187,16 +183,12 @@
         return 8
 
     def test_mesh_size_requirement_error(self):
-        device_type = "cuda" if torch.cuda.is_available() else "cpu"
-        backend = "nccl" if device_type == "cuda" else "gloo"
+        device_type, backend = _get_device_type_and_backend()
         # skip the test if not enough GPUs
         if backend == "nccl" and torch.cuda.device_count() < self.world_size:
             sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
         mesh_tensor = torch.arange(4).reshape(2, 2)
-        os.environ["MASTER_ADDR"] = "localhost"
-        os.environ["MASTER_PORT"] = "25364"
-        os.environ["WORLD_SIZE"] = f"{self.world_size}"
-        os.environ["RANK"] = f"{self.rank}"
+        _set_env_var(world_size=self.world_size, rank=self.rank)
         with self.assertRaisesRegex(RuntimeError, "DeviceMesh must include every process in WORLD"):
             mesh = DeviceMesh(device_type, mesh_tensor)
         self.assertTrue(not is_initialized())