Allow coordination service agent to be initialized with a single RPC channel (instead of an RPC channel cache).
PiperOrigin-RevId: 408408618
Change-Id: I828d1674f785ad693ca15e0be25931e2fcd7bd3c
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_client.h b/tensorflow/core/distributed_runtime/coordination/coordination_client.h
index b8bea15..1f92eb6 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_client.h
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_client.h
@@ -16,6 +16,7 @@
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_CLIENT_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_CLIENT_H_
+#include <memory>
#include <string>
namespace tensorflow {
@@ -32,6 +33,11 @@
// If the `target` names a remote task, returns a pointer of the
// CoordinationClient object wrapping that channel to the remote task.
virtual CoordinationClient* GetClient(const std::string& target) = 0;
+
+ // If the `target` names a remote task, returns an owned pointer of the
+ // CoordinationClient object wrapping that channel to the remote task.
+ virtual std::unique_ptr<CoordinationClient> GetOwnedClient(
+ const std::string& target) = 0;
};
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h
index 297a4ee..9dd8832 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h
@@ -25,6 +25,7 @@
#include "tensorflow/core/platform/statusor.h"
namespace tensorflow {
+class CoordinationServiceConfig;
class DeviceAttributes;
class WorkerEnv;
class ServerDef;
@@ -52,6 +53,12 @@
const WorkerEnv* worker_env, const ServerDef& server_def,
std::unique_ptr<CoordinationClientCache> client_cache,
StatusCallback error_fn) = 0;
+ virtual Status Initialize(const WorkerEnv* worker_env,
+ const std::string& job_name, int task_id,
+ const CoordinationServiceConfig& configs,
+ std::unique_ptr<CoordinationClient> leader_client,
+ StatusCallback error_fn) = 0;
+
// Return true if the coordination service agent has been initialized.
virtual bool IsInitialized() = 0;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
index 8b20af3..3ccf374 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
@@ -418,8 +418,9 @@
public:
// Default behavior is to set fail_fast = False and handle timeouts
// manually.
- StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,
- const std::shared_ptr<::grpc::ClientContext>& context)
+ StreamingRPCState(
+ std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call,
+ const std::shared_ptr<::grpc::ClientContext>& context)
: context_(context), call_(std::move(call)), call_state_(State::kActive) {
Ref();
VLOG(3) << "Created new StreamingRPCState " << this;
@@ -625,7 +626,7 @@
// Order of context_ and call_ is important because context_ must outlive
// call_.
const std::shared_ptr<const ::grpc::ClientContext> context_;
- std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
+ std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call_;
mutable mutex mu_;
ExchangeQueue exchanges_ TF_GUARDED_BY(mu_);
@@ -711,7 +712,7 @@
// the channel to become ready.
context_->set_wait_for_ready(true);
- std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call =
+ std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call =
stub_->PrepareCall(context_.get(), method_, cq_);
state_.reset(new StreamingRPCState<Response>(std::move(call), context_));