blob: eb1f3359815944e74c3afd2f82a8661056a3f2a1 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/service/data_service.h"
#include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h"
#include "absl/types/optional.h"
#include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace data {
namespace {
constexpr const char kParallelEpochs[] = "parallel_epochs";
constexpr const char kDistributedEpoch[] = "distributed_epoch";
} // namespace
Status ParseProcessingMode(const std::string& s, ProcessingMode& mode) {
if (s == kParallelEpochs) {
mode = ProcessingMode::PARALLEL_EPOCHS;
} else if (s == kDistributedEpoch) {
mode = ProcessingMode::DISTRIBUTED_EPOCH;
} else {
return errors::InvalidArgument("Unrecognized processing mode: ", s);
}
return Status::OK();
}
std::string ProcessingModeToString(ProcessingMode mode) {
switch (mode) {
case ProcessingMode::PARALLEL_EPOCHS:
return kParallelEpochs;
case ProcessingMode::DISTRIBUTED_EPOCH:
return kDistributedEpoch;
default:
DCHECK(false);
return "Unknown";
}
}
Status DataServiceDispatcherClient::WorkerHeartbeat(
const std::string& worker_address, const std::vector<int64>& current_tasks,
std::vector<TaskDef>& new_tasks, std::vector<int64>& tasks_to_delete) {
TF_RETURN_IF_ERROR(EnsureInitialized());
WorkerHeartbeatRequest req;
req.set_worker_address(worker_address);
for (int64 task : current_tasks) {
req.add_current_tasks(task);
}
WorkerHeartbeatResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->WorkerHeartbeat(&client_ctx, req, &resp);
if (!status.ok()) {
return grpc_util::WrapError("Failed to perform worker heartbeat", status);
}
for (const auto& task : resp.new_tasks()) {
new_tasks.push_back(task);
}
for (int64 task_to_delete : resp.tasks_to_delete()) {
tasks_to_delete.push_back(task_to_delete);
}
return Status::OK();
}
Status DataServiceDispatcherClient::WorkerUpdate(
const std::string& worker_address,
std::vector<TaskProgress>& task_progress) {
TF_RETURN_IF_ERROR(EnsureInitialized());
WorkerUpdateRequest req;
req.set_worker_address(worker_address);
for (const auto& update : task_progress) {
*(req.add_updates()) = update;
}
WorkerUpdateResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->WorkerUpdate(&client_ctx, req, &resp);
if (!status.ok()) {
return grpc_util::WrapError("Failed to send worker update", status);
}
return Status::OK();
}
Status DataServiceDispatcherClient::GetDatasetDef(int64 dataset_id,
DatasetDef& dataset_def) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetDatasetDefRequest req;
req.set_dataset_id(dataset_id);
GetDatasetDefResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->GetDatasetDef(&client_ctx, req, &resp);
if (!status.ok()) {
return grpc_util::WrapError("Failed to get dataset def", status);
}
dataset_def = resp.dataset_def();
return Status::OK();
}
Status DataServiceDispatcherClient::GetSplit(int64 job_id, int64 repetition,
Tensor& split,
bool& end_of_splits) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetSplitRequest req;
req.set_job_id(job_id);
req.set_repetition(repetition);
GetSplitResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->GetSplit(&client_ctx, req, &resp);
if (!status.ok()) {
return grpc_util::WrapError("Failed to get split", status);
}
end_of_splits = resp.end_of_splits();
if (!end_of_splits) {
if (!split.FromProto(resp.split())) {
return errors::Internal("Failed to parse split tensor proto");
}
}
return Status::OK();
}
Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
int64& dataset_id) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrRegisterDatasetRequest req;
*req.mutable_dataset()->mutable_graph() = dataset;
GetOrRegisterDatasetResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->GetOrRegisterDataset(&client_ctx, req, &resp);
if (!status.ok()) {
return grpc_util::WrapError("Failed to register dataset", status);
}
dataset_id = resp.dataset_id();
return Status::OK();
}
Status DataServiceDispatcherClient::GetOrCreateJob(
int64 dataset_id, ProcessingMode processing_mode,
const absl::optional<JobKey>& job_key, absl::optional<int64> num_consumers,
int64& job_client_id) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrCreateJobRequest req;
req.set_dataset_id(dataset_id);
req.set_processing_mode(ProcessingModeDef(processing_mode));
if (job_key.has_value()) {
*req.mutable_job_key() = job_key.value();
}
if (num_consumers.has_value()) {
req.set_num_consumers(num_consumers.value());
}
GetOrCreateJobResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp);
if (!status.ok()) {
return grpc_util::WrapError(
absl::StrCat("Failed to get or create job for dataset with id ",
dataset_id),
status);
}
job_client_id = resp.job_client_id();
return Status::OK();
}
Status DataServiceDispatcherClient::ReleaseJobClient(int64 job_client_id) {
TF_RETURN_IF_ERROR(EnsureInitialized());
ReleaseJobClientRequest req;
req.set_job_client_id(job_client_id);
ReleaseJobClientResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->ReleaseJobClient(&client_ctx, req, &resp);
if (!status.ok()) {
return grpc_util::WrapError(
absl::StrCat("Failed to release job client with id ", job_client_id),
status);
}
return Status::OK();
}
Status DataServiceDispatcherClient::GetTasks(int64 job_client_id,
std::vector<TaskInfo>& tasks,
bool& job_finished) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetTasksRequest req;
req.set_job_client_id(job_client_id);
GetTasksResponse resp;
grpc::ClientContext ctx;
grpc::Status s = stub_->GetTasks(&ctx, req, &resp);
if (!s.ok()) {
return grpc_util::WrapError("Failed to get tasks", s);
}
tasks.clear();
for (auto& task : resp.task_info()) {
tasks.push_back(task);
}
job_finished = resp.job_finished();
return Status::OK();
}
Status DataServiceDispatcherClient::GetWorkers(
std::vector<WorkerInfo>& workers) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetWorkersRequest req;
GetWorkersResponse resp;
grpc::ClientContext ctx;
grpc::Status s = stub_->GetWorkers(&ctx, req, &resp);
if (!s.ok()) {
return grpc_util::WrapError("Failed to get workers", s);
}
workers.clear();
for (auto& worker : resp.workers()) {
workers.push_back(worker);
}
return Status::OK();
}
Status DataServiceDispatcherClient::EnsureInitialized() {
mutex_lock l(mu_);
if (stub_) {
return Status::OK();
}
std::shared_ptr<grpc::ChannelCredentials> credentials;
TF_RETURN_IF_ERROR(
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
auto channel = grpc::CreateCustomChannel(address_, credentials, args);
stub_ = DispatcherService::NewStub(channel);
return Status::OK();
}
Status DataServiceWorkerClient::GetElement(int64 task_id,
absl::optional<int64> consumer_index,
absl::optional<int64> round_index,
CompressedElement& element,
bool& end_of_sequence) {
TF_RETURN_IF_ERROR(EnsureInitialized());
{
mutex_lock l(mu_);
if (cancelled_) {
return errors::Cancelled("Client was cancelled.");
}
}
GetElementRequest req;
req.set_task_id(task_id);
if (consumer_index.has_value()) {
req.set_consumer_index(consumer_index.value());
}
if (round_index.has_value()) {
req.set_round_index(round_index.value());
}
GetElementResponse resp;
grpc::ClientContext ctx;
{
mutex_lock l(mu_);
active_contexts_.insert(&ctx);
}
grpc::Status s = stub_->GetElement(&ctx, req, &resp);
{
mutex_lock l(mu_);
active_contexts_.erase(&ctx);
}
if (!s.ok()) {
return grpc_util::WrapError("Failed to get element", s);
}
end_of_sequence = resp.end_of_sequence();
if (!end_of_sequence) {
element = std::move(*resp.mutable_compressed_element());
}
return Status::OK();
}
Status DataServiceWorkerClient::EnsureInitialized() {
mutex_lock l(mu_);
if (stub_) {
return Status::OK();
}
std::shared_ptr<grpc::ChannelCredentials> credentials;
TF_RETURN_IF_ERROR(
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(-1);
auto channel = grpc::CreateCustomChannel(address_, credentials, args);
stub_ = WorkerService::NewStub(channel);
return Status::OK();
}
void DataServiceWorkerClient::TryCancel() {
mutex_lock l(mu_);
cancelled_ = true;
for (const auto& ctx : active_contexts_) {
ctx->TryCancel();
}
}
Status CreateDataServiceDispatcherClient(
const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceDispatcherClient>& out) {
auto client =
absl::make_unique<DataServiceDispatcherClient>(address, protocol);
TF_RETURN_IF_ERROR(client->Initialize());
out = std::move(client);
return Status::OK();
}
Status CreateDataServiceWorkerClient(
const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceWorkerClient>& out) {
auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol);
TF_RETURN_IF_ERROR(client->Initialize());
out = std::move(client);
return Status::OK();
}
} // namespace data
} // namespace tensorflow