| /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/distributed_runtime/worker.h" |
| |
| #include "tensorflow/core/common_runtime/collective_executor_mgr.h" |
| #include "tensorflow/core/common_runtime/device_mgr.h" |
| #include "tensorflow/core/common_runtime/process_util.h" |
| #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h" |
| #include "tensorflow/core/common_runtime/step_stats_collector.h" |
| #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" |
| #include "tensorflow/core/distributed_runtime/tensor_coding.h" |
| #include "tensorflow/core/distributed_runtime/worker_session.h" |
| #include "tensorflow/core/platform/tracing.h" |
| #include "tensorflow/core/profiler/lib/profiler_session.h" |
| |
| namespace tensorflow { |
| |
| Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) { |
| // Enable log history collection in StatusGroup so that recent warning and |
| // error log messages will be attached to the root error status to be |
| // forwarded to the master. |
| StatusGroup::ConfigureLogHistory(); |
| } |
| |
| void Worker::GetStatusAsync(const GetStatusRequest* request, |
| GetStatusResponse* response, StatusCallback done) { |
| DeviceMgr* dm = env_->device_mgr; |
| std::vector<DeviceAttributes> devices; |
| dm->ListDeviceAttributes(&devices); |
| response->mutable_device_attributes()->Reserve(devices.size()); |
| for (auto& d : devices) { |
| response->add_device_attributes()->Swap(&d); |
| } |
| done(Status::OK()); |
| } |
| |
| void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, |
| CreateWorkerSessionResponse* response, |
| StatusCallback done) { |
| Status s = env_->session_mgr->CreateSession( |
| request->session_handle(), request->server_def(), |
| request->cluster_device_attributes(), request->isolate_session_state()); |
| done(s); |
| } |
| |
| void Worker::DeleteWorkerSessionAsync(CallOptions* opts, |
| const DeleteWorkerSessionRequest* request, |
| DeleteWorkerSessionResponse* response, |
| StatusCallback done) { |
| Status s = env_->session_mgr->DeleteSession(request->session_handle()); |
| done(s); |
| } |
| |
| void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, |
| RegisterGraphResponse* response, |
| StatusCallback done) { |
| std::shared_ptr<WorkerSession> session; |
| Status s; |
| if (request->create_worker_session_called()) { |
| s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), |
| &session); |
| } else { |
| session = env_->session_mgr->LegacySession(); |
| } |
| if (s.ok()) { |
| s = session->graph_mgr->Register( |
| request->session_handle(), request->graph_def(), session.get(), |
| request->graph_options(), request->debug_options(), |
| request->collective_graph_key(), session->cluster_flr.get(), |
| response->mutable_graph_handle()); |
| } |
| done(s); |
| } |
| |
| void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request, |
| DeregisterGraphResponse* response, |
| StatusCallback done) { |
| std::shared_ptr<WorkerSession> session; |
| Status s; |
| if (request->create_worker_session_called()) { |
| s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), |
| &session); |
| } else { |
| session = env_->session_mgr->LegacySession(); |
| } |
| if (s.ok()) { |
| s = session->graph_mgr->Deregister(request->graph_handle()); |
| } |
| |
| done(s); |
| } |
| |
| void Worker::AbortStep(int64 step_id) { |
| Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id); |
| SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() { |
| // Delay a bit before aborting the step. This way, the root |
| // cause may return first back to the client instead of this |
| // cancellation generated abort error. |
| rendez->StartAbort(errors::Aborted("Step ", step_id, |
| " cancelled. Cancelling rendezvous.")); |
| rendez->Unref(); |
| }); |
| } |
| |
| Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req, |
| GraphMgr::NamedTensors* in, |
| GraphMgr::NamedTensors* out) { |
| static Tensor empty_tensor(DT_FLOAT); |
| if (req->num_sends() > 0) { |
| Tensor val; |
| for (size_t i = 0; i < req->num_sends(); ++i) { |
| TF_RETURN_IF_ERROR(req->SendValue(i, &val)); |
| in->insert({req->send_key(i), val}); |
| } |
| } |
| for (size_t i = 0; i < req->num_recvs(); ++i) { |
| out->insert({req->recv_key(i), empty_tensor}); |
| } |
| return Status::OK(); |
| } |
| |
| void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, |
| MutableRunGraphResponseWrapper* response, |
| StatusCallback done) { |
| if (request->store_errors_in_response_body()) { |
| done = [response, done](const Status& status) { |
| response->set_status(status); |
| done(Status::OK()); |
| }; |
| } |
| if (request->is_partial()) { |
| DoPartialRunGraph(opts, request, response, std::move(done)); |
| } else { |
| DoRunGraph(opts, request, response, std::move(done)); |
| } |
| } |
| |
| MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() { |
| return new InMemoryRunGraphRequest; |
| } |
| |
| MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() { |
| return new InMemoryRunGraphResponse; |
| } |
| |
| void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, |
| MutableRunGraphResponseWrapper* response, |
| StatusCallback done) { |
| const int64 step_id = request->step_id(); |
| TRACEPRINTF("RunGraph: %lld", step_id); |
| Status s = recent_request_ids_.TrackUnique(request->request_id(), |
| "RunGraph (Worker)", request); |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| |
| std::shared_ptr<WorkerSession> session; |
| if (request->create_worker_session_called()) { |
| s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), |
| &session); |
| } else { |
| session = env_->session_mgr->LegacySession(); |
| } |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| GraphMgr::NamedTensors in; |
| GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; |
| s = PrepareRunGraph(request, &in, out); |
| if (!s.ok()) { |
| delete out; |
| done(s); |
| return; |
| } |
| StepStatsCollector* collector = nullptr; |
| if (request->exec_opts().report_tensor_allocations_upon_oom() || |
| request->exec_opts().record_timeline() || |
| request->exec_opts().record_costs()) { |
| collector = new StepStatsCollector(response->mutable_step_stats()); |
| } |
| ProfilerSession* profiler_session = nullptr; |
| if (collector && request->exec_opts().record_timeline()) { |
| // If timeline was requested, assume we want hardware level tracing. |
| profiler_session = ProfilerSession::Create().release(); |
| } |
| CancellationManager* cm = new CancellationManager; |
| opts->SetCancelCallback([this, cm, step_id]() { |
| LOG(INFO) << "Cancellation requested for RunGraph."; |
| cm->StartCancel(); |
| AbortStep(step_id); |
| }); |
| CancellationToken token; |
| token = cancellation_manager_.get_cancellation_token(); |
| bool already_cancelled = !cancellation_manager_.RegisterCallback( |
| token, [cm]() { cm->StartCancel(); }); |
| if (already_cancelled) { |
| opts->ClearCancelCallback(); |
| delete cm; |
| delete collector; |
| delete profiler_session; |
| delete out; |
| done(errors::Aborted("Call was aborted")); |
| return; |
| } |
| session->graph_mgr->ExecuteAsync( |
| request->graph_handle(), step_id, session.get(), request->exec_opts(), |
| collector, response, cm, in, |
| [this, step_id, response, session, cm, out, token, collector, |
| profiler_session, opts, done](const Status& status) { |
| Status s = status; |
| if (s.ok()) { |
| s = session->graph_mgr->RecvOutputs(step_id, out); |
| } |
| |
| opts->ClearCancelCallback(); |
| cancellation_manager_.DeregisterCallback(token); |
| delete cm; |
| |
| if (profiler_session) { |
| RunMetadata run_metadata; |
| profiler_session->CollectData(&run_metadata).IgnoreError(); |
| response->mutable_step_stats()->MergeFrom(run_metadata.step_stats()); |
| } |
| |
| if (s.ok()) { |
| for (const auto& p : *out) { |
| const string& key = p.first; |
| const Tensor& val = p.second; |
| response->AddRecv(key, val); |
| } |
| } |
| |
| if (collector) collector->Finalize(); |
| delete collector; |
| delete profiler_session; |
| delete out; |
| done(s); |
| }); |
| } |
| |
| // TODO(suharshs): Add stats collection support to partial run. |
| void Worker::DoPartialRunGraph(CallOptions* opts, |
| RunGraphRequestWrapper* request, |
| MutableRunGraphResponseWrapper* response, |
| StatusCallback done) { |
| const int64 step_id = request->step_id(); |
| const string& graph_handle = request->graph_handle(); |
| TRACEPRINTF("PartialRunGraph: %lld", step_id); |
| Status s = recent_request_ids_.TrackUnique( |
| request->request_id(), "PartialRunGraph (Worker)", request); |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| |
| std::shared_ptr<WorkerSession> session; |
| if (request->create_worker_session_called()) { |
| s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), |
| &session); |
| } else { |
| session = env_->session_mgr->LegacySession(); |
| } |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| |
| GraphMgr::NamedTensors in; |
| GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; |
| s = PrepareRunGraph(request, &in, out); |
| auto finish = [done, out, opts](const Status& s) { |
| opts->ClearCancelCallback(); |
| delete out; |
| done(s); |
| }; |
| if (!s.ok()) { |
| finish(s); |
| return; |
| } |
| |
| CancellationManager* cm = nullptr; |
| bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm); |
| |
| // Before we start doing anything, we set the RPC cancellation. |
| opts->SetCancelCallback([this, cm, step_id]() { |
| LOG(INFO) << "Cancellation requested for PartialRunGraph."; |
| cm->StartCancel(); |
| AbortStep(step_id); |
| }); |
| |
| // If this is a new partial run request, the request will need to start the |
| // executors. |
| if (is_new_partial_run) { |
| CancellationToken token; |
| token = cancellation_manager_.get_cancellation_token(); |
| cancellation_manager_.RegisterCallback(token, |
| [cm]() { cm->StartCancel(); }); |
| session->graph_mgr->ExecuteAsync( |
| graph_handle, step_id, session.get(), request->exec_opts(), |
| nullptr /* collector */, nullptr /* response */, cm, in, |
| [this, token, step_id, session](Status s) { |
| cancellation_manager_.DeregisterCallback(token); |
| partial_run_mgr_.ExecutorDone(step_id, s); |
| }); |
| } else { |
| // Send the partial run's new inputs. |
| s = session->graph_mgr->SendInputs(step_id, in); |
| if (!s.ok()) { |
| finish(s); |
| return; |
| } |
| } |
| |
| session->graph_mgr->RecvOutputsAsync( |
| step_id, out, [this, out, request, response, step_id, finish](Status s) { |
| if (s.ok()) { |
| // Construct and return the resp. |
| for (const auto& p : *out) { |
| const string& key = p.first; |
| const Tensor& val = p.second; |
| response->AddRecv(key, val); |
| } |
| } |
| if (request->is_last_partial_run()) { |
| partial_run_mgr_.PartialRunDone(step_id, finish, s); |
| } else { |
| finish(s); |
| } |
| }); |
| } |
| |
| void Worker::CleanupGraphAsync(const CleanupGraphRequest* request, |
| CleanupGraphResponse* response, |
| StatusCallback done) { |
| const int64 step_id = request->step_id(); |
| env_->rendezvous_mgr->Cleanup(step_id); |
| if (env_->collective_executor_mgr) { |
| env_->collective_executor_mgr->Cleanup(step_id); |
| } |
| for (Device* d : env_->local_devices) { |
| ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr(); |
| if (sam) { |
| sam->Cleanup(step_id); |
| } |
| } |
| done(Status::OK()); |
| } |
| |
| void Worker::CleanupAllAsync(const CleanupAllRequest* request, |
| CleanupAllResponse* response, |
| StatusCallback done) { |
| std::vector<string> containers; |
| for (const auto& c : request->container()) containers.push_back(c); |
| env_->device_mgr->ClearContainers(containers); |
| done(Status::OK()); |
| } |
| |
| void Worker::LoggingAsync(const LoggingRequest* request, |
| LoggingResponse* response, StatusCallback done) { |
| done(errors::Unimplemented("Logging")); |
| } |
| |
| void Worker::TracingAsync(const TracingRequest* request, |
| TracingResponse* response, StatusCallback done) { |
| done(errors::Unimplemented("Tracing")); |
| } |
| |
| void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, |
| RecvBufResponse* response, StatusCallback done) { |
| // The base Worker class does not implement RecvBufAsync because |
| // it is not currently used for worker-to-worker communication. Use a |
| // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`) |
| // instead. |
| done(errors::Unimplemented("Worker::RecvBufAsync()")); |
| } |
| |
| void Worker::CompleteGroupAsync(CallOptions* opts, |
| const CompleteGroupRequest* request, |
| CompleteGroupResponse* response, |
| StatusCallback done) { |
| if (env_->collective_executor_mgr) { |
| env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync( |
| request, response, &cancellation_manager_, done); |
| } else { |
| done( |
| errors::Internal("Runtime not initialized with CollectiveExecutorMgr")); |
| } |
| } |
| |
| void Worker::CompleteInstanceAsync(CallOptions* opts, |
| const CompleteInstanceRequest* request, |
| CompleteInstanceResponse* response, |
| StatusCallback done) { |
| if (env_->collective_executor_mgr) { |
| env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync( |
| request, response, &cancellation_manager_, done); |
| } else { |
| done( |
| errors::Internal("Runtime not initialized with CollectiveExecutorMgr")); |
| } |
| } |
| |
| void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request, |
| GetStepSequenceResponse* response, |
| StatusCallback done) { |
| if (env_->collective_executor_mgr) { |
| env_->collective_executor_mgr->GetStepSequenceAsync(request, response, |
| done); |
| } else { |
| done( |
| errors::Internal("Runtime not initialized with CollectiveExecutorMgr")); |
| } |
| } |
| |
| // Helper for RecvTensor. Validates "key" and returns the source |
| // device in "*src_dev". |
| Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, |
| Device** src_dev) { |
| // Figures out which device the tensor is hosted on. |
| string local_name = DeviceNameUtils::LocalName(parsed.src_device); |
| TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev)); |
| |
| // Does the device have the right incarnation number we expect? |
| if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) { |
| return errors::Aborted( |
| "RecvTensor expects a different device incarnation: ", |
| parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(), |
| ". Your worker job (\"", |
| env_->session_mgr->LegacySession()->worker_name, |
| "\") was probably restarted. Check your " |
| "worker job for the reason why it was restarted."); |
| } |
| |
| return Status::OK(); |
| } |
| |
| void Worker::RecvTensorAsync(CallOptions* opts, |
| const RecvTensorRequest* request, |
| TensorResponse* response, StatusCallback done) { |
| // The base Worker class does not implement RecvTensorAsync, because |
| // it is not currently used for worker-to-worker communication. Use a |
| // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`) |
| // instead. |
| done(errors::Unimplemented("Worker::RecvTensorAsync()")); |
| } |
| |
| } // namespace tensorflow |