[XLA] CPU AllReduce implementation: same process, all work done by "main" worker

Only "virtual" devices are supported which share the same thread pool.
Communication is performed via a global map.
This all-reduce implementation is not actually "collective": all the work is
performed by the main worker, while all the other threads are waiting.

PiperOrigin-RevId: 277548660
Change-Id: Iccd438dbe81329a02a157d9aae72d23c706b9f52
diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h
index a935916..f2fa964 100644
--- a/tensorflow/compiler/xla/service/collective_ops_utils.h
+++ b/tensorflow/compiler/xla/service/collective_ops_utils.h
@@ -218,7 +218,7 @@
     WaitAndLogIfStuck(&all_participants_present_, [&] {
       return absl::StrFormat(
           "participant for device ordinal %d, stream %p waiting for all "
-          "participants to be arrive at rendezvous %s",
+          "participants to arrive at rendezvous %s",
           participant.device_ordinal, participant.stream, key_.ToString());
     });
 
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index b30754f..fa9d00a 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -288,6 +288,7 @@
         "//tensorflow/compiler/xla:window_util",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/service:buffer_assignment",
+        "//tensorflow/compiler/xla/service:collective_ops_utils",
         "//tensorflow/compiler/xla/service:elemental_ir_emitter",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:hlo_casting_utils",
@@ -304,6 +305,10 @@
         "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
         "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
         "//tensorflow/core:lib",
+        "//tensorflow/core/lib/math:math_util",
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:types",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/strings",
@@ -501,15 +506,27 @@
     copts = runtime_copts(),
     deps = [
         "//tensorflow/compiler/xla:executable_run_options",
+        "//tensorflow/compiler/xla:refcounting_hash_map",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/service:collective_ops_utils",
         "//tensorflow/compiler/xla/service:computation_placer",
-        "//tensorflow/core:lib",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_parser",
+        "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+        "//tensorflow/core/platform:dynamic_annotations",
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:mutex",
+        "//tensorflow/core/platform:platform_port",
+        "//tensorflow/core/platform:types",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/stream_executor",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/types:span",
     ],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 645d71b..32bd6cb 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -15,23 +15,34 @@
 
 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
 
+#include <cstddef>
 #include <cstring>
 #include <functional>
+#include <limits>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
 #include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/refcounting_hash_map.h"
+#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
 #include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/dynamic_annotations.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mem.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/stream_executor/device_memory.h"
 #include "tensorflow/stream_executor/stream_executor.h"
 
+namespace se = ::stream_executor;
+
 namespace xla {
 namespace cpu {
 namespace runtime {
@@ -100,6 +111,7 @@
     "__xla_cpu_runtime_TracingStart";
 extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
 extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
+extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
 extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
 
 }  // namespace runtime
@@ -238,6 +250,206 @@
                                          std::move(shape));
 }
 
