[ROCm] Enable bf16-related tests in test_c10d_nccl.py and test_grad_layout_1devicemodule_1replicaperprocess (#82020)

### Description
Enable bf16-related unit tests in test_c10d_nccl.py and test_grad_layout_1devicemodule_1replicaperprocess as follows:

- distributed/test_c10d_nccl test_bf16_compress_wrapper_is_view (main.DistributedDataParallelTest)
- distributed/test_c10d_nccl test_bf16_compress_wrapper_nccl (main.DistributedDataParallelTest)
- distributed/test_c10d_nccl test_grad_layout_1devicemodule_1replicaperprocess (main.DistributedDataParallelTest)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82020
Approved by: https://github.com/ezyang
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index 766c5ce..f858424 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -61,8 +61,12 @@
 # bfloat16 is only supported by CUDA 11+
 BFLOAT16_AVAILABLE = (
     torch.cuda.is_available()
-    and torch.version.cuda is not None
-    and int(torch.version.cuda.split('.')[0]) >= 11)
+    and
+    (
+        (torch.version.cuda is not None and int(torch.version.cuda.split('.')[0]) >= 11)
+        or torch.version.hip is not None
+    )
+)
 
 class RendezvousEnvTest(TestCase):
     @retry_on_connect_failures
@@ -2096,7 +2100,6 @@
         "BFloat16 is only supported by CUDA 11+",
     )
     @skip_if_lt_x_gpu(2)
-    @skip_if_rocm
     def test_bf16_compress_wrapper_nccl(self):
         self._test_bf16_compress_wrapper()
 
@@ -2137,7 +2140,6 @@
         "BFloat16 is only supported by CUDA 11+",
     )
     @skip_if_lt_x_gpu(2)
-    @skip_if_rocm
     def test_bf16_compress_wrapper_is_view(self):
         self._test_bf16_compress_wrapper(gradient_as_bucket_view=True)
 
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index 20fe354..cedb006 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -1649,8 +1649,8 @@
             hook.__name__ in ["bf16_compress_hook", "bf16_compress_wrapper_hook"]
             and
             (
-                torch.version.cuda is None
-                or int(torch.version.cuda.split('.')[0]) < 11
+                (torch.version.cuda is None and torch.version.hip is None)
+                or (torch.version.cuda is not None and int(torch.version.cuda.split('.')[0]) < 11)
                 or not dist.is_available()
                 or not dist.is_nccl_available()
                 or torch.cuda.nccl.version() < (2, 10)