blob: 7850ecc46b2244cdd31ade3a466406dd99319bcb [file] [log] [blame]
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/worker.h"
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
namespace tensorflow {
Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
// Enable log history collection in StatusGroup so that recent warning and
// error log messages will be attached to the root error status to be
// forwarded to the master.
StatusGroup::ConfigureLogHistory();
}
void Worker::GetStatusAsync(const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) {
DeviceMgr* dm = env_->device_mgr;
std::vector<DeviceAttributes> devices;
dm->ListDeviceAttributes(&devices);
response->mutable_device_attributes()->Reserve(devices.size());
for (auto& d : devices) {
response->add_device_attributes()->Swap(&d);
}
done(Status::OK());
}
void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
CreateWorkerSessionResponse* response,
StatusCallback done) {
Status s = env_->session_mgr->CreateSession(
request->session_handle(), request->server_def(),
request->cluster_device_attributes(), request->isolate_session_state());
done(s);
}
void Worker::DeleteWorkerSessionAsync(CallOptions* opts,
const DeleteWorkerSessionRequest* request,
DeleteWorkerSessionResponse* response,
StatusCallback done) {
Status s = env_->session_mgr->DeleteSession(request->session_handle());
done(s);
}
void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
RegisterGraphResponse* response,
StatusCallback done) {
std::shared_ptr<WorkerSession> session;
Status s;
if (request->create_worker_session_called()) {
s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
&session);
} else {
session = env_->session_mgr->LegacySession();
}
if (s.ok()) {
s = session->graph_mgr()->Register(
request->session_handle(), request->graph_def(), session.get(),
request->graph_options(), request->debug_options(),
request->config_proto(), request->collective_graph_key(),
session->cluster_flr(), response->mutable_graph_handle());
}
done(s);
}
void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
DeregisterGraphResponse* response,
StatusCallback done) {
std::shared_ptr<WorkerSession> session;
Status s;
if (request->create_worker_session_called()) {
s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
&session);
} else {
session = env_->session_mgr->LegacySession();
}
if (s.ok()) {
s = session->graph_mgr()->Deregister(request->graph_handle());
}
done(s);
}
void Worker::AbortStep(int64 step_id) {
Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
// Delay a bit before aborting the step. This way, the root
// cause may return first back to the client instead of this
// cancellation generated abort error.
rendez->StartAbort(errors::Aborted("Step ", step_id,
" cancelled. Cancelling rendezvous."));
rendez->Unref();
});
}
Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
GraphMgr::NamedTensors* in,
GraphMgr::NamedTensors* out) {
static Tensor empty_tensor(DT_FLOAT);
if (req->num_sends() > 0) {
Tensor val;
for (size_t i = 0; i < req->num_sends(); ++i) {
TF_RETURN_IF_ERROR(req->SendValue(i, &val));
in->insert({req->send_key(i), val});
}
}
for (size_t i = 0; i < req->num_recvs(); ++i) {
out->insert({req->recv_key(i), empty_tensor});
}
return Status::OK();
}
void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
MutableRunGraphResponseWrapper* response,
StatusCallback done) {
if (request->store_errors_in_response_body()) {
done = [response, done](const Status& status) {
response->set_status(status);
done(Status::OK());
};
}
if (request->is_partial()) {
DoPartialRunGraph(opts, request, response, std::move(done));
} else {
DoRunGraph(opts, request, response, std::move(done));
}
}
MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() {
return new InMemoryRunGraphRequest;
}
MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() {
return new InMemoryRunGraphResponse;
}
void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
MutableRunGraphResponseWrapper* response,
StatusCallback done) {
const int64 step_id = request->step_id();
TRACEPRINTF("RunGraph: %lld", step_id);
Status s = recent_request_ids_.TrackUnique(request->request_id(),
"RunGraph (Worker)", request);
if (!s.ok()) {
done(s);
return;
}
std::shared_ptr<WorkerSession> session;
if (request->create_worker_session_called()) {
s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
&session);
} else {
session = env_->session_mgr->LegacySession();
}
if (!s.ok()) {
done(s);
return;
}
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
s = PrepareRunGraph(request, &in, out);
if (!s.ok()) {
delete out;
done(s);
return;
}
StepStatsCollector* collector = nullptr;
if (request->exec_opts().report_tensor_allocations_upon_oom() ||
request->exec_opts().record_timeline() ||
request->exec_opts().record_costs()) {
collector = new StepStatsCollector(response->mutable_step_stats());
}
ProfilerSession* profiler_session = nullptr;
if (collector && request->exec_opts().record_timeline()) {
// If timeline was requested, assume we want hardware level tracing.
profiler_session = ProfilerSession::Create().release();
}
CancellationManager* cm = new CancellationManager;
opts->SetCancelCallback([this, cm, step_id]() {
LOG(INFO) << "Cancellation requested for RunGraph.";
cm->StartCancel();
AbortStep(step_id);
});
CancellationToken token;
token = cancellation_manager_.get_cancellation_token();
bool already_cancelled = !cancellation_manager_.RegisterCallback(
token, [cm]() { cm->StartCancel(); });
if (already_cancelled) {
opts->ClearCancelCallback();
delete cm;
delete collector;
delete profiler_session;
delete out;
done(errors::Aborted("Call was aborted"));
return;
}
session->graph_mgr()->ExecuteAsync(
request->graph_handle(), step_id, session.get(), request->exec_opts(),
collector, response, cm, in,
[this, step_id, response, session, cm, out, token, collector,
profiler_session, opts, done](const Status& status) {
Status s = status;
if (s.ok()) {
s = session->graph_mgr()->RecvOutputs(step_id, out);
}
opts->ClearCancelCallback();
cancellation_manager_.DeregisterCallback(token);
delete cm;
if (profiler_session) {
RunMetadata run_metadata;
profiler_session->CollectData(&run_metadata).IgnoreError();
response->mutable_step_stats()->MergeFrom(run_metadata.step_stats());
}
if (s.ok()) {
for (const auto& p : *out) {
const string& key = p.first;
const Tensor& val = p.second;
response->AddRecv(key, val);
}
}
if (collector) collector->Finalize();
delete collector;
delete profiler_session;
delete out;
done(s);
});
}
// TODO(suharshs): Add stats collection support to partial run.
void Worker::DoPartialRunGraph(CallOptions* opts,
RunGraphRequestWrapper* request,
MutableRunGraphResponseWrapper* response,
StatusCallback done) {
const int64 step_id = request->step_id();
const string& graph_handle = request->graph_handle();
TRACEPRINTF("PartialRunGraph: %lld", step_id);
Status s = recent_request_ids_.TrackUnique(
request->request_id(), "PartialRunGraph (Worker)", request);
if (!s.ok()) {
done(s);
return;
}
std::shared_ptr<WorkerSession> session;
if (request->create_worker_session_called()) {
s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
&session);
} else {
session = env_->session_mgr->LegacySession();
}
if (!s.ok()) {
done(s);
return;
}
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
s = PrepareRunGraph(request, &in, out);
auto finish = [done, out, opts](const Status& s) {
opts->ClearCancelCallback();
delete out;
done(s);
};
if (!s.ok()) {
finish(s);
return;
}
CancellationManager* cm = nullptr;
bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);
// Before we start doing anything, we set the RPC cancellation.
opts->SetCancelCallback([this, cm, step_id]() {
LOG(INFO) << "Cancellation requested for PartialRunGraph.";
cm->StartCancel();
AbortStep(step_id);
});
// If this is a new partial run request, the request will need to start the
// executors.
if (is_new_partial_run) {
CancellationToken token;
token = cancellation_manager_.get_cancellation_token();
cancellation_manager_.RegisterCallback(token,
[cm]() { cm->StartCancel(); });
session->graph_mgr()->ExecuteAsync(
graph_handle, step_id, session.get(), request->exec_opts(),
nullptr /* collector */, nullptr /* response */, cm, in,
[this, token, step_id, session](Status s) {
cancellation_manager_.DeregisterCallback(token);
partial_run_mgr_.ExecutorDone(step_id, s);
});
} else {
// Send the partial run's new inputs.
s = session->graph_mgr()->SendInputs(step_id, in);
if (!s.ok()) {
finish(s);
return;
}
}
session->graph_mgr()->RecvOutputsAsync(
step_id, out, [this, out, request, response, step_id, finish](Status s) {
if (s.ok()) {
// Construct and return the resp.
for (const auto& p : *out) {
const string& key = p.first;
const Tensor& val = p.second;
response->AddRecv(key, val);
}
}
if (request->is_last_partial_run()) {
partial_run_mgr_.PartialRunDone(step_id, finish, s);
} else {
finish(s);
}
});
}
void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) {
const int64 step_id = request->step_id();
env_->rendezvous_mgr->Cleanup(step_id);
if (env_->collective_executor_mgr) {
env_->collective_executor_mgr->Cleanup(step_id);
}
for (Device* d : env_->local_devices) {
ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
if (sam) {
sam->Cleanup(step_id);
}
}
done(Status::OK());
}
void Worker::CleanupAllAsync(const CleanupAllRequest* request,
CleanupAllResponse* response,
StatusCallback done) {
std::vector<string> containers;
for (const auto& c : request->container()) containers.push_back(c);
env_->device_mgr->ClearContainers(containers);
done(Status::OK());
}
void Worker::LoggingAsync(const LoggingRequest* request,
LoggingResponse* response, StatusCallback done) {
done(errors::Unimplemented("Logging"));
}
void Worker::TracingAsync(const TracingRequest* request,
TracingResponse* response, StatusCallback done) {
done(errors::Unimplemented("Tracing"));
}
void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
RecvBufResponse* response, StatusCallback done) {
// The base Worker class does not implement RecvBufAsync because
// it is not currently used for worker-to-worker communication. Use a
// transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
// instead.
done(errors::Unimplemented("Worker::RecvBufAsync()"));
}
void Worker::CompleteGroupAsync(CallOptions* opts,
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
StatusCallback done) {
if (env_->collective_executor_mgr) {
env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync(
request, response, &cancellation_manager_, done);
} else {
done(
errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
}
}
void Worker::CompleteInstanceAsync(CallOptions* opts,
const CompleteInstanceRequest* request,
CompleteInstanceResponse* response,
StatusCallback done) {
if (env_->collective_executor_mgr) {
env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync(
request, response, &cancellation_manager_, done);
} else {
done(
errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
}
}
void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request,
GetStepSequenceResponse* response,
StatusCallback done) {
if (env_->collective_executor_mgr) {
env_->collective_executor_mgr->GetStepSequenceAsync(request, response,
done);
} else {
done(
errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
}
}
// Helper for RecvTensor. Validates "key" and returns the source
// device in "*src_dev".
Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
Device** src_dev) {
// Figures out which device the tensor is hosted on.
string local_name = DeviceNameUtils::LocalName(parsed.src_device);
TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
// Does the device have the right incarnation number we expect?
if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
return errors::Aborted(
"RecvTensor expects a different device incarnation: ",
parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
". Your worker job (\"",
env_->session_mgr->LegacySession()->worker_name(),
"\") was probably restarted. Check your "
"worker job for the reason why it was restarted.");
}
return Status::OK();
}
void Worker::RecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
TensorResponse* response, StatusCallback done) {
// The base Worker class does not implement RecvTensorAsync, because
// it is not currently used for worker-to-worker communication. Use a
// transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`)
// instead.
done(errors::Unimplemented("Worker::RecvTensorAsync()"));
}
} // namespace tensorflow