Reuse existing rendezvous when updating cluster.

Before this change, we clean up and recreate rendezvous when updating members in the cluster. It cancels pending function calls on workers that are not affected by this update. This CL keeps the original rendezvous and the WorkerSession it refers to, and updates the WorkerSession in place (just like what we already do on the worker side). In addition, reuse the session manager in the grpc_server worker_env since there is no need to recreate it.

PiperOrigin-RevId: 278735207
Change-Id: I883b50c70b1cd269db2b2bce5dc3b9481ab88e10
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 83649f2..e6dfb20 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -457,7 +457,8 @@
     remote_device_mgr = new_remote_device_mgr.get();
   } else {
     ctx->context->ClearCaches();
-    grpc_server->worker_env()->rendezvous_mgr->Cleanup(context_id);
+    // TODO(b/143914772): Potential memory leak if rendezvous has pending
+    // tensors for removed / replaced workers.
 
     remote_device_mgr = ctx->context->GetOwnedRemoteDeviceMgr();
     if (remote_device_mgr == nullptr) {
@@ -552,43 +553,52 @@
 
   tensorflow::RemoteRendezvous* r =
       grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
-
   auto session_name = tensorflow::strings::StrCat("eager_", context_id);
-  TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
-      session_name, server_def, base_request.cluster_device_attributes(),
-      true));
-
-  std::shared_ptr<tensorflow::WorkerSession> worker_session;
-  TF_RETURN_IF_ERROR(
-      grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
-          session_name, &worker_session));
-
-  // Initialize remote tensor communication based on worker session.
-  TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
-
   auto* device_mgr = grpc_server->worker_env()->device_mgr;
-  tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
-      tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
-                                          worker_session.get());
-  auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
-      /*is_master=*/true, ctx->context);
+  std::shared_ptr<tensorflow::WorkerSession> worker_session;
 
   if (reset_context) {
+    TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
+        session_name, server_def, base_request.cluster_device_attributes(),
+        true));
+    TF_RETURN_IF_ERROR(
+        grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
+            session_name, &worker_session));
+
+    // Initialize remote tensor communication based on worker session.
+    TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
+
+    tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
+        tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
+                                            worker_session.get());
+    auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
+        /*is_master=*/true, ctx->context);
+
     LOG_AND_RETURN_IF_ERROR(ctx->context->InitializeRemoteMaster(
         std::move(new_server), grpc_server->worker_env(), worker_session,
         std::move(remote_eager_workers), std::move(new_remote_device_mgr),
         remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
         std::move(remote_mgr)));
-  } else {
-    LOG_AND_RETURN_IF_ERROR(ctx->context->UpdateRemoteMaster(
-        grpc_server->worker_env(), worker_session,
-        std::move(remote_eager_workers), added_workers, removed_workers,
-        context_id, r, device_mgr, keep_alive_secs, cluster_flr));
-  }
 
-  // NOTE: We start the server after all other initialization, because the
-  // GrpcServer cannot be destroyed after it is started.
-  LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
+    // NOTE: We start the server after all other initialization, because the
+    // GrpcServer cannot be destroyed after it is started.
+    LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
+  } else {
+    LOG_AND_RETURN_IF_ERROR(
+        grpc_server->worker_env()->session_mgr->UpdateSession(
+            session_name, server_def, base_request.cluster_device_attributes(),
+            true));
+    TF_RETURN_IF_ERROR(
+        grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
+            session_name, &worker_session));
+    tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
+        tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
+                                            worker_session.get());
+    LOG_AND_RETURN_IF_ERROR(ctx->context->UpdateRemoteMaster(
+        grpc_server->worker_env(), std::move(remote_eager_workers),
+        added_workers, removed_workers, context_id, r, device_mgr,
+        keep_alive_secs, cluster_flr));
+  }
 #undef LOG_AND_RETURN_IF_ERROR
 
   return tensorflow::Status::OK();
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 717e01f..3a0b24f 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -795,7 +795,7 @@
 }
 
 Status EagerContext::UpdateRemoteMaster(
-    WorkerEnv* worker_env, std::shared_ptr<WorkerSession> worker_session,
+    WorkerEnv* worker_env,
     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
     const std::vector<string>& add_remote_contexts,
     const std::vector<string>& remove_remote_contexts, uint64 context_id,
@@ -832,7 +832,7 @@
   }
   std::vector<const FunctionDef*> function_defs = ListRegisteredFunctions();
   TF_RETURN_IF_ERROR(SetMasterContextState(
-      /*server=*/nullptr, worker_env, std::move(worker_session),
+      /*server=*/nullptr, worker_env, /*worker_session=*/nullptr,
       std::move(remote_eager_workers), /*remote_device_manager=*/nullptr,
       context_id, GetContextViewId() + 1, r, local_device_mgr, keep_alive_secs,
       cluster_flr, /*remote_mgr=*/nullptr));
