| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" |
| |
| #include "absl/container/fixed_array.h" |
| #include "absl/memory/memory.h" |
| #include "tensorflow/c/c_api_internal.h" |
| #include "tensorflow/c/tf_status_helper.h" |
| #include "tensorflow/core/common_runtime/device_mgr.h" |
| #include "tensorflow/core/common_runtime/eager/context.h" |
| #include "tensorflow/core/common_runtime/eager/eager_operation.h" |
| #include "tensorflow/core/common_runtime/eager/execute.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/common_runtime/process_util.h" |
| #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" |
| #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h" |
| #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" |
| #include "tensorflow/core/distributed_runtime/server_lib.h" |
| #include "tensorflow/core/distributed_runtime/session_mgr.h" |
| #include "tensorflow/core/distributed_runtime/worker_cache.h" |
| #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h" |
| #include "tensorflow/core/distributed_runtime/worker_env.h" |
| #include "tensorflow/core/framework/rendezvous.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/gtl/cleanup.h" |
| #include "tensorflow/core/lib/random/random.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/lib/strings/stringprintf.h" |
| #include "tensorflow/core/platform/cpu_info.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/platform/host_info.h" |
| #include "tensorflow/core/profiler/lib/traceme.h" |
| #include "tensorflow/core/protobuf/error_codes.pb.h" |
| |
| namespace tensorflow { |
| namespace eager { |
| |
| namespace { |
| Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, |
| const google::protobuf::Map<string, tensorflow::AttrValue>& attrs, |
| int* num_retvals) { |
| const tensorflow::OpRegistrationData* op_reg_data = nullptr; |
| auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); |
| if (errors::IsNotFound(status)) { |
| status = context->FindFunctionOpData(op_name, &op_reg_data); |
| } |
| TF_RETURN_IF_ERROR(status); |
| |
| const tensorflow::OpDef& op_def = op_reg_data->op_def; |
| |
| for (const auto& output_arg : op_def.output_arg()) { |
| if (!output_arg.number_attr().empty()) { |
| auto iter = attrs.find(output_arg.number_attr()); |
| if (iter == attrs.end()) { |
| return errors::InvalidArgument("Unable to find number_attr ", |
| output_arg.number_attr(), |
| " for Op: ", op_name); |
| } |
| *num_retvals += iter->second.i(); |
| } else if (!output_arg.type_list_attr().empty()) { |
| auto iter = attrs.find(output_arg.type_list_attr()); |
| if (iter == attrs.end()) { |
| return errors::InvalidArgument("Unable to find type_list_attr ", |
| output_arg.type_list_attr(), |
| " for Op: ", op_name); |
| } |
| *num_retvals += iter->second.list().type_size(); |
| } else { |
| *num_retvals += 1; |
| } |
| } |
| |
| return Status::OK(); |
| } |
| } // namespace |
| |
| Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, |
| CreateContextResponse* response) { |
| // make sure env_ , env_->rendezvous_mgr available |
| if (env_ == nullptr || env_->rendezvous_mgr == nullptr) { |
| return tensorflow::errors::Internal( |
| "invalid eager env_ or env_->rendezvous_mgr."); |
| } |
| std::vector<DeviceAttributes> cluster_device_attributes; |
| cluster_device_attributes.reserve( |
| request->cluster_device_attributes().size()); |
| for (const auto& cluster_device : request->cluster_device_attributes()) { |
| cluster_device_attributes.push_back(cluster_device); |
| } |
| |
| auto* r = env_->rendezvous_mgr->Find(request->context_id()); |
| auto session_name = |
| tensorflow::strings::StrCat("eager_", request->context_id()); |
| TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession( |
| session_name, request->server_def(), request->cluster_device_attributes(), |
| true)); |
| |
| std::shared_ptr<WorkerSession> worker_session; |
| TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession( |
| session_name, &worker_session)); |
| |
| tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); |
| |
| // Initialize remote tensor communication based on worker session. |
| TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); |
| |
| std::function<Rendezvous*(const int64)> rendezvous_creator = |
| [worker_session, this](const int64 step_id) { |
| auto* r = env_->rendezvous_mgr->Find(step_id); |
| r->Initialize(worker_session.get()).IgnoreError(); |
| return r; |
| }; |
| |
| LOG(INFO) << "Creating " << (request->async() ? "async" : "sync") |
| << " eager service context with rendezvous_id on host " |
| << port::Hostname(); |
| tensorflow::EagerContext* ctx = new tensorflow::EagerContext( |
| SessionOptions(), |
| tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, |
| tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(), |
| device_mgr, false, r, GetDefaultCustomKernelCreator(), |
| worker_session->cluster_flr()); |
| // Ownership will be transferred to the ServerContext, or else in an error |
| // case ctx will be deleted by this unref. |
| core::ScopedUnref unref_ctx(ctx); |
| |
| std::vector<string> remote_workers; |
| worker_session->worker_cache()->ListWorkers(&remote_workers); |
| remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(), |
| worker_session->worker_name()), |
| remote_workers.end()); |
| |
| std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers; |
| TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache( |
| &remote_eager_workers)); |
| |
| auto remote_mgr = |
| absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx); |
| Status s = ctx->InitializeRemoteWorker( |
| std::move(remote_eager_workers), worker_session->remote_device_mgr(), |
| remote_workers, request->context_id(), std::move(rendezvous_creator), |
| std::move(remote_mgr)); |
| if (!s.ok()) { |
| VLOG(1) << "EagerContext::InitializeRemoteWorker failed with " |
| << s.ToString(); |
| return s; |
| } |
| |
| std::vector<DeviceAttributes> device_attributes; |
| device_mgr->ListDeviceAttributes(&device_attributes); |
| |
| for (const auto& da : device_attributes) { |
| *response->add_device_attributes() = da; |
| } |
| { |
| mutex_lock l(contexts_mu_); |
| if (contexts_.find(request->context_id()) != contexts_.end()) { |
| return errors::InvalidArgument("EagerService:CreateContext failed. ", |
| "Context id: <", request->context_id(), |
| "> already exists."); |
| } |
| contexts_.emplace(request->context_id(), |
| new ServerContext(ctx, request->keep_alive_secs(), env_)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::CreateMasterContext( |
| const tensorflow::uint64 context_id, EagerContext* context) { |
| { |
| mutex_lock l(contexts_mu_); |
| auto iter = contexts_.find(context_id); |
| if (iter != contexts_.end()) { |
| return errors::InvalidArgument( |
| "EagerService:CreateMasterContext failed. ", "Context id: <", |
| context_id, "> already exists."); |
| } |
| } |
| ServerContext* server_context = |
| ServerContext::CreateMasterContext(context, env_); |
| mutex_lock l(contexts_mu_); |
| contexts_.emplace(context_id, server_context); |
| return Status::OK(); |
| } |
| |
| Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { |
| const tensorflow::Tensor* t = nullptr; |
| |
| // TODO(nareshmodi): This call makes async calls sync calls. Fix this. |
| TF_RETURN_IF_ERROR(handle->Tensor(&t)); |
| |
| t->shape().AsProto(proto); |
| |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::ExecuteOp(const Operation& operation, |
| EagerContext* eager_context, |
| EagerExecutor* eager_executor, |
| QueueResponse* queue_response) { |
| std::unique_ptr<tensorflow::EagerOperation> op; |
| const char* name = operation.name().c_str(); // Shorthand |
| const tensorflow::AttrTypeMap* types; |
| bool is_function = false; |
| TF_RETURN_IF_ERROR(tensorflow::AttrTypeMapForOp(name, &types, &is_function)); |
| if (is_function && !eager_context->FindFunctionByName(name)) { |
| return errors::NotFound( |
| "'", name, |
| "' is neither a type of a primitive operation nor a name " |
| "of a function registered in binary running on ", |
| port::Hostname(), |
| ". Make sure the operation or function is " |
| "registered in the binary running in this process."); |
| } |
| op.reset(new tensorflow::EagerOperation(eager_context, name, is_function, |
| types, eager_executor)); |
| |
| TF_RETURN_IF_ERROR(op->SetDeviceName(operation.device().c_str())); |
| |
| { |
| profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal", |
| profiler::TraceMeLevel::kVerbose); |
| for (const auto& remote_handle : operation.inputs()) { |
| tensorflow::TensorHandle* handle; |
| TF_RETURN_IF_ERROR( |
| eager_context->RemoteMgr()->DeserializeRemoteTensorHandle( |
| remote_handle, &handle)); |
| op->AddInput(handle); |
| // Unref handle since it has a ref as an input now. |
| handle->Unref(); |
| } |
| } |
| |
| for (const auto& attr : operation.attrs()) { |
| op->MutableAttrs()->Set(attr.first, attr.second); |
| } |
| |
| int num_retvals = 0; |
| // TODO(nareshmodi): Consider caching this. |
| TF_RETURN_IF_ERROR(GetNumRetvals(eager_context, operation.name(), |
| operation.attrs(), &num_retvals)); |
| |
| absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals); |
| VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id(); |
| TF_RETURN_IF_ERROR(EagerExecute(op.get(), retvals.data(), &num_retvals)); |
| |
| eager_context->RemoteMgr()->AddOperationOutputs( |
| absl::MakeSpan(retvals.data(), num_retvals), operation.id()); |
| |
| for (int i = 0; i < num_retvals; i++) { |
| TF_RETURN_IF_ERROR( |
| TensorHandleShape(retvals[i], queue_response->add_shape())); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::Enqueue(const EnqueueRequest* request, |
| EnqueueResponse* response, uint64 stream_id) { |
| profiler::TraceMe activity( |
| [&] { |
| return absl::StrCat("EagerService:Enqueue:", request->DebugString()); |
| }, |
| profiler::TraceMeLevel::kInfo); |
| ServerContext* context = nullptr; |
| TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); |
| core::ScopedUnref context_unref(context); |
| |
| EagerExecutor& executor = |
| stream_id == kInvalidStreamId |
| ? context->Context()->Executor() |
| : context->Context()->RemoteMgr()->GetOrCreateExecutorForStream( |
| stream_id); |
| Status s; |
| for (const auto& item : request->queue()) { |
| auto* queue_response = response->add_queue_response(); |
| if (item.has_operation()) { |
| s = ExecuteOp(item.operation(), context->Context(), &executor, |
| queue_response); |
| } else if (item.has_handle_to_decref()) { |
| auto handle_to_decref = absl::make_unique<RemoteTensorHandleInternal>( |
| item.handle_to_decref()); |
| auto node = absl::make_unique<ClientTensorHandleDeleteNode>( |
| context, std::move(handle_to_decref)); |
| s = context->Context()->Executor().AddOrExecute(std::move(node)); |
| } else { |
| s = SendTensor(item.send_tensor(), context->Context()); |
| } |
| |
| if (!s.ok()) { |
| if (stream_id != kInvalidStreamId) { |
| // TODO(b/138847548): Cleanup the executor when StreamCall is deleted. |
| context->Context()->RemoteMgr()->DeleteExecutorForStream(stream_id); |
| } |
| return s; |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request, |
| WaitQueueDoneResponse* response) { |
| ServerContext* context = nullptr; |
| TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); |
| core::ScopedUnref context_unref(context); |
| |
| if (request->op_id_size() > 0) { |
| return errors::Unimplemented( |
| "EagerServiceImpl::WaitQueueDone is not " |
| "implemented for particular op IDs."); |
| } |
| return context->Context()->Executor().WaitForAllPendingNodes(); |
| } |
| |
| Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request, |
| KeepAliveResponse* response) { |
| ServerContext* context = nullptr; |
| TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); |
| core::ScopedUnref context_unref(context); |
| |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::CloseContext(const CloseContextRequest* request, |
| CloseContextResponse* response) { |
| VLOG(1) << "Executing EagerService::CloseContext for context " |
| << request->context_id(); |
| ServerContext* context = nullptr; |
| if (!GetServerContext(request->context_id(), &context).ok()) { |
| // Swallow the error here. |
| return Status::OK(); |
| } |
| |
| core::ScopedUnref context_unref(context); |
| |
| mutex_lock l(contexts_mu_); |
| contexts_.erase(request->context_id()); |
| |
| // GetServerContext returns a newly Reffed copy of ServerContext, which is |
| // unreffed by context_unref. Additionally, we need to unref it one time since |
| // we are releasing it from the map. |
| context->Unref(); |
| |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::RegisterFunction( |
| const RegisterFunctionRequest* request, |
| RegisterFunctionResponse* response) { |
| ServerContext* context = nullptr; |
| TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); |
| core::ScopedUnref context_unref(context); |
| |
| // If the function is a component of a multi-device function, we only need to |
| // register it locally. |
| return context->Context()->AddFunctionDef(request->function_def(), |
| request->is_component_function()); |
| } |
| |
| Status EagerServiceImpl::SendTensor(const SendTensorRequest* request, |
| SendTensorResponse* response) { |
| ServerContext* context = nullptr; |
| TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); |
| core::ScopedUnref context_unref(context); |
| |
| tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors; |
| for (const auto& tensor_proto : request->tensors()) { |
| Tensor tensor; |
| if (!tensor.FromProto(tensor_proto)) { |
| return errors::InvalidArgument("Unable to parse tensor proto"); |
| } |
| |
| TensorHandle* tensor_handle = nullptr; |
| TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(tensor, &tensor_handle)); |
| TensorHandle* copied_handle = nullptr; |
| EagerContext* ctx = context->Context(); |
| Device* device; |
| TF_RETURN_IF_ERROR( |
| ctx->FindDeviceFromName(request->device_name().c_str(), &device)); |
| TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, ctx, &ctx->Executor(), |
| device, false, &copied_handle)); |
| tensors.push_back(copied_handle); |
| tensor_handle->Unref(); |
| } |
| |
| context->Context()->RemoteMgr()->AddOperationOutputs(tensors, |
| request->op_id()); |
| |
| return Status::OK(); |
| } |
| |
| Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor, |
| EagerContext* eager_context) { |
| tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors; |
| for (const auto& tensor_proto : send_tensor.tensors()) { |
| Tensor tensor; |
| if (!tensor.FromProto(tensor_proto)) { |
| return errors::InvalidArgument("Unable to parse tensor proto"); |
| } |
| |
| TensorHandle* tensor_handle = nullptr; |
| TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(tensor, &tensor_handle)); |
| TensorHandle* copied_handle = nullptr; |
| Device* device; |
| TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName( |
| send_tensor.device_name().c_str(), &device)); |
| TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, eager_context, |
| &eager_context->Executor(), device, |
| false, &copied_handle)); |
| tensors.push_back(copied_handle); |
| tensor_handle->Unref(); |
| } |
| |
| eager_context->RemoteMgr()->AddOperationOutputs(tensors, send_tensor.op_id()); |
| |
| return Status::OK(); |
| } |
| |
| tensorflow::Status EagerServiceImpl::GetServerContext( |
| uint64 context_id, ServerContext** server_context) { |
| mutex_lock l(contexts_mu_); |
| auto iter = contexts_.find(context_id); |
| if (iter == contexts_.end()) { |
| *server_context = nullptr; |
| return errors::InvalidArgument(strings::Printf( |
| "Unable to find a context_id matching the specified one " |
| "(%llu). Perhaps the worker was restarted, or the context was GC'd?", |
| context_id)); |
| } |
| |
| *server_context = iter->second; |
| (*server_context)->Ref(); |
| |
| (*server_context)->RecordAccess(); |
| |
| return Status::OK(); |
| } |
| |
| } // namespace eager |
| } // namespace tensorflow |