[XLA:GPU] Add bf16 support to NCCL collectives.

PiperOrigin-RevId: 381223653
Change-Id: I89a1119e0f24f00a6dfb385cc3d0db06a153cd8a
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
index cc86b6e..ef740db 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
@@ -185,6 +185,9 @@
     case F16:
     case F32:
     case F64:
+#if defined(__CUDA_BF16_TYPES_EXIST__)
+    case BF16:
+#endif
       return true;
     default:
       return false;
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
index dd064c9..d0927b1 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
@@ -68,6 +68,10 @@
       return ncclFloat32;
     case F64:
       return ncclFloat64;
+#if defined(__CUDA_BF16_TYPES_EXIST__)
+    case BF16:
+      return ncclBfloat16;
+#endif
     default:
       return tensorflow::errors::InvalidArgument(absl::StrFormat(
           "Unsupported data type: %s", PrimitiveType_Name(element_type)));
diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc
index 221b7326..28315ef 100644
--- a/tensorflow/compiler/xla/tests/collective_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc
@@ -227,6 +227,11 @@
   TestAllOpsForReduce<Eigen::half>();
 }
 
+XLA_TEST_F(CollectiveOpsTest,
+           DISABLED_ON_CPU(AllReduceTwoReplicasOneOperand_bfloat16)) {
+  TestAllOpsForReduce<bfloat16>();
+}
+
 XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) {
   // Test with equal elements.
   TestTwoReplicasOneOperand<bool>(