[ROCm] Enable bf16-related tests in test_c10d_nccl.py and test_grad_layout_1devicemodule_1replicaperprocess (#82020)
### Description
Enable bf16-related unit tests in test_c10d_nccl.py and test_grad_layout_1devicemodule_1replicaperprocess as follows:
- distributed/test_c10d_nccl test_bf16_compress_wrapper_is_view (main.DistributedDataParallelTest)
- distributed/test_c10d_nccl test_bf16_compress_wrapper_nccl (main.DistributedDataParallelTest)
- distributed/test_c10d_nccl test_grad_layout_1devicemodule_1replicaperprocess (main.DistributedDataParallelTest)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82020
Approved by: https://github.com/ezyang
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index 766c5ce..f858424 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -61,8 +61,12 @@
# bfloat16 is only supported by CUDA 11+
BFLOAT16_AVAILABLE = (
torch.cuda.is_available()
- and torch.version.cuda is not None
- and int(torch.version.cuda.split('.')[0]) >= 11)
+ and
+ (
+ (torch.version.cuda is not None and int(torch.version.cuda.split('.')[0]) >= 11)
+ or torch.version.hip is not None
+ )
+)
class RendezvousEnvTest(TestCase):
@retry_on_connect_failures
@@ -2096,7 +2100,6 @@
"BFloat16 is only supported by CUDA 11+",
)
@skip_if_lt_x_gpu(2)
- @skip_if_rocm
def test_bf16_compress_wrapper_nccl(self):
self._test_bf16_compress_wrapper()
@@ -2137,7 +2140,6 @@
"BFloat16 is only supported by CUDA 11+",
)
@skip_if_lt_x_gpu(2)
- @skip_if_rocm
def test_bf16_compress_wrapper_is_view(self):
self._test_bf16_compress_wrapper(gradient_as_bucket_view=True)
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index 20fe354..cedb006 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -1649,8 +1649,8 @@
hook.__name__ in ["bf16_compress_hook", "bf16_compress_wrapper_hook"]
and
(
- torch.version.cuda is None
- or int(torch.version.cuda.split('.')[0]) < 11
+ (torch.version.cuda is None and torch.version.hip is None)
+ or (torch.version.cuda is not None and int(torch.version.cuda.split('.')[0]) < 11)
or not dist.is_available()
or not dist.is_nccl_available()
or torch.cuda.nccl.version() < (2, 10)