blob: 36edbe5ce747d28f2da6acb82486dcfec8ad2e71 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_
#define TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/data/experimental/service_config.pb.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace data {
// A TensorFlow DataService serves dataset elements over RPC.
class DataServiceWorkerImpl {
public:
explicit DataServiceWorkerImpl(const experimental::WorkerConfig& config);
~DataServiceWorkerImpl();
// Starts the worker. The worker needs to know its own address so that it can
// register with the dispatcher. This is set in `Start` instead of in the
// constructor because the worker may be binding to port `0`, in which case
// the address isn't known until the worker has started and decided which port
// to bind to.
Status Start(const std::string& worker_address);
// See worker.proto for API documentation.
/// Dispatcher-facing API.
Status ProcessTask(const ProcessTaskRequest* request,
ProcessTaskResponse* response);
/// Client-facing API.
Status GetElement(const GetElementRequest* request,
GetElementResponse* response);
private:
Status MakeDispatcherStub(std::unique_ptr<DispatcherService::Stub>* stub);
// Registers the worker with the dispatcher.
Status Register(DispatcherService::Stub* dispatcher) LOCKS_EXCLUDED(mu_);
// Sends task status to the dispatcher and checks for dispatcher commands.
Status SendTaskUpdates(DispatcherService::Stub* dispatcher)
LOCKS_EXCLUDED(mu_);
// Creates an iterator to process a task.
Status ProcessTaskInternal(const TaskDef& task) EXCLUSIVE_LOCKS_REQUIRED(mu_);
// A thread for doing async background processing not associated with a
// specific RPC, such as reporting finished tasks. The thread takes
// ownership of the passed dispatcher_ptr. We use a raw pointer instead of
// unique_ptr since unique_ptr cannot be passed to std::function.
void BackgroundThread(DispatcherService::Stub* dispatcher_ptr)
LOCKS_EXCLUDED(mu_);
typedef struct Task {
int64 task_id;
// TODO(aaudibert): Have standalone::Iterator own a reference to
// standalone::Dataset so that we don't need to store the dataset here.
std::unique_ptr<standalone::Dataset> dataset;
std::unique_ptr<standalone::Iterator> iterator;
} Task;
const experimental::WorkerConfig config_;
// The worker's own address.
std::string worker_address_;
mutex mu_;
// Information about tasks, keyed by task ids.
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
// Completed tasks which haven't yet been communicated to the dispatcher.
absl::flat_hash_set<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
bool cancelled_ TF_GUARDED_BY(mu_) = false;
// Whether the worker has registered with the dispatcher yet.
bool registered_ TF_GUARDED_BY(mu_) = false;
// Condition variable for notifying the background thread.
condition_variable background_cv_ TF_GUARDED_BY(mu_);
std::unique_ptr<Thread> background_thread_;
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceWorkerImpl);
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_