+namespace {
+
+class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
+ public:
+  explicit CpuAllReduceRendezvous(const xla::RendezvousKey& k)
+      : xla::Rendezvous<std::nullptr_t>(k) {}
+
+ protected:
+  xla::StatusOr<std::pair<std::nullptr_t, bool>> SubmitParticipantImpl(
+      xla::AllReduceParticipantData participant) override {
+    xla::PrimitiveType datatype = participant.primitive_type;
+    bool primary = [&] {
+      tensorflow::mutex_lock lock(mu_);
+      if (!initialized_) {
+        initialized_ = true;
+        return true;
+      }
+      return false;
+    }();
+
+    if (primary) {
+      switch (datatype) {
+        case xla::S8:
+          DoAllReduce<xla::S8>(participant);
+          break;
+        case xla::U8:
+          DoAllReduce<xla::U8>(participant);
+          break;
+        case xla::S32:
+          DoAllReduce<xla::S32>(participant);
+          break;
+        case xla::U32:
+          DoAllReduce<xla::U32>(participant);
+          break;
+        case xla::S64:
+          DoAllReduce<xla::S64>(participant);
+          break;
+        case xla::U64:
+          DoAllReduce<xla::U64>(participant);
+          break;
+        case xla::F16:
+          DoAllReduce<xla::F16>(participant);
+          break;
+        case xla::F32:
+          DoAllReduce<xla::F32>(participant);
+          break;
+        case xla::F64:
+          DoAllReduce<xla::F64>(participant);
+          break;
+        default:
+          LOG(FATAL) << "Unexpected datatype;";
+      }
+    }
+
+    // First element is a dummy value.
+    return std::make_pair(nullptr, primary);
+  }
+
+ private:
+  template <xla::PrimitiveType PT>
+  void DoAllReduce(xla::AllReduceParticipantData participant) {
+    using T = typename xla::primitive_util::PrimitiveTypeToNative<PT>::type;
+    tensorflow::mutex_lock lock(mu_);
+    CHECK(!participants_.empty());
+    xla::int64 element_count = participant.element_count;
+    xla::ReductionKind reduction_kind = participant.reduction_kind;
+    for (const auto& p : participants_) {
+      CHECK_EQ(p.element_count, element_count);
+      CHECK(p.reduction_kind == reduction_kind);
+    }
+
+    std::vector<absl::Span<T>> input_buffers;
+    std::vector<absl::Span<T>> output_buffers;
+    input_buffers.reserve(participants_.size());
+    output_buffers.reserve(participants_.size());
+
+    for (auto& p : participants_) {
+      input_buffers.emplace_back(static_cast<T*>(p.source_data.opaque()),
+                                 element_count);
+      output_buffers.emplace_back(static_cast<T*>(p.destination_data.opaque()),
+                                  element_count);
+    }
+
+    auto compute = [reduction_kind](T a, T b) -> T {
+      switch (reduction_kind) {
+        case xla::ReductionKind::SUM:
+          return a + b;
+        case xla::ReductionKind::PRODUCT:
+          return a * b;
+        case xla::ReductionKind::MIN:
+          return std::min(a, b);
+        case xla::ReductionKind::MAX:
+          return std::max(a, b);
+      }
+    };
+
+    for (int idx = 0; idx < element_count; idx++) {
+      T out = [&]() -> T {
+        switch (reduction_kind) {
+          case xla::ReductionKind::SUM:
+            return static_cast<T>(0);
+          case xla::ReductionKind::PRODUCT:
+            return static_cast<T>(1);
+          case xla::ReductionKind::MIN:
+            return std::numeric_limits<T>::max();
+          case xla::ReductionKind::MAX:
+            return std::numeric_limits<T>::min();
+        }
+      }();
+
+      for (auto& input : input_buffers) {
+        out = compute(out, input[idx]);
+      }
+      for (auto& output : output_buffers) {
+        output[idx] = out;
+      }
+    }
+  }
+};
+
+xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
+GlobalRendezvousMap() {
+  static auto& m =
+      *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>(
+          [](const xla::RendezvousKey& k) {
+            return absl::make_unique<CpuAllReduceRendezvous>(k);
+          });
+  return m;
+}
+
+}  // namespace
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
+    const xla::ExecutableRunOptions* run_options,
+    const void* replica_groups_str, xla::int32 replica_groups_str_size,
+    xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
+    const void* shape_ptr, xla::int32 shape_length, void* input_buffer,
+    void* output_buffer) {
+  absl::string_view replica_groups_serialized(
+      static_cast<const char*>(replica_groups_str), replica_groups_str_size);
+
+  // FIXME(cheshire): avoid repetition w/__xla_cpu_runtime_ReplicaId.
+  int device_ordinal = [&] {
+    if (run_options->stream()) {
+      return run_options->stream()->parent()->device_ordinal();
+    } else {
+      return run_options->device_ordinal();
+    }
+  }();
+
+  std::vector<xla::ReplicaGroup> group =
+      xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
+  xla::int32 replica_count = run_options->device_assignment()->replica_count();
+  std::vector<xla::int64> participating_replicas_vec =
+      xla::GetParticipatingReplicas(device_ordinal, group, replica_count,
+                                    *run_options->device_assignment())
+          .ValueOrDie();
+
+  xla::RendezvousKey::CollectiveOpKind op_kind =
+      channel_id_present ? xla::RendezvousKey::kCrossModule
+                         : xla::RendezvousKey::kCrossReplica;
+  xla::RendezvousKey rendezvous_key(run_options->run_id(),
+                                    participating_replicas_vec, op_kind, op_id);
+
+  std::shared_ptr<CpuAllReduceRendezvous> rendezvous =
+      GlobalRendezvousMap()[rendezvous_key];
+
+  auto shape_str = ShapeString(shape_ptr, shape_length);
+  VLOG(2) << "All-reduce input/output shape : " << shape_str;
+
+  xla::Shape shape =
+      DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
+
+  xla::AllReduceParticipantData participant(rendezvous_key);
+
+  CHECK_EQ(shape.dimensions_size(), 1);
+  participant.element_count = shape.dimensions(0);
+  participant.device_ordinal = device_ordinal;
+  participant.primitive_type = shape.element_type();
+  participant.stream = run_options->stream();
+
+  se::DeviceMemoryBase input(input_buffer, xla::ShapeUtil::ByteSizeOf(shape));
+  se::DeviceMemoryBase output(output_buffer, xla::ShapeUtil::ByteSizeOf(shape));
+  participant.source_data = input;
+  participant.destination_data = output;
+  participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
+
+  auto p = rendezvous->SubmitParticipant(participant).ValueOrDie();
+  std::shared_ptr<tensorflow::BlockingCounter> blocking_counter = p.second;
+  blocking_counter->DecrementCount();
+  xla::WaitAndLogIfStuck(blocking_counter.get(), [&] {
+    return absl::StrFormat(
+        "participant waiting for all threads to drop their reference to the "
+        "rendezvous: %s",
+        rendezvous_key.ToString());
+  });
+
+  rendezvous.reset();
+}
+
 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
     const xla::ExecutableRunOptions* run_options, void* output_buffer) {
   int device_ordinal = [&]() {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index 23a9359..598ab35 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -28,6 +28,7 @@
 
 #include "tensorflow/compiler/xla/executable_run_options.h"
 #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/types.h"
 
 namespace xla {
@@ -67,6 +68,7 @@
 extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
 extern const char* const kParallelForkJoinSymbolName;
 extern const char* const kKeyValueSortSymbolName;
+extern const char* const kAllReduceSymbolName;
 extern const char* const kReplicaIdSymbolName;
 extern const char* const kTracingStartSymbolName;
 extern const char* const kTracingEndSymbolName;
@@ -154,6 +156,20 @@
     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
     void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length);
 
+// Perform all reduce on a CPU.
+//
+// participating_replicas: array of replica IDs participating in the reduction,
+// cf. GetParticipatingReplicas.
+// channel_id_present, op_id: whether op_id is a channel ID or a module ID.
+// reduction_kind: operator used for a reduction, cf. ReductionKind.
+// shape_ptr: shape of all input/output buffers.
+extern void __xla_cpu_runtime_AllReduce(
+    const xla::ExecutableRunOptions* run_options,
+    const void* replica_groups_str, xla::int32 replica_groups_str_size,
+    xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
+    const void* shape_ptr, xla::int32 shape_length, void* input_buffer,
+    void* output_buffer);
+
 // Write the replica ID into the output buffer.
 extern void __xla_cpu_runtime_ReplicaId(
     const xla::ExecutableRunOptions* run_options, void* output_buffer);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 2cca102..0848288 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -40,7 +40,9 @@
 #include "llvm/IR/LLVMContext.h"
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
@@ -68,6 +70,7 @@
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/bits.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/math/math_util.h"
@@ -1365,13 +1368,7 @@
   return Status::OK();
 }
 
-Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
-  if (hlo_module_config_.replica_count() != 1) {
-    // TODO(b/33011107): Support nontrivial cross replica sum on CPU.
-    return Unimplemented(
-        "AllReduce with >1 replica is not implemented on CPU.");
-  }
-
+Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) {
   // When there is a single replica, a cross replica sum is the identity
   // function, and the buffer assignment expects a copy.
   //
@@ -1406,6 +1403,112 @@
   return Status::OK();
 }
 
+Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
+  CHECK_GE(crs->operand_count(), 1);
+  PrimitiveType datatype = crs->operand(0)->shape().element_type();
+
+  bool is_datatype_supported = [&] {
+    // TODO(cheshire): Fix duplication wrt. cpu_runtime
+    switch (datatype) {
+      case S8:
+      case U8:
+      case S32:
+      case U32:
+      case S64:
+      case U64:
+      case F16:
+      case F32:
+      case F64:
+        return true;
+      default:
+        return false;
+    }
+  }();
+
+  if (!is_datatype_supported) {
+    return Unimplemented("AllReduce for datatype '%s' is not supported",
+                         primitive_util::LowercasePrimitiveTypeName(datatype));
+  }
+
+  if (!MatchReductionComputation(crs->to_apply()).has_value()) {
+    return Unimplemented("AllReduce for computation '%s' is not supported",
+                         crs->to_apply()->ToString());
+  }
+
+  llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
+  llvm::Type* int32_type = b_.getInt32Ty();
+  llvm::Type* int64_type = b_.getInt64Ty();
+  llvm::FunctionType* all_reduce_func_ty =
+      llvm::FunctionType::get(b_.getVoidTy(),
+                              {/*run_options=*/i8_ptr_type,
+                               /*replica_groups=*/i8_ptr_type,
+                               /*replica_groups_size=*/int32_type,
+                               /*channel_id_present=*/int32_type,
+                               /*op_id=*/int64_type,
+                               /*reduction_kind=*/int32_type,
+                               /*shape_ptr=*/i8_ptr_type,
+                               /*shape_length=*/int32_type,
+                               /*input_buffer=*/i8_ptr_type,
+                               /*output_buffer=*/i8_ptr_type},
+                              /*isVarArg=*/false);
+
+  auto all_reduce_func = llvm::dyn_cast<llvm::Function>(
+      module_
+          ->getOrInsertFunction(runtime::kAllReduceSymbolName,
+                                all_reduce_func_ty)
+          .getCallee());
+  all_reduce_func->setCallingConv(llvm::CallingConv::C);
+
+  std::string replica_groups = ReplicaGroupsToString(crs->replica_groups());
+  int32 replica_groups_size = replica_groups.size();
+  llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
+
+  Shape shape = crs->operand(0)->shape();
+  int32 shape_length;
+  TF_ASSIGN_OR_RETURN(
+      llvm::Value * shape_ptr,
+      llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
+
+  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
+                      assignment_.GetUniqueSlice(crs->operand(0), {}));
+  llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
+
+  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
+                      assignment_.GetUniqueSlice(crs, {}));
+  llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
+
+  Call(all_reduce_func,
+       {/*run_options=*/GetExecutableRunOptionsArgument(),
+        /*replica_groups=*/replica_groups_v,
+        /*replica_groups_size=*/b_.getInt32(replica_groups_size),
+
+        /*channel_id_present=*/
+        b_.getInt32(static_cast<int32>(crs->channel_id().has_value())),
+        /*op_id=*/
+        b_.getInt64(crs->channel_id().has_value()
+                        ? *crs->channel_id()
+                        : crs->GetModule()->unique_id()),
+
+        /*reduction_kind=*/
+        b_.getInt32(
+            static_cast<int32>(*MatchReductionComputation(crs->to_apply()))),
+
+        /*shape_ptr=*/shape_ptr,
+        /*shape_length=*/b_.getInt32(shape_length),
+
+        /*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
+        /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)});
+
+  return Status::OK();
+}
+
+Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
+  if (hlo_module_config_.replica_count() == 1) {
+    return HandleAllReduceSingleReplica(crs);
+  }
+  return HandleAllReduceMultipleReplica(crs);
+}
+
 Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
   llvm::FunctionType* replica_id_function_ty =
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 5f6da58..453676b 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -196,6 +196,9 @@
   }
 
  private:
+  Status HandleAllReduceSingleReplica(HloInstruction* crs);
+  Status HandleAllReduceMultipleReplica(HloInstruction* crs);
+
   // Private helper to initialize an IR function for the computation.
   void InitializeIrFunction(const string& function_name);
 
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index 4036497..1505f0a 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -216,6 +216,7 @@
 
   REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
   REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
+  REGISTER_CPU_RUNTIME_SYMBOL(AllReduce);
   REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId);
   REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 5589821..6302907 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1893,7 +1893,11 @@
 xla_test(
     name = "collective_ops_test",
     srcs = ["collective_ops_test.cc"],
-    backends = ["gpu"],
+    args = ["--xla_force_host_platform_device_count=4"],
+    backends = [
+        "gpu",
+        "cpu",
+    ],
     tags = [
         # This test is tagged "manual" because it requires multiple GPUs, and
         # Forge only supports single-GPU tests.  Guitar skips "manual" tests
diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc
index da430ed..372a366 100644
--- a/tensorflow/compiler/xla/tests/collective_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc
@@ -224,7 +224,7 @@
 
 // Check that the NCCL data structures in our all-reduce implementation are
 // cached as we expect.
-XLA_TEST_F(CollectiveOpsTest, AllReduce_NcclChannelCaching) {
+XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_NcclChannelCaching)) {
   const int64 kNumElems = 1024;
 
   std::vector<float> input_vec(kNumElems);
@@ -398,7 +398,7 @@
   }
 }
 
-XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) {
+XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_Simple)) {
   const char* const kModuleStr = R"(
   HloModule test
   ENTRY test_computation {