blob: 9795c52f7158827e589991be240d444b7cc9b9ae [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <functional>
#include <memory>
#include <string>
#include "absl/strings/string_view.h"
#include "absl/synchronization/barrier.h"
#include "absl/synchronization/notification.h"
#include "absl/time/time.h"
#include "grpcpp/grpcpp.h"
#include "grpcpp/server.h"
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
struct ServiceParams {
std::string test_name;
// If false, test uses distributed runtime service instead.
bool use_coordination_service = false;
};
class ClientServerTest : public testing::TestWithParam<ServiceParams> {
public:
void StartService(DistributedRuntimeServiceImpl::Options service_options,
bool use_coordination_service,
absl::string_view service_address = "") {
::grpc::ServerBuilder builder;
// Add a listening port if address is specified.
if (!service_address.empty()) {
auto credentials = ::grpc::InsecureServerCredentials();
builder.AddListeningPort(std::string(service_address), credentials);
}
// Set up and register service on the gRPC server.
if (use_coordination_service) {
coord_service_ =
std::make_unique<CoordinationServiceImpl>(service_options, &builder);
server_ = builder.BuildAndStart();
coord_service_->StartRpcThread();
} else {
distributed_runtime_service_ =
std::make_unique<DistributedRuntimeServiceImpl>(service_options);
builder.RegisterService(distributed_runtime_service_.get());
server_ = builder.BuildAndStart();
}
}
// Shut down the server.
void Stop() {
// Avoid shutting down the server twice if the test has already called
// Stop() earlier.
if (stop_is_already_called_) {
return;
}
server_->Shutdown();
stop_is_already_called_ = true;
}
void TearDown() override { Stop(); }
std::unique_ptr<::grpc::Server> server_;
private:
std::unique_ptr<CoordinationServiceImpl> coord_service_;
std::unique_ptr<DistributedRuntimeServiceImpl> distributed_runtime_service_;
bool stop_is_already_called_ = false;
};
TEST_P(ClientServerTest, ConnectAndShutdownAreBarriers) {
int num_nodes = 3;
DistributedRuntimeServiceImpl::Options service_options;
service_options.num_nodes = num_nodes;
DistributedRuntimeServiceImpl service(service_options);
StartService(service_options, GetParam().use_coordination_service);
absl::Mutex mu;
int connect_count = 0;
int shutdown_count = 0;
absl::Barrier barrier(num_nodes);
auto thread_fn = [&](int node_id) -> xla::Status {
DistributedRuntimeClient::Options client_options;
client_options.node_id = node_id;
auto client = GetDistributedRuntimeClient(
server_->InProcessChannel(::grpc::ChannelArguments()), client_options,
GetParam().use_coordination_service);
// Allow the threads to call Connect one-by-one in order.
auto my_connect_turn = [&]() {
mu.AssertHeld();
return connect_count == node_id;
};
{
absl::MutexLock lock(&mu);
mu.Await(absl::Condition(&my_connect_turn));
++connect_count;
}
TF_RETURN_IF_ERROR(client->Connect());
// Verify that all of the threads have called Connect() by the time we get
// here.
{
absl::MutexLock lock(&mu);
TF_RET_CHECK(connect_count == num_nodes);
}
// Similarly for shutting down.
auto my_shutdown_turn = [&]() {
mu.AssertHeld();
return shutdown_count == node_id;
};
{
absl::MutexLock lock(&mu);
mu.Await(absl::Condition(&my_shutdown_turn));
++shutdown_count;
}
TF_RETURN_IF_ERROR(client->Shutdown());
{
absl::MutexLock lock(&mu);
TF_RET_CHECK(shutdown_count == num_nodes);
}
return ::tensorflow::OkStatus();
};
std::vector<xla::Status> statuses(num_nodes);
{
tensorflow::thread::ThreadPool thread_pool(tensorflow::Env::Default(),
"test_threads", num_nodes);
for (int i = 0; i < num_nodes; ++i) {
thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); });
}
}
for (int i = 0; i < num_nodes; ++i) {
TF_EXPECT_OK(statuses[i]);
}
}
TEST_P(ClientServerTest, ConnectAndEnumerateDevices) {
DistributedRuntimeServiceImpl::Options service_options;
service_options.num_nodes = 2;
StartService(service_options, GetParam().use_coordination_service);
std::vector<LocalTopologyProto> locals(2);
locals[0].set_node_id(0);
locals[1].set_node_id(1);
DeviceProto* d0 = locals[0].add_devices();
d0->set_local_device_ordinal(0);
DeviceProto* d1 = locals[0].add_devices();
d1->set_local_device_ordinal(0);
DeviceProto* d2 = locals[0].add_devices();
d2->set_local_device_ordinal(707);
DeviceProto* d3 = locals[1].add_devices();
d3->set_local_device_ordinal(1);
GlobalTopologyProto expected_topology;
auto* node0 = expected_topology.add_nodes();
auto* node1 = expected_topology.add_nodes();
*node0 = locals[0];
node0->mutable_devices(0)->set_global_device_id(0);
node0->mutable_devices(1)->set_global_device_id(1);
node0->mutable_devices(2)->set_global_device_id(2);
*node1 = locals[1];
node1->mutable_devices(0)->set_global_device_id(3);
// Used to ensure that thread0's client connects before thread1's client to
// set the global device ids deterministically.
absl::Notification n;
auto thread0_fn = [&]() -> xla::Status {
DistributedRuntimeClient::Options client_options;
client_options.node_id = 0;
auto client = GetDistributedRuntimeClient(
server_->InProcessChannel(::grpc::ChannelArguments()), client_options,
GetParam().use_coordination_service);
GlobalTopologyProto topology;
// Unblock the second thread.
// Note: For distributed runtime service, client->Connect() blocks
// until all clients have connected concurrently. Thus, we cannot notify
// after this Connect() due to a deadlock.
n.Notify();
TF_RETURN_IF_ERROR(client->Connect());
TF_RETURN_IF_ERROR(client->EnumerateDevices(locals[0], &topology));
TF_RET_CHECK(
xla::protobuf_util::ProtobufEquals(topology, expected_topology))
<< topology.DebugString();
TF_RETURN_IF_ERROR(client->KeyValueSet("key1", "value1"));
TF_ASSIGN_OR_RETURN(
std::string value,
client->BlockingKeyValueGet("key2", absl::InfiniteDuration()));
TF_RET_CHECK(value == "value2");
return ::tensorflow::OkStatus();
};
auto thread1_fn = [&]() -> xla::Status {
// Wait for thread0 client to be ready for connection, to ensure global ids
// are set in order (thread0 client, then thread1 client).
n.WaitForNotification();
DistributedRuntimeClient::Options client_options;
client_options.node_id = 1;
auto client = GetDistributedRuntimeClient(
server_->InProcessChannel(::grpc::ChannelArguments()), client_options,
GetParam().use_coordination_service);
GlobalTopologyProto topology;
TF_RETURN_IF_ERROR(client->Connect());
absl::SleepFor(absl::Seconds(1));
TF_RETURN_IF_ERROR(client->EnumerateDevices(locals[1], &topology));
TF_RET_CHECK(
xla::protobuf_util::ProtobufEquals(topology, expected_topology))
<< topology.DebugString();
TF_ASSIGN_OR_RETURN(
std::string value,
client->BlockingKeyValueGet("key1", absl::InfiniteDuration()));
TF_RET_CHECK(value == "value1");
TF_RETURN_IF_ERROR(client->KeyValueSet("key2", "value2"));
return ::tensorflow::OkStatus();
};
std::vector<std::function<xla::Status()>> functions = {thread0_fn,
thread1_fn};
std::vector<xla::Status> statuses(functions.size());
{
tensorflow::thread::ThreadPool thread_pool(
tensorflow::Env::Default(), "test_threads", functions.size());
for (int i = 0; i < functions.size(); ++i) {
thread_pool.Schedule([&, i]() { statuses[i] = functions[i](); });
}
}
TF_EXPECT_OK(statuses[0]);
TF_EXPECT_OK(statuses[1]);
}
TEST_P(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) {
int num_nodes = 3;
DistributedRuntimeServiceImpl::Options service_options;
service_options.num_nodes = num_nodes;
service_options.heartbeat_interval = absl::Milliseconds(500);
service_options.max_missing_heartbeats = 2;
StartService(service_options, GetParam().use_coordination_service);
auto thread_fn = [&](int node_id) -> xla::Status {
DistributedRuntimeClient::Options client_options;
client_options.node_id = node_id;
client_options.heartbeat_interval = service_options.heartbeat_interval;
client_options.max_missing_heartbeats = 2;
client_options.shutdown_on_destruction = node_id != 0;
client_options.missed_heartbeat_callback =
[&](xla::Status status, bool coordinator_initiated) {};
auto client = GetDistributedRuntimeClient(
server_->InProcessChannel(::grpc::ChannelArguments()), client_options,
GetParam().use_coordination_service);
TF_RETURN_IF_ERROR(client->Connect());
if (node_id == 0) {
return ::tensorflow::OkStatus();
}
// The call to Shutdown() should be interrupted if a worker stops issuing
// heartbeats.
TF_RETURN_IF_ERROR(client->Shutdown());
return ::tensorflow::OkStatus();
};
std::vector<xla::Status> statuses(num_nodes);
{
tensorflow::thread::ThreadPool thread_pool(tensorflow::Env::Default(),
"test_threads", num_nodes);
for (int i = 0; i < num_nodes; ++i) {
thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); });
}
}
TF_EXPECT_OK(statuses[0]);
for (int i = 1; i < num_nodes; ++i) {
if (GetParam().use_coordination_service) {
// Other nodes will be placed into ERROR state when the service informs
// them of node 0's missing heartbeat failure.
// agent->Shutdown() may lead into two different error codes depending on
// the timing of the call:
// 1. Internal: node turns into ERROR state during the shutdown call.
// 2. Failed Precondition: node is already in ERROR state before the
// shutdown call (note: agent will still stop sending heartbeats).
EXPECT_TRUE(tensorflow::errors::IsInternal(statuses[i]) ||
tensorflow::errors::IsFailedPrecondition(statuses[i]));
} else {
EXPECT_EQ(statuses[i].code(), tensorflow::error::ABORTED);
}
}
}
TEST_P(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) {
int num_nodes = 3;
DistributedRuntimeServiceImpl::Options service_options;
service_options.num_nodes = num_nodes;
service_options.heartbeat_interval = absl::Milliseconds(500);
service_options.max_missing_heartbeats = 2;
StartService(service_options, GetParam().use_coordination_service);
auto thread_fn = [&](int node_id) -> xla::Status {
DistributedRuntimeClient::Options client_options;
client_options.node_id = node_id;
client_options.heartbeat_interval = service_options.heartbeat_interval;
client_options.max_missing_heartbeats = 2;
client_options.shutdown_on_destruction = (node_id != 0);
absl::Notification shutdown;
client_options.missed_heartbeat_callback = [&](xla::Status status,
bool coordinator_initiated) {
shutdown.Notify();
};
auto client = GetDistributedRuntimeClient(
server_->InProcessChannel(::grpc::ChannelArguments()), client_options,
GetParam().use_coordination_service);
TF_RETURN_IF_ERROR(client->Connect());
if (node_id == 0) {
return ::tensorflow::OkStatus();
}
shutdown.WaitForNotification();
return ::tensorflow::OkStatus();
};
std::vector<xla::Status> statuses(num_nodes);
{
tensorflow::thread::ThreadPool thread_pool(tensorflow::Env::Default(),
"test_threads", num_nodes);
for (int i = 0; i < num_nodes; ++i) {
thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); });
}
}
for (int i = 0; i < num_nodes; ++i) {
TF_EXPECT_OK(statuses[i]);
}
}
TEST_P(ClientServerTest, ClientsTerminateIfServiceGoesAway) {
int num_nodes = 3;
DistributedRuntimeServiceImpl::Options service_options;
service_options.num_nodes = num_nodes;
service_options.heartbeat_interval = absl::Milliseconds(500);
service_options.max_missing_heartbeats = 2;
// We use a socket connection for this test case because the in-process API
// does not react well to the server being told to shutdown while there are
// active clients.
int port = tensorflow::testing::PickUnusedPortOrDie();
StartService(service_options, GetParam().use_coordination_service,
absl::StrCat("[::]:", port));
absl::Barrier barrier(num_nodes + 1);
auto thread_fn = [&](int node_id) -> xla::Status {
DistributedRuntimeClient::Options client_options;
client_options.node_id = node_id;
client_options.heartbeat_interval = service_options.heartbeat_interval;
client_options.max_missing_heartbeats = 2;
client_options.rpc_timeout = absl::Seconds(1);
client_options.shutdown_timeout = absl::Seconds(10);
absl::Notification shutdown;
client_options.missed_heartbeat_callback = [&](xla::Status status,
bool coordinator_initiated) {
shutdown.Notify();
};
std::shared_ptr<::grpc::ChannelCredentials> creds =
::grpc::InsecureChannelCredentials();
std::shared_ptr<::grpc::Channel> channel =
::grpc::CreateChannel(absl::StrCat("dns:///localhost:", port), creds);
auto client = GetDistributedRuntimeClient(
channel, client_options, GetParam().use_coordination_service);
TF_RETURN_IF_ERROR(client->Connect());
barrier.Block();
shutdown.WaitForNotification();
TF_RETURN_IF_ERROR(client->Shutdown());
return ::tensorflow::OkStatus();
};
std::vector<xla::Status> statuses(num_nodes);
{
tensorflow::thread::ThreadPool thread_pool(tensorflow::Env::Default(),
"test_threads", num_nodes);
for (int i = 0; i < num_nodes; ++i) {
thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); });
}
barrier.Block();
Stop();
}
for (int i = 0; i < num_nodes; ++i) {
if (GetParam().use_coordination_service) {
EXPECT_EQ(statuses[i].code(), tensorflow::error::FAILED_PRECONDITION);
} else {
EXPECT_EQ(statuses[i].code(), tensorflow::error::DEADLINE_EXCEEDED)
<< statuses[i];
}
}
}
// We should eventually connect, even if some clients are late to show up.
TEST_P(ClientServerTest, LateClientsAreOk) {
int num_nodes = 3;
DistributedRuntimeServiceImpl::Options service_options;
service_options.num_nodes = num_nodes;
StartService(service_options, GetParam().use_coordination_service);
absl::Barrier barrier(num_nodes);
auto thread_fn = [&](int node_id) -> xla::Status {
DistributedRuntimeClient::Options client_options;
client_options.node_id = node_id;
client_options.init_timeout = absl::Milliseconds(20000);
client_options.rpc_timeout = absl::Milliseconds(200);
auto client = GetDistributedRuntimeClient(
server_->InProcessChannel(::grpc::ChannelArguments()), client_options,
GetParam().use_coordination_service);
barrier.Block();
absl::SleepFor(absl::Milliseconds(200) * node_id);
TF_RETURN_IF_ERROR(client->Connect());
TF_RETURN_IF_ERROR(client->Shutdown());
return ::tensorflow::OkStatus();
};
std::vector<xla::Status> statuses(num_nodes);
{
tensorflow::thread::ThreadPool thread_pool(tensorflow::Env::Default(),
"test_threads", num_nodes);
for (int i = 0; i < num_nodes; ++i) {
thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); });
}
}
for (int i = 0; i < num_nodes; ++i) {
TF_EXPECT_OK(statuses[i]);
}
}
// We should eventually time out if a client does not show up.
TEST_P(ClientServerTest, ConnectEventuallyTimesOutIfAClientDoesNotShowUp) {
int num_nodes = 3;
absl::Duration timeout = absl::Milliseconds(500);
DistributedRuntimeServiceImpl::Options service_options;
service_options.num_nodes = num_nodes;
service_options.enumerate_devices_timeout = timeout;
service_options.shutdown_timeout = timeout;
StartService(service_options, GetParam().use_coordination_service);
auto thread_fn = [&](int node_id) -> xla::Status {
DistributedRuntimeClient::Options client_options;
client_options.node_id = node_id;
client_options.init_timeout = timeout;
client_options.rpc_timeout = absl::Milliseconds(200);
auto client = GetDistributedRuntimeClient(
server_->InProcessChannel(::grpc::ChannelArguments()), client_options,
GetParam().use_coordination_service);
TF_RETURN_IF_ERROR(client->Connect());
TF_RETURN_IF_ERROR(client->Shutdown());
return ::tensorflow::OkStatus();
};
// Note: one fewer thread than 'num_nodes'.
std::vector<xla::Status> statuses(num_nodes - 1);
{
tensorflow::thread::ThreadPool thread_pool(tensorflow::Env::Default(),
"test_threads", num_nodes);
for (int i = 0; i < num_nodes - 1; ++i) {
thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); });
}
}
for (int i = 0; i < num_nodes - 1; ++i) {
EXPECT_EQ(statuses[i].code(), tensorflow::error::DEADLINE_EXCEEDED);
}
}
TEST_P(ClientServerTest, WaitAtBarrier_Succeed) {
int num_nodes = 2;
DistributedRuntimeServiceImpl::Options service_options;
service_options.num_nodes = num_nodes;
StartService(service_options, GetParam().use_coordination_service);
auto thread_fn = [&](int node_id) -> xla::Status {
DistributedRuntimeClient::Options client_options;
client_options.node_id = node_id;
auto client = GetDistributedRuntimeClient(
server_->InProcessChannel(::grpc::ChannelArguments()), client_options,
GetParam().use_coordination_service);
TF_RETURN_IF_ERROR(client->Connect());
TF_RETURN_IF_ERROR(
client->WaitAtBarrier("barrier_1", absl::Milliseconds(100)));
TF_RETURN_IF_ERROR(
client->WaitAtBarrier("barrier_2", absl::Milliseconds(100)));
TF_RETURN_IF_ERROR(client->Shutdown());
return xla::Status::OK();
};
std::vector<xla::Status> statuses(num_nodes);
{
tensorflow::thread::ThreadPool thread_pool(tensorflow::Env::Default(),
"test_threads", num_nodes);
for (int i = 0; i < num_nodes; ++i) {
thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); });
}
}
for (int i = 0; i < num_nodes; ++i) {
TF_EXPECT_OK(statuses[i]);
}
}
TEST_P(ClientServerTest, WaitAtBarrier_Timeout) {
int num_nodes = 2;
DistributedRuntimeServiceImpl::Options service_options;
service_options.num_nodes = num_nodes;
StartService(service_options, GetParam().use_coordination_service);
absl::Notification n;
auto thread_fn = [&](int node_id) -> xla::Status {
DistributedRuntimeClient::Options client_options;
client_options.node_id = node_id;
auto client = GetDistributedRuntimeClient(
server_->InProcessChannel(::grpc::ChannelArguments()), client_options,
GetParam().use_coordination_service);
TF_RETURN_IF_ERROR(client->Connect());
if (node_id == 1) {
n.WaitForNotification();
}
Status barrier_status =
client->WaitAtBarrier("barrier_1", absl::Milliseconds(100));
if (node_id == 0) {
n.Notify();
}
TF_RETURN_IF_ERROR(barrier_status);
TF_RETURN_IF_ERROR(client->Shutdown());
return xla::Status::OK();
};
std::vector<xla::Status> statuses(num_nodes);
{
tensorflow::thread::ThreadPool thread_pool(tensorflow::Env::Default(),
"test_threads", num_nodes);
for (int i = 0; i < num_nodes; ++i) {
thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); });
}
}
for (int i = 0; i < num_nodes; ++i) {
if (GetParam().use_coordination_service) {
// Co-ordination service returns the status of the previous barrier
// failure without waiting for the thread to time out.
EXPECT_EQ(statuses[i].code(), tensorflow::error::DEADLINE_EXCEEDED)
<< " node id: " << i;
} else {
if (i == 0) {
EXPECT_EQ(statuses[i].code(), tensorflow::error::DEADLINE_EXCEEDED)
<< " node id: " << i;
}
if (i == 1) {
EXPECT_EQ(statuses[i].code(), tensorflow::error::FAILED_PRECONDITION)
<< " node id: " << i;
}
}
}
}
TEST_P(ClientServerTest, WaitAtBarrier_TimeoutWithDifferentBarrierId) {
int num_nodes = 2;
DistributedRuntimeServiceImpl::Options service_options;
service_options.num_nodes = num_nodes;
StartService(service_options, GetParam().use_coordination_service);
auto thread_fn = [&](int node_id) -> xla::Status {
DistributedRuntimeClient::Options client_options;
client_options.node_id = node_id;
auto client = GetDistributedRuntimeClient(
server_->InProcessChannel(::grpc::ChannelArguments()), client_options,
GetParam().use_coordination_service);
TF_RETURN_IF_ERROR(client->Connect());
std::string barrier_id;
if (node_id == 0) {
barrier_id = "barrier_0";
} else if (node_id == 1) {
barrier_id = "barrier_1";
}
TF_RETURN_IF_ERROR(
client->WaitAtBarrier(barrier_id, absl::Milliseconds(100)));
TF_RETURN_IF_ERROR(client->Shutdown());
return xla::Status::OK();
};
std::vector<xla::Status> statuses(num_nodes);
{
tensorflow::thread::ThreadPool thread_pool(tensorflow::Env::Default(),
"test_threads", num_nodes);
for (int i = 0; i < num_nodes; ++i) {
thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); });
}
}
for (int i = 0; i < num_nodes; ++i) {
EXPECT_EQ(statuses[i].code(), tensorflow::error::DEADLINE_EXCEEDED)
<< " node id: " << i;
}
}
TEST_P(ClientServerTest, WaitAtBarrier_FailWithSameBarrierId) {
int num_nodes = 2;
DistributedRuntimeServiceImpl::Options service_options;
service_options.num_nodes = num_nodes;
StartService(service_options, GetParam().use_coordination_service);
auto thread_fn = [&](int node_id) -> xla::Status {
DistributedRuntimeClient::Options client_options;
client_options.node_id = node_id;
auto client = GetDistributedRuntimeClient(
server_->InProcessChannel(::grpc::ChannelArguments()), client_options,
GetParam().use_coordination_service);
TF_RETURN_IF_ERROR(client->Connect());
TF_RETURN_IF_ERROR(
client->WaitAtBarrier("barrier_1", absl::Milliseconds(100)));
TF_RETURN_IF_ERROR(
client->WaitAtBarrier("barrier_1", absl::Milliseconds(100)));
TF_RETURN_IF_ERROR(client->Shutdown());
return xla::Status::OK();
};
std::vector<xla::Status> statuses(num_nodes);
{
tensorflow::thread::ThreadPool thread_pool(tensorflow::Env::Default(),
"test_threads", num_nodes);
for (int i = 0; i < num_nodes; ++i) {
thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); });
}
}
for (int i = 0; i < num_nodes; ++i) {
EXPECT_EQ(statuses[i].code(), tensorflow::error::FAILED_PRECONDITION)
<< " node id: " << i;
}
}
INSTANTIATE_TEST_SUITE_P(
ClientServerTests, ClientServerTest,
::testing::ValuesIn<ServiceParams>({
{"CoordinationService", true},
{"DistributedRuntimeService", false},
}),
[](const ::testing::TestParamInfo<ClientServerTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace
} // namespace xla