blob: 33221e512187b799243a7e7660cf6dbff2ce55ae [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/context.h"
#include <memory>
#include <vector>
// clang-format off
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
#include "tensorflow/core/common_runtime/colocation_graph.h"
#include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/device_name_utils.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
#endif // !IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
namespace {
bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
bool val;
if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
return val;
}
return default_val;
}
auto* eager_context_created =
monitoring::Gauge<bool, 0>::New("/tensorflow/core/eager_context_created",
"True if an eager context was created.");
} // namespace
EagerContext::EagerContext(
const SessionOptions& opts,
ContextDevicePlacementPolicy default_device_placement_policy,
ContextMirroringPolicy default_mirroring_policy, bool async,
const bool lazy_copy_function_remote_inputs, const DeviceMgr* device_mgr,
bool device_mgr_owned, Rendezvous* rendezvous,
const CustomKernelCreator* custom_kernel_creator,
DistributedFunctionLibraryRuntime* cluster_flr)
: default_device_placement_policy_(default_device_placement_policy),
default_mirroring_policy_(default_mirroring_policy),
local_device_manager_(device_mgr, device_mgr_owned),
host_cpu_device_(device_mgr->HostCPU()),
rendezvous_(rendezvous),
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
custom_kernel_creator_(custom_kernel_creator),
cluster_flr_(cluster_flr),
log_device_placement_(opts.config.log_device_placement()),
allow_soft_placement_(opts.config.allow_soft_placement()),
num_active_steps_(0),
default_executor_(async),
log_memory_(LogMemory::IsEnabled()),
env_(opts.env),
lazy_copy_function_remote_inputs_(lazy_copy_function_remote_inputs),
use_send_tensor_rpc_(false),
pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
"TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) {
ResetPFLR(device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION,
&func_lib_def_, opts.config.graph_options().optimizer_options(),
thread_pool_.get(), cluster_flr, custom_kernel_creator_);
// Starts exporting metrics through a platform-specific monitoring API (if
// provided). For builds using "tensorflow/core/platform/default", this is
// currently a no-op.
eager_context_created->GetCell()->Set(true);
InitPrioritizedDeviceTypeList();
runner_ = [this](std::function<void()> closure) {
this->thread_pool_->Schedule(std::move(closure));
};
#if !defined(IS_MOBILE_PLATFORM)
context_id_ = kInvalidContextId;
#endif // IS_MOBILE_PLATFORM
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(local_device_mgr()));
std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
opts.config, local_device_mgr(), drl.get(),
"/job:localhost/replica:0/task:0"));
collective_executor_mgr_.Reset(
new CollectiveExecutorMgr(opts.config, local_device_mgr(), std::move(drl),
std::move(cprl)),
/*owned=*/true);
}
void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
const ConfigProto* config, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
thread::ThreadPool* thread_pool,
DistributedFunctionLibraryRuntime* cluster_flr,
const CustomKernelCreator* custom_kernel_creator) {
Rendezvous::Factory rendezvous_factory{
[this](const int64 step_id, const DeviceMgr*, Rendezvous** r) {
*r = CreateRendezvous(step_id);
return Status::OK();
}};
if (lazy_copy_function_remote_inputs_) {
pflr_.reset(new eager::EagerProcessFunctionLibraryRuntime(
device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
thread_pool, cluster_flr, custom_kernel_creator,
/*session_metadata=*/nullptr, std::move(rendezvous_factory)));
} else {
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
thread_pool, cluster_flr, custom_kernel_creator,
/*session_metadata=*/nullptr, std::move(rendezvous_factory)));
}
}
void EagerContext::InitPrioritizedDeviceTypeList() {
DeviceSet ds;
for (Device* d : local_device_mgr()->ListDevices()) {
ds.AddDevice(d);
}
auto remote_device_manager = remote_device_mgr();
if (remote_device_manager != nullptr) {
for (Device* d : remote_device_manager->ListDevices()) {
ds.AddDevice(d);
}
}
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
}
namespace {
// Using absl::StrJoin with lambda does not work in tf-lite builds.
// TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
std::vector<string> DevicesToString(const PrioritizedDeviceVector& devices) {
std::vector<string> v;
v.reserve(devices.size());
for (const auto& p : devices) {
v.push_back(p.first->name());
}
return v;
}
std::vector<string> DeviceTypesToString(
const PrioritizedDeviceTypeVector& types) {
std::vector<string> v;
v.reserve(types.size());
for (const auto& p : types) {
v.push_back(p.first.type_string());
}
return v;
}
// Selects the "best" device that both exists and is supported.
//
// The `existing` argument specifies the available devices in the system, in
// priority order. The `supported` argument specifies the supported device types
// and their priorities, lower index types having higher priority.
// Currently the type priority defined by the `supported` parameter takes
// precedence over system device priorities from `existing`.
//
// TODO(b/148213212): Allow setting default device in eager context.
Device* SelectBestMatchingDevice(const DeviceNameUtils::ParsedName& pattern,
const PrioritizedDeviceVector& existing,
const PrioritizedDeviceTypeVector& supported) {
for (const std::pair<DeviceType, int32>& prioritized_type : supported) {
for (const std::pair<Device*, int32>& prioritized_device : existing) {
Device* dev = prioritized_device.first;
if (DeviceType(dev->attributes().device_type()) ==
prioritized_type.first &&
DeviceNameUtils::IsCompleteSpecification(pattern,
dev->parsed_name())) {
return dev;
}
}
}
return nullptr;
}
} // namespace
Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
const PrioritizedDeviceTypeVector& supported,
const DataType dtype, Device** out) const {
DCHECK(out != nullptr);
// We always place string tensors on the CPU device if we're allowed to.
if (dtype == DT_STRING && AllowSoftPlacement()) {
preferred = HostCPU()->parsed_name();
}
// Select the first matching registered device from the supported device
// list. If nothing matches and soft placement is enabled, pick a suitable
// device from the available ones.
const PrioritizedDeviceVector& existing =
pflr()->device_set()->prioritized_devices();
*out = SelectBestMatchingDevice(preferred, existing, supported);
if (*out != nullptr) {
return Status::OK();
}
if (AllowSoftPlacement()) {
DeviceNameUtils::ParsedName soft_device_name = preferred;
soft_device_name.type.clear();
soft_device_name.has_type = false;
soft_device_name.has_id = false;
// TODO(b/148213746): Soft placement logic picks up another task if the
// requested does not exist.
*out = SelectBestMatchingDevice(soft_device_name, existing, supported);
if (*out != nullptr) {
return Status::OK();
}
}
if (DeviceNameUtils::HasSomeDetails(preferred)) {
return errors::InvalidArgument(
"Could not satisfy device specification '", preferred,
"'. enable_soft_placement=", AllowSoftPlacement(),
". Supported device types [",
absl::StrJoin(DeviceTypesToString(supported), ", "),
"]. All available devices [",
absl::StrJoin(DevicesToString(existing), ", "), "].");
}
return errors::InvalidArgument(
"No supported device found in available devices [",
absl::StrJoin(DevicesToString(existing), ", "),
"]. enable_soft_placement=", AllowSoftPlacement(),
". Supported devices types [",
absl::StrJoin(DeviceTypesToString(supported), ", "), "].");
}
void EagerContext::ResetClusterFLR(
DistributedFunctionLibraryRuntime* cluster_flr) {
cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_);
}
EagerExecutor& EagerContext::Executor() {
tf_shared_lock l(executor_map_mu_);
return *gtl::FindWithDefault(thread_local_executor_,
std::this_thread::get_id(), &default_executor_);
}
void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
tensorflow::mutex_lock l(executor_map_mu_);
if (executor == &default_executor_) {
thread_local_executor_.erase(std::this_thread::get_id());
} else {
thread_local_executor_[std::this_thread::get_id()] = executor;
}
}
void EagerContext::ClearCachesAndThreadExecutors() {
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
{
mutex_lock l(executor_map_mu_);
executors_copy = thread_local_executor_;
}
for (const auto& entry : executors_copy) {
entry.second->WaitForAllPendingNodes().IgnoreError();
}
ClearCachesAndDefaultExecutor();
}
void EagerContext::ClearCachesAndDefaultExecutor() {
// The executor stores pointers to kernels, so we need to make sure that no
// async eager ops are still executing. We lock the cache during this time
// as well.
mutex_lock ml(cache_mu_);
default_executor_.WaitForAllPendingNodes().IgnoreError();
kernel_cache_.clear();
for (auto& entry : registered_functions_) {
entry.second->cached_kernel_keys->clear();
}
}
void EagerContext::SetThreadLocalDevicePlacementPolicy(
ContextDevicePlacementPolicy policy) {
mutex_lock ml(policy_map_mu_);
device_placement_policy_[std::this_thread::get_id()] = policy;
}
ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() const {
tf_shared_lock l(policy_map_mu_);
auto policy_map_it =
device_placement_policy_.find(std::this_thread::get_id());
if (policy_map_it != device_placement_policy_.end()) {
return policy_map_it->second;
}
return default_device_placement_policy_;
}
void EagerContext::SetThreadLocalMirroringPolicy(
ContextMirroringPolicy policy) {
mutex_lock ml(policy_map_mu_);
mirroring_policy_[std::this_thread::get_id()] = policy;
}
ContextMirroringPolicy EagerContext::GetMirroringPolicy() const {
tf_shared_lock l(policy_map_mu_);
auto policy_map_it = mirroring_policy_.find(std::this_thread::get_id());
if (policy_map_it != mirroring_policy_.end()) {
return policy_map_it->second;
}
return default_mirroring_policy_;
}
bool EagerContext::MirrorTensors() const {
return GetMirroringPolicy() == MIRRORING_ALL;
}
bool EagerContext::LazyCopyFunctionRemoteInputs() const {
return lazy_copy_function_remote_inputs_;
}
#if !defined(IS_MOBILE_PLATFORM)
void EagerContext::CloseAndClearAllRemoteContexts() {
uint64 context_id;
uint64 context_view_id;
{
mutex_lock l(remote_state_mu_);
if (!is_master_) return;
context_id = context_id_;
context_view_id = context_view_id_;
context_id_ = kInvalidContextId;
// Forget the current view id and reset to the starting value 0.
context_view_id_ = 0;
}
CloseRemoteContexts(remote_contexts_, context_id, context_view_id);
remote_contexts_.clear();
}
void EagerContext::CloseRemoteContexts(
const std::vector<string>& remote_contexts, uint64 context_id,
uint64 context_view_id) {
// Close all remote contexts.
eager::CloseContextRequest request;
request.set_context_id(context_id);
request.set_context_view_id(context_view_id);
// Setting context_id to a new value can avoid us issuing DestroyTensorHandle
// request to closed remote workers.
std::vector<eager::CloseContextResponse> responses(remote_contexts.size());
BlockingCounter counter(static_cast<int>(remote_contexts.size()));
int i = 0;
for (const auto& worker : remote_contexts) {
core::RefCountPtr<eager::EagerClient> client;
Status s = remote_eager_workers_->GetClient(worker, &client);
client->CloseContextAsync(
&request, &responses[i],
[&worker, &counter, context_id](const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "Unable to close remote context with ID "
<< context_id << " for worker: " << worker << " due to "
<< s.error_message();
}
counter.DecrementCount();
});
i++;
}
counter.Wait();
}
#endif // !IS_MOBILE_PLATFORM
void EagerContext::WaitForAndCloseRemoteContexts() {
ClearCachesAndThreadExecutors();
#if !defined(IS_MOBILE_PLATFORM)
{
mutex_lock l(keep_alive_thread_shutdown_mu_);
shutting_down_ = true;
keep_alive_thread_cv_.notify_all();
}
keep_alive_thread_.reset();
if (!remote_contexts_.empty()) {
CloseAndClearAllRemoteContexts();
}
{
mutex_lock l(remote_state_mu_);
default_executor_.ShutDown().IgnoreError();
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
{
mutex_lock l(executor_map_mu_);
executors_copy = thread_local_executor_;
}
for (const auto& it : executors_copy) {
it.second->ShutDown().IgnoreError();
}
}
// This shuts down the completion queue and joins the thread polling it.
// The thread exits only after the completion queue has been drained of all
// the events. These events' completion should invoke all remaining RPC
// callbacks.
// This also deletes all EagerClient instances. There should not be any
// references to EagerClients left after all RPCs and async ops have been
// finished.
remote_eager_workers_ = nullptr;
#endif // !IS_MOBILE_PLATFORM
}
EagerContext::~EagerContext() {
// TODO(iga): Add a separate API method to shutdown EagerContext so that we
// don't send RPCs and block in destructor.
WaitForAndCloseRemoteContexts();
ClearCachesAndThreadExecutors();
for (auto& entry : registered_functions_) {
while (!entry.second->Unref()) {
// remove all references.
}
}
registered_functions_.clear();
#if !defined(IS_MOBILE_PLATFORM)
if (server_) {
// TODO(b/136478427): Fix this.
LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
"Servers don't support clean shutdown.";
server_.release();
}
{
mutex_lock l(keep_alive_thread_shutdown_mu_);
shutting_down_ = true;
keep_alive_thread_cv_.notify_all();
}
keep_alive_thread_.reset();
if (!remote_contexts_.empty()) {
CloseAndClearAllRemoteContexts();
}
#endif // !IS_MOBILE_PLATFORM
if (rendezvous_) {
rendezvous_->Unref();
}
if (resource_deallocator_ != nullptr) {
resource_deallocator_();
}
}
bool EagerContext::FindFunctionByName(const string& name) const {
return func_lib_def_.Find(name) != nullptr;
}
Status EagerContext::FindFunctionOpData(
const string& name, const tensorflow::OpRegistrationData** op_data) {
return func_lib_def_.LookUp(name, op_data);
}
const FunctionDef* EagerContext::FindFunctionDef(const string& name) {
return func_lib_def_.Find(name);
}
std::vector<const FunctionDef*> EagerContext::ListRegisteredFunctions() {
std::vector<const FunctionDef*> result;
std::vector<string> function_names = func_lib_def_.ListFunctionNames();
result.reserve(function_names.size());
for (const string& fn : function_names) {
result.emplace_back(func_lib_def_.Find(fn));
}
return result;
}
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
void EagerContext::ListDevices(
std::vector<tensorflow::DeviceAttributes>* devices) {
local_device_mgr()->ListDeviceAttributes(devices);
if (remote_device_mgr()) {
remote_device_mgr()->ListDeviceAttributes(devices);
}
}
void EagerContext::StartStep() {
mutex_lock ml(metadata_mu_);
num_active_steps_++;
if (step_container_ == nullptr) {
step_container_.reset(
new ScopedStepContainer(0, [this](const string& name) {
auto local_devices = local_device_mgr()->ListDevices();
for (Device* device : local_devices) {
device->resource_manager()->Cleanup(name).IgnoreError();
}
}));
}
}
void EagerContext::EndStep() {
mutex_lock ml(metadata_mu_);
num_active_steps_--;
if (num_active_steps_ == 0) {
step_container_.reset();
}
}
ScopedStepContainer* EagerContext::StepContainer() {
if (num_active_steps_.load() == 0) {
return nullptr;
}
mutex_lock ml(metadata_mu_);
return step_container_.get();
}
Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
// Only client context can register function on remote worker context.
if (!remote_device_manager_.Owned()) return Status::OK();
#if !defined(IS_MOBILE_PLATFORM)
std::shared_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
request->set_context_id(GetContextId());
eager::RegisterFunctionOp* register_function =
request->add_queue()->mutable_register_function();
*register_function->mutable_function_def() = fdef;
StripDefaultAttributes(
*OpRegistry::Global(),
register_function->mutable_function_def()->mutable_node_def());
for (const auto& target : remote_contexts_) {
core::RefCountPtr<eager::EagerClient> eager_client;
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
eager::EnqueueResponse* response = new eager::EnqueueResponse();
eager_client->StreamingEnqueueAsync(
request.get(), response, [request, response](const Status& status) {
if (!status.ok()) {
LOG(ERROR) << "Failed to register function remotely due to "
<< status.error_message()
<< "\nThis shouldn't happen, please file a bug to "
"tensorflow team.";
}
delete response;
});
}
#endif // !IS_MOBILE_PLATFORM
return Status::OK();
}
Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
const std::vector<const FunctionDef*>& function_defs,
const std::vector<string>& remote_workers) {
#if !defined(IS_MOBILE_PLATFORM)
// Register multiple functions on selected remote workers.
uint64 context_id = GetContextId();
for (int i = 0; i < remote_workers.size(); i++) {
core::RefCountPtr<eager::EagerClient> eager_client;
Status s =
remote_eager_workers_->GetClient(remote_workers[i], &eager_client);
if (!s.ok()) {
continue;
}
for (int j = 0; j < function_defs.size(); j++) {
auto* request = new eager::EnqueueRequest;
request->set_context_id(context_id);
eager::RegisterFunctionOp* register_function =
request->add_queue()->mutable_register_function();
*register_function->mutable_function_def() = *function_defs[j];
StripDefaultAttributes(
*OpRegistry::Global(),
register_function->mutable_function_def()->mutable_node_def());
auto* response = new eager::EnqueueResponse;
eager_client->StreamingEnqueueAsync(
request, response, [request, response](const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "Failed to register function remotely due to "
<< s.error_message()
<< "\nThis shouldn't happen, please file a bug to "
"tensorflow team.";
}
delete request;
delete response;
});
}
}
#endif // !IS_MOBILE_PLATFORM
return Status::OK();
}
Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
return AddFunctionDef(fdef, FunctionDefLibrary(),
/* add_to_local_only=*/false);
}
Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
const FunctionDefLibrary& library,
const bool add_to_local_only) {
bool is_first_ref = false;
{
mutex_lock l(cache_mu_);
auto* registered_function =
gtl::FindPtrOrNull(registered_functions_, fdef.signature().name());
if (registered_function == nullptr) {
registered_function = new RegisteredFunction;
registered_function->cached_kernel_keys =
absl::make_unique<std::vector<Fprint128>>();
gtl::InsertOrUpdate(&registered_functions_, fdef.signature().name(),
registered_function);
} else {
registered_function->Ref();
}
is_first_ref = registered_function->RefCountIsOne();
}
if (is_first_ref) {
TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef));
TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
if (!add_to_local_only) {
return MaybeRegisterFunctionRemotely(fdef);
}
}
return Status::OK();
}
const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) {
return func_lib_def_.Find(function_name);
}
Status EagerContext::RemoveFunction(const string& func) {
bool is_last_ref = false;
{
mutex_lock l(cache_mu_);
auto* registered_function = gtl::FindPtrOrNull(registered_functions_, func);
if (registered_function == nullptr) {
return errors::InvalidArgument("Tried to remove non-existent function '",
func, "'.");
}
is_last_ref = registered_function->RefCountIsOne();
if (is_last_ref) {
for (auto& key : *registered_function->cached_kernel_keys) {
kernel_cache_.erase(key);
}
registered_functions_.erase(func);
}
registered_function->Unref();
}
if (is_last_ref) {
// TODO(fishx): Remove remote function as well.
return func_lib_def_.RemoveFunction(func);
}
return Status::OK();
}
Status EagerContext::SyncExecutors() {
StatusGroup sg;
// Synchronize on context default executor
sg.Update(default_executor_.WaitForAllPendingNodes());
default_executor_.ClearError();
// Synchronize thread local executors on client
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
{
mutex_lock l(executor_map_mu_);
executors_copy = thread_local_executor_;
}
for (const auto& entry : executors_copy) {
sg.Update(entry.second->WaitForAllPendingNodes());
entry.second->ClearError();
}
#if !defined(IS_MOBILE_PLATFORM)
// Synchronize executors on remote workers
eager::EnqueueRequest request;
request.set_context_id(GetContextId());
request.add_queue()->mutable_sync_remote_executor_for_stream();
BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
std::vector<Status> statuses(remote_contexts_.size());
for (int i = 0; i < remote_contexts_.size(); i++) {
const auto& target = remote_contexts_[i];
core::RefCountPtr<eager::EagerClient> eager_client;
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
eager::EnqueueResponse* response = new eager::EnqueueResponse();
eager_client->StreamingEnqueueAsync(
&request, response,
[response, target, &counter, &s = statuses[i]](const Status& status) {
s = status;
delete response;
counter.DecrementCount();
});
}
counter.Wait();
for (const Status& s : statuses) {
sg.Update(s);
}
#endif // !IS_MOBILE_PLATFORM
return sg.as_summary_status();
}
core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
Fprint128 cache_key) {
tf_shared_lock l(cache_mu_);
auto iter = kernel_cache_.find(cache_key);
if (iter == kernel_cache_.end()) {
return nullptr;
}
core::RefCountPtr<KernelAndDevice> new_ref(iter->second.get());
new_ref->Ref();
return new_ref;
}
void EagerContext::AddKernelToCache(Fprint128 cache_key,
KernelAndDevice* kernel) {
mutex_lock ml(cache_mu_);
core::RefCountPtr<KernelAndDevice> new_ref(kernel);
new_ref->Ref();
kernel_cache_[cache_key] = std::move(new_ref);
auto* registered_function =
gtl::FindPtrOrNull(registered_functions_, kernel->name());
// The kernel name can be either a primitive op or a function.
if (registered_function != nullptr) {
registered_function->cached_kernel_keys->emplace_back(cache_key);
}
}
bool EagerContext::ShouldStoreGraphs() { return should_store_graphs_.load(); }
void EagerContext::SetShouldStoreGraphs(bool value) {
mutex_lock ml(metadata_mu_);
should_store_graphs_.store(value);
if (!value) {
run_metadata_.Clear();
}
}
Status EagerContext::FindDeviceFromName(const char* device_name,
Device** device) const {
*device = HostCPU();
if (device_name == nullptr || strlen(device_name) == 0) {
return Status::OK();
}
auto status = local_device_mgr()->LookupDevice(device_name, device);
if (status.ok()) {
return status;
}
if (remote_device_mgr() != nullptr) {
return remote_device_mgr()->LookupDevice(device_name, device);
}
return status;
}
Status EagerContext::FindCustomDeviceFromName(const string& device_name,
CustomDevice** dev) const {
auto dev_it = custom_devices_.find(device_name);
if (dev_it == custom_devices_.end()) {
return errors::InvalidArgument(device_name, " unknown device.");
}
*dev = dev_it->second.get();
return Status::OK();
}
Status EagerContext::RegisterCustomDevice(
const string& device_name, std::unique_ptr<CustomDevice> device) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
!parsed.has_job || !parsed.has_replica || !parsed.has_task ||
!parsed.has_type || !parsed.has_id) {
return errors::InvalidArgument(
device_name,
" could not be parsed as a device name. Use the full "
"/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
"format.");
}
Device* existing_physical_device = nullptr;
if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
return errors::AlreadyExists(device_name,
" already registered as a physical device.");
}
if (!custom_devices_.emplace(device_name, std::move(device)).second) {
return errors::AlreadyExists(device_name,
" already registered as a custom device.");
}
return Status::OK();
}
bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
if (first == nullptr) first = HostCPU();
if (second == nullptr) second = HostCPU();
return first->parsed_name().job == second->parsed_name().job &&
first->parsed_name().replica == second->parsed_name().replica &&
first->parsed_name().task == second->parsed_name().task;
}
// Gets the CPU device on the task of device.
Status EagerContext::CPUDeviceOnTask(const Device* device,
Device** cpu_device) const {
string cpu_device_name;
TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
device->name(), &cpu_device_name));
return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
}
namespace {
Status GetTaskName(Device* d, string* task_name) {
string ignored;
if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) {
return errors::InvalidArgument("Unable to parse device name: ", d->name());
}
return Status::OK();
}
} // namespace
#if !defined(IS_MOBILE_PLATFORM)
Status EagerContext::GetClient(Device* device,
core::RefCountPtr<eager::EagerClient>* client) {
return GetClient(device->parsed_name(), client);
}
Status EagerContext::GetClient(const DeviceNameUtils::ParsedName& device_name,
core::RefCountPtr<eager::EagerClient>* client) {
if (remote_eager_workers_ == nullptr) {
return errors::Internal(
"Haven't set up remote eager worker in this eager context yet.");
}
string device_task_name;
if (!DeviceNameUtils::GetTaskName(device_name, &device_task_name)) {
return errors::InvalidArgument(
"Task is not fully specified in device name: ",
DeviceNameUtils::ParsedNameToString(device_name));
}
TF_RETURN_IF_ERROR(
remote_eager_workers_->GetClient(device_task_name, client));
if (*client == nullptr) {
return errors::InvalidArgument(
"Unable to find eager client corresponding to device ",
DeviceNameUtils::ParsedNameToString(device_name));
}
if (std::find(remote_contexts_.begin(), remote_contexts_.end(),
device_task_name) == remote_contexts_.end()) {
return errors::Internal("Unable to find a context for handle on task: ",
device_task_name, ". This should not be possible");
}
return Status::OK();
}
Status EagerContext::GetClient(const string& remote_task,
core::RefCountPtr<eager::EagerClient>* client) {
if (remote_eager_workers_ == nullptr) {
return errors::Internal(
"Haven't set up remote eager worker in this eager context yet.");
}
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(remote_task, client));
if (*client == nullptr) {
return errors::InvalidArgument(
"Unable to find eager client corresponding to target ", remote_task);
}
return Status::OK();
}
uint64 EagerContext::GetContextId() const {
tf_shared_lock l(remote_state_mu_);
return context_id_;
}
uint64 EagerContext::GetContextViewId() const {
tf_shared_lock l(remote_state_mu_);
return context_view_id_;
}
void EagerContext::IncrementContextViewId() {
mutex_lock l(remote_state_mu_);
context_view_id_ += 1;
}
// Set collective ops related state in the context. Passing nullptr to
// `new_server` will reuse the existing GRPC server in context.
Status EagerContext::StoreCollectiveOpsServer(
std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
local_device_manager_.Reset(device_mgr);
host_cpu_device_ = local_device_manager_.Get()->HostCPU();
InitPrioritizedDeviceTypeList();
ClearCachesAndThreadExecutors();
default_executor_.ClearError();
{
tensorflow::mutex_lock l(executor_map_mu_);
for (auto& entry : thread_local_executor_) {
entry.second->ClearError();
}
}
const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
ResetPFLR(
local_device_manager_.Get(), env_, /*config=*/config,
TF_GRAPH_DEF_VERSION, &func_lib_def_,
/*optimizer_options=*/
config ? config->graph_options().optimizer_options() : OptimizerOptions(),
thread_pool_.get());
if (new_server != nullptr) {
// Memory leak!
if (server_ != nullptr) {
LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
"Servers don't support clean shutdown.";
server_.release();
}
server_ = std::move(new_server);
}
DCHECK(server_ != nullptr);
return Status::OK();
}
Status EagerContext::SetRemoteDeviceFilters(
const string& remote_worker, const std::vector<string>& device_filters) {
// Get fully specified task name for remote worker
string remote_worker_task_name;
DeviceNameUtils::ParsedName pw;
if (!DeviceNameUtils::ParseFullName(remote_worker, &pw)) {
return tensorflow::errors::InvalidArgument(
"Remote worker task name is invalid ", remote_worker);
}
// Force set a replica as the key in cluster device filters map. I.e., if the
// remote worker is `/job:worker/task:0` it then becomes
// `/job:worker/replica:0/task:0`.
pw.has_replica = true;
if (!DeviceNameUtils::GetTaskName(pw, &remote_worker_task_name)) {
return tensorflow::errors::InvalidArgument(
"Job name and task index must be specified for worker ", remote_worker);
}
std::vector<DeviceNameUtils::ParsedName> parsed_filters;
for (auto& filter : device_filters) {
DeviceNameUtils::ParsedName parsed_filter;
if (DeviceNameUtils::ParseFullName(filter, &parsed_filter)) {
parsed_filters.emplace_back(parsed_filter);
} else {
return tensorflow::errors::InvalidArgument("Invalid filter: ", filter);
}
}
if (VLOG_IS_ON(1)) {
VLOG(1) << "Setting device filters for " << remote_worker << ":";
for (auto& filter : device_filters) {
VLOG(1) << " " << filter;
}
}
mutex_lock l(remote_state_mu_);
cluster_device_filters_.emplace(remote_worker_task_name, parsed_filters);
return Status::OK();
}
void EagerContext::FilterDevicesForRemoteWorkers(
const string& remote_worker,
const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
std::vector<bool>* filtered_device_mask) {
filtered_device_mask->resize(device_attrs.size());
std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), false);
tf_shared_lock l(remote_state_mu_);
auto it = cluster_device_filters_.find(remote_worker);
// If no filters were specified, all devices should be visible to the worker
if (it == cluster_device_filters_.end() || it->second.empty()) {
std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), true);
return;
}
const std::vector<DeviceNameUtils::ParsedName>& parsed_filters = it->second;
DeviceNameUtils::ParsedName parsed_remote_worker;
DeviceNameUtils::ParseFullName(remote_worker, &parsed_remote_worker);
for (int i = 0; i < device_attrs.size(); i++) {
DeviceNameUtils::ParsedName pn;
DeviceNameUtils::ParseFullName(device_attrs[i].name(), &pn);
if (DeviceNameUtils::IsSameAddressSpace(parsed_remote_worker, pn)) {
// If this device is on the remote worker itself, it should be visible
// regardless of device filters
filtered_device_mask->at(i) = true;
continue;
}
for (const auto& pf : parsed_filters) {
if ((!pn.has_job || !pf.has_job || pn.job == pf.job) &&
(!pn.has_replica || !pf.has_replica || pn.replica == pf.replica) &&
(!pn.has_task || !pf.has_task || pn.task == pf.task) &&
(!pn.has_type || !pf.has_type || pn.type == pf.type) &&
(!pn.has_id || !pf.has_id || pn.id == pf.id)) {
// Found a match, make it visible, stop processing more device filters
filtered_device_mask->at(i) = true;
break;
}
}
}
}
Status EagerContext::InitializeRemoteMaster(
std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
std::shared_ptr<WorkerSession> worker_session,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
const std::vector<string>& remote_contexts, uint64 context_id,
Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
DistributedFunctionLibraryRuntime* cluster_flr,
std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
remote_mgr) {
if (context_id == kInvalidContextId) {
return errors::InvalidArgument(
"Failed to initialize remote for master context due to invalid ",
"context id");
}
if (!remote_contexts_.empty()) {
CloseAndClearAllRemoteContexts();
}
remote_contexts_ = remote_contexts;
return SetMasterContextState(
std::move(server), worker_env, std::move(worker_session),
std::move(remote_eager_workers), std::move(remote_device_manager),
context_id, 0, r, local_device_mgr, keep_alive_secs, cluster_flr,
std::move(remote_mgr));
}
Status EagerContext::UpdateRemoteMaster(
WorkerEnv* worker_env,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
const std::vector<string>& add_remote_contexts,
const std::vector<string>& remove_remote_contexts, uint64 context_id,
Rendezvous* r) {
{
tf_shared_lock l(remote_state_mu_);
if (context_id != context_id_) {
return errors::InvalidArgument(
"Failed to update remote remote master context due to invalid ",
"context id. Request id = ", context_id,
" but current id = ", context_id_);
}
}
if (!remove_remote_contexts.empty()) {
// N.B. remove_remote_contexts include both removed and replaced workers.
// In the case where a worker is replaced by one that resolves to the same
// `hostname:port`, it is safe to close context with the current view id,
// since the newly created context on the remote worker will be holding
// a larger view id and ignores this request.
CloseRemoteContexts(remove_remote_contexts, context_id, GetContextViewId());
for (const string& remote_context : remove_remote_contexts) {
remote_contexts_.erase(
std::remove(remote_contexts_.begin(), remote_contexts_.end(),
remote_context),
remote_contexts_.end());
}
}
if (!add_remote_contexts.empty()) {
remote_contexts_.insert(std::end(remote_contexts_),
std::begin(add_remote_contexts),
std::end(add_remote_contexts));
}
std::vector<const FunctionDef*> function_defs = ListRegisteredFunctions();
{
mutex_lock l(remote_state_mu_);
context_view_id_++;
worker_env_ = worker_env;
if (rendezvous_ != nullptr) rendezvous_->Unref();
rendezvous_ = r;
remote_eager_workers_ = std::move(remote_eager_workers);
pflr_->InitializeDeviceSet();
InitPrioritizedDeviceTypeList();
default_executor_.ClearError();
{
tensorflow::mutex_lock l(executor_map_mu_);
for (auto& entry : thread_local_executor_) {
entry.second->ClearError();
}
}
}
// Register existing functions to the newly added remote workers. Note that
// this should happen only after updating `remote_contexts_` because new
// functions might be registered while we update the context. When that
// happens, this ordering ensures that `MaybeRegisterFunctionRemotely` will
// register the new functions on all remote workers (including the newly added
// ones), and `RegisterExistingFunctionsOnRemoteWorkers` will take care of
// registering existing functions, where duplicate registrations will be
// ignored by the remote workers.
TF_RETURN_IF_ERROR(RegisterExistingFunctionsOnRemoteWorkers(
function_defs, add_remote_contexts));
return Status::OK();
}
// Set distributed execution related state in the master context.
Status EagerContext::SetMasterContextState(
std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
std::shared_ptr<WorkerSession> worker_session,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<DynamicDeviceMgr> remote_device_manager, uint64 context_id,
uint64 context_view_id, Rendezvous* r, DeviceMgr* local_device_mgr,
int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
remote_mgr) {
mutex_lock l(remote_state_mu_);
is_master_ = true;
context_id_ = context_id;
context_view_id_ = context_view_id;
use_send_tensor_rpc_ =
ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
local_device_manager_.Reset(local_device_mgr);
host_cpu_device_ = local_device_manager_.Get()->HostCPU();
if (rendezvous_ != nullptr) rendezvous_->Unref();
rendezvous_ = r;
// Memory leak!
if (server_ != nullptr) {
LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
"Servers don't support clean shutdown.";
server_.release();
}
server_ = std::move(server);
remote_mgr_ = std::move(remote_mgr);
worker_env_ = worker_env;
worker_session_ = std::move(worker_session);
remote_eager_workers_ = std::move(remote_eager_workers);
remote_device_manager_.Reset(std::move(remote_device_manager));
ResetClusterFLR(cluster_flr);
InitPrioritizedDeviceTypeList();
ClearCachesAndThreadExecutors();
default_executor_.ClearError();
{
tensorflow::mutex_lock l(executor_map_mu_);
for (auto& entry : thread_local_executor_) {
entry.second->ClearError();
}
}
const auto* config = pflr_->config();
ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
&func_lib_def_, config->graph_options().optimizer_options(),
thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
keep_alive_secs_ = keep_alive_secs;
sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
// Only schedule a single closure.
if (keep_alive_thread_ == nullptr) {
keep_alive_thread_.reset(
env_->StartThread({}, "EagerKeepAliveThread", [this]() {
while (true) {
{
{
mutex_lock l(keep_alive_thread_shutdown_mu_);
if (shutting_down_) {
return;
}
keep_alive_thread_cv_.wait_for(
l, std::chrono::seconds(sleep_for_secs_));
if (shutting_down_) {
return;
}
}
{
mutex_lock l(remote_state_mu_);
if (keep_alive_secs_ > 0) {
{
for (const auto& worker : remote_contexts_) {
core::RefCountPtr<eager::EagerClient> client;
Status s =
remote_eager_workers_->GetClient(worker, &client);
if (!s.ok()) {
LOG(WARNING) << "Keep-alive thread was unable to find "
"a client for target "
<< worker << ". Got error: " << s;
continue;
}
eager::KeepAliveRequest* request =
new eager::KeepAliveRequest;
eager::KeepAliveResponse* response =
new eager::KeepAliveResponse;
request->set_context_id(context_id_);
client->KeepAliveAsync(
request, response,
[request, response](const Status& s) {
delete request;
delete response;
});
}
}
}
}
}
}
}));
}
return Status::OK();
}
Status EagerContext::InitializeRemoteWorker(
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
DynamicDeviceMgr* remote_device_mgr,
const std::vector<string>& remote_contexts, uint64 context_id,
uint64 context_view_id,
std::function<Rendezvous*(const int64)> rendezvous_creator,
DistributedFunctionLibraryRuntime* cluster_flr,
std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
remote_mgr,
std::function<void()> resource_deallocator) {
if (context_id == kInvalidContextId) {
return errors::InvalidArgument(
"Failed to initialize remote for worker context due to invalid ",
"context id");
}
mutex_lock l(remote_state_mu_);
if (remote_device_manager_.Owned() || server_ != nullptr ||
keep_alive_thread_ != nullptr) {
return errors::FailedPrecondition(
"EagerContext::InitializeRemoteWorker Failed. ",
"Already initialized remote as a master context.");
}
is_master_ = false;
remote_contexts_ = remote_contexts;
context_id_ = context_id;
context_view_id_ = context_view_id;
rendezvous_creator_ = std::move(rendezvous_creator);
remote_eager_workers_ = std::move(remote_eager_workers);
remote_mgr_ = std::move(remote_mgr);
ResetClusterFLR(cluster_flr);
remote_device_manager_.Reset(remote_device_mgr);
const auto* config = pflr_->config();
ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
&func_lib_def_, config->graph_options().optimizer_options(),
thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
InitPrioritizedDeviceTypeList();
ClearCachesAndThreadExecutors();
default_executor_.ClearError();
{
tensorflow::mutex_lock l(executor_map_mu_);
for (auto& entry : thread_local_executor_) {
entry.second->ClearError();
}
}
resource_deallocator_ = std::move(resource_deallocator);
return Status::OK();
}
Status EagerContext::UpdateRemoteWorker(
const DeviceMgr* worker_session_device_mgr,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
DynamicDeviceMgr* remote_device_mgr,
const std::vector<string>& remote_contexts, uint64 context_id,
DistributedFunctionLibraryRuntime* cluster_flr) {
{
mutex_lock l(remote_state_mu_);
if (context_id != context_id_) {
return errors::InvalidArgument(
"Failed to update remote for worker context due to invalid ",
"context id. Request id = ", context_id,
" but current id = ", context_id_);
}
context_view_id_++;
}
remote_contexts_ = remote_contexts;
remote_eager_workers_ = std::move(remote_eager_workers);
ResetClusterFLR(cluster_flr);
remote_device_manager_.Reset(remote_device_mgr);
InitPrioritizedDeviceTypeList();
ClearCachesAndThreadExecutors();
default_executor_.ClearError();
{
tensorflow::mutex_lock l(executor_map_mu_);
for (auto& entry : thread_local_executor_) {
entry.second->ClearError();
}
}
SessionOptions options = SessionOptions();
const auto* config = pflr_->config();
ResetPFLR(worker_session_device_mgr, options.env, config,
TF_GRAPH_DEF_VERSION, FuncLibDef(),
config->graph_options().optimizer_options(), thread_pool_.get(),
cluster_flr_.Get(), custom_kernel_creator_);
return Status::OK();
}
#endif // !IS_MOBILE_PLATFORM
} // namespace tensorflow