Replace coordination service agent dependency on worker environment with OS Env + DeviceMgr.
PiperOrigin-RevId: 416144518
Change-Id: I754ba9db887cb9f0d4e001985b3790bbc5c4eeb6
diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
index f22aaad..5b3dfef 100644
--- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
+++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
@@ -15,6 +15,11 @@
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
+#include <algorithm>
+#include <numeric>
+#include <string>
+#include <utility>
+
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@@ -719,8 +724,8 @@
LOG_AND_RETURN_IF_ERROR(
worker_cache->GetCoordinationClientCache(&agent_cache));
LOG_AND_RETURN_IF_ERROR(coordination_service_agent_->Initialize(
- server->worker_env(), server_def, std::move(agent_cache),
- [this](Status s) {
+ server->worker_env()->env, server->worker_env()->device_mgr,
+ server_def, std::move(agent_cache), [this](Status s) {
context_->GetCollectiveExecutorHandle()->get()->StartAbort(s);
}));
LOG_AND_RETURN_IF_ERROR(coordination_service_agent_->Connect());
diff --git a/tensorflow/core/distributed_runtime/coordination/BUILD b/tensorflow/core/distributed_runtime/coordination/BUILD
index 7931e62..6a8e75e 100644
--- a/tensorflow/core/distributed_runtime/coordination/BUILD
+++ b/tensorflow/core/distributed_runtime/coordination/BUILD
@@ -88,13 +88,9 @@
hdrs = ["coordination_service_agent.h"],
deps = [
":coordination_client",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime:device_mgr",
- "//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/protobuf:coordination_service_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/synchronization",
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc
index 1592f60..55f7dc9 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc
@@ -21,7 +21,6 @@
#include "absl/synchronization/notification.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
-#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
@@ -44,11 +43,13 @@
public:
CoordinationServiceAgentImpl() = default;
~CoordinationServiceAgentImpl() override { Stop(); }
- Status Initialize(const WorkerEnv* worker_env, const ServerDef& server_def,
+ Status Initialize(Env* env, const DeviceMgr* device_mgr,
+ const ServerDef& server_def,
std::unique_ptr<CoordinationClientCache> client_cache,
StatusCallback error_fn) override;
- Status Initialize(const WorkerEnv* worker_env, const std::string& job_name,
- int task_id, const CoordinationServiceConfig& configs,
+ Status Initialize(Env* env, const DeviceMgr* device_mgr,
+ const std::string& job_name, int task_id,
+ const CoordinationServiceConfig& configs,
std::unique_ptr<CoordinationClient> leader_client,
StatusCallback error_fn) override;
bool IsInitialized() override;
@@ -81,7 +82,8 @@
void Stop();
private:
- const WorkerEnv* env_;
+ Env* env_; // Not owned.
+ const DeviceMgr* device_mgr_; // Not owned.
const int64_t incarnation_id_ = random::New64();
std::string job_name_;
int task_id_;
@@ -111,7 +113,7 @@
};
Status CoordinationServiceAgentImpl::Initialize(
- const WorkerEnv* env, const ServerDef& server_def,
+ Env* env, const DeviceMgr* device_mgr, const ServerDef& server_def,
std::unique_ptr<CoordinationClientCache> client_cache,
StatusCallback error_fn) {
CoordinationServiceConfig configs =
@@ -133,13 +135,13 @@
}
}
return Initialize(
- env, server_def.job_name(), server_def.task_index(), configs,
+ env, device_mgr, server_def.job_name(), server_def.task_index(), configs,
client_cache->GetOwnedClient(configs.service_leader()), error_fn);
}
Status CoordinationServiceAgentImpl::Initialize(
- const WorkerEnv* worker_env, const std::string& job_name, int task_id,
- const CoordinationServiceConfig& configs,
+ Env* env, const DeviceMgr* device_mgr, const std::string& job_name,
+ int task_id, const CoordinationServiceConfig& configs,
std::unique_ptr<CoordinationClient> leader_client,
StatusCallback error_fn) {
mutex_lock l(state_mu_);
@@ -148,7 +150,8 @@
"Coordination service agent has already been initialized.");
}
- env_ = worker_env;
+ env_ = env;
+ device_mgr_ = device_mgr;
job_name_ = job_name;
task_id_ = task_id;
configs_ = configs;
@@ -227,8 +230,8 @@
}
}
- heartbeat_thread_.reset(env_->env->StartThread(
- ThreadOptions(), kHeartbeatThread, [this]() -> void {
+ heartbeat_thread_.reset(
+ env_->StartThread(ThreadOptions(), kHeartbeatThread, [this]() -> void {
HeartbeatRequest request;
request.set_job(job_name_);
request.set_task(task_id_);
@@ -282,7 +285,7 @@
request.set_job(job_name_);
request.set_task(task_id_);
std::vector<DeviceAttributes> devices;
- env_->device_mgr->ListDeviceAttributes(&devices);
+ device_mgr_->ListDeviceAttributes(&devices);
for (auto& d : devices) {
request.add_local_device_attributes()->Swap(&d);
}
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h
index 9dd8832..acd8a3b 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h
@@ -27,7 +27,8 @@
namespace tensorflow {
class CoordinationServiceConfig;
class DeviceAttributes;
-class WorkerEnv;
+class DeviceMgr;
+class Env;
class ServerDef;
// CoordinationServiceAgent defines the interface for tasks to communicate with
@@ -50,10 +51,10 @@
// Initialize coordination service agent.
virtual Status Initialize(
- const WorkerEnv* worker_env, const ServerDef& server_def,
+ Env* env, const DeviceMgr* device_mgr, const ServerDef& server_def,
std::unique_ptr<CoordinationClientCache> client_cache,
StatusCallback error_fn) = 0;
- virtual Status Initialize(const WorkerEnv* worker_env,
+ virtual Status Initialize(Env* env, const DeviceMgr* device_mgr,
const std::string& job_name, int task_id,
const CoordinationServiceConfig& configs,
std::unique_ptr<CoordinationClient> leader_client,
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index b781263..5150085 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -15,6 +15,10 @@
#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
+#include <functional>
+#include <string>
+#include <utility>
+
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h"
#include "absl/types/optional.h"
@@ -326,7 +330,8 @@
&client_cache));
TF_RETURN_IF_ERROR(
ctx->GetDistributedManager()->GetCoordinationServiceAgent()->Initialize(
- env_, request->server_def(), std::move(client_cache),
+ env_->env, env_->device_mgr, request->server_def(),
+ std::move(client_cache),
/*error_fn=*/[](Status s) {
LOG(ERROR) << "Coordination agent is set to error: " << s;
}));
diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc
index 6530ac5..a00cd48 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr.cc
@@ -209,7 +209,8 @@
TF_RETURN_IF_ERROR(worker_cache->GetCoordinationClientCache(&agent_cache));
coordination_service_agent_ = CreateCoordinationServiceAgent();
TF_RETURN_IF_ERROR(coordination_service_agent_->Initialize(
- worker_env_, server_def, std::move(agent_cache),
+ worker_env_->env, worker_env_->device_mgr, server_def,
+ std::move(agent_cache),
/*error_fn=*/[](Status s) {
LOG(ERROR) << "Coordination agent is set to error: " << s;
}));