| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" |
| |
| #include <functional> |
| #include <string> |
| #include <utility> |
| |
| #include "absl/container/fixed_array.h" |
| #include "absl/memory/memory.h" |
| #include "absl/types/optional.h" |
| #include "tensorflow/c/c_api_internal.h" |
| #include "tensorflow/c/eager/immediate_execution_distributed_manager.h" |
| #include "tensorflow/c/tf_status_helper.h" |
| #include "tensorflow/core/common_runtime/device_mgr.h" |
| #include "tensorflow/core/common_runtime/eager/context.h" |
| #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h" |
| #include "tensorflow/core/common_runtime/eager/eager_operation.h" |
| #include "tensorflow/core/common_runtime/eager/execute.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/common_runtime/process_util.h" |
| #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" |
| #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" |
| #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h" |
| #include "tensorflow/core/distributed_runtime/message_wrappers.h" |
| #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" |
| #include "tensorflow/core/distributed_runtime/server_lib.h" |
| #include "tensorflow/core/distributed_runtime/session_mgr.h" |
| #include "tensorflow/core/distributed_runtime/worker_cache.h" |
| #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h" |
| #include "tensorflow/core/distributed_runtime/worker_env.h" |
| #include "tensorflow/core/framework/rendezvous.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/gtl/cleanup.h" |
| #include "tensorflow/core/lib/random/random.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/lib/strings/stringprintf.h" |
| #include "tensorflow/core/platform/cpu_info.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/host_info.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/refcount.h" |
| #include "tensorflow/core/profiler/lib/traceme.h" |
| #include "tensorflow/core/protobuf/coordination_config.pb.h" |
| #include "tensorflow/core/protobuf/error_codes.pb.h" |
| |
| namespace tensorflow { |
| namespace eager { |
| |
| namespace { |
| Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, |
| const google::protobuf::Map<string, tensorflow::AttrValue>& attrs, |
| int* num_retvals) { |
| const tensorflow::OpRegistrationData* op_reg_data = nullptr; |
| auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); |
| if (errors::IsNotFound(status)) { |
| status = context->FindFunctionOpData(op_name, &op_reg_data); |
| } |
| TF_RETURN_IF_ERROR(status); |
| |
| const tensorflow::OpDef& op_def = op_reg_data->op_def; |
| |
| for (const auto& output_arg : op_def.output_arg()) { |
| if (!output_arg.number_attr().empty()) { |
| auto iter = attrs.find(output_arg.number_attr()); |
| if (iter == attrs.end()) { |
| return errors::InvalidArgument("Unable to find number_attr ", |
| output_arg.number_attr(), |
| " for Op: ", op_name); |
| } |
| *num_retvals += iter->second.i(); |
| } else if (!output_arg.type_list_attr().empty()) { |
| auto iter = attrs.find(output_arg.type_list_attr()); |
| if (iter == attrs.end()) { |
| return errors::InvalidArgument("Unable to find type_list_attr ", |
| output_arg.type_list_attr(), |
| " for Op: ", op_name); |
| } |
| *num_retvals += iter->second.list().type_size(); |
| } else { |
| *num_retvals += 1; |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status GetEagerOperationAndNumRetvals(const Operation& operation, |
| EagerContext* eager_context, |
| EagerExecutor* eager_executor, |
| EagerOperation* eager_op, |
| int* num_retvals) { |
| const char* name = operation.name().c_str(); // Shorthand |
| absl::optional<tensorflow::EagerFunctionParams> remote_func_params = |
| absl::nullopt; |
| if (operation.is_function()) { |
| if (operation.is_component_function()) { |
| remote_func_params = {operation.id(), operation.func_step_id()}; |
| } else { |
| remote_func_params = {operation.id(), absl::nullopt}; |
| } |
| } |
| TF_RETURN_IF_ERROR(eager_op->Reset(name, operation.device().c_str(), false, |
| eager_executor, remote_func_params)); |
| |
| { |
| profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal", |
| profiler::TraceMeLevel::kVerbose); |
| for (const auto& input : operation.op_inputs()) { |
| tensorflow::TensorHandle* handle; |
| if (input.has_remote_handle()) { |
| TF_RETURN_IF_ERROR( |
| eager_context->RemoteMgr()->DeserializeRemoteTensorHandle( |
| input.remote_handle(), &handle)); |
| TF_RETURN_IF_ERROR(eager_op->AddInput(handle)); |
| } else { |
| Tensor tensor; |
| if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) { |
| return errors::InvalidArgument("Invalid TensorProto: ", |
| input.tensor().DebugString()); |
| } else { |
| handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr, |
| nullptr, eager_context); |
| TF_RETURN_IF_ERROR(eager_op->AddInput(handle)); |
| } |
| } |
| // Unref handle since it has a ref as an input now. |
| handle->Unref(); |
| } |
| } |
| |
| for (const auto& attr : operation.attrs()) { |
| eager_op->MutableAttrs()->Set(attr.first, attr.second); |
| } |
| |
| // TODO(nareshmodi): Consider caching this. |
| return GetNumRetvals(eager_context, operation.name(), operation.attrs(), |
| num_retvals); |
| } |
| |
| Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) { |
| const tensorflow::Tensor* t = nullptr; |
| TF_RETURN_IF_ERROR(handle->Tensor(&t)); |
| t->AsProtoTensorContent(proto); |
| return Status::OK(); |
| } |
| |
| Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { |
| const tensorflow::Tensor* t = nullptr; |
| |
| // TODO(nareshmodi): This call makes async calls sync calls. Fix this. |
| if (handle->Type() == TensorHandle::LOCAL) { |
| TF_RETURN_IF_ERROR(handle->Tensor(&t)); |
| |
| t->shape().AsProto(proto); |
| } else { |
| TensorShape shape; |
| TF_RETURN_IF_ERROR(handle->Shape(&shape)); |
| shape.AsProto(proto); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AddOpRetvalsToResponse( |
| EagerContext* eager_context, int op_id, int num_retvals, |
| const std::vector<int32>& output_nums, TensorHandle** retvals, |
| std::function<TensorProto*()> add_tensor_proto_fn, |
| std::function<TensorShapeProto*()> add_shape_proto_fn, |
| std::function<string*()> add_device_fn = nullptr) { |
| if (op_id == kInvalidOpId) { |
| // Copy the output tensors back along with the response, since the op id |
| // is invalid which cannot be added to RemoteMgr. |
| for (int i = 0; i < num_retvals; i++) { |
| TF_RETURN_IF_ERROR(TensorHandleProto(retvals[i], add_tensor_proto_fn())); |
| retvals[i]->Unref(); |
| } |
| } else { |
| for (int i = 0; i < num_retvals; i++) { |
| TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn())); |
| if (add_device_fn) { |
| Device* device = retvals[i]->device(); |
| *add_device_fn() = device ? device->name() : ""; |
| } |
| if (retvals[i]->Type() == TensorHandle::REMOTE) { |
| retvals[i]->Unref(); |
| } else { |
| const int output_num = output_nums.empty() ? i : output_nums.at(i); |
| eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id, |
| output_num); |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| } // namespace |
| |
| Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, |
| CreateContextResponse* response) { |
| { |
| mutex_lock l(contexts_mu_); |
| auto context_it = contexts_.find(request->context_id()); |
| if (context_it != contexts_.end()) { |
| if (request->context_view_id() < |
| context_it->second->Context()->GetContextViewId()) { |
| return errors::InvalidArgument("EagerService:CreateContext failed. ", |
| "Context id: <", request->context_id(), |
| "> already exists."); |
| } else { |
| // For existing context with a stale context_view_id, close the old one |
| // and recreate with new view id. This is likely due to the worker |
| // disconnected and then reconnected after one or more cluster updates. |
| context_it->second->Unref(); |
| contexts_.erase(context_it); |
| } |
| } |
| } |
| // make sure env_ , env_->rendezvous_mgr available |
| if (env_ == nullptr || env_->rendezvous_mgr == nullptr) { |
| return tensorflow::errors::Internal( |
| "invalid eager env_ or env_->rendezvous_mgr."); |
| } |
| |
| auto* r = env_->rendezvous_mgr->Find(request->context_id()); |
| auto session_name = |
| tensorflow::strings::StrCat("eager_", request->context_id()); |
| if (VLOG_IS_ON(2)) { |
| VLOG(2) << "Creating context on /job:" << request->server_def().job_name() |
| << "/task:" << request->server_def().task_index(); |
| for (const auto& da : request->cluster_device_attributes()) { |
| VLOG(2) << " " << da.name(); |
| } |
| } |
| TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession( |
| session_name, request->server_def(), request->cluster_device_attributes(), |
| request->server_def().default_session_config().isolate_session_state())); |
| int64_t context_id = request->context_id(); |
| std::function<void()> session_destroyer = [this, context_id, session_name]() { |
| env_->rendezvous_mgr->Cleanup(context_id); |
| auto s = env_->session_mgr->DeleteSession(session_name); |
| if (!s.ok()) { |
| LOG(WARNING) << "Failed to destroy worker session '" << session_name |
| << "' due to " << s.error_message(); |
| } |
| }; |
| |
| std::shared_ptr<WorkerSession> worker_session; |
| TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession( |
| session_name, &worker_session)); |
| |
| tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); |
| |
| // Initialize remote tensor communication based on worker session. |
| TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); |
| // Set the rendezvous as context-global instance for eager op-by-op execution. |
| r->SetRemoteEagerContextDefault(); |
| |
| std::function<Rendezvous*(const int64_t)> rendezvous_creator = |
| [worker_session, this](const int64_t step_id) { |
| auto* r = env_->rendezvous_mgr->Find(step_id); |
| r->Initialize(worker_session.get()).IgnoreError(); |
| return r; |
| }; |
| |
| LOG(INFO) << "Creating " << (request->async() ? "async" : "sync") |
| << " eager service context with rendezvous_id on host " |
| << port::Hostname() << " " << worker_session->worker_name(); |
| SessionOptions opts; |
| opts.config = request->server_def().default_session_config(); |
| tensorflow::EagerContext* ctx = new tensorflow::EagerContext( |
| opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, |
| request->async(), device_mgr, false, r, worker_session->cluster_flr(), |
| env_->collective_executor_mgr.get()); |
| // Ownership will be transferred to the ServerContext, or else in an error |
| // case ctx will be deleted by this unref. |
| core::ScopedUnref unref_ctx(ctx); |
| |
| std::vector<string> remote_workers; |
| worker_session->worker_cache()->ListWorkers(&remote_workers); |
| remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(), |
| worker_session->worker_name()), |
| remote_workers.end()); |
| |
| std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers; |
| TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache( |
| &remote_eager_workers)); |
| DistributedFunctionLibraryRuntime* cluster_flr = |
| eager::CreateClusterFLR(request->context_id(), ctx, worker_session.get()); |
| |
| auto remote_mgr = |
| absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx); |
| Status s = ctx->InitializeRemoteWorker( |
| std::move(remote_eager_workers), worker_session->remote_device_mgr(), |
| remote_workers, request->context_id(), request->context_view_id(), |
| std::move(rendezvous_creator), cluster_flr, std::move(remote_mgr), |
| std::move(session_destroyer)); |
| if (!s.ok()) { |
| VLOG(1) << "EagerContext::InitializeRemoteWorker failed with " |
| << s.ToString(); |
| return s; |
| } |
| |
| #if !defined(IS_MOBILE_PLATFORM) |
| const auto& config = request->server_def().default_session_config(); |
| const bool enable_coordination = |
| !config.experimental().coordination_config().service_type().empty(); |
| if (enable_coordination) { |
| auto dist_mgr = std::make_unique<EagerContextDistributedManager>(ctx); |
| ctx->SetDistributedManager(std::move(dist_mgr)); |
| TF_RETURN_IF_ERROR(ctx->GetDistributedManager()->EnableCoordinationService( |
| config.experimental().coordination_config().service_type(), env_, |
| request->server_def(), worker_session->worker_cache())); |
| std::unique_ptr<CoordinationClientCache> client_cache; |
| TF_RETURN_IF_ERROR( |
| worker_session->worker_cache()->GetCoordinationClientCache( |
| &client_cache)); |
| TF_RETURN_IF_ERROR( |
| ctx->GetDistributedManager()->GetCoordinationServiceAgent()->Initialize( |
| env_->env, env_->device_mgr, request->server_def(), |
| std::move(client_cache), |
| /*error_fn=*/[](Status s) { |
| LOG(ERROR) << "Coordination agent is set to error: " << s; |
| })); |
| } |
| #endif // !IS_MOBILE_PLATFORM |
| |
| std::vector<DeviceAttributes> device_attributes; |
| device_mgr->ListDeviceAttributes(&device_attributes); |
| |
| for (const auto& da : device_attributes) { |
| *response->add_device_attributes() = da; |
| } |
| { |
| mutex_lock l(contexts_mu_); |
| auto context_it = contexts_.find(request->context_id()); |
| if (context_it != contexts_.end()) { |
| return errors::InvalidArgument("EagerService:CreateContext failed. ", |
| "Context id: <", request->context_id(), |
| "> already exists."); |
| } |
| contexts_.emplace(request->context_id(), |
| new ServerContext(ctx, request->keep_alive_secs(), env_)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request, |
| UpdateContextResponse* response) { |
| // make sure env_ , env_->rendezvous_mgr available |
| if (env_ == nullptr || env_->rendezvous_mgr == nullptr) { |
| return tensorflow::errors::Internal( |
| "invalid eager env_ or env_->rendezvous_mgr."); |
| } |
| |
| // Find the context to update by the requested context_id |
| ServerContext* server_context = nullptr; |
| TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &server_context)); |
| core::ScopedUnref context_unref(server_context); |
| |
| tensorflow::EagerContext* ctx = server_context->Context(); |
| if (request->context_view_id() != ctx->GetContextViewId() + 1) { |
| return errors::InvalidArgument( |
| "EagerService:UpdateContext failed. Context id: <", |
| request->context_id(), "> currently at view #", ctx->GetContextViewId(), |
| " but received update request at view #", request->context_view_id(), |
| ". View id should only be continuously incremented."); |
| } |
| if (request->cluster_device_attributes_size() == 0) { |
| // In this case, the client indicates that the updated `server_def` and |
| // device info is irrelevant to this worker, since it is not connected to |
| // the updated ones (likely due to device filter settings). The worker |
| // simply needs to update view ID and does not update other internal state. |
| ctx->IncrementContextViewId(); |
| VLOG(1) << "Processing simplified UpdateContextRequest on " |
| << ctx->HostCPU()->name(); |
| return Status::OK(); |
| } |
| |
| auto session_name = |
| tensorflow::strings::StrCat("eager_", request->context_id()); |
| |
| TF_RETURN_IF_ERROR( |
| env_->session_mgr->UpdateSession(session_name, request->server_def(), |
| request->cluster_device_attributes())); |
| |
| std::shared_ptr<WorkerSession> worker_session; |
| TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession( |
| session_name, &worker_session)); |
| |
| const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); |
| |
| std::vector<string> remote_workers; |
| worker_session->worker_cache()->ListWorkers(&remote_workers); |
| remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(), |
| worker_session->worker_name()), |
| remote_workers.end()); |
| VLOG(1) << "On existing server " << worker_session->worker_name() |
| << " updating remote workers"; |
| if (VLOG_IS_ON(2)) { |
| for (const string& rw : remote_workers) { |
| VLOG(2) << "Remote worker " << rw; |
| } |
| } |
| |
| std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers; |
| TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache( |
| &remote_eager_workers)); |
| |
| ctx->ClearCachesAndThreadExecutors(); |
| Status s = ctx->UpdateRemoteWorker(std::move(remote_eager_workers), |
| remote_workers, request->context_id()); |
| if (!s.ok()) { |
| VLOG(1) << "EagerContext::UpdateRemoteWorker failed with " << s.ToString(); |
| return s; |
| } |
| |
| std::vector<DeviceAttributes> device_attributes; |
| device_mgr->ListDeviceAttributes(&device_attributes); |
| |
| for (const auto& da : device_attributes) { |
| *response->add_device_attributes() = da; |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::CreateMasterContext( |
| const tensorflow::uint64 context_id, EagerContext* context) { |
| { |
| mutex_lock l(contexts_mu_); |
| auto iter = contexts_.find(context_id); |
| if (iter != contexts_.end()) { |
| return errors::InvalidArgument( |
| "EagerService:CreateMasterContext failed. ", "Context id: <", |
| context_id, "> already exists."); |
| } |
| } |
| ServerContext* server_context = |
| ServerContext::CreateMasterContext(context, env_); |
| mutex_lock l(contexts_mu_); |
| contexts_.emplace(context_id, server_context); |
| return Status::OK(); |
| } |
| |
| void EagerServiceImpl::RunComponentFunction( |
| CallOptions* call_opts, const RunComponentFunctionRequest* request, |
| RunComponentFunctionResponse* response, StatusCallback done) { |
| ServerContext* context = nullptr; |
| Status s = GetServerContext(request->context_id(), &context); |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| core::ScopedUnref context_unref(context); |
| |
| auto& operation = request->operation(); |
| // This codepath should only be triggered for executing component function |
| if (!operation.is_function() || !operation.is_component_function()) { |
| done(errors::Internal( |
| "RunComponentFunction request can only be used to execute " |
| "component functions.")); |
| return; |
| } |
| |
| EagerContext* eager_context = context->Context(); |
| EagerExecutor* eager_executor = &eager_context->Executor(); |
| |
| EagerOperation* op = new EagerOperation(eager_context); |
| int* num_retvals = new int(0); |
| s = GetEagerOperationAndNumRetvals(operation, eager_context, eager_executor, |
| op, num_retvals); |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| if (!op->IsLocal()) { |
| done(errors::Internal( |
| "Received RunComponentFunction request with remote function device. ")); |
| return; |
| } |
| s = op->SetAttrBool("is_component_function", true); |
| if (!s.ok()) { |
| done(errors::Internal("Error setting is_component_function attribute: ", |
| s.error_message())); |
| return; |
| } |
| |
| auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals); |
| VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op " |
| << operation.id(); |
| std::vector<int32> output_nums; |
| for (const int32_t output_num : request->output_num()) { |
| output_nums.push_back(output_num); |
| } |
| |
| auto cm = std::make_shared<CancellationManager>(); |
| op->SetCancellationManager(cm.get()); |
| call_opts->SetCancelCallback([cm] { cm->StartCancel(); }); |
| |
| context->Ref(); |
| EagerLocalExecuteAsync( |
| op, retvals->data(), num_retvals, |
| [op, op_id = operation.id(), num_retvals, retvals, output_nums, cm, |
| call_opts, response, eager_context, context, |
| done = std::move(done)](const Status& status) { |
| call_opts->ClearCancelCallback(); |
| auto wrapped_done = [&](const Status& status) { |
| context->Unref(); |
| done(status); |
| delete op; |
| delete num_retvals; |
| delete retvals; |
| }; |
| if (!status.ok()) { |
| wrapped_done(status); |
| return; |
| } |
| // The output device of a component function is the component device |
| // which is known on the default device of it's parent function. |
| wrapped_done(AddOpRetvalsToResponse( |
| eager_context, op_id, *num_retvals, output_nums, retvals->data(), |
| [response] { return response->add_tensor(); }, |
| [response] { return response->add_shape(); })); |
| }); |
| } |
| |
| Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts, |
| const Operation& operation, |
| EagerContext* eager_context, |
| EagerExecutor* eager_executor, |
| QueueResponse* queue_response) { |
| tensorflow::EagerOperation op(eager_context); |
| int num_retvals = 0; |
| TF_RETURN_IF_ERROR(GetEagerOperationAndNumRetvals( |
| operation, eager_context, eager_executor, &op, &num_retvals)); |
| |
| auto cm = std::make_shared<CancellationManager>(); |
| if (call_opts) { |
| op.SetCancellationManager(cm.get()); |
| call_opts->SetCancelCallback([cm] { cm->StartCancel(); }); |
| } |
| |
| absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals); |
| VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id(); |
| TF_RETURN_IF_ERROR(op.Execute( |
| absl::MakeSpan( |
| reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals.data()), |
| num_retvals), |
| &num_retvals)); |
| |
| std::function<string*()> add_device_fn = nullptr; |
| // Send the output devices of a function back to let a client know where the |
| // outputs are. For a primitive op, an output devics is the op device which is |
| // known on a client. |
| if (op.is_function()) { |
| add_device_fn = [queue_response] { return queue_response->add_device(); }; |
| } |
| |
| return AddOpRetvalsToResponse( |
| eager_context, operation.id(), num_retvals, /*output_nums=*/{}, |
| retvals.data(), [queue_response] { return queue_response->add_tensor(); }, |
| [queue_response] { return queue_response->add_shape(); }, |
| std::move(add_device_fn)); |
| } |
| |
| Status EagerServiceImpl::Enqueue(CallOptions* call_opts, |
| const EnqueueRequest* request, |
| EnqueueResponse* response, uint64 stream_id) { |
| profiler::TraceMe activity( |
| [&] { |
| return absl::StrCat( |
| "EagerService:Enqueue#debug_str=", request->DebugString(), "#"); |
| }, |
| profiler::TraceMeLevel::kInfo); |
| ServerContext* context = nullptr; |
| TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); |
| core::ScopedUnref context_unref(context); |
| |
| EagerExecutor& executor = |
| stream_id == kInvalidStreamId |
| ? context->Context()->Executor() |
| : context->Context()->RemoteMgr()->GetOrCreateExecutorForStream( |
| stream_id); |
| Status s; |
| for (const auto& item : request->queue()) { |
| auto* queue_response = response->add_queue_response(); |
| if (item.has_operation()) { |
| s = ExecuteOp(call_opts, item.operation(), context->Context(), &executor, |
| queue_response); |
| } else if (item.has_handle_to_decref()) { |
| auto handle_to_decref = absl::make_unique<RemoteTensorHandleInternal>( |
| item.handle_to_decref()); |
| auto node = absl::make_unique<ClientTensorHandleDeleteNode>( |
| context, std::move(handle_to_decref)); |
| s = context->Context()->Executor().AddOrExecute(std::move(node)); |
| } else if (item.has_send_tensor()) { |
| s = SendTensor(item.send_tensor(), context->Context()); |
| } else if (item.has_send_packed_handle()) { |
| s = SendPackedHandle(item.send_packed_handle(), context->Context()); |
| } else if (item.has_register_function()) { |
| s = RegisterFunction(item.register_function(), context->Context()); |
| } else if (item.has_cleanup_function()) { |
| s = CleanupFunction(item.cleanup_function()); |
| } else { |
| DCHECK(item.has_sync_remote_executor_for_stream()); |
| s = executor.WaitForAllPendingNodes(); |
| } |
| |
| if (!s.ok()) { |
| if (stream_id != kInvalidStreamId) { |
| context->Context()->RemoteMgr()->DeleteExecutorForStream(stream_id); |
| } |
| return s; |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request, |
| WaitQueueDoneResponse* response) { |
| ServerContext* context = nullptr; |
| TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); |
| core::ScopedUnref context_unref(context); |
| |
| if (request->op_id_size() > 0) { |
| return errors::Unimplemented( |
| "EagerServiceImpl::WaitQueueDone is not " |
| "implemented for particular op IDs."); |
| } |
| return context->Context()->Executor().WaitForAllPendingNodes(); |
| } |
| |
| Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request, |
| KeepAliveResponse* response) { |
| ServerContext* context = nullptr; |
| TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); |
| core::ScopedUnref context_unref(context); |
| |
| tensorflow::EagerContext* ctx = context->Context(); |
| response->set_context_view_id(ctx->GetContextViewId()); |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::CloseContext(const CloseContextRequest* request, |
| CloseContextResponse* response) { |
| VLOG(1) << "Executing EagerService::CloseContext for context " |
| << request->context_id(); |
| ServerContext* context = nullptr; |
| if (!GetServerContext(request->context_id(), &context).ok()) { |
| // Swallow the error here. |
| return Status::OK(); |
| } |
| core::ScopedUnref context_unref(context); |
| |
| if (request->context_view_id() < context->Context()->GetContextViewId()) { |
| // Swallow the error here. |
| LOG(INFO) << "Ignoring CloseContext request with a stale context_view_id " |
| << request->context_view_id() << " for context_id " |
| << request->context_id() << ". The current context_view_id is " |
| << context->Context()->GetContextViewId() << "."; |
| return Status::OK(); |
| } |
| |
| mutex_lock l(contexts_mu_); |
| contexts_.erase(request->context_id()); |
| |
| // GetServerContext returns a newly Reffed copy of ServerContext, which is |
| // unreffed by context_unref. Additionally, we need to unref it one time since |
| // we are releasing it from the map. |
| context->Unref(); |
| |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::RegisterFunction( |
| const RegisterFunctionOp& register_function, EagerContext* eager_context) { |
| // If the function is a component of a multi-device function, we only need to |
| // register it locally. |
| return eager_context->AddFunctionDef( |
| register_function.function_def(), register_function.library(), |
| register_function.is_component_function()); |
| } |
| |
| Status EagerServiceImpl::CleanupFunction( |
| const CleanupFunctionOp& cleanup_function) { |
| env_->rendezvous_mgr->Cleanup(cleanup_function.step_id()); |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor, |
| EagerContext* eager_context) { |
| tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors; |
| for (const auto& tensor_proto : send_tensor.tensors()) { |
| Tensor tensor; |
| if (!tensor.FromProto(tensor_proto)) { |
| return errors::InvalidArgument("Unable to parse tensor proto"); |
| } |
| |
| TensorHandle* tensor_handle = TensorHandle::CreateLocalHandle( |
| std::move(tensor), nullptr, nullptr, eager_context); |
| TensorHandle* copied_handle = nullptr; |
| Device* device; |
| TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName( |
| send_tensor.device_name().c_str(), &device)); |
| TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, eager_context, |
| &eager_context->Executor(), device, |
| false, &copied_handle)); |
| tensors.push_back(copied_handle); |
| tensor_handle->Unref(); |
| } |
| |
| eager_context->RemoteMgr()->AddOperationOutputs(tensors, send_tensor.op_id()); |
| |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::SendPackedHandle( |
| const SendPackedHandleOp& send_packed_handle, EagerContext* eager_context) { |
| if (send_packed_handle.handles().empty()) { |
| return errors::InvalidArgument("Handles should not be empty."); |
| } |
| |
| std::vector<tensorflow::TensorHandle*> handles; |
| handles.resize(send_packed_handle.handles_size()); |
| for (int i = 0; i < send_packed_handle.handles_size(); ++i) { |
| const auto& item = send_packed_handle.handles(i); |
| if (item.has_local_handle()) { |
| Tensor tensor; |
| if (!ParseTensorProtoToTensor(item.local_handle().tensor(), &tensor)) { |
| return errors::InvalidArgument( |
| "Invalid TensorProto: ", |
| item.local_handle().tensor().DebugString()); |
| } |
| Device* op_device = nullptr; |
| TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName( |
| item.local_handle().device().c_str(), &op_device)); |
| handles[i] = TensorHandle::CreateLocalHandle( |
| std::move(tensor), /*d=*/nullptr, op_device, eager_context); |
| } else { |
| TF_RETURN_IF_ERROR( |
| eager_context->RemoteMgr()->DeserializeRemoteTensorHandle( |
| item.remote_handle(), &handles[i])); |
| } |
| } |
| |
| tensorflow::TensorHandle* packed_handle = nullptr; |
| std::vector<tensorflow::TensorHandle*> handles_to_pack = handles; |
| // Create a unshaped packed TensorHandle. |
| TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle( |
| std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(), |
| send_packed_handle.device_name(), eager_context, &packed_handle)); |
| |
| for (auto* h : handles) { |
| // Unref handle since it has a ref in the packed handle now. |
| h->Unref(); |
| } |
| |
| eager_context->RemoteMgr()->AddOperationOutputs({packed_handle}, |
| send_packed_handle.op_id()); |
| return Status::OK(); |
| } |
| |
| tensorflow::Status EagerServiceImpl::GetServerContext( |
| uint64 context_id, ServerContext** server_context) { |
| tf_shared_lock l(contexts_mu_); |
| auto iter = contexts_.find(context_id); |
| if (iter == contexts_.end()) { |
| *server_context = nullptr; |
| return errors::Aborted(strings::Printf( |
| "Unable to find a context_id matching the specified one " |
| "(%llu). Perhaps the worker was restarted, or the context was GC'd?", |
| static_cast<unsigned long long>(context_id))); |
| } |
| |
| *server_context = iter->second; |
| (*server_context)->Ref(); |
| |
| (*server_context)->RecordAccess(); |
| |
| return Status::OK(); |
| } |
| |
| } // namespace eager |
| } // namespace tensorflow |