[ROCm] Enable test_jit_c10.py tests for ROCm (#52410)
Summary:
Re-enabling these test cases for ROCm because they are passing.
jeffdaily
Signed-off-by: Kyle Chen <kylechen@amd.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52410
Reviewed By: glaringlee
Differential Revision: D26516757
Pulled By: malfet
fbshipit-source-id: 49921ee724a50f19afd8e6884a5f3ecd9291fa5c
diff --git a/test/distributed/test_jit_c10d.py b/test/distributed/test_jit_c10d.py
index ca43c25..f21c4ec 100644
--- a/test/distributed/test_jit_c10d.py
+++ b/test/distributed/test_jit_c10d.py
@@ -8,7 +8,7 @@
from typing import List
import torch.testing._internal.common_utils as common
-from torch.testing._internal.common_distributed import requires_nccl, skip_if_rocm_single_process
+from torch.testing._internal.common_distributed import requires_nccl
from torch.testing._internal.common_utils import load_tests, TEST_WITH_TSAN, run_tests, IS_WINDOWS
from torch.testing._internal.jit_utils import JitTestCase
@@ -71,12 +71,10 @@
self.world_size, self.rank, [], "nccl", tcp_store, name, 0)
@requires_nccl()
- @skip_if_rocm_single_process
def test_init_process_group_nccl_torchbind(self):
self._create_nccl_pg("raw_process_group_nccl_torchbind")
@requires_nccl()
- @skip_if_rocm_single_process
def test_process_group_nccl_torchbind_alltoall(self):
nccl_pg = self._create_nccl_pg("process_group_nccl_as_base_class")
@@ -98,13 +96,11 @@
run_pg_nccl_alltoall(nccl_pg, output, input)
@requires_nccl()
- @skip_if_rocm_single_process
def test_init_process_group_nccl_as_base_process_group_torchbind(self):
name = unique_process_group_name("creation_test_process_group")
self._create_nccl_pg_as_base_process_group(name)
@requires_nccl()
- @skip_if_rocm_single_process
def test_process_group_nccl_as_base_process_group_torchbind_alltoall(self):
name = unique_process_group_name("alltoall_test_process_group")
nccl_pg = self._create_nccl_pg_as_base_process_group(name)
@@ -127,7 +123,6 @@
run_pg_nccl_alltoall(nccl_pg, output, input)
@requires_nccl()
- @skip_if_rocm_single_process
def test_process_group_nccl_serialization(self):
class TestModule(torch.nn.Module):
def __init__(self, pg_nccl):
@@ -185,7 +180,6 @@
raise unittest.SkipTest("NCCL test requires 2+ GPUs")
@requires_nccl()
- @skip_if_rocm_single_process
def test_frontend_singleton(self):
frontend1 = torch.classes.dist_c10d.frontend()
frontend2 = torch.classes.dist_c10d.frontend()
@@ -208,7 +202,6 @@
raise unittest.SkipTest("NCCL test requires 2+ GPUs")
@requires_nccl()
- @skip_if_rocm_single_process
def test_process_group_as_module_member(self):
class TestModule(torch.nn.Module):
def __init__(self):