Allow coordination service agent to be initialized with a single RPC channel (instead of an RPC channel cache).

PiperOrigin-RevId: 408408618
Change-Id: I828d1674f785ad693ca15e0be25931e2fcd7bd3c
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_client.h b/tensorflow/core/distributed_runtime/coordination/coordination_client.h
index b8bea15..1f92eb6 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_client.h
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_client.h
@@ -16,6 +16,7 @@
 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_CLIENT_H_
 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_CLIENT_H_
 
+#include <memory>
 #include <string>
 
 namespace tensorflow {
@@ -32,6 +33,11 @@
   // If the `target` names a remote task, returns a pointer of the
   // CoordinationClient object wrapping that channel to the remote task.
   virtual CoordinationClient* GetClient(const std::string& target) = 0;
+
+  // If the `target` names a remote task, returns an owned pointer of the
+  // CoordinationClient object wrapping that channel to the remote task.
+  virtual std::unique_ptr<CoordinationClient> GetOwnedClient(
+      const std::string& target) = 0;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h
index 297a4ee..9dd8832 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h
@@ -25,6 +25,7 @@
 #include "tensorflow/core/platform/statusor.h"
 
 namespace tensorflow {
+class CoordinationServiceConfig;
 class DeviceAttributes;
 class WorkerEnv;
 class ServerDef;
@@ -52,6 +53,12 @@
       const WorkerEnv* worker_env, const ServerDef& server_def,
       std::unique_ptr<CoordinationClientCache> client_cache,
       StatusCallback error_fn) = 0;
+  virtual Status Initialize(const WorkerEnv* worker_env,
+                            const std::string& job_name, int task_id,
+                            const CoordinationServiceConfig& configs,
+                            std::unique_ptr<CoordinationClient> leader_client,
+                            StatusCallback error_fn) = 0;
+
   // Return true if the coordination service agent has been initialized.
   virtual bool IsInitialized() = 0;
 
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
index 8b20af3..3ccf374 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
@@ -418,8 +418,9 @@
  public:
   // Default behavior is to set fail_fast = False and handle timeouts
   // manually.
-  StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,
-                    const std::shared_ptr<::grpc::ClientContext>& context)
+  StreamingRPCState(
+      std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call,
+      const std::shared_ptr<::grpc::ClientContext>& context)
       : context_(context), call_(std::move(call)), call_state_(State::kActive) {
     Ref();
     VLOG(3) << "Created new StreamingRPCState " << this;
@@ -625,7 +626,7 @@
   // Order of context_ and call_ is important because context_ must outlive
   // call_.
   const std::shared_ptr<const ::grpc::ClientContext> context_;
-  std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
+  std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call_;
 
   mutable mutex mu_;
   ExchangeQueue exchanges_ TF_GUARDED_BY(mu_);
@@ -711,7 +712,7 @@
     // the channel to become ready.
     context_->set_wait_for_ready(true);
 
-    std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call =
+    std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call =
         stub_->PrepareCall(context_.get(), method_, cq_);
 
     state_.reset(new StreamingRPCState<Response>(std::move(call), context_));