blob: 1ef926c28bed19cf9e1a3717dfb3b2a4094d224a [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
#include "grpcpp/impl/codegen/client_context.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace data {
// Modes for how a tf.data service job should process a dataset.
enum class ProcessingMode : int64 {
UNSET = 0,
// Each tf.data worker processes an entire epoch. If a dataset contains 2
// elements and there are 3 workers, the job will produce 6 elements.
PARALLEL_EPOCHS = 1,
// Processing of a single epoch is distributed across all tf.data workers.
DISTRIBUTED_EPOCH = 2,
};
// Parses a string representing a processing mode and stores the result in
// `mode`. Returns an InvalidArgument status if the string is not recognized.
Status ParseProcessingMode(const std::string& s, ProcessingMode& mode);
// Converts a processing mode to its corresponding string.
std::string ProcessingModeToString(ProcessingMode mode);
// Base class for data service clients. Data service clients are
// threadsafe.
class DataServiceClientBase {
public:
DataServiceClientBase(const std::string& address, const std::string& protocol)
: address_(address), protocol_(protocol) {}
virtual ~DataServiceClientBase() = default;
// Not copyable or movable.
DataServiceClientBase(const DataServiceClientBase&) = delete;
DataServiceClientBase& operator=(const DataServiceClientBase&) = delete;
// Initializes the client. Calling `Initialize()` is not required since the
// first RPC will perform any necessary initialization. However, it can be
// useful to call `Initialize()` proactively so that any errors that happen
// during initialization can be surfaced earlier.
Status Initialize() { return EnsureInitialized(); }
protected:
// Initializes the client if it isn't already initialized.
virtual Status EnsureInitialized() = 0;
const std::string address_;
const std::string protocol_;
};
// Client for communicating with the tf.data service dispatcher.
class DataServiceDispatcherClient : public DataServiceClientBase {
public:
DataServiceDispatcherClient(const std::string& address,
const std::string& protocol)
: DataServiceClientBase(address, protocol) {}
// Sends a heartbeat to the dispatcher. If the worker wasn't already
// registered with the dispatcher, this will register the worker. The
// dispatcher will report which new tasks the worker should run, and which
// tasks it should delete. This is stored into `new_tasks` and
// `tasks_to_delete`.
Status WorkerHeartbeat(const std::string& worker_address,
const std::vector<int64>& current_tasks,
std::vector<TaskDef>& new_tasks,
std::vector<int64>& tasks_to_delete);
// Updates the dispatcher with information about the worker's state.
Status WorkerUpdate(const std::string& worker_address,
std::vector<TaskProgress>& task_progress);
// Gets a dataset definition for the given dataset id, and stores the
// definition in `dataset_def`.
Status GetDatasetDef(int64 dataset_id, DatasetDef& dataset_def);
// Gets the next split for the specified job id and repetition.
Status GetSplit(int64 job_id, int64 repetition, Tensor& split,
bool& end_of_splits);
// Registers a dataset with the tf.data service, and stores the generated
// dataset id in `dataset_id`.
Status RegisterDataset(GraphDef dataset, int64& dataset_id);
// If `job_key` is set, looks up a job matching `job_key`. If `job_key` is
// absent or no matching job is found, creates a new job. The resulting job
// id is stored in `job_client_id`.
Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode,
const absl::optional<JobKey>& job_key,
absl::optional<int64> num_consumers,
int64& job_client_id);
// Releases a job client id, indicating that the id will no longer be used to
// read from the job.
Status ReleaseJobClient(int64 job_client_id);
// Queries the dispatcher for the tasks associated with the specified job.
// The tasks will be stored in `tasks`, and whether the job is finished will
// be stored in `job_finished`.
Status GetTasks(int64 job_client_id, std::vector<TaskInfo>& tasks,
bool& job_finished);
// Queries the dispatcher for its registered workers. The worker info will be
// stored in `workers`.
Status GetWorkers(std::vector<WorkerInfo>& workers);
protected:
Status EnsureInitialized() override;
private:
mutex mu_;
// Initialization is guarded by `mu_`, but using the stub does not require
// holding `mu_`
std::unique_ptr<DispatcherService::Stub> stub_;
};
// Client for communicating with the tf.data service worker.
class DataServiceWorkerClient : public DataServiceClientBase {
public:
DataServiceWorkerClient(const std::string& address,
const std::string& protocol)
: DataServiceClientBase(address, protocol) {}
// Fetches the next element for the specified task_id. The optional
// `consumer_index` and `round_index` must be specified for tasks which use
// round-robin ordering. The element's compressed tensors will be stored in
// `element`. If no element is available, `end_of_sequence` will be `true`,
// and `element` will be left unchanged.
Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
absl::optional<int64> round_index,
CompressedElement& element, bool& end_of_sequence);
// Makes a best effort to cancel all outstanding calls in progress for the
// client, and causes further calls to return Cancelled status.
void TryCancel();
protected:
Status EnsureInitialized() override;
private:
mutex mu_;
// Initialization is guarded by `mu_`, but using the stub does not require
// holding `mu_`
std::unique_ptr<WorkerService::Stub> stub_;
// Set of all currently active clients contexts. Used to support
// cancellation.
absl::flat_hash_set<::grpc::ClientContext*> active_contexts_ GUARDED_BY(mu_);
// Indicates that the client has been cancelled, so no further requests should
// be accepted.
bool cancelled_ GUARDED_BY(mu_) = false;
};
// Creates and initializes a new tf.data service dispatcher client.
Status CreateDataServiceDispatcherClient(
const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceDispatcherClient>& out);
// Creates and initializes a new tf.data service worker client.
Status CreateDataServiceWorkerClient(
const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceWorkerClient>& out);
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_