blob: 2648a2aec1cc2953a9cf119ba08f3f182eab3a81 [file] [log] [blame]
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include <deque>
#include <memory>
#include <unordered_map>
#include <vector>
#include "grpcpp/alarm.h"
#include "grpcpp/server_builder.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
#include "tensorflow/core/distributed_runtime/worker.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
#include "tensorflow/core/protobuf/transport_options.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
namespace tensorflow {
namespace {
// This macro creates a new request for the given RPC method name
// (e.g., `ENQUEUE_REQUEST(GetStatus, false);`), and enqueues it on
// `this->cq_`.
//
// This macro is invoked one or more times for each RPC method to
// ensure that there are sufficient completion queue entries to
// handle incoming requests without blocking.
//
// The implementation of the request handler for each RPC method
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
// to keep accepting new requests.
#define ENQUEUE_REQUEST(method, supports_cancel) \
do { \
mutex_lock l(shutdown_mu_); \
if (!is_shutdown_) { \
Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService, \
method##Request, method##Response>:: \
EnqueueRequestForMethod( \
worker_service_, cq_.get(), \
static_cast<int>(GrpcWorkerMethod::k##method), \
&GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
} \
} while (0)
#define SETUP_FOR_REQUEST(method, default_depth, supports_cancel) \
for (int i = 0; \
i < gtl::FindWithDefault(queue_depth_, \
static_cast<int>(GrpcWorkerMethod::k##method), \
default_depth); \
++i) { \
ENQUEUE_REQUEST(method, supports_cancel); \
}
// GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
// requests. Each thread operates on an independent completion queue.
class GrpcWorkerServiceThread {
public:
explicit GrpcWorkerServiceThread(
GrpcWorker* worker, ::grpc::ServerBuilder* builder,
std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
grpc::WorkerService::AsyncService* worker_service)
: worker_(worker),
queue_depth_(queue_depth),
cache_(cache),
worker_service_(worker_service),
is_shutdown_(false) {
cq_ = builder->AddCompletionQueue();
}
void Start() {
thread_.reset(
worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
[this]() { HandleRPCsLoop(); }));
}
void Join() { thread_.reset(); } // Blocks until thread exits
void Shutdown() {
{
mutex_lock lock(shutdown_mu_);
is_shutdown_ = true;
}
cq_->Shutdown();
}
private:
// Add one or more completion queue entries for each worker method, then
// begin servicing requests from the completion queue.
void HandleRPCsLoop() {
// TODO(ncteisen): This may require performance engineering. We can
// change the number of threads, the number of handlers per thread,
// or even decide to specialize certain threads to certain methods.
SETUP_FOR_REQUEST(GetStatus, 1, false);
SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
SETUP_FOR_REQUEST(CleanupAll, 1, false);
SETUP_FOR_REQUEST(RegisterGraph, 1, false);
SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
SETUP_FOR_REQUEST(Logging, 1, false);
SETUP_FOR_REQUEST(Tracing, 1, false);
SETUP_FOR_REQUEST(CompleteGroup, 10, true);
SETUP_FOR_REQUEST(CompleteInstance, 10, true);
SETUP_FOR_REQUEST(GetStepSequence, 10, true);
SETUP_FOR_REQUEST(RecvBuf, 500, true);
SETUP_FOR_REQUEST(RunGraph, 100, true);
SETUP_FOR_REQUEST(CleanupGraph, 100, false);
SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);
// TODO(ncteisen): Determine a better policy for enqueuing the
// appropriate number of each request type.
for (int i = 0;
i < gtl::FindWithDefault(
queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
1000);
++i) {
EnqueueRecvTensorRequestRaw();
}
void* tag;
bool ok;
while (cq_->Next(&tag, &ok)) {
UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
CHECK(callback_tag);
callback_tag->OnCompleted(this, ok);
}
}
private:
void Schedule(std::function<void()> f) {
worker_->env()->compute_pool->Schedule(std::move(f));
}
// The following section contains one request handler method per
// RPC. The `FooHandler` method is called (indirectly) by
// `HandleRPCsLoop()` when the next Foo RPC is received. Each
// `FooHandler` call schedules a closure on `worker_->env()->compute_pool`,
// and is responsible for requesting the next Foo call by calling
// `ENQUEUE_REQUEST(Foo)`.
template <class RequestMessage, class ResponseMessage>
using WorkerCall =
Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
RequestMessage, ResponseMessage>;
// Handle all non-cancellable simple methods with a standard wrapper.
// The boolean `may_block_on_compute_pool` indicates whether or not the
// operation may block on activities (such as op execution) that run on the
// compute pool.
#define HANDLE_CALL(method, may_block_on_compute_pool) \
void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
auto closure = [this, call]() { \
Status s = worker_->method(&call->request, &call->response); \
if (!s.ok()) { \
VLOG(3) << "Bad response from " << #method << ": " << s; \
} \
call->SendResponse(ToGrpcStatus(s)); \
}; \
if ((may_block_on_compute_pool)) { \
worker_->env()->env->SchedClosure(std::move(closure)); \
} else { \
worker_->env()->compute_pool->Schedule(std::move(closure)); \
} \
ENQUEUE_REQUEST(method, false); \
}
HANDLE_CALL(GetStatus, false);
HANDLE_CALL(CreateWorkerSession, false);
HANDLE_CALL(DeleteWorkerSession, true);
HANDLE_CALL(CleanupAll, false);
HANDLE_CALL(RegisterGraph, false);
HANDLE_CALL(DeregisterGraph, false);
HANDLE_CALL(CleanupGraph, false);
HANDLE_CALL(Logging, false);
HANDLE_CALL(Tracing, false);
#undef HANDLE_CALL
void GetStepSequenceHandler(
WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
Schedule([this, call]() {
worker_->GetStepSequenceAsync(
&call->request, &call->response, [call](const Status& s) {
VLOG(3) << "Bad response from GetStepSequence:" << s;
call->SendResponse(ToGrpcStatus(s));
});
});
ENQUEUE_REQUEST(GetStepSequence, true);
}
void MarkRecvFinishedHandler(
WorkerCall<MarkRecvFinishedRequest, MarkRecvFinishedResponse>* call) {
VLOG(3) << "Clean cache entry for request " << call->request.request_id();
worker_->RemoveCacheEntryForId(call->request.request_id());
call->SendResponse(::grpc::Status::OK);
ENQUEUE_REQUEST(MarkRecvFinished, false);
}
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
Schedule([this, call]() {
CallOptions* call_opts = new CallOptions;
ProtoRunGraphRequest* wrapped_request =
new ProtoRunGraphRequest(&call->request);
NonOwnedProtoRunGraphResponse* wrapped_response =
new NonOwnedProtoRunGraphResponse(&call->response);
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
[call, call_opts, wrapped_request,
wrapped_response](const Status& s) {
VLOG(3) << "RunGraph::Done";
if (!s.ok()) {
VLOG(3) << "Bad response from RunGraph:" << s;
}
call->ClearCancelCallback();
delete call_opts;
delete wrapped_request;
delete wrapped_response;
call->SendResponse(ToGrpcStatus(s));
});
});
ENQUEUE_REQUEST(RunGraph, true);
}
void RecvTensorHandlerRaw(
WorkerCall<RecvTensorRequest, ::grpc::ByteBuffer>* call) {
Schedule([this, call]() {
CallOptions* call_opts = new CallOptions;
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
worker_->GrpcRecvTensorAsync(
call_opts, &call->request, &call->response,
[call, call_opts](const Status& s) {
call->ClearCancelCallback();
delete call_opts;
if (!s.ok()) {
VLOG(3) << "Bad response from RecvTensor:" << s;
}
call->SendResponse(ToGrpcStatus(s));
});
});
EnqueueRecvTensorRequestRaw();
}
void RecvBufHandler(WorkerCall<RecvBufRequest, RecvBufResponse>* call) {
Schedule([this, call]() {
CallOptions* call_opts = new CallOptions;
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
worker_->RecvBufAsync(call_opts, &call->request, &call->response,
[call, call_opts](const Status& s) {
call->ClearCancelCallback();
delete call_opts;
if (!s.ok()) {
VLOG(3) << "Bad response from RecvBuf:" << s;
}
call->SendResponse(ToGrpcStatus(s));
});
});
ENQUEUE_REQUEST(RecvBuf, true);
}
void CompleteGroupHandler(
WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
Schedule([this, call]() {
CallOptions* call_opts = new CallOptions;
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
worker_->CompleteGroupAsync(
call_opts, &call->request, &call->response,
[call, call_opts](const Status& s) {
call->ClearCancelCallback();
delete call_opts;
if (!s.ok()) {
VLOG(3) << "Bad response from CompleteGroup:" << s;
}
call->SendResponse(ToGrpcStatus(s));
});
});
ENQUEUE_REQUEST(CompleteGroup, true);
}
void CompleteInstanceHandler(
WorkerCall<CompleteInstanceRequest, CompleteInstanceResponse>* call) {
Schedule([this, call]() {
CallOptions* call_opts = new CallOptions;
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
worker_->CompleteInstanceAsync(
call_opts, &call->request, &call->response,
[call, call_opts](const Status& s) {
call->ClearCancelCallback();
delete call_opts;
if (!s.ok()) {
VLOG(3) << "Bad response from CompleteInstance:" << s;
}
call->SendResponse(ToGrpcStatus(s));
});
});
ENQUEUE_REQUEST(CompleteInstance, false);
}
#undef ENQUEUE_REQUEST
void EnqueueRecvTensorRequestRaw() {
mutex_lock l(shutdown_mu_);
if (!is_shutdown_) {
Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
RecvTensorRequest, ::grpc::ByteBuffer>::
EnqueueRequestForMethod(
worker_service_, cq_.get(),
static_cast<int>(GrpcWorkerMethod::kRecvTensor),
&GrpcWorkerServiceThread::RecvTensorHandlerRaw,
true /* supports cancel*/);
}
}
GrpcWorker* const worker_ = nullptr; // Not owned.
std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
std::unique_ptr<Thread> thread_;
std::unordered_map<int, int> queue_depth_;
GrpcResponseCache* cache_;
grpc::WorkerService::AsyncService* const worker_service_;
mutex shutdown_mu_;
bool is_shutdown_ TF_GUARDED_BY(shutdown_mu_);
TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerServiceThread);
};
class GrpcWorkerService : public AsyncServiceInterface {
public:
GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
GrpcWorkerServiceOptions options)
: is_shutdown_(false) {
builder->RegisterService(&worker_service_);
for (int i = 0; i < options.num_serving_threads; i++) {
threads_.emplace_back(
new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
cache_.get(), &worker_service_));
}
}
void Shutdown() override {
bool did_shutdown = false;
{
mutex_lock l(service_shutdown_mu_);
if (!is_shutdown_) {
LOG(INFO) << "Shutting down GrpcWorkerService.";
is_shutdown_ = true;
did_shutdown = true;
}
}
if (did_shutdown) {
for (auto& worker_thread : threads_) {
worker_thread->Shutdown();
}
}
}
// This method blocks forever handling requests from the completion queue.
void HandleRPCsLoop() override {
for (auto& worker_thread : threads_) {
worker_thread->Start();
}
for (auto& worker_thread : threads_) {
worker_thread->Join();
}
}
private:
grpc::WorkerService::AsyncService worker_service_;
std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;
std::unique_ptr<GrpcResponseCache> cache_;
mutex service_shutdown_mu_;
bool is_shutdown_ TF_GUARDED_BY(service_shutdown_mu_);
TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
};
} // namespace
GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config)
: Worker(worker_env),
recv_buf_max_chunk_(
config.experimental().recv_buf_max_chunk() > 0
? config.experimental().recv_buf_max_chunk()
: (config.experimental().recv_buf_max_chunk() < 0 ? 0 : 4096)) {
if (config.rpc_options().cache_rpc_response()) {
EnableResponseCache();
}
}
void GrpcWorker::EnableResponseCache() {
VLOG(3) << "Enabling gRPC tensor response cache.";
response_cache_ = absl::make_unique<GrpcResponseCache>();
}
// GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
// buffers for a response object, to avoid extra protocol buffer serialization
// overhead we generate our response directly into a ::grpc::ByteBuffer object
void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
::grpc::ByteBuffer* response,
StatusCallback done) {
VLOG(3) << "GrpcRecvTensorAsync req: " << request->DebugString();
const int64_t request_id = request->request_id();
const int64_t step_id = request->step_id();
bool cache_enabled = (response_cache_ != nullptr && request_id != 0);
auto do_response = [response, done, cache_enabled](const Tensor& tensor,
bool is_dead,
const Status& status) {
if (status.ok()) {
grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response);
}
done(status);
};
// If response cache is enabled and the response cache already contains the
// request, we delegate this retry request to the response cache. Otherwise,
// we add the request to the response cache and start the computation to
// retrieve the requested data.
if (cache_enabled &&
response_cache_->QueueRequest(request_id, step_id, do_response)) {
return;
}
auto rendezvous_done = [this, request_id, do_response, cache_enabled](
const Tensor& tensor, bool is_dead,
const Status& status) {
if (cache_enabled) {
// Data is ready. Process all pending requests in the response cache.
response_cache_->OnRequestFinished(request_id, tensor, is_dead, status);
} else {
do_response(tensor, is_dead, status);
}
};
auto fail = [&rendezvous_done](const Status& status) {
rendezvous_done(Tensor(), false, status);
};
Status s = recent_request_ids_.TrackUnique(
request_id, "RecvTensor (GrpcWorker)", *request);
if (!s.ok()) {
fail(s);
return;
}
const string& key = request->rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed;
s = Rendezvous::ParseKey(key, &parsed);
Device* src_dev = nullptr;
if (s.ok()) {
s = PrepareRecvTensor(parsed, &src_dev);
}
if (!s.ok()) {
fail(s);
return;
}
// Request the tensor associated with the rendezvous key.
// Any time while waiting for the tensor to be produced, up until the start of
// execution of the callback lambda body below, an RPC cancellation should
// abort the rendezvous.
// Note that gRPC can generate cancellations in response to transient network
// failures, and the client might not observe any errors or cancellations but
// simply waits for the responses. Aborting the step would report an error to
// the client, and avoid permanent hanging in distributed function execution.
opts->SetCancelCallback([this, step_id]() {
LOG(WARNING) << "RecvTensor cancelled for " << step_id;
AbortStep(step_id);
});
env_->rendezvous_mgr->RecvLocalAsync(
step_id, parsed,
[opts, rendezvous_done, src_dev, request](
const Status& status, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& val,
const bool is_dead) {
opts->ClearCancelCallback();
if (status.ok()) {
// DMA can only be used for Tensors that do not fall into
// the following three odd edge cases: 1) a zero-size
// buffer, 2) a dead tensor which has an uninit value, and
// 3) the tensor has the on_host allocation attribute,
// i.e. it's in CPU RAM *independent of its assigned
// device type*.
const bool on_host = send_args.alloc_attrs.on_host();
{
// Non-DMA cases.
if (src_dev->tensorflow_accelerator_device_info() && (!on_host)) {
DeviceContext* send_dev_context = send_args.device_context;
AllocatorAttributes alloc_attrs;
alloc_attrs.set_gpu_compatible(true);
alloc_attrs.set_on_host(true);
Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
CHECK(send_dev_context)
<< "send dev name: " << src_dev->name() << " gpu_info: "
<< src_dev->tensorflow_accelerator_device_info();
// "val" is on an accelerator device. Uses the device_context to
// fill the copy on host.
StatusCallback copy_ready = [rendezvous_done, copy,
is_dead](const Status& s) {
// The value is now ready to be returned on the wire.
rendezvous_done(*copy, is_dead, s);
delete copy;
};
CopyDeviceToHost(&val, alloc, alloc, request->rendezvous_key(),
src_dev, copy, send_dev_context, copy_ready);
return;
}
}
}
rendezvous_done(val, is_dead, status);
});
}
namespace {
// If RecvBufRespExtra.tensor_content is a single large string, then gRPC
// can stall on the recv side when the string buffer needs to be enlarged,
// since the size is not sent in advance. Changing this field to a sequence
// of small strings costs some extra time on the send side, since we do
// some otherwise unnecessary copies, but it improves runtime overall by
// improving flow control. Best performance is likely achieved with a
// max_chunk_bytes equal to the memory page size.
//
// TODO(tucker): When proto3 supports [ctype=CORD] then change
// RecvBufRespExtra.tensor_content to a cord instead of a repeated string,
// and remove this function.
void SetTensorInRecvBufResp(int64_t max_chunk_bytes, const Tensor* tensor,
RecvBufResponse* response) {
RecvBufRespExtra extra;
int64_t num_bytes = tensor->TotalBytes();
const char* head = reinterpret_cast<const char*>(DMAHelper::base(tensor));
while (num_bytes > 0) {
int64_t bytes =
max_chunk_bytes > 0 ? std::min(num_bytes, max_chunk_bytes) : num_bytes;
extra.add_tensor_content(std::string(head, bytes));
head += bytes;
num_bytes -= bytes;
}
response->mutable_transport_options()->PackFrom(extra);
}
} // namespace
void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
RecvBufResponse* response, StatusCallback done) {
const int64_t request_id = request->request_id();
const int64_t step_id = request->step_id();
bool cache_enabled = (response_cache_ != nullptr && request_id != 0);
auto do_response = [this, response, done, cache_enabled](
const Tensor& tensor, bool is_dead,
const Status& status) {
if (status.ok()) {
SetTensorInRecvBufResp(recv_buf_max_chunk_, &tensor, response);
}
response->set_send_start_micros(env_->env->NowMicros());
response->set_require_ack(cache_enabled);
done(status);
};
// If response cache is enabled and the response cache already contains the
// request, we delegate this retry request to the response cache. Otherwise,
// we add the request to the response cache and start the computation to
// retrieve the requested data.
if (cache_enabled &&
response_cache_->QueueRequest(request_id, step_id, do_response)) {
return;
}
auto rendezvous_done = [this, request_id, do_response, cache_enabled](
const Tensor& tensor, const Status& status) {
if (cache_enabled) {
// Data is ready. Process all pending requests in the response cache.
response_cache_->OnRequestFinished(request_id, tensor, false, status);
} else {
do_response(tensor, false, status);
}
};
auto fail = [&rendezvous_done](const Status& status) {
rendezvous_done(Tensor(), status);
};
// This is a generic, low performance implementation appropriate for grpc.
Status s = recent_request_ids_.TrackUnique(request_id, "RecvBuf (GrpcWorker)",
*request);
if (!s.ok()) {
fail(s);
return;
}
CollectiveExecutor::Handle ce_handle(
env_->collective_executor_mgr->FindOrCreate(step_id), true);
CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
auto consumer_callback = [this, request, rendezvous_done](
const Status& status,
BufRendezvous::Hook* hook) {
Status s = status;
if (s.ok()) {
if (hook == nullptr) {
s = errors::Internal("Invalid null hook for key ",
request->buf_rendezvous_key());
}
if (!DMAHelper::CanUseDMA(hook->prod_value)) {
s = errors::Internal("Tensor value for key ",
request->buf_rendezvous_key(),
" is not of a type supported by RecvBuf");
}
} else {
if (hook != nullptr) {
LOG(ERROR) << "Got hook " << hook << " with status " << s
<< " from ConsumeBuf";
}
}
if (s.ok()) {
// The RPC source tensor needs to be in CPU RAM. If not already
// there make a copy using memory appropriate to the purpose.
const size_t num_bytes = hook->prod_value->TotalBytes();
const bool on_host =
hook->prod_dev->attributes().device_type() == "CPU" ||
hook->prod_attr.on_host();
if ((!on_host) && (num_bytes > 0)) {
Device* cpu_dev = nullptr;
s = env_->device_mgr->LookupDevice("CPU:0", &cpu_dev);
if (s.ok()) {
AllocatorAttributes cpu_attr;
cpu_attr.set_gpu_compatible(true);
cpu_attr.set_nic_compatible(true);
profiler::ScopedMemoryDebugAnnotation op_annotation(
"GrpcWorker::RecvBufAsync::consumer_callback", request->step_id(),
"dynamic", hook->prod_value->dtype(),
[hook]() { return hook->prod_value->shape().DebugString(); });
Tensor* cpu_tensor =
new Tensor(cpu_dev->GetAllocator(cpu_attr),
hook->prod_value->dtype(), hook->prod_value->shape());
hook->prod_ctx->CopyDeviceTensorToCPU(
hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
[hook, cpu_tensor, rendezvous_done](const Status& s) {
rendezvous_done(*cpu_tensor, s);
BufRendezvous::DoneWithHook(hook);
delete cpu_tensor;
});
return;
}
}
}
if (hook == nullptr) {
rendezvous_done(Tensor(), s);
} else {
rendezvous_done(*hook->prod_value, s);
BufRendezvous::DoneWithHook(hook);
}
};
rma->buf_rendezvous()->ConsumeBuf(
request->buf_rendezvous_key(), request->src_device(),
request->src_incarnation(), consumer_callback,
/*cancellation_manager=*/nullptr);
}
void GrpcWorker::LoggingAsync(const LoggingRequest* request,
LoggingResponse* response, StatusCallback done) {
auto env = this->env();
if (env) {
auto session_mgr = env->session_mgr;
if (session_mgr) {
if (request->enable_rpc_logging()) {
session_mgr->SetLogging(true);
}
// NOTE(mrry): Handle old masters that disable RPC logging by setting
// `request->enable_rpc_logging` to `false`.
if (request->disable_rpc_logging() ||
(!request->enable_rpc_logging() &&
request->fetch_step_id_size() == 0)) {
session_mgr->SetLogging(false);
}
for (const auto& step_id : request->fetch_step_id()) {
session_mgr->RetrieveLogs(step_id, response);
}
if (request->clear()) {
session_mgr->ClearLogs();
}
}
}
done(OkStatus());
}
void GrpcWorker::CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) {
if (response_cache_) {
// Cleanup any stale response cache entries for this step. This can occur if
// a worker crashes before acking a request.
response_cache_->CleanEntriesForStep(request->step_id());
}
Worker::CleanupGraphAsync(request, response, done);
}
WorkerEnv* GrpcWorker::env() { return env_; }
void GrpcWorker::RemoveCacheEntryForId(int64_t request_id) {
if (response_cache_) {
response_cache_->EraseRequestId(request_id);
}
}
std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
const ConfigProto& config) {
return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
}
std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
GrpcWorker* worker, ::grpc::ServerBuilder* builder,
GrpcWorkerServiceOptions options) {
return std::unique_ptr<AsyncServiceInterface>(
new GrpcWorkerService(worker, builder, options));
}
} // namespace tensorflow