blob: 5b3dfef316c8d7f9088ca1179c62660e1acde59e [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
#include <algorithm>
#include <numeric>
#include <string>
#include <utility>
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/protobuf/coordination_config.pb.h"
#include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#endif // !IS_MOBILE_PLATFORM
namespace tensorflow {
#if !defined(IS_MOBILE_PLATFORM)
namespace {
bool AreLocalDevicesCompatible(const EagerContext* context,
const ServerDef& server_def) {
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
return false;
}
return server_def.default_session_config().SerializeAsString() ==
context->session_options().config.SerializeAsString();
}
Status AddRemoteDevicesToMgr(const std::vector<string>& added_remote_workers,
WorkerCacheInterface* worker_cache,
DynamicDeviceMgr* remote_device_mgr) {
std::vector<std::unique_ptr<Device>> remote_devices;
mutex remote_devices_mu;
int num_added_workers = added_remote_workers.size();
BlockingCounter counter(num_added_workers);
std::vector<Status> statuses(num_added_workers);
for (int i = 0; i < num_added_workers; i++) {
NewRemoteDevices(
Env::Default(), worker_cache, added_remote_workers[i],
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
const Status& s, std::vector<Device*>* devices) {
statuses[i] = s;
if (s.ok()) {
mutex_lock l(remote_devices_mu);
for (Device* d : *devices) {
remote_devices.emplace_back(d);
}
}
counter.DecrementCount();
});
}
counter.Wait();
for (int i = 0; i < num_added_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
return Status::OK();
}
Status GetAllRemoteDevices(const std::vector<string>& remote_workers,
WorkerCacheInterface* worker_cache,
std::unique_ptr<DynamicDeviceMgr>* device_mgr) {
auto remote_device_mgr = std::make_unique<DynamicDeviceMgr>();
TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
remote_device_mgr.get()));
*device_mgr = std::move(remote_device_mgr);
return Status::OK();
}
Status RemoveRemoteDevicesFromMgr(
const std::vector<string>& removed_remote_workers,
DynamicDeviceMgr* remote_device_mgr) {
const std::vector<Device*> remote_devices =
(remote_device_mgr->ListDevices());
std::vector<Device*> devices_to_remove;
for (Device* d : remote_devices) {
for (const string& remote_worker : removed_remote_workers) {
if (DeviceNameUtils::IsSameAddressSpace(remote_worker, d->name())) {
devices_to_remove.emplace_back(d);
break;
}
}
}
TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
return Status::OK();
}
Status ListRemoteWorkers(ServerInterface* server, const string& local_worker,
std::vector<string>* remote_workers) {
server->master_env()->worker_cache->ListWorkers(remote_workers);
remote_workers->erase(
std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
remote_workers->end());
return Status::OK();
}
void DifferentiateWorkerLists(const std::vector<string>* current_list,
const std::vector<string>* new_list,
std::vector<string>* added,
std::vector<string>* removed,
std::vector<string>* existing) {
// Get STL set_difference and set_intersection with one list traversal.
// Similar to the set_difference library function, the input lists
// (`current_list` and `new_list`) must be sorted before calling the function.
added->resize(new_list->size());
removed->resize(current_list->size());
existing->resize(current_list->size());
std::vector<string>::const_iterator curr_it = current_list->begin();
std::vector<string>::const_iterator new_it = new_list->begin();
std::vector<string>::iterator added_it = added->begin();
std::vector<string>::iterator removed_it = removed->begin();
std::vector<string>::iterator existing_it = existing->begin();
while (curr_it != current_list->end() && new_it != new_list->end()) {
if (*curr_it < *new_it) {
*removed_it++ = *curr_it++;
} else if (*curr_it > *new_it) {
*added_it++ = *new_it++;
} else {
*existing_it++ = *curr_it++;
new_it++;
}
}
removed_it = std::copy(curr_it, current_list->end(), removed_it);
added_it = std::copy(new_it, new_list->end(), added_it);
added->resize(added_it - added->begin());
removed->resize(removed_it - removed->begin());
existing->resize(existing_it - existing->begin());
}
Status GetReplacedFromExistingWorkers(
const std::vector<string>* existing_workers, uint64 context_id,
uint64 context_view_id, const ServerDef& server_def,
eager::EagerClientCache* client_cache,
std::vector<string>* replaced_workers) {
BlockingCounter counter(existing_workers->size());
std::vector<Status> statuses(existing_workers->size());
eager::KeepAliveRequest request;
request.set_context_id(context_id);
std::vector<eager::KeepAliveResponse> responses(existing_workers->size());
for (int i = 0; i < existing_workers->size(); i++) {
core::RefCountPtr<eager::EagerClient> eager_client;
statuses[i] =
client_cache->GetClient(existing_workers->at(i), &eager_client);
if (!statuses[i].ok()) {
counter.DecrementCount();
continue;
}
eager_client->KeepAliveAsync(&request, &responses[i],
[i, &statuses, &counter](const Status& s) {
statuses[i] = s;
counter.DecrementCount();
});
}
counter.Wait();
for (int i = 0; i < existing_workers->size(); i++) {
// If the RPC fails (indicating that the requested ID doesn't exist on
// remote), or the returned view ID is not equal to the local one
// (indicating that the remote worker has a stale view of cluster), treat
// the worker as replaced.
if (!statuses[i].ok() ||
responses[i].context_view_id() != context_view_id) {
replaced_workers->emplace_back(existing_workers->at(i));
}
}
return Status::OK();
}
Status CreateRemoteContexts(EagerContext* context,
const std::vector<string>& remote_workers,
uint64 context_id, uint64 context_view_id,
int keep_alive_secs, const ServerDef& server_def,
eager::EagerClientCache* remote_eager_workers,
bool async,
const eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
BlockingCounter counter(num_remote_workers);
std::vector<Status> statuses(num_remote_workers);
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
DeviceNameUtils::ParsedName parsed_name;
if (!DeviceNameUtils::ParseFullName(remote_worker, &parsed_name)) {
statuses[i] = errors::InvalidArgument("Unable to parse ", remote_worker,
" as a device name");
counter.DecrementCount();
continue;
}
core::RefCountPtr<eager::EagerClient> eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (eager_client == nullptr) {
statuses[i] = errors::Internal(
"Cannot find a client for the given target:", remote_worker);
}
if (!statuses[i].ok()) {
counter.DecrementCount();
continue;
}
eager::CreateContextRequest request;
eager::CreateContextResponse* response = new eager::CreateContextResponse();
request.set_context_id(context_id);
request.set_context_view_id(context_view_id);
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
std::vector<bool> filtered_device_mask;
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(),
base_request.cluster_device_attributes_size());
for (int i = 0; i < filtered_device_mask.size(); i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
// TODO(b/134094971): deprecate lazy_copy_remote_function_inputs when server
// doesn't try to get the value of lazy_copy_remote_function_inputs.
request.set_lazy_copy_remote_function_inputs(true);
eager_client->CreateContextAsync(
&request, response,
[i, &statuses, &counter, response](const Status& s) {
statuses[i] = s;
delete response;
counter.DecrementCount();
});
}
counter.Wait();
StatusGroup sg;
for (int i = 0; i < num_remote_workers; i++) {
if (TF_PREDICT_FALSE(!statuses[i].ok())) {
sg.Update(statuses[i]);
}
}
return sg.as_summary_status();
}
Status UpdateRemoteContexts(EagerContext* context,
const std::vector<string>& remote_workers,
const std::vector<string>& added_workers,
const std::vector<string>& removed_workers,
uint64 context_id, uint64 context_view_id,
const ServerDef& server_def,
eager::EagerClientCache* remote_eager_workers,
const eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
BlockingCounter counter(num_remote_workers);
std::vector<Status> statuses(num_remote_workers);
int cluster_device_count = base_request.cluster_device_attributes_size();
std::unordered_set<string> added_or_removed(added_workers.begin(),
added_workers.end());
std::copy(removed_workers.begin(), removed_workers.end(),
std::inserter(added_or_removed, added_or_removed.end()));
// Whether each device is in the updated (added or removed) workers
std::vector<bool> device_added_or_removed(cluster_device_count);
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
const auto& da = base_request.cluster_device_attributes().at(i);
DeviceNameUtils::ParsedName pn;
DeviceNameUtils::ParseFullName(da.name(), &pn);
string task_name;
DeviceNameUtils::GetTaskName(pn, &task_name);
if (added_or_removed.find(task_name) != added_or_removed.end()) {
device_added_or_removed[i] = true;
}
}
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
DeviceNameUtils::ParsedName parsed_name;
if (!DeviceNameUtils::ParseFullName(remote_worker, &parsed_name)) {
statuses[i] = errors::InvalidArgument("Unable to parse ", remote_worker,
" as a device name");
counter.DecrementCount();
continue;
}
core::RefCountPtr<eager::EagerClient> eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (eager_client == nullptr) {
statuses[i] = errors::Internal(
"Cannot find a client for the given target:", remote_worker);
}
if (!statuses[i].ok()) {
counter.DecrementCount();
continue;
}
std::vector<bool> filtered_device_mask;
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
// If any of the devices that match the device filters are in the set of
// added or removed workers, we must send a complete UpdateContextRequest.
// Otherwise, only send a simple request to increment context view ID.
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
std::transform(device_added_or_removed.begin(),
device_added_or_removed.end(), filtered_device_mask.begin(),
added_or_removed_filtered_devices.begin(),
std::logical_and<bool>());
const bool full_update_request =
std::accumulate(added_or_removed_filtered_devices.begin(),
added_or_removed_filtered_devices.end(), false,
std::logical_or<bool>());
eager::UpdateContextRequest request;
auto* response = new eager::UpdateContextResponse();
request.set_context_id(context_id);
request.set_context_view_id(context_view_id);
if (full_update_request) {
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
for (int i = 0; i < cluster_device_count; i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
}
eager_client->UpdateContextAsync(
&request, response,
[i, &statuses, &counter, response](const Status& s) {
statuses[i] = s;
delete response;
counter.DecrementCount();
});
}
counter.Wait();
for (int i = 0; i < num_remote_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
return Status::OK();
}
Status UpdateContextWithServerDef(EagerContext* context,
const ServerDef& server_def,
bool reset_context, int keep_alive_secs) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
// message.
#define LOG_AND_RETURN_IF_ERROR(...) \
do { \
const tensorflow::Status _status = (__VA_ARGS__); \
if (TF_PREDICT_FALSE(!_status.ok())) { \
LOG(ERROR) << _status.error_message(); \
return _status; \
} \
} while (0);
string worker_name =
strings::StrCat("/job:", server_def.job_name(),
"/replica:0/task:", server_def.task_index());
// List of current remote workers before updating server_def. Unused if
// resetting the server_def.
std::vector<string> curr_remote_workers;
// List of updated remote workers.
std::vector<string> remote_workers;
// New server created for new server_def. Unused if updating server_def.
std::unique_ptr<ServerInterface> new_server;
ServerInterface* server;
if (reset_context) {
DeviceMgr* device_mgr = AreLocalDevicesCompatible(context, server_def)
? context->local_device_mgr()
: nullptr;
LOG_AND_RETURN_IF_ERROR(
NewServerWithOptions(server_def, {device_mgr}, &new_server));
server = new_server.get();
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
} else {
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
&curr_remote_workers));
// No need to check the cast here, since `ListRemoteWorkers` already checks
// if the server is a GRPC server or not.
server = context->GetServer();
LOG_AND_RETURN_IF_ERROR(server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(server, worker_name, &remote_workers));
}
uint64 context_id = context->GetContextId();
uint64 context_view_id = context->GetContextViewId();
if (reset_context) {
context_id = EagerContext::NewContextId();
context_view_id = 0;
// Make master eager context accessible by local eager service, which might
// receive send tensor requests from remote workers.
LOG_AND_RETURN_IF_ERROR(
server->AddMasterEagerContextToEagerService(context_id, context));
}
std::unique_ptr<eager::EagerClientCache> remote_eager_workers;
LOG_AND_RETURN_IF_ERROR(
server->master_env()->worker_cache->GetEagerClientCache(
&remote_eager_workers));
// For cluster update, use a status group to aggregate statuses from
// * adding and removing remote devices
// * creating remote contexts on newly added workers
// * updating remote contexts on existing workers
// * updating the master context
// Note that we should not return immediately on errors in the middle of these
// updates to prevent cluster from having inconsistent context views.
//
// Unused if `reset_context` is True.
StatusGroup sg;
// When updating an existing context, populate the following lists with:
// * added_workers: set(remote_workers) - set(curr_remote_workers)
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
// * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
// * replaced_workers: workers with the same task names and potentially the
// same `hostname:port`s, but replaced by different processes
std::vector<string> added_workers;
std::vector<string> removed_workers;
std::vector<string> existing_workers;
std::vector<string> replaced_workers;
// New remote device manager created for new server_def. Unused if updating
// server_def.
std::unique_ptr<DynamicDeviceMgr> new_remote_device_mgr;
DynamicDeviceMgr* remote_device_mgr = nullptr;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(
GetAllRemoteDevices(remote_workers, server->master_env()->worker_cache,
&new_remote_device_mgr));
remote_device_mgr = new_remote_device_mgr.get();
} else {
// NOTE(b/143914772): Potential memory leak if rendezvous has pending
// tensors for removed / replaced workers.
context->ClearCachesAndDefaultExecutor();
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
if (remote_device_mgr == nullptr) {
LOG_AND_RETURN_IF_ERROR(errors::InvalidArgument(
"Updating context with an invalid set of remote devices."));
}
std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
std::sort(remote_workers.begin(), remote_workers.end());
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
&added_workers, &removed_workers,
&existing_workers);
sg.Update(GetReplacedFromExistingWorkers(
&existing_workers, context_id, context->GetContextViewId(), server_def,
remote_eager_workers.get(), &replaced_workers));
if (VLOG_IS_ON(1)) {
VLOG(1) << "Updating cluster with following changes";
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
for (const string& w : removed_workers)
VLOG(1) << " Removed worker " << w;
for (const string& w : replaced_workers)
VLOG(1) << " Replaced worker " << w;
}
if (!replaced_workers.empty()) {
// Treat replaced workers as removed then added back, so that we recreate
// remote devices and contexts, and re-register functions on those workers
removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
replaced_workers.end());
added_workers.insert(added_workers.end(), replaced_workers.begin(),
replaced_workers.end());
for (const string& w : replaced_workers) {
existing_workers.erase(
std::remove(existing_workers.begin(), existing_workers.end(), w),
existing_workers.end());
}
}
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
sg.Update(AddRemoteDevicesToMgr(
added_workers, server->master_env()->worker_cache, remote_device_mgr));
}
std::vector<DeviceAttributes> cluster_device_attributes;
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
std::vector<DeviceAttributes> local_device_attributes;
server->worker_env()->device_mgr->ListDeviceAttributes(
&local_device_attributes);
// This request make sure that we can create Rendezvous properly between
// Local and Remote context.
eager::CreateContextRequest base_request;
for (const auto& da : cluster_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
// Initialize remote eager workers.
if (reset_context) {
const Status s = CreateRemoteContexts(
context, remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
base_request);
// NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
// the CreateRemoteContexts to fail. We currently only log instead of
// directly returning the error, since returning here will cause the server
// object to be destroyed (which currently CHECK-fails). The client will
// see additional errors if ops are subsequently sent to the failed workers.
if (TF_PREDICT_FALSE(!s.ok())) {
LOG(ERROR) << "Error when creating contexts on remote targets: "
<< s.error_message()
<< "\nExecuting remote ops or functions on these remote "
"targets will fail.";
}
} else {
if (sg.ok()) {
// Create remote contexts on the newly added workers only if the master
// has collected all device information from them (i.e., the
// GetAllRemoteDevices call returns succussfully). Note that in rare cases
// GetAllRemoteDevices can still fail even with RPCs configured to wait
// until the remote workers to become alive. If the master creates remote
// contexts on the workers whose devices are still not collected, those
// workers will be treated as existing workers subsequently, so the master
// will never get devices from them even with retrying UpdateServerDef.
sg.Update(CreateRemoteContexts(
context, added_workers, context_id, context_view_id + 1,
keep_alive_secs, server_def, remote_eager_workers.get(),
context->Executor().Async(), base_request));
}
if (!existing_workers.empty()) {
if (VLOG_IS_ON(1)) {
for (const string& w : existing_workers) {
VLOG(1) << "Updating cluster with existing worker " << w;
}
}
// The master's context_view_id will be incremented by one in the
// UpdateRemoteMaster call later. We want existing workers to also have
// the updated context_view_id, so we must set their context_view_id to
// the master's current context_view_id + 1.
sg.Update(UpdateRemoteContexts(context, existing_workers, added_workers,
removed_workers, context_id,
context_view_id + 1, server_def,
remote_eager_workers.get(), base_request));
}
}
auto session_name = strings::StrCat("eager_", context_id);
if (reset_context) {
RemoteRendezvous* r =
server->worker_env()->rendezvous_mgr->Find(context_id);
auto* device_mgr = server->worker_env()->device_mgr;
std::shared_ptr<WorkerSession> worker_session;
LOG_AND_RETURN_IF_ERROR(server->worker_env()->session_mgr->CreateSession(
session_name, server_def, base_request.cluster_device_attributes(),
context->session_options().config.isolate_session_state()));
LOG_AND_RETURN_IF_ERROR(
server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
// Initialize remote tensor communication based on worker session.
LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
DistributedFunctionLibraryRuntime* cluster_flr =
eager::CreateClusterFLR(context_id, context, worker_session.get());
auto remote_mgr = std::make_unique<eager::RemoteMgr>(
/*is_master=*/true, context);
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
std::move(new_server), server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
std::move(remote_mgr)));
// NOTE: We start the server after all other initialization, because the
// GrpcServer cannot be destroyed after it is started.
LOG_AND_RETURN_IF_ERROR(server->Start());
} else {
sg.Update(server->worker_env()->session_mgr->UpdateSession(
session_name, server_def, base_request.cluster_device_attributes()));
sg.Update(context->UpdateRemoteMaster(context_id,
std::move(remote_eager_workers),
added_workers, removed_workers));
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
}
#undef LOG_AND_RETURN_IF_ERROR
return Status::OK();
}
} // namespace
Status EagerContextDistributedManager::SetOrUpdateServerDef(
const ServerDef& server_def, bool reset_context, int keep_alive_secs) {
if (server_def.has_cluster_device_filters()) {
if (reset_context) {
const auto& cdf = server_def.cluster_device_filters();
for (const auto& jdf : cdf.jobs()) {
const string remote_prefix = "/job:" + jdf.name() + "/task:";
for (const auto& tdf : jdf.tasks()) {
const int32_t task_index = tdf.first;
std::vector<string> device_filters(tdf.second.device_filters_size());
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
device_filters[i] = tdf.second.device_filters(i);
}
const string remote_worker =
strings::StrCat(remote_prefix, task_index);
TF_RETURN_IF_ERROR(
context_->SetRemoteDeviceFilters(remote_worker, device_filters));
}
}
} else {
LOG(WARNING) << "Device filters can only be specified when initializing "
"the cluster. Any changes in device filters are ignored "
"when updating the server def.";
}
}
return UpdateContextWithServerDef(context_, server_def, reset_context,
keep_alive_secs);
}
Status EagerContextDistributedManager::EnableCollectiveOps(
const ServerDef& server_def) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
// message.
#define LOG_AND_RETURN_IF_ERROR(...) \
do { \
const tensorflow::Status _status = (__VA_ARGS__); \
if (TF_PREDICT_FALSE(!_status.ok())) { \
LOG(ERROR) << _status.error_message(); \
return _status; \
} \
} while (0);
ServerInterface* server = context_->GetServer();
if (server == nullptr) {
std::unique_ptr<ServerInterface> new_server;
LOG_AND_RETURN_IF_ERROR(NewServer(server_def, &new_server));
server = new_server.get();
if (server == nullptr) {
LOG_AND_RETURN_IF_ERROR(errors::Internal(
"Currently, TF eager runtime only supports GrpcServer."));
}
auto worker_cache =
server->worker_env()->session_mgr->LegacySession()->worker_cache();
const auto& config = server_def.default_session_config();
const bool enable_coordination =
!config.experimental().coordination_config().service_type().empty();
if (enable_coordination) {
// For coordination leader: start the service instance
LOG_AND_RETURN_IF_ERROR(EnableCoordinationService(
config.experimental().coordination_config().service_type(),
server->worker_env(), server_def, worker_cache));
LOG_AND_RETURN_IF_ERROR(server->SetCoordinationServiceAgentInstance(
coordination_service_agent_.get()));
}
LOG_AND_RETURN_IF_ERROR(server->Start());
if (enable_coordination) {
auto session_name = strings::StrCat("eager_", context_->GetContextId());
std::shared_ptr<WorkerSession> worker_session;
LOG_AND_RETURN_IF_ERROR(server->worker_env()->session_mgr->CreateSession(
session_name, server_def,
context_->session_options().config.isolate_session_state()));
LOG_AND_RETURN_IF_ERROR(
server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
context_->SetWorkerEnv(server->worker_env(), worker_session);
// Coordination agent: initialize, connect, wait for all tasks
std::unique_ptr<CoordinationClientCache> agent_cache;
LOG_AND_RETURN_IF_ERROR(
worker_cache->GetCoordinationClientCache(&agent_cache));
LOG_AND_RETURN_IF_ERROR(coordination_service_agent_->Initialize(
server->worker_env()->env, server->worker_env()->device_mgr,
server_def, std::move(agent_cache), [this](Status s) {
context_->GetCollectiveExecutorHandle()->get()->StartAbort(s);
}));
LOG_AND_RETURN_IF_ERROR(coordination_service_agent_->Connect());
LOG_AND_RETURN_IF_ERROR(coordination_service_agent_->WaitForAllTasks());
// Add remote devices to eager context.
std::vector<std::unique_ptr<Device>> remote_devices;
for (const auto& d :
coordination_service_agent_->GetClusterDeviceAttributes()) {
// Treat all devices as remote so that EagerContext::remote_device_mgr
// maintains all the devices, including both local and remote.
remote_devices.emplace_back(NewRemoteDevice(context_->TFEnv(), d));
}
LOG_AND_RETURN_IF_ERROR(context_->AddDevices(std::move(remote_devices)));
}
LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer(
std::move(new_server), server->worker_env()->device_mgr,
server->worker_env()->collective_executor_mgr.get()));
if (enable_coordination) {
// Update cluster_flr and remote device list
eager::EagerClusterFunctionLibraryRuntime* cluster_flr =
new eager::EagerClusterFunctionLibraryRuntime(
context_->GetContextId(), context_,
context_->GetOwnedRemoteDeviceMgr());
context_->UpdateClusterFLRAndInitDevices(cluster_flr);
}
} else {
LOG_AND_RETURN_IF_ERROR(server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer(
/*new_server=*/nullptr, server->worker_env()->device_mgr,
server->worker_env()->collective_executor_mgr.get()));
}
#undef LOG_AND_RETURN_IF_ERROR
return Status::OK();
}
Status EagerContextDistributedManager::EnableCoordinationService(
const std::string& service_type, const WorkerEnv* worker_env,
const ServerDef& server_def, WorkerCacheInterface* worker_cache) {
std::unique_ptr<CoordinationClientCache> client_cache;
TF_RETURN_IF_ERROR(worker_cache->GetCoordinationClientCache(&client_cache));
coordination_service_ =
CoordinationServiceInterface::EnableCoordinationService(
service_type, worker_env->env, server_def, std::move(client_cache));
return Status::OK();
}
Status EagerContextDistributedManager::CheckRemoteAlive(
const std::string& remote_task_name, bool* is_alive) {
*is_alive = false;
WorkerInterface* wi =
context_->GetServer()->master_env()->worker_cache->GetOrCreateWorker(
remote_task_name);
if (wi == nullptr) {
return errors::InvalidArgument(
"Unable to find worker interface corresponding to task ",
remote_task_name);
}
GetStatusRequest request;
GetStatusResponse response;
Status remote_status;
Notification done;
wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
[&remote_status, &done](const Status& s) {
remote_status = s;
done.Notify();
});
done.WaitForNotification();
if (remote_status.ok()) {
*is_alive = true;
} else {
LOG(INFO) << "Remote worker " << remote_task_name
<< " is not alive: " << remote_status.error_message();
}
return Status::OK();
}
#endif // !IS_MOBILE_PLATFORM
} // namespace tensorflow