blob: f2fa9640f855d8cad2917af0fbeacf65fd65cc8a [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
#include <vector>
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
enum class ReductionKind { SUM, PRODUCT, MIN, MAX };
// Atempts to match computation to one of the possible cases in ReductionKind.
absl::optional<ReductionKind> MatchReductionComputation(
const HloComputation* computation);
// Figures out which devices (named by their replica-ids) are participating in
// the all-reduce subgroup that contains device_ordinal.
StatusOr<std::vector<int64>> GetParticipatingReplicas(
int64 device_ordinal, absl::Span<const ReplicaGroup> replica_groups,
int64 total_replica_count, const DeviceAssignment& device_assn);
// Key that identifies a particular Rendezvous object in our global hashtable.
// This determines which calls to ExecuteOnStream communicate with each other.
// The rules are as follows.
//
// * Only ops with the same RunId can communicate with each other. (This is the
// whole purpose of RunId).
//
// * Only ops with the same set of participating replicas can communicate with
// each other. This is how we separate out different replica groups (e.g. a
// single AllReduce HLO might do two reductions, between say GPUs {0,2} and
// {1,3}).
//
// * Only ops with the same opcode can communicate with each other. At the
// moment we only support kAllReduce, so we don't check for this explicitly.
//
// * For cross-module all-reduces (i.e. instr->channel_id().has_value()),
// only ops with the same value for channel_id() can communicate with each
// other.
//
// * For cross-replica (i.e. same-module) all-reduces (i.e.
// !channel_id().has_value()), only ops from the same module (as
// identified by its unique_id()) can communicate with each other.
//
struct RendezvousKey {
enum CollectiveOpKind {
kCrossModule,
kCrossReplica,
};
explicit RendezvousKey(const RunId& run_id,
std::vector<int64> participating_replicas,
CollectiveOpKind collective_op_kind, int64 op_id)
: run_id(run_id),
participating_replicas(participating_replicas),
collective_op_kind(collective_op_kind),
op_id(op_id) {}
static RendezvousKey FromInstruction(
const RunId& run_id, std::vector<int64> participating_replicas,
const HloInstruction* instr) {
CollectiveOpKind collective_op_kind;
int64 op_id;
std::tie(collective_op_kind, op_id) =
instr->channel_id().has_value()
? std::make_pair(kCrossModule, instr->channel_id().value())
: std::make_pair(
kCrossReplica,
static_cast<int64>(instr->GetModule()->unique_id()));
return RendezvousKey(run_id, participating_replicas, collective_op_kind,
op_id);
}
int num_participants() const { return participating_replicas.size(); }
template <typename H>
friend H AbslHashValue(H h, const RendezvousKey& k) {
return H::combine(std::move(h), k.run_id, k.participating_replicas,
static_cast<int>(k.collective_op_kind), k.op_id);
}
friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) {
return a.run_id == b.run_id &&
a.participating_replicas == b.participating_replicas &&
a.collective_op_kind == b.collective_op_kind && //
a.op_id == b.op_id;
}
friend bool operator!=(const RendezvousKey& a, const RendezvousKey& b) {
return !(a == b);
}
string ToString() const {
return absl::StrFormat(
"RendezvousKey{run_id=%s, participating_replicas=[%s], "
"collective_op_kind=%d, op_id=%d}",
run_id.ToString(), absl::StrJoin(participating_replicas, ","),
static_cast<int>(collective_op_kind), op_id);
}
RunId run_id;
std::vector<int64> participating_replicas;
CollectiveOpKind collective_op_kind;
int64 op_id;
};
template <typename DescFn>
void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
const DescFn& desc_fn) {
VLOG(3) << "Begin: " << desc_fn();
const std::chrono::milliseconds timeout(5000);
bool ok = counter->WaitFor(timeout);
if (ok) {
VLOG(3) << "Finished: " << desc_fn();
return;
}
LOG(ERROR) << "This thread has been waiting for " << timeout.count()
<< "ms for and may be stuck: " << desc_fn();
counter->Wait();
LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
"Perhaps the timeout is too short: "
<< desc_fn();
}
// Encapsulates parameters to Rendezvous::SubmitParticipant.
struct AllReduceParticipantData {
explicit AllReduceParticipantData(RendezvousKey rendezvous_key)
: rendezvous_key(rendezvous_key) {}
int64 element_count;
int64 device_ordinal;
RendezvousKey rendezvous_key;
// TODO(b/125951860): We should vet that we're buffer allocating such that
// source_buffer == destination_buffer if that avoids a NCCL copy (will depend
// on how well the NCCL in-place implementation performs vs the out-of-place
// implementation).
se::DeviceMemoryBase source_data;
se::DeviceMemoryBase destination_data;
se::Stream* stream;
ReductionKind reduction_kind;
PrimitiveType primitive_type;
int num_participants() const { return rendezvous_key.num_participants(); }
string ToString() const {
return absl::StrFormat(
"AllReduceParticipantData{element_count=%d, rendezvous_key=%s, "
"device_ordinal=%d, stream=%p}",
element_count, rendezvous_key.ToString(), device_ordinal, stream);
}
};
// The set of threads that want to do a collective op together all pick the same
// Rendezvous object out of the global cache and call SubmitParticipant.
//
// The Rendezvous instance handles waiting for all threads to join, ensuring
// that a clique exists for the desired set of GPUs, etc.
//
// Rendezvous objects can only be used once.
//
// O: Participant output.
template <typename O>
class Rendezvous {
public:
virtual ~Rendezvous() {}
explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
// Runs the all-reduce on the given thread. If successful, returns
// - a handle to the clique that was used, so that the caller may keep the
// clique alive if it chooses.
// - a BlockingCounter initialized to the number of participants, so that
// the caller can coordinate with the participants one last time if it
// chooses. This is useful for coordinating destruction of the Rendezvous.
StatusOr<std::pair<O, std::shared_ptr<tensorflow::BlockingCounter>>>
SubmitParticipant(AllReduceParticipantData participant) {
{
tensorflow::mutex_lock lock(mu_);
CHECK(!initialized_);
// Spot check for consistent replica counts among submitting threads.
if (!participants_.empty() &&
(participants_.back().element_count != participant.element_count ||
participants_.back().rendezvous_key != participant.rendezvous_key)) {
return InvalidArgument(
"Mismatch among all-reduce participants. Expected same "
"replica-count, element-count, and rendezvous-key but were %s and "
"%s",
participants_.back().ToString(), participant.ToString());
}
participants_.push_back(participant);
}
// Wait for all participants to arrive.
all_participants_present_.DecrementCount();
WaitAndLogIfStuck(&all_participants_present_, [&] {
return absl::StrFormat(
"participant for device ordinal %d, stream %p waiting for all "
"participants to arrive at rendezvous %s",
participant.device_ordinal, participant.stream, key_.ToString());
});
StatusOr<std::pair<O, bool>> p_or = SubmitParticipantImpl(participant);
done_.DecrementCount();
if (!p_or.ok()) {
return p_or.status();
}
std::pair<O, bool> p = p_or.ValueOrDie();
O handle = p.first;
bool is_primary = p.second;
// The primary owns the lock on the NCCL clique. Hold it until all threads
// are done. (We'll release it when we return from this function.)
if (is_primary) {
WaitAndLogIfStuck(&done_, [&] {
return absl::StrFormat(
"primary participant waiting for all other participants to "
"complete all-reduce %s",
key_.ToString());
});
}
CleanupImpl(handle, is_primary);
return std::make_pair(handle, returned_blocking_counter_);
}
protected:
// Returns domain-specific output O and whether this replica is primary.
virtual StatusOr<std::pair<O, bool>> SubmitParticipantImpl(
AllReduceParticipantData participant) = 0;
virtual void CleanupImpl(O handle, bool is_primary) {}
tensorflow::mutex mu_;
bool initialized_ GUARDED_BY(mu_) = false;
std::vector<AllReduceParticipantData> participants_ GUARDED_BY(mu_);
private:
const RendezvousKey key_;
tensorflow::BlockingCounter all_participants_present_{
key_.num_participants()};
tensorflow::BlockingCounter done_{key_.num_participants()};
// tensorflow::BlockingCounter returned by SubmitParticipant.
std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
std::make_shared<tensorflow::BlockingCounter>(key_.num_participants())};
};
} // end namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_