Add RPC deadline for ShutdownTask, ReportErrorToTask to prevent hanging calls after service has shut down.

PiperOrigin-RevId: 437360289
diff --git a/tensorflow/core/distributed_runtime/coordination/BUILD b/tensorflow/core/distributed_runtime/coordination/BUILD
index e0d6257..08aceb9 100644
--- a/tensorflow/core/distributed_runtime/coordination/BUILD
+++ b/tensorflow/core/distributed_runtime/coordination/BUILD
@@ -56,6 +56,7 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/common_runtime:device_mgr",
+        "//tensorflow/core/distributed_runtime:call_options",
         "//tensorflow/core/protobuf:coordination_service_proto_cc",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_client.h b/tensorflow/core/distributed_runtime/coordination/coordination_client.h
index ef1749a..2a34ee7c 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_client.h
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_client.h
@@ -44,7 +44,8 @@
                                     WaitForAllTasksResponse* response,
                                     StatusCallback done) = 0;
 
-  virtual void ShutdownTaskAsync(const ShutdownTaskRequest* request,
+  virtual void ShutdownTaskAsync(CallOptions* call_opts,
+                                 const ShutdownTaskRequest* request,
                                  ShutdownTaskResponse* response,
                                  StatusCallback done) = 0;
 
@@ -52,7 +53,8 @@
                               ResetTaskResponse* response,
                               StatusCallback done) = 0;
 