@@ -851,9 +851,9 @@
 }
 
 // Set distributed execution related fields in the master context. Passing
-// nullptr to `server` / `remote_device_mgr` will only update the existing GRPC
-// server / remote device manager in the master context (instead of resetting
-// with new ones).
+// nullptr to `server` / `worker_session` / `remote_device_mgr` will only update
+// the existing GRPC server / worker session / remote device manager in the
+// master context (instead of resetting with new ones).
 Status EagerContext::SetMasterContextState(
     std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
     std::shared_ptr<WorkerSession> worker_session,
@@ -893,7 +893,10 @@
     remote_mgr_ = std::move(remote_mgr);
   }
   worker_env_ = worker_env;
-  worker_session_ = worker_session;
+  if (worker_session != nullptr) {
+    worker_session_ = worker_session;
+  }
+  DCHECK(worker_session_ != nullptr);
   remote_eager_workers_ = std::move(remote_eager_workers);
 
   if (remote_device_manager != nullptr) {
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 95da8be..1749188 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -210,7 +210,6 @@
   bool LogMemory() const { return log_memory_; }
 
   Rendezvous* GetRendezvous() const { return rendezvous_; }
-  void ResetRendezvous(Rendezvous* r) { rendezvous_ = r; }
   Rendezvous* CreateRendezvous(const int64 step_id) const {
     if (rendezvous_creator_ != nullptr) {
       return rendezvous_creator_(step_id);
@@ -305,7 +304,7 @@
   // can still be accessed, and will automatically register existing functions
   // if there are newly added hosts.
   Status UpdateRemoteMaster(
-      WorkerEnv* worker_env, std::shared_ptr<WorkerSession> worker_session,
+      WorkerEnv* worker_env,
       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
       const std::vector<string>& add_remote_contexts,
       const std::vector<string>& remove_remote_contexts, uint64 context_id,
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index c5a349a..083aeef 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -235,14 +235,9 @@
         " but received update request at view #", request->context_view_id(),
         ". View id should only be continuously incremented.");
   }
-
-  // Remove then recreate rendezvous. Necessary because rondezvous does not
-  // allow double initialization.
-  // NOTE: safe to clean up rendezvous on worker assuming the remote client
-  // calls to WaitForAllPendingNodes on all executors (for example, through
-  // `ClearCaches()`) before issuing requests to update contexts.
-  env_->rendezvous_mgr->Cleanup(request->context_id());
-  auto* r = env_->rendezvous_mgr->Find(request->context_id());
+  ctx->ClearCaches();
+  // TODO(b/143914772): Potential memory leak if rendezvous has pending
+  // tensors for removed / replaced workers.
 
   std::vector<DeviceAttributes> cluster_device_attributes;
   cluster_device_attributes.reserve(
@@ -262,10 +257,6 @@
 
   tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
 
-  // Initialize remote tensor communication based on worker session.
-  TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
-  ctx->ResetRendezvous(r);
-
   std::vector<string> remote_workers;
   worker_session->worker_cache()->ListWorkers(&remote_workers);
   remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 200a20b..32083fc 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -437,13 +437,6 @@
       std::move(dev_resolver), std::move(param_resolver), worker_cache,
       default_worker_name);
 
-  worker_env_.session_mgr = new SessionMgr(
-      &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
-      std::unique_ptr<WorkerCacheInterface>(worker_cache),
-      [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
-        WorkerCacheFactoryOptions options(server_def);
-        return WorkerCacheFactory(options, worker_cache);
-      });
   master_env_.worker_cache = worker_cache;
   master_env_.collective_executor_mgr = worker_env_.collective_executor_mgr;
   return Status::OK();
diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h
index 97ff73d..09bb41d 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.h
+++ b/tensorflow/core/distributed_runtime/session_mgr.h
@@ -53,6 +53,8 @@
       const protobuf::RepeatedPtrField<DeviceAttributes>& device_attributes,
       bool isolate_session_state);
 
+  // Updates state (worker cache, devices) of worker session identified by
+  // session name (`session`) based on a new server_def and set of devices.
   Status UpdateSession(const string& session, const ServerDef& server_def,
                        const protobuf::RepeatedPtrField<DeviceAttributes>&
                            cluster_device_attributes,