Protect EagerContext on worker side when updating cluster.

When handling worker failures, the failure handling thread sends update context request to all workers. In the meanwhile, other eager executors might be sending op/function execution requests. This change avoids the necessity of grabbing a global lock on the client side to prevent race conditions of concurrent updating and execution.

* Ref count the eager client to avoid deallocating them before pending requests finish.
* Hold context lock on worker side to avoid concurrently executing enqueue ops while handling context update.
* Adjust local device initialization to avoid clearing the _context_devices list since this can be called multiple times by update_server_def.

PiperOrigin-RevId: 283847202
Change-Id: I3f84d56c44cd2adce5136f7fd4f67313a1da3610
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 46ade1b..8793e30 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -233,7 +233,7 @@
   std::vector<tensorflow::eager::KeepAliveResponse> responses(
       existing_workers->size());
   for (int i = 0; i < existing_workers->size(); i++) {
-    tensorflow::eager::EagerClient* eager_client;
+    tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
     statuses[i] =
         client_cache->GetClient(existing_workers->at(i), &eager_client);
     if (!statuses[i].ok()) {
@@ -282,7 +282,7 @@
       continue;
     }
 
-    tensorflow::eager::EagerClient* eager_client;
+    tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
     statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
     if (eager_client == nullptr) {
       statuses[i] = tensorflow::errors::Internal(
@@ -340,7 +340,7 @@
       continue;
     }
 
-    tensorflow::eager::EagerClient* eager_client;
+    tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
     statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
     if (eager_client == nullptr) {
       statuses[i] = tensorflow::errors::Internal(
@@ -819,7 +819,7 @@
   }
 
   // TODO(yuefengz): support partially specified `worker_name`.
-  tensorflow::eager::EagerClient* eager_client;
+  tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
   status->status = remote_eager_workers->GetClient(worker_name, &eager_client);
   if (!status->status.ok()) {
     return false;
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index b5b0bce..b8dd8d8 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -273,7 +273,7 @@
 
   int i = 0;
   for (const auto& worker : remote_contexts) {
-    eager::EagerClient* client;
+    core::RefCountPtr<eager::EagerClient> client;
     Status s = remote_eager_workers_->GetClient(worker, &client);
 
     client->CloseContextAsync(
@@ -449,7 +449,7 @@
       register_function->mutable_function_def()->mutable_node_def());
 
   for (const auto& target : remote_contexts_) {
-    eager::EagerClient* eager_client;
+    core::RefCountPtr<eager::EagerClient> eager_client;
     TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
 
     eager::EnqueueResponse* response = new eager::EnqueueResponse();
@@ -475,7 +475,7 @@
   // Register multiple functions on selected remote workers.
   uint64 context_id = GetContextId();
   for (int i = 0; i < remote_workers.size(); i++) {
-    eager::EagerClient* eager_client;
+    core::RefCountPtr<eager::EagerClient> eager_client;
     Status s =
         remote_eager_workers_->GetClient(remote_workers[i], &eager_client);
     if (!s.ok()) {
@@ -649,12 +649,13 @@
 }  // namespace
 
 #if !defined(IS_MOBILE_PLATFORM)
-Status EagerContext::GetClient(Device* device, eager::EagerClient** client) {
+Status EagerContext::GetClient(Device* device,
+                               core::RefCountPtr<eager::EagerClient>* client) {
   return GetClient(device->parsed_name(), client);
 }
 
 Status EagerContext::GetClient(const DeviceNameUtils::ParsedName& device_name,
-                               eager::EagerClient** client) {
+                               core::RefCountPtr<eager::EagerClient>* client) {
   if (remote_eager_workers_ == nullptr) {
     return errors::Internal(
         "Haven't set up remote eager worker in this eager context yet.");
@@ -685,7 +686,7 @@
 }
 
 Status EagerContext::GetClient(const string& remote_task,
-                               eager::EagerClient** client) {
+                               core::RefCountPtr<eager::EagerClient>* client) {
   if (remote_eager_workers_ == nullptr) {
     return errors::Internal(
         "Haven't set up remote eager worker in this eager context yet.");
@@ -934,7 +935,7 @@
                 if (keep_alive_secs_ > 0) {
                   {
                     for (const auto& worker : remote_contexts_) {
-                      eager::EagerClient* client;
+                      core::RefCountPtr<eager::EagerClient> client;
                       Status s =
                           remote_eager_workers_->GetClient(worker, &client);
 
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 4a9606e..93fbd89 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -265,10 +265,18 @@
   FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
 
 #if !defined(IS_MOBILE_PLATFORM)
-  Status GetClient(Device* device, eager::EagerClient** client);
+  // Assign the EagerClient pointer to `client` based on the given device / task
+  // name, and increment the refcount of the client. The reference ownership is
+  // transferred to the caller, and the unref should automatically happen when
+  // destructing the RefCountPtr object at the caller's side.
+  // `client` must not be initialized or holding a reference of another object
+  // before calling this method.
+  Status GetClient(Device* device,
+                   core::RefCountPtr<eager::EagerClient>* client);
   Status GetClient(const DeviceNameUtils::ParsedName& device_name,
-                   eager::EagerClient** client);
-  Status GetClient(const string& remote_task, eager::EagerClient** client);
+                   core::RefCountPtr<eager::EagerClient>* client);
+  Status GetClient(const string& remote_task,
+                   core::RefCountPtr<eager::EagerClient>* client);
 
   uint64 GetContextId();
   uint64 GetContextViewId();
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 32fdb21..32937bf 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/refcount.h"
 #include "tensorflow/core/platform/platform.h"
 // clang-format on
 
@@ -727,7 +728,7 @@
     op->SetDevice(device);
   }
 
-  eager::EagerClient* eager_client = nullptr;
+  core::RefCountPtr<eager::EagerClient> eager_client;
   uint64 context_id = ctx->GetContextId();
   TF_RETURN_IF_ERROR(ctx->GetClient(op->GetDeviceParsedName(), &eager_client));
   string remote_task;
@@ -860,7 +861,7 @@
            << " (is async?: " << executor.Async() << ").";
 
   std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
-      std::move(request), op_device, eager_client,
+      std::move(request), op_device, eager_client.get(),
       op->MutableAttrs()->BuildNodeDef(), op->EagerContext()->FuncLibDef(),
       op->Inputs(), {retvals, num_outputs}));
   Status s = executor.AddOrExecute(std::move(node));
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD
index bbcc10b..6cd525b 100644
--- a/tensorflow/core/distributed_runtime/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/eager/BUILD
@@ -65,6 +65,7 @@
     deps = [
         "//tensorflow/core:eager_service_proto_cc",
         "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
     ],
 )
 
diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc
index a1cfe58..3f94028 100644
--- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc
+++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc
@@ -59,7 +59,7 @@
 
   VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
           << " (this: " << this << ")";
-  eager::EagerClient* eager_client = nullptr;
+  core::RefCountPtr<eager::EagerClient> eager_client;
   Device* device;
   s = ctx_->FindDeviceFromName(target.c_str(), &device);
   if (!s.ok()) {
@@ -97,7 +97,8 @@
 
   eager_client->EnqueueAsync(request, response,
                              [this, request, response, handle, released_op,
-                              target, eager_client, done](const Status& s) {
+                              target, eager_client = eager_client.get(),
+                              done](const Status& s) {
                                {
                                  mutex_lock l(mu_);
                                  *handle = function_data_.size();
diff --git a/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h b/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h
index 869345f..bc1670b 100644
--- a/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h
+++ b/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h
@@ -30,45 +30,24 @@
 class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode {
  public:
   DestroyTensorHandleNode(std::unique_ptr<EnqueueRequest> request,
-                          EagerContext* ctx, const string& remote_task,
-                          bool ready)
+                          EagerClient* eager_client, bool ready)
       : tensorflow::AsyncEagerNode(),
         request_(std::move(request)),
-        ctx_(ctx),
-        remote_task_(remote_task),
+        eager_client_(eager_client),
         ready_(ready) {
-    ctx_->Ref();
+    eager_client_->Ref();
   }
 
-  ~DestroyTensorHandleNode() override { ctx_->Unref(); }
+  ~DestroyTensorHandleNode() override { eager_client_->Unref(); }
 
   void RunAsync(StatusCallback done) override {
-    auto context_id = request_->context_id();
-    if (ctx_->GetContextId() != context_id) {
-      // This means that this tensor was pointing to a remote device, which
-      // has been changed out from under us. Simply return since there is
-      // nothing we can do.
-      done(Status::OK());
-      return;
-    }
-
-    eager::EagerClient* eager_client;
-    Status status = ctx_->GetClient(remote_task_, &eager_client);
-    if (!status.ok()) {
-      LOG_EVERY_N_SEC(INFO, 60)
-          << "Unable to destroy remote tensor handle because the target "
-          << remote_task_ << " is no longer available.";
-      done(Status::OK());
-      return;
-    }
-
     EnqueueResponse* response = new EnqueueResponse;
     bool ready = ready_;
     // NOTE(fishx): Don't use StreamingEnqueueAsync here. When a
     // StreamingEnqueueAsync request fails all following requests will fail as
     // well. We don't want this request poison following requests since it is
     // safe to ignore a failing destroy tensor handle request.
-    eager_client->EnqueueAsync(
+    eager_client_->EnqueueAsync(
         request_.get(), response,
         [response, ready, done](const tensorflow::Status& s) {
           // Omit the warning if:
@@ -96,7 +75,7 @@
 
  private:
   std::unique_ptr<EnqueueRequest> request_;
-  EagerContext* ctx_;
+  EagerClient* eager_client_;
   const string remote_task_;
   bool ready_;
 };
diff --git a/tensorflow/core/distributed_runtime/eager/eager_client.h b/tensorflow/core/distributed_runtime/eager/eager_client.h
index 089cf25..3b083f3 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_client.h
+++ b/tensorflow/core/distributed_runtime/eager/eager_client.h
@@ -16,6 +16,7 @@
 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_CLIENT_H_
 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_CLIENT_H_
 
+#include "tensorflow/core/lib/core/refcount.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/protobuf/eager_service.pb.h"
@@ -25,9 +26,9 @@
 
 // This is a base class that can be implemented by a variety of
 // transports (e.g. gRPC which for each of the client methods makes an RPC).
-class EagerClient {
+class EagerClient : public core::RefCounted {
  public:
-  virtual ~EagerClient() {}
+  ~EagerClient() override {}
 #define CLIENT_METHOD(method)                                \
   virtual void method##Async(const method##Request* request, \
                              method##Response* response,     \
@@ -62,7 +63,13 @@
 class EagerClientCache {
  public:
   virtual ~EagerClientCache() {}
-  virtual Status GetClient(const string& target, EagerClient** client) = 0;
+
+  // If the `target` exists, assign the EagerClient pointer to `client` and
+  // increment the refcount of the client. The reference ownership is
+  // transferred to the caller, and the unref should automatically happen when
+  // destructing the RefCountPtr object from the caller's side.
+  virtual Status GetClient(const string& target,
+                           core::RefCountPtr<EagerClient>* client) = 0;
 };
 
 }  // namespace eager
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 92e3d2f..e1a5f34 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -235,7 +235,6 @@
         " but received update request at view #", request->context_view_id(),
         ". View id should only be continuously incremented.");
   }
-  ctx->ClearCaches();
   // TODO(b/143914772): Potential memory leak if rendezvous has pending
   // tensors for removed / replaced workers.
 
@@ -277,13 +276,25 @@
   DistributedFunctionLibraryRuntime* cluster_flr =
       eager::CreateClusterFLR(request->context_id(), ctx, worker_session.get());
 
-  Status s = ctx->UpdateRemoteWorker(
-      device_mgr, std::move(remote_eager_workers),
-      worker_session->remote_device_mgr(), remote_workers,
-      request->context_id(), cluster_flr);
-  if (!s.ok()) {
-    VLOG(1) << "EagerContext::UpdateRemoteWorker failed with " << s.ToString();
-    return s;
+  {
+    // Hold `contexts_mu_` exclusively, wait for all pending nodes to finish
+    // (implicitly calling WaitForAllPendingNodes inside `ctx->ClearCaches`),
+    // and update the context state.
+    // This lock prevents other threads from handling enqueue requests at the
+    // same time. Each enqueue request will be processed either with context
+    // state before or after the update, but the exact ordering needs to be
+    // determined by the client if desired.
+    mutex_lock lock(contexts_mu_);
+    ctx->ClearCaches();
+    Status s = ctx->UpdateRemoteWorker(
+        device_mgr, std::move(remote_eager_workers),
+        worker_session->remote_device_mgr(), remote_workers,
+        request->context_id(), cluster_flr);
+    if (!s.ok()) {
+      VLOG(1) << "EagerContext::UpdateRemoteWorker failed with "
+              << s.ToString();
+      return s;
+    }
   }
 
   std::vector<DeviceAttributes> device_attributes;
@@ -408,6 +419,9 @@
   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
   core::ScopedUnref context_unref(context);
 
+  // Acquire shared lock to prevent handling enqueue requests while updating
+  // context (see UpdateContext).
+  tf_shared_lock lock(contexts_mu_);
   EagerExecutor& executor =
       stream_id == kInvalidStreamId
           ? context->Context()->Executor()
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index dbf3c63..a2c15da 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -103,13 +103,15 @@
 class DummyEagerClientCache : public EagerClientCache {
  public:
   DummyEagerClientCache() : client_(new FakeEagerClient) {}
-  Status GetClient(const string& target, EagerClient** client) override {
-    *client = client_.get();
+  Status GetClient(const string& target,
+                   core::RefCountPtr<EagerClient>* client) override {
+    client->reset(client_.get());
+    client_->Ref();
     return Status::OK();
   }
 
  private:
-  std::unique_ptr<EagerClient> client_;
+  core::RefCountPtr<EagerClient> client_;
 };
 
 class FakeCache : public TestWorkerCache {
@@ -481,9 +483,9 @@
     TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
     Device* device;
     TF_ASSERT_OK(ctx->FindDeviceFromName(local_device_.c_str(), &device));
-    EagerClient* client;
+    core::RefCountPtr<EagerClient> client;
     TF_ASSERT_OK(ctx->GetClient(device, &client));
-    FakeEagerClient* fake_client = static_cast<FakeEagerClient*>(client);
+    FakeEagerClient* fake_client = static_cast<FakeEagerClient*>(client.get());
     fake_client->SetServiceImpl(&eager_service_impl_);
 
     // Create an input on local_device for MatMulFunction.
diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
index 0dfcd82..d0b07a5 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
@@ -156,7 +156,7 @@
     remote_op->set_id(ctx_->RemoteMgr()->NextOpId());
 
     // Issue the RPC
-    eager::EagerClient* eager_client;
+    core::RefCountPtr<eager::EagerClient> eager_client;
     status = ctx_->GetClient(send_device_, &eager_client);
     if (!status.ok()) {
       captured_state_->SetSendStatus(status);
@@ -199,7 +199,7 @@
   PrepareRemoteOp(remote_op, op);
   remote_op->set_id(recv_op_id_);
 
-  eager::EagerClient* eager_client;
+  core::RefCountPtr<eager::EagerClient> eager_client;
   Status status = ctx_->GetClient(recv_device_, &eager_client);
   if (!status.ok()) {
     captured_state_->dst()->Poison(status);
@@ -307,7 +307,7 @@
   }
   tensor.AsProtoTensorContent(send_tensor->add_tensors());
 
-  eager::EagerClient* eager_client;
+  core::RefCountPtr<eager::EagerClient> eager_client;
   s = ctx_->GetClient(recv_device_, &eager_client);
   if (!s.ok()) {
     captured_state_->dst()->Poison(s);
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
index 3736173..b0342fc 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
@@ -60,6 +60,7 @@
     for (auto handle : inputs_) {
       handle->Ref();
     }
+    eager_client_->Ref();
   }
 
   ~RemoteExecuteNode() override {
@@ -70,6 +71,7 @@
     for (auto handle : inputs_) {
       handle->Unref();
     }
+    eager_client_->Unref();
   }
 
   Status Prepare() override {
diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc
index 58741ee..af63c20 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc
@@ -34,7 +34,7 @@
     return;
   }
 
-  eager::EagerClient* eager_client;
+  core::RefCountPtr<eager::EagerClient> eager_client;
   Status status = ctx->GetClient(remote_task, &eager_client);
   if (!status.ok()) {
     LOG_EVERY_N_SEC(INFO, 60)
@@ -52,8 +52,8 @@
 
   VLOG(3) << "Sending request to delete " << request->DebugString();
   std::unique_ptr<EagerNode> node(
-      absl::make_unique<eager::DestroyTensorHandleNode>(std::move(request), ctx,
-                                                        remote_task, ready));
+      absl::make_unique<eager::DestroyTensorHandleNode>(
+          std::move(request), eager_client.get(), ready));
   auto& executor = ctx->Executor();
   if (executor.Async()) {
     Status status = executor.AddOrExecute(std::move(node));
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
index 487479a..921696e 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
@@ -20,6 +20,7 @@
 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/lib/core/refcount.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/protobuf/eager_service.pb.h"
@@ -61,21 +62,68 @@
   return result;
 }
 
+// Ref-counted thread to handle callbacks for completed requests a GRPC
+// completion queue. The thread might be shared by multiple eager clients, and
+// each one of them should hold a reference count to ensure that the thread
+// outlives the clients.
+// To ensure that every tag in completion queue is processed, this thread also
+// holds a reference to itself and always wait until ref count is one to exit.
+class GrpcEagerClientThread : public core::RefCounted {
+ public:
+  GrpcEagerClientThread() {
+    // Hold a reference to ensure every completion tag gets processed.
+    Ref();
+    thread_.reset(Env::Default()->StartThread(
+        ThreadOptions(), "eager_client_thread", [this]() {
+          void* tag;
+          bool ok;
+          while (completion_queue_.Next(&tag, &ok)) {
+            VLOG(4) << "GrpcEagerClientThread got next tag";
+            GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
+            callback_tag->OnCompleted(ok);
+            VLOG(4) << "GrpcEagerClientThread blocking for next tag";
+            if (RefCountIsOne()) {
+              break;
+            }
+          }
+          VLOG(4) << "GrpcEagerClientThread exiting";
+          completion_queue_.Shutdown();
+          // `this` holds the final reference so cannot directly Unref here.
+          // Instead, schedule a separate thread to clean it up.
+          Env::Default()->SchedClosure([this]() { this->Unref(); });
+        }));
+  }
+
+  ~GrpcEagerClientThread() override {}
+
+  ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
+
+ private:
+  ::grpc::CompletionQueue completion_queue_;
+  std::unique_ptr<Thread> thread_;
+};
+
 class GrpcEagerClient : public EagerClient {
  public:
   GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
-                  ::grpc::CompletionQueue* cq)
-      : stub_(channel), cq_(cq) {}
-  ~GrpcEagerClient() override {}
+                  GrpcEagerClientThread* thread)
+      : stub_(channel), thread_(thread) {
+    // Hold a reference to make sure the corresponding EagerClientThread
+    // outlives the client.
+    thread_->Ref();
+    cq_ = thread->completion_queue();
+  }
+  ~GrpcEagerClient() override { thread_->Unref(); }
 
 #define CLIENT_METHOD(method)                                             \
   void method##Async(const method##Request* request,                      \
                      method##Response* response, StatusCallback done)     \
       override {                                                          \
+    StatusCallback done_wrapped = callback_wrapper(std::move(done));      \
     new RPCState<protobuf::Message>(                                      \
         &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
-        response, std::move(done), nullptr, nullptr, /*max_retries=*/0,   \
-        /*fail_fast=*/true);                                              \
+        response, std::move(done_wrapped), /*call_opts=*/nullptr,         \
+        /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true);   \
   }
 
   CLIENT_METHOD(CreateContext);
@@ -89,9 +137,11 @@
   void CloseContextAsync(const CloseContextRequest* request,
                          CloseContextResponse* response,
                          StatusCallback done) override {
+    StatusCallback done_wrapped = callback_wrapper(std::move(done));
     new RPCState<protobuf::Message>(
         &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
-        response, std::move(done), nullptr, nullptr);
+        response, std::move(done_wrapped), /*call_opts=*/nullptr,
+        /*threadpool=*/nullptr);
 
     VLOG(1) << "Sending RPC to close remote eager context "
             << request->DebugString();
@@ -110,6 +160,7 @@
   void StreamingEnqueueAsync(const EnqueueRequest* request,
                              EnqueueResponse* response,
                              StatusCallback done) override {
+    StatusCallback done_wrapped = callback_wrapper(std::move(done));
     if (EnableStreaming()) {
       tf_shared_lock l(mu_);
       auto it = enqueue_dispatchers_.find(request->context_id());
@@ -122,7 +173,7 @@
                 "/tensorflow.eager.EagerService/StreamingEnqueue"));
         it = it_and_bool.first;
       }
-      it->second.SendNextRequest(*request, response, std::move(done));
+      it->second.SendNextRequest(*request, response, std::move(done_wrapped));
     } else {
       Notification n;
       Status status;
@@ -131,29 +182,44 @@
         n.Notify();
       });
       n.WaitForNotification();
-      done(status);
+      done_wrapped(status);
     }
   }
 
  private:
   ::grpc::GenericStub stub_;
+  const GrpcEagerClientThread* thread_;
+
   ::grpc::CompletionQueue* cq_;
 
   mutable mutex mu_;
 
   std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
       enqueue_dispatchers_ GUARDED_BY(mu_);
+
+  StatusCallback callback_wrapper(StatusCallback done) {
+    Ref();
+    return [this, done = std::move(done)](const Status& status) {
+      done(status);
+      this->Unref();
+    };
+  }
 };
 
 class GrpcEagerClientCache : public EagerClientCache {
  public:
   explicit GrpcEagerClientCache(
       std::shared_ptr<tensorflow::GrpcChannelCache> cache)
-      : next_round_robin_assignment_(0), cache_(cache), threads_(4) {}
+      : next_round_robin_assignment_(0), cache_(cache), threads_(4) {
+    for (int i = 0; i < threads_.size(); i++) {
+      threads_[i].reset(new GrpcEagerClientThread());
+    }
+  }
 
   ~GrpcEagerClientCache() override { threads_.clear(); }
 
-  Status GetClient(const string& target, EagerClient** client) override {
+  Status GetClient(const string& target,
+                   core::RefCountPtr<EagerClient>* client) override {
     auto it = clients_.find(target);
     if (it == clients_.end()) {
       tensorflow::SharedGrpcChannelPtr shared =
@@ -162,13 +228,14 @@
         return errors::InvalidArgument("Client for target ", target,
                                        " not found.");
       }
-      auto worker = std::unique_ptr<EagerClient>(new GrpcEagerClient(
-          shared, threads_[AssignClientToThread(target)].completion_queue()));
-
-      it = clients_.emplace(target, std::move(worker)).first;
+      int assigned_index = AssignClientToThread(target);
+      GrpcEagerClientThread* thread = threads_[assigned_index].get();
+      auto worker = new GrpcEagerClient(shared, thread);
+      it = clients_.emplace(target, worker).first;
     }
 
-    *client = it->second.get();
+    it->second->Ref();
+    client->reset(it->second.get());
     return Status::OK();
   }
 
@@ -192,39 +259,9 @@
     return it->second;
   }
 
-  class GrpcEagerClientThread {
-   public:
-    GrpcEagerClientThread() {
-      thread_.reset(Env::Default()->StartThread(
-          ThreadOptions(), "eager_client_thread", [this]() {
-            void* tag;
-            bool ok;
-            while (completion_queue_.Next(&tag, &ok)) {
-              VLOG(4) << "GrpcEagerClientThread got next tag";
-              GrpcClientCQTag* callback_tag =
-                  static_cast<GrpcClientCQTag*>(tag);
-              callback_tag->OnCompleted(ok);
-              VLOG(4) << "GrpcEagerClientThread blocking for next tag";
-            }
-            VLOG(4) << "GrpcEagerClientThread exiting";
-          }));
-    }
-
-    ~GrpcEagerClientThread() {
-      completion_queue_.Shutdown();
-      thread_.reset();
-    }
-
-    ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
-
-   private:
-    ::grpc::CompletionQueue completion_queue_;
-    std::unique_ptr<Thread> thread_;
-  };  // GrpcEagerClientThread
-
   std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
-  std::unordered_map<string, std::unique_ptr<EagerClient>> clients_;
-  std::vector<GrpcEagerClientThread> threads_;
+  std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_;
+  std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
 };
 
 }  // namespace
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index dbcdd4a..19626ec 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -461,27 +461,29 @@
   def _initialize_logical_devices(self):
     """Helper to initialize devices."""
     # Store list of devices
-    self._logical_devices = []
-    self._context_devices = []
+    logical_devices = []
+    context_devices = []
     device_list = pywrap_tensorflow.TFE_ContextListDevices(
         self._context_handle)
     try:
       self._num_gpus = 0
       for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
         dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
-        self._context_devices.append(pydev.canonical_name(dev_name))
+        context_devices.append(pydev.canonical_name(dev_name))
         spec = pydev.DeviceSpec.from_string(dev_name)
         # If the job is localhost, we assume that the cluster has not yet been
         # configured and thus clear the job, replica & task.
         if spec.job == "localhost":
           spec = spec.replace(job=None, replica=None, task=None)
-        self._logical_devices.append(
+        logical_devices.append(
             LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
         dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
         if dev_type == "GPU":
           self._num_gpus += 1
 
     finally:
+      self._logical_devices = logical_devices
+      self._context_devices = context_devices
       pywrap_tensorflow.TF_DeleteDeviceList(device_list)
 
   def ensure_initialized(self):