-  virtual void ReportErrorToTaskAsync(const ReportErrorToTaskRequest* request,
+  virtual void ReportErrorToTaskAsync(CallOptions* call_opts,
+                                      const ReportErrorToTaskRequest* request,
                                       ReportErrorToTaskResponse* response,
                                       StatusCallback done) = 0;
 
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service.cc
index dead837..b4dae99 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service.cc
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service.cc
@@ -30,6 +30,7 @@
 #include "absl/time/time.h"
 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/distributed_runtime/call_options.h"
 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
 #include "tensorflow/core/platform/env.h"
@@ -52,6 +53,7 @@
 namespace {
 
 constexpr int kDefaultHeartbeatTimeoutMs = 10 * 1000;  // 10 seconds
+constexpr int kServiceToClientTimeoutMs = 10 * 1000;   // 10 seconds
 constexpr size_t kOngoingBarriersSoftLimit = 20;
 constexpr char kHealthCheckThread[] = "CoordinationServiceHealthCheck";
 
@@ -146,8 +148,8 @@
   void StartCheckStaleness();  // Checks both heartbeat and barrier timeouts.
   void Stop(bool shut_staleness_thread = true);
   // Report service error to a specified task.
-  void ReportServiceErrorToTask(const CoordinatedTask& destination_task,
-                                Status error);
+  void ReportServiceErrorToTaskAsync(const CoordinatedTask& destination_task,
+                                     Status error);
   // Report error from a task to all other connected tasks.
   // Note: SetTaskError() must be called before propagating its error.
   void PropagateError(const CoordinatedTask& source_task,
@@ -648,7 +650,7 @@
   return s;
 }
 
-void CoordinationServiceStandaloneImpl::ReportServiceErrorToTask(
+void CoordinationServiceStandaloneImpl::ReportServiceErrorToTaskAsync(
     const CoordinatedTask& destination_task, Status error) {
   assert(!error.ok());
 
@@ -665,11 +667,14 @@
   CoordinatedTask* error_source =
       request->mutable_error_payload()->mutable_source_task();
   error_source->set_job_name("coordination_service");
+  auto call_opts = std::make_shared<CallOptions>();
+  call_opts->SetTimeout(kServiceToClientTimeoutMs);
 
   const std::string task_name = GetTaskName(destination_task);
   CoordinationClient* client = client_cache_->GetClient(task_name);
   client->ReportErrorToTaskAsync(
-      request.get(), response.get(), [request, response, task_name](Status s) {
+      call_opts.get(), request.get(), response.get(),
+      [request, response, task_name, call_opts](Status s) {
         if (!s.ok()) {
           LOG(ERROR) << "Encountered another error while reporting to "
                      << task_name << ": " << s;
@@ -691,6 +696,8 @@
   CoordinationServiceError* payload = request.mutable_error_payload();
   *payload->mutable_source_task() = source_task;
   payload->set_is_reported_error(is_reported_by_task);
+  CallOptions call_opts;
+  call_opts.SetTimeout(kServiceToClientTimeoutMs);
   std::vector<std::shared_ptr<Notification>> notifications;
 
   std::vector<absl::string_view> task_names;
@@ -718,7 +725,7 @@
     auto response = std::make_shared<ReportErrorToTaskResponse>();
     auto n = std::make_shared<Notification>();
     client->ReportErrorToTaskAsync(
-        &request, response.get(), [response, n, task](Status s) {
+        &call_opts, &request, response.get(), [response, n, task](Status s) {
           if (!s.ok()) {
             LOG(ERROR) << "Encountered another error while reporting to "
                        << task << ": " << s;
@@ -1028,7 +1035,7 @@
         // Propagate errors to straggling tasks that have not reached the
         // barrier. The barrier must have failed if any task did not reach the
         // barrier.
-        ReportServiceErrorToTask(task, shutdown_error);
+        ReportServiceErrorToTaskAsync(task, shutdown_error);
       }
     }
   }
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc
index af3a94b..4c44e2c 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc
@@ -379,9 +379,11 @@
     ShutdownTaskRequest request;
     *request.mutable_source_task() = task_;
     ShutdownTaskResponse response;
+    CallOptions call_opts;
+    call_opts.SetTimeout(configs_.shutdown_barrier_timeout_in_ms());
 
     absl::Notification n;
-    leader_client_->ShutdownTaskAsync(&request, &response,
+    leader_client_->ShutdownTaskAsync(&call_opts, &request, &response,
                                       [&status, &n](Status s) {
                                         status = s;
                                         n.Notify();
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent_test.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent_test.cc
index 6fc3ec0..49e011f 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent_test.cc
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent_test.cc
@@ -47,7 +47,7 @@
                                       GetKeyValueResponse*, StatusCallback));
   MOCK_METHOD4(RegisterTaskAsync, void(CallOptions*, const RegisterTaskRequest*,
                                        RegisterTaskResponse*, StatusCallback));
-  MOCK_METHOD3(ShutdownTaskAsync, void(const ShutdownTaskRequest*,
+  MOCK_METHOD4(ShutdownTaskAsync, void(CallOptions*, const ShutdownTaskRequest*,
                                        ShutdownTaskResponse*, StatusCallback));
   MOCK_METHOD3(ResetTaskAsync, void(const ResetTaskRequest*, ResetTaskResponse*,
                                     StatusCallback));
@@ -64,12 +64,17 @@
 
   UNIMPLEMENTED(Heartbeat);
   UNIMPLEMENTED(WaitForAllTasks);
-  UNIMPLEMENTED(ReportErrorToTask);
   UNIMPLEMENTED(InsertKeyValue);
   UNIMPLEMENTED(DeleteKeyValue);
   UNIMPLEMENTED(Barrier);
   UNIMPLEMENTED(CancelBarrier);
 #undef UNIMPLEMENTED
+  void ReportErrorToTaskAsync(CallOptions* call_opts,
+                              const ReportErrorToTaskRequest* request,
+                              ReportErrorToTaskResponse* response,
+                              StatusCallback done) override {
+    done(errors::Unimplemented("ReportErrorToTaskAsync"));
+  }
 };
 
 class CoordinationServiceAgentTest : public ::testing::Test {
@@ -77,8 +82,8 @@
   void SetUp() override {
     ON_CALL(*client_, RegisterTaskAsync(_, _, _, _))
         .WillByDefault(InvokeArgument<3>(Status::OK()));
-    ON_CALL(*client_, ShutdownTaskAsync(_, _, _))
-        .WillByDefault(InvokeArgument<2>(Status::OK()));
+    ON_CALL(*client_, ShutdownTaskAsync(_, _, _, _))
+        .WillByDefault(InvokeArgument<3>(Status::OK()));
     ON_CALL(*client_, ReportErrorToServiceAsync(_, _, _))
         .WillByDefault(InvokeArgument<2>(Status::OK()));
     ON_CALL(*GetClient(), ResetTaskAsync(_, _, _))
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_test.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_test.cc
index c96f861..09bc8a7 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service_test.cc
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_test.cc
@@ -62,7 +62,8 @@
     done(Status::OK());
   }
 
-  void ReportErrorToTaskAsync(const ReportErrorToTaskRequest* request,
+  void ReportErrorToTaskAsync(CallOptions* call_opts,
+                              const ReportErrorToTaskRequest* request,
                               ReportErrorToTaskResponse* response,
                               StatusCallback done) override {
     mutex_lock l(mu_);
@@ -80,7 +81,6 @@
 
   UNIMPLEMENTED(Heartbeat);
   UNIMPLEMENTED(WaitForAllTasks);
-  UNIMPLEMENTED(ShutdownTask);
   UNIMPLEMENTED(ResetTask);
   UNIMPLEMENTED(ReportErrorToService);
   UNIMPLEMENTED(InsertKeyValue);
@@ -89,6 +89,12 @@
   UNIMPLEMENTED(Barrier);
   UNIMPLEMENTED(CancelBarrier);
 #undef UNIMPLEMENTED
+  void ShutdownTaskAsync(CallOptions* call_opts,
+                         const ShutdownTaskRequest* request,
+                         ShutdownTaskResponse* response,
+                         StatusCallback done) override {
+    done(errors::Unimplemented("ShutdownTaskAsync"));
+  }
 
  private:
   mutex mu_;
diff --git a/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.cc b/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.cc
index 2e07f28..c7f1dca 100644
--- a/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.cc
+++ b/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.cc
@@ -92,12 +92,13 @@
         &target_);
   }
 
-  void ShutdownTaskAsync(const ShutdownTaskRequest* request,
+  void ShutdownTaskAsync(CallOptions* call_opts,
+                         const ShutdownTaskRequest* request,
                          ShutdownTaskResponse* response,
                          StatusCallback done) override {
     new RPCState<protobuf::Message>(
         &stub_, cq_, "/tensorflow.CoordinationService/ShutdownTask", *request,
-        response, std::move(done), /*call_opts=*/nullptr,
+        response, std::move(done), call_opts,
         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
         &target_);
   }
@@ -124,12 +125,13 @@
         /*fail_fast=*/true, &target_);
   }
 
-  void ReportErrorToTaskAsync(const ReportErrorToTaskRequest* request,
+  void ReportErrorToTaskAsync(CallOptions* call_opts,
+                              const ReportErrorToTaskRequest* request,
                               ReportErrorToTaskResponse* response,
                               StatusCallback done) override {
     new RPCState<protobuf::Message>(
         &stub_, cq_, "/tensorflow.CoordinationService/ReportErrorToTask",
-        *request, response, std::move(done), /*call_opts=*/nullptr,
+        *request, response, std::move(done), call_opts,
         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
         &target_);
   }