blob: 677ef65428fd34e2b78d5d9a1c9938e6d8322c62 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/parallel_interleave_dataset_op.h"
#include <atomic>
#include <deque>
#include <memory>
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/kernels/data/stats_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/cpu_info.h"
namespace tensorflow {
namespace data {
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following op.
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kDatasetType;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kInputDataset;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kOtherArguments;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kCycleLength;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kBlockLength;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kNumParallelCalls;
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kTarguments;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kOutputTypes;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kOutputShapes;
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
constexpr char kTfDataParallelInterleaveWorkerPool[] =
"tf_data_parallel_interleave_worker_pool";
constexpr char kParallelism[] = "parallelism";
constexpr char kBlockIndex[] = "block_index";
constexpr char kCycleIndex[] = "cycle_index";
constexpr char kEndOfInput[] = "end_of_input";
constexpr char kElementIdCounter[] = "element_id_counter";
constexpr char kCurrentElements[] = "current_elements";
constexpr char kCurrentElementsSize[] = "current_elements.size";
constexpr char kFutureElements[] = "future_elements";
constexpr char kFutureElementsSize[] = "future_elements.size";
constexpr char kResultsSuffix[] = ".results";
constexpr char kCodeSuffix[] = ".code";
constexpr char kErrorMessageSuffix[] = ".error_message";
constexpr char kIdSuffix[] = ".id";
constexpr char kSizeSuffix[] = ".size";
constexpr char kInputsSuffix[] = ".inputs";
constexpr char kIsReadySuffix[] = ".is_ready";
// `kCyclePrefetchFactor * cycle_length` is the number of future cycle elements
// that will be prefetched ahead of time. The purpose of prefetching future
// cycle elements is to overlap expensive initialization (e.g. opening of a
// remote file) with other computation.
constexpr double kCyclePrefetchFactor = 2.0L;
// `kPerIteratorPrefetchFactor * block_length + 1` is the number of per-iterator
// results that will be prefetched ahead of time. The `+ 1` is to match the
// behavior of the original autotune implementation.
constexpr double kPerIteratorPrefetchFactor = 2.0L;
// The motivation for creating an alternative implementation of parallel
// interleave is to decouple the degree of parallelism from the cycle length.
// This makes it possible to change the degree of parallelism (e.g. through
// auto-tuning) without changing the cycle length (which would change the order
// in which elements are produced).
//
// Furthermore, this class favors modularity over extended functionality. In
// particular, it refrains from implementing configurable buffering of output
// elements and prefetching of input iterators.
class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, int64 num_parallel_calls, bool sloppy,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
input_(input),
captured_func_(std::move(captured_func)),
cycle_length_(cycle_length),
block_length_(block_length),
num_parallel_calls_(num_parallel_calls),
sloppy_(sloppy),
output_types_(output_types),
output_shapes_(output_shapes) {
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
name_utils::IteratorPrefixParams params;
params.op_version = op_version_;
return absl::make_unique<ParallelInterleaveIterator>(
ParallelInterleaveIterator::Params{
this,
name_utils::IteratorPrefix(
ParallelInterleaveDatasetOp::kDatasetType, prefix, params)},
sloppy_);
}
const DataTypeVector& output_dtypes() const override { return output_types_; }
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
name_utils::DatasetDebugStringParams params;
params.op_version = op_version_;
return name_utils::DatasetDebugString(
ParallelInterleaveDatasetOp::kDatasetType, params);
}
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
Node* num_parallel_calls_node;
TF_RETURN_IF_ERROR(
b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
std::vector<Node*> other_arguments;
DataTypeVector other_arguments_types;
TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
&other_arguments_types));
AttrValue f;
b->BuildAttrValue(captured_func_->func(), &f);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
AttrValue sloppy_attr;
b->BuildAttrValue(sloppy_, &sloppy_attr);
TF_RETURN_IF_ERROR(b->AddDataset(this,
{{0, input_node},
{2, cycle_length_node},
{3, block_length_node},
{4, num_parallel_calls_node}},
{{1, other_arguments}},
{{kFunc, f},
{kTarguments, other_arguments_types_attr},
{kSloppy, sloppy_attr}},
output));
return Status::OK();
}
private:
class ParallelInterleaveIterator : public DatasetIterator<Dataset> {
public:
ParallelInterleaveIterator(const Params& params, bool sloppy)
: DatasetIterator<Dataset>(params),
per_iterator_prefetch_(
static_cast<int>(params.dataset->block_length_ *
kPerIteratorPrefetchFactor) +
1),
future_elements_prefetch_(static_cast<int>(
params.dataset->cycle_length_ * kCyclePrefetchFactor)),
mu_(std::make_shared<mutex>()),
num_parallel_calls_cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
params.dataset->num_parallel_calls_, mu_,
num_parallel_calls_cond_var_)),
sloppy_(sloppy),
current_elements_(params.dataset->cycle_length_) {}
~ParallelInterleaveIterator() override {
mutex_lock l(*mu_);
cancelled_ = true;
StopAllThreads(&l);
// Notify any callers blocked in GetNextInternal or SaveInternal.
for (auto element : current_elements_) {
if (element) {
element->cond_var.notify_all();
}
}
sloppy_cond_var_.notify_all();
zero_active_workers_cond_var_.notify_all();
}
string BuildTraceMeName() override {
// NOTE: We do not synchronize the following access to
// num_parallel_calls_ to minimize the tracing overhead.
int64 parallelism = num_parallel_calls_->value;
return strings::StrCat(
prefix(), "#parallelism=", parallelism,
",cycle_length=", dataset()->cycle_length_,
",block_length=", dataset()->block_length_,
",autotune=", dataset()->num_parallel_calls_ == model::kAutotune,
",deterministic=", !sloppy_, "#");
}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
// Note that if `ctx->thread_pool()` is non-null, then instead of creating
// a dedicated thread pool of size `num_threads`, computation will be
// scheduled into the shared threadpool. The threadpool is guaranteed to
// support `num_threads` concurrent tasks without blocking indefinitely.
//
// Allocate one thread for the worker manager, `cycle_length_` threads for
// the current workers, and `future_elements_prefetch_` for the future
// workers.
int max_current_workers = dataset()->cycle_length_;
int future_workers = future_elements_prefetch_ + dataset()->cycle_length_;
const int num_threads = 1 + max_current_workers + future_workers;
thread_pool_ = ctx->CreateThreadPool(kTfDataParallelInterleaveWorkerPool,
num_threads);
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = dataset()->cycle_length_;
}
last_valid_current_element_ = dataset()->cycle_length_ - 1;
ctx_ = std::make_unique<IteratorContext>(*ctx);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
std::shared_ptr<Result> result;
{
mutex_lock l(*mu_);
EnsureInitialElementsCreated();
EnsureThreadsStarted();
while (!cancelled_ && !Consume(&result)) {
RecordStop(ctx);
if (sloppy_) {
sloppy_cond_var_.wait(l);
} else {
VLOG(3) << "Blocked waiting for element "
<< current_elements_[cycle_index_]->id;
current_elements_[cycle_index_]->cond_var.wait(l);
}
RecordStart(ctx);
}
if (cancelled_) {
return errors::Cancelled("Iterator was cancelled");
}
}
if (!result) {
*end_of_sequence = true;
return Status::OK();
}
if (result->status.ok()) {
*out_tensors = std::move(result->return_values);
RecordBufferDequeue(ctx, *out_tensors);
}
*end_of_sequence = false;
return result->status;
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeAsyncInterleaveManyNode(
std::move(args),
{model::MakeParameter(kParallelism, num_parallel_calls_, /*min=*/1,
/*max=*/dataset()->cycle_length_)});
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(*mu_);
wait_for_checkpoint_ = true;
// Wait for all in-flight calls to complete.
while (num_active_workers_ > 0) {
zero_active_workers_cond_var_.wait(l);
}
// Initialize all elements and filter out elements with no input.
InitializeInputs(element_id_counter_);
for (auto& element : current_elements_) {
if (element && element->no_input) {
element.reset();
}
}
while (!future_elements_.empty() && future_elements_.back()->no_input) {
future_elements_.pop_back();
}
wait_for_checkpoint_ = false;
DCHECK_EQ(num_active_workers_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kBlockIndex), block_index_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kCycleIndex), cycle_index_));
if (end_of_input_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kEndOfInput), ""));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kElementIdCounter),
element_id_counter_));
TF_RETURN_IF_ERROR(WriteCurrentElements(writer));
TF_RETURN_IF_ERROR(WriteFutureElements(writer));
// Wake workers back up.
current_workers_cond_var_.notify_all();
future_workers_cond_var_.notify_all();
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(*mu_);
StopAllThreads(&l);
cancelled_ = false;
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name(kBlockIndex), &block_index_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name(kCycleIndex), &cycle_index_));
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kElementIdCounter),
&element_id_counter_));
if (reader->Contains(full_name(kEndOfInput))) end_of_input_ = true;
TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader));
TF_RETURN_IF_ERROR(ReadFutureElements(ctx, reader));
initial_elements_created_ = false;
for (int i = 0; i < current_elements_.size(); ++i) {
int index = (cycle_index_ + i) % current_elements_.size();
auto element = current_elements_[index];
if (element) {
elements_to_process_.push_back(index);
element->initialized = true;
element->cycle_index = index;
initial_elements_created_ = true;
}
}
for (auto element : future_elements_) {
element->initialized = true;
}
last_valid_current_element_ = current_elements_.size() - 1;
while (last_valid_current_element_ >= 0 &&
!current_elements_[last_valid_current_element_]) {
last_valid_current_element_--;
}
threads_initialized_ = false;
VLOG(2) << "Parallel interleave iterator restored";
return Status::OK();
}
private:
// Represents the result of fetching an element from a dataset.
struct Result {
Status status;
std::vector<Tensor> return_values;
};
// The interleave transformation repeatedly inputs elements, applies the
// user-provided function to transform the input elements to datasets, and
// interleaves the elements of these datasets as its output.
//
// This structure represents an input element and derived state.
struct Element {
// Unique identifier, needed to support checkpointing.
int64 id GUARDED_BY(&ParallelInterleaveIterator::mu_);
// The actual input element. Iterator created from the input element. A
// null value indicates that the element either reached end of input or
// hasn't been initialized yet.
std::unique_ptr<std::vector<Tensor>> inputs
GUARDED_BY(&ParallelInterleaveIterator::mu_);
// Iterator created from the input element. A null value indicates that
// the element either reached end of input or hasn't been initialized yet.
std::unique_ptr<IteratorBase> iterator
GUARDED_BY(&ParallelInterleaveIterator::mu_);
// Buffer for storing the outputs of `iterator`.
std::deque<std::shared_ptr<Result>> GUARDED_BY(
&ParallelInterleaveIterator::mu_) results;
// The element's index in the cycle, if it is in the current cycle.
// -1 if the element is not in the current cycle.
int64 cycle_index GUARDED_BY(&ParallelInterleaveIterator::mu_) = -1;
// Whether the element is currently being processed by a worker thread.
// This is used to ensure that only one thread at a time tries to process
// an element.
bool active GUARDED_BY(&ParallelInterleaveIterator::mu_) = false;
// Whether the inputs and iterator have been initialized.
bool initialized GUARDED_BY(&ParallelInterleaveIterator::mu_) = false;
// Whether we tried to initialize the element, but the input interator
// was exhausted so we could produce no inputs.
bool no_input GUARDED_BY(&ParallelInterleaveIterator::mu_) = false;
// Condition variable for communicating between current worker threads
// and GetNext.
condition_variable cond_var;
};
void EnsureInitialElementsCreated() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!initial_elements_created_) {
for (int i = 0; i < dataset()->cycle_length_; ++i) {
current_elements_[i] = MakeElement();
if (current_elements_[i]) {
current_elements_[i]->cycle_index = i;
elements_to_process_.push_back(i);
}
}
initial_elements_created_ = true;
}
}
void EnsureThreadsStarted() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!threads_initialized_) {
IncrementOutstandingThreads();
thread_pool_->Schedule([this]() { WorkerManagerThread(); });
threads_initialized_ = true;
}
}
// Advances the position in the interleave cycle to the next cycle
// element.
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
block_index_ = 0;
cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
// Advances the position in the interleave cycle by one.
void AdvancePosition() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
++block_index_;
if (block_index_ == dataset()->block_length_) {
AdvanceToNextInCycle();
}
}
// Consumes a result (if available), returning an indication of whether
// a result is available. If `true` is returned, `result` either
// points to a valid result or is null if end of input has been reached.
bool Consume(std::shared_ptr<Result>* result)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!sloppy_) {
return ConsumeHelper(result);
}
// If we are allowed to be sloppy (i.e. return results out of order),
// try to find an element in the cycle that has a result available.
for (int i = 0; i < dataset()->cycle_length_; ++i) {
if (ConsumeHelper(result)) {
return true;
}
AdvanceToNextInCycle();
}
return false;
}
// Consumes a result (if available), returning an indication of whether
// a result is available. If `true` is returned, `result` either
// points to a valid result or is null if end of input has been reached.
bool ConsumeHelper(std::shared_ptr<Result>* result)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
while (true) {
for (int64 i = 0; i < (last_valid_current_element_ + 1); ++i) {
int64 index = (cycle_index_ + i) % (last_valid_current_element_ + 1);
if (current_elements_[index]) {
cycle_index_ = index;
if (i > 0) {
block_index_ = 0;
}
break;
}
}
if (!current_elements_[cycle_index_]) {
// Reached end of input.
return true;
}
std::shared_ptr<Element> element = current_elements_[cycle_index_];
if (!element->results.empty()) {
// We found a result.
std::swap(*result, element->results.front());
element->results.pop_front();
if (!element->active) {
elements_to_process_.push_back(cycle_index_);
current_workers_cond_var_.notify_one();
}
AdvancePosition();
return true;
}
if (!element->initialized || element->iterator) {
// The element is still producing results, so we wait.
return false;
}
// We've consumed all results from the element. Get a new element from
// future_elements, or create a new element if no future elements are
// available.
if (!future_elements_.empty()) {
std::shared_ptr<Element> future_element =
std::move(future_elements_.front());
future_elements_.pop_front();
if (future_element->iterator) {
EnableAutotune(ctx_.get(), future_element->iterator.get());
}
future_element->cycle_index = cycle_index_;
current_elements_[cycle_index_] = std::move(future_element);
future_workers_cond_var_.notify_one();
if (!current_elements_[cycle_index_]->active) {
current_workers_cond_var_.notify_one();
}
} else {
current_elements_[cycle_index_] = MakeElement();
if (current_elements_[cycle_index_]) {
current_elements_[cycle_index_]->cycle_index = cycle_index_;
elements_to_process_.push_back(cycle_index_);
element->cycle_index = cycle_index_;
current_workers_cond_var_.notify_one();
}
while (last_valid_current_element_ >= 0 &&
!current_elements_[last_valid_current_element_]) {
last_valid_current_element_--;
}
}
AdvanceToNextInCycle();
}
}
// Creates a new element.
std::shared_ptr<Element> MakeElement() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (end_of_input_) {
return nullptr;
}
auto element = std::make_shared<Element>();
element->id = element_id_counter_++;
uninitialized_elements_.push_back(element);
return element;
}
// Thread responsible for launching all worker threads. The thread stays
// around after startup in case autotuning increases num_parallel_calls.
void WorkerManagerThread() LOCKS_EXCLUDED(mu_) {
int initial_current_workers;
// When elements are moved from `future_elements_` to `current_elements_`,
// the future worker which created the element may continue to process
// the element for some time. That is why we need an additional
// `cycle_length_` future workers to guarantee that whenever
// `future_element_.size() < future_elements_prefetch_`, there will be a
// future worker available to create a new future element.
int future_workers = future_elements_prefetch_ + dataset()->cycle_length_;
{
mutex_lock l(*mu_);
initial_current_workers = num_parallel_calls_->value;
outstanding_threads_ += initial_current_workers + future_workers;
num_current_workers_ += initial_current_workers;
num_active_workers_ += initial_current_workers + future_workers;
num_current_active_workers_ += initial_current_workers;
}
// Start current workers before future workers to improve startup time.
for (int i = 0; i < initial_current_workers; ++i) {
StartCurrentWorkerThread();
}
for (int i = 0; i < future_workers; ++i) {
StartFutureWorkerThread();
}
while (true) {
{
mutex_lock l(*mu_);
while (!cancelled_ &&
num_current_workers_ >= num_parallel_calls_->value) {
num_parallel_calls_cond_var_->wait(l);
}
if (cancelled_ || end_of_input_) {
DecrementOutstandingThreads();
return;
}
IncrementOutstandingThreads();
IncrementCurrentWorkers();
IncrementActiveWorkers();
IncrementCurrentActiveWorkers();
StartCurrentWorkerThread();
}
}
}
void StartCurrentWorkerThread() {
thread_pool_->Schedule([this]() { CurrentWorkerThread(); });
}
void StartFutureWorkerThread() {
thread_pool_->Schedule([this]() { FutureWorkerThread(); });
}
// Current workers are responsible for keeping elements in
// `current_elements_` processed. An element is processed if it is either
// done or its `results` buffer is full (contains `kPerIteratorPrefetch`
// elements).
//
// Current workers cycle between two phases: (1) finding an element and (2)
// processing it. When a worker is processing an element, it will
// claim the element by setting `element->active`, then continue to produce
// results for the element until enough results have been computed for the
// current cycle and the results buffer is full.
void CurrentWorkerThread() LOCKS_EXCLUDED(mu_) {
RecordStart(ctx_.get());
auto done = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
DecrementActiveWorkers();
DecrementCurrentActiveWorkers();
DecrementOutstandingThreads();
DecrementCurrentWorkers();
};
while (true) {
int element_index;
std::shared_ptr<Element> element;
// Find an element to process.
{
mutex_lock l(*mu_);
// In case autotune changes num_parallel_calls.
if (num_current_workers_ > num_parallel_calls_->value) {
done();
return;
}
// Look for an element that needs processing.
element.reset();
while (!cancelled_) {
while (!elements_to_process_.empty() && !wait_for_checkpoint_) {
int index = elements_to_process_.front();
elements_to_process_.pop_front();
auto& e = current_elements_[index];
if (NeedsProcessing(e) && !e->active) {
element_index = index;
element = e;
break;
}
}
if (element) {
break;
}
DecrementCurrentActiveWorkers();
WaitWorkerThread(&current_workers_cond_var_, &l);
IncrementCurrentActiveWorkers();
}
if (cancelled_) {
done();
return;
}
VLOG(3) << "Current worker woke up to process " << element->id;
element->active = true;
}
// Loop on the element until we fill its results buffer or reach end of
// input for the element.
while (true) {
ProcessElement(element);
{
mutex_lock l(*mu_);
// Check whether we have produced enough results for the current
// cycle.
if (!NeedsProcessing(element)) {
element->active = false;
break;
}
}
}
}
}
// Future workers process elements after the current interleave cycle. A
// future worker's job is to keep `future_elements_` filled with elements.
// Elements in `future_elements` have had their first `kPerIteratorPrefetch`
// results computed.
void FutureWorkerThread() LOCKS_EXCLUDED(mu_) {
RecordStart(ctx_.get());
auto done = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
RecordStop(ctx_.get());
DecrementActiveWorkers();
DecrementOutstandingThreads();
};
std::shared_ptr<Element> element;
while (true) {
{
mutex_lock l(*mu_);
if (element) {
element->active = false;
if (element->cycle_index != -1) {
element->cond_var.notify_one();
// A current worker may need to process the element further.
elements_to_process_.push_back(element->cycle_index);
current_workers_cond_var_.notify_one();
}
}
while (!cancelled_ &&
(future_elements_.size() >= future_elements_prefetch_ ||
wait_for_checkpoint_)) {
WaitWorkerThread(&future_workers_cond_var_, &l);
}
if (cancelled_) {
done();
return;
}
element = MakeElement();
if (!element) {
done();
return;
}
VLOG(3) << "Future worker created element " << element->id;
element->active = true;
future_elements_.push_back(element);
}
ProcessElement(element);
}
}
// Generates results for the given element until the element's results
// buffer is full or the element is done producing results.
void ProcessElement(std::shared_ptr<Element> element) LOCKS_EXCLUDED(mu_) {
DCHECK(element != nullptr);
IteratorBase* iterator;
// Initialize the inputs and iterator if necessary.
{
mutex_lock l(*mu_);
DCHECK(element->active);
if (!element->iterator) {
InitializeInputs(element->id);
if (!element->iterator) {
return;
}
}
// `iterator` will remain valid after releasing the lock because we have
// marked the element as active, so no other thread will modify its
// iterator.
iterator = element->iterator.get();
}
DCHECK(iterator != nullptr);
// Process until the results queue is full or we reach end of input.
while (true) {
auto result = std::make_shared<Result>();
bool end_of_input = false;
result->status = iterator->GetNext(ctx_.get(), &result->return_values,
&end_of_input);
if (end_of_input) {
mutex_lock l(*mu_);
element->iterator.reset();
element->inputs.reset();
NotifyElementUpdate(element);
break;
}
RecordBufferEnqueue(ctx_.get(), result->return_values);
mutex_lock l(*mu_);
element->results.push_back(std::move(result));
NotifyElementUpdate(element);
if (element->results.size() == per_iterator_prefetch_) {
break;
}
}
}
// Initialize inputs and create an iterator for all elements up to
// element_id.
void InitializeInputs(int element_id) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
while (!uninitialized_elements_.empty() &&
uninitialized_elements_.front()->id <= element_id) {
std::shared_ptr<Element> element = uninitialized_elements_.front();
uninitialized_elements_.pop_front();
element->initialized = true;
// Check if we've already reached end of input.
if (end_of_input_) {
element->no_input = true;
NotifyElementUpdate(element);
continue;
}
std::vector<Tensor> inputs;
Status status =
input_impl_->GetNext(ctx_.get(), &inputs, &end_of_input_);
if (!status.ok()) {
AddErrorResult(element, status);
continue;
}
if (end_of_input_) {
element->no_input = true;
NotifyElementUpdate(element);
continue;
}
element->inputs =
absl::make_unique<std::vector<Tensor>>(std::move(inputs));
status = MakeIteratorFromInputElement(
ctx_.get(), *element->inputs, element->id,
*instantiated_captured_func_, prefix(), &element->iterator);
if (!status.ok()) {
element->inputs.reset();
element->iterator.reset();
AddErrorResult(element, status);
continue;
}
if (element->cycle_index == -1) {
DisableAutotune(ctx_.get(), element->iterator.get());
}
}
}
// Adds an error result for the given element.
void AddErrorResult(std::shared_ptr<Element> element, Status status)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
auto result = std::make_shared<Result>();
result->status = status;
element->results.push_back(std::move(result));
NotifyElementUpdate(element);
}
// Cancels all threads (including the manager) and waits for them to finish.
void StopAllThreads(mutex_lock* l) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
cancelled_ = true;
num_parallel_calls_cond_var_->notify_all();
for (auto element : current_elements_) {
if (element) {
element->cond_var.notify_all();
}
}
current_workers_cond_var_.notify_all();
future_workers_cond_var_.notify_all();
while (outstanding_threads_ > 0) {
outstanding_threads_finished_cond_var_.wait(*l);
}
}
// Waits on the given cond_var in a worker thread.
void WaitWorkerThread(condition_variable* cond_var, mutex_lock* l)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
DecrementActiveWorkers();
RecordStop(ctx_.get());
cond_var->wait(*l);
RecordStart(ctx_.get());
IncrementActiveWorkers();
}
void NotifyElementUpdate(std::shared_ptr<Element> element)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (sloppy_) {
sloppy_cond_var_.notify_one();
} else {
element->cond_var.notify_one();
}
}
bool NeedsProcessing(const std::shared_ptr<Element>& element)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!element) {
return false;
}
if (!element->initialized) {
return true;
}
return element->iterator &&
element->results.size() < per_iterator_prefetch_;
}
inline void IncrementCurrentWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
num_current_workers_++;
}
inline void DecrementCurrentWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
num_current_workers_--;
}
inline void IncrementActiveWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
num_active_workers_++;
}
inline void DecrementActiveWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
num_active_workers_--;
if (num_active_workers_ == 0) {
zero_active_workers_cond_var_.notify_one();
}
}
inline void IncrementCurrentActiveWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
num_current_active_workers_++;
UpdateThreadUtilizationStats();
}
inline void DecrementCurrentActiveWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
num_current_active_workers_--;
UpdateThreadUtilizationStats();
}
inline void IncrementOutstandingThreads() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
outstanding_threads_++;
}
inline void DecrementOutstandingThreads() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
outstanding_threads_--;
if (outstanding_threads_ == 0) {
outstanding_threads_finished_cond_var_.notify_one();
}
}
inline void UpdateThreadUtilizationStats() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
const auto& stats_aggregator = ctx_->stats_aggregator();
if (stats_aggregator) {
stats_aggregator->AddScalar(
stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
static_cast<float>(num_current_active_workers_) /
static_cast<float>(num_parallel_calls_->value),
num_elements());
}
}
Status WriteStatusLocked(IteratorStateWriter* writer,
const string& key_prefix, size_t idx,
const Status& status)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
CodeKey(key_prefix, idx), static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(key_prefix, idx),
status.error_message()));
}
return Status::OK();
}
Status ReadStatusLocked(IteratorStateReader* reader,
const string& key_prefix, size_t idx,
Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(
reader->ReadScalar(CodeKey(key_prefix, idx), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
tstring error_message;
TF_RETURN_IF_ERROR(reader->ReadScalar(ErrorMessageKey(key_prefix, idx),
&error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
}
return Status::OK();
}
string CodeKey(const string& key_prefix, size_t idx) {
return full_name(strings::StrCat(key_prefix, kResultsSuffix, "[", idx,
"]", kCodeSuffix));
}
string ErrorMessageKey(const string& key_prefix, size_t idx) {
return full_name(strings::StrCat(key_prefix, kResultsSuffix, "[", idx,
"]", kErrorMessageSuffix));
}
Status WriteElement(std::shared_ptr<Element> element, int idx,
const string& key_prefix, IteratorStateWriter* writer)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (element->iterator) {
TF_RETURN_IF_ERROR(SaveInput(writer, element->iterator));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "]", kIdSuffix)),
element->id));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "]", kInputsSuffix,
kSizeSuffix)),
element->inputs->size()));
for (int i = 0; i < element->inputs->size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat(key_prefix, "[", idx, "]",
kInputsSuffix, "[", i, "]")),
element->inputs->at(i)));
}
}
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "]", kResultsSuffix,
kSizeSuffix)),
element->results.size()));
for (size_t i = 0; i < element->results.size(); i++) {
std::shared_ptr<Result> result = element->results[i];
TF_RETURN_IF_ERROR(WriteStatusLocked(
writer, strings::StrCat(key_prefix, "[", idx, "]"), i,
result->status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "]", kResultsSuffix,
"[", i, "]", kSizeSuffix)),
result->return_values.size()));
for (size_t j = 0; j < result->return_values.size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat(key_prefix, "[", idx, "]",
kResultsSuffix, "[", i, "][", j, "]")),
result->return_values[j]));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "]", kResultsSuffix,
"[", i, "]", kIsReadySuffix)),
""));
}
return Status::OK();
}
Status WriteCurrentElements(IteratorStateWriter* writer)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentElementsSize),
current_elements_.size()));
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (current_elements_[idx]) {
TF_RETURN_IF_ERROR(WriteElement(current_elements_[idx], idx,
kCurrentElements, writer));
}
}
return Status::OK();
}
Status WriteFutureElements(IteratorStateWriter* writer)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kFutureElementsSize),
future_elements_.size()));
for (int idx = 0; idx < future_elements_.size(); idx++) {
if (future_elements_[idx]) {
TF_RETURN_IF_ERROR(WriteElement(future_elements_[idx], idx,
kFutureElements, writer));
}
}
return Status::OK();
}
Status ReadElement(IteratorContext* ctx, IteratorStateReader* reader,
int idx, const string& key_prefix,
std::shared_ptr<Element>* out)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!reader->Contains(full_name(strings::StrCat(
key_prefix, "[", idx, "]", kResultsSuffix, kSizeSuffix)))) {
return Status::OK();
}
auto element = std::make_shared<Element>();
int64 results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "]", kResultsSuffix,
kSizeSuffix)),
&results_size));
element->results.resize(results_size);
for (size_t i = 0; i < results_size; i++) {
auto result = std::make_shared<Result>();
TF_RETURN_IF_ERROR(
ReadStatusLocked(reader, strings::StrCat(key_prefix, "[", idx, "]"),
i, &result->status));
int64 num_return_values;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "]", kResultsSuffix,
"[", i, "]", kSizeSuffix)),
&num_return_values));
result->return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
result->return_values.emplace_back();
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat(key_prefix, "[", idx, "]",
kResultsSuffix, "[", i, "][", j, "]")),
&result->return_values.back()));
}
element->results[i] = std::move(result);
}
if (!reader->Contains(full_name(strings::StrCat(
key_prefix, "[", idx, "]", kInputsSuffix, kSizeSuffix)))) {
element->iterator.reset();
*out = std::move(element);
return Status::OK();
}
int64 inputs_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "]", kInputsSuffix,
kSizeSuffix)),
&inputs_size));
element->inputs = std::make_unique<std::vector<Tensor>>(inputs_size);
for (int i = 0; i < inputs_size; i++) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat(key_prefix, "[", idx, "]", kInputsSuffix,
"[", i, "]")),
&element->inputs->at(i)));
}
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "]", kIdSuffix)),
&element->id));
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, *element->inputs, element->id,
*instantiated_captured_func_.get(), prefix(), &element->iterator));
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, element->iterator));
*out = std::move(element);
return Status::OK();
}
Status ReadCurrentElements(IteratorContext* ctx,
IteratorStateReader* reader)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name(kCurrentElementsSize), &size));
DCHECK_EQ(current_elements_.size(), size);
for (int idx = 0; idx < current_elements_.size(); idx++) {
TF_RETURN_IF_ERROR(ReadElement(ctx, reader, idx, kCurrentElements,
&current_elements_[idx]));
}
return Status::OK();
}
Status ReadFutureElements(IteratorContext* ctx, IteratorStateReader* reader)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name(kFutureElementsSize), &size));
future_elements_.resize(size);
for (int idx = 0; idx < future_elements_.size(); idx++) {
TF_RETURN_IF_ERROR(ReadElement(ctx, reader, idx, kFutureElements,
&future_elements_[idx]));
}
return Status::OK();
}
// Indices of `current_elements_` which need to be processed by a current
// worker.
std::deque<int> elements_to_process_;
// The last index in `current_elements_` containing a non-null element.
// This allows us to optimize the situation when the cycle_length is large
// but the input dataset doesn't have many elements. By tracking the index
// of the last valid element, GetNext can avoid checking many null entries
// each time through the cycle.
// TODO(aaudibert): Generalize this optimization by removing null elements
// from `current_elements_`, e.g. by compacting the vector when x% of
// its elements are null.
int64 last_valid_current_element_ GUARDED_BY(mu_);
const int per_iterator_prefetch_;
const int future_elements_prefetch_;
// Identifies whether the current_elements_ vector has been initialized.
bool initial_elements_created_ GUARDED_BY(mu_) = false;
// Identifies whether the element threads have been initialized.
bool threads_initialized_ GUARDED_BY(mu_) = false;
// Used for coordination between the main thread, the manager threads, and
// the worker threads.
const std::shared_ptr<mutex> mu_;
// Condition variable for waking up current workers.
condition_variable current_workers_cond_var_;
// Condition variable for waking up future workers.
condition_variable future_workers_cond_var_;
// Number of active worker threads which might be processing elements,
// including both current workers and future workers. Used by
// checkpointing to wait for outstanding work to finish.
int num_active_workers_ GUARDED_BY(mu_) = 0;
// Number of active current worker threads.
int num_current_active_workers_ GUARDED_BY(mu_) = 0;
// Condition variable notified whenever the total number of active workers
// drops to zero. Used for checkpointing.
condition_variable zero_active_workers_cond_var_;
// Condition notified whenever num_parallel_calls_ changes. Shared so that
// autotuning can notify us when num_parallel_calls_ changes.
std::shared_ptr<condition_variable> num_parallel_calls_cond_var_;
// Identifies the maximum number of parallel calls.
const std::shared_ptr<model::SharedState> num_parallel_calls_;
// The number of current workers currently alive or scheduled to be started.
// This includes current workers which are blocked waiting for work.
int num_current_workers_ GUARDED_BY(mu_) = 0;
// Condition variable to signal that a result has been produced by some
// element thread. Only used when `sloppy_` is true.
condition_variable sloppy_cond_var_;
// Determines whether outputs can be produced in non-deterministic order.
const bool sloppy_;
// Iterator for input elements.
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
// Identifies position in the interleave cycle.
int64 block_index_ GUARDED_BY(mu_) = 0;
int64 cycle_index_ GUARDED_BY(mu_) = 0;
// Elements of the current interleave cycle.
std::vector<std::shared_ptr<Element>> current_elements_ GUARDED_BY(mu_);
// Elements which still need their inputs and iterators to be initialized.
// Elements at the front need to be initialized first.
std::deque<std::shared_ptr<Element>> uninitialized_elements_
GUARDED_BY(mu_);
// Elements to be used in the interleave cycle in the future. The element
// at the front is the next element to add to the interleave cycle when a
// current element is exhausted.
std::deque<std::shared_ptr<Element>> future_elements_ GUARDED_BY(mu_);
// Identifies whether the global end of input has been reached.
bool end_of_input_ GUARDED_BY(mu_) = false;
// The number of outstanding element threads.
int outstanding_threads_ GUARDED_BY(mu_) = 0;
// Condition variable notified when outstanding_threads_ drops to 0.
condition_variable outstanding_threads_finished_cond_var_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
int64 element_id_counter_ GUARDED_BY(mu_) = 0;
// Iterator context used in worker threads.
std::unique_ptr<IteratorContext> ctx_;
// Set to true during checkpointing to alert element threads that they
// should pause operation. This is needed to prevent constantly-active
// worker threads from blocking checkpointing indefinitely.
bool wait_for_checkpoint_ = false;
// Identifies whether background threads should be cancelled.
bool cancelled_ GUARDED_BY(mu_) = false;
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
};
const DatasetBase* const input_;
const std::unique_ptr<CapturedFunction> captured_func_;
const int64 cycle_length_;
const int64 block_length_;
const int64 num_parallel_calls_;
const int op_version_ = 2;
const bool sloppy_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
};
ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
FunctionMetadata::Params params;
params.is_multi_device_function = true;
OP_REQUIRES_OK(ctx,
FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kSloppy, &sloppy_));
}
void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase* input,
DatasetBase** output) {
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
if (cycle_length == model::kAutotune) {
cycle_length = port::NumSchedulableCPUs();
}
OP_REQUIRES(ctx, cycle_length > 0,
errors::InvalidArgument("`cycle_length` must be > 0"));
int64 block_length = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length));
OP_REQUIRES(ctx, block_length > 0,
errors::InvalidArgument("`block_length` must be > 0"));
int64 num_parallel_calls = 0;
OP_REQUIRES_OK(
ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
OP_REQUIRES(
ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutotune,
errors::InvalidArgument("num_parallel_calls must be greater than zero."));
OP_REQUIRES(
ctx, num_parallel_calls <= cycle_length,
errors::InvalidArgument(
"num_parallel_calls must less than or equal to cycle_length."));
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx,
CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
&captured_func));
if (num_parallel_calls == model::kAutotune) {
metrics::RecordTFDataAutotune(kDatasetType);
}
*output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
block_length, num_parallel_calls, sloppy_,
output_types_, output_shapes_);
}
namespace {
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV2");
} // namespace
} // namespace data
} // namespace tensorflow