blob: 1c645476174538f638458547801b6731be1ac0b8 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
#include <algorithm>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/notification.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/coordination_config.pb.h"
#include "tensorflow/core/protobuf/coordination_service.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
namespace {
constexpr int kDefaultHeartbeatTimeoutMs = 10 * 1000; // 10 seconds
constexpr char kHealthCheckThread[] = "CoordinationServiceHealthCheck";
std::string GetTaskName(const std::string& job_name, int task_id) {
return strings::StrCat("/job:", job_name, "/replica:", 0, "/task:", task_id);
}
bool is_multi_client_leader(const ServerDef& server_def) {
const auto& config = server_def.default_session_config();
const std::string& leader =
config.experimental().coordination_config().service_leader();
const std::string& collective_leader =
config.experimental().collective_group_leader();
DeviceNameUtils::ParsedName leader_pn;
if (!leader.empty()) {
DeviceNameUtils::ParseFullName(leader, &leader_pn);
} else if (!collective_leader.empty()) {
LOG(INFO) << "No coordination leader is set, using the collective leader "
<< collective_leader;
DeviceNameUtils::ParseFullName(collective_leader, &leader_pn);
} else {
LOG(INFO) << "No coordination leader is set, using the default /job:"
<< server_def.job_name() << "/replica:0/task:0";
return server_def.task_index() == 0;
}
return server_def.job_name() == leader_pn.job &&
server_def.task_index() == leader_pn.task;
}
// Standalone implementation of the coordination service.
class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface {
public:
CoordinationServiceStandaloneImpl(
std::unique_ptr<CoordinationClientCache> client_cache, Env* env,
const ServerDef& server_def);
~CoordinationServiceStandaloneImpl() override { Stop(); }
void RegisterWorker(const std::string& job_name, int task_id,
uint64 incarnation, StatusCallback done) override;
void WaitForAllTasks(const std::string& job_name, int task_id,
const CoordinationServiceDeviceInfo& devices,
StatusCallback done) override;
Status RecordHeartbeat(const std::string& job_name, int task_id,
uint64 incarnation) override;
Status ReportTaskError(const std::string& job_name, int task_id,
Status error) override;
Status InsertKeyValue(const std::string& key,
const std::string& value) override;
StatusOr<std::string> GetKeyValue(const std::string& key) override;
void GetKeyValueAsync(const std::string& key,
StatusOrValueCallback done) override;
Status DeleteKeyValue(const std::string& key) override;
private:
const CoordinationServiceDeviceInfo& ListClusterDevices() override
TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
void StartCheckStaleness();
void Stop();
void PropagateError(const std::string& job, int task_id, Status error,
bool is_reported_by_agent = false)
TF_LOCKS_EXCLUDED(state_mu_);
void DoneClusterRegistration(Status s) TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
class TaskState {
public:
// Task state maintained on the coordination service side.
// State transition:
// Register Heartbeat
// DISCONNECTED -------> CONNECTED --------> ERROR (timeout)
// | ReportError
// +--------------> ERROR
// | Register
// ---------------> RESTARTED
//
// When task state becomes ERROR or RESTARTED, propagate this status to
// other CONNECTED tasks in the cluster.
enum class State {
DISCONNECTED,
CONNECTED,
ERROR,
RESTARTED,
};
State GetState() { return state_; }
Status GetStatus() { return status_; }
void SetConnected(uint64 task_incarnation);
void SetRegisteredCallback(StatusCallback cb);
Status RecordHeartbeat(uint64 task_incarnation);
int64 TimeSinceLastHeartbeatMs();
void InvokeRegisteredCallback(Status s);
void SetError(Status status);
private:
// Incarnation ID for CPU:0 on remote task.
uint64 task_incarnation_ = 0;
// WaitForAllTasks callback invoked when all tasks are registered. Must be
// invoked exactly once.
StatusCallback registered_callback_;
std::atomic_bool is_callback_invoked_{true};
State state_ = State::DISCONNECTED;
Status status_;
mutex last_heartbeat_mu_;
int64 last_heartbeat_us_ TF_GUARDED_BY(last_heartbeat_mu_);
};
std::unique_ptr<CoordinationClientCache> client_cache_;
Env& env_;
const uint64 heartbeat_timeout_ms_;
mutex state_mu_;
condition_variable cluster_registered_cv_;
absl::flat_hash_map<std::string, std::unique_ptr<TaskState>> cluster_state_
TF_GUARDED_BY(state_mu_);
CoordinationServiceDeviceInfo cluster_devices_ TF_GUARDED_BY(state_mu_);
int cluster_pending_workers_ TF_GUARDED_BY(state_mu_);
mutex kv_mu_;
// Ordered map to store config key-values
std::map<std::string, std::string> kv_store_ TF_GUARDED_BY(kv_mu_);
absl::flat_hash_map<std::string, std::vector<StatusOrValueCallback>> get_cb_
TF_GUARDED_BY(kv_mu_);
mutex check_staleness_thread_shutdown_mu_;
condition_variable check_staleness_thread_cv_;
bool shutting_down_ TF_GUARDED_BY(check_staleness_thread_shutdown_mu_) =
false;
std::unique_ptr<Thread> check_staleness_thread_;
TF_DISALLOW_COPY_AND_ASSIGN(CoordinationServiceStandaloneImpl);
};
void CoordinationServiceStandaloneImpl::TaskState::SetConnected(
uint64 task_incarnation) {
state_ = State::CONNECTED;
status_ = Status::OK();
task_incarnation_ = task_incarnation;
mutex_lock l(last_heartbeat_mu_);
last_heartbeat_us_ = Env::Default()->NowMicros();
}
void CoordinationServiceStandaloneImpl::TaskState::SetRegisteredCallback(
StatusCallback cb) {
is_callback_invoked_.store(false);
registered_callback_ = cb;
}
void CoordinationServiceStandaloneImpl::TaskState::SetError(
const Status status) {
if (state_ == State::ERROR) return;
state_ = State::ERROR;
status_ = status;
}
Status CoordinationServiceStandaloneImpl::TaskState::RecordHeartbeat(
uint64 task_incarnation) {
if (!status_.ok()) return status_;
if (task_incarnation != task_incarnation_) {
return MakeCoordinationError(errors::Aborted(
"Incarnation ID mismatch: expecting ", task_incarnation_, " but got ",
task_incarnation, ". This means the remote task has restarted."));
}
mutex_lock l(last_heartbeat_mu_);
last_heartbeat_us_ = Env::Default()->NowMicros();
return Status::OK();
}
int64 CoordinationServiceStandaloneImpl::TaskState::TimeSinceLastHeartbeatMs() {
mutex_lock l(last_heartbeat_mu_);
return (Env::Default()->NowMicros() - last_heartbeat_us_) / 1000;
}
void CoordinationServiceStandaloneImpl::TaskState::InvokeRegisteredCallback(
Status s) {
if (!is_callback_invoked_.exchange(true, std::memory_order_acq_rel)) {
registered_callback_(s);
mutex_lock l(last_heartbeat_mu_);
last_heartbeat_us_ = Env::Default()->NowMicros();
}
}
CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl(
std::unique_ptr<CoordinationClientCache> client_cache, Env* env,
const ServerDef& server_def)
: client_cache_(std::move(client_cache)),
env_(*env),
heartbeat_timeout_ms_([&server_def]() -> uint64 {
const auto& configs = server_def.default_session_config()
.experimental()
.coordination_config();
return configs.heartbeat_timeout_in_ms() > 0
? configs.heartbeat_timeout_in_ms()
: kDefaultHeartbeatTimeoutMs;
}()) {
const auto& configs =
server_def.default_session_config().experimental().coordination_config();
const std::unordered_set<std::string> coordinated_jobs(
configs.coordinated_jobs().cbegin(), configs.coordinated_jobs().cend());
const auto& cluster_def = server_def.cluster();
for (const auto& job : cluster_def.job()) {
// If `coordinated_jobs` is specified, skip jobs that are not included there
if (!coordinated_jobs.empty() &&
coordinated_jobs.find(job.name()) == coordinated_jobs.end()) {
continue;
}
for (const auto& task : job.tasks()) {
const std::string& task_name = GetTaskName(job.name(), task.first);
cluster_state_.emplace(task_name, std::make_unique<TaskState>());
}
}
cluster_pending_workers_ = cluster_state_.size();
StartCheckStaleness();
}
void CoordinationServiceStandaloneImpl::StartCheckStaleness() {
check_staleness_thread_.reset(
env_.StartThread({}, kHealthCheckThread, [this]() {
// Used to store the job and task info if a task becomes stale
DeviceNameUtils::ParsedName parsed;
while (true) {
{
mutex_lock l(check_staleness_thread_shutdown_mu_);
check_staleness_thread_cv_.wait_for(l, std::chrono::seconds(1));
if (shutting_down_) {
return;
}
}
Status status = Status::OK();
{
mutex_lock l(state_mu_);
for (const auto& worker_state : cluster_state_) {
// Skip workers that are not registered or in error state
if (worker_state.second->GetState() !=
TaskState::State::CONNECTED) {
continue;
}
const bool is_stale =
worker_state.second->TimeSinceLastHeartbeatMs() >
heartbeat_timeout_ms_;
VLOG(1) << "Checking staleness for " << worker_state.first
<< " stale?=" << is_stale;
if (is_stale) {
status = MakeCoordinationError(errors::Unavailable(
"Task ", worker_state.first,
" heartbeat timeout. This indicates that the remote task "
"has failed, got preempted, or crashed unexpectedly."));
worker_state.second->SetError(status);
DeviceNameUtils::ParseFullName(worker_state.first, &parsed);
break;
}
}
}
if (!status.ok()) {
PropagateError(parsed.job, parsed.task, status);
}
}
}));
}
void CoordinationServiceStandaloneImpl::Stop() {
{
mutex_lock l(kv_mu_);
get_cb_.clear();
}
{
mutex_lock l(state_mu_);
cluster_state_.clear();
}
{
mutex_lock l(check_staleness_thread_shutdown_mu_);
shutting_down_ = true;
check_staleness_thread_cv_.notify_all();
cluster_registered_cv_.notify_all();
}
check_staleness_thread_.reset();
}
void CoordinationServiceStandaloneImpl::RegisterWorker(
const std::string& job_name, int task_id, uint64 incarnation,
StatusCallback done) {
const std::string& task_name = GetTaskName(job_name, task_id);
Status status;
{
mutex_lock l(state_mu_);
if (!cluster_state_.contains(task_name)) {
done(MakeCoordinationError(errors::InvalidArgument(
"Unexpected worker registered with job_name=", job_name,
", task_id=", task_id)));
// Note: unexpected task register should not be propagated to other tasks
return;
} else if (cluster_state_[task_name]->GetState() ==
TaskState::State::CONNECTED) {
Status s = MakeCoordinationError(
errors::Aborted("Duplicate worker registration with job_name=",
job_name, ", task_id=", task_id),
job_name, task_id);
cluster_state_[task_name]->SetError(s);
status = s;
DoneClusterRegistration(s);
} else {
// Hit this path when the task is registering itself for the first time,
// or it's already in ERROR state and now register again. In both cases,
// the service allows it to be registered.
cluster_state_[task_name]->SetConnected(incarnation);
}
}
if (!status.ok()) PropagateError(job_name, task_id, status);
done(status);
}
void CoordinationServiceStandaloneImpl::WaitForAllTasks(
const std::string& job_name, int task_id,
const CoordinationServiceDeviceInfo& devices, StatusCallback done) {
const std::string& task_name = GetTaskName(job_name, task_id);
mutex_lock l(state_mu_);
if (!cluster_state_.contains(task_name)) {
done(MakeCoordinationError(errors::InvalidArgument(
"Unexpected worker request with job_name=", job_name,
", task_id=", task_id)));
return;
}
DCHECK_GT(cluster_pending_workers_, 0);
cluster_state_[task_name]->SetRegisteredCallback(std::move(done));
cluster_devices_.MergeFrom(devices);
cluster_pending_workers_--;
if (cluster_pending_workers_ == 0) {
DoneClusterRegistration(Status::OK());
}
}
const CoordinationServiceDeviceInfo&
CoordinationServiceStandaloneImpl::ListClusterDevices() {
return cluster_devices_;
}
void CoordinationServiceStandaloneImpl::DoneClusterRegistration(Status s) {
for (const auto& task_state : cluster_state_) {
if (task_state.second != nullptr) {
task_state.second->InvokeRegisteredCallback(s);
}
}
cluster_registered_cv_.notify_all();
}
Status CoordinationServiceStandaloneImpl::ReportTaskError(
const std::string& job_name, int task_id, Status error) {
const std::string& task_name = GetTaskName(job_name, task_id);
{
mutex_lock l(state_mu_);
if (!cluster_state_.contains(task_name)) {
return MakeCoordinationError(errors::InvalidArgument(
"Unexpected worker request with job_name=", job_name,
", task_id=", task_id));
} else if (cluster_state_[task_name]->GetState() !=
TaskState::State::CONNECTED) {
return MakeCoordinationError(errors::FailedPrecondition(
"The task is not connected or already has an error."));
} else {
cluster_state_[task_name]->SetError(error);
}
}
PropagateError(job_name, task_id, error, /*is_reported_by_agent=*/true);
return Status::OK();
}
Status CoordinationServiceStandaloneImpl::RecordHeartbeat(
const std::string& job_name, int task_id, uint64 incarnation) {
const std::string& task_name = GetTaskName(job_name, task_id);
Status s = Status::OK();
{
mutex_lock l(state_mu_);
if (!cluster_state_.contains(task_name)) {
return MakeCoordinationError(errors::InvalidArgument(
"Unexpected worker heartbeat with job_name=", job_name,
", task_id=", task_id));
} else if (!cluster_state_[task_name]->GetStatus().ok()) {
return cluster_state_[task_name]->GetStatus();
} else if (cluster_state_[task_name]->GetState() ==
TaskState::State::DISCONNECTED) {
return MakeCoordinationError(errors::InvalidArgument(
"Task with job_name=", job_name, ", task_id=", task_id,
" must be registered before sending heartbeat messages"));
}
s = cluster_state_[task_name]->RecordHeartbeat(incarnation);
}
if (!s.ok()) {
PropagateError(job_name, task_id, s);
}
return s;
}
void CoordinationServiceStandaloneImpl::PropagateError(
const std::string& job_name, int task_id, Status error,
bool is_reported_by_agent) {
assert(!error.ok());
ReportErrorToAgentRequest request;
request.set_error_code(error.code());
request.set_error_message(error.error_message());
CoordinationServiceError* payload = request.mutable_error_payload();
payload->set_job(job_name);
payload->set_task(task_id);
payload->set_is_reported_error(is_reported_by_agent);
std::vector<std::shared_ptr<Notification>> notifications;
std::vector<absl::string_view> task_names;
{
tf_shared_lock l(state_mu_);
task_names.reserve(cluster_state_.size());
for (const auto& pair : cluster_state_) {
task_names.emplace_back(pair.first);
}
}
for (absl::string_view task : task_names) {
{
mutex_lock l(state_mu_);
// Propagate error only to workers that are connected
if (cluster_state_[task]->GetState() != TaskState::State::CONNECTED)
continue;
}
CoordinationClient* client = client_cache_->GetClient(std::string(task));
auto response = std::make_shared<ReportErrorToAgentResponse>();
auto n = std::make_shared<Notification>();
client->ReportErrorToAgentAsync(
&request, response.get(), [response, n, task](Status s) {
if (!s.ok()) {
LOG(ERROR) << "Encountered another error while reporting to "
<< task << ": " << s;
}
n->Notify();
});
notifications.push_back(n);
}
for (auto& n : notifications) {
n->WaitForNotification();
}
}
// Utility for normalizing structured config key string.
// The normalized key will not have leading or trailing slashes, and all parts
// in the key path are separated by exactly one slack ('/').
// E.g., ///a//b/c// --> a/b/c
std::string NormalizeKey(const StringPiece orig_key) {
std::string norm_key = std::string(orig_key);
const char* src = norm_key.c_str();
std::string::iterator dst = norm_key.begin();
// Parse all characters
while (*src) {
// Skip leading slashes
while (*src == '/') src++;
// Copy over all non-slash characters
while (*src && *src != '/') {
*dst++ = *src++;
}
// Allow one slash at the end of current directory
if (*src) {
*dst++ = *src++;
}
}
// If ending with slash, remove the trailing slash
if (dst > norm_key.begin() && *(dst - 1) == '/') dst--;
norm_key.resize(dst - norm_key.begin());
return norm_key;
}
Status CoordinationServiceStandaloneImpl::InsertKeyValue(
const std::string& key, const std::string& value) {
const std::string& norm_key = NormalizeKey(key);
mutex_lock l(kv_mu_);
if (kv_store_.find(norm_key) != kv_store_.end()) {
return MakeCoordinationError(
errors::AlreadyExists("Config key ", key, " already exists."));
}
kv_store_.emplace(norm_key, value);
auto iter = get_cb_.find(norm_key);
if (iter != get_cb_.end()) {
for (const auto& cb : iter->second) {
cb(value);
}
get_cb_.erase(iter);
}
return Status::OK();
}
StatusOr<std::string> CoordinationServiceStandaloneImpl::GetKeyValue(
const std::string& key) {
absl::Notification n;
StatusOr<std::string> result;
GetKeyValueAsync(key, [&](const StatusOr<std::string>& status_or_value) {
result = status_or_value;
n.Notify();
});
n.WaitForNotification();
return result;
}
void CoordinationServiceStandaloneImpl::GetKeyValueAsync(
const std::string& key, StatusOrValueCallback done) {
const std::string& norm_key = NormalizeKey(key);
mutex_lock l(kv_mu_);
const auto& iter = kv_store_.find(norm_key);
if (iter != kv_store_.end()) {
done(iter->second);
return;
}
auto cb_iter = get_cb_.find(norm_key);
if (cb_iter == get_cb_.end()) {
cb_iter =
get_cb_.emplace(norm_key, std::vector<StatusOrValueCallback>()).first;
}
cb_iter->second.emplace_back(std::move(done));
}
Status CoordinationServiceStandaloneImpl::DeleteKeyValue(
const std::string& key) {
const std::string& norm_key = NormalizeKey(key);
mutex_lock l(kv_mu_);
// Delete directory: find key range that match directory prefix
const std::string& dir = strings::StrCat(norm_key, "/");
auto begin = kv_store_.lower_bound(dir);
std::map<std::string, std::string>::iterator end;
for (end = begin; end != kv_store_.end(); end++) {
if (std::mismatch(dir.begin(), dir.end(), end->first.begin()).first !=
dir.end())
break;
}
kv_store_.erase(begin, end);
auto iter = kv_store_.find(norm_key);
if (iter != kv_store_.end()) {
kv_store_.erase(iter);
}
return Status::OK();
}
} // namespace
std::unique_ptr<CoordinationServiceInterface> EnableCoordinationService(
Env* env, const ServerDef& server_def,
std::unique_ptr<CoordinationClientCache> cache) {
std::unique_ptr<CoordinationServiceInterface> coord_service;
if (is_multi_client_leader(server_def)) {
coord_service = std::make_unique<CoordinationServiceStandaloneImpl>(
std::move(cache), env, server_def);
}
return coord_service;
}
// Register standalone coordination service implementation.
REGISTER_COORDINATION_SERVICE("standalone", EnableCoordinationService);
} // namespace tensorflow