Replace coordination service agent dependency on worker environment with OS Env + DeviceMgr.

PiperOrigin-RevId: 416144518
Change-Id: I754ba9db887cb9f0d4e001985b3790bbc5c4eeb6
diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
index f22aaad..5b3dfef 100644
--- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
+++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
@@ -15,6 +15,11 @@
 
 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
 
+#include <algorithm>
+#include <numeric>
+#include <string>
+#include <utility>
+
 #include "tensorflow/core/common_runtime/copy_tensor.h"
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
@@ -719,8 +724,8 @@
       LOG_AND_RETURN_IF_ERROR(
           worker_cache->GetCoordinationClientCache(&agent_cache));
       LOG_AND_RETURN_IF_ERROR(coordination_service_agent_->Initialize(
-          server->worker_env(), server_def, std::move(agent_cache),
-          [this](Status s) {
+          server->worker_env()->env, server->worker_env()->device_mgr,
+          server_def, std::move(agent_cache), [this](Status s) {
             context_->GetCollectiveExecutorHandle()->get()->StartAbort(s);
           }));
       LOG_AND_RETURN_IF_ERROR(coordination_service_agent_->Connect());
diff --git a/tensorflow/core/distributed_runtime/coordination/BUILD b/tensorflow/core/distributed_runtime/coordination/BUILD
index 7931e62..6a8e75e 100644
--- a/tensorflow/core/distributed_runtime/coordination/BUILD
+++ b/tensorflow/core/distributed_runtime/coordination/BUILD
@@ -88,13 +88,9 @@
     hdrs = ["coordination_service_agent.h"],
     deps = [
         ":coordination_client",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/common_runtime:device_mgr",
-        "//tensorflow/core/distributed_runtime:worker_env",
         "//tensorflow/core/protobuf:coordination_service_proto_cc",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/synchronization",
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc
index 1592f60..55f7dc9 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc
@@ -21,7 +21,6 @@
 #include "absl/synchronization/notification.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
-#include "tensorflow/core/distributed_runtime/worker_env.h"
 #include "tensorflow/core/framework/device_attributes.pb.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
@@ -44,11 +43,13 @@
  public:
   CoordinationServiceAgentImpl() = default;
   ~CoordinationServiceAgentImpl() override { Stop(); }
-  Status Initialize(const WorkerEnv* worker_env, const ServerDef& server_def,
+  Status Initialize(Env* env, const DeviceMgr* device_mgr,
+                    const ServerDef& server_def,
                     std::unique_ptr<CoordinationClientCache> client_cache,
                     StatusCallback error_fn) override;
-  Status Initialize(const WorkerEnv* worker_env, const std::string& job_name,
-                    int task_id, const CoordinationServiceConfig& configs,
+  Status Initialize(Env* env, const DeviceMgr* device_mgr,
+                    const std::string& job_name, int task_id,
+                    const CoordinationServiceConfig& configs,
                     std::unique_ptr<CoordinationClient> leader_client,
                     StatusCallback error_fn) override;
   bool IsInitialized() override;
@@ -81,7 +82,8 @@
   void Stop();
 
  private:
-  const WorkerEnv* env_;
+  Env* env_;                     // Not owned.
+  const DeviceMgr* device_mgr_;  // Not owned.
   const int64_t incarnation_id_ = random::New64();
   std::string job_name_;
   int task_id_;
@@ -111,7 +113,7 @@
 };
 
 Status CoordinationServiceAgentImpl::Initialize(
-    const WorkerEnv* env, const ServerDef& server_def,
+    Env* env, const DeviceMgr* device_mgr, const ServerDef& server_def,
     std::unique_ptr<CoordinationClientCache> client_cache,
     StatusCallback error_fn) {
   CoordinationServiceConfig configs =
@@ -133,13 +135,13 @@
     }
   }
   return Initialize(
-      env, server_def.job_name(), server_def.task_index(), configs,
+      env, device_mgr, server_def.job_name(), server_def.task_index(), configs,
       client_cache->GetOwnedClient(configs.service_leader()), error_fn);
 }
 
 Status CoordinationServiceAgentImpl::Initialize(
-    const WorkerEnv* worker_env, const std::string& job_name, int task_id,
-    const CoordinationServiceConfig& configs,
+    Env* env, const DeviceMgr* device_mgr, const std::string& job_name,
+    int task_id, const CoordinationServiceConfig& configs,
     std::unique_ptr<CoordinationClient> leader_client,
     StatusCallback error_fn) {
   mutex_lock l(state_mu_);
@@ -148,7 +150,8 @@
         "Coordination service agent has already been initialized.");
   }
 
