blob: cfaad76931e5b3eb920538fb65c2f02f614581f8 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h"
#include "absl/synchronization/notification.h"
#include "absl/time/time.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/coordination_config.pb.h"
#include "tensorflow/core/protobuf/coordination_service.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
namespace {
constexpr absl::Duration kDefaultClusterRegisterTimeout = absl::Hours(1);
constexpr absl::Duration kDefaultHeartbeatTimeout = absl::Seconds(10);
constexpr absl::Duration kDefaultShutdownTimeout = absl::Seconds(10);
constexpr char kHeartbeatThread[] = "CoordinationServiceHeartbeatLoop";
class CoordinationServiceAgentImpl : public CoordinationServiceAgent {
public:
CoordinationServiceAgentImpl() = default;
~CoordinationServiceAgentImpl() override {
Status s = Shutdown();
if (!s.ok()) {
LOG(ERROR) << "Agent shutdown failed with status: " << s;
}
// Cancel all pending GetKeyValue() RPC calls.
cancellation_manager_.StartCancel();
}
Status Initialize(Env* env, const ServerDef& server_def,
std::unique_ptr<CoordinationClientCache> client_cache,
StatusCallback error_fn) override;
Status Initialize(Env* env, const std::string& job_name, int task_id,
const CoordinationServiceConfig& configs,
std::unique_ptr<CoordinationClient> leader_client,
StatusCallback error_fn) override;
Status Initialize(Env* env, const CoordinatedTask& task,
const CoordinationServiceConfig& configs,
std::unique_ptr<CoordinationClient> leader_client,
StatusCallback error_fn) override;
bool IsInitialized() override;
Status Connect() override;
Status WaitForAllTasks(
const CoordinationServiceDeviceInfo& local_devices) override;
const CoordinationServiceDeviceInfo& GetClusterDeviceInfo() override;
StatusOr<CoordinatedTask> GetOwnTask() override;
StatusOr<TaskState> GetTaskStatus(const CoordinatedTask& task) override;
Status ReportError(const Status& error) override;
Status Shutdown() override;
Status Reset() override;
StatusOr<std::string> GetKeyValue(const std::string& key) override;
StatusOr<std::string> GetKeyValue(const std::string& key,
absl::Duration timeout) override;
std::shared_ptr<CallOptions> GetKeyValueAsync(
const std::string& key, StatusOrValueCallback done) override;
StatusOr<std::string> TryGetKeyValue(const std::string& key) override;
StatusOr<std::vector<KeyValueEntry>> GetKeyValueDir(
const std::string& key) override;
void GetKeyValueDirAsync(const std::string& key,
StatusOrValueDirCallback done) override;
Status InsertKeyValue(const std::string& key,
const std::string& value) override;
Status DeleteKeyValue(const std::string& key) override;
Status UpdateKeyValue(const std::string& key,
const std::string& value) override;
Status StartWatchKey(const std::string& key,
ChangedKeyValuesCallback on_change) override;
Status StopWatchKey(const std::string& key) override;
Status WaitAtBarrier(const std::string& barrier_id, absl::Duration timeout,
const std::vector<CoordinatedTask>& tasks) override;
void WaitAtBarrierAsync(const std::string& barrier_id, absl::Duration timeout,
const std::vector<CoordinatedTask>& tasks,
StatusCallback done) override;
Status CancelBarrier(const std::string& barrier_id) override;
void CancelBarrierAsync(const std::string& barrier_id,
StatusCallback done) override;
protected:
void SetError(const Status& error) override;
Status ActivateWatch(const std::string& key,
const std::map<std::string, std::string>&) override;
// Returns an error if agent is not running.
Status ValidateRunningAgent();
void StopHeartbeat();
private:
Env* env_ = nullptr; // Not owned.
const uint64_t incarnation_id_ = random::New64();
CoordinatedTask task_;
CoordinationServiceConfig configs_;
StatusCallback error_fn_;
enum class State {
UNINITIALIZED,
DISCONNECTED,
RUNNING,
ERROR,
SHUTDOWN,
};
mutable mutex state_mu_;
State state_ TF_GUARDED_BY(state_mu_) = State::UNINITIALIZED;
Status status_ TF_GUARDED_BY(state_mu_) = OkStatus();
// Note: this set grows without bounds. For now, this is okay as most users
// require < 100 barriers. If there is a use case that requires many barriers,
// consider using a monotonic sequence number to track instead.
absl::flat_hash_set<std::string> used_barrier_ids_ TF_GUARDED_BY(state_mu_);
uint64_t leader_incarnation_ = 0;
CoordinationServiceDeviceInfo cluster_devices_;
mutex heartbeat_thread_shutdown_mu_;
condition_variable heartbeat_thread_cv_;
bool shutting_down_ TF_GUARDED_BY(heartbeat_thread_shutdown_mu_) = false;
std::unique_ptr<Thread> heartbeat_thread_;
// Must outlive coordination client which may need to access it within
// GetKeyValueAsync() callbacks.
CancellationManager cancellation_manager_;
std::unique_ptr<CoordinationClient> leader_client_;
TF_DISALLOW_COPY_AND_ASSIGN(CoordinationServiceAgentImpl);
};
Status CoordinationServiceAgentImpl::Initialize(
Env* env, const ServerDef& server_def,
std::unique_ptr<CoordinationClientCache> client_cache,
StatusCallback error_fn) {
CoordinationServiceConfig configs =
server_def.default_session_config().experimental().coordination_config();
if (configs.service_leader().empty()) {
const std::string& collective_leader = server_def.default_session_config()
.experimental()
.collective_group_leader();
if (!collective_leader.empty()) {
configs.set_service_leader(collective_leader);
LOG(INFO) << "No coordination leader is set, using the collective leader "
<< collective_leader;
} else {
const std::string& default_leader =
strings::StrCat("/job:", server_def.job_name(), "/replica:0/task:0");
configs.set_service_leader(default_leader);
LOG(INFO) << "No coordination leader is set, using the default leader "
<< default_leader;
}
}
return Initialize(
env, server_def.job_name(), server_def.task_index(), configs,
client_cache->GetOwnedClient(configs.service_leader()), error_fn);
}
Status CoordinationServiceAgentImpl::Initialize(
Env* env, const std::string& job_name, int task_id,
const CoordinationServiceConfig& configs,
std::unique_ptr<CoordinationClient> leader_client,
StatusCallback error_fn) {
CoordinatedTask task;
task.set_job_name(job_name);
task.set_task_id(task_id);
return Initialize(env, task, configs, std::move(leader_client), error_fn);
}
Status CoordinationServiceAgentImpl::Initialize(
Env* env, const CoordinatedTask& task,
const CoordinationServiceConfig& configs,
std::unique_ptr<CoordinationClient> leader_client,
StatusCallback error_fn) {
mutex_lock l(state_mu_);
if (state_ != State::UNINITIALIZED) {
return MakeCoordinationError(errors::FailedPrecondition(
"Coordination service agent has already been initialized."));
}
env_ = env;
task_ = task;
configs_ = configs;
if (configs_.service_leader().empty()) {
return MakeCoordinationError(errors::InvalidArgument(
"CoordinationServiceAgent must be initialized with a valid leader."));
}
leader_client_ = std::move(leader_client);
if (leader_client_ == nullptr) {
return MakeCoordinationError(errors::InvalidArgument(
"CoordinationServiceAgent must have a valid leader client."));
}
error_fn_ = error_fn;
state_ = State::DISCONNECTED;
return OkStatus();
}
bool CoordinationServiceAgentImpl::IsInitialized() {
mutex_lock l(state_mu_);
return state_ != State::UNINITIALIZED;
}
void CoordinationServiceAgentImpl::StopHeartbeat() {
{
mutex_lock l(heartbeat_thread_shutdown_mu_);
shutting_down_ = true;
heartbeat_thread_cv_.notify_all();
}
heartbeat_thread_.reset();
}
Status CoordinationServiceAgentImpl::Connect() {
{
mutex_lock l(state_mu_);
if (state_ != State::DISCONNECTED) {
return MakeCoordinationError(errors::FailedPrecondition(
"Coordination service agent is not in DISCONNECTED state."));
}
}
RegisterTaskRequest request;
*request.mutable_source_task() = task_;
request.set_incarnation(incarnation_id_);
RegisterTaskResponse response;
absl::Notification n;
// Block until the remote service is up and the task is registered.
CallOptions call_opts;
const int64_t register_timeout =
configs_.cluster_register_timeout_in_ms() > 0
? configs_.cluster_register_timeout_in_ms()
: absl::ToInt64Milliseconds(kDefaultClusterRegisterTimeout);
call_opts.SetTimeout(register_timeout);
leader_client_->RegisterTaskAsync(
&call_opts, &request, &response, [&](Status s) {
if (!s.ok()) {
SetError(s);
} else {
leader_incarnation_ = response.leader_incarnation();
{
mutex_lock l(state_mu_);
state_ = State::RUNNING;
}
}
n.Notify();
});
n.WaitForNotification();
{
mutex_lock l(state_mu_);
if (state_ == State::ERROR) {
return status_;
}
}
LOG(INFO) << "Coordination agent has successfully connected.";
heartbeat_thread_.reset(
env_->StartThread(ThreadOptions(), kHeartbeatThread, [this]() -> void {
HeartbeatRequest request;
*request.mutable_source_task() = task_;
request.set_incarnation(incarnation_id_);
HeartbeatResponse response;
const int64_t heartbeat_interval_ms =
configs_.heartbeat_timeout_in_ms() > 0
? configs_.heartbeat_timeout_in_ms() / 2
: absl::ToInt64Milliseconds(kDefaultHeartbeatTimeout) / 2;
CallOptions call_opts;
call_opts.SetTimeout(heartbeat_interval_ms);
while (true) {
{
mutex_lock l(heartbeat_thread_shutdown_mu_);
heartbeat_thread_cv_.wait_for(
l, std::chrono::milliseconds(heartbeat_interval_ms));
if (shutting_down_) {
return;
}
}
Status status;
absl::Notification n;
// Heartbeat RPC implementation automatically retries to tolerate
// transient network failures.
leader_client_->HeartbeatAsync(&call_opts, &request, &response,
[&](Status s) {
status = s;
n.Notify();
});
n.WaitForNotification();
if (!status.ok()) {
SetError(status);
} else if (response.leader_incarnation() != leader_incarnation_) {
SetError(MakeCoordinationError(
errors::Aborted("Leader incarnation ID mismatch: the "
"coordination leader has restarted.")));
}
}
}));
return OkStatus();
}
Status CoordinationServiceAgentImpl::WaitForAllTasks(
const CoordinationServiceDeviceInfo& local_devices) {
Status agent_running_status = ValidateRunningAgent();
if (!agent_running_status.ok()) {
return agent_running_status;
}
WaitForAllTasksRequest request;
*request.mutable_source_task() = task_;
*request.mutable_local_device_info() = local_devices;
WaitForAllTasksResponse response;
Status status;
absl::Notification n;
leader_client_->WaitForAllTasksAsync(&request, &response, [&](Status s) {
status = s;
n.Notify();
});
n.WaitForNotification();
if (!status.ok()) {
SetError(status);
return status;
}
cluster_devices_.MergeFrom(response.cluster_device_info());
return OkStatus();
}
const CoordinationServiceDeviceInfo&
CoordinationServiceAgentImpl::GetClusterDeviceInfo() {
return cluster_devices_;
}
StatusOr<CoordinatedTask> CoordinationServiceAgentImpl::GetOwnTask() {
if (!IsInitialized()) {
return MakeCoordinationError(
errors::FailedPrecondition("Agent has not been initialized; we do not "
"know the associated task yet."));
}
return task_;
}
StatusOr<CoordinationServiceAgentImpl::TaskState>
CoordinationServiceAgentImpl::GetTaskStatus(const CoordinatedTask& task) {
return MakeCoordinationError(errors::Unimplemented(
"CoordinationServiceAgentImpl::GetTaskStatus is not implemented."));
}
Status CoordinationServiceAgentImpl::ReportError(const Status& error) {
{
mutex_lock l(state_mu_);
if (state_ == State::UNINITIALIZED) {
return MakeCoordinationError(errors::FailedPrecondition(
"Coordination service agent must be initialized first before "
"reporting error."));
} else if (state_ == State::ERROR) {
return MakeCoordinationError(errors::FailedPrecondition(
"Coordination service agent is already in error state."));
}
}
SetError(MakeCoordinationError(error, task_,
/*is_reported_error=*/true));
LOG(INFO) << "Reporting error to coordination service: " << error;
ReportErrorToServiceRequest request;
request.set_error_code(error.code());
request.set_error_message(error.error_message());
*request.mutable_error_origin() = task_;
ReportErrorToServiceResponse response;
absl::Notification n;
leader_client_->ReportErrorToServiceAsync(&request, &response, [&](Status s) {
if (!s.ok()) {
LOG(ERROR) << "Encountered another error when reporting error to "
"coordination service: "
<< s;
}
n.Notify();
});
n.WaitForNotification();
return OkStatus();
}
Status CoordinationServiceAgentImpl::Shutdown() {
Status status = OkStatus();
bool is_connected = false;
{
mutex_lock l(state_mu_);
is_connected = state_ == State::RUNNING;
}
// Disconnect agent from service.
if (!configs_.agent_destruction_without_shutdown() && is_connected) {
ShutdownTaskRequest request;
*request.mutable_source_task() = task_;
ShutdownTaskResponse response;
CallOptions call_opts;
const int64_t shutdown_timeout =
configs_.shutdown_barrier_timeout_in_ms() > 0
? configs_.shutdown_barrier_timeout_in_ms()
: absl::ToInt64Milliseconds(kDefaultShutdownTimeout);
call_opts.SetTimeout(shutdown_timeout);
absl::Notification n;
leader_client_->ShutdownTaskAsync(&call_opts, &request, &response,
[&status, &n](Status s) {
status = s;
n.Notify();
});
n.WaitForNotification();
if (status.ok()) {
LOG(INFO) << "Coordination agent has successfully shut down.";
} else {
LOG(ERROR)
<< "Failed to disconnect from coordination service with status: "
<< status << ". Proceeding with agent shutdown anyway.";
}
}
// Tear down agent.
StopHeartbeat();
{
mutex_lock l(state_mu_);
if (state_ == State::ERROR) {
status = MakeCoordinationError(errors::FailedPrecondition(absl::StrCat(
"Shutdown() was called while agent is in error state, implying that "
"distributed execution failed. Note: agent will still shutdown "
"anyway. Agent status: ",
status_.ToString())));
}
state_ = State::SHUTDOWN;
}
return status;
}
Status CoordinationServiceAgentImpl::Reset() {
{
mutex_lock l(state_mu_);
if (state_ != State::ERROR) {
return MakeCoordinationError(errors::FailedPrecondition(
"Reset() failed: coordination service agent is not in ERROR state."));
}
}
ResetTaskRequest request;
*request.mutable_source_task() = task_;
ResetTaskResponse response;
Status status;
absl::Notification n;
leader_client_->ResetTaskAsync(&request, &response, [&status, &n](Status s) {
status = s;
n.Notify();
});
n.WaitForNotification();
if (!status.ok()) {
return status;
}
// Reset agent state.
StopHeartbeat();
{
mutex_lock l(state_mu_);
state_ = State::DISCONNECTED;
}
{
mutex_lock l(heartbeat_thread_shutdown_mu_);
shutting_down_ = false;
}
LOG(INFO) << "Coordination agent has been reset.";
return status;
}
StatusOr<std::string> CoordinationServiceAgentImpl::GetKeyValue(
const std::string& key) {
return GetKeyValue(key, /*timeout=*/absl::InfiniteDuration());
}
StatusOr<std::string> CoordinationServiceAgentImpl::GetKeyValue(
const std::string& key, absl::Duration timeout) {
auto n = std::make_shared<absl::Notification>();
auto result = std::make_shared<StatusOr<std::string>>();
GetKeyValueAsync(key,
[n, result](const StatusOr<std::string>& status_or_value) {
*result = status_or_value;
n->Notify();
});
bool call_completed_before_timeout =
n->WaitForNotificationWithTimeout(timeout);
if (!call_completed_before_timeout) {
return MakeCoordinationError(errors::DeadlineExceeded(absl::Substitute(
"GetKeyValue() timed out with key: $0 and duration: $1", key,
absl::FormatDuration(timeout))));
}
return *result;
}
std::shared_ptr<CallOptions> CoordinationServiceAgentImpl::GetKeyValueAsync(
const std::string& key, StatusOrValueCallback done) {
auto request = std::make_shared<GetKeyValueRequest>();
request->set_key(key);
auto response = std::make_shared<GetKeyValueResponse>();
auto call_opts = std::make_shared<CallOptions>();
const CancellationToken token =
cancellation_manager_.get_cancellation_token();
const bool already_cancelled = !cancellation_manager_.RegisterCallback(
token, [call_opts]() { call_opts->StartCancel(); });
if (already_cancelled) {
done(errors::Cancelled("GetKeyValueAsync() was cancelled."));
return call_opts;
}
leader_client_->GetKeyValueAsync(
call_opts.get(), request.get(), response.get(),
[call_opts, request, response, done = std::move(done),
&cm = cancellation_manager_, token](const Status& s) {
// RPC call has completed (no longer needs to be cancelled if agent is
// destroyed).
cm.TryDeregisterCallback(token);
// Retrieve server response.
if (!s.ok()) {
done(s);
} else {
done(response->kv().value());
}
});
return call_opts;
}
StatusOr<std::string> CoordinationServiceAgentImpl::TryGetKeyValue(
const std::string& key) {
absl::Notification n;
StatusOr<std::string> result;
TryGetKeyValueRequest request;
request.set_key(key);
TryGetKeyValueResponse response;
leader_client_->TryGetKeyValueAsync(&request, &response,
[&](const Status& s) {
if (s.ok()) {
result = response.kv().value();
} else {
result = s;
}
n.Notify();
});
n.WaitForNotification();
return result;
}
StatusOr<std::vector<KeyValueEntry>>
CoordinationServiceAgentImpl::GetKeyValueDir(const std::string& key) {
absl::Notification n;
StatusOr<std::vector<KeyValueEntry>> result;
GetKeyValueDirAsync(
key, [&n, &result](StatusOr<std::vector<KeyValueEntry>> status_or_value) {
result = std::move(status_or_value);
n.Notify();
});
n.WaitForNotification();
return result;
}
void CoordinationServiceAgentImpl::GetKeyValueDirAsync(
const std::string& key, StatusOrValueDirCallback done) {
auto request = std::make_shared<GetKeyValueDirRequest>();
request->set_directory_key(key);
auto response = std::make_shared<GetKeyValueDirResponse>();
leader_client_->GetKeyValueDirAsync(
request.get(), response.get(),
[request, response, done = std::move(done)](const Status& s) {
if (!s.ok()) {
done(s);
} else {
std::vector<KeyValueEntry> kv_in_directory = {
std::make_move_iterator(response->kv().begin()),
std::make_move_iterator(response->kv().end())};
done(kv_in_directory);
}
});
}
Status CoordinationServiceAgentImpl::InsertKeyValue(const std::string& key,
const std::string& value) {
InsertKeyValueRequest request;
request.mutable_kv()->set_key(key.data(), key.size());
request.mutable_kv()->set_value(value.data(), value.size());
InsertKeyValueResponse response;
Status status;
absl::Notification n;
leader_client_->InsertKeyValueAsync(&request, &response, [&](Status s) {
status = s;
n.Notify();
});
n.WaitForNotification();
return status;
}
Status CoordinationServiceAgentImpl::DeleteKeyValue(const std::string& key) {
DeleteKeyValueRequest request;
request.set_key(key);
request.set_is_directory(true);
DeleteKeyValueResponse response;
Status status;
absl::Notification n;
leader_client_->DeleteKeyValueAsync(&request, &response, [&](Status s) {
status = s;
n.Notify();
});
n.WaitForNotification();
return OkStatus();
}
Status CoordinationServiceAgentImpl::UpdateKeyValue(const std::string& key,
const std::string& value) {
return MakeCoordinationError(errors::Unimplemented(
"CoordinationServiceAgent::UpdateKeyValue is not implemented."));
}
Status CoordinationServiceAgentImpl::StartWatchKey(
const std::string& key,
CoordinationServiceAgentImpl::ChangedKeyValuesCallback on_change) {
return MakeCoordinationError(errors::Unimplemented(
"CoordinationServiceAgent::StartWatchKey is not implemented."));
}
Status CoordinationServiceAgentImpl::StopWatchKey(const std::string& key) {
return MakeCoordinationError(errors::Unimplemented(
"CoordinationServiceAgent::StopWatchKey is not implemented."));
}
void CoordinationServiceAgentImpl::SetError(const Status& error) {
assert(!error.ok());
mutex_lock l(state_mu_);
if (state_ == State::ERROR) return;
LOG(ERROR) << "Coordination agent is in ERROR: " << error;
state_ = State::ERROR;
status_ = error;
error_fn_(error);
}
Status CoordinationServiceAgentImpl::ActivateWatch(
const std::string& key, const std::map<std::string, std::string>& kvs) {
return MakeCoordinationError(errors::Unimplemented(
"CoordinationServiceAgent::ActivateWatch is not implemented."));
}
Status CoordinationServiceAgentImpl::WaitAtBarrier(
const std::string& barrier_id, absl::Duration timeout,
const std::vector<CoordinatedTask>& tasks) {
Status status;
absl::Notification n;
WaitAtBarrierAsync(barrier_id, timeout, tasks, [&](Status s) {
status = s;
n.Notify();
});
n.WaitForNotification();
return status;
}
void CoordinationServiceAgentImpl::WaitAtBarrierAsync(
const std::string& barrier_id, absl::Duration timeout,
const std::vector<CoordinatedTask>& tasks, StatusCallback done) {
Status agent_running_status = ValidateRunningAgent();
if (!agent_running_status.ok()) {
done(agent_running_status);
return;
}
{
mutex_lock l(state_mu_);
auto [it, inserted] = used_barrier_ids_.insert(barrier_id);
if (!inserted) {
done(errors::FailedPrecondition(
"WaitAtBarrier() should not be called with the same id more than "
"once. Barrier id: ",
barrier_id));
return;
}
}
auto request = std::make_shared<BarrierRequest>();
auto response = std::make_shared<BarrierResponse>();
request->set_barrier_id(barrier_id);
request->set_barrier_timeout_in_ms(timeout / absl::Milliseconds(1));
*request->mutable_source_task() = task_;
*request->mutable_tasks() = {tasks.begin(), tasks.end()};
leader_client_->BarrierAsync(request.get(), response.get(),
[request, response, done = std::move(done)](
const Status& s) { done(s); });
}
Status CoordinationServiceAgentImpl::CancelBarrier(
const std::string& barrier_id) {
Status status;
absl::Notification n;
CancelBarrierAsync(barrier_id, [&](const Status& s) {
status = s;
n.Notify();
});
n.WaitForNotification();
return status;
}
void CoordinationServiceAgentImpl::CancelBarrierAsync(
const std::string& barrier_id, StatusCallback done) {
Status agent_running_status = ValidateRunningAgent();
if (!agent_running_status.ok()) {
done(agent_running_status);
return;
}
auto request = std::make_shared<CancelBarrierRequest>();
auto response = std::make_shared<CancelBarrierResponse>();
request->set_barrier_id(barrier_id);
*request->mutable_source_task() = task_;
leader_client_->CancelBarrierAsync(
request.get(), response.get(),
[request, response, done = std::move(done)](const Status& s) {
done(s);
});
}
// Returns an error if agent is not running.
Status CoordinationServiceAgentImpl::ValidateRunningAgent() {
mutex_lock l(state_mu_);
switch (state_) {
case State::RUNNING:
return OkStatus();
case State::UNINITIALIZED:
return MakeCoordinationError(errors::FailedPrecondition(
"Agent must be in RUNNING state. It is currently UNINITIALIZED."));
case State::DISCONNECTED:
return MakeCoordinationError(errors::FailedPrecondition(
"Agent must be in RUNNING state. It is currently DISCONNECTED."));
case State::ERROR:
return MakeCoordinationError(errors::FailedPrecondition(
"Agent must be in RUNNING state. It is currently in ERROR."));
case State::SHUTDOWN:
return MakeCoordinationError(errors::FailedPrecondition(
"Agent must be in RUNNING state. It is currently in SHUTDOWN."));
default:
return MakeCoordinationError(errors::FailedPrecondition(absl::StrCat(
"Agent is not in RUNNING state. Current state: ", state_)));
}
}
} // namespace
std::unique_ptr<CoordinationServiceAgent> CreateCoordinationServiceAgent() {
return std::make_unique<CoordinationServiceAgentImpl>();
}
} // namespace tensorflow