blob: 195fb39d52a26ae00707572c3b6083a399e8ac62 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
#include <memory>
#include <string>
#include <utility>
#include "absl/strings/str_cat.h"
#include "absl/synchronization/notification.h"
#include "absl/time/time.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
#include "tensorflow/core/distributed_runtime/test_utils.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/coordination_config.pb.h"
#include "tensorflow/core/protobuf/coordination_service.pb.h"
namespace tensorflow {
namespace {
using ::testing::EqualsProto;
using ::testing::IsEmpty;
using ::testing::UnorderedElementsAre;
using ::testing::proto::IgnoringRepeatedFieldOrdering;
constexpr absl::Duration kHeartbeatTimeout = absl::Seconds(2);
constexpr absl::Duration kShutdownBarrierTimeout = absl::Seconds(1);
constexpr char kCoordinationServiceType[] = "standalone";
KeyValueEntry CreateKv(const std::string& key, const std::string& value) {
KeyValueEntry kv;
kv.set_key(key);
kv.set_value(value);
return kv;
}
class TestCoordinationClient : public CoordinationClient {
public:
TestCoordinationClient() = default;
Status GetStatus() {
mutex_lock l(mu_);
return status_;
}
void RegisterTaskAsync(CallOptions* opts, const RegisterTaskRequest* request,
RegisterTaskResponse* response,
StatusCallback done) override {
done(OkStatus());
}
void ReportErrorToTaskAsync(CallOptions* call_opts,
const ReportErrorToTaskRequest* request,
ReportErrorToTaskResponse* response,
StatusCallback done) override {
mutex_lock l(mu_);
status_ = Status(static_cast<errors::Code>(request->error_code()),
request->error_message());
done(OkStatus());
}
#define UNIMPLEMENTED(method) \
void method##Async(const method##Request* request, \
method##Response* response, StatusCallback done) \
override { \
done(errors::Unimplemented(#method "Async")); \
}
UNIMPLEMENTED(WaitForAllTasks);
UNIMPLEMENTED(ResetTask);
UNIMPLEMENTED(ReportErrorToService);
UNIMPLEMENTED(InsertKeyValue);
UNIMPLEMENTED(TryGetKeyValue);
UNIMPLEMENTED(GetKeyValueDir);
UNIMPLEMENTED(DeleteKeyValue);
UNIMPLEMENTED(Barrier);
UNIMPLEMENTED(CancelBarrier);
#undef UNIMPLEMENTED
#define UNIMPLEMENTED_WITH_CALL_OPTS(method) \
void method##Async(CallOptions* call_opts, const method##Request* request, \
method##Response* response, StatusCallback done) \
override { \
done(errors::Unimplemented(#method "Async")); \
}
UNIMPLEMENTED_WITH_CALL_OPTS(GetKeyValue);
UNIMPLEMENTED_WITH_CALL_OPTS(Heartbeat);
UNIMPLEMENTED_WITH_CALL_OPTS(ShutdownTask);
#undef UNIMPLEMENTED_WITH_CALL_OPTS
private:
mutex mu_;
Status status_ TF_GUARDED_BY(mu_);
};
class TestCoordinationClientCache : public CoordinationClientCache {
public:
void AddTask(const std::string& target, CoordinationClient* client) {
clients_.emplace(target, client);
}
CoordinationClient* GetClient(const string& target) override {
auto it = clients_.find(target);
if (it == clients_.end()) return nullptr;
return it->second;
}
std::unique_ptr<CoordinationClient> GetOwnedClient(
const string& target) override {
LOG(ERROR) << "GetOwnedClient is not supported.";
return nullptr;
}
private:
std::unordered_map<std::string, CoordinationClient*> clients_;
};
class CoordinationBarrierTest : public ::testing::Test {
protected:
CoordinationBarrierTest() {
// Set up fake cluster with 3 tasks.
const int num_tasks = 3;
const ServerDef& server_def = GetMultiClientServerDef("worker", num_tasks);
auto client_cache = std::make_unique<TestCoordinationClientCache>();
for (int i = 0; i < num_tasks; ++i) {
CoordinatedTask task;
task.set_job_name("worker");
task.set_task_id(i);
auto client = std::make_unique<TestCoordinationClient>();
client_cache->AddTask(absl::StrCat("/job:worker/replica:0/task:", i),
client.get());
tasks_.push_back(task);
clients_.push_back(std::move(client));
}
coord_service_ = CoordinationServiceInterface::EnableCoordinationService(
kCoordinationServiceType, Env::Default(), server_def,
std::move(client_cache));
// Register the tasks.
for (int i = 0; i < num_tasks; ++i) {
Status s = coord_service_->RegisterTask(tasks_[i], /*incarnation=*/0);
if (!s.ok()) {
LOG(FATAL) << "RegisterTask() failed in CoordinationBarrierTest(): "
<< s;
}
}
}
CoordinationServiceInterface* GetCoordinationService() {
return coord_service_.get();
}
CoordinatedTask GetTask(int i) { return tasks_[i]; }
private:
std::unique_ptr<CoordinationServiceInterface> coord_service_;
std::vector<CoordinatedTask> tasks_;
std::vector<std::unique_ptr<TestCoordinationClient>> clients_;
};
// Sets up coordination service that expects 2 worker tasks.
class CoordinateTwoTasksTest : public ::testing::Test {
protected:
CoordinateTwoTasksTest() {
task_0_.set_job_name("worker");
task_0_.set_task_id(0);
task_1_.set_job_name("worker");
task_1_.set_task_id(1);
}
// Set up coordination service.
void EnableCoordinationService(bool has_service_to_client_connection = true,
bool enable_shutdown_barrier = false) {
ServerDef server_def = GetMultiClientServerDef("worker", /*num_tasks=*/2);
auto client_cache = std::make_unique<TestCoordinationClientCache>();
if (has_service_to_client_connection) {
client_cache->AddTask("/job:worker/replica:0/task:0", &client_0_);
client_cache->AddTask("/job:worker/replica:0/task:1", &client_1_);
} else {
client_cache = nullptr;
}
auto coord_config = server_def.mutable_default_session_config()
->mutable_experimental()
->mutable_coordination_config();
coord_config->set_service_type(kCoordinationServiceType);
coord_config->set_heartbeat_timeout_in_ms(kHeartbeatTimeout /
absl::Milliseconds(1));
if (enable_shutdown_barrier) {
coord_config->set_shutdown_barrier_timeout_in_ms(kShutdownBarrierTimeout /
absl::Milliseconds(1));
}
// Init service.
coord_service_ = CoordinationServiceInterface::EnableCoordinationService(
kCoordinationServiceType, Env::Default(), server_def,
std::move(client_cache));
}
CoordinatedTask task_0_;
const uint64_t incarnation_0_ = random::New64();
TestCoordinationClient client_0_;
CoordinatedTask task_1_;
const uint64_t incarnation_1_ = random::New64();
TestCoordinationClient client_1_;
std::unique_ptr<CoordinationServiceInterface> coord_service_;
};
// Construct fake device protos.
DeviceAttributes CreateTestTfDevice(absl::string_view name) {
DeviceAttributes device;
device.set_name(name);
device.set_device_type("CPU");
return device;
}
xla::DeviceProto CreateTestXlaDevice(absl::string_view name,
const int local_id) {
xla::DeviceProto device;
device.set_name(name);
device.set_local_device_ordinal(local_id);
return device;
}
TEST_F(CoordinateTwoTasksTest, TestStandaloneService) {
EnableCoordinationService();
// Not specified in server def.
CoordinatedTask task_2;
task_2.set_job_name("worker");
task_2.set_task_id(2);
TF_ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
absl::Notification wait_for_all;
coord_service_->WaitForAllTasks(task_0_, {}, [&](Status s) {
TF_ASSERT_OK(s);
wait_for_all.Notify();
});
// Not all tasks have registered, so must not be notified here.
ASSERT_FALSE(wait_for_all.HasBeenNotified());
TF_ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
coord_service_->WaitForAllTasks(task_1_, {},
[&](Status s) { TF_ASSERT_OK(s); });
// All tasks have registered.
wait_for_all.WaitForNotification();
TF_ASSERT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_));
TF_ASSERT_OK(coord_service_->RecordHeartbeat(task_1_, incarnation_1_));
EXPECT_TRUE(
errors::IsInvalidArgument(coord_service_->RecordHeartbeat(task_2, 0)));
// Sending heartbeat with incarnation mismatch leads to Aborted error.
EXPECT_TRUE(errors::IsAborted(coord_service_->RecordHeartbeat(task_1_, 0)));
EXPECT_TRUE(errors::IsAborted(coord_service_->RecordHeartbeat(task_1_, 0)));
// Error is propagated to other tasks.
EXPECT_TRUE(errors::IsAborted(client_0_.GetStatus()));
}
TEST(CoordinationServiceTest, TestCoordinatedJobs) {
ServerDef server_def = GetMultiClientServerDef("chief", 1);
CoordinatedTask chief;
chief.set_job_name("chief");
chief.set_task_id(0);
CoordinatedTask task_0;
task_0.set_job_name("worker");
task_0.set_task_id(0);
CoordinatedTask task_1;
task_1.set_job_name("worker");
task_1.set_task_id(1);
CoordinatedTask evaluator;
evaluator.set_job_name("evaluator");
evaluator.set_task_id(0);
// Add a worker job with 2 tasks
ClusterDef* cluster_def = server_def.mutable_cluster();
JobDef* job_def = cluster_def->add_job();
job_def->set_name("worker");
job_def->mutable_tasks()->insert({0, "dummy address"});
job_def->mutable_tasks()->insert({1, "dummy address"});
// Add an evaluator job with 1 task
job_def = cluster_def->add_job();
job_def->set_name("evaluator");
job_def->mutable_tasks()->insert({0, "dummy address"});
CoordinationServiceConfig* configs =
server_def.mutable_default_session_config()
->mutable_experimental()
->mutable_coordination_config();
configs->mutable_coordinated_jobs()->Add("chief");
configs->mutable_coordinated_jobs()->Add("worker");
auto client_cache = std::make_unique<TestCoordinationClientCache>();
TestCoordinationClient ci;
client_cache->AddTask("/job:chief/replica:0/task:0", &ci);
TestCoordinationClient wi0;
client_cache->AddTask("/job:worker/replica:0/task:0", &wi0);
TestCoordinationClient wi1;
client_cache->AddTask("/job:worker/replica:0/task:1", &wi1);
TestCoordinationClient ei;
client_cache->AddTask("/job:evaluator/replica:0/task:0", &ei);
std::unique_ptr<CoordinationServiceInterface> coord_service =
CoordinationServiceInterface::EnableCoordinationService(
kCoordinationServiceType, Env::Default(), server_def,
std::move(client_cache));
// Each coordinated task registers and waits for other tasks.
absl::Notification register_chief;
TF_ASSERT_OK(coord_service->RegisterTask(chief, /*incarnation=*/0));
coord_service->WaitForAllTasks(chief, {}, [&](Status s) {
TF_ASSERT_OK(s);
register_chief.Notify();
});
absl::Notification register_task0;
TF_ASSERT_OK(coord_service->RegisterTask(task_0, /*incarnation=*/0));
coord_service->WaitForAllTasks(task_0, {}, [&](Status s) {
TF_ASSERT_OK(s);
register_task0.Notify();
});
absl::Notification register_task1;
TF_ASSERT_OK(coord_service->RegisterTask(task_1, /*incarnation=*/0));
coord_service->WaitForAllTasks(task_1, {}, [&](Status s) {
TF_ASSERT_OK(s);
register_task1.Notify();
});
// All tasks in the coordinated jobs have registered.
register_chief.WaitForNotification();
register_task0.WaitForNotification();
register_task1.WaitForNotification();
// Registering the evaluator task is unexpected
Status status = coord_service->RegisterTask(evaluator, /*incarnation=*/0);
EXPECT_TRUE(errors::IsInvalidArgument(status)) << status;
}
TEST(CoordinationServiceTest, RegisterTask_AlreadyConnected_Fails) {
ServerDef server_def = GetMultiClientServerDef("worker", 1);
JobDef* job_def = server_def.mutable_cluster()->add_job();
job_def->set_name("worker");
job_def->mutable_tasks()->insert({0, "dummy address"});
CoordinatedTask task_0;
task_0.set_job_name("worker");
task_0.set_task_id(0);
std::unique_ptr<CoordinationServiceInterface> coord_service =
CoordinationServiceInterface::EnableCoordinationService(
kCoordinationServiceType, Env::Default(), server_def,
/*cache=*/nullptr);
// Task connects to coordination service.
TF_ASSERT_OK(coord_service->RegisterTask(task_0, /*incarnation=*/0));
// Registration should fail since task already registered previously.
const Status status = coord_service->RegisterTask(task_0, /*incarnation=*/0);
EXPECT_TRUE(errors::IsAborted(status)) << status;
}
TEST(CoordinationServiceTest, RegisterTask_AlreadyInError_Fails) {
ServerDef server_def = GetMultiClientServerDef("worker", 1);
JobDef* job_def = server_def.mutable_cluster()->add_job();
job_def->set_name("worker");
job_def->mutable_tasks()->insert({0, "dummy address"});
CoordinatedTask task_0;
task_0.set_job_name("worker");
task_0.set_task_id(0);
std::unique_ptr<CoordinationServiceInterface> coord_service =
CoordinationServiceInterface::EnableCoordinationService(
kCoordinationServiceType, Env::Default(), server_def,
/*cache=*/nullptr);
// Task connects to coordination service.
TF_ASSERT_OK(coord_service->RegisterTask(task_0, /*incarnation=*/0));
// Arbitrarily set task to be in error.
TF_ASSERT_OK(
coord_service->ReportTaskError(task_0, errors::Internal("test_error")));
// Registration should fail since task already registered previously.
const Status status = coord_service->RegisterTask(task_0, /*incarnation=*/0);
EXPECT_TRUE(errors::IsAborted(status)) << status;
}
TEST_F(CoordinateTwoTasksTest, TestTaskHeartbeatTimeout) {
EnableCoordinationService();
TF_ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
TF_ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
// No heartbeat for a while, leader considers the task as stale.
Env::Default()->SleepForMicroseconds(
absl::ToInt64Microseconds(2 * kHeartbeatTimeout));
EXPECT_TRUE(errors::IsUnavailable(
coord_service_->RecordHeartbeat(task_0_, incarnation_0_)));
EXPECT_TRUE(errors::IsUnavailable(
coord_service_->RecordHeartbeat(task_1_, incarnation_1_)));
}
TEST_F(CoordinateTwoTasksTest,
HeartbeatTimeoutWithoutServerToClientConnection) {
EnableCoordinationService(/*has_service_to_client_connection=*/false);
TF_ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
TF_ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
// No heartbeat for a while, leader consider the task as stale.
// Service stops and disconnects both tasks.
Env::Default()->SleepForMicroseconds(
absl::ToInt64Microseconds(2 * kHeartbeatTimeout));
// Unexpected heartbeat from unregistered tasks since service state has been
// reset.
EXPECT_TRUE(errors::IsInvalidArgument(
coord_service_->RecordHeartbeat(task_0_, incarnation_0_)));
EXPECT_TRUE(errors::IsInvalidArgument(
coord_service_->RecordHeartbeat(task_1_, incarnation_1_)));
}
TEST_F(CoordinateTwoTasksTest, TestTaskRestart) {
EnableCoordinationService();
TF_ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
TF_ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
// Simulate task restart scenario: trying to register to cluster again.
Status s =
coord_service_->RegisterTask(task_1_, /*incarnation=*/random::New64());
EXPECT_TRUE(errors::IsAborted(s)) << s;
// Aborted error is also propagated to other tasks in cluster.
EXPECT_TRUE(errors::IsAborted(client_0_.GetStatus()))
<< client_0_.GetStatus();
}
TEST_F(CoordinateTwoTasksTest, TestSetGetValues) {
EnableCoordinationService();
// Simple key
TF_ASSERT_OK(coord_service_->InsertKeyValue("key0", "value0"));
// Unix file like key path
TF_ASSERT_OK(coord_service_->InsertKeyValue("/path", "value"));
TF_ASSERT_OK(coord_service_->InsertKeyValue("/path/to/key1", "value1"));
// Key with redundant slashes
TF_ASSERT_OK(coord_service_->InsertKeyValue("path/to//key2/", "value2"));
// Error when repeatedly inserting the same key
EXPECT_TRUE(errors::IsAlreadyExists(
coord_service_->InsertKeyValue("/path/to/key1/", "value2")));
// Get simple key
absl::Notification n1;
StatusOr<std::string> ret;
coord_service_->GetKeyValueAsync(
"key0", [&](const StatusOr<std::string>& status_or_value) {
ret = status_or_value;
n1.Notify();
});
n1.WaitForNotification();
TF_ASSERT_OK(ret.status());
EXPECT_EQ(ret.ValueOrDie(), "value0");
// Get key with redundant slashes
absl::Notification n2;
coord_service_->GetKeyValueAsync(
"path//to///key1////", [&](const StatusOr<std::string>& status_or_value) {
ret = status_or_value;
n2.Notify();
});
n2.WaitForNotification();
EXPECT_EQ(ret.ValueOrDie(), "value1");
// Delete single key-value
TF_ASSERT_OK(coord_service_->DeleteKeyValue("key0"));
// Get key that is not available
absl::Notification n3;
coord_service_->GetKeyValueAsync(
"key0", [&](const StatusOr<std::string>& status_or_value) {
ret = status_or_value;
n3.Notify();
});
EXPECT_FALSE(n3.HasBeenNotified());
// Insert the previously deleted key again
TF_ASSERT_OK(coord_service_->InsertKeyValue("key0", "value0_new"));
n3.WaitForNotification();
EXPECT_EQ(ret.ValueOrDie(), "value0_new");
// Delete key-values recursively
TF_ASSERT_OK(coord_service_->DeleteKeyValue("/path"));
// Get key that is not available
auto n4 = std::make_shared<absl::Notification>();
coord_service_->GetKeyValueAsync(
"/path/to/key1",
// Note: this callback will remain pending until it is cleaned up during
// service shutdown. Hence, we use a shared pointer for notification so
// that the it will not be deallocated before the pending callback is
// cleaned up.
[n4](const StatusOr<std::string>& status_or_value) { n4->Notify(); });
EXPECT_FALSE(n4->HasBeenNotified());
}
TEST(CoordinationServiceTest, TryGetKeyValue) {
const ServerDef& server_def = GetMultiClientServerDef("worker", 1);
auto client_cache = std::make_unique<TestCoordinationClientCache>();
std::unique_ptr<CoordinationServiceInterface> coord_service =
CoordinationServiceInterface::EnableCoordinationService(
kCoordinationServiceType, Env::Default(), server_def,
std::move(client_cache));
// Try to get nonexistent key.
StatusOr<std::string> result = coord_service->TryGetKeyValue("test_key");
EXPECT_TRUE(errors::IsNotFound(result.status()));
// Insert key value.
TF_ASSERT_OK(coord_service->InsertKeyValue("test_key", "test_value"));
result = coord_service->TryGetKeyValue("test_key");
EXPECT_EQ(result.ValueOrDie(), "test_value");
// Delete Key, and try to get the key again.
TF_ASSERT_OK(coord_service->DeleteKeyValue("test_key"));
result = coord_service->TryGetKeyValue("test_key");
EXPECT_TRUE(errors::IsNotFound(result.status()));
}
TEST_F(CoordinateTwoTasksTest, GetKeyValueDir_SingleValueInDirectory) {
EnableCoordinationService();
KeyValueEntry kv = CreateKv("dir/path", "value0");
TF_ASSERT_OK(coord_service_->InsertKeyValue(kv.key(), kv.value()));
std::vector<KeyValueEntry> result = coord_service_->GetKeyValueDir("dir");
EXPECT_THAT(result, UnorderedElementsAre(EqualsProto(kv)));
}
TEST_F(CoordinateTwoTasksTest, GetKeyValueDir_MultipleValuesInDirectory) {
EnableCoordinationService();
KeyValueEntry kv = CreateKv("dir/path", "value0");
KeyValueEntry kv2 = CreateKv("dir/path2", "value1");
// Placed in nested subdirectory.
KeyValueEntry kv_sub = CreateKv("dir/sub_dir/path", "value_sub");
TF_ASSERT_OK(coord_service_->InsertKeyValue(kv.key(), kv.value()));
TF_ASSERT_OK(coord_service_->InsertKeyValue(kv2.key(), kv2.value()));
TF_ASSERT_OK(coord_service_->InsertKeyValue(kv_sub.key(), kv_sub.value()));
std::vector<KeyValueEntry> result = coord_service_->GetKeyValueDir("dir");
EXPECT_THAT(result, UnorderedElementsAre(EqualsProto(kv), EqualsProto(kv2),
EqualsProto(kv_sub)));
}
TEST_F(CoordinateTwoTasksTest, GetKeyValueDir_Empty_ReturnsEmptyList) {
EnableCoordinationService();
std::vector<KeyValueEntry> result = coord_service_->GetKeyValueDir("dir");
EXPECT_THAT(result, IsEmpty());
}
TEST_F(CoordinateTwoTasksTest, GetKeyValueDir_WrongDir_ReturnsEmptyList) {
EnableCoordinationService();
// Wrong directory.
TF_ASSERT_OK(coord_service_->InsertKeyValue("dir0/path", "value0"));
std::vector<KeyValueEntry> result = coord_service_->GetKeyValueDir("dir");
EXPECT_THAT(result, IsEmpty());
}
TEST_F(CoordinateTwoTasksTest, GetKeyValueDir_WrongDirPrefix_ReturnsEmptyList) {
EnableCoordinationService();
// Check that we don't match with nested subdirectories with the wrong prefix.
TF_ASSERT_OK(coord_service_->InsertKeyValue("wrong_dir/dir/path", "value0"));
std::vector<KeyValueEntry> result = coord_service_->GetKeyValueDir("dir");
EXPECT_THAT(result, IsEmpty());
}
TEST_F(CoordinateTwoTasksTest,
GetKeyValueDir_NonDirectoryPrefix_ReturnsEmptyList) {
EnableCoordinationService();
// Wrong directory.
TF_ASSERT_OK(coord_service_->InsertKeyValue("dir_key", "value0"));
std::vector<KeyValueEntry> result = coord_service_->GetKeyValueDir("dir");
EXPECT_THAT(result, IsEmpty());
}
TEST_F(CoordinateTwoTasksTest,
GetKeyValueDir_NonDirectoryKey_ReturnsEmptyList) {
EnableCoordinationService();
// Insert same key that is not a directory.
TF_ASSERT_OK(coord_service_->InsertKeyValue("dir", "value0"));
std::vector<KeyValueEntry> result = coord_service_->GetKeyValueDir("dir");
EXPECT_THAT(result, IsEmpty());
}
} // namespace
// Verify that coordination service can gather each task's device info and
// propagate the aggregated cluster device info correctly.
TEST(CoordinationServiceTest, ListClusterDevices_TfDevice) {
const ServerDef& server_def = GetMultiClientServerDef("worker", 3);
CoordinatedTask task_0;
task_0.set_job_name("worker");
task_0.set_task_id(0);
CoordinatedTask task_1;
task_1.set_job_name("worker");
task_1.set_task_id(1);
CoordinatedTask task_2;
task_2.set_job_name("worker");
task_2.set_task_id(2);
Status status = OkStatus();
auto client_cache = std::make_unique<TestCoordinationClientCache>();
std::unique_ptr<CoordinationServiceInterface> coord_service =
CoordinationServiceInterface::EnableCoordinationService(
kCoordinationServiceType, Env::Default(), server_def,
std::move(client_cache));
absl::Notification n;
// Map fake devices to each task.
CoordinationServiceDeviceInfo local_devices_0;
CoordinationServiceDeviceInfo local_devices_1;
CoordinationServiceDeviceInfo local_devices_2;
*local_devices_0.mutable_tf()->mutable_devices()->Add() =
CreateTestTfDevice("task0_device0");
*local_devices_0.mutable_tf()->mutable_devices()->Add() =
CreateTestTfDevice("task0_device1");
*local_devices_1.mutable_tf()->mutable_devices()->Add() =
CreateTestTfDevice("task1_device0");
*local_devices_2.mutable_tf()->mutable_devices()->Add() =
CreateTestTfDevice("task2_device0");
// Each task sends its device info.
CoordinationServiceDeviceInfo cluster_devices;
coord_service->WaitForAllTasks(task_0, local_devices_0,
[&](Status s) { TF_ASSERT_OK(s); });
coord_service->WaitForAllTasks(task_1, local_devices_1,
[&](Status s) { TF_ASSERT_OK(s); });
coord_service->WaitForAllTasks(task_2, local_devices_2, [&](Status s) {
TF_ASSERT_OK(s);
// Gather the cluster device info.
cluster_devices = coord_service->ListClusterDevices();
n.Notify();
});
n.WaitForNotification();
CoordinationServiceDeviceInfo expected_cluster_devices;
auto expected_devices =
expected_cluster_devices.mutable_tf()->mutable_devices();
expected_devices->Add(local_devices_0.mutable_tf()->devices().begin(),
local_devices_0.mutable_tf()->devices().end());
expected_devices->Add(local_devices_1.mutable_tf()->devices().begin(),
local_devices_1.mutable_tf()->devices().end());
expected_devices->Add(local_devices_2.mutable_tf()->devices().begin(),
local_devices_2.mutable_tf()->devices().end());
EXPECT_THAT(cluster_devices, IgnoringRepeatedFieldOrdering(
EqualsProto(expected_cluster_devices)));
}
TEST(CoordinationServiceTest, ListClusterDevices_XlaDevice) {
const ServerDef& server_def = GetMultiClientServerDef("worker", 3);
CoordinatedTask task_0;
task_0.set_job_name("worker");
task_0.set_task_id(0);
CoordinatedTask task_1;
task_1.set_job_name("worker");
task_1.set_task_id(1);
CoordinatedTask task_2;
task_2.set_job_name("worker");
task_2.set_task_id(2);
Status status = OkStatus();
auto client_cache = std::make_unique<TestCoordinationClientCache>();
std::unique_ptr<CoordinationServiceInterface> coord_service =
CoordinationServiceInterface::EnableCoordinationService(
kCoordinationServiceType, Env::Default(), server_def,
std::move(client_cache));
absl::Notification n;
// Map fake devices to each task.
CoordinationServiceDeviceInfo local_devices_0;
CoordinationServiceDeviceInfo local_devices_1;
CoordinationServiceDeviceInfo local_devices_2;
xla::LocalTopologyProto local_0;
xla::LocalTopologyProto local_1;
xla::LocalTopologyProto local_2;
local_0.set_node_id(0);
local_1.set_node_id(1);
local_2.set_node_id(2);
*local_0.add_devices() = CreateTestXlaDevice("task0_device0", 0);
*local_0.add_devices() = CreateTestXlaDevice("task0_device1", 1);
*local_1.add_devices() = CreateTestXlaDevice("task1_device0", 0);
*local_2.add_devices() = CreateTestXlaDevice("task2_device0", 0);
*local_devices_0.mutable_xla()->mutable_devices()->add_nodes() = local_0;
*local_devices_1.mutable_xla()->mutable_devices()->add_nodes() = local_1;
*local_devices_2.mutable_xla()->mutable_devices()->add_nodes() = local_2;
// Each task sends its device info.
CoordinationServiceDeviceInfo cluster_devices;
coord_service->WaitForAllTasks(task_0, local_devices_0,
[&](Status s) { TF_ASSERT_OK(s); });
coord_service->WaitForAllTasks(task_1, local_devices_1,
[&](Status s) { TF_ASSERT_OK(s); });
coord_service->WaitForAllTasks(task_2, local_devices_2, [&](Status s) {
TF_ASSERT_OK(s);
// Gather the cluster device info.
cluster_devices = coord_service->ListClusterDevices();
n.Notify();
});
n.WaitForNotification();
CoordinationServiceDeviceInfo expected_cluster_devices;
local_0.mutable_devices(0)->set_global_device_id(0);
local_0.mutable_devices(1)->set_global_device_id(1);
local_1.mutable_devices(0)->set_global_device_id(2);
local_2.mutable_devices(0)->set_global_device_id(3);
*expected_cluster_devices.mutable_xla()->mutable_devices()->add_nodes() =
local_0;
*expected_cluster_devices.mutable_xla()->mutable_devices()->add_nodes() =
local_1;
*expected_cluster_devices.mutable_xla()->mutable_devices()->add_nodes() =
local_2;
EXPECT_THAT(cluster_devices, IgnoringRepeatedFieldOrdering(
EqualsProto(expected_cluster_devices)));
}
// Task devices should not be added twice if same task calls WaitForAllDevices()
// twice.
TEST(CoordinationServiceTest, ListClusterDevices_DevicesAreNotAddedTwice) {
const ServerDef& server_def = GetMultiClientServerDef("worker", 2);
CoordinatedTask task_0;
task_0.set_job_name("worker");
task_0.set_task_id(0);
CoordinatedTask task_1;
task_1.set_job_name("worker");
task_1.set_task_id(1);
Status status = OkStatus();
auto client_cache = std::make_unique<TestCoordinationClientCache>();
std::unique_ptr<CoordinationServiceInterface> coord_service =
CoordinationServiceInterface::EnableCoordinationService(
kCoordinationServiceType, Env::Default(), server_def,
std::move(client_cache));
absl::Notification n;
// Map fake devices to each task.
CoordinationServiceDeviceInfo local_devices_0;
CoordinationServiceDeviceInfo local_devices_1;
*local_devices_0.mutable_tf()->mutable_devices()->Add() =
CreateTestTfDevice("task0_device0");
*local_devices_0.mutable_tf()->mutable_devices()->Add() =
CreateTestTfDevice("task0_device1");
*local_devices_1.mutable_tf()->mutable_devices()->Add() =
CreateTestTfDevice("task1_device0");
// Task0 sends device info.
CoordinationServiceDeviceInfo cluster_devices;
coord_service->WaitForAllTasks(task_0, local_devices_0,
[](Status s) { TF_ASSERT_OK(s); });
// Task0 sends device info sgain.
coord_service->WaitForAllTasks(task_0, local_devices_0,
[](Status s) { TF_ASSERT_OK(s); });
coord_service->WaitForAllTasks(
task_1, local_devices_1,
[coord_service = coord_service.get(), &cluster_devices, &n](Status s) {
TF_ASSERT_OK(s);
// Gather the cluster device info.
cluster_devices = coord_service->ListClusterDevices();
n.Notify();
});
n.WaitForNotification();
// No duplicates found.
CoordinationServiceDeviceInfo expected_cluster_devices;
auto expected_devices =
expected_cluster_devices.mutable_tf()->mutable_devices();
expected_devices->Add(local_devices_0.mutable_tf()->devices().begin(),
local_devices_0.mutable_tf()->devices().end());
expected_devices->Add(local_devices_1.mutable_tf()->devices().begin(),
local_devices_1.mutable_tf()->devices().end());
EXPECT_THAT(cluster_devices, IgnoringRepeatedFieldOrdering(
EqualsProto(expected_cluster_devices)));
}
TEST_F(CoordinationBarrierTest, Barrier) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
Status barrier_status_0;
Status barrier_status_1;
Status barrier_status_2;
absl::Notification n_0;
absl::Notification n_1;
absl::Notification n_2;
GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{},
[&barrier_status_0, &n_0](Status s) {
barrier_status_0 = s;
n_0.Notify();
});
GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(1),
/*participating_tasks=*/{},
[&barrier_status_1, &n_1](Status s) {
barrier_status_1 = s;
n_1.Notify();
});
// Make sure barrier has not been exited prematurely.
EXPECT_FALSE(n_0.HasBeenNotified());
EXPECT_FALSE(n_1.HasBeenNotified());
EXPECT_FALSE(n_2.HasBeenNotified());
// Last task calls the barrier.
GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(2),
/*participating_tasks=*/{},
[&barrier_status_2, &n_2](Status s) {
barrier_status_2 = s;
n_2.Notify();
});
EXPECT_TRUE(n_0.HasBeenNotified());
EXPECT_TRUE(n_1.HasBeenNotified());
EXPECT_TRUE(n_2.HasBeenNotified());
TF_EXPECT_OK(barrier_status_0);
TF_EXPECT_OK(barrier_status_1);
TF_EXPECT_OK(barrier_status_2);
}
TEST_F(CoordinationBarrierTest, BarrierWithSubsetOfTasks) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
Status barrier_status_0;
Status barrier_status_1;
absl::Notification n_0;
absl::Notification n_1;
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{GetTask(0), GetTask(1)},
[&barrier_status_0, &n_0](Status s) {
barrier_status_0 = s;
n_0.Notify();
});
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(1),
/*participating_tasks=*/{GetTask(0), GetTask(1)},
[&barrier_status_1, &n_1](Status s) {
barrier_status_1 = s;
n_1.Notify();
});
// All listed tasks passed the barrier.
EXPECT_TRUE(n_0.HasBeenNotified());
EXPECT_TRUE(n_1.HasBeenNotified());
TF_EXPECT_OK(barrier_status_0);
TF_EXPECT_OK(barrier_status_1);
}
TEST_F(CoordinationBarrierTest, BarrierWithMismatchedTasks) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
Status barrier_status_0;
Status barrier_status_1;
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{GetTask(0), GetTask(1)},
[&barrier_status_0](Status s) { barrier_status_0 = s; });
// task_1's barrier call specified a conflicting set of tasks (task_2 instead
// of task_0).
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(1),
/*participating_tasks=*/{GetTask(1), GetTask(2)},
[&barrier_status_1](Status s) { barrier_status_1 = s; });
EXPECT_TRUE(errors::IsInvalidArgument(barrier_status_0));
EXPECT_TRUE(errors::IsInvalidArgument(barrier_status_1));
}
TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTask) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
Status barrier_status_0;
Status barrier_status_1;
absl::Notification n_0;
absl::Notification n_1;
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{GetTask(0), GetTask(1)},
[&barrier_status_0](Status s) { barrier_status_0 = s; });
// Task 2 unexpectedly calls a barrier that it is not participating in.
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(2),
/*participating_tasks=*/{GetTask(0), GetTask(1)},
[&barrier_status_1](Status s) { barrier_status_1 = s; });
// Barrier should fail for all tasks with the unexpected call.
EXPECT_TRUE(errors::IsInvalidArgument(barrier_status_0));
EXPECT_TRUE(errors::IsInvalidArgument(barrier_status_1));
}
TEST_F(CoordinationBarrierTest, BarrierByNonClusterTask) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
Status barrier_status_0;
absl::Notification n_0;
CoordinatedTask unspecified_task;
unspecified_task.set_job_name("task_from_another_cluster");
unspecified_task.set_task_id(2);
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{GetTask(0), unspecified_task},
[&barrier_status_0, &n_0](Status s) {
barrier_status_0 = s;
n_0.Notify();
});
n_0.WaitForNotification();
// Barrier should fail with the unexpected participating task argument.
EXPECT_TRUE(errors::IsInvalidArgument(barrier_status_0));
}
TEST_F(CoordinationBarrierTest, BarrierTimeout) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(1);
Status barrier_status_0;
absl::Notification n_0;
GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{},
[&barrier_status_0, &n_0](Status s) {
barrier_status_0 = s;
n_0.Notify();
});
// Block until user-specified timeout.
n_0.WaitForNotification();
EXPECT_TRUE(errors::IsDeadlineExceeded(barrier_status_0));
}
TEST_F(CoordinationBarrierTest, BarrierReturnsPreviousError) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(1);
Status barrier_status_0;
Status barrier_status_1;
absl::Notification n_0;
GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{},
[&barrier_status_0, &n_0](Status s) {
barrier_status_0 = s;
n_0.Notify();
});
TF_ASSERT_OK(GetCoordinationService()->ReportTaskError(
GetTask(0), errors::Internal("test_error")));
// Block until barrier has failed due to task error.
n_0.WaitForNotification();
// Same response should be returned immediately.
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(1),
/*participating_tasks=*/{},
[&barrier_status_1](Status s) { barrier_status_1 = s; });
EXPECT_TRUE(errors::IsInternal(barrier_status_0));
EXPECT_TRUE(errors::IsInternal(barrier_status_1));
}
TEST_F(CoordinationBarrierTest, BarrierCancelled) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
Status barrier_status;
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{},
[&barrier_status](Status s) { barrier_status = s; });
Status cancelled_status =
GetCoordinationService()->CancelBarrier(barrier_id, GetTask(0));
EXPECT_TRUE(errors::IsCancelled(barrier_status));
TF_EXPECT_OK(cancelled_status);
}
TEST_F(CoordinationBarrierTest, CancelNonExistentBarrier_FutureBarrierFails) {
const std::string barrier_id = "cancelled_barrier_id";
absl::Duration timeout = absl::Seconds(1);
Status barrier_status;
// Cancel barrier should still succeed.
TF_ASSERT_OK(GetCoordinationService()->CancelBarrier(barrier_id, GetTask(0)));
// Calling a cancelled barrier should fail instantly.
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{},
[&barrier_status](Status s) { barrier_status = s; });
EXPECT_TRUE(errors::IsCancelled(barrier_status)) << barrier_status;
}
TEST_F(CoordinationBarrierTest, CancelAfterBarrierHasPassed) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
Status barrier_status_0;
Status barrier_status_1;
Status barrier_status_2;
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{},
[&barrier_status_0](Status s) { barrier_status_0 = s; });
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(1),
/*participating_tasks=*/{},
[&barrier_status_1](Status s) { barrier_status_1 = s; });
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(2),
/*participating_tasks=*/{},
[&barrier_status_2](Status s) { barrier_status_2 = s; });
// Cancel barrier should fail if barrier has already been passed.
Status cancelled_status =
GetCoordinationService()->CancelBarrier(barrier_id, GetTask(0));
EXPECT_TRUE(errors::IsFailedPrecondition(cancelled_status));
TF_EXPECT_OK(barrier_status_0);
TF_EXPECT_OK(barrier_status_1);
TF_EXPECT_OK(barrier_status_2);
}
TEST_F(CoordinationBarrierTest, PassedBarrierReturnsImmediately) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
Status barrier_status_0;
Status barrier_status_1;
Status barrier_status_2;
Status barrier_status_repeat;
absl::Notification n0;
absl::Notification n1;
absl::Notification n2;
absl::Notification n_repeat;
GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{},
[&barrier_status_0, &n0](Status s) {
barrier_status_0 = s;
n0.Notify();
});
GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(1),
/*participating_tasks=*/{},
[&barrier_status_1, &n1](Status s) {
barrier_status_1 = s;
n1.Notify();
});
GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(2),
/*participating_tasks=*/{},
[&barrier_status_2, &n2](Status s) {
barrier_status_2 = s;
n2.Notify();
});
// Repeated call should return the same result.
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(1),
/*participating_tasks=*/{},
[&barrier_status_repeat, &n_repeat](Status s) {
barrier_status_repeat = s;
n_repeat.Notify();
});
EXPECT_TRUE(n0.HasBeenNotified());
EXPECT_TRUE(n1.HasBeenNotified());
EXPECT_TRUE(n2.HasBeenNotified());
EXPECT_TRUE(n_repeat.HasBeenNotified());
TF_EXPECT_OK(barrier_status_0);
TF_EXPECT_OK(barrier_status_1);
TF_EXPECT_OK(barrier_status_2);
TF_EXPECT_OK(barrier_status_repeat);
}
TEST_F(CoordinationBarrierTest, BarrierFailsIfTaskIsAlreadyInError) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
// Set task 0 to error state.
TF_ASSERT_OK(GetCoordinationService()->ReportTaskError(
GetTask(0), errors::Internal("test_error")));
Status barrier_status;
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(1),
/*participating_tasks=*/{},
[&barrier_status](Status s) { barrier_status = s; });
EXPECT_TRUE(errors::IsInternal(barrier_status));
}
TEST_F(CoordinationBarrierTest, BarrierFailsUponTaskError) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
absl::Notification n0;
Status barrier_status;
GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{},
[&barrier_status, &n0](Status s) {
barrier_status = s;
n0.Notify();
});
TF_ASSERT_OK(GetCoordinationService()->ReportTaskError(
GetTask(0), errors::Internal("test_error")));
n0.WaitForNotification();
EXPECT_TRUE(errors::IsInternal(barrier_status));
}
TEST_F(CoordinationBarrierTest,
BarrierStillBlocksIfSameTaskCallsOngoingBarrierRepeatedly) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
Status barrier_status_0;
Status barrier_status_1;
Status barrier_status_2;
absl::Notification n_0;
absl::Notification n_1;
absl::Notification n_2;
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{GetTask(0), GetTask(1)},
[&barrier_status_0, &n_0](Status s) {
barrier_status_0 = s;
n_0.Notify();
});
// Duplicate call.
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{GetTask(0), GetTask(1)},
[&barrier_status_1, &n_1](Status s) {
barrier_status_1 = s;
n_1.Notify();
});
// All listed tasks passed the barrier.
EXPECT_FALSE(n_0.HasBeenNotified());
EXPECT_FALSE(n_1.HasBeenNotified());
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(1),
/*participating_tasks=*/{GetTask(0), GetTask(1)},
[&barrier_status_2, &n_2](Status s) {
barrier_status_2 = s;
n_2.Notify();
});
TF_EXPECT_OK(barrier_status_0);
TF_EXPECT_OK(barrier_status_1);
TF_EXPECT_OK(barrier_status_2);
}
TEST_F(CoordinateTwoTasksTest, ResetAndRegisterAgain) {
EnableCoordinationService();
TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
TF_EXPECT_OK(coord_service_->ResetTask(task_0_));
// Task should be allowed to register again after being reset.
TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
}
TEST_F(CoordinateTwoTasksTest, Reset_HeartbeatsAreAcceptedForAGracePeriod) {
EnableCoordinationService();
TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
TF_EXPECT_OK(coord_service_->ResetTask(task_0_));
// Heartbeat should be allowed for a short grace period after reset.
TF_EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_));
// Heartbeat failure should be triggered for disconnected task after grace
// period.
Env::Default()->SleepForMicroseconds(
absl::ToInt64Microseconds(3 * kHeartbeatTimeout));
EXPECT_TRUE(errors::IsInvalidArgument(
coord_service_->RecordHeartbeat(task_0_, incarnation_0_)));
}
TEST_F(CoordinateTwoTasksTest, Reset_FailsOngoingBarrier) {
EnableCoordinationService(/*has_service_to_client_connection=*/true,
/*enable_shutdown_barrier=*/false);
TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
Status barrier_status;
absl::Notification barrier_n;
coord_service_->BarrierAsync(
"ongoing_barrier", absl::InfiniteDuration(), task_0_,
/*participating_tasks=*/{}, [&barrier_status, &barrier_n](Status s) {
barrier_status = s;
barrier_n.Notify();
});
TF_EXPECT_OK(coord_service_->ResetTask(task_0_));
// Ongoing barrier should fail with error after shutdown.
EXPECT_TRUE(barrier_n.HasBeenNotified());
EXPECT_TRUE(errors::IsInternal(barrier_status)) << barrier_status;
}
TEST_F(CoordinateTwoTasksTest, Shutdown_HeartbeatsAreAcceptedForAGracePeriod) {
EnableCoordinationService(/*has_service_to_client_connection=*/true,
/*enable_shutdown_barrier=*/false);
TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
absl::Notification n;
coord_service_->ShutdownTaskAsync(task_0_, [&n](Status s) {
TF_EXPECT_OK(s);
n.Notify();
});
n.WaitForNotification();
// Heartbeat should be allowed for a short grace period after shutdown.
TF_EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_));
// Heartbeat failure should be triggered for disconnected task after grace
// period.
Env::Default()->SleepForMicroseconds(
absl::ToInt64Microseconds(3 * kHeartbeatTimeout));
EXPECT_TRUE(errors::IsInvalidArgument(
coord_service_->RecordHeartbeat(task_0_, incarnation_0_)));
}
TEST_F(CoordinateTwoTasksTest, Shutdown_FailsOngoingBarrier) {
EnableCoordinationService(/*has_service_to_client_connection=*/true,
/*enable_shutdown_barrier=*/false);
TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
Status barrier_status;
absl::Notification barrier_n;
coord_service_->BarrierAsync(
"ongoing_barrier", absl::InfiniteDuration(), task_0_,
/*participating_tasks=*/{}, [&barrier_status, &barrier_n](Status s) {
barrier_status = s;
barrier_n.Notify();
});
absl::Notification shutdown_n;
coord_service_->ShutdownTaskAsync(task_0_, [&shutdown_n](Status s) {
TF_EXPECT_OK(s);
shutdown_n.Notify();
});
shutdown_n.WaitForNotification();
// Ongoing barrier should fail with error after shutdown.
EXPECT_TRUE(barrier_n.HasBeenNotified());
EXPECT_TRUE(errors::IsInternal(barrier_status)) << barrier_status;
}
TEST_F(CoordinateTwoTasksTest, ShutdownWithBarrier_BarrierSucceeds) {
EnableCoordinationService(/*has_service_to_client_connection=*/true,
/*enable_shutdown_barrier=*/true);
TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
Status barrier_status;
Status barrier_status_2;
coord_service_->ShutdownTaskAsync(
task_0_, [&barrier_status](Status s) { barrier_status = s; });
coord_service_->ShutdownTaskAsync(
task_1_, [&barrier_status_2](Status s) { barrier_status_2 = s; });
TF_EXPECT_OK(barrier_status);
TF_EXPECT_OK(barrier_status_2);
// Confirm that both tasks have disconnected.
// Note: this should not happen in prod where RegisterTask() is called after
// Shutdown(), which is prevented by agent-side logic.
TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
}
TEST_F(CoordinateTwoTasksTest,
ShutdownWithBarrier_BarrierFails_TaskDisconnectsOtherTaskIsAlerted) {
EnableCoordinationService(/*has_service_to_client_connection=*/true,
/*enable_shutdown_barrier=*/true);
TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
Status barrier_status;
absl::Notification n;
coord_service_->ShutdownTaskAsync(task_0_, [&n, &barrier_status](Status s) {
barrier_status = s;
n.Notify();
});
// Block until barrier times out.
n.WaitForNotification();
EXPECT_TRUE(errors::IsDeadlineExceeded(barrier_status)) << barrier_status;
// Confirm that task_0_ has disconnected.
// Note: this should not happen in prod where RegisterTask() is called after
// Shutdown(), which is prevented by agent-side logic.
TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
// Other task is alerted that shutdown has been initiated without it.
Status other_task_status = client_1_.GetStatus();
EXPECT_TRUE(errors::IsInternal(other_task_status)) << other_task_status;
}
TEST_F(CoordinateTwoTasksTest,
ShutdownWithBarrier_BarrierFailsWithoutClientConnection_ServiceStops) {
EnableCoordinationService(/*has_service_to_client_connection=*/false,
/*enable_shutdown_barrier=*/true);
TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
Status barrier_status;
absl::Notification n;
coord_service_->ShutdownTaskAsync(task_0_, [&n, &barrier_status](Status s) {
barrier_status = s;
n.Notify();
});
// Block until barrier times out.
n.WaitForNotification();
// Provide time for coordination service to shut down after barrier timeout.
Env::Default()->SleepForMicroseconds(
absl::ToInt64Microseconds(absl::Seconds(1)));
EXPECT_TRUE(errors::IsDeadlineExceeded(barrier_status)) << barrier_status;
// Service stops because no service-to-client connection is available for
// error propagation.
// Task 1 still sends unexpected heartbeat because it doesn't know that
// service has stopped yet, which should fail.
Status s = coord_service_->RecordHeartbeat(task_1_, incarnation_1_);
EXPECT_TRUE(errors::IsInvalidArgument(s)) << s;
}
} // namespace tensorflow