-  env_ = worker_env;
+  env_ = env;
+  device_mgr_ = device_mgr;
   job_name_ = job_name;
   task_id_ = task_id;
   configs_ = configs;
@@ -227,8 +230,8 @@
     }
   }
 
-  heartbeat_thread_.reset(env_->env->StartThread(
-      ThreadOptions(), kHeartbeatThread, [this]() -> void {
+  heartbeat_thread_.reset(
+      env_->StartThread(ThreadOptions(), kHeartbeatThread, [this]() -> void {
         HeartbeatRequest request;
         request.set_job(job_name_);
         request.set_task(task_id_);
@@ -282,7 +285,7 @@
   request.set_job(job_name_);
   request.set_task(task_id_);
   std::vector<DeviceAttributes> devices;
-  env_->device_mgr->ListDeviceAttributes(&devices);
+  device_mgr_->ListDeviceAttributes(&devices);
   for (auto& d : devices) {
     request.add_local_device_attributes()->Swap(&d);
   }
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h
index 9dd8832..acd8a3b 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h
@@ -27,7 +27,8 @@
 namespace tensorflow {
 class CoordinationServiceConfig;
 class DeviceAttributes;
-class WorkerEnv;
+class DeviceMgr;
+class Env;
 class ServerDef;
 
 // CoordinationServiceAgent defines the interface for tasks to communicate with
@@ -50,10 +51,10 @@
 
   // Initialize coordination service agent.
   virtual Status Initialize(
-      const WorkerEnv* worker_env, const ServerDef& server_def,
+      Env* env, const DeviceMgr* device_mgr, const ServerDef& server_def,
       std::unique_ptr<CoordinationClientCache> client_cache,
       StatusCallback error_fn) = 0;
-  virtual Status Initialize(const WorkerEnv* worker_env,
+  virtual Status Initialize(Env* env, const DeviceMgr* device_mgr,
                             const std::string& job_name, int task_id,
                             const CoordinationServiceConfig& configs,
                             std::unique_ptr<CoordinationClient> leader_client,
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index b781263..5150085 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -15,6 +15,10 @@
 
 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
 
+#include <functional>
+#include <string>
+#include <utility>
+
 #include "absl/container/fixed_array.h"
 #include "absl/memory/memory.h"
 #include "absl/types/optional.h"
@@ -326,7 +330,8 @@
             &client_cache));
     TF_RETURN_IF_ERROR(
         ctx->GetDistributedManager()->GetCoordinationServiceAgent()->Initialize(
-            env_, request->server_def(), std::move(client_cache),
+            env_->env, env_->device_mgr, request->server_def(),
+            std::move(client_cache),
             /*error_fn=*/[](Status s) {
               LOG(ERROR) << "Coordination agent is set to error: " << s;
             }));
diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc
index 6530ac5..a00cd48 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr.cc
@@ -209,7 +209,8 @@
     TF_RETURN_IF_ERROR(worker_cache->GetCoordinationClientCache(&agent_cache));
     coordination_service_agent_ = CreateCoordinationServiceAgent();
     TF_RETURN_IF_ERROR(coordination_service_agent_->Initialize(
-        worker_env_, server_def, std::move(agent_cache),
+        worker_env_->env, worker_env_->device_mgr, server_def,
+        std::move(agent_cache),
         /*error_fn=*/[](Status s) {
           LOG(ERROR) << "Coordination agent is set to error: " << s;
         }));