[XLA] CPU AllReduce implementation: same process, all work done by "main" worker
Only "virtual" devices are supported which share the same thread pool.
Communication is performed via a global map.
This all-reduce implementation is not actually "collective": all the work is
performed by the main worker, while all the other threads are waiting.
PiperOrigin-RevId: 277548660
Change-Id: Iccd438dbe81329a02a157d9aae72d23c706b9f52
diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h
index a935916..f2fa964 100644
--- a/tensorflow/compiler/xla/service/collective_ops_utils.h
+++ b/tensorflow/compiler/xla/service/collective_ops_utils.h
@@ -218,7 +218,7 @@
WaitAndLogIfStuck(&all_participants_present_, [&] {
return absl::StrFormat(
"participant for device ordinal %d, stream %p waiting for all "
- "participants to be arrive at rendezvous %s",
+ "participants to arrive at rendezvous %s",
participant.device_ordinal, participant.stream, key_.ToString());
});
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index b30754f..fa9d00a 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -288,6 +288,7 @@
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/compiler/xla/service:collective_ops_utils",
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
@@ -304,6 +305,10 @@
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
+ "//tensorflow/core/lib/math:math_util",
+ "//tensorflow/core/platform:logging",
+ "//tensorflow/core/platform:macros",
+ "//tensorflow/core/platform:types",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
@@ -501,15 +506,27 @@
copts = runtime_copts(),
deps = [
"//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/xla:refcounting_hash_map",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:collective_ops_utils",
"//tensorflow/compiler/xla/service:computation_placer",
- "//tensorflow/core:lib",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core/platform:dynamic_annotations",
+ "//tensorflow/core/platform:logging",
+ "//tensorflow/core/platform:macros",
+ "//tensorflow/core/platform:mutex",
+ "//tensorflow/core/platform:platform_port",
+ "//tensorflow/core/platform:types",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 645d71b..32bd6cb 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -15,23 +15,34 @@
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include <cstddef>
#include <cstring>
#include <functional>
+#include <limits>
#include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/refcounting_hash_map.h"
+#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/stream_executor.h"
+namespace se = ::stream_executor;
+
namespace xla {
namespace cpu {
namespace runtime {
@@ -100,6 +111,7 @@
"__xla_cpu_runtime_TracingStart";
extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
+extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
} // namespace runtime
@@ -238,6 +250,206 @@
std::move(shape));
}
+namespace {
+
+class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
+ public:
+ explicit CpuAllReduceRendezvous(const xla::RendezvousKey& k)
+ : xla::Rendezvous<std::nullptr_t>(k) {}
+
+ protected:
+ xla::StatusOr<std::pair<std::nullptr_t, bool>> SubmitParticipantImpl(
+ xla::AllReduceParticipantData participant) override {
+ xla::PrimitiveType datatype = participant.primitive_type;
+ bool primary = [&] {
+ tensorflow::mutex_lock lock(mu_);
+ if (!initialized_) {
+ initialized_ = true;
+ return true;
+ }
+ return false;
+ }();
+
+ if (primary) {
+ switch (datatype) {
+ case xla::S8:
+ DoAllReduce<xla::S8>(participant);
+ break;
+ case xla::U8:
+ DoAllReduce<xla::U8>(participant);
+ break;
+ case xla::S32:
+ DoAllReduce<xla::S32>(participant);
+ break;
+ case xla::U32:
+ DoAllReduce<xla::U32>(participant);
+ break;
+ case xla::S64:
+ DoAllReduce<xla::S64>(participant);
+ break;
+ case xla::U64:
+ DoAllReduce<xla::U64>(participant);
+ break;
+ case xla::F16:
+ DoAllReduce<xla::F16>(participant);
+ break;
+ case xla::F32:
+ DoAllReduce<xla::F32>(participant);
+ break;
+ case xla::F64:
+ DoAllReduce<xla::F64>(participant);
+ break;
+ default:
+ LOG(FATAL) << "Unexpected datatype;";
+ }
+ }
+
+ // First element is a dummy value.
+ return std::make_pair(nullptr, primary);
+ }
+
+ private:
+ template <xla::PrimitiveType PT>
+ void DoAllReduce(xla::AllReduceParticipantData participant) {
+ using T = typename xla::primitive_util::PrimitiveTypeToNative<PT>::type;
+ tensorflow::mutex_lock lock(mu_);
+ CHECK(!participants_.empty());
+ xla::int64 element_count = participant.element_count;
+ xla::ReductionKind reduction_kind = participant.reduction_kind;
+ for (const auto& p : participants_) {
+ CHECK_EQ(p.element_count, element_count);
+ CHECK(p.reduction_kind == reduction_kind);
+ }
+
+ std::vector<absl::Span<T>> input_buffers;
+ std::vector<absl::Span<T>> output_buffers;
+ input_buffers.reserve(participants_.size());
+ output_buffers.reserve(participants_.size());
+
+ for (auto& p : participants_) {
+ input_buffers.emplace_back(static_cast<T*>(p.source_data.opaque()),
+ element_count);
+ output_buffers.emplace_back(static_cast<T*>(p.destination_data.opaque()),
+ element_count);
+ }
+
+ auto compute = [reduction_kind](T a, T b) -> T {
+ switch (reduction_kind) {
+ case xla::ReductionKind::SUM:
+ return a + b;
+ case xla::ReductionKind::PRODUCT:
+ return a * b;
+ case xla::ReductionKind::MIN:
+ return std::min(a, b);
+ case xla::ReductionKind::MAX:
+ return std::max(a, b);
+ }
+ };
+
+ for (int idx = 0; idx < element_count; idx++) {
+ T out = [&]() -> T {
+ switch (reduction_kind) {
+ case xla::ReductionKind::SUM:
+ return static_cast<T>(0);
+ case xla::ReductionKind::PRODUCT:
+ return static_cast<T>(1);
+ case xla::ReductionKind::MIN:
+ return std::numeric_limits<T>::max();
+ case xla::ReductionKind::MAX:
+ return std::numeric_limits<T>::min();
+ }
+ }();
+
+ for (auto& input : input_buffers) {
+ out = compute(out, input[idx]);
+ }
+ for (auto& output : output_buffers) {
+ output[idx] = out;
+ }
+ }
+ }
+};
+
+xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
+GlobalRendezvousMap() {
+ static auto& m =
+ *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>(
+ [](const xla::RendezvousKey& k) {
+ return absl::make_unique<CpuAllReduceRendezvous>(k);
+ });
+ return m;
+}
+
+} // namespace
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
+ const xla::ExecutableRunOptions* run_options,
+ const void* replica_groups_str, xla::int32 replica_groups_str_size,
+ xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
+ const void* shape_ptr, xla::int32 shape_length, void* input_buffer,
+ void* output_buffer) {
+ absl::string_view replica_groups_serialized(
+ static_cast<const char*>(replica_groups_str), replica_groups_str_size);
+
+ // FIXME(cheshire): avoid repetition w/__xla_cpu_runtime_ReplicaId.
+ int device_ordinal = [&] {
+ if (run_options->stream()) {
+ return run_options->stream()->parent()->device_ordinal();
+ } else {
+ return run_options->device_ordinal();
+ }
+ }();
+
+ std::vector<xla::ReplicaGroup> group =
+ xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
+ xla::int32 replica_count = run_options->device_assignment()->replica_count();
+ std::vector<xla::int64> participating_replicas_vec =
+ xla::GetParticipatingReplicas(device_ordinal, group, replica_count,
+ *run_options->device_assignment())
+ .ValueOrDie();
+
+ xla::RendezvousKey::CollectiveOpKind op_kind =
+ channel_id_present ? xla::RendezvousKey::kCrossModule
+ : xla::RendezvousKey::kCrossReplica;
+ xla::RendezvousKey rendezvous_key(run_options->run_id(),
+ participating_replicas_vec, op_kind, op_id);
+
+ std::shared_ptr<CpuAllReduceRendezvous> rendezvous =
+ GlobalRendezvousMap()[rendezvous_key];
+
+ auto shape_str = ShapeString(shape_ptr, shape_length);
+ VLOG(2) << "All-reduce input/output shape : " << shape_str;
+
+ xla::Shape shape =
+ DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
+
+ xla::AllReduceParticipantData participant(rendezvous_key);
+
+ CHECK_EQ(shape.dimensions_size(), 1);
+ participant.element_count = shape.dimensions(0);
+ participant.device_ordinal = device_ordinal;
+ participant.primitive_type = shape.element_type();
+ participant.stream = run_options->stream();
+
+ se::DeviceMemoryBase input(input_buffer, xla::ShapeUtil::ByteSizeOf(shape));
+ se::DeviceMemoryBase output(output_buffer, xla::ShapeUtil::ByteSizeOf(shape));
+ participant.source_data = input;
+ participant.destination_data = output;
+ participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
+
+ auto p = rendezvous->SubmitParticipant(participant).ValueOrDie();
+ std::shared_ptr<tensorflow::BlockingCounter> blocking_counter = p.second;
+ blocking_counter->DecrementCount();
+ xla::WaitAndLogIfStuck(blocking_counter.get(), [&] {
+ return absl::StrFormat(
+ "participant waiting for all threads to drop their reference to the "
+ "rendezvous: %s",
+ rendezvous_key.ToString());
+ });
+
+ rendezvous.reset();
+}
+
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
const xla::ExecutableRunOptions* run_options, void* output_buffer) {
int device_ordinal = [&]() {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index 23a9359..598ab35 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -28,6 +28,7 @@
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
@@ -67,6 +68,7 @@
extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
extern const char* const kParallelForkJoinSymbolName;
extern const char* const kKeyValueSortSymbolName;
+extern const char* const kAllReduceSymbolName;
extern const char* const kReplicaIdSymbolName;
extern const char* const kTracingStartSymbolName;
extern const char* const kTracingEndSymbolName;
@@ -154,6 +156,20 @@
const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length);
+// Perform all reduce on a CPU.
+//
+// participating_replicas: array of replica IDs participating in the reduction,
+// cf. GetParticipatingReplicas.
+// channel_id_present, op_id: whether op_id is a channel ID or a module ID.
+// reduction_kind: operator used for a reduction, cf. ReductionKind.
+// shape_ptr: shape of all input/output buffers.
+extern void __xla_cpu_runtime_AllReduce(
+ const xla::ExecutableRunOptions* run_options,
+ const void* replica_groups_str, xla::int32 replica_groups_str_size,
+ xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
+ const void* shape_ptr, xla::int32 shape_length, void* input_buffer,
+ void* output_buffer);
+
// Write the replica ID into the output buffer.
extern void __xla_cpu_runtime_ReplicaId(
const xla::ExecutableRunOptions* run_options, void* output_buffer);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 2cca102..0848288 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -40,7 +40,9 @@
#include "llvm/IR/LLVMContext.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
@@ -68,6 +70,7 @@
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/math/math_util.h"
@@ -1365,13 +1368,7 @@
return Status::OK();
}
-Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
- if (hlo_module_config_.replica_count() != 1) {
- // TODO(b/33011107): Support nontrivial cross replica sum on CPU.
- return Unimplemented(
- "AllReduce with >1 replica is not implemented on CPU.");
- }
-
+Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) {
// When there is a single replica, a cross replica sum is the identity
// function, and the buffer assignment expects a copy.
//
@@ -1406,6 +1403,112 @@
return Status::OK();
}
+Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
+ CHECK_GE(crs->operand_count(), 1);
+ PrimitiveType datatype = crs->operand(0)->shape().element_type();
+
+ bool is_datatype_supported = [&] {
+ // TODO(cheshire): Fix duplication wrt. cpu_runtime
+ switch (datatype) {
+ case S8:
+ case U8:
+ case S32:
+ case U32:
+ case S64:
+ case U64:
+ case F16:
+ case F32:
+ case F64:
+ return true;
+ default:
+ return false;
+ }
+ }();
+
+ if (!is_datatype_supported) {
+ return Unimplemented("AllReduce for datatype '%s' is not supported",
+ primitive_util::LowercasePrimitiveTypeName(datatype));
+ }
+
+ if (!MatchReductionComputation(crs->to_apply()).has_value()) {
+ return Unimplemented("AllReduce for computation '%s' is not supported",
+ crs->to_apply()->ToString());
+ }
+
+ llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
+ llvm::Type* int32_type = b_.getInt32Ty();
+ llvm::Type* int64_type = b_.getInt64Ty();
+ llvm::FunctionType* all_reduce_func_ty =
+ llvm::FunctionType::get(b_.getVoidTy(),
+ {/*run_options=*/i8_ptr_type,
+ /*replica_groups=*/i8_ptr_type,
+ /*replica_groups_size=*/int32_type,
+ /*channel_id_present=*/int32_type,
+ /*op_id=*/int64_type,
+ /*reduction_kind=*/int32_type,
+ /*shape_ptr=*/i8_ptr_type,
+ /*shape_length=*/int32_type,
+ /*input_buffer=*/i8_ptr_type,
+ /*output_buffer=*/i8_ptr_type},
+ /*isVarArg=*/false);
+
+ auto all_reduce_func = llvm::dyn_cast<llvm::Function>(
+ module_
+ ->getOrInsertFunction(runtime::kAllReduceSymbolName,
+ all_reduce_func_ty)
+ .getCallee());
+ all_reduce_func->setCallingConv(llvm::CallingConv::C);
+
+ std::string replica_groups = ReplicaGroupsToString(crs->replica_groups());
+ int32 replica_groups_size = replica_groups.size();
+ llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
+
+ Shape shape = crs->operand(0)->shape();
+ int32 shape_length;
+ TF_ASSIGN_OR_RETURN(
+ llvm::Value * shape_ptr,
+ llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
+
+ TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
+ assignment_.GetUniqueSlice(crs->operand(0), {}));
+ llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
+
+ TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
+ assignment_.GetUniqueSlice(crs, {}));
+ llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
+
+ Call(all_reduce_func,
+ {/*run_options=*/GetExecutableRunOptionsArgument(),
+ /*replica_groups=*/replica_groups_v,
+ /*replica_groups_size=*/b_.getInt32(replica_groups_size),
+
+ /*channel_id_present=*/
+ b_.getInt32(static_cast<int32>(crs->channel_id().has_value())),
+ /*op_id=*/
+ b_.getInt64(crs->channel_id().has_value()
+ ? *crs->channel_id()
+ : crs->GetModule()->unique_id()),
+
+ /*reduction_kind=*/
+ b_.getInt32(
+ static_cast<int32>(*MatchReductionComputation(crs->to_apply()))),
+
+ /*shape_ptr=*/shape_ptr,
+ /*shape_length=*/b_.getInt32(shape_length),
+
+ /*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
+ /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)});
+
+ return Status::OK();
+}
+
+Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
+ if (hlo_module_config_.replica_count() == 1) {
+ return HandleAllReduceSingleReplica(crs);
+ }
+ return HandleAllReduceMultipleReplica(crs);
+}
+
Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
llvm::FunctionType* replica_id_function_ty =
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 5f6da58..453676b 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -196,6 +196,9 @@
}
private:
+ Status HandleAllReduceSingleReplica(HloInstruction* crs);
+ Status HandleAllReduceMultipleReplica(HloInstruction* crs);
+
// Private helper to initialize an IR function for the computation.
void InitializeIrFunction(const string& function_name);
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index 4036497..1505f0a 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -216,6 +216,7 @@
REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
+ REGISTER_CPU_RUNTIME_SYMBOL(AllReduce);
REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId);
REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 5589821..6302907 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1893,7 +1893,11 @@
xla_test(
name = "collective_ops_test",
srcs = ["collective_ops_test.cc"],
- backends = ["gpu"],
+ args = ["--xla_force_host_platform_device_count=4"],
+ backends = [
+ "gpu",
+ "cpu",
+ ],
tags = [
# This test is tagged "manual" because it requires multiple GPUs, and
# Forge only supports single-GPU tests. Guitar skips "manual" tests
diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc
index da430ed..372a366 100644
--- a/tensorflow/compiler/xla/tests/collective_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc
@@ -224,7 +224,7 @@
// Check that the NCCL data structures in our all-reduce implementation are
// cached as we expect.
-XLA_TEST_F(CollectiveOpsTest, AllReduce_NcclChannelCaching) {
+XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_NcclChannelCaching)) {
const int64 kNumElems = 1024;
std::vector<float> input_vec(kNumElems);
@@ -398,7 +398,7 @@
}
}
-XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) {
+XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_Simple)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {