Add CheckPeerHealth to PeerAccessInterface

This method will be used to detect peer failures. It sends a GetStatus to the peer and verifies that the device incarnation matches the local record. Verifying incarnation is necessary to detect the case when a worker fails and restarts quickly.

This change also adds a GetDeviceAttributesCached method to DeviceResolverInterface in order to get the local record of device incarnation. We cannot use GetDeviceAttributesAsync because it sends a RPC with fail_fast=false if the device is not in cache. Such RPC may wait forever if the peer is down and never comes back.

PiperOrigin-RevId: 327711672
Change-Id: I870b7079b32f00803a1f509dfa009141c67fdf49
diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc
index 4cd9f82..ec875d0 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local.cc
@@ -108,6 +108,13 @@
                              from_alloc_attr, done);
 }
 
+void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task,
+                                                  const StatusCallback& done) {
+  // Assume local devices are always healthy.
+  done(errors::Internal(
+      "CheckPeerHealth is not supposed to be called for local collectives"));
+}
+
 /*static*/
 void CollectiveRemoteAccessLocal::MemCpyAsync(
     DeviceContext* src_dev_ctx, DeviceContext* dst_dev_ctx, Device* src_dev,
diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h
index 8a0bbd5..12aca90 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.h
+++ b/tensorflow/core/common_runtime/collective_rma_local.h
@@ -53,6 +53,9 @@
                   const DeviceLocality& client_locality,
                   const StatusCallback& done) override;
 
+  void CheckPeerHealth(const string& peer_task,
+                       const StatusCallback& done) override;
+
   BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; }
 
   // Copy utility that always copies bytes from src to dst even if
diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc
index d721fc3..2c60614 100644
--- a/tensorflow/core/common_runtime/collective_rma_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc
@@ -151,5 +151,16 @@
   EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
 }
 
+TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) {
+  Status status;
+  Notification done;
+  rma_->CheckPeerHealth(kTaskName, [&status, &done](const Status& s) {
+    status = s;
+    done.Notify();
+  });
+  done.WaitForNotification();
+  EXPECT_TRUE(errors::IsInternal(status));
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_resolver_local.cc b/tensorflow/core/common_runtime/device_resolver_local.cc
index 12e1e28..9a898e7 100644
--- a/tensorflow/core/common_runtime/device_resolver_local.cc
+++ b/tensorflow/core/common_runtime/device_resolver_local.cc
@@ -46,4 +46,10 @@
   done(s);
 }
 
+Status DeviceResolverLocal::GetTaskCached(
+    const string& task, std::vector<DeviceAttributes>* attributes) {
+  return errors::Internal(
+      "GetTaskCached is not supposed to be called in local collectives");
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_resolver_local.h b/tensorflow/core/common_runtime/device_resolver_local.h
index 53a3c87..12b7dce 100644
--- a/tensorflow/core/common_runtime/device_resolver_local.h
+++ b/tensorflow/core/common_runtime/device_resolver_local.h
@@ -39,6 +39,9 @@
                                 DeviceAttributes* attributes,
                                 const StatusCallback& done) override;
 
+  Status GetTaskCached(const string& task,
+                       std::vector<DeviceAttributes>* attributes) override;
+
   void ClearTask(const string& task) override {}
 
   void ClearCache() override {}
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 505e0c3..94570c1 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -538,6 +538,7 @@
         "//tensorflow/core:lib_internal",  # protobuf::Any
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:worker_proto_cc",
+        "@com_google_absl//absl/memory",
     ],
 )
 
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
index dbc9417..4215b16 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -185,6 +185,62 @@
                                           dev_attributes_callback);
 }
 
