blob: 43b23bda9bee045d543419084053ce72324f0bcc [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
#include <memory>
#include <string>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/coordination_config.pb.h"
#include "tensorflow/core/protobuf/coordination_service.pb.h"
namespace tensorflow {
namespace {
using ::testing::_;
using ::testing::DoAll;
using ::testing::InvokeArgument;
using ::testing::SetArgPointee;
using ::testing::UnorderedPointwise;
using ::testing::WithArgs;
MATCHER(KvEq, "simple KeyValueEntry matcher") {
const KeyValueEntry& kv0 = std::get<0>(arg);
const KeyValueEntry& kv1 = std::get<1>(arg);
return kv0.key() == kv1.key() && kv0.value() == kv1.value();
}
KeyValueEntry CreateKv(const std::string& key, const std::string& value) {
KeyValueEntry kv;
kv.set_key(key);
kv.set_value(value);
return kv;
}
class TestCoordinationClient : public CoordinationClient {
public:
TestCoordinationClient() = default;
// MOCK_METHOD does not work on Windows build, using deprecated MOCK_METHOD3
// instead.
MOCK_METHOD4(GetKeyValueAsync,
void(CallOptions* call_opts, const GetKeyValueRequest*,
GetKeyValueResponse*, StatusCallback));
MOCK_METHOD3(TryGetKeyValueAsync,
void(const TryGetKeyValueRequest*, TryGetKeyValueResponse*,
StatusCallback));
MOCK_METHOD3(GetKeyValueDirAsync,
void(const GetKeyValueDirRequest*, GetKeyValueDirResponse*,
StatusCallback));
MOCK_METHOD4(RegisterTaskAsync, void(CallOptions*, const RegisterTaskRequest*,
RegisterTaskResponse*, StatusCallback));
MOCK_METHOD4(ShutdownTaskAsync, void(CallOptions*, const ShutdownTaskRequest*,
ShutdownTaskResponse*, StatusCallback));
MOCK_METHOD3(ResetTaskAsync, void(const ResetTaskRequest*, ResetTaskResponse*,
StatusCallback));
MOCK_METHOD3(ReportErrorToServiceAsync,
void(const ReportErrorToServiceRequest*,
ReportErrorToServiceResponse*, StatusCallback));
MOCK_METHOD3(BarrierAsync,
void(const BarrierRequest*, BarrierResponse*, StatusCallback));
#define UNIMPLEMENTED(method) \
void method##Async(const method##Request* request, \
method##Response* response, StatusCallback done) \
override { \
done(errors::Unimplemented(#method "Async")); \
}
UNIMPLEMENTED(WaitForAllTasks);
UNIMPLEMENTED(InsertKeyValue);
UNIMPLEMENTED(DeleteKeyValue);
UNIMPLEMENTED(CancelBarrier);
#undef UNIMPLEMENTED
void HeartbeatAsync(CallOptions* call_opts, const HeartbeatRequest* request,
HeartbeatResponse* response,
StatusCallback done) override {
done(errors::Unimplemented("HeartbeatAsync"));
}
void ReportErrorToTaskAsync(CallOptions* call_opts,
const ReportErrorToTaskRequest* request,
ReportErrorToTaskResponse* response,
StatusCallback done) override {
done(errors::Unimplemented("ReportErrorToTaskAsync"));
}
};
class CoordinationServiceAgentTest : public ::testing::Test {
public:
void SetUp() override {
ON_CALL(*client_, RegisterTaskAsync(_, _, _, _))
.WillByDefault(InvokeArgument<3>(OkStatus()));
ON_CALL(*client_, ShutdownTaskAsync(_, _, _, _))
.WillByDefault(InvokeArgument<3>(OkStatus()));
ON_CALL(*client_, ReportErrorToServiceAsync(_, _, _))
.WillByDefault(InvokeArgument<2>(OkStatus()));
ON_CALL(*client_, ResetTaskAsync(_, _, _))
.WillByDefault(InvokeArgument<2>(OkStatus()));
ON_CALL(*client_, BarrierAsync(_, _, _))
.WillByDefault(InvokeArgument<2>(OkStatus()));
}
// Should be called after mocking service responses, before testing the agent.
void InitializeAgent() {
CoordinationServiceConfig config;
config.set_service_leader("test_leader");
TF_EXPECT_OK(agent_->Initialize(
Env::Default(), /*job_name=*/"test_job",
/*task_id=*/0, config, std::move(client_),
/*error_fn=*/[](Status s) {
LOG(ERROR) << "Coordination agent is set to error: " << s;
}));
}
TestCoordinationClient* GetClient() {
// InitializeAgent() transfers ownership of the coordination client.
CHECK(client_ != nullptr)
<< "GetClient() was called after InitializeAgent()";
return client_.get();
}
protected:
std::unique_ptr<CoordinationServiceAgent> agent_ =
CreateCoordinationServiceAgent();
std::unique_ptr<TestCoordinationClient> client_ =
std::make_unique<TestCoordinationClient>();
};
TEST_F(CoordinationServiceAgentTest, GetKeyValue_Simple_Success) {
const std::string& test_key = "test_key";
const std::string& test_value = "test_value";
// Mock server response: set key-value pair and invoke done callback.
GetKeyValueResponse mocked_response;
auto kv = mocked_response.mutable_kv();
kv->set_key(test_key);
kv->set_value(test_value);
ON_CALL(*GetClient(), GetKeyValueAsync(_, _, _, _))
.WillByDefault(DoAll(SetArgPointee<2>(mocked_response),
InvokeArgument<3>(OkStatus())));
// Initialize coordination agent.
InitializeAgent();
auto result = agent_->GetKeyValue(test_key);
TF_EXPECT_OK(result.status());
EXPECT_EQ(result.ValueOrDie(), test_value);
}
TEST_F(CoordinationServiceAgentTest, GetKeyValue_WithTimeout_Success) {
const std::string& test_key = "test_key";
const std::string& test_value = "test_value";
// Mock server response: set key-value pair and invoke done callback.
GetKeyValueResponse mocked_response;
auto kv = mocked_response.mutable_kv();
kv->set_key(test_key);
kv->set_value(test_value);
ON_CALL(*GetClient(), GetKeyValueAsync(_, _, _, _))
.WillByDefault(DoAll(SetArgPointee<2>(mocked_response),
InvokeArgument<3>(OkStatus())));
// Initialize coordination agent.
InitializeAgent();
auto result = agent_->GetKeyValue(test_key, /*timeout=*/absl::Seconds(10));
TF_EXPECT_OK(result.status());
EXPECT_EQ(result.ValueOrDie(), test_value);
}
TEST_F(CoordinationServiceAgentTest, GetKeyValue_Timeout_ReturnError) {
const std::string& test_key = "test_key";
StatusCallback owned_done;
ON_CALL(*GetClient(), GetKeyValueAsync(_, _, _, _))
.WillByDefault(WithArgs<3>([&](StatusCallback done) {
// Copy method argument to prevent de-allocation.
owned_done = done;
}));
InitializeAgent();
auto result = agent_->GetKeyValue(test_key, /*timeout=*/absl::Seconds(1));
EXPECT_EQ(result.status().code(), error::DEADLINE_EXCEEDED);
// Needed to tear down test safely since agent dtor would cancel pending
// calls, which would reference deallocated call_opts.
owned_done(errors::Cancelled("error"));
}
TEST_F(CoordinationServiceAgentTest,
GetKeyValue_DelayedResponse_TimeoutWithoutMemoryError) {
const std::string& test_key = "test_key";
const std::string& test_value = "test_value";
auto client = std::make_unique<TestCoordinationClient>();
GetKeyValueResponse* owned_response;
StatusCallback owned_done;
ON_CALL(*GetClient(), GetKeyValueAsync(_, _, _, _))
.WillByDefault(WithArgs<2, 3>(
[&](GetKeyValueResponse* response, StatusCallback done) {
// Copy method arguments to prevent de-allocation before mocking the
// server callback beyond timeout.
owned_response = response;
owned_done = done;
}));
// Initialize coordination service agent.
InitializeAgent();
auto result = agent_->GetKeyValue(test_key, /*timeout=*/absl::Seconds(1));
EXPECT_EQ(result.status().code(), error::DEADLINE_EXCEEDED);
// Delayed server response: set key-value response, and invoke done callback.
auto kv = owned_response->mutable_kv();
kv->set_key(test_key);
kv->set_value(test_value);
owned_done(OkStatus());
// No explicit test, but used to verify there is no stack-use-after-return
// or other memory-related errors.
}
TEST_F(CoordinationServiceAgentTest,
GetKeyValue_DelayedResponseBeforeTimeout_Success) {
const std::string& test_key = "test_key";
const std::string& test_value = "test_value";
// Mock delayed server response before timeout: set key-value pair and invoke
// done callback.
auto client = std::make_unique<TestCoordinationClient>();
std::unique_ptr<Thread> async_thread;
GetKeyValueResponse* owned_response;
StatusCallback owned_done;
ON_CALL(*GetClient(), GetKeyValueAsync(_, _, _, _))
// Setup async callback to insert key-value after a brief delay (5s)
// before timeout (10s).
.WillByDefault(WithArgs<2, 3>(
[&](GetKeyValueResponse* response, StatusCallback done) {
// Copy method arguments to prevent de-allocation before
// triggering this async callback.
owned_response = response;
owned_done = done;
async_thread = absl::WrapUnique(Env::Default()->StartThread(
ThreadOptions(), "async_thread", [&]() {
// Set brief delay.
absl::SleepFor(absl::Seconds(5));
// Set key-value response, and invoke done callback.
auto kv = owned_response->mutable_kv();
kv->set_key(test_key);
kv->set_value(test_value);
owned_done(OkStatus());
}));
}));
InitializeAgent();
auto result = agent_->GetKeyValue(test_key, /*timeout=*/absl::Seconds(10));
TF_EXPECT_OK(result.status());
EXPECT_EQ(result.ValueOrDie(), test_value);
}
TEST_F(CoordinationServiceAgentTest, CancelGetKeyValue_Success) {
const std::string test_key = "test_key";
ON_CALL(*GetClient(), GetKeyValueAsync(_, _, _, _))
.WillByDefault(
WithArgs<0, 3>([](CallOptions* call_opts, StatusCallback done) {
// Mock RPC call cancellation.
call_opts->SetCancelCallback([callback = std::move(done)]() {
callback(errors::Cancelled("RPC call cancelled."));
});
}));
InitializeAgent();
Status status;
std::shared_ptr<CallOptions> get_kv_call_opts = agent_->GetKeyValueAsync(
test_key, [&status](const StatusOr<std::string>& result) {
status = result.status();
});
get_kv_call_opts->StartCancel();
EXPECT_TRUE(errors::IsCancelled(status)) << status;
// This is to prevent memory leaks due to how we set this particular cancel
// callback. In practice, this should not be necessary.
get_kv_call_opts->ClearCancelCallback();
}
TEST_F(CoordinationServiceAgentTest, TryGetKeyValue_Simple_Success) {
const std::string& test_key = "test_key";
const std::string& test_value = "test_value";
// Mock server response: set key-value pair and invoke done callback.
TryGetKeyValueResponse mocked_response;
auto kv = mocked_response.mutable_kv();
kv->set_key(test_key);
kv->set_value(test_value);
ON_CALL(*GetClient(), TryGetKeyValueAsync(_, _, _))
.WillByDefault(DoAll(SetArgPointee<1>(mocked_response),
InvokeArgument<2>(OkStatus())));
// Initialize coordination agent.
InitializeAgent();
auto result = agent_->TryGetKeyValue(test_key);
TF_ASSERT_OK(result.status());
EXPECT_EQ(result.ValueOrDie(), test_value);
}
TEST_F(CoordinationServiceAgentTest, GetKeyValueDir_Simple_Success) {
const std::string test_key = "test_key_dir";
std::vector<KeyValueEntry> test_values;
test_values.push_back(CreateKv("test_key_dir/task_0", "0"));
test_values.push_back(CreateKv("test_key_dir/task_1", "1"));
// Mock server response: set key-value pair and invoke done callback.
GetKeyValueDirResponse mocked_response;
mocked_response.set_directory_key(test_key);
*mocked_response.mutable_kv() = {test_values.begin(), test_values.end()};
ON_CALL(*GetClient(), GetKeyValueDirAsync(_, _, _))
.WillByDefault(DoAll(SetArgPointee<1>(mocked_response),
InvokeArgument<2>(OkStatus())));
// Initialize coordination agent.
InitializeAgent();
auto result = agent_->GetKeyValueDir(test_key);
TF_EXPECT_OK(result.status());
EXPECT_THAT(result.ValueOrDie(), UnorderedPointwise(KvEq(), test_values));
}
TEST_F(CoordinationServiceAgentTest, NotAllowedToConnectAfterShuttingDown) {
InitializeAgent();
TF_EXPECT_OK(agent_->Connect());
TF_EXPECT_OK(agent_->Shutdown());
Status status = agent_->Connect();
// Not allowed to connect after shutting down.
EXPECT_TRUE(errors::IsFailedPrecondition(status));
}
TEST_F(CoordinationServiceAgentTest, ShutdownInErrorShouldReturnError) {
// Connect coordination agent and set it to error.
InitializeAgent();
TF_EXPECT_OK(agent_->Connect());
TF_EXPECT_OK(agent_->ReportError(errors::Internal("Test Error.")));
// Shutdown should return error.
Status s = agent_->Shutdown();
EXPECT_TRUE(errors::IsFailedPrecondition(s));
}
TEST_F(CoordinationServiceAgentTest, Reset_ConnectedButNotInError_Fail) {
// Connect agent.
InitializeAgent();
TF_EXPECT_OK(agent_->Connect());
auto status = agent_->Reset();
// Fails because agent is not in ERROR state.
EXPECT_TRUE(errors::IsFailedPrecondition(status));
}
TEST_F(CoordinationServiceAgentTest, ConnectAfterResetError) {
// Connect coordination agent and set it to error.
InitializeAgent();
TF_EXPECT_OK(agent_->Connect());
TF_EXPECT_OK(agent_->ReportError(errors::Internal("Test Error.")));
// Reset error.
TF_EXPECT_OK(agent_->Reset());
// Agent should be able to reconnect to the service after resetting.
TF_EXPECT_OK(agent_->Connect());
}
TEST_F(CoordinationServiceAgentTest, ResetCanBeRetried) {
// Mock reset error failing for the first time.
EXPECT_CALL(*GetClient(), ResetTaskAsync(_, _, _))
.WillOnce(InvokeArgument<2>(errors::Internal("Reset error")))
.WillOnce(InvokeArgument<2>(OkStatus()));
// Connect coordination agent and set it to error.
InitializeAgent();
TF_EXPECT_OK(agent_->Connect());
TF_EXPECT_OK(agent_->ReportError(errors::Internal("Test Error.")));
// Reset error fails for the first time.
Status reset_status = agent_->Reset();
EXPECT_TRUE(errors::IsInternal(reset_status));
// Agent should be able to attempt resetting again.
TF_EXPECT_OK(agent_->Reset());
// Agent should be able to reconnect to the service after resetting.
TF_EXPECT_OK(agent_->Connect());
}
TEST_F(CoordinationServiceAgentTest, GetOwnTask) {
InitializeAgent();
auto result = agent_->GetOwnTask();
TF_EXPECT_OK(result.status());
CoordinatedTask actual_task = result.ValueOrDie();
// These fields are from the arguments used in InitializeAgent().
CoordinatedTask expected_task;
expected_task.set_job_name("test_job");
expected_task.set_task_id(0);
EXPECT_EQ(actual_task.job_name(), expected_task.job_name());
EXPECT_EQ(actual_task.task_id(), expected_task.task_id());
}
TEST_F(CoordinationServiceAgentTest, GetOwnTask_Uninitialized) {
auto result = agent_->GetOwnTask();
EXPECT_TRUE(errors::IsFailedPrecondition(result.status()));
}
TEST_F(CoordinationServiceAgentTest, WaitAtBarrier_SameIdUsedTwice_Fails) {
InitializeAgent();
const std::string barrier_id = "only_use_once";
TF_EXPECT_OK(agent_->Connect());
// Wait at barrier for the first time should succeed.
TF_EXPECT_OK(
agent_->WaitAtBarrier(barrier_id, absl::Seconds(1), /*tasks=*/{}));
// Subsequent calls should fail.
auto result =
agent_->WaitAtBarrier(barrier_id, absl::Seconds(1), /*tasks=*/{});
EXPECT_TRUE(errors::IsFailedPrecondition(result));
}
} // namespace
} // namespace tensorflow