blob: 2ccdf5b545cee1e82e0cabb704beb5205b569e7d [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
#include <algorithm>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/notification.h"
#include "absl/time/time.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/coordination_config.pb.h"
#include "tensorflow/core/protobuf/coordination_service.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
namespace {
constexpr absl::Duration kDevicePropagationTimeout = absl::Hours(1);
constexpr int kDefaultHeartbeatTimeoutMs = 10 * 1000; // 10 seconds
constexpr int kServiceToClientTimeoutMs = 10 * 1000; // 10 seconds
constexpr size_t kOngoingBarriersSoftLimit = 20;
constexpr char kHealthCheckThread[] = "CoordinationServiceHealthCheck";
std::string GetTaskName(absl::string_view job_name, int task_id) {
return strings::StrCat("/job:", job_name, "/replica:", 0, "/task:", task_id);
}
std::string GetTaskName(const CoordinatedTask& task) {
return GetTaskName(task.job_name(), task.task_id());
}
CoordinatedTask GetTaskFromName(absl::string_view task_name) {
DeviceNameUtils::ParsedName parsed;
DeviceNameUtils::ParseFullName(task_name, &parsed);
CoordinatedTask task;
task.set_job_name(parsed.job);
task.set_task_id(parsed.task);
return task;
}
bool is_multi_client_leader(const ServerDef& server_def) {
const auto& config = server_def.default_session_config();
const std::string& leader =
config.experimental().coordination_config().service_leader();
const std::string& collective_leader =
config.experimental().collective_group_leader();
DeviceNameUtils::ParsedName leader_pn;
if (!leader.empty()) {
DeviceNameUtils::ParseFullName(leader, &leader_pn);
} else if (!collective_leader.empty()) {
LOG(INFO) << "No coordination leader is set, using the collective leader "
<< collective_leader;
DeviceNameUtils::ParseFullName(collective_leader, &leader_pn);
} else {
LOG(INFO) << "No coordination leader is set, using the default /job:"
<< server_def.job_name() << "/replica:0/task:0";
return server_def.task_index() == 0;
}
return server_def.job_name() == leader_pn.job &&
server_def.task_index() == leader_pn.task;
}
// Convenience structs to allow using CoordinatedTask as container keys.
struct CoordinatedTaskHash {
uint64_t operator()(const CoordinatedTask& task) const {
return absl::HashOf(task.job_name(), task.task_id());
}
};
struct CoordinatedTaskEqual {
bool operator()(const CoordinatedTask& lhs,
const CoordinatedTask& rhs) const {
return lhs.job_name() == rhs.job_name() && lhs.task_id() == rhs.task_id();
}
};
// Standalone implementation of the coordination service.
class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface {
public:
CoordinationServiceStandaloneImpl(
std::unique_ptr<CoordinationClientCache> client_cache, Env* env,
const ServerDef& server_def);
~CoordinationServiceStandaloneImpl() override { Stop(); }
Status RegisterTask(const CoordinatedTask& task,
uint64_t incarnation) override;
void WaitForAllTasks(const CoordinatedTask& task,
const CoordinationServiceDeviceInfo& devices,
StatusCallback done) override;
void ShutdownTaskAsync(const CoordinatedTask& task,
StatusCallback done) override;
Status ResetTask(const CoordinatedTask& task) override;
Status RecordHeartbeat(const CoordinatedTask& task,
uint64_t incarnation) override;
Status ReportTaskError(const CoordinatedTask& task, Status error) override;
Status InsertKeyValue(const std::string& key,
const std::string& value) override;
void GetKeyValueAsync(const std::string& key,
StatusOrValueCallback done) override;
StatusOr<std::string> TryGetKeyValue(const std::string& key) override;
std::vector<KeyValueEntry> GetKeyValueDir(
absl::string_view directory_key) override;
Status DeleteKeyValue(const std::string& key) override;
void BarrierAsync(const std::string& barrier_id, absl::Duration timeout,
const CoordinatedTask& task,
const std::vector<CoordinatedTask>& participating_tasks,
StatusCallback done) override;
Status CancelBarrier(const std::string& barrier_id,
const CoordinatedTask& task) override;
private:
const CoordinationServiceDeviceInfo& ListClusterDevices() override
TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
uint64_t GetServiceIncarnation() override;
void StartCheckStaleness(); // Checks both heartbeat and barrier timeouts.
void Stop(bool shut_staleness_thread = true);
// Report service error to a specified task.
void ReportServiceErrorToTaskAsync(const CoordinatedTask& destination_task,
Status error);
// Report error from a task to all other connected tasks.
// Note: SetTaskError() must be called before propagating its error.
void PropagateError(const CoordinatedTask& source_task,
bool is_reported_by_task = false)
TF_LOCKS_EXCLUDED(state_mu_);
void SetTaskError(absl::string_view task_name, Status error)
TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
void SetXlaGlobalDeviceIds() TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
Status DisconnectTask(const CoordinatedTask& task)
TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
struct BarrierState {
bool passed = false;
Status result = errors::Unknown(
"Invalid barrier result."); // Only valid if `passed` is true.
uint64_t deadline_in_micros = 0;
int num_pending_tasks = 0;
// Specifies which tasks have called the barrier so far.
absl::flat_hash_map<CoordinatedTask, bool, CoordinatedTaskHash,
CoordinatedTaskEqual>
tasks_at_barrier;
std::vector<StatusCallback> done_callbacks;
};
void PassBarrier(absl::string_view barrier_id, Status result,
BarrierState* barrier)
TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
// Check if participating tasks are specified correctly across barrier calls.
bool ValidateTaskArgs(
const std::vector<CoordinatedTask>& tasks_args,
const absl::flat_hash_map<CoordinatedTask, bool, CoordinatedTaskHash,
CoordinatedTaskEqual>& tasks_at_barrier,
int64_t cluster_size);
class TaskState {
public:
// Task state maintained on the coordination service side.
// State transition:
// Register Heartbeat
// DISCONNECTED -------> CONNECTED --------> ERROR (timeout)
// | ReportError
// +--------------> ERROR
//
// When task state becomes ERROR, propagate this status to other CONNECTED
// tasks in the cluster.
enum class State {
DISCONNECTED,
CONNECTED,
ERROR,
};
State GetState() { return state_; }
Status GetStatus() { return status_; }
void SetConnected(uint64_t task_incarnation);
void Disconnect(uint64_t grace_period_duration_us);
Status RecordHeartbeat(uint64_t task_incarnation);
int64_t TimeSinceLastHeartbeatMs();
// This denotes the deadline after which we stop accepting heartbeats from a
// disconnected task. This grace period accounts for the lag time between
// the service recording the state change and the agent stopping heartbeats.
uint64_t GetDisconnectedGracePeriodMicros();
void SetError(Status status);
bool GetDeviceInfoCollected() { return device_info_collected_; }
void MarkDeviceInfoCollected() { device_info_collected_ = true; }
absl::flat_hash_set<std::string> GetOngoingBarriers();
void JoinBarrier(absl::string_view barrier_id);
void ExitBarrier(absl::string_view barrier_id);
private:
// Incarnation ID for CPU:0 on remote task.
uint64_t task_incarnation_ = 0;
State state_ = State::DISCONNECTED;
Status status_;
mutex last_heartbeat_mu_;
uint64_t last_heartbeat_us_ TF_GUARDED_BY(last_heartbeat_mu_);
// This denotes the deadline after which we stop accepting heartbeats from a
// disconnected task. This grace period accounts for the lag time between
// the service recording the state change and the agent stopping heartbeats.
uint64_t disconnect_grace_period_us_ = 0;
// Checks if task has called WaitForAllTasks() previously, which gathers the
// local device info.
bool device_info_collected_ = false;
// For now, we assume there won't be many simultaneous barriers so we simply
// use a set.
absl::flat_hash_set<std::string> ongoing_barriers_for_task_;
};
std::unique_ptr<CoordinationClientCache> client_cache_;
Env& env_;
const uint64_t service_incarnation_ = random::New64();
const uint64_t heartbeat_timeout_ms_;
const absl::Duration shutdown_barrier_timeout_;
const std::string device_propagation_barrier_id_ =
absl::StrCat("WaitForAllTasks::", std::to_string(service_incarnation_));
const std::string shutdown_barrier_id_ =
absl::StrCat("Shutdown::", std::to_string(service_incarnation_));
mutex state_mu_;
absl::flat_hash_map<std::string, std::unique_ptr<TaskState>> cluster_state_
TF_GUARDED_BY(state_mu_);
CoordinationServiceDeviceInfo cluster_devices_ TF_GUARDED_BY(state_mu_);
mutex kv_mu_;
// Ordered map to store config key-values
std::map<std::string, std::string> kv_store_ TF_GUARDED_BY(kv_mu_);
absl::flat_hash_map<std::string, std::vector<StatusOrValueCallback>> get_cb_
TF_GUARDED_BY(kv_mu_);
mutex check_staleness_thread_shutdown_mu_;
condition_variable check_staleness_thread_cv_;
bool shutting_down_ TF_GUARDED_BY(check_staleness_thread_shutdown_mu_) =
false;
std::unique_ptr<Thread> check_staleness_thread_;
absl::flat_hash_map<std::string, BarrierState> barriers_
TF_GUARDED_BY(state_mu_);
// For now, we assume there won't be many simultaneous barriers so we simply
// use a set.
absl::flat_hash_set<std::string> ongoing_barriers_ TF_GUARDED_BY(state_mu_);
TF_DISALLOW_COPY_AND_ASSIGN(CoordinationServiceStandaloneImpl);
};
void CoordinationServiceStandaloneImpl::TaskState::SetConnected(
uint64_t task_incarnation) {
state_ = State::CONNECTED;
status_ = OkStatus();
task_incarnation_ = task_incarnation;
mutex_lock l(last_heartbeat_mu_);
last_heartbeat_us_ = Env::Default()->NowMicros();
}
void CoordinationServiceStandaloneImpl::TaskState::Disconnect(
uint64_t grace_period_duration_us) {
disconnect_grace_period_us_ =
Env::Default()->NowMicros() + grace_period_duration_us;
state_ = State::DISCONNECTED;
status_ = OkStatus();
}
void CoordinationServiceStandaloneImpl::TaskState::SetError(
const Status status) {
if (state_ == State::ERROR) return;
state_ = State::ERROR;
status_ = status;
}
Status CoordinationServiceStandaloneImpl::TaskState::RecordHeartbeat(
uint64_t task_incarnation) {
if (!status_.ok()) return status_;
if (task_incarnation != task_incarnation_) {
return MakeCoordinationError(errors::Aborted(
"Incarnation ID mismatch: expecting ", task_incarnation_, " but got ",
task_incarnation, ". This means the remote task has restarted."));
}
mutex_lock l(last_heartbeat_mu_);
last_heartbeat_us_ = Env::Default()->NowMicros();
return OkStatus();
}
int64_t
CoordinationServiceStandaloneImpl::TaskState::TimeSinceLastHeartbeatMs() {
mutex_lock l(last_heartbeat_mu_);
return (Env::Default()->NowMicros() - last_heartbeat_us_) / 1000;
}
uint64_t CoordinationServiceStandaloneImpl::TaskState::
GetDisconnectedGracePeriodMicros() {
return disconnect_grace_period_us_;
}
absl::flat_hash_set<std::string>
CoordinationServiceStandaloneImpl::TaskState::GetOngoingBarriers() {
return ongoing_barriers_for_task_;
}
void CoordinationServiceStandaloneImpl::TaskState::JoinBarrier(
absl::string_view barrier_id) {
ongoing_barriers_for_task_.emplace(barrier_id);
}
void CoordinationServiceStandaloneImpl::TaskState::ExitBarrier(
absl::string_view barrier_id) {
ongoing_barriers_for_task_.erase(barrier_id);
}
CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl(
std::unique_ptr<CoordinationClientCache> client_cache, Env* env,
const ServerDef& server_def)
: client_cache_(std::move(client_cache)),
env_(*env),
heartbeat_timeout_ms_([&server_def]() -> uint64_t {
const auto& configs = server_def.default_session_config()
.experimental()
.coordination_config();
return configs.heartbeat_timeout_in_ms() > 0
? configs.heartbeat_timeout_in_ms()
: kDefaultHeartbeatTimeoutMs;
}()),
shutdown_barrier_timeout_(
absl::Milliseconds(server_def.default_session_config()
.experimental()
.coordination_config()
.shutdown_barrier_timeout_in_ms())) {
const auto& configs =
server_def.default_session_config().experimental().coordination_config();
const std::unordered_set<std::string> coordinated_jobs(
configs.coordinated_jobs().cbegin(), configs.coordinated_jobs().cend());
const auto& cluster_def = server_def.cluster();
for (const auto& job : cluster_def.job()) {
// If `coordinated_jobs` is specified, skip jobs that are not included there
if (!coordinated_jobs.empty() &&
coordinated_jobs.find(job.name()) == coordinated_jobs.end()) {
continue;
}
for (const auto& task : job.tasks()) {
const std::string& task_name = GetTaskName(job.name(), task.first);
cluster_state_.emplace(task_name, std::make_unique<TaskState>());
}
}
StartCheckStaleness();
}
// Checks both heartbeat and barrier timeouts in the same thread, since threads
// are a constrained resource.
void CoordinationServiceStandaloneImpl::StartCheckStaleness() {
check_staleness_thread_.reset(
env_.StartThread({}, kHealthCheckThread, [this]() {
const bool has_service_to_client_connection = client_cache_ != nullptr;
// Used to store stale tasks and barriers.
std::vector<absl::string_view> stale_task_names;
absl::flat_hash_map<std::string, BarrierState*> expired_barriers;
while (true) {
{
mutex_lock l(check_staleness_thread_shutdown_mu_);
check_staleness_thread_cv_.wait_for(l, std::chrono::seconds(1));
if (shutting_down_) {
return;
}
}
// Heartbeat check.
Status status = OkStatus();
{
mutex_lock l(state_mu_);
for (const auto& [task_name, task_state] : cluster_state_) {
// Skip tasks that are not registered or in error state
if (task_state->GetState() != TaskState::State::CONNECTED) {
continue;
}
const bool is_stale = task_state->TimeSinceLastHeartbeatMs() >
heartbeat_timeout_ms_;
VLOG(1) << "Checking staleness for " << task_name
<< " stale?=" << is_stale;
if (is_stale) {
stale_task_names.push_back(task_name);
status = MakeCoordinationError(errors::Unavailable(
"Task ", task_name,
" heartbeat timeout. This indicates that the remote task "
"has failed, got preempted, or crashed unexpectedly."));
SetTaskError(task_name, status);
}
}
}
// Propagate heartbeat timeout errors to other connected tasks.
if (!stale_task_names.empty()) {
if (!has_service_to_client_connection) {
// Error cannot be propagated since there is no service-to-client
// connection, so shut down service instead. Note: we cannot
// destroy the thread within its own function. However, this
// thread will be destroyed once the function returns.
LOG(ERROR) << "Stopping coordination service as heartbeat has "
"timed out for "
<< stale_task_names[0]
<< " and there is no service-to-client connection";
Stop(/*shut_staleness_thread=*/false);
return;
}
for (const auto& stale_task_name : stale_task_names) {
PropagateError(GetTaskFromName(stale_task_name));
}
stale_task_names.clear();
}
// Barrier timeout check.
uint64_t current_time_micros = Env::Default()->NowMicros();
{
mutex_lock l(state_mu_);
// Gather barriers which have timed out.
for (const std::string& barrier_id : ongoing_barriers_) {
auto* barrier = &barriers_[barrier_id];
if (current_time_micros > barrier->deadline_in_micros) {
expired_barriers[barrier_id] = barrier;
}
}
// Pass these barriers with the time out error.
for (const auto& [barrier_id, barrier] : expired_barriers) {
const Status error =
MakeCoordinationError(errors::DeadlineExceeded(absl::StrCat(
"Barrier timed out. Barrier_id: ", barrier_id)));
PassBarrier(barrier_id, error, barrier);
}
}
if (!has_service_to_client_connection &&
expired_barriers.contains(shutdown_barrier_id_)) {
// Error cannot be propagated since there is no service-to-client
// connection, so shut down service instead. Note: we cannot
// destroy the thread within its own function. However, this
// thread will be destroyed once the function returns.
LOG(ERROR)
<< "Stopping coordination service as shutdown barrier "
"timed out and there is no service-to-client connection.";
Stop(/*shut_staleness_thread=*/false);
}
// Reset this for the next barrier check.
expired_barriers.clear();
}
}));
}
void CoordinationServiceStandaloneImpl::Stop(bool shut_staleness_thread) {
{
mutex_lock l(kv_mu_);
for (const auto& [key, get_kv_callbacks] : get_cb_) {
for (const auto& get_kv_callback : get_kv_callbacks) {
get_kv_callback(errors::Cancelled(
absl::StrCat("Coordination service is shutting down. Cancelling "
"GetKeyValue() for key: ",
key)));
}
}
get_cb_.clear();
}
{
mutex_lock l(state_mu_);
cluster_state_.clear();
for (auto& [barrier_id, barrier] : barriers_) {
if (!barrier.passed) {
Status error = MakeCoordinationError(errors::Aborted(absl::StrCat(
"Barrier failed because service is shutting down. Barrier_id: ",
barrier_id)));
PassBarrier(barrier_id, error, &barrier);
}
}
barriers_.clear();
}
{
mutex_lock l(check_staleness_thread_shutdown_mu_);
shutting_down_ = true;
check_staleness_thread_cv_.notify_all();
}
if (shut_staleness_thread) {
check_staleness_thread_.reset();
}
}
Status CoordinationServiceStandaloneImpl::RegisterTask(
const CoordinatedTask& task, uint64_t incarnation) {
const std::string& task_name = GetTaskName(task);
Status status;
{
mutex_lock l(state_mu_);
if (!cluster_state_.contains(task_name)) {
// Note: return early here as unexpected task register errors should not
// be propagated to other tasks.
return MakeCoordinationError(errors::InvalidArgument(
"Unexpected task registered with task_name=", task_name));
}
if (cluster_state_[task_name]->GetState() ==
TaskState::State::DISCONNECTED) {
// This task is currently disconnected (registering for the first time or
// has called ResetTask() previously).
cluster_state_[task_name]->SetConnected(incarnation);
LOG(INFO) << task_name
<< " has connected to coordination service. Incarnation: "
<< incarnation;
} else {
// This task is connected or already in error, which implies it has
// registered previously.
status = MakeCoordinationError(
errors::Aborted("Duplicate task registration with task_name=",
task_name),
task);
SetTaskError(task_name, status);
}
}
if (!status.ok()) {
PropagateError(task);
}
return status;
}
void CoordinationServiceStandaloneImpl::WaitForAllTasks(
const CoordinatedTask& task, const CoordinationServiceDeviceInfo& devices,
StatusCallback done) {
{
mutex_lock l(state_mu_);
const auto& task_state = cluster_state_.find(GetTaskName(task));
// Add task device info to global device state for the first time that task
// has called WaitForAllTasks().
if (task_state != cluster_state_.end() &&
!task_state->second->GetDeviceInfoCollected()) {
cluster_devices_.MergeFrom(devices);
task_state->second->MarkDeviceInfoCollected();
}
}
BarrierAsync(device_propagation_barrier_id_, kDevicePropagationTimeout, task,
{}, std::move(done));
}
void CoordinationServiceStandaloneImpl::ShutdownTaskAsync(
const CoordinatedTask& task, StatusCallback done) {
if (shutdown_barrier_timeout_ > absl::ZeroDuration()) {
// Impose shutdown barrier so that all tasks can disconnect together.
BarrierAsync(shutdown_barrier_id_, shutdown_barrier_timeout_, task, {},
done);
} else {
Status status;
{
mutex_lock l(state_mu_);
// Disconnect task from service individually.
status = DisconnectTask(task);
}
done(status);
}
}
Status CoordinationServiceStandaloneImpl::ResetTask(
const CoordinatedTask& task) {
mutex_lock l(state_mu_);
return DisconnectTask(task);
}
Status CoordinationServiceStandaloneImpl::DisconnectTask(
const CoordinatedTask& task) {
const std::string task_name = GetTaskName(task);
// Check if task is valid and not already disconnected.
if (!cluster_state_.contains(task_name)) {
return MakeCoordinationError(errors::InvalidArgument(
"Unexpected disconnect request with task_name=", task_name));
} else if (cluster_state_[task_name]->GetState() ==
TaskState::State::DISCONNECTED) {
return MakeCoordinationError(errors::FailedPrecondition(
"The task is already disconnected: ", task_name));
}
// Disconnect task and fail any ongoing barriers.
cluster_state_[task_name]->Disconnect(
/*grace_period_duration_us=*/heartbeat_timeout_ms_ * 1000);
for (const auto& barrier_id :
cluster_state_[task_name]->GetOngoingBarriers()) {
Status error = MakeCoordinationError(errors::Internal(absl::StrCat(
"Barrier failed from a disconnected task. Barrier Id: ", barrier_id,
", Task: ", task_name)));
PassBarrier(barrier_id, error, &barriers_[barrier_id]);
}
LOG(INFO) << task_name << " has disconnected from coordination service.";
return OkStatus();
}
const CoordinationServiceDeviceInfo&
CoordinationServiceStandaloneImpl::ListClusterDevices() {
return cluster_devices_;
}
uint64_t CoordinationServiceStandaloneImpl::GetServiceIncarnation() {
return service_incarnation_;
}
Status CoordinationServiceStandaloneImpl::ReportTaskError(
const CoordinatedTask& task, Status error) {
const std::string& task_name = GetTaskName(task);
{
mutex_lock l(state_mu_);
if (!cluster_state_.contains(task_name)) {
return MakeCoordinationError(
errors::InvalidArgument("Unexpected request from task ", task_name));
} else if (cluster_state_[task_name]->GetState() !=
TaskState::State::CONNECTED) {
return MakeCoordinationError(errors::FailedPrecondition(
"The task is not connected or already has an error."));
} else {
SetTaskError(task_name, error);
}
}
PropagateError(task, /*is_reported_by_task=*/true);
return OkStatus();
}
Status CoordinationServiceStandaloneImpl::RecordHeartbeat(
const CoordinatedTask& task, uint64_t incarnation) {
const std::string& task_name = GetTaskName(task);
Status s = OkStatus();
{
mutex_lock l(state_mu_);
if (!cluster_state_.contains(task_name)) {
return MakeCoordinationError(errors::InvalidArgument(
"Unexpected task request with task_name=", task_name));
}
if (!cluster_state_[task_name]->GetStatus().ok()) {
return cluster_state_[task_name]->GetStatus();
} else if (cluster_state_[task_name]->GetState() ==
TaskState::State::DISCONNECTED &&
// We accept heartbeats for a short grace period to account for
// the lag time between the service recording the state change
// and the agent stopping heartbeats.
Env::Default()->NowMicros() >
cluster_state_[task_name]
->GetDisconnectedGracePeriodMicros()) {
return MakeCoordinationError(errors::InvalidArgument(
"Task with task_name=", task_name,
" must be registered before sending heartbeat messages"));
}
s = cluster_state_[task_name]->RecordHeartbeat(incarnation);
}
// Set and propagate any heartbeat errors.
if (!s.ok()) {
{
mutex_lock l(state_mu_);
SetTaskError(task_name, s);
}
PropagateError(task);
}
return s;
}
void CoordinationServiceStandaloneImpl::ReportServiceErrorToTaskAsync(
const CoordinatedTask& destination_task, Status error) {
assert(!error.ok());
// Don't report error if there is no service-to-client connection.
if (client_cache_ == nullptr) {
LOG(ERROR) << error;
return;
}
auto request = std::make_shared<ReportErrorToTaskRequest>();
auto response = std::make_shared<ReportErrorToTaskResponse>();
request->set_error_code(error.code());
request->set_error_message(error.error_message());
CoordinatedTask* error_source =
request->mutable_error_payload()->mutable_source_task();
error_source->set_job_name("coordination_service");
auto call_opts = std::make_shared<CallOptions>();
call_opts->SetTimeout(kServiceToClientTimeoutMs);
const std::string task_name = GetTaskName(destination_task);
CoordinationClient* client = client_cache_->GetClient(task_name);
client->ReportErrorToTaskAsync(
call_opts.get(), request.get(), response.get(),
[request, response, task_name, call_opts](Status s) {
if (!s.ok()) {
LOG(ERROR) << "Encountered another error while reporting to "
<< task_name << ": " << s;
}
});
}
void CoordinationServiceStandaloneImpl::PropagateError(
const CoordinatedTask& source_task, bool is_reported_by_task) {
Status error;
{
mutex_lock l(state_mu_);
error = cluster_state_[GetTaskName(source_task)]->GetStatus();
}
assert(!error.ok());
ReportErrorToTaskRequest request;
request.set_error_code(error.code());
request.set_error_message(error.error_message());
CoordinationServiceError* payload = request.mutable_error_payload();
*payload->mutable_source_task() = source_task;
payload->set_is_reported_error(is_reported_by_task);
CallOptions call_opts;
call_opts.SetTimeout(kServiceToClientTimeoutMs);
std::vector<std::shared_ptr<absl::Notification>> notifications;
std::vector<absl::string_view> task_names;
{
tf_shared_lock l(state_mu_);
task_names.reserve(cluster_state_.size());
for (const auto& pair : cluster_state_) {
task_names.emplace_back(pair.first);
}
}
for (absl::string_view task : task_names) {
{
mutex_lock l(state_mu_);
// Propagate error only to tasks that are connected
if (cluster_state_[task]->GetState() != TaskState::State::CONNECTED)
continue;
}
// Don't propagate error if there is no service-to-client connection.
if (client_cache_ == nullptr) {
LOG(ERROR)
<< "Stopping coordination service as there is no "
"service-to-client connection, but we encountered an error: "
<< error;
Stop(/*shut_staleness_thread=*/false);
return;
}
CoordinationClient* client = client_cache_->GetClient(std::string(task));
auto response = std::make_shared<ReportErrorToTaskResponse>();
auto n = std::make_shared<absl::Notification>();
client->ReportErrorToTaskAsync(
&call_opts, &request, response.get(), [response, n, task](Status s) {
if (!s.ok()) {
LOG(ERROR) << "Encountered another error while reporting to "
<< task << ": " << s;
}
n->Notify();
});
notifications.push_back(n);
}
for (auto& n : notifications) {
n->WaitForNotification();
}
}
// Utility for normalizing structured config key string.
// The normalized key will not have leading or trailing slashes, and all parts
// in the key path are separated by exactly one slack ('/').
// E.g., ///a//b/c// --> a/b/c
std::string NormalizeKey(const StringPiece orig_key) {
std::string norm_key = std::string(orig_key);
const char* src = norm_key.c_str();
std::string::iterator dst = norm_key.begin();
// Parse all characters
while (*src) {
// Skip leading slashes
while (*src == '/') src++;
// Copy over all non-slash characters
while (*src && *src != '/') {
*dst++ = *src++;
}
// Allow one slash at the end of current directory
if (*src) {
*dst++ = *src++;
}
}
// If ending with slash, remove the trailing slash
if (dst > norm_key.begin() && *(dst - 1) == '/') dst--;
norm_key.resize(dst - norm_key.begin());
return norm_key;
}
Status CoordinationServiceStandaloneImpl::InsertKeyValue(
const std::string& key, const std::string& value) {
const std::string& norm_key = NormalizeKey(key);
mutex_lock l(kv_mu_);
if (kv_store_.find(norm_key) != kv_store_.end()) {
return MakeCoordinationError(
errors::AlreadyExists("Config key ", key, " already exists."));
}
kv_store_.emplace(norm_key, value);
auto iter = get_cb_.find(norm_key);
if (iter != get_cb_.end()) {
for (const auto& cb : iter->second) {
cb(value);
}
get_cb_.erase(iter);
}
return OkStatus();
}
void CoordinationServiceStandaloneImpl::GetKeyValueAsync(
const std::string& key, StatusOrValueCallback done) {
const std::string& norm_key = NormalizeKey(key);
mutex_lock l(kv_mu_);
const auto& iter = kv_store_.find(norm_key);
if (iter != kv_store_.end()) {
done(iter->second);
return;
}
auto cb_iter = get_cb_.find(norm_key);
if (cb_iter == get_cb_.end()) {
cb_iter =
get_cb_.emplace(norm_key, std::vector<StatusOrValueCallback>()).first;
}
cb_iter->second.emplace_back(std::move(done));
}
StatusOr<std::string> CoordinationServiceStandaloneImpl::TryGetKeyValue(
const std::string& key) {
const std::string& norm_key = NormalizeKey(key);
mutex_lock l(kv_mu_);
const auto& iter = kv_store_.find(norm_key);
if (iter == kv_store_.end()) {
return errors::NotFound("Config key ", key, " not found.");
}
return iter->second;
}
std::vector<KeyValueEntry> CoordinationServiceStandaloneImpl::GetKeyValueDir(
absl::string_view directory_key) {
std::vector<KeyValueEntry> kvs_in_directory;
const std::string norm_key = NormalizeKey(directory_key);
const std::string dir = absl::StrCat(norm_key, "/");
mutex_lock l(kv_mu_);
// Find first key in ordered map that has the directory prefix.
auto begin = kv_store_.lower_bound(dir);
std::map<std::string, std::string>::iterator it;
// Iterate through key range that match directory prefix.
for (it = begin; it != kv_store_.end(); ++it) {
// Stop once the next key does not have the directory prefix. Since keys are
// ordered, none of the other keys would have a matching prefix.
if (std::mismatch(dir.begin(), dir.end(), it->first.begin()).first !=
dir.end()) {
break;
}
KeyValueEntry kv;
kv.set_key(it->first);
kv.set_value(it->second);
kvs_in_directory.push_back(kv);
}
return kvs_in_directory;
}
Status CoordinationServiceStandaloneImpl::DeleteKeyValue(
const std::string& key) {
const std::string& norm_key = NormalizeKey(key);
mutex_lock l(kv_mu_);
// Delete directory: find key range that match directory prefix
const std::string& dir = strings::StrCat(norm_key, "/");
auto begin = kv_store_.lower_bound(dir);
std::map<std::string, std::string>::iterator end;
for (end = begin; end != kv_store_.end(); end++) {
if (std::mismatch(dir.begin(), dir.end(), end->first.begin()).first !=
dir.end())
break;
}
kv_store_.erase(begin, end);
auto iter = kv_store_.find(norm_key);
if (iter != kv_store_.end()) {
kv_store_.erase(iter);
}
return OkStatus();
}
void CoordinationServiceStandaloneImpl::SetTaskError(
absl::string_view task_name, Status error) {
cluster_state_[task_name]->SetError(error);
for (const auto& barrier_id :
cluster_state_[task_name]->GetOngoingBarriers()) {
Status error = MakeCoordinationError(errors::Internal(absl::StrCat(
"Barrier failed from a task error. Barrier Id: ", barrier_id,
", Task: ", task_name)));
PassBarrier(barrier_id, error, &barriers_[barrier_id]);
}
LOG(ERROR) << task_name << " has been set to ERROR: " << error;
}
void CoordinationServiceStandaloneImpl::BarrierAsync(
const std::string& barrier_id, absl::Duration timeout,
const CoordinatedTask& task,
const std::vector<CoordinatedTask>& participating_tasks,
StatusCallback done) {
mutex_lock l(state_mu_);
auto pair = barriers_.try_emplace(barrier_id);
auto it = pair.first;
bool inserted = pair.second;
auto* barrier = &it->second;
// Create barrier for the first time.
if (inserted) {
// Initialize barrier state.
barrier->passed = false;
// Assume barrier is for entire cluster if no tasks are specified.
if (participating_tasks.empty()) {
for (const auto& task_state : cluster_state_) {
absl::string_view task_name = task_state.first;
barrier->tasks_at_barrier[GetTaskFromName(task_name)] = false;
}
} else {
for (const auto& task : participating_tasks) {
// Fail the barrier immediately if unexpected task is included in the
// barrier.
const std::string task_name = GetTaskName(task);
if (!cluster_state_.contains(task_name)) {
Status error = MakeCoordinationError(errors::InvalidArgument(
absl::StrCat("Unexpected task (", task_name,
") that is not in the cluster called the barrier. "
"Barrier Id: ",
barrier_id)));
PassBarrier(barrier_id, error, barrier);
done(error);
return;
}
barrier->tasks_at_barrier[task] = false;
}
}
barrier->num_pending_tasks = barrier->tasks_at_barrier.size();
// Fail the barrier immediately if any tasks are already in error.
for (const auto& pending_task : barrier->tasks_at_barrier) {
const std::string task_name = GetTaskName(pending_task.first);
if (cluster_state_[task_name]->GetState() == TaskState::State::ERROR) {
Status error = MakeCoordinationError(errors::Internal(
absl::StrCat("Task (", task_name,
") is already in error before the barrier "
"was called. Barrier Id: ",
barrier_id)));
PassBarrier(barrier_id, error, barrier);
done(error);
return;
}
}
barrier->deadline_in_micros =
Env::Default()->NowMicros() + (timeout / absl::Microseconds(1));
// Add ongoing barrier to cluster state.
ongoing_barriers_.emplace(barrier_id);
const size_t num_ongoing_barriers = ongoing_barriers_.size();
if (num_ongoing_barriers > kOngoingBarriersSoftLimit) {
LOG(WARNING) << "There is a high number of ongoing barriers in "
"coordination service: "
<< num_ongoing_barriers;
}
for (const auto& pending_task : barrier->tasks_at_barrier) {
const CoordinatedTask& task = pending_task.first;
cluster_state_[GetTaskName(task)]->JoinBarrier(barrier_id);
}
}
// Barrier has already been passed, return previous result immediately.
if (barrier->passed) {
// Special hook for shutdown barrier to disconnect task.
if (barrier_id == shutdown_barrier_id_) {
Status s = DisconnectTask(task);
// Return any errors from the disconnect attempt, otherwise return the
// barrier status outside of this hook.
if (!s.ok()) {
done(s);
return;
}
}
done(barrier->result);
return;
}
// Add pending callbacks.
barrier->done_callbacks.push_back(done);
// Check if caller task is participating in the barrier.
if (!barrier->tasks_at_barrier.contains(task)) {
// Unexpected barrier call from a task not participating in the barrier.
Status error = MakeCoordinationError(errors::InvalidArgument(
absl::StrCat("A non-participating task (", GetTaskName(task),
") called the barrier: ", barrier_id)));
PassBarrier(barrier_id, error, barrier);
return;
}
// Check if task args are specified consistently across barrier calls.
if (!ValidateTaskArgs(participating_tasks, barrier->tasks_at_barrier,
cluster_state_.size())) {
Status error = MakeCoordinationError(errors::InvalidArgument(absl::StrCat(
"Conflicting tasks specified for the same barrier: ", barrier_id)));
PassBarrier(barrier_id, error, barrier);
return;
}
// Remove pending task.
// We need to check if task made a repeated call after reaching the barrier.
if (!barrier->tasks_at_barrier[task]) {
barrier->tasks_at_barrier[task] = true;
--barrier->num_pending_tasks;
if (barrier->num_pending_tasks == 0) {
PassBarrier(barrier_id, OkStatus(), barrier);
return;
}
}
}
Status CoordinationServiceStandaloneImpl::CancelBarrier(
const std::string& barrier_id, const CoordinatedTask& task) {
mutex_lock l(state_mu_);
auto [it, inserted] = barriers_.try_emplace(barrier_id);
auto* barrier = &it->second;
if (inserted) {
LOG(WARNING) << "Barrier (" << barrier_id
<< ") is cancelled before being created by task: "
<< GetTaskName(task);
}
// Barrier has already been passed.
if (barrier->passed) {
return MakeCoordinationError(errors::FailedPrecondition(absl::StrCat(
"Barrier (", barrier_id, ") has already been passed with status code: ",
barrier->result.code())));
}
// Cancel barrier.
Status cancelled = MakeCoordinationError(errors::Cancelled(absl::StrCat(
"Barrier (", barrier_id, ") is cancelled by task: ", GetTaskName(task))));
PassBarrier(barrier_id, cancelled, barrier);
return OkStatus();
}
// Mark barrier as passed.
void CoordinationServiceStandaloneImpl::PassBarrier(
absl::string_view barrier_id, Status result, BarrierState* barrier) {
barrier->passed = true;
barrier->result = result;
// Special hook for device propagation barrier to set global device ids.
if (barrier_id == device_propagation_barrier_id_) {
SetXlaGlobalDeviceIds();
}
for (const auto& task_at_barrier : barrier->tasks_at_barrier) {
// Clean up task state (used as error hooks).
const CoordinatedTask& task = task_at_barrier.first;
cluster_state_[GetTaskName(task)]->ExitBarrier(barrier_id);
}
// Special hook for shutdown barrier to disconnect tasks at the barrier.
if (barrier_id == shutdown_barrier_id_) {
if (result.ok()) {
LOG(INFO) << "Shutdown barrier has passed.";
} else {
LOG(ERROR) << "Shutdown barrier failed: " << result
<< ". This suggests that at least one worker did not complete "
"its job, or was too slow/hanging in its execution.";
}
Status shutdown_error = MakeCoordinationError(errors::Internal(
absl::StrCat("Shutdown barrier has been passed with status: '",
barrier->result.ToString(),
"', but this task is not at the barrier yet.")));
for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) {
if (at_barrier) {
// Disconnect tasks that reached the barrier.
Status disconnect_status = DisconnectTask(task);
if (!disconnect_status.ok()) {
LOG(ERROR) << disconnect_status;
}
} else {
// Propagate errors to straggling tasks that have not reached the
// barrier. The barrier must have failed if any task did not reach the
// barrier.
ReportServiceErrorToTaskAsync(task, shutdown_error);
}
}
}
barrier->tasks_at_barrier.clear();
ongoing_barriers_.erase(barrier_id);
// Note: barrier_id shouldn't be referenced after this line as its lifetime
// may be tied to one of the callbacks.
// Propagate results to participating tasks.
for (const auto& callback : barrier->done_callbacks) {
callback(result);
}
barrier->done_callbacks.clear();
}
bool CoordinationServiceStandaloneImpl::ValidateTaskArgs(
const std::vector<CoordinatedTask>& tasks_args,
const absl::flat_hash_map<CoordinatedTask, bool, CoordinatedTaskHash,
CoordinatedTaskEqual>& tasks_at_barrier,
int64_t cluster_size) {
if (tasks_args.empty()) {
return tasks_at_barrier.size() == cluster_size;
} else if (tasks_at_barrier.size() != tasks_args.size()) {
return false;
} else {
for (const auto& task : tasks_args) {
if (!tasks_at_barrier.contains(task)) {
return false;
}
}
}
return true;
}
void CoordinationServiceStandaloneImpl::SetXlaGlobalDeviceIds() {
// No-op if TF devices are specified.
if (cluster_devices_.has_xla()) {
int global_id = 0;
for (xla::LocalTopologyProto& local_topology :
*cluster_devices_.mutable_xla()->mutable_devices()->mutable_nodes()) {
for (xla::DeviceProto& device : *local_topology.mutable_devices()) {
device.set_global_device_id(global_id);
++global_id;
}
}
}
}
} // namespace
std::unique_ptr<CoordinationServiceInterface> EnableCoordinationService(
Env* env, const ServerDef& server_def,
std::unique_ptr<CoordinationClientCache> cache) {
std::unique_ptr<CoordinationServiceInterface> coord_service;
if (is_multi_client_leader(server_def)) {
coord_service = std::make_unique<CoordinationServiceStandaloneImpl>(
std::move(cache), env, server_def);
}
return coord_service;
}
// Register standalone coordination service implementation.
REGISTER_COORDINATION_SERVICE("standalone", EnableCoordinationService);
} // namespace tensorflow