+void CollectiveRemoteAccessDistributed::CheckPeerHealth(
+    const string& peer_task, const StatusCallback& done) {
+  if (peer_task == task_name_) {
+    // Fast path if the peer is the worker itself.
+    done(Status::OK());
+    return;
+  }
+  // We send a GetStatus RPC with fail_fast=false to check the health of a peer
+  // task. If the RPC succeeds, we verify if the peer_device incarnation matches
+  // the local record if we have it. Note that DeviceResolverInterface always
+  // caches the device attributes.
+  WorkerInterface* wi = worker_cache_->GetOrCreateWorker(peer_task);
+  if (wi == nullptr) {
+    done(errors::InvalidArgument(peer_task,
+                                 " not found. It's probably in valid. The "
+                                 "valid form is /job:xxx/replica:0/task:N"));
+    return;
+  }
+  auto req = new GetStatusRequest();
+  auto resp = new GetStatusResponse();
+  // We're not using Cancellable call because GetStatusAsync doesn't support
+  // cancellation yet.
+  wi->GetStatusAsync(
+      req, resp, /*fail_fast*/ true,
+      [this, req, resp, wi, peer_task, done](Status s) {
+        std::vector<DeviceAttributes> cached_attrs;
+        if (s.ok()) {
+          s = dev_resolver_->GetTaskCached(peer_task, &cached_attrs);
+        }
+        if (s.ok()) {
+          absl::flat_hash_set<uint64> remote_incarnations;
+          for (const DeviceAttributes& da : resp->device_attributes()) {
+            remote_incarnations.insert(da.incarnation());
+          }
+          for (const DeviceAttributes& attr : cached_attrs) {
+            if (!remote_incarnations.contains(attr.incarnation())) {
+              s = errors::FailedPrecondition(
+                  attr.name(), " with incarnation ", attr.incarnation(),
+                  " is not available. This usually means ", peer_task,
+                  " has restarted");
+              break;
+            }
+          }
+        } else if (errors::IsNotFound(s)) {
+          // Skip validating device incarnation if we don't know what the
+          // incarnation should be. The device attribute is cached after the
+          // first collective.
+          s = Status::OK();
+        }
+        delete req;
+        delete resp;
+        worker_cache_->ReleaseWorker(peer_task, wi);
+        done(s);
+      });
+}
+
 void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
   CollectiveRemoteAccessLocal::StartAbort(s);
   cancel_mgr_.StartCancel();
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
index d6546e3..ed4d448 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.h
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
@@ -28,10 +28,11 @@
   CollectiveRemoteAccessDistributed(
       const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
       std::shared_ptr<UnboundedWorkQueue> work_queue,
-      WorkerCacheInterface* worker_cache, int64 step_id)
+      WorkerCacheInterface* worker_cache, int64 step_id, string task_name)
       : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
         worker_cache_(worker_cache),
-        work_queue_(std::move(work_queue)) {}
+        work_queue_(std::move(work_queue)),
+        task_name_(std::move(task_name)) {}
 
   ~CollectiveRemoteAccessDistributed() override {}
 
@@ -43,6 +44,9 @@
                     int dev_to_dev_stream_index,
                     const StatusCallback& done) override;
 
+  void CheckPeerHealth(const string& peer_task,
+                       const StatusCallback& done) override;
+
   void StartAbort(const Status& s) override;
 
  protected:
@@ -51,6 +55,7 @@
   // `CollectiveExecutorMgr`.
   std::shared_ptr<UnboundedWorkQueue> work_queue_;
   CancellationManager cancel_mgr_;
+  string task_name_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
index 2975442..b6975e4 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -63,11 +63,12 @@
 class FakeWorker : public TestWorkerInterface {
  public:
   FakeWorker(const string& name, DeviceMgr* dev_mgr,
-             DeviceResolverDistributed* dres)
+             DeviceResolverDistributed* dres, bool is_failed)
       : name_(name),
         device_mgr_(dev_mgr),
         device_resolver_(dres),
-        buf_rendezvous_(kStepId, dev_mgr) {}
+        buf_rendezvous_(kStepId, dev_mgr),
+        is_failed_(is_failed) {}
 
   // Direct access to a BufRendezvous that holds whatever the remote
   // worker is supposed to have.
