| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" |
| |
| #include "grpcpp/generic/generic_stub.h" |
| #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" |
| #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" |
| #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h" |
| #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" |
| #include "tensorflow/core/lib/core/refcount.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/protobuf/eager_service.pb.h" |
| #include "tensorflow/core/util/env_var.h" |
| |
| namespace tensorflow { |
| namespace eager { |
| namespace { |
| |
| /* |
| * Setting environment variable "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" to |
| * true will turn on asynchronous execution of remote op. It means that when |
| * executing an op on a remote worker, client will not block on waiting |
| * for the response anymore. Using follow code as example: |
| * |
| * with tf.device('worker:0'): |
| * a = tf.matmul(...) |
| * b = tf.matmul(...) |
| * logging.into('Requests sent') # Probably not executed yet |
| * logging.info('b: %s', b.numpy()) # Block until 'b' finished. |
| * |
| * Streaming RPC will preserve order as well. So 'a' must be executed before |
| * 'b' on 'worker:0'. |
| * |
| * When turning on this feature, you should explicitly wait for some result |
| * from remote workers at the end of you python program. Otherwise, client may |
| * shutdown remote workers without waiting all pending ops. |
| * |
| * TODO(fishx): When exiting client, make sure all pending ops on remote workers |
| * are finished. |
| * |
| * TODO(b/139210648): Move this comment to eager/execute.py when this feature is |
| * on by default. |
| */ |
| bool EnableStreaming() { |
| bool result; |
| TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE", |
| true, &result)); |
| return result; |
| } |
| |
| // Ref-counted thread to handle callbacks for completed requests a GRPC |
| // completion queue. The thread might be shared by multiple eager clients, and |
| // each one of them should hold a reference count to ensure that the thread |
| // outlives the clients. |
| // To ensure that every tag in completion queue is processed, this thread also |
| // holds a reference to itself and always wait until ref count is one to exit. |
| class GrpcEagerClientThread : public core::RefCounted { |
| public: |
| GrpcEagerClientThread() { |
| // Hold a reference to ensure every completion tag gets processed. |
| Ref(); |
| thread_.reset(Env::Default()->StartThread( |
| ThreadOptions(), "eager_client_thread", [this]() { |
| void* tag; |
| bool ok; |
| while (completion_queue_.Next(&tag, &ok)) { |
| VLOG(4) << "GrpcEagerClientThread got next tag"; |
| GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag); |
| callback_tag->OnCompleted(ok); |
| VLOG(4) << "GrpcEagerClientThread blocking for next tag"; |
| if (RefCountIsOne()) { |
| break; |
| } |
| } |
| VLOG(4) << "GrpcEagerClientThread exiting"; |
| completion_queue_.Shutdown(); |
| // `this` holds the final reference so cannot directly Unref here. |
| // Instead, schedule a separate thread to clean it up. |
| Env::Default()->SchedClosure([this]() { this->Unref(); }); |
| })); |
| } |
| |
| ~GrpcEagerClientThread() override {} |
| |
| ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; } |
| |
| private: |
| ::grpc::CompletionQueue completion_queue_; |
| std::unique_ptr<Thread> thread_; |
| }; |
| |
| class GrpcEagerClient : public EagerClient { |
| public: |
| GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel, |
| GrpcEagerClientThread* thread) |
| : stub_(channel), thread_(thread) { |
| // Hold a reference to make sure the corresponding EagerClientThread |
| // outlives the client. |
| thread_->Ref(); |
| cq_ = thread->completion_queue(); |
| } |
| ~GrpcEagerClient() override { thread_->Unref(); } |
| |
| #define CLIENT_METHOD(method) \ |
| void method##Async(const method##Request* request, \ |
| method##Response* response, StatusCallback done) \ |
| override { \ |
| StatusCallback done_wrapped = callback_wrapper(std::move(done)); \ |
| new RPCState<protobuf::Message>( \ |
| &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \ |
| response, std::move(done_wrapped), /*call_opts=*/nullptr, \ |
| /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true); \ |
| } |
| |
| CLIENT_METHOD(CreateContext); |
| CLIENT_METHOD(UpdateContext); |
| CLIENT_METHOD(Enqueue); |
| CLIENT_METHOD(WaitQueueDone); |
| CLIENT_METHOD(KeepAlive); |
| |
| #undef CLIENT_METHOD |
| |
| void CloseContextAsync(const CloseContextRequest* request, |
| CloseContextResponse* response, |
| StatusCallback done) override { |
| StatusCallback done_wrapped = callback_wrapper(std::move(done)); |
| new RPCState<protobuf::Message>( |
| &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request, |
| response, std::move(done_wrapped), /*call_opts=*/nullptr, |
| /*threadpool=*/nullptr); |
| |
| VLOG(1) << "Sending RPC to close remote eager context " |
| << request->DebugString(); |
| |
| mutex_lock l(mu_); |
| const auto& it = enqueue_dispatchers_.find(request->context_id()); |
| if (it != enqueue_dispatchers_.end()) { |
| it->second.CancelCall(); |
| enqueue_dispatchers_.erase(it); |
| } else if (EnableStreaming()) { |
| LOG(ERROR) << "Remote EagerContext with id " << request->context_id() |
| << " does not seem to exist."; |
| } |
| } |
| |
| void StreamingEnqueueAsync(const EnqueueRequest* request, |
| EnqueueResponse* response, |
| StatusCallback done) override { |
| StatusCallback done_wrapped = callback_wrapper(std::move(done)); |
| if (EnableStreaming()) { |
| tf_shared_lock l(mu_); |
| auto it = enqueue_dispatchers_.find(request->context_id()); |
| if (it == enqueue_dispatchers_.end()) { |
| auto it_and_bool = enqueue_dispatchers_.emplace( |
| std::piecewise_construct, |
| std::forward_as_tuple(request->context_id()), |
| std::forward_as_tuple( |
| &stub_, cq_, |
| "/tensorflow.eager.EagerService/StreamingEnqueue")); |
| it = it_and_bool.first; |
| } |
| it->second.SendNextRequest(*request, response, std::move(done_wrapped)); |
| } else { |
| Notification n; |
| Status status; |
| EnqueueAsync(request, response, [&n, &status](const Status& s) { |
| status.Update(s); |
| n.Notify(); |
| }); |
| n.WaitForNotification(); |
| done_wrapped(status); |
| } |
| } |
| |
| private: |
| ::grpc::GenericStub stub_; |
| const GrpcEagerClientThread* thread_; |
| |
| ::grpc::CompletionQueue* cq_; |
| |
| mutable mutex mu_; |
| |
| std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>> |
| enqueue_dispatchers_ GUARDED_BY(mu_); |
| |
| StatusCallback callback_wrapper(StatusCallback done) { |
| Ref(); |
| return [this, done = std::move(done)](const Status& status) { |
| done(status); |
| this->Unref(); |
| }; |
| } |
| }; |
| |
| class GrpcEagerClientCache : public EagerClientCache { |
| public: |
| explicit GrpcEagerClientCache( |
| std::shared_ptr<tensorflow::GrpcChannelCache> cache) |
| : next_round_robin_assignment_(0), cache_(cache), threads_(4) { |
| for (int i = 0; i < threads_.size(); i++) { |
| threads_[i].reset(new GrpcEagerClientThread()); |
| } |
| } |
| |
| ~GrpcEagerClientCache() override { threads_.clear(); } |
| |
| Status GetClient(const string& target, |
| core::RefCountPtr<EagerClient>* client) override { |
| auto it = clients_.find(target); |
| if (it == clients_.end()) { |
| tensorflow::SharedGrpcChannelPtr shared = |
| cache_->FindWorkerChannel(target); |
| if (shared == nullptr) { |
| return errors::InvalidArgument("Client for target ", target, |
| " not found."); |
| } |
| int assigned_index = AssignClientToThread(target); |
| GrpcEagerClientThread* thread = threads_[assigned_index].get(); |
| auto worker = new GrpcEagerClient(shared, thread); |
| it = clients_.emplace(target, worker).first; |
| } |
| |
| it->second->Ref(); |
| client->reset(it->second.get()); |
| return Status::OK(); |
| } |
| |
| private: |
| mutex assignment_mu_; |
| std::unordered_map<std::string, size_t> target_assignments_ |
| GUARDED_BY(assignment_mu_); |
| size_t next_round_robin_assignment_ GUARDED_BY(assignment_mu_); |
| |
| size_t AssignClientToThread(const string& target) { |
| // Round-robin target assignment, but keeps the same target on the same |
| // polling thread always, as this is important for gRPC performance |
| mutex_lock lock(assignment_mu_); |
| auto it = target_assignments_.find(target); |
| if (it == target_assignments_.end()) { |
| it = target_assignments_ |
| .insert(std::make_pair( |
| target, (next_round_robin_assignment_++) % threads_.size())) |
| .first; |
| } |
| return it->second; |
| } |
| |
| std::shared_ptr<tensorflow::GrpcChannelCache> cache_; |
| std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_; |
| std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_; |
| }; |
| |
| } // namespace |
| |
| EagerClientCache* NewGrpcEagerClientCache( |
| std::shared_ptr<tensorflow::GrpcChannelCache> channel) { |
| return new GrpcEagerClientCache(channel); |
| } |
| |
| } // namespace eager |
| } // namespace tensorflow |