blob: 46d35f671c3cca60e8093cf04637d82312140370 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
#include <memory>
#include <string>
#include "absl/memory/memory.h"
#include "absl/time/time.h"
#include "grpcpp/server_builder.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
#include "tensorflow/compiler/xla/pjrt/distributed/util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/threadpool.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/coordination_config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
constexpr int kBarrierTimedOut = -1000;
std::unique_ptr<tensorflow::CoordinationServiceInterface>
EnableCoordinationService(
const xla::DistributedRuntimeServiceImpl::Options& options) {
const std::string& job_name = "jax_worker";
// TODO(b/205307544): Remove TensorFlow server def references once it is no
// longer needed.
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
auto job_def = server_def.mutable_cluster()->add_job();
job_def->set_name(job_name);
for (size_t i = 0; i < options.num_nodes; ++i) {
job_def->mutable_tasks()->insert({i, "UNKNOWN_SERVER_ADDRESS"});
}
// Convert options to coordination service config.
auto coordination_config = server_def.mutable_default_session_config()
->mutable_experimental()
->mutable_coordination_config();
coordination_config->set_service_type("standalone");
coordination_config->set_service_leader(
absl::StrCat("/job:", job_name, "/task:0"));
coordination_config->set_cluster_register_timeout_in_ms(
absl::ToInt64Milliseconds(options.enumerate_devices_timeout));
coordination_config->set_heartbeat_timeout_in_ms(absl::ToInt64Milliseconds(
options.heartbeat_interval * options.max_missing_heartbeats));
coordination_config->set_shutdown_barrier_timeout_in_ms(
absl::ToInt64Milliseconds(options.shutdown_timeout));
return tensorflow::CoordinationServiceInterface::EnableCoordinationService(
"standalone", options.env, server_def, /*cache=*/nullptr);
}
} // namespace
namespace xla {
DistributedRuntimeServiceImpl::DistributedRuntimeServiceImpl(
const Options& options)
: options_(options), session_id_(tensorflow::random::New64()) {
nodes_.resize(options.num_nodes);
local_topologies_.resize(options.num_nodes);
}
DistributedRuntimeServiceImpl::~DistributedRuntimeServiceImpl() {
{
absl::MutexLock lock(&mu_);
state_ = State::kClosed;
service_status_ =
tensorflow::errors::FailedPrecondition("Service shutting down.");
if (!stop_heartbeat_thread_.HasBeenNotified()) {
stop_heartbeat_thread_.Notify();
}
}
}
// Steals the contents of `local_topologies`.
void BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,
GlobalTopologyProto* global_topology) {
int next_global_device_id = 0;
for (LocalTopologyProto& local : local_topologies) {
for (DeviceProto& device : *local.mutable_devices()) {
device.set_global_device_id(next_global_device_id++);
}
global_topology->add_nodes()->Swap(&local);
}
}
xla::Status DistributedRuntimeServiceImpl::ValidateNodeId(int node_id) {
if (node_id < 0) {
return xla::InvalidArgument("Invalid node ID %d, must be non-negative",
node_id);
}
if (node_id >= options_.num_nodes) {
return xla::FailedPrecondition(
"Invalid node ID %d, must be in the range [0, %d)", node_id,
options_.num_nodes);
}
return ::tensorflow::OkStatus();
}
xla::Status DistributedRuntimeServiceImpl::ValidateSessionId(
uint64_t session_id) {
if (session_id != session_id_) {
return xla::FailedPrecondition(
"Session ID of request %llu does not match active session ID %llu",
session_id, session_id_);
}
return ::tensorflow::OkStatus();
}
::grpc::Status DistributedRuntimeServiceImpl::Connect(
::grpc::ServerContext* context, const ConnectRequest* request,
ConnectResponse* response) {
VLOG(10) << "Connect " << request->DebugString();
if (request->protocol_version() != DistributedRuntimeProtocolVersion()) {
return xla::ToGrpcStatus(xla::InvalidArgument("Invalid protocol version %d",
request->protocol_version()));
}
absl::MutexLock lock(&mu_);
if (state_ != State::kInitializing) {
// This most likely indicates that a client task was restarted but the
// old master is still up. Clients should retry on failure.
return xla::ToGrpcStatus(tensorflow::errors::Aborted(
"Connect() called when system is not initializing."));
}
int node_id = request->node_id();
xla::Status status = ValidateNodeId(node_id);
if (!status.ok()) {
return xla::ToGrpcStatus(status);
}
if (!nodes_[node_id].present) {
nodes_[node_id].present = true;
++num_nodes_present_;
}
nodes_[node_id].client_id = request->client_id();
auto all_nodes_present_or_duplicate_request = [&]() {
mu_.AssertHeld();
return num_nodes_present_ == nodes_.size() ||
nodes_[node_id].client_id != request->client_id();
};
auto connect_timeout = absl::Milliseconds(request->timeout_milliseconds());
if (!mu_.AwaitWithTimeout(
absl::Condition(&all_nodes_present_or_duplicate_request),
connect_timeout)) {
nodes_[node_id].present = false;
--num_nodes_present_;
return xla::ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
"Timed out after ", absl::FormatDuration(connect_timeout),
" waiting for all nodes to call Connect()"));
}
if (nodes_[node_id].client_id != request->client_id()) {
// This might happen either if two nodes are erroneously configured with the
// same ID number, or it might happen if a task fails and is restarted
// while we are waiting for nodes to connect. To elaborate on the second
// scenario, it would look like this:
// * a task calls Connect() with a particular node_id and client_id.
// * the task is killed and restarted, or alternatively the client's RPC
// times out and it decides to retry.
// * the task calls Connect() again with the same node_id and a different
// client_id.
// In this scenario we take whichever client showed up most recently and
// evict the client with an out-of-date client ID.
return xla::ToGrpcStatus(
tensorflow::errors::Aborted("Duplicate node ID ", node_id));
}
if (node_id == 0) {
state_ = State::kRunning;
heartbeat_thread_.reset(options_.env->StartThread(
tensorflow::ThreadOptions(), "pjrt_service_heartbeat",
[this]() { HeartbeatLoop(); }));
} else {
auto running = [&]() {
mu_.AssertHeld();
return state_ == State::kRunning;
};
mu_.Await(absl::Condition(&running));
}
nodes_[node_id].last_heartbeat = absl::Now();
response->set_session_id(session_id_);
return ::grpc::Status::OK;
}
::grpc::Status DistributedRuntimeServiceImpl::Shutdown(
::grpc::ServerContext* context, const ShutdownRequest* request,
ShutdownResponse* response) {
VLOG(10) << "Shutdown " << request->DebugString();
xla::Status status = ValidateSessionId(request->session_id());
if (!status.ok()) {
return xla::ToGrpcStatus(status);
}
absl::MutexLock lock(&mu_);
if (state_ != State::kRunning) {
if (!service_status_.ok()) {
return xla::ToGrpcStatus(service_status_);
}
return xla::ToGrpcStatus(xla::FailedPrecondition(
"Shutdown() called when system is not running."));
}
int node_id = request->node_id();
status = ValidateNodeId(node_id);
if (!status.ok()) {
return xla::ToGrpcStatus(status);
}
++num_nodes_shutting_down_;
auto all_nodes_shutting_down = [&]() {
mu_.AssertHeld();
return num_nodes_shutting_down_ == nodes_.size() || !service_status_.ok();
};
if (!mu_.AwaitWithTimeout(absl::Condition(&all_nodes_shutting_down),
options_.shutdown_timeout)) {
state_ = State::kClosed;
return xla::ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
"Timed out after ", absl::FormatDuration(options_.shutdown_timeout),
" waiting for all nodes to call Shutdown()"));
}
state_ = State::kClosed;
if (!stop_heartbeat_thread_.HasBeenNotified()) {
stop_heartbeat_thread_.Notify();
}
if (!service_status_.ok()) {
return xla::ToGrpcStatus(service_status_);
}
return ::grpc::Status::OK;
}
::grpc::Status DistributedRuntimeServiceImpl::EnumerateDevices(
::grpc::ServerContext* context, const EnumerateDevicesRequest* request,
EnumerateDevicesResponse* response) {
VLOG(10) << "EnumerateDevices " << request->DebugString();
xla::Status status = ValidateSessionId(request->session_id());
if (!status.ok()) {
return xla::ToGrpcStatus(status);
}
absl::MutexLock lock(&mu_);
if (state_ != State::kRunning) {
if (!service_status_.ok()) {
return xla::ToGrpcStatus(service_status_);
}
return xla::ToGrpcStatus(xla::FailedPrecondition(
"EnumerateDevices() called when system is not running."));
}
int node_id = request->local_topology().node_id();
status = ValidateNodeId(node_id);
if (!status.ok()) {
return xla::ToGrpcStatus(status);
}
local_topologies_[node_id] = request->local_topology();
++num_topologies_present_;
auto all_topologies_present = [&]() {
mu_.AssertHeld();
return num_topologies_present_ == nodes_.size() || !service_status_.ok();
};
if (!mu_.AwaitWithTimeout(absl::Condition(&all_topologies_present),
options_.enumerate_devices_timeout)) {
return xla::ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
"Timed out after ",
absl::FormatDuration(options_.enumerate_devices_timeout),
" waiting for all nodes to call EnumerateDevices()"));
}
if (!service_status_.ok()) {
return xla::ToGrpcStatus(service_status_);
}
if (node_id == 0) {
topology_.emplace();
BuildGlobalTopology(absl::Span<LocalTopologyProto>(local_topologies_),
&*topology_);
local_topologies_.clear();
} else {
auto topology_ready = [&]() -> bool {
mu_.AssertHeld();
return topology_.has_value();
};
mu_.Await(absl::Condition(&topology_ready));
}
*response->mutable_global_topology() = *topology_;
return ::grpc::Status::OK;
}
::grpc::Status DistributedRuntimeServiceImpl::Heartbeat(
::grpc::ServerContext* context, const HeartbeatRequest* request,
HeartbeatResponse* response) {
VLOG(10) << "Heartbeat " << request->DebugString();
xla::Status status = ValidateSessionId(request->session_id());
if (!status.ok()) {
return xla::ToGrpcStatus(status);
}
absl::MutexLock lock(&mu_);
if (state_ != State::kRunning) {
if (!service_status_.ok()) {
return xla::ToGrpcStatus(service_status_);
}
return xla::ToGrpcStatus(xla::FailedPrecondition(
"Heartbeat() called when system is not running."));
}
int node_id = request->node_id();
status = ValidateNodeId(node_id);
if (!status.ok()) {
return xla::ToGrpcStatus(status);
}
nodes_[node_id].last_heartbeat = absl::Now();
return ::grpc::Status::OK;
}
void DistributedRuntimeServiceImpl::HeartbeatLoop() {
while (true) {
stop_heartbeat_thread_.WaitForNotificationWithTimeout(
options_.heartbeat_interval);
VLOG(10) << "Checking heartbeats";
if (stop_heartbeat_thread_.HasBeenNotified()) {
VLOG(10) << "Heartbeat checking stopped.";
return;
}
absl::Time now = absl::Now();
absl::MutexLock lock(&mu_);
for (size_t i = 0; i < nodes_.size(); ++i) {
// If we haven't heard from the node for a number of heartbeat intervals,
// declare that we are unhealthy.
VLOG(10) << "Node " << i
<< " last heartbeat: " << nodes_[i].last_heartbeat;
if (nodes_[i].last_heartbeat +
options_.max_missing_heartbeats * options_.heartbeat_interval <
now) {
LOG(INFO) << "Missed heartbeats from node " << i << ". Shutting down.";
state_ = State::kClosed;
service_status_ = tensorflow::errors::Aborted(
"Shutting down due to missed heartbeat from task ", i);
return;
}
}
}
}
::grpc::Status DistributedRuntimeServiceImpl::KeyValueGet(
::grpc::ServerContext* context, const KeyValueGetRequest* request,
KeyValueGetResponse* response) {
VLOG(10) << "KeyValueGet " << request->DebugString();
xla::Status status = ValidateSessionId(request->session_id());
if (!status.ok()) {
return xla::ToGrpcStatus(status);
}
{
absl::MutexLock lock(&mu_);
if (state_ != State::kRunning) {
if (!service_status_.ok()) {
return xla::ToGrpcStatus(service_status_);
}
return xla::ToGrpcStatus(xla::FailedPrecondition(
"KeyValueGet() called when system is not running."));
}
}
return key_value_store_.Get(
request->key(), absl::Milliseconds(request->timeout_milliseconds()),
response->mutable_value());
}
::grpc::Status DistributedRuntimeServiceImpl::KeyValueSet(
::grpc::ServerContext* context, const KeyValueSetRequest* request,
KeyValueSetResponse* response) {
VLOG(10) << "KeyValueSet " << request->DebugString();
xla::Status status = ValidateSessionId(request->session_id());
if (!status.ok()) {
return xla::ToGrpcStatus(status);
}
{
absl::MutexLock lock(&mu_);
if (state_ != State::kRunning) {
if (!service_status_.ok()) {
return xla::ToGrpcStatus(service_status_);
}
return xla::ToGrpcStatus(xla::FailedPrecondition(
"KeyValueSet() called when system is not running; clients must call "
"Connect() first"));
}
}
return key_value_store_.Set(request->key(), request->value());
}
::grpc::Status DistributedRuntimeServiceImpl::WaitAtBarrier(
::grpc::ServerContext* context, const WaitAtBarrierRequest* request,
WaitAtBarrierResponse* response) {
VLOG(10) << "WaitAtBarrier " << request->DebugString();
xla::Status status = ValidateSessionId(request->session_id());
if (!status.ok()) {
return xla::ToGrpcStatus(status);
}
absl::MutexLock lock(&mu_);
if (state_ != State::kRunning) {
if (!service_status_.ok()) {
return xla::ToGrpcStatus(service_status_);
}
return xla::ToGrpcStatus(xla::FailedPrecondition(
"WaitAtBarrier() called when system is not running."));
}
int node_id = request->node_id();
status = ValidateNodeId(node_id);
if (!status.ok()) {
return xla::ToGrpcStatus(status);
}
std::string barrier_id = request->barrier_id();
if (barrier_id_to_num_nodes_[barrier_id] == nodes_.size()) {
return xla::ToGrpcStatus(
xla::FailedPrecondition("Calling WaitAtBarrier with the same id "
"across barriers is not allowed. Please use "
"unique barrier ids across barriers."));
}
if (barrier_id_to_num_nodes_[barrier_id] == kBarrierTimedOut) {
return xla::ToGrpcStatus(xla::FailedPrecondition(
"A process timed out waiting at the barrier. Exiting early because the "
"current process will also timeout."));
}
++barrier_id_to_num_nodes_[barrier_id];
absl::Duration timeout = absl::Milliseconds(request->timeout_milliseconds());
auto all_nodes_at_barrier = [&]() {
mu_.AssertHeld();
return barrier_id_to_num_nodes_[barrier_id] == nodes_.size() ||
!service_status_.ok();
};
// TODO(yashkatariya,hanyangtay): Do something similar to the coordination
// service here.
if (!mu_.AwaitWithTimeout(absl::Condition(&all_nodes_at_barrier), timeout)) {
barrier_id_to_num_nodes_[barrier_id] = kBarrierTimedOut;
return xla::ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
"Timed out after ", timeout,
" waiting for all nodes to be at WaitAtBarrier()"));
}
if (!service_status_.ok()) {
return xla::ToGrpcStatus(service_status_);
}
return ::grpc::Status::OK;
}
CoordinationServiceImpl::CoordinationServiceImpl(
const DistributedRuntimeServiceImpl::Options& options,
::grpc::ServerBuilder* builder)
: env_(options.env) {
coord_service_ = EnableCoordinationService(options);
coord_compute_pool_ = std::make_unique<tensorflow::thread::ThreadPool>(
options.env, "CoordinationServiceRpcHandler",
/*num_threads=*/4);
coord_rpc_service_ =
std::make_unique<tensorflow::GrpcCoordinationServiceImpl>(
coord_compute_pool_.get(), builder);
LOG(INFO) << "Experimental coordination service is enabled.";
}
CoordinationServiceImpl::~CoordinationServiceImpl() {
coord_rpc_service_->Shutdown();
}
void CoordinationServiceImpl::StartRpcThread() {
coord_rpc_thread_.reset(env_->StartThread(
tensorflow::ThreadOptions(), "CoordinationServiceHandleRPCsLoop",
[service = coord_rpc_service_.get()] { service->HandleRPCsLoop(); }));
}
xla::StatusOr<std::unique_ptr<DistributedRuntimeService>>
DistributedRuntimeService::Get(
const std::string& address,
std::shared_ptr<::grpc::ServerCredentials> credentials,
const DistributedRuntimeServiceImpl::Options& options,
bool use_coordination_service) {
::grpc::ServerBuilder builder;
builder.AddListeningPort(address, credentials);
VLOG(1) << "Distributed runtime service address " << address;
auto service = std::make_unique<DistributedRuntimeService>(
options, &builder, use_coordination_service);
if (!service->server_) {
return xla::Unknown("Failed to start RPC server");
}
LOG(INFO) << "Jax service listening on " << address;
return service;
}
DistributedRuntimeService::DistributedRuntimeService(
const DistributedRuntimeServiceImpl::Options& options,
::grpc::ServerBuilder* builder, bool use_coordination_service) {
if (use_coordination_service) {
coord_impl_ = std::make_unique<CoordinationServiceImpl>(options, builder);
server_ = builder->BuildAndStart();
coord_impl_->StartRpcThread();
} else {
impl_ = std::make_unique<DistributedRuntimeServiceImpl>(options);
builder->RegisterService(impl_.get());
server_ = builder->BuildAndStart();
}
}
DistributedRuntimeService::~DistributedRuntimeService() { Shutdown(); }
void DistributedRuntimeService::Shutdown() {
if (server_) {
LOG(INFO) << "Jax service shutting down";
server_->Shutdown();
server_->Wait();
}
}
} // namespace xla