blob: ee20249bfe7c2d79594d06a7524fc01bf04c94ec [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
#include <atomic>
#include <deque>
#include <functional>
#include <utility>
#include <vector>
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/cpu_info.h"
namespace tensorflow {
namespace data {
namespace {
class ParallelMapIterator : public DatasetBaseIterator {
public:
explicit ParallelMapIterator(
const typename DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls)
: DatasetBaseIterator(params),
input_dataset_(input_dataset),
init_func_(std::move(init_func)),
map_func_(std::move(map_func)),
num_parallel_calls_(num_parallel_calls) {}
~ParallelMapIterator() override {
mutex_lock l(mu_);
// Cancel the runner thread.
cancelled_ = true;
cond_var_.notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
cond_var_.wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
if (num_parallel_calls_ == kAutoTune) {
num_parallel_calls_ = 1;
// TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
// use it here for the maximum.
AddTunableParameter(ctx, "parallelism", &num_parallel_calls_ /* value */,
1 /* min */, port::NumSchedulableCPUs() /* max */,
&cond_var_);
} else {
AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
}
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
if (init_func_) {
TF_RETURN_IF_ERROR(init_func_(ctx));
}
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
std::shared_ptr<InvocationResult> result;
{
mutex_lock l(mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty()) {
RecordStop(ctx);
cond_var_.wait(l);
RecordStart(ctx);
}
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
cond_var_.notify_all();
}
RecordStop(ctx);
result->notification.WaitForNotification();
RecordStart(ctx);
return ProcessResult(result, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
cond_var_.wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
std::shared_ptr<InvocationResult> result = invocation_results_[i];
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("invocation_results[", i, "].size")),
result->return_values.size()));
for (size_t j = 0; j < result->return_values.size(); j++) {
TF_RETURN_IF_ERROR(
writer->WriteTensor(full_name(strings::StrCat(
"invocation_results[", i, "][", j, "]")),
result->return_values[j]));
}
if (result->end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")),
""));
}
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name("invocation_results.size"), &invocation_results_size));
for (size_t i = 0; i < invocation_results_size; i++) {
std::shared_ptr<InvocationResult> result(new InvocationResult());
invocation_results_.push_back(result);
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
size_t num_return_values;
{
int64 size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name(strings::StrCat(
"invocation_results[", i, "].size")),
&size));
num_return_values = static_cast<size_t>(size);
if (num_return_values != size) {
return errors::InvalidArgument(strings::StrCat(
full_name(
strings::StrCat("invocation_results[", i, "].size")),
": ", size, " is not a valid value of type size_t."));
}
}
result->return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
result->return_values.emplace_back();
TF_RETURN_IF_ERROR(
reader->ReadTensor(full_name(strings::StrCat(
"invocation_results[", i, "][", j, "]")),
&result->return_values.back()));
}
result->end_of_input = reader->Contains(full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")));
result->notification.Notify();
}
return Status::OK();
}
private:
struct InvocationResult {
Notification notification;
Status status;
std::vector<Tensor> return_values;
bool end_of_input;
};
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
{}, "runner_thread",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
}
}
void CallCompleted(const std::shared_ptr<InvocationResult>& result)
LOCKS_EXCLUDED(mu_) {
{
mutex_lock l(mu_);
num_calls_--;
cond_var_.notify_all();
}
result->notification.Notify();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
LOCKS_EXCLUDED(mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
result->status =
input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input);
if (result->end_of_input || !result->status.ok()) {
CallCompleted(result);
return;
}
// Call `func_(input_element)`, store the result in `result->return_values`,
// and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
};
map_func_(ctx.get(), std::move(input_element), &result->return_values,
std::move(done));
}
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
if (!result->end_of_input && result->status.ok()) {
*out_tensors = std::move(result->return_values);
*end_of_sequence = false;
return Status::OK();
}
if (errors::IsOutOfRange(result->status)) {
// `f` may deliberately raise `errors::OutOfRange` to indicate that we
// should terminate the iteration early.
*end_of_sequence = true;
return Status::OK();
}
*end_of_sequence = result->end_of_input;
return result->status;
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
new_calls.reserve(num_parallel_calls_);
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
int64 num_parallel_calls = num_parallel_calls_;
return num_calls_ >= num_parallel_calls ||
invocation_results_.size() >= num_parallel_calls;
};
while (true) {
{
mutex_lock l(mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
cond_var_.wait(l);
RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
while (!busy()) {
invocation_results_.emplace_back(new InvocationResult());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
cond_var_.notify_all();
}
for (const auto& call : new_calls) {
CallFunction(ctx, call);
}
new_calls.clear();
}
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(ErrorMessageKey(index), status.error_message()));
}
return Status::OK();
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
string error_message;
TF_RETURN_IF_ERROR(
reader->ReadScalar(ErrorMessageKey(index), &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
}
return Status::OK();
}
string CodeKey(size_t index) {
return full_name(
strings::StrCat("invocation_results[", index, "].code"));
}
string ErrorMessageKey(size_t index) {
return full_name(
strings::StrCat("invocation_results[", index, "].error_message"));
}
const DatasetBase* const input_dataset_; // Not owned.
const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
// Used for coordination between the main thread and the runner thread.
mutex mu_;
// Used for coordination between the main thread and the runner thread. In
// particular, the runner thread should only schedule new calls when the
// number of in-flight calls is less than the user specified level of
// parallelism and there are slots available in the `invocation_results_`
// buffer.
condition_variable cond_var_;
// Identifies the maximum number of parallel calls.
std::atomic<int64> num_parallel_calls_;
// Counts the number of outstanding calls.
int64 num_calls_ GUARDED_BY(mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
GUARDED_BY(mu_);
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false;
};
} // namespace
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
int32 num_parallel_calls) {
return NewParallelMapIterator(params, input_dataset, nullptr,
std::move(map_func), num_parallel_calls);
}
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
return std::unique_ptr<IteratorBase>(
new ParallelMapIterator(params, input_dataset, std::move(init_func),
std::move(map_func), num_parallel_calls));
}
} // namespace data
} // namespace tensorflow