Add RPC deadline for ShutdownTask, ReportErrorToTask to prevent hanging calls after service has shut down.
PiperOrigin-RevId: 437360289
diff --git a/tensorflow/core/distributed_runtime/coordination/BUILD b/tensorflow/core/distributed_runtime/coordination/BUILD
index e0d6257..08aceb9 100644
--- a/tensorflow/core/distributed_runtime/coordination/BUILD
+++ b/tensorflow/core/distributed_runtime/coordination/BUILD
@@ -56,6 +56,7 @@
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime:device_mgr",
+ "//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/protobuf:coordination_service_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_client.h b/tensorflow/core/distributed_runtime/coordination/coordination_client.h
index ef1749a..2a34ee7c 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_client.h
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_client.h
@@ -44,7 +44,8 @@
WaitForAllTasksResponse* response,
StatusCallback done) = 0;
- virtual void ShutdownTaskAsync(const ShutdownTaskRequest* request,
+ virtual void ShutdownTaskAsync(CallOptions* call_opts,
+ const ShutdownTaskRequest* request,
ShutdownTaskResponse* response,
StatusCallback done) = 0;
@@ -52,7 +53,8 @@
ResetTaskResponse* response,
StatusCallback done) = 0;
- virtual void ReportErrorToTaskAsync(const ReportErrorToTaskRequest* request,
+ virtual void ReportErrorToTaskAsync(CallOptions* call_opts,
+ const ReportErrorToTaskRequest* request,
ReportErrorToTaskResponse* response,
StatusCallback done) = 0;
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service.cc
index dead837..b4dae99 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service.cc
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service.cc
@@ -30,6 +30,7 @@
#include "absl/time/time.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
#include "tensorflow/core/platform/env.h"
@@ -52,6 +53,7 @@
namespace {
constexpr int kDefaultHeartbeatTimeoutMs = 10 * 1000; // 10 seconds
+constexpr int kServiceToClientTimeoutMs = 10 * 1000; // 10 seconds
constexpr size_t kOngoingBarriersSoftLimit = 20;
constexpr char kHealthCheckThread[] = "CoordinationServiceHealthCheck";
@@ -146,8 +148,8 @@
void StartCheckStaleness(); // Checks both heartbeat and barrier timeouts.
void Stop(bool shut_staleness_thread = true);
// Report service error to a specified task.
- void ReportServiceErrorToTask(const CoordinatedTask& destination_task,
- Status error);
+ void ReportServiceErrorToTaskAsync(const CoordinatedTask& destination_task,
+ Status error);
// Report error from a task to all other connected tasks.
// Note: SetTaskError() must be called before propagating its error.
void PropagateError(const CoordinatedTask& source_task,
@@ -648,7 +650,7 @@
return s;
}
-void CoordinationServiceStandaloneImpl::ReportServiceErrorToTask(
+void CoordinationServiceStandaloneImpl::ReportServiceErrorToTaskAsync(
const CoordinatedTask& destination_task, Status error) {
assert(!error.ok());
@@ -665,11 +667,14 @@
CoordinatedTask* error_source =
request->mutable_error_payload()->mutable_source_task();
error_source->set_job_name("coordination_service");
+ auto call_opts = std::make_shared<CallOptions>();
+ call_opts->SetTimeout(kServiceToClientTimeoutMs);
const std::string task_name = GetTaskName(destination_task);
CoordinationClient* client = client_cache_->GetClient(task_name);
client->ReportErrorToTaskAsync(
- request.get(), response.get(), [request, response, task_name](Status s) {
+ call_opts.get(), request.get(), response.get(),
+ [request, response, task_name, call_opts](Status s) {
if (!s.ok()) {
LOG(ERROR) << "Encountered another error while reporting to "
<< task_name << ": " << s;
@@ -691,6 +696,8 @@
CoordinationServiceError* payload = request.mutable_error_payload();
*payload->mutable_source_task() = source_task;
payload->set_is_reported_error(is_reported_by_task);
+ CallOptions call_opts;
+ call_opts.SetTimeout(kServiceToClientTimeoutMs);
std::vector<std::shared_ptr<Notification>> notifications;
std::vector<absl::string_view> task_names;
@@ -718,7 +725,7 @@
auto response = std::make_shared<ReportErrorToTaskResponse>();
auto n = std::make_shared<Notification>();
client->ReportErrorToTaskAsync(
- &request, response.get(), [response, n, task](Status s) {
+ &call_opts, &request, response.get(), [response, n, task](Status s) {
if (!s.ok()) {
LOG(ERROR) << "Encountered another error while reporting to "
<< task << ": " << s;
@@ -1028,7 +1035,7 @@
// Propagate errors to straggling tasks that have not reached the
// barrier. The barrier must have failed if any task did not reach the
// barrier.
- ReportServiceErrorToTask(task, shutdown_error);
+ ReportServiceErrorToTaskAsync(task, shutdown_error);
}
}
}
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc
index af3a94b..4c44e2c 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent.cc
@@ -379,9 +379,11 @@
ShutdownTaskRequest request;
*request.mutable_source_task() = task_;
ShutdownTaskResponse response;
+ CallOptions call_opts;
+ call_opts.SetTimeout(configs_.shutdown_barrier_timeout_in_ms());
absl::Notification n;
- leader_client_->ShutdownTaskAsync(&request, &response,
+ leader_client_->ShutdownTaskAsync(&call_opts, &request, &response,
[&status, &n](Status s) {
status = s;
n.Notify();
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent_test.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent_test.cc
index 6fc3ec0..49e011f 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service_agent_test.cc
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_agent_test.cc
@@ -47,7 +47,7 @@
GetKeyValueResponse*, StatusCallback));
MOCK_METHOD4(RegisterTaskAsync, void(CallOptions*, const RegisterTaskRequest*,
RegisterTaskResponse*, StatusCallback));
- MOCK_METHOD3(ShutdownTaskAsync, void(const ShutdownTaskRequest*,
+ MOCK_METHOD4(ShutdownTaskAsync, void(CallOptions*, const ShutdownTaskRequest*,
ShutdownTaskResponse*, StatusCallback));
MOCK_METHOD3(ResetTaskAsync, void(const ResetTaskRequest*, ResetTaskResponse*,
StatusCallback));
@@ -64,12 +64,17 @@
UNIMPLEMENTED(Heartbeat);
UNIMPLEMENTED(WaitForAllTasks);
- UNIMPLEMENTED(ReportErrorToTask);
UNIMPLEMENTED(InsertKeyValue);
UNIMPLEMENTED(DeleteKeyValue);
UNIMPLEMENTED(Barrier);
UNIMPLEMENTED(CancelBarrier);
#undef UNIMPLEMENTED
+ void ReportErrorToTaskAsync(CallOptions* call_opts,
+ const ReportErrorToTaskRequest* request,
+ ReportErrorToTaskResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("ReportErrorToTaskAsync"));
+ }
};
class CoordinationServiceAgentTest : public ::testing::Test {
@@ -77,8 +82,8 @@
void SetUp() override {
ON_CALL(*client_, RegisterTaskAsync(_, _, _, _))
.WillByDefault(InvokeArgument<3>(Status::OK()));
- ON_CALL(*client_, ShutdownTaskAsync(_, _, _))
- .WillByDefault(InvokeArgument<2>(Status::OK()));
+ ON_CALL(*client_, ShutdownTaskAsync(_, _, _, _))
+ .WillByDefault(InvokeArgument<3>(Status::OK()));
ON_CALL(*client_, ReportErrorToServiceAsync(_, _, _))
.WillByDefault(InvokeArgument<2>(Status::OK()));
ON_CALL(*GetClient(), ResetTaskAsync(_, _, _))
diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_test.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_test.cc
index c96f861..09bc8a7 100644
--- a/tensorflow/core/distributed_runtime/coordination/coordination_service_test.cc
+++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_test.cc
@@ -62,7 +62,8 @@
done(Status::OK());
}
- void ReportErrorToTaskAsync(const ReportErrorToTaskRequest* request,
+ void ReportErrorToTaskAsync(CallOptions* call_opts,
+ const ReportErrorToTaskRequest* request,
ReportErrorToTaskResponse* response,
StatusCallback done) override {
mutex_lock l(mu_);
@@ -80,7 +81,6 @@
UNIMPLEMENTED(Heartbeat);
UNIMPLEMENTED(WaitForAllTasks);
- UNIMPLEMENTED(ShutdownTask);
UNIMPLEMENTED(ResetTask);
UNIMPLEMENTED(ReportErrorToService);
UNIMPLEMENTED(InsertKeyValue);
@@ -89,6 +89,12 @@
UNIMPLEMENTED(Barrier);
UNIMPLEMENTED(CancelBarrier);
#undef UNIMPLEMENTED
+ void ShutdownTaskAsync(CallOptions* call_opts,
+ const ShutdownTaskRequest* request,
+ ShutdownTaskResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("ShutdownTaskAsync"));
+ }
private:
mutex mu_;
diff --git a/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.cc b/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.cc
index 2e07f28..c7f1dca 100644
--- a/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.cc
+++ b/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.cc
@@ -92,12 +92,13 @@
&target_);
}
- void ShutdownTaskAsync(const ShutdownTaskRequest* request,
+ void ShutdownTaskAsync(CallOptions* call_opts,
+ const ShutdownTaskRequest* request,
ShutdownTaskResponse* response,
StatusCallback done) override {
new RPCState<protobuf::Message>(
&stub_, cq_, "/tensorflow.CoordinationService/ShutdownTask", *request,
- response, std::move(done), /*call_opts=*/nullptr,
+ response, std::move(done), call_opts,
/*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
&target_);
}
@@ -124,12 +125,13 @@
/*fail_fast=*/true, &target_);
}
- void ReportErrorToTaskAsync(const ReportErrorToTaskRequest* request,
+ void ReportErrorToTaskAsync(CallOptions* call_opts,
+ const ReportErrorToTaskRequest* request,
ReportErrorToTaskResponse* response,
StatusCallback done) override {
new RPCState<protobuf::Message>(
&stub_, cq_, "/tensorflow.CoordinationService/ReportErrorToTask",
- *request, response, std::move(done), /*call_opts=*/nullptr,
+ *request, response, std::move(done), call_opts,
/*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
&target_);
}