blob: 2e6e0465f7b1fd88902550177b480b1f2b0537ad [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <atomic>
#include <deque>
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "cycle_length", &cycle_length));
OP_REQUIRES(ctx, cycle_length > 0,
errors::InvalidArgument("`cycle_length` must be > 0"));
int64 block_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "block_length", &block_length));
OP_REQUIRES(ctx, block_length > 0,
errors::InvalidArgument("`block_length` must be > 0"));
bool sloppy = false;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy));
int64 buffer_output_elements = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "buffer_output_elements",
&buffer_output_elements));
OP_REQUIRES(
ctx, buffer_output_elements > 0,
errors::InvalidArgument("`buffer_output_elements` must be > 0"));
int64 prefetch_input_elements = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefetch_input_elements",
&prefetch_input_elements));
OP_REQUIRES(
ctx, prefetch_input_elements >= 0,
errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(
ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
&captured_func));
*output =
new Dataset(ctx, input, interleave_func_, std::move(captured_func),
cycle_length, block_length, sloppy, buffer_output_elements,
prefetch_input_elements, output_types_, output_shapes_);
}
private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, bool sloppy, int64 buffer_output_elements,
int64 prefetch_input_elements, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
input_(input),
interleave_func_(func),
captured_func_(std::move(captured_func)),
cycle_length_(cycle_length),
block_length_(block_length),
sloppy_(sloppy),
buffer_output_elements_(buffer_output_elements),
prefetch_input_elements_(prefetch_input_elements),
output_types_(output_types),
output_shapes_(output_shapes) {
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::ParallelInterleave")}));
}
const DataTypeVector& output_dtypes() const override {
return output_types_;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return "ParallelInterleaveDatasetOp::Dataset";
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
Node* sloppy_node;
TF_RETURN_IF_ERROR(b->AddScalar(sloppy_, &sloppy_node));
Node* buffer_output_elements_node;
TF_RETURN_IF_ERROR(
b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
Node* prefetch_input_elements_node;
TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_,
&prefetch_input_elements_node));
DataTypeVector other_arguments_types;
other_arguments_types.reserve(captured_func_->captured_inputs().size());
std::vector<Node*> other_arguments;
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}
AttrValue f;
b->BuildAttrValue(interleave_func_, &f);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
TF_RETURN_IF_ERROR(b->AddDataset(
this,
{{0, input_node},
{2, cycle_length_node},
{3, block_length_node},
{4, sloppy_node},
{5, buffer_output_elements_node},
{6, prefetch_input_elements_node}},
{{1, other_arguments}},
{{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
return Status::OK();
}
private:
int64 num_threads() const {
return cycle_length_ + prefetch_input_elements_;
}
// Parallel interleave's implementation is designed around a few principles:
// 1. Thread creation is relatively expensive. (Not reusing
// threads causes a number of indirect costs such as poorer tcmalloc
// performance due to thread-local caches, etc.) We allocate a fixed
// number of threads at the start and never change. This is why we've
// fused functionality that is theoretically orthogonal (i.e.
// .prefetch()) into the implementation.
// 2. Drop-in replacement for standard interleave. The goal will be to
// auto-opt people into an optimized implementation without any work
// on the customer's part. We thus go through great pains to maintain
// identical iteration orders, full determinism (disabled only via a
// flag, etc.)
// 3. Performance across a variety of environments and I/O envelopes.
//
// The actual implementation centers around a collection of worker threads
// and their corresponding worker state (tracked in the `workers_` vector).
// Worker threads repeatedly receive a vector of Tensors that are used as
// input to the flat-map function (`captured_func_`). The output of this
// function must be a dataset. The worker thread then repeatedly calls
// `GetNext()`, maintaining a buffer of elements to minimize the likelihood
// that a caller will block waiting for an element to be produced.
//
// Pointers to these worker states are kept in 2 disjoint data structures:
// 1. `interleave_indices_` is a vector containing indices of WorkerStates
// in `workers_` that we are interleaving. Worker threads backing these
// WorkerStates should be regularly producing values.
// 2. `staging_indices_` is a deque containing indices of WorkerStates in
// `workers_` that we will move to `interleave_indices_` when an
// iterator in `interleave_indices_` is exhausted.
//
// The client calls `GetNext[Internal]()` to retrieve an output element. The
// internal implementation updates the state of `interleave_indices_` and
// `staging_indices_` as output iterators (run by the worker threads) are
// exhausted.
//
// `input_impl_` is the input iterator that generates arguments for the
// flat-map function (`captured_func_`). It is set to an iterator at
// Iterator construction, and is fixed until we consume all input elements.
// Once it is exhausted, we reset the unique_ptr to eagerly deallocate
// memory.
//
// A few invariants are maintained:
// 1. No element in interleave_indices_ should be a -1 unless
// `staging_indices_` is empty and `input_impl_` is empty.
// 2. Every `worker_` element is pointed to by at most one element of the
// union of `interleave_indices_` and `staging_indices_`.
// 3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
// an element in `interleave_indices_` or `staging_indices_`.
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
workers_(dataset()->num_threads()),
worker_thread_states_(dataset()->num_threads()) {}
~Iterator() override {
mutex_lock l(mu_);
cancelled_ = true;
// Notify all workers in case they are blocked.
for (auto& worker : workers_) {
worker.cond_var.notify_all();
}
}
Status Initialize(IteratorContext* ctx) override {
AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
}
// It is implemented so that it matches the deterministic interleave
// unless getting the next element would block and we are allowed to be
// sloppy.
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
while (!cancelled_) {
// Wait for an item to become available, blocking if necessary. If we
// are allowed to be sloppy, we can skip over input datasets that do
// not have an item readily available.
bool can_produce_elements = false;
bool must_wait_for_input = true;
for (int64 i = 0; i < interleave_indices_.size(); ++i) {
int64 index = (next_index_ + i) % interleave_indices_.size();
int64 current_worker_index = interleave_indices_[index];
if (current_worker_index < 0) {
continue; // Empty interleave elements.
}
WorkerState* current_worker = &workers_[current_worker_index];
can_produce_elements |= current_worker->MayHaveElements();
if (!current_worker->outputs.empty()) {
// We have an element!
next_index_ = index;
const bool element_acquired_sloppily =
dataset()->sloppy_ && i > 1;
if (!element_acquired_sloppily) {
// If the element was acquired in the regular (non-sloppy)
// order, then advance the current block and cycle pointers to
// the next element in the regular order.
block_count_++;
if (block_count_ == dataset()->block_length_) {
next_index_ = (index + 1) % interleave_indices_.size();
block_count_ = 0;
}
} else {
block_count_ = 0;
}
*end_of_sequence = false;
Status s = current_worker->outputs.front().status;
current_worker->outputs.front().output.swap(*out_tensors);
current_worker->outputs.pop_front();
current_worker->cond_var.notify_one();
return s;
} else if (current_worker->is_producing && !dataset()->sloppy_) {
// current_worker.outputs.empty(), and we must wait for this
// iterator.
if (next_index_ != index) {
// We have advanced to a new iterator; reset block counts.
next_index_ = index;
block_count_ = 0;
}
break;
} else if (!current_worker->is_producing) {
// This iterator has reached end of input.
interleave_indices_[index] = -1;
if (input_impl_) {
// Start prefetching a new iterator.
std::vector<Tensor> args;
bool end_of_input = false;
Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
if (end_of_input) {
input_impl_.reset();
} else {
current_worker->SetInputs(s, std::move(args));
staging_indices_.emplace_back(current_worker_index);
}
}
if (!staging_indices_.empty()) {
// Move a worker from `staging_indices_` to
// `interleave_indices_`.
interleave_indices_[index] = staging_indices_.front();
staging_indices_.pop_front();
next_index_ = (index + 1) % interleave_indices_.size();
block_count_ = 0;
// Restart the inner [for] loop
can_produce_elements = true;
must_wait_for_input = false;
break;
}
}
}
if (!can_produce_elements && !input_impl_) {
// No potential for future values.
*end_of_sequence = true;
return Status::OK();
}
if (must_wait_for_input) {
// Wait for elements to become available.
RecordStop(ctx);
if (dataset()->sloppy_) {
sloppy_cond_var_.wait(l);
} else {
workers_[interleave_indices_[next_index_]].cond_var.wait(l);
}
RecordStart(ctx);
}
}
return errors::Cancelled(
"ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
// The order of locking is important here to avoid deadlock.
mutex_lock l(mu_);
mutex_lock ckpt_l(ckpt_mu_);
if (input_impl_) {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_exhausted"), ""));
}
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("next_index"), next_index_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("block_count"), block_count_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("workers_size"), workers_.size()));
for (int i = 0; i < workers_.size(); ++i) {
TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i));
}
for (int i = 0; i < worker_thread_states_.size(); ++i) {
TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(writer, i));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("interleave_size"),
interleave_indices_.size()));
for (int i = 0; i < interleave_indices_.size(); ++i) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("interleave_indices_", i)),
interleave_indices_[i]));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("staging_size"),
staging_indices_.size()));
for (int i = 0; i < staging_indices_.size(); ++i) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("staging_indices_", i)),
staging_indices_[i]));
}
if (!worker_threads_.empty()) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("worker_threads_running"), ""));
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
// The order of locking is important here to avoid deadlock.
mutex_lock l(mu_);
mutex_lock ckpt_l(ckpt_mu_);
if (!reader->Contains(full_name("input_exhausted"))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
int64 temp;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next_index"), &temp));
next_index_ = size_t(temp);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("block_count"), &temp));
block_count_ = size_t(temp);
// Restore WorkerStates.
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("workers_size"), &temp));
if (temp != dataset()->num_threads()) {
return errors::Internal("Expected ", dataset()->num_threads(),
" worker states but found ", temp, ".");
}
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
TF_RETURN_IF_ERROR(ReadWorkerStateLocked(reader, i, ctx));
}
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
TF_RETURN_IF_ERROR(ReadWorkerThreadStateLocked(reader, i, ctx));
}
// Restore `interleave_indices_`.
std::set<int64> all_indices;
{
int64 interleave_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("interleave_size"),
&interleave_size));
interleave_indices_.reserve(interleave_size);
for (int64 i = 0; i < interleave_size; ++i) {
int64 temp;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("interleave_indices_", i)), &temp));
if (temp >= 0 && all_indices.find(temp) != all_indices.end()) {
return errors::Internal(
"Duplicate entry for ", temp,
" found when reading interleave and staging indices.");
}
if (temp >= 0) {
all_indices.insert(temp);
}
interleave_indices_.emplace_back(temp);
}
}
// Restore `staging_indices_`.
{
int64 staging_size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("staging_size"), &staging_size));
for (int i = 0; i < staging_size; ++i) {
int64 temp;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("staging_indices_", i)), &temp));
if (all_indices.find(temp) != all_indices.end()) {
return errors::Internal(
"Duplicate entry for ", temp,
" found when reading interleave and staging indices.");
}
if (temp >= 0) {
all_indices.insert(temp);
}
staging_indices_.emplace_back(temp);
}
}
// Start Worker threads.
if (reader->Contains(full_name("worker_threads_running"))) {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
}
}
return Status::OK();
}
private:
// OutputElem contains the information from a call to GetNext by an output
// iterator.
struct OutputElem {
// The output iterator sets `status` if getting the output element
// fails.
Status status;
// The buffered data element.
std::vector<Tensor> output;
explicit OutputElem(const Status& s) : status(s) {}
};
// Worker threads operate on their relevant WorkerState structs.
//
// WorkerState's fields are all protected by mu_;
struct WorkerState {
// The arguments to be used to construct an output iterator.
std::vector<Tensor> input;
// The buffered output elements.
std::deque<OutputElem> outputs;
// Set to true iff the worker thread expects to append more elements to
// outputs. is_producing can be false despite !outputs.empty().
// Concretely, all output elements will have been consumed only when:
// is_producing == false && outputs.empty();
bool is_producing = false;
// Condition variable used to coordinate between threads. The worker
// thread waits on this condition variable when it is either (1) waiting
// for the main thread to add arguments to `input`, or (2) waiting for
// the main thread to consume an element of `outputs`. The main thread
// waits on cond_var if it is waiting for the worker thread to produce
// an element into `outputs` (this implies sloppy_==false).
condition_variable cond_var;
inline bool MayHaveElements() const {
return is_producing || !outputs.empty();
}
// Sets inputs for a worker thread and notifies it to start processing.
void SetInputs(const Status& s, std::vector<Tensor> input_arguments) {
if (s.ok()) {
DCHECK(!MayHaveElements())
<< "Tried to start inputs, despite already producing!";
input = std::move(input_arguments);
is_producing = true;
cond_var.notify_one();
} else {
outputs.emplace_back(s);
}
}
};
// The internal state of a worker thread that is not already captured
// in its `WorkerState`.
//
// This is needed only for checkpointing purposes. We keep this
// separate from `WorkerState` and guard its fields using a separate
// lock `ckpt_mu_` so as to not affect the performance of main pipeline.
struct WorkerThreadState {
// The output element that has been produced from the input iterator
// and is waiting to be added to `WorkerState.outputs`.
OutputElem output_elem;
// Whether the input iterator returned an `end_of_sequence`.
bool end_of_sequence = false;
// Status returned from `MakeIteratorFromInputElement`.
Status iterator_creation_status;
// The arguments to be used to construct `iterator`.
std::vector<Tensor> input;
std::unique_ptr<IteratorBase> iterator;
WorkerThreadState() : output_elem(Status::OK()) {}
};
Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (worker_threads_.empty()) {
worker_threads_.reserve(dataset()->num_threads());
for (int64 i = 0; i < dataset()->num_threads(); ++i) {
std::vector<Tensor> args;
bool end_of_input = false;
Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
if (end_of_input) {
input_impl_.reset();
return Status::OK();
}
workers_[i].SetInputs(s, std::move(args));
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);
} else {
staging_indices_.push_back(i);
}
}
DCHECK(interleave_indices_.size() == dataset()->cycle_length_);
DCHECK(staging_indices_.size() ==
dataset()->prefetch_input_elements_);
}
return Status::OK();
}
// Produces elements into the worker's output buffers.
void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
const int64 thread_index) {
// Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
//
// 1. Any local state that may need to be checkpointed should be kept
// in `worker_thread_states_[thread_index]`.
// 2. `WorkerThreadState` should contain state that is needed only for
// checkpointing, i.e., if we were to remove checkpointing support,
// we could keep that state as local variables in this thread.
// 3. This thread should only read/write state at `thread_index`
// and should not access other thread states.
// 4. When restoring from checkpoint, threads are started only after
// the restore is complete.
// 5. Once restored from a checkpoint, the local state is edited only
// by this thread. 3 & 4 allow making assumptions like temporarily
// caching local state in this thread and using it outside a lock
// e.g. `make_new_iterator`.
// 6. `ckpt_mu_` should be wisely used to create *consistent*
// checkpoint markers.
// std::function arguments are copy-constructable, so we pass raw
// pointers, and then immediately wrap them to ensure correct ownership.
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
mutex_lock l(mu_);
workers_[thread_index].cond_var.notify_all();
RecordStop(ctx.get());
});
bool make_new_iterator;
{
tf_shared_lock l(ckpt_mu_);
// Decide whether a new iterator should be built.
// 1. If there is an existing iterator, we use it.
// 2. If there was an error in iterator creation that could not be
// notified to the client we attempt to send that to the client
// first.
make_new_iterator =
worker_thread_states_[thread_index].iterator == nullptr &&
worker_thread_states_[thread_index].iterator_creation_status.ok();
}
// Even though `make_new_iterator` has cached values from
// `worker_thread_states_[thread_index]` which is guarded by ckpt_mu_,
// it is safe to *read* `make_new_iterator`outside of a lock without
// worrying about concurrent changes to values in
// `worker_thread_states_[thread_index]`. See comment at the start of
// this function for details.
while (true) {
// Whether creation of the iterator succeeded.
Status iterator_creation_status;
// 1. Build a new iterator or use the existing one.
if (make_new_iterator) {
// 1a. Get new input tensors or use the exiting ones.
bool read_new_input;
{
tf_shared_lock l(ckpt_mu_);
// worker_thread_states_[thread_index].input will be non-empty
// if checkpointing happened at CHECKPOINT_MARKER_A.
read_new_input =
worker_thread_states_[thread_index].input.empty();
}
if (read_new_input) {
mutex_lock l(mu_);
while (!cancelled_ && !workers_[thread_index].is_producing) {
RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
RecordStart(ctx.get());
}
if (cancelled_) return;
// Copy the input tensors so that we do not need to block on `mu_`
// when building the iterator.
// We keep a copy of the input tensors in
// `WorkerThreadState.input` till the iterator is in use. This is
// used in `RestoreInternal` to re-build the iterator.
// TODO(b/78046638): Explore ways to avoid tracking the input
// tensors.
tf_shared_lock ckpt_l(ckpt_mu_);
worker_thread_states_[thread_index].input.swap(
workers_[thread_index].input);
// CHECKPOINT_MARKER_A
// We have the input tensors but have not built the iterator yet.
}
// 1b. Run the user defined function to produce a new iterator.
{
tf_shared_lock l(ckpt_mu_);
worker_thread_states_[thread_index].iterator_creation_status =
MakeIteratorFromInputElement(
ctx.get(), worker_thread_states_[thread_index].input,
thread_index, dataset()->captured_func_.get(), prefix(),
&worker_thread_states_[thread_index].iterator);
iterator_creation_status =
worker_thread_states_[thread_index].iterator_creation_status;
if (!iterator_creation_status.ok()) {
worker_thread_states_[thread_index].input.clear();
}
// CHECKPOINT_MARKER_B
// Either an iterator has been successfully built and placed in
// `worker_thread_states_[thread_index].iterator` or it failed and
// a non-OK status has been put in
// `worker_thread_states_[thread_index].iterator_creation_status`.
}
} else {
tf_shared_lock l(ckpt_mu_);
iterator_creation_status =
worker_thread_states_[thread_index].iterator_creation_status;
// Mark that we have used up the restored iterator.
make_new_iterator = true;
}
// 2. Start producing elements or send error state to client if
// iterator creation failed.
if (!iterator_creation_status.ok()) {
mutex_lock l(mu_);
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
RecordStart(ctx.get());
}
if (cancelled_) return;
tf_shared_lock ckpt_l(ckpt_mu_);
workers_[thread_index].outputs.emplace_back(
iterator_creation_status);
workers_[thread_index].is_producing = false;
worker_thread_states_[thread_index].iterator_creation_status =
Status::OK();
// CHECKPOINT_MARKER_C
// Non-OK iterator creation status has been notified to the
// client.
workers_[thread_index].cond_var.notify_one();
} else {
bool end_of_sequence = false;
while (!end_of_sequence) {
// 3.a Produce an element!
{
tf_shared_lock ckpt_l(ckpt_mu_);
if (worker_thread_states_[thread_index]
.output_elem.status.ok() &&
worker_thread_states_[thread_index]
.output_elem.output.empty() &&
!worker_thread_states_[thread_index].end_of_sequence) {
worker_thread_states_[thread_index].output_elem.status =
worker_thread_states_[thread_index].iterator->GetNext(
ctx.get(),
&worker_thread_states_[thread_index]
.output_elem.output,
&worker_thread_states_[thread_index].end_of_sequence);
end_of_sequence =
worker_thread_states_[thread_index].end_of_sequence;
} else {
end_of_sequence =
worker_thread_states_[thread_index].end_of_sequence;
}
// CHECKPOINT_MARKER_D
// An element has been read or an error or end_of_sequence has
// been received from the input iterator and is waiting to be
// sent to client.
}
// 3.b Make it available to the client.
{
mutex_lock l(mu_);
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
RecordStart(ctx.get());
}
if (cancelled_) return;
tf_shared_lock ckpt_l(ckpt_mu_);
workers_[thread_index].is_producing = !end_of_sequence;
// Output the element.
// Move the temporary state in WorkerThreadState to WorkerState
// and mark it as used.
if (end_of_sequence) {
worker_thread_states_[thread_index].iterator.reset();
worker_thread_states_[thread_index].input.clear();
worker_thread_states_[thread_index].end_of_sequence = false;
} else {
workers_[thread_index].outputs.emplace_back(
worker_thread_states_[thread_index].output_elem.status);
workers_[thread_index].outputs.back().output.swap(
worker_thread_states_[thread_index].output_elem.output);
}
worker_thread_states_[thread_index].output_elem.status =
Status::OK();
if (dataset()->sloppy_) {
sloppy_cond_var_.notify_one();
} else {
workers_[thread_index].cond_var.notify_one();
}
// CHECKPOINT_MARKER_E
// Output element or iterator status has been sent to the
// client.
}
}
}
}
}
Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index)
EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
string prefix = strings::StrCat("worker_", index);
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(prefix, "_input_size")),
workers_[index].input.size()));
for (int i = 0; i < workers_[index].input.size(); ++i) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat(prefix, "_input_", i)),
workers_[index].input[i]));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(prefix, "_outputs_size")),
workers_[index].outputs.size()));
for (int i = 0; i < workers_[index].outputs.size(); ++i) {
TF_RETURN_IF_ERROR(WriteOutputElemLocked(
writer, workers_[index].outputs[i],
full_name(strings::StrCat(prefix, "_outputs_", i))));
}
if (workers_[index].is_producing) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(prefix, "_is_producing")), ""));
}
return Status::OK();
}
Status ReadWorkerStateLocked(IteratorStateReader* reader, int index,
IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
string worker_prefix = strings::StrCat("worker_", index);
// Restore inputs.
int64 input_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(worker_prefix, "_input_size")),
&input_size));
workers_[index].input.reserve(input_size);
for (int i = 0; i < input_size; ++i) {
workers_[index].input.emplace_back();
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat(worker_prefix, "_input_", i)),
&workers_[index].input.back()));
}
int64 outputs_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(worker_prefix, "_outputs_size")),
&outputs_size));
for (int i = 0; i < outputs_size; ++i) {
workers_[index].outputs.emplace_back(Status::OK());
TF_RETURN_IF_ERROR(ReadOutputElemLocked(
reader, &workers_[index].outputs.back(),
full_name(strings::StrCat(worker_prefix, "_outputs_", i))));
}
if (reader->Contains(
full_name(strings::StrCat(worker_prefix, "_is_producing")))) {
workers_[index].is_producing = true;
} else {
workers_[index].is_producing = false;
}
return Status::OK();
}
Status WriteWorkerThreadStateLocked(IteratorStateWriter* writer,
int index)
EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
string prefix = strings::StrCat("worker_thread_", index);
if (worker_thread_states_[index].iterator != nullptr) {
TF_RETURN_IF_ERROR(
SaveInput(writer, worker_thread_states_[index].iterator));
} else {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(prefix, "_iterator_exhausted")), ""));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(prefix, "_input_size")),
worker_thread_states_[index].input.size()));
for (int i = 0; i < worker_thread_states_[index].input.size(); ++i) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat(prefix, "_input_", i)),
worker_thread_states_[index].input[i]));
}
TF_RETURN_IF_ERROR(WriteStatusLocked(
writer, strings::StrCat(prefix, "_iterator_creation_status"),
worker_thread_states_[index].iterator_creation_status));
TF_RETURN_IF_ERROR(WriteOutputElemLocked(
writer, worker_thread_states_[index].output_elem,
full_name(strings::StrCat(prefix, "_output"))));
if (worker_thread_states_[index].end_of_sequence) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(prefix, "_end_of_sequence")), ""));
}
return Status::OK();
}
Status ReadWorkerThreadStateLocked(IteratorStateReader* reader, int index,
IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
string worker_prefix = strings::StrCat("worker_thread_", index);
// Restore inputs.
int64 input_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(worker_prefix, "_input_size")),
&input_size));
worker_thread_states_[index].input.reserve(input_size);
for (int i = 0; i < input_size; ++i) {
worker_thread_states_[index].input.emplace_back();
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat(worker_prefix, "_input_", i)),
&worker_thread_states_[index].input.back()));
}
// Restore iterator.
if (reader->Contains(full_name(
strings::StrCat(worker_prefix, "_iterator_exhausted")))) {
worker_thread_states_[index].iterator.reset();
} else {
std::unique_ptr<IteratorBase> iterator;
Status s = MakeIteratorFromInputElement(
ctx, worker_thread_states_[index].input, index,
dataset()->captured_func_.get(), prefix(), &iterator);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
worker_thread_states_[index].iterator.swap(iterator);
}
TF_RETURN_IF_ERROR(ReadStatusLocked(
reader, strings::StrCat(worker_prefix, "_iterator_creation_status"),
&worker_thread_states_[index].iterator_creation_status));
TF_RETURN_IF_ERROR(ReadOutputElemLocked(
reader, &worker_thread_states_[index].output_elem,
full_name(strings::StrCat(worker_prefix, "_output"))));
if (reader->Contains(full_name(
strings::StrCat(worker_prefix, "_end_of_sequence")))) {
worker_thread_states_[index].end_of_sequence = true;
} else {
worker_thread_states_[index].end_of_sequence = false;
}
return Status::OK();
}
Status WriteOutputElemLocked(IteratorStateWriter* writer,
const OutputElem& output_elem,
const string& prefix)
EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
TF_RETURN_IF_ERROR(WriteStatusLocked(
writer, strings::StrCat(prefix, "_status"), output_elem.status));
TF_RETURN_IF_ERROR(
writer->WriteScalar(strings::StrCat(prefix, "_output_size"),
output_elem.output.size()));
for (int i = 0; i < output_elem.output.size(); ++i) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
strings::StrCat(prefix, "_output_", i), output_elem.output[i]));
}
return Status::OK();
}
Status ReadOutputElemLocked(IteratorStateReader* reader,
OutputElem* output_elem, const string& prefix)
EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
TF_RETURN_IF_ERROR(ReadStatusLocked(
reader, strings::StrCat(prefix, "_status"), &output_elem->status));
int64 output_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
strings::StrCat(prefix, "_output_size"), &output_size));
output_elem->output.reserve(output_size);
for (int i = 0; i < output_size; ++i) {
output_elem->output.emplace_back();
TF_RETURN_IF_ERROR(
reader->ReadTensor(strings::StrCat(prefix, "_output_", i),
&output_elem->output.back()));
}
return Status::OK();
}
Status WriteStatusLocked(IteratorStateWriter* writer,
const string& prefix, const Status& status)
EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(strings::StrCat(prefix, "_msg")),
status.error_message()));
}
return Status::OK();
}
Status ReadStatusLocked(IteratorStateReader* reader, const string& prefix,
Status* status)
EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(prefix, "_code")), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
string error_message;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(prefix, "_msg")), &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
}
return Status::OK();
}
// Mutex & condition variable to guard mutable iterator internals and
// coordinate among worker threads and client thread[s].
mutex mu_ ACQUIRED_BEFORE(ckpt_mu_);
// The main thread waits on this condition variable if running in sloppy
// mode and no values are available.
condition_variable sloppy_cond_var_;
// Mutex used to wait for a consistent state while checkpointing.
// Only Save and Restore require an exclusive lock on this mutex. In
// other scenarios we just acquire a shared lock so the pipeline's
// performance should not be affected in the absence of checkpointing.
// A thread must not wait on any condition variable while holding
// `ckpt_mu_` in either shared or exclusive modes.
mutex ckpt_mu_;
// The iterator producing elements which are converted to datasets by
// the dataset()->captured_func_ then interleaved together.
// input_impl_ is reset when we have exhausted its input.
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
// The WorkerState structs the worker threads operate on.
// workers_ elements are in at most one of interleave_ and staging_.
std::vector<WorkerState> workers_ GUARDED_BY(mu_);
// Stores the temporary state of WorkerThreads which is not stored in
// WorkerState. This is used for checkpointing purposes only.
std::vector<WorkerThreadState> worker_thread_states_ GUARDED_BY(ckpt_mu_);
// Indices in `workers_` of iterators to interleave.
std::vector<int64> interleave_indices_ GUARDED_BY(mu_);
// Indices in `workers_` of prefetched iterators.
std::deque<int64> staging_indices_ GUARDED_BY(mu_);
// The index into output_elements_ for next element to produce.
size_t next_index_ GUARDED_BY(mu_) = 0;
// The number of items produced so far within the block
size_t block_count_ GUARDED_BY(mu_) = 0;
// Flag to instruct the worker threads to exit.
bool cancelled_ GUARDED_BY(mu_) = false;
// The worker threads. This must be last to ensure the
// threads have exited before any other members are deallocated.
// TODO(b/65178177): Avoid allocating additional threads.
std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
};
const DatasetBase* const input_;
const NameAttrList interleave_func_;
const std::unique_ptr<CapturedFunction> captured_func_;
const int64 cycle_length_;
const int64 block_length_;
const bool sloppy_;
const int64 buffer_output_elements_;
const int64 prefetch_input_elements_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
};
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList interleave_func_;
};
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
// The motivation for creating an alternative implementation of parallel
// interleave is to decouple the degree of parallelism from the cycle length.
// This makes it possible to change the degree of parallelism (e.g. through
// auto-tuning) without changing the cycle length (which would change the order
// in which elements are produced).
//
// Furthermore, this class favors modularity over extended functionality. In
// particular, it refrains from implementing configurable buffering of output
// elements and prefetching of input iterators, relying on other parts of
// tf.data to provide this functionality if necessary.
//
// The above design choices were made with automated optimizations in mind,
// isolating the degree of parallelism as the single tunable knob of this
// implementation.
class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "cycle_length", &cycle_length));
OP_REQUIRES(ctx, cycle_length > 0,
errors::InvalidArgument("`cycle_length` must be > 0"));
int64 block_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "block_length", &block_length));
OP_REQUIRES(ctx, block_length > 0,
errors::InvalidArgument("`block_length` must be > 0"));
int64 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
OP_REQUIRES(
ctx, num_parallel_calls <= cycle_length,
errors::InvalidArgument(
"num_parallel_calls must less than or equal to cycle_length."));
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(
ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
&captured_func));
*output = new Dataset(ctx, input, interleave_func_,
std::move(captured_func), cycle_length, block_length,
num_parallel_calls, output_types_, output_shapes_);
}
private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, int64 num_parallel_calls,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
input_(input),
interleave_func_(func),
captured_func_(std::move(captured_func)),
cycle_length_(cycle_length),
block_length_(block_length),
num_parallel_calls_(num_parallel_calls),
output_types_(output_types),
output_shapes_(output_shapes) {
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::ParallelInterleaveV2")}));
}
const DataTypeVector& output_dtypes() const override {
return output_types_;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return "ParallelInterleaveDatasetV2Op::Dataset";
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
Node* num_parallel_calls_node;
TF_RETURN_IF_ERROR(
b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
DataTypeVector other_arguments_types;
other_arguments_types.reserve(captured_func_->captured_inputs().size());
std::vector<Node*> other_arguments;
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}
AttrValue f;
b->BuildAttrValue(interleave_func_, &f);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
TF_RETURN_IF_ERROR(b->AddDataset(
this,
{{0, input_node},
{2, cycle_length_node},
{3, block_length_node},
{4, num_parallel_calls_node}},
{{1, other_arguments}},
{{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
num_parallel_calls_(params.dataset->num_parallel_calls_),
args_list_(params.dataset->cycle_length_),
current_elements_(params.dataset->cycle_length_),
element_in_use_(params.dataset->cycle_length_, false),
thread_pool_(new thread::ThreadPool(
Env::Default(), ThreadOptions(), "parallel_interleave",
dataset()->cycle_length_ /* num_threads */,
false /* low_latency_hint */)) {}
~Iterator() override {
mutex_lock l(mu_);
// Cancel the runner thread.
cancelled_ = true;
cond_var_.notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
cond_var_.wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
if (num_parallel_calls_ == kAutoTune) {
num_parallel_calls_ = 1;
AddTunableParameter(ctx, "parallelism",
&num_parallel_calls_ /* value */, 1 /* min */,
dataset()->cycle_length_ /* max */, &cond_var_);
} else {
AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
}
AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
std::shared_ptr<InvocationResult> result;
do {
{
mutex_lock l(mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty() &&
(!end_of_input_ || num_open_ > 0)) {
RecordStop(ctx);
cond_var_.wait(l);
RecordStart(ctx);
}
if (!invocation_results_.empty()) {
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
} else {
*end_of_sequence = true;
return Status::OK();
}
cond_var_.notify_all();
}
RecordStop(ctx);
result->notification.WaitForNotification();
RecordStart(ctx);
} while (result->skip);
if (result->status.ok()) {
*out_tensors = std::move(result->return_values);
}
*end_of_sequence = false;
return result->status;
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
cond_var_.wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name("invocation_results.size"), invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
std::shared_ptr<InvocationResult> result = invocation_results_[i];
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("invocation_results[", i, "].size")),
result->return_values.size()));
for (size_t j = 0; j < result->return_values.size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(
strings::StrCat("invocation_results[", i, "][", j, "]")),
result->return_values[j]));
}
if (result->skip) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("invocation_results[", i, "].skip")),
""));
}
}
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("cycle_index"), cycle_index_));
if (end_of_input_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("end_of_input"), ""));
}
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("num_open"), num_open_));
TF_RETURN_IF_ERROR(WriteCurrentElements(writer));
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name("invocation_results.size"), &invocation_results_size));
for (size_t i = 0; i < invocation_results_size; i++) {
std::shared_ptr<InvocationResult> result(new InvocationResult());
invocation_results_.push_back(result);
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
size_t num_return_values;
{
int64 size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("invocation_results[", i, "].size")),
&size));
num_return_values = static_cast<size_t>(size);
if (num_return_values != size) {
return errors::InvalidArgument(strings::StrCat(
full_name(
strings::StrCat("invocation_results[", i, "].size")),
": ", size, " is not a valid value of type size_t."));
}
}
result->return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
result->return_values.emplace_back();
TF_RETURN_IF_ERROR(
reader->ReadTensor(full_name(strings::StrCat(
"invocation_results[", i, "][", j, "]")),
&result->return_values.back()));
}
result->skip = reader->Contains(
full_name(strings::StrCat("invocation_results[", i, "].skip")));
result->notification.Notify();
}
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("cycle_index"), &cycle_index_));
if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("num_open"), &num_open_));
TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader));
return Status::OK();
}
private:
struct InvocationResult {
Notification notification; // used for coordination with the consumer
Status status; // the invocation status
std::vector<Tensor> return_values; // the invocation result values
bool skip; // if set the result should be skipped
};
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
{}, "runner_thread",
[this, new_ctx]() { RunnerThread(new_ctx); }));
}
}
// Fetches up to `results.size()` outputs from the cycle element at
// position `cycle_index`.
//
// If end of input is encountered, the `skip` field of the invocation
// result is used to identify results that should be skipped.
void FetchOutputs(
const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
const std::vector<std::shared_ptr<InvocationResult>>& results)
LOCKS_EXCLUDED(mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
bool end_of_input = false;
for (auto& result : results) {
if (!end_of_input) {
result->status = current_elements_[cycle_index]->GetNext(
ctx.get(), &result->return_values, &end_of_input);
}
if (end_of_input) {
result->skip = true;
}
result->notification.Notify();
if (!result->status.ok()) {
break;
}
}
// Release the ownership of the cycle element iterator, closing the
// iterator if end of input was encountered.
if (end_of_input) {
current_elements_[cycle_index].reset();
}
mutex_lock l(mu_);
element_in_use_[cycle_index] = false;
num_calls_--;
if (end_of_input) {
args_list_[cycle_index].clear();
num_open_--;
}
cond_var_.notify_all();
}
// Method responsible for 1) creating iterators out of input elements, 2)
// determining the order in which elements are fetched from the iterators,
// and 3) scheduling the fetching of the elements to a threadpool.
//
// This method runs in the `runner_thread` background thread.
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
return element_in_use_[cycle_index_] ||
num_calls_ >= num_parallel_calls_ ||
invocation_results_.size() >=
dataset()->cycle_length_ * dataset()->block_length_;
};
while (true) {
mutex_lock l(mu_);
// Wait until this thread is cancelled, the end of input has been
// reached, or the cycle element at the `cycle_index_` position is
// not in use and there is space in the `invocation_results_` queue.
while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
RecordStop(ctx.get());
cond_var_.wait(l);
RecordStart(ctx.get());
}
if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
return;
}
while ((!end_of_input_ || num_open_ > 0) && !busy()) {
if (!current_elements_[cycle_index_]) {
// Try to create a new iterator from the next input element.
Status status = input_impl_->GetNext(
ctx.get(), &args_list_[cycle_index_], &end_of_input_);
if (!status.ok()) {
invocation_results_.emplace_back(new InvocationResult());
std::shared_ptr<InvocationResult>& result =
invocation_results_.back();
result->status.Update(status);
result->notification.Notify();
break;
}
if (!end_of_input_) {
Status status = MakeIteratorFromInputElement(
ctx.get(), args_list_[cycle_index_], cycle_index_,
dataset()->captured_func_.get(), prefix(),
&current_elements_[cycle_index_]);
if (!status.ok()) {
invocation_results_.emplace_back(new InvocationResult());
std::shared_ptr<InvocationResult>& result =
invocation_results_.back();
result->status.Update(status);
result->notification.Notify();
break;
}
++num_open_;
}
}
if (current_elements_[cycle_index_]) {
// Pre-allocate invocation results for outputs to be fetched
// and then fetch the outputs asynchronously.
std::vector<std::shared_ptr<InvocationResult>> results;
results.reserve(dataset()->block_length_);
for (int i = 0; i < dataset()->block_length_; ++i) {
invocation_results_.emplace_back(new InvocationResult());
results.push_back(invocation_results_.back());
}
num_calls_++;
element_in_use_[cycle_index_] = true;
thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
ctx, cycle_index_,
std::move(results)));
}
cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
cond_var_.notify_all();
}
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
const Status& status)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
status.error_message()));
}
return Status::OK();
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
string error_message;
TF_RETURN_IF_ERROR(
reader->ReadScalar(ErrorMessageKey(index), &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
}
return Status::OK();
}
string CodeKey(size_t index) {
return full_name(
strings::StrCat("invocation_results[", index, "].code"));
}
string ErrorMessageKey(size_t index) {
return full_name(
strings::StrCat("invocation_results[", index, "].error_message"));
}
Status WriteCurrentElements(IteratorStateWriter* writer)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (current_elements_[idx]) {
TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx]));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("args_size[", idx, "]")),
args_list_[idx].size()));
for (int i = 0; i < args_list_[idx].size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
args_list_[idx][i]));
}
}
}
return Status::OK();
}
Status ReadCurrentElements(IteratorContext* ctx,
IteratorStateReader* reader)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (reader->Contains(
full_name(strings::StrCat("args_size[", idx, "]")))) {
int64 args_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("args_size[", idx, "]")),
&args_size));
args_list_[idx].resize(args_size);
for (int i = 0; i < args_size; i++) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
&args_list_[idx][i]));
}
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, args_list_[idx], idx, dataset()->captured_func_.get(),
prefix(), &current_elements_[idx]));
TF_RETURN_IF_ERROR(
RestoreInput(ctx, reader, current_elements_[idx]));
} else {
current_elements_[idx].reset();
}
}
return Status::OK();
}
// Used for coordination between the main thread, the runner thread, and
// the worker threads.
mutex mu_;
// Used for coordination between the main thread, the runner thread, and
// the worker threads. In particular, the runner thread should only
// schedule new calls when the number of in-flight calls is less than the
// user specified level of parallelism, there are slots available in the
// `invocation_results_` buffer, the current cycle element is not in use,
// and there are elements left to be fetched.
condition_variable cond_var_;
// Identifies the maximum number of parallel calls.
std::atomic<int64> num_parallel_calls_;
// Iterator for input elements.
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
// Identifies current cycle element.
int64 cycle_index_ = 0;
// Arguments for creating an iterator for cycle elements.
std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
// Iterators for the current cycle elements. Concurrent access is
// protected by `element_in_use_`.
std::vector<std::unique_ptr<IteratorBase>> current_elements_;
// Identifies cycle elements that are in use by worker threads.
std::vector<bool> element_in_use_ GUARDED_BY(mu_);
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
GUARDED_BY(mu_);
// Identifies whether end of input has been reached.
bool end_of_input_ GUARDED_BY(mu_) = false;
// Identifies the number of open iterators.
int64 num_open_ GUARDED_BY(mu_) = 0;
// Identifies the number of outstanding calls.
int64 num_calls_ GUARDED_BY(mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_;
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
// Identifies whether background activity should be cancelled.
bool cancelled_ GUARDED_BY(mu_) = false;
};
const DatasetBase* const input_;
const NameAttrList interleave_func_;
const std::unique_ptr<CapturedFunction> captured_func_;
const int64 cycle_length_;
const int64 block_length_;
const int64 num_parallel_calls_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
};
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList interleave_func_;
};
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
ParallelInterleaveDatasetV2Op);
} // namespace
} // namespace data
} // namespace tensorflow