#include "grpcpp/impl/codegen/client_context.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace data {
// Modes for how a service job should process a dataset.
enum class ProcessingMode : int64 {
UNSET = 0,
// Each worker processes an entire epoch. If a dataset contains 2
// elements and there are 3 workers, the job will produce 6 elements.
// Processing of a single epoch is distributed across all workers.
// Parses a string representing a processing mode and stores the result in
// `mode`. Returns an InvalidArgument status if the string is not recognized.
Status ParseProcessingMode(const std::string& s, ProcessingMode& mode);
// Converts a processing mode to its corresponding string.
std::string ProcessingModeToString(ProcessingMode mode);
// Base class for data service clients. Data service clients are
// threadsafe.
class DataServiceClientBase {
DataServiceClientBase(const std::string& address, const std::string& protocol)
: address_(address), protocol_(protocol) {}
virtual ~DataServiceClientBase() = default;
// Not copyable or movable.
DataServiceClientBase(const DataServiceClientBase&) = delete;
DataServiceClientBase& operator=(const DataServiceClientBase&) = delete;
// Initializes the client. Calling `Initialize()` is not required since the
// first RPC will perform any necessary initialization. However, it can be
// useful to call `Initialize()` proactively so that any errors that happen
// during initialization can be surfaced earlier.
Status Initialize() { return EnsureInitialized(); }
// Initializes the client if it isn't already initialized.
virtual Status EnsureInitialized() = 0;
const std::string address_;
const std::string protocol_;
// Client for communicating with the service dispatcher.
class DataServiceDispatcherClient : public DataServiceClientBase {
DataServiceDispatcherClient(const std::string& address,
const std::string& protocol)
: DataServiceClientBase(address, protocol) {}
// Sends a heartbeat to the dispatcher. If the worker wasn't already
// registered with the dispatcher, this will register the worker. The
// dispatcher will report which new tasks the worker should run, and which
// tasks it should delete. This is stored into `new_tasks` and
// `tasks_to_delete`.
Status WorkerHeartbeat(const std::string& worker_address,
const std::vector<int64>& current_tasks,
std::vector<TaskDef>& new_tasks,
std::vector<int64>& tasks_to_delete);
// Updates the dispatcher with information about the worker's state.
Status WorkerUpdate(const std::string& worker_address,
std::vector<TaskProgress>& task_progress);
// Gets a dataset definition for the given dataset id, and stores the
// definition in `dataset_def`.
Status GetDatasetDef(int64 dataset_id, DatasetDef& dataset_def);
// Gets the next split for the specified job id and repetition.
Status GetSplit(int64 job_id, int64 repetition, Tensor& split,
bool& end_of_splits);
// Registers a dataset with the service, and stores the generated
// dataset id in `dataset_id`.
Status RegisterDataset(GraphDef dataset, int64& dataset_id);
// If `job_key` is set, looks up a job matching `job_key`. If `job_key` is
// absent or no matching job is found, creates a new job. The resulting job
// id is stored in `job_client_id`.
Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode,
const absl::optional<JobKey>& job_key,
absl::optional<int64> num_consumers,
int64& job_client_id);
// Releases a job client id, indicating that the id will no longer be used to
// read from the job.
Status ReleaseJobClient(int64 job_client_id);
// Queries the dispatcher for the tasks associated with the specified job.
// The tasks will be stored in `tasks`, and whether the job is finished will
// be stored in `job_finished`.
Status GetTasks(int64 job_client_id, std::vector<TaskInfo>& tasks,
bool& job_finished);
// Queries the dispatcher for its registered workers. The worker info will be
// stored in `workers`.
Status GetWorkers(std::vector<WorkerInfo>& workers);
Status EnsureInitialized() override;
mutex mu_;
// Initialization is guarded by `mu_`, but using the stub does not require
// holding `mu_`
std::unique_ptr<DispatcherService::Stub> stub_;
// Client for communicating with the service worker.
class DataServiceWorkerClient : public DataServiceClientBase {
DataServiceWorkerClient(const std::string& address,
const std::string& protocol)
: DataServiceClientBase(address, protocol) {}
// Fetches the next element for the specified task_id. The optional
// `consumer_index` and `round_index` must be specified for tasks which use
// round-robin ordering. The element's compressed tensors will be stored in
// `element`. If no element is available, `end_of_sequence` will be `true`,
// and `element` will be left unchanged.
Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
absl::optional<int64> round_index,
CompressedElement& element, bool& end_of_sequence);
// Makes a best effort to cancel all outstanding calls in progress for the
// client, and causes further calls to return Cancelled status.
void TryCancel();
Status EnsureInitialized() override;
mutex mu_;
// Initialization is guarded by `mu_`, but using the stub does not require
// holding `mu_`
std::unique_ptr<WorkerService::Stub> stub_;
// Set of all currently active clients contexts. Used to support
// cancellation.
absl::flat_hash_set<::grpc::ClientContext*> active_contexts_ GUARDED_BY(mu_);
// Indicates that the client has been cancelled, so no further requests should
// be accepted.
bool cancelled_ GUARDED_BY(mu_) = false;
// Creates and initializes a new service dispatcher client.
Status CreateDataServiceDispatcherClient(
const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceDispatcherClient>& out);
// Creates and initializes a new service worker client.
Status CreateDataServiceWorkerClient(
const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceWorkerClient>& out);
} // namespace data
} // namespace tensorflow