@@ -76,6 +77,10 @@
   void GetStatusAsync(const GetStatusRequest* request,
                       GetStatusResponse* response, bool fail_fast,
                       StatusCallback done) override {
+    if (is_failed_) {
+      done(errors::Unavailable("peer down"));
+      return;
+    }
     std::vector<DeviceAttributes> dev_attr;
     device_mgr_->ListDeviceAttributes(&dev_attr);
     for (const auto& da : dev_attr) {
@@ -86,6 +91,10 @@
 
   void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
                     RecvBufResponse* response, StatusCallback done) override {
+    if (is_failed_) {
+      done(errors::Unavailable("peer down"));
+      return;
+    }
     opts->SetCancelCallback([this]() {
       // Within this test the call is satisfied by a process-local
       // BufRendezvous table. In real application the BufRendezvous
@@ -125,6 +134,7 @@
   DeviceMgr* device_mgr_;
   DeviceResolverDistributed* device_resolver_;
   BufRendezvous buf_rendezvous_;
+  bool is_failed_;
 };
 
 class FakeCache : public TestWorkerCache {
@@ -201,7 +211,7 @@
     // All tests simulate requests from worker 0 to worker 1.
     rma_.reset(new CollectiveRemoteAccessDistributed(
         device_mgrs_[0], dev_resolvers_[dev0_worker_name], work_queue_, &wc_,
-        kStepId));
+        kStepId, "/job:worker/replica:0/task:0"));
 
     const int kNumElts = 8;
     expected_value_ = Tensor(DT_FLOAT, {kNumElts});
@@ -215,7 +225,7 @@
   }
 
   void DefineWorker(const string& worker_name, const string& device_type,
-                    int num_devices) {
+                    int num_devices, bool is_failed = false) {
     std::vector<std::unique_ptr<Device>> devices;
     for (int i = 0; i < num_devices; ++i) {
       devices.push_back(NewDevice(
@@ -232,19 +242,19 @@
     DeviceResolverDistributed* dev_res =
         new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
     dev_resolvers_[worker_name] = dev_res;
-    FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
+    FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res, is_failed);
     workers_.push_back(fw);
     wc_.AddWorker(worker_name, fw);
   }
 
   void RestartWorker(const string& worker_name, const string& device_type,
-                     int num_devices) {
+                     int num_devices, bool is_failed = false) {
     auto it = dev_resolvers_.find(worker_name);
     if (it != dev_resolvers_.end()) {
       delete it->second;
       dev_resolvers_.erase(it);
     }
-    DefineWorker(worker_name, device_type, num_devices);
+    DefineWorker(worker_name, device_type, num_devices, is_failed);
   }
 
   void ValidateResultTensor() {
@@ -401,7 +411,7 @@
   ValidateResultTensor();
 
   // Restart task 1 and check that recv from task 1 to task 0 fails.
-  RestartWorker("/job:worker/replica:0/task:1", "CPU", 1);
+  RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
   Notification post_restart_note;
   rma_->RecvFromPeer(
       "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
@@ -417,5 +427,139 @@
   EXPECT_TRUE(errors::IsFailedPrecondition(consumer_status));
 }
 
+TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) {
+  DeviceAttributes attr;
+  Status get_attr_status;
+  Notification get_attr_done;
+  // Call GetDeviceAttributesAsync to cache the device attributes of a remote
+  // worker.
+  dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
+      "/job:worker/replica:0/task:1/device:CPU:0",
+      "/job:worker/replica:0/task:1", &attr,
+      [&get_attr_status, &get_attr_done](const Status& s) {
+        get_attr_status = s;
+        get_attr_done.Notify();
+      });
+  get_attr_done.WaitForNotification();
+  TF_ASSERT_OK(get_attr_status);
+
+  Status check_health_status;
+  Notification check_health_done;
+  rma_->CheckPeerHealth(
+      "/job:worker/replica:0/task:1",
+      [&check_health_status, &check_health_done](const Status s) {
+        check_health_status = s;
+        check_health_done.Notify();
+      });
+  check_health_done.WaitForNotification();
+  TF_EXPECT_OK(check_health_status);
+}
+
+TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
+  Status check_health_status;
+  Notification check_health_done;
+  rma_->CheckPeerHealth(
+      "/job:worker/replica:0/task:1",
+      [&check_health_status, &check_health_done](const Status s) {
+        check_health_status = s;
+        check_health_done.Notify();
+      });
+  check_health_done.WaitForNotification();
+  EXPECT_TRUE(check_health_status.ok());
+}
+
+TEST_F(CollRMADistTest, CheckHealthRestarted) {
+  DeviceAttributes attr;
+  Status get_attr_status;
+  Notification get_attr_done;
+  // Call GetDeviceAttributesAsync to cache the device attributes of a remote
+  // worker.
+  dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
+      "/job:worker/replica:0/task:1/device:CPU:0",
+      "/job:worker/replica:0/task:1", &attr,
+      [&get_attr_status, &get_attr_done](const Status& s) {
+        get_attr_status = s;
+        get_attr_done.Notify();
+      });
+  get_attr_done.WaitForNotification();
+  TF_ASSERT_OK(get_attr_status);
+
+  RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
+
+  Status check_health_status;
+  Notification check_health_done;
+  rma_->CheckPeerHealth(
+      "/job:worker/replica:0/task:1",
+      [&check_health_status, &check_health_done](const Status s) {
+        check_health_status = s;
+        check_health_done.Notify();
+      });
+  check_health_done.WaitForNotification();
+  EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
+}
+
+TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
+  DeviceAttributes attr;
+  Status get_attr_status;
+  Notification get_attr_done;
+  // Call GetDeviceAttributesAsync to cache the device attributes of a remote
+  // worker.
+  dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
+      "/job:worker/replica:0/task:1/device:CPU:0",
+      "/job:worker/replica:0/task:1", &attr,
+      [&get_attr_status, &get_attr_done](const Status& s) {
+        get_attr_status = s;
+        get_attr_done.Notify();
+      });
+  get_attr_done.WaitForNotification();
+  TF_ASSERT_OK(get_attr_status);
+
+  RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1,
+                /*is_failed*/ true);
+
+  Status check_health_status;
+  Notification check_health_done;
+  rma_->CheckPeerHealth(
+      "/job:worker/replica:0/task:1",
+      [&check_health_status, &check_health_done](const Status s) {
+        check_health_status = s;
+        check_health_done.Notify();
+      });
+  check_health_done.WaitForNotification();
+  EXPECT_TRUE(errors::IsUnavailable(check_health_status));
+}
+
+TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
+  RestartWorker("/job:worker/replica:0/task:1", "GPU", /*num_devices*/ 1);
+
+  DeviceAttributes attr;
+  Status get_attr_status;
+  Notification get_attr_done;
+  // Call GetDeviceAttributesAsync to cache the device attributes of a remote
+  // worker.
+  dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
+      "/job:worker/replica:0/task:1/device:GPU:0",
+      "/job:worker/replica:0/task:1", &attr,
+      [&get_attr_status, &get_attr_done](const Status& s) {
+        get_attr_status = s;
+        get_attr_done.Notify();
+      });
+  get_attr_done.WaitForNotification();
+  TF_ASSERT_OK(get_attr_status);
+
+  RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
+
+  Status check_health_status;
+  Notification check_health_done;
+  rma_->CheckPeerHealth(
+      "/job:worker/replica:0/task:1",
+      [&check_health_status, &check_health_done](const Status s) {
+        check_health_status = s;
+        check_health_done.Notify();
+      });
+  check_health_done.WaitForNotification();
+  EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc
index 927925c..ab0b3a6 100644
--- a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc
@@ -113,6 +113,22 @@
       });
 }
 
