| /* Copyright 2020 Google LLC |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/pjrt/distributed/client.h" |
| |
| #include <algorithm> |
| #include <chrono> // NOLINT |
| #include <random> |
| #include <string> |
| #include <utility> |
| |
| #include "absl/synchronization/mutex.h" |
| #include "absl/synchronization/notification.h" |
| #include "absl/time/time.h" |
| #include "grpcpp/channel.h" |
| #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h" |
| #include "tensorflow/compiler/xla/pjrt/distributed/util.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h" |
| #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h" |
| #include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h" |
| #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.h" |
| #include "tensorflow/core/distributed_runtime/worker_env.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/random.h" |
| #include "tensorflow/core/protobuf/coordination_config.pb.h" |
| #include "tensorflow/core/protobuf/coordination_service.pb.h" |
| |
| namespace xla { |
| class DistributedRuntimeClientImpl : public DistributedRuntimeClient { |
| public: |
| DistributedRuntimeClientImpl(std::shared_ptr<::grpc::Channel> channel, |
| const Options& options); |
| explicit DistributedRuntimeClientImpl( |
| std::shared_ptr<::grpc::Channel> channel) |
| : DistributedRuntimeClientImpl(channel, Options()) {} |
| ~DistributedRuntimeClientImpl() override; |
| |
| xla::Status Connect() override; |
| xla::Status Shutdown() override; |
| xla::Status EnumerateDevices(const LocalTopologyProto& local_topology, |
| GlobalTopologyProto* global_topology) override; |
| xla::StatusOr<std::string> BlockingKeyValueGet( |
| std::string key, absl::Duration timeout) override; |
| xla::Status KeyValueSet(std::string key, std::string value) override; |
| xla::Status WaitAtBarrier(std::string barrier_id, |
| absl::Duration timeout) override; |
| |
| private: |
| // Entry point for the heartbeat thread. |
| void HeartbeatLoop(); |
| |
| const std::unique_ptr<grpc::DistributedRuntimeService::Stub> stub_; |
| const DistributedRuntimeClient::Options options_; |
| |
| // Possible states of the client. |
| // The only legal transitions are downwards in the order below. i.e., there is |
| // no way to reopen a closed client. |
| enum class State { |
| // The client has not yet connected to the server, i.e., had a Connect() |
| // RPC succeed. |
| kNotConnected, |
| |
| // The client is connected to the server and as far as we are aware the |
| // connection is healthy. |
| kConnected, |
| |
| // The client is in the process of shutting down, i.e., Shutdown() has been |
| // called. |
| kShuttingDown, |
| |
| // The client has shut down its server connection, either due to an error |
| // or due to an explicit shutdown. |
| kClosed, |
| }; |
| |
| static absl::string_view StateToString(State state); |
| |
| // state_ is protected by a mutex because the heartbeat thread needs to look |
| // at it. |
| absl::Mutex mu_; |
| State state_ ABSL_GUARDED_BY(mu_) = State::kNotConnected; |
| |
| // A unique session ID, assigned by the server during Connect(). |
| uint64_t session_id_; |
| |
| // Notification that tells the heartbeat thread to stop running. |
| absl::Notification stop_heartbeats_; |
| |
| // Thread responsible for performing heartbeats. |
| std::unique_ptr<tensorflow::Thread> heartbeat_thread_; |
| }; |
| |
| class DistributedRuntimeCoordinationServiceClient |
| : public DistributedRuntimeClient { |
| public: |
| DistributedRuntimeCoordinationServiceClient( |
| std::shared_ptr<::grpc::Channel> channel, const Options& options); |
| explicit DistributedRuntimeCoordinationServiceClient( |
| std::shared_ptr<::grpc::Channel> channel) |
| : DistributedRuntimeCoordinationServiceClient(channel, Options()) {} |
| ~DistributedRuntimeCoordinationServiceClient() override; |
| |
| xla::Status Connect() override; |
| xla::Status Shutdown() override; |
| xla::Status EnumerateDevices(const LocalTopologyProto& local_topology, |
| GlobalTopologyProto* global_topology) override; |
| xla::StatusOr<std::string> BlockingKeyValueGet( |
| std::string key, absl::Duration timeout) override; |
| xla::Status KeyValueSet(std::string key, std::string value) override; |
| xla::Status WaitAtBarrier(std::string barrier_id, |
| absl::Duration timeout) override; |
| |
| private: |
| std::unique_ptr<tensorflow::CoordinationServiceAgent> coord_agent_; |
| tensorflow::CoordinationServiceConfig config_; |
| int task_id_; |
| }; |
| |
| DistributedRuntimeClientImpl::DistributedRuntimeClientImpl( |
| std::shared_ptr<::grpc::Channel> channel, const Options& options) |
| : stub_(grpc::DistributedRuntimeService::NewStub(std::move(channel))), |
| options_(options) {} |
| |
| DistributedRuntimeClientImpl::~DistributedRuntimeClientImpl() { |
| bool connected; |
| { |
| absl::MutexLock lock(&mu_); |
| connected = (state_ == State::kConnected); |
| } |
| if (connected) { |
| if (options_.shutdown_on_destruction) { |
| Status status = Shutdown(); |
| if (!status.ok()) { |
| LOG(WARNING) << "PJRT shutdown failed: " << status; |
| } |
| } else { |
| if (!stop_heartbeats_.HasBeenNotified()) { |
| stop_heartbeats_.Notify(); |
| } |
| } |
| } |
| } |
| |
| /*static*/ absl::string_view DistributedRuntimeClientImpl::StateToString( |
| State state) { |
| switch (state) { |
| case State::kNotConnected: |
| return "kNotConnected"; |
| case State::kConnected: |
| return "kConnected"; |
| case State::kShuttingDown: |
| return "kShuttingDown"; |
| case State::kClosed: |
| return "kClosed"; |
| } |
| } |
| |
| xla::Status DistributedRuntimeClientImpl::Connect() { |
| { |
| absl::MutexLock lock(&mu_); |
| if (state_ != State::kNotConnected) { |
| return xla::FailedPrecondition("Connect() called when client in state %s", |
| StateToString(state_)); |
| } |
| } |
| ConnectRequest request; |
| request.set_protocol_version(DistributedRuntimeProtocolVersion()); |
| request.set_timeout_milliseconds( |
| absl::ToInt64Milliseconds(options_.rpc_timeout) / 2); |
| request.set_node_id(options_.node_id); |
| VLOG(10) << "Connect: " << request.DebugString(); |
| ConnectResponse response; |
| ::grpc::Status status; |
| absl::Time deadline = absl::Now() + options_.init_timeout; |
| int attempt = 0; |
| std::default_random_engine generator; |
| std::uniform_real_distribution<double> distribution(0.0, 1.0); |
| do { |
| ::grpc::ClientContext ctx; |
| ctx.set_fail_fast(false); |
| ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout)); |
| request.set_client_id(tensorflow::random::New64()); |
| response.Clear(); |
| status = stub_->Connect(&ctx, request, &response); |
| if (!status.ok()) { |
| VLOG(1) << "Connect failed() with status: " << FromGrpcStatus(status); |
| if (attempt % 10 == 0) { |
| LOG(INFO) << "Connect failed() with status: " << FromGrpcStatus(status); |
| } |
| // Exponential backoff with jitter. Note we will retry for `init_timeout` |
| // time in total; the `14` here corresponds to an ~16s maximum interval |
| // between connection attempts. |
| int backoff = 1 << std::min(14, attempt); |
| absl::SleepFor(absl::Milliseconds(backoff * distribution(generator))); |
| } |
| ++attempt; |
| } while (!status.ok() && absl::Now() < deadline); |
| if (!status.ok()) { |
| LOG(ERROR) << "Connect() failed after " << attempt << " retries in " |
| << options_.init_timeout |
| << "; most recent failure status: " << FromGrpcStatus(status); |
| return tensorflow::errors::DeadlineExceeded( |
| absl::StrFormat("Connect() timed out after %s with %d attempts. Most " |
| "recent failure was: %s", |
| absl::FormatDuration(options_.init_timeout), attempt, |
| FromGrpcStatus(status).ToString())); |
| } |
| VLOG(10) << "Connect() response: " << response.DebugString(); |
| { |
| absl::MutexLock lock(&mu_); |
| state_ = State::kConnected; |
| } |
| session_id_ = response.session_id(); |
| |
| heartbeat_thread_.reset(options_.env->StartThread( |
| tensorflow::ThreadOptions(), "pjrt_distributed_heartbeat", |
| [this]() { HeartbeatLoop(); })); |
| LOG(INFO) << "Connected to distributed JAX controller"; |
| return ::tensorflow::OkStatus(); |
| } |
| |
| xla::Status DistributedRuntimeClientImpl::EnumerateDevices( |
| const LocalTopologyProto& local_topology, |
| GlobalTopologyProto* global_topology) { |
| { |
| absl::MutexLock lock(&mu_); |
| if (state_ != State::kConnected) { |
| return xla::FailedPrecondition( |
| "EnumerateDevices() called when client not connected."); |
| } |
| } |
| ::grpc::ClientContext ctx; |
| ctx.set_fail_fast(false); |
| ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout)); |
| EnumerateDevicesRequest request; |
| request.set_session_id(session_id_); |
| *request.mutable_local_topology() = local_topology; |
| request.mutable_local_topology()->set_node_id(options_.node_id); |
| |
| VLOG(10) << "EnumerateDevices: " << request.DebugString(); |
| EnumerateDevicesResponse response; |
| ::grpc::Status status = stub_->EnumerateDevices(&ctx, request, &response); |
| if (!status.ok()) { |
| return FromGrpcStatus(status); |
| } |
| VLOG(10) << "EnumerateDevices() response: " << response.DebugString(); |
| response.mutable_global_topology()->Swap(global_topology); |
| return ::tensorflow::OkStatus(); |
| } |
| |
| xla::Status DistributedRuntimeClientImpl::Shutdown() { |
| LOG(INFO) << "Waiting for all distributed JAX tasks to shut down."; |
| ::grpc::ClientContext ctx; |
| { |
| absl::MutexLock lock(&mu_); |
| if (state_ != State::kConnected) { |
| return xla::FailedPrecondition( |
| "Shutdown() called when client not connected."); |
| } |
| state_ = State::kShuttingDown; |
| } |
| ctx.set_fail_fast(false); |
| ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.shutdown_timeout)); |
| ShutdownRequest request; |
| request.set_session_id(session_id_); |
| VLOG(10) << "Shutdown: " << request.DebugString(); |
| ShutdownResponse response; |
| ::grpc::Status status = stub_->Shutdown(&ctx, request, &response); |
| |
| LOG(INFO) << "Distributed task shutdown result: " << FromGrpcStatus(status); |
| if (!status.ok()) { |
| return FromGrpcStatus(status); |
| } |
| if (!stop_heartbeats_.HasBeenNotified()) { |
| stop_heartbeats_.Notify(); |
| } |
| VLOG(10) << "Shutdown() response: " << response.DebugString(); |
| absl::MutexLock lock(&mu_); |
| state_ = State::kClosed; |
| return ::tensorflow::OkStatus(); |
| } |
| |
| xla::StatusOr<std::string> DistributedRuntimeClientImpl::BlockingKeyValueGet( |
| std::string key, absl::Duration timeout) { |
| { |
| absl::MutexLock lock(&mu_); |
| if (state_ != State::kConnected) { |
| return xla::FailedPrecondition( |
| "BlockingKeyValueGet() called when client not connected."); |
| } |
| } |
| ::grpc::ClientContext ctx; |
| ctx.set_fail_fast(false); |
| ctx.set_deadline(absl::ToChronoTime(absl::Now() + timeout)); |
| KeyValueGetRequest request; |
| request.set_session_id(session_id_); |
| request.set_key(std::move(key)); |
| timeout = std::min(timeout, absl::Minutes(10)); // Avoid overflow |
| request.set_timeout_milliseconds(absl::ToInt64Milliseconds(timeout)); |
| VLOG(10) << "BlockingKeyValueGet: " << request.DebugString(); |
| KeyValueGetResponse response; |
| ::grpc::Status status = stub_->KeyValueGet(&ctx, request, &response); |
| if (!status.ok()) { |
| return FromGrpcStatus(status); |
| } |
| return response.value(); |
| } |
| |
| xla::Status DistributedRuntimeClientImpl::KeyValueSet(std::string key, |
| std::string value) { |
| { |
| absl::MutexLock lock(&mu_); |
| if (state_ != State::kConnected) { |
| return xla::FailedPrecondition( |
| "KeyValueSet() called when client not connected."); |
| } |
| } |
| ::grpc::ClientContext ctx; |
| ctx.set_fail_fast(false); |
| ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout)); |
| KeyValueSetRequest request; |
| request.set_session_id(session_id_); |
| request.set_key(std::move(key)); |
| request.set_value(std::move(value)); |
| VLOG(10) << "KeyValueSet: " << request.DebugString(); |
| KeyValueSetResponse response; |
| ::grpc::Status status = stub_->KeyValueSet(&ctx, request, &response); |
| return FromGrpcStatus(status); |
| } |
| |
| xla::Status DistributedRuntimeClientImpl::WaitAtBarrier( |
| std::string barrier_id, absl::Duration timeout) { |
| { |
| absl::MutexLock lock(&mu_); |
| if (state_ != State::kConnected) { |
| return xla::FailedPrecondition( |
| "WaitAtBarrier() called when client not connected."); |
| } |
| } |
| ::grpc::ClientContext ctx; |
| ctx.set_fail_fast(false); |
| ctx.set_deadline(absl::ToChronoTime(absl::Now() + timeout)); |
| WaitAtBarrierRequest request; |
| request.set_session_id(session_id_); |
| request.set_barrier_id(std::move(barrier_id)); |
| request.set_node_id(options_.node_id); |
| // TODO(yashkatariya,hanyuangtay): Change timeout_milliseconds to int64 in |
| // protocol.proto so that we don't need a minimum timeout here. |
| timeout = std::min(timeout, absl::Minutes(10)); // Avoid overflow |
| request.set_timeout_milliseconds(absl::ToInt64Milliseconds(timeout)); |
| VLOG(10) << "WaitAtBarrier: " << request.DebugString(); |
| WaitAtBarrierResponse response; |
| ::grpc::Status status = stub_->WaitAtBarrier(&ctx, request, &response); |
| return FromGrpcStatus(status); |
| } |
| |
| void DistributedRuntimeClientImpl::HeartbeatLoop() { |
| int num_missing_heartbeats = 0; |
| while (true) { |
| stop_heartbeats_.WaitForNotificationWithTimeout( |
| options_.heartbeat_interval); |
| if (stop_heartbeats_.HasBeenNotified()) { |
| return; |
| } |
| |
| ::grpc::ClientContext ctx; |
| ctx.set_fail_fast(false); |
| ctx.set_deadline( |
| absl::ToChronoTime(absl::Now() + options_.heartbeat_interval)); |
| HeartbeatRequest request; |
| request.set_session_id(session_id_); |
| request.set_node_id(options_.node_id); |
| VLOG(10) << "Heartbeat: " << request.DebugString(); |
| HeartbeatResponse response; |
| ::grpc::Status status = stub_->Heartbeat(&ctx, request, &response); |
| if (status.ok()) { |
| VLOG(10) << "Heartbeat ok"; |
| num_missing_heartbeats = 0; |
| } else { |
| ++num_missing_heartbeats; |
| VLOG(10) << "Heartbeat error, " |
| << options_.max_missing_heartbeats - num_missing_heartbeats |
| << " tries left: " << status.error_message(); |
| bool is_transient_error = |
| (status.error_code() == ::grpc::StatusCode::DEADLINE_EXCEEDED || |
| status.error_code() == ::grpc::StatusCode::UNAVAILABLE); |
| if (!stop_heartbeats_.HasBeenNotified() && |
| (!is_transient_error || |
| num_missing_heartbeats >= options_.max_missing_heartbeats)) { |
| // If we are shutting down, missed heartbeats are benign: they may |
| // simply mean that the server has shut down already before it saw |
| // the heartbeat request. |
| absl::MutexLock lock(&mu_); |
| if (state_ != State::kShuttingDown) { |
| options_.missed_heartbeat_callback(FromGrpcStatus(status), |
| !is_transient_error); |
| } |
| return; |
| } |
| } |
| } |
| } |
| |
| DistributedRuntimeCoordinationServiceClient:: |
| DistributedRuntimeCoordinationServiceClient( |
| std::shared_ptr<::grpc::Channel> channel, const Options& options) { |
| // Convert options to coordination config. |
| tensorflow::CoordinationServiceConfig config; |
| config.set_service_type("standalone"); |
| config.set_service_leader("/job:jax_worker/task:0"); |
| config.set_cluster_register_timeout_in_ms( |
| absl::ToInt64Milliseconds(options.init_timeout)); |
| config.set_heartbeat_timeout_in_ms(absl::ToInt64Milliseconds( |
| options.heartbeat_interval * options.max_missing_heartbeats)); |
| config.set_shutdown_barrier_timeout_in_ms( |
| absl::ToInt64Milliseconds(options.shutdown_timeout)); |
| config.set_agent_destruction_without_shutdown( |
| !options.shutdown_on_destruction); |
| auto error_fn = |
| [timeout_fn = options.missed_heartbeat_callback](const Status& status) { |
| LOG(ERROR) << "Coordination service agent in error status: " << status; |
| timeout_fn(status, /*coordinator_reported_failure=*/true); |
| }; |
| |
| std::unique_ptr<tensorflow::CoordinationClient> leader_client; |
| leader_client.reset(tensorflow::NewGrpcCoordinationClient(channel)); |
| coord_agent_ = tensorflow::CreateCoordinationServiceAgent(); |
| const Status status = |
| coord_agent_->Initialize(options.env, "jax_worker", options.node_id, |
| config, std::move(leader_client), error_fn); |
| if (!status.ok()) { |
| LOG(ERROR) << "Coordination agent failed to initialize: " << status; |
| } |
| task_id_ = options.node_id; |
| config_ = config; |
| } |
| |
| DistributedRuntimeCoordinationServiceClient:: |
| ~DistributedRuntimeCoordinationServiceClient() {} |
| |
| xla::Status DistributedRuntimeCoordinationServiceClient::Connect() { |
| Status s = tensorflow::errors::Unknown("Connection not attempted yet."); |
| absl::Duration timeout = |
| absl::Milliseconds(config_.cluster_register_timeout_in_ms()); |
| absl::Time deadline = absl::Now() + timeout; |
| int attempt = 0; |
| std::default_random_engine generator; |
| std::uniform_real_distribution<double> distribution(0.0, 1.0); |
| |
| do { |
| ++attempt; |
| s = coord_agent_->Connect(); |
| if (s.ok()) { |
| s = coord_agent_->WaitAtBarrier("PjRT_Client_Connect", timeout, |
| /*tasks=*/{}); |
| } |
| // Exponential backoff with jitter. Note we will retry for `init_timeout` |
| // time in total; the `14` here corresponds to an ~16s maximum interval |
| // between connection attempts. |
| |
| int backoff = 1 << std::min(14, attempt); |
| absl::SleepFor(absl::Milliseconds(backoff * distribution(generator))); |
| } while (!s.ok() && absl::Now() < deadline && |
| // Retries are only made for RPC errors. If a valid service error is |
| // returned, fail immediately. |
| s.GetPayload(tensorflow::CoordinationErrorPayloadKey()) == |
| absl::nullopt); |
| if (s.ok()) { |
| LOG(INFO) << "Connected to distributed JAX controller"; |
| } else { |
| LOG(INFO) << "Failed to connect to distributed JAX controller: " << s; |
| } |
| return s; |
| } |
| |
| xla::Status DistributedRuntimeCoordinationServiceClient::Shutdown() { |
| LOG(INFO) << "Distributed task shutdown initiated."; |
| Status s = coord_agent_->Shutdown(); |
| LOG(INFO) << "Distributed task shutdown result: " << s; |
| return s; |
| } |
| |
| xla::Status DistributedRuntimeCoordinationServiceClient::EnumerateDevices( |
| const LocalTopologyProto& local_topology, |
| GlobalTopologyProto* global_topology) { |
| tensorflow::CoordinationServiceDeviceInfo devices; |
| LocalTopologyProto* device = |
| devices.mutable_xla()->mutable_devices()->add_nodes(); |
| *device = local_topology; |
| device->set_node_id(task_id_); |
| Status s = coord_agent_->WaitForAllTasks(devices); |
| if (!s.ok()) return s; |
| *global_topology = coord_agent_->GetClusterDeviceInfo().xla().devices(); |
| return ::tensorflow::OkStatus(); |
| } |
| |
| xla::StatusOr<std::string> |
| DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet( |
| std::string key, absl::Duration timeout) { |
| return coord_agent_->GetKeyValue(key, timeout); |
| } |
| |
| xla::Status DistributedRuntimeCoordinationServiceClient::KeyValueSet( |
| std::string key, std::string value) { |
| return coord_agent_->InsertKeyValue(key, value); |
| } |
| |
| xla::Status DistributedRuntimeCoordinationServiceClient::WaitAtBarrier( |
| std::string barrier_id, absl::Duration timeout) { |
| return coord_agent_->WaitAtBarrier(barrier_id, timeout, /*tasks=*/{}); |
| } |
| |
| std::unique_ptr<DistributedRuntimeClient> GetDistributedRuntimeClient( |
| std::shared_ptr<::grpc::Channel> channel, |
| const DistributedRuntimeClient::Options& options, |
| bool use_coordination_service) { |
| if (use_coordination_service) { |
| return std::make_unique<xla::DistributedRuntimeCoordinationServiceClient>( |
| channel, options); |
| } |
| return std::make_unique<xla::DistributedRuntimeClientImpl>(channel, options); |
| } |
| } // namespace xla |