commit | 1ab997b0bfd798877c93d3edde6436881fd95abb | [log] [tgz] |
---|---|---|
author | Chris Jones <cjfj@google.com> | Thu Jun 24 04:07:41 2021 -0700 |
committer | TensorFlower Gardener <gardener@tensorflow.org> | Thu Jun 24 04:11:13 2021 -0700 |
tree | 586a72b562e30e6db04bcb26b17f7bfde9e91db4 | |
parent | b8c84970ef0a77678bd36164a359968a6a4bb01d [diff] |
[XLA:GPU] Add bf16 support to NCCL collectives. PiperOrigin-RevId: 381223653 Change-Id: I89a1119e0f24f00a6dfb385cc3d0db06a153cd8a
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc index cc86b6e..ef740db 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
@@ -185,6 +185,9 @@ case F16: case F32: case F64: +#if defined(__CUDA_BF16_TYPES_EXIST__) + case BF16: +#endif return true; default: return false;
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc index dd064c9..d0927b1 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
@@ -68,6 +68,10 @@ return ncclFloat32; case F64: return ncclFloat64; +#if defined(__CUDA_BF16_TYPES_EXIST__) + case BF16: + return ncclBfloat16; +#endif default: return tensorflow::errors::InvalidArgument(absl::StrFormat( "Unsupported data type: %s", PrimitiveType_Name(element_type)));
diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc index 221b7326..28315ef 100644 --- a/tensorflow/compiler/xla/tests/collective_ops_test.cc +++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc
@@ -227,6 +227,11 @@ TestAllOpsForReduce<Eigen::half>(); } +XLA_TEST_F(CollectiveOpsTest, + DISABLED_ON_CPU(AllReduceTwoReplicasOneOperand_bfloat16)) { + TestAllOpsForReduce<bfloat16>(); +} + XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) { // Test with equal elements. TestTwoReplicasOneOperand<bool>(