+Status DeviceResolverDistributed::GetTaskCached(
+    const string& task, std::vector<DeviceAttributes>* attributes) {
+  mutex_lock l(mu_);
+  attributes->clear();
+  for (const auto& it : attr_table_) {
+    const string& device_name = it.first;
+    if (DeviceNameUtils::IsSameAddressSpace(task, device_name)) {
+      attributes->push_back(it.second);
+    }
+  }
+  if (attributes->empty()) {
+    return errors::NotFound(task, " not found in the cache");
+  }
+  return Status::OK();
+}
+
 void DeviceResolverDistributed::ClearTask(const string& task) {
   mutex_lock l(mu_);
   // First find all the keys belonging to the task.
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.h b/tensorflow/core/distributed_runtime/device_resolver_distributed.h
index 93d51a5..d400fb5 100644
--- a/tensorflow/core/distributed_runtime/device_resolver_distributed.h
+++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.h
@@ -43,6 +43,9 @@
                                 DeviceAttributes* attributes,
                                 const StatusCallback& done) override;
 
+  Status GetTaskCached(const string& task,
+                       std::vector<DeviceAttributes>* attributes) override;
+
   void ClearTask(const string& task) override;
 
   void ClearCache() override;
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
index 4fbc4bb..62a67b5 100644
--- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
@@ -47,8 +47,9 @@
 
 CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
   CollectiveRemoteAccessDistributed* rma =
-      new CollectiveRemoteAccessDistributed(
-          dev_mgr_, dev_resolver_.get(), work_queue_, worker_cache_, step_id);
+      new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
+                                            work_queue_, worker_cache_, step_id,
+                                            task_name_);
   return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
                                     &gpu_ring_order_, work_queue_);
 }
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index 05eefed..d0c5323 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -161,17 +161,20 @@
       std::vector<DeviceAttributes>* attributes,
       const StatusCallback& done) = 0;
 
-  // Populate *attributes with the DeviceAttributes of the specified
-  // device.
+  // Populates *attributes with the DeviceAttributes of the specified device.
   virtual void GetDeviceAttributesAsync(const string& device,
                                         const string& task,
                                         DeviceAttributes* attributes,
                                         const StatusCallback& done) = 0;
 
-  // Clear the cache of device data belonging to the specified task.
+  // Returns the cached device attributes of a task.
+  virtual Status GetTaskCached(const string& task,
+                               std::vector<DeviceAttributes>* attributes) = 0;
+
+  // Clears the cache of device data belonging to the specified task.
   virtual void ClearTask(const string& task) = 0;
 
-  // Clear the cache of all device data.
+  // Clears the cache of all device data.
   virtual void ClearCache() = 0;
 };
 
@@ -279,6 +282,12 @@
                           const DeviceLocality& client_locality,
                           const StatusCallback& done) = 0;
 
+  // Checks the health of a collective peer. It probes the peer to see if it is
+  // alive. Note that if a peer has restarted, it's considered a different one,
+  // so CheckPeerHealth fails.
+  virtual void CheckPeerHealth(const string& peer_task,
+                               const StatusCallback& done) = 0;
+
   virtual BufRendezvous* buf_rendezvous() = 0;
 
   virtual void StartAbort(const Status& s) = 0;