| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #include "tensorflow/core/kernels/data/prefetch_dataset_op.h" |
| |
| #include <deque> |
| |
| #include "tensorflow/core/common_runtime/metrics.h" |
| #include "tensorflow/core/framework/partial_tensor_shape.h" |
| #include "tensorflow/core/framework/stats_aggregator.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/kernels/data/name_utils.h" |
| #include "tensorflow/core/kernels/data/stats_utils.h" |
| #include "tensorflow/core/lib/core/error_codes.pb.h" |
| #include "tensorflow/core/lib/gtl/cleanup.h" |
| #include "tensorflow/core/lib/strings/str_util.h" |
| #include "tensorflow/core/lib/strings/stringprintf.h" |
| |
| namespace tensorflow { |
| namespace data { |
| |
| // See documentation in ../../ops/dataset_ops.cc for a high-level |
| // description of the following op. |
| |
| /* static */ constexpr const char* const PrefetchDatasetOp::kDatasetType; |
| /* static */ constexpr const char* const PrefetchDatasetOp::kInputDataset; |
| /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSize; |
| /* static */ constexpr const char* const PrefetchDatasetOp::kOutputTypes; |
| /* static */ constexpr const char* const PrefetchDatasetOp::kOutputShapes; |
| /* static */ constexpr const char* const PrefetchDatasetOp::kSlackPeriod; |
| /* static */ constexpr const char* const PrefetchDatasetOp::kLegacyAutotune; |
| |
| // Determines the fraction of slack time by which to delay prefetching of data. |
| constexpr double kSleepFactor = 0.2; |
| constexpr char kBuffer[] = "buffer"; |
| constexpr char kStatus[] = "status"; |
| constexpr char kSizeSuffix[] = ".size"; |
| constexpr char kCodeSuffix[] = ".code"; |
| constexpr char kErrorMessageSuffix[] = ".error_message"; |
| |
| class PrefetchDatasetOp::Dataset : public DatasetBase { |
| public: |
| Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size, |
| int64 slack_period, bool legacy_autotune) |
| : DatasetBase(DatasetContext(ctx)), |
| input_(input), |
| buffer_size_(buffer_size), |
| slack_period_(slack_period), |
| legacy_autotune_(legacy_autotune) { |
| input_->Ref(); |
| } |
| |
| ~Dataset() override { input_->Unref(); } |
| |
| std::unique_ptr<IteratorBase> MakeIteratorInternal( |
| const string& prefix) const override { |
| return absl::make_unique<Iterator>(Iterator::Params{ |
| this, name_utils::IteratorPrefix(kDatasetType, prefix)}); |
| } |
| |
| const DataTypeVector& output_dtypes() const override { |
| return input_->output_dtypes(); |
| } |
| |
| const std::vector<PartialTensorShape>& output_shapes() const override { |
| return input_->output_shapes(); |
| } |
| |
| string DebugString() const override { |
| return name_utils::DatasetDebugString(kDatasetType); |
| } |
| |
| int64 Cardinality() const override { return input_->Cardinality(); } |
| |
| Status CheckExternalState() const override { |
| return input_->CheckExternalState(); |
| } |
| |
| protected: |
| Status AsGraphDefInternal(SerializationContext* ctx, |
| DatasetGraphDefBuilder* b, |
| Node** output) const override { |
| Node* input_graph_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); |
| Node* buffer_size = nullptr; |
| TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); |
| AttrValue slack_period_attr; |
| b->BuildAttrValue(slack_period_, &slack_period_attr); |
| TF_RETURN_IF_ERROR(b->AddDataset( |
| this, {input_graph_node, buffer_size}, |
| {std::make_pair(kSlackPeriod, slack_period_attr)}, output)); |
| return Status::OK(); |
| } |
| |
| private: |
| class Iterator : public DatasetIterator<Dataset> { |
| public: |
| explicit Iterator(const Params& params) |
| : DatasetIterator<Dataset>(params), |
| auto_tuner_(params.dataset->buffer_size_) { |
| slack_us_ = 0; |
| } |
| |
| ~Iterator() override { |
| // Signal the prefetch thread to terminate it. We will then |
| // join that thread when we delete `this->prefetch_thread_`. |
| // |
| // TODO(mrry): Replace this cancellation logic with a |
| // CancellationManager. The syntax would be more heavyweight, |
| // but it would be possible to thread a cancellation manager |
| // through the IteratorContext to upstream, |
| // potentially-blocking iterators, when we add these. |
| { |
| mutex_lock l(mu_); |
| cancelled_ = true; |
| cond_var_.notify_all(); |
| } |
| } |
| |
| string BuildTraceMeName() override { |
| int64 buffer_limit; |
| { |
| tf_shared_lock l(mu_); |
| buffer_limit = auto_tuner_.buffer_limit(); |
| } |
| string prefetch_with_slack_trace = ""; |
| if (dataset()->slack_period_ > 0) { |
| int64 slack_us = slack_us_; |
| prefetch_with_slack_trace = strings::StrCat(",slack=", slack_us); |
| } |
| return strings::StrCat(prefix(), "#buffer_limit=", buffer_limit, |
| prefetch_with_slack_trace, "#"); |
| } |
| |
| Status Initialize(IteratorContext* ctx) override { |
| return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); |
| } |
| |
| Status GetNextInternal(IteratorContext* ctx, |
| std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) override { |
| const auto& stats_aggregator = ctx->stats_aggregator(); |
| { |
| mutex_lock l(mu_); |
| TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx)); |
| // Wait until the next element in the buffer has been |
| // produced, or we are shutting down. |
| while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ && |
| auto_tuner_.buffer_limit() != 0) { |
| auto_tuner_.RecordEmpty(); |
| RecordStop(ctx); |
| cond_var_.wait(l); |
| RecordStart(ctx); |
| } |
| |
| if (cancelled_) { |
| return errors::Cancelled( |
| "PrefetchDatasetOp::Dataset::Iterator::GetNext"); |
| } |
| |
| if (!buffer_.empty()) { |
| return Consume(ctx, out_tensors, end_of_sequence); |
| } |
| |
| if (prefetch_thread_finished_) { |
| *end_of_sequence = true; |
| return Status::OK(); |
| } |
| |
| DCHECK_EQ(auto_tuner_.buffer_limit(), 0); |
| } |
| |
| mutex_lock parent_l(parent_mu_); |
| mutex_lock l(mu_); |
| if (stats_aggregator) { |
| stats_aggregator->AddScalar( |
| stats_utils::BufferSizeScalarName(dataset()->node_name()), |
| static_cast<float>(buffer_.size()), num_elements()); |
| stats_aggregator->AddScalar( |
| stats_utils::BufferCapacityScalarName(dataset()->node_name()), |
| static_cast<float>(auto_tuner_.buffer_limit()), num_elements()); |
| } |
| return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); |
| } |
| |
| protected: |
| std::shared_ptr<model::Node> CreateNode( |
| IteratorContext* ctx, model::Node::Args args) const override { |
| return model::MakeAsyncKnownRatioNode(std::move(args), |
| /*ratio=*/1, |
| /*parameters=*/{}); |
| } |
| |
| Status SaveInternal(IteratorStateWriter* writer) override { |
| // Acquire both locks to ensure that the prefetch thread and |
| // all GetNext threads are blocked. |
| mutex_lock parent_l(parent_mu_); |
| mutex_lock l(mu_); |
| TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(full_name(kBufferSize), buffer_.size())); |
| for (size_t i = 0; i < buffer_.size(); i++) { |
| auto& buffer_element = buffer_[i]; |
| TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status)); |
| if (buffer_element.status.ok()) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| full_name(strings::StrCat(kBuffer, "[", i, "]", kSizeSuffix)), |
| buffer_element.value.size())); |
| for (size_t j = 0; j < buffer_element.value.size(); j++) { |
| TF_RETURN_IF_ERROR(writer->WriteTensor( |
| full_name(strings::StrCat(kBuffer, "[", i, "][", j, "]")), |
| buffer_element.value[j])); |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status RestoreInternal(IteratorContext* ctx, |
| IteratorStateReader* reader) override { |
| mutex_lock parent_l(parent_mu_); |
| mutex_lock l(mu_); |
| buffer_.clear(); |
| TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); |
| size_t buffer_size; |
| { |
| int64 temp; |
| TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBufferSize), &temp)); |
| buffer_size = static_cast<size_t>(temp); |
| } |
| for (size_t i = 0; i < buffer_size; i++) { |
| buffer_.emplace_back(); |
| auto& buffer_element = buffer_.back(); |
| TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status)); |
| if (buffer_element.status.ok()) { |
| size_t value_size; |
| { |
| int64 temp; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| full_name(strings::StrCat(kBuffer, "[", i, "]", kSizeSuffix)), |
| &temp)); |
| value_size = static_cast<size_t>(temp); |
| } |
| buffer_element.value.reserve(value_size); |
| for (size_t j = 0; j < value_size; j++) { |
| buffer_element.value.emplace_back(); |
| TF_RETURN_IF_ERROR(reader->ReadTensor( |
| full_name(strings::StrCat(kBuffer, "[", i, "][", j, "]")), |
| &buffer_element.value.back())); |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| private: |
| // A buffer element comprises a status and (if that status is |
| // OK) a vector of tensors, representing an element of the input dataset. |
| struct BufferElement { |
| // The producer sets `status` if getting the input element fails. |
| Status status; |
| // The buffered data element. |
| std::vector<Tensor> value; |
| int64 created_us; |
| }; |
| |
| Status Consume(IteratorContext* ctx, std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| const auto& stats_aggregator = ctx->stats_aggregator(); |
| if (stats_aggregator) { |
| stats_aggregator->AddToHistogram( |
| stats_utils::BufferUtilizationHistogramName(dataset()->node_name()), |
| {static_cast<float>(buffer_.size()) / |
| static_cast<float>(auto_tuner_.buffer_limit())}, |
| num_elements()); |
| stats_aggregator->AddScalar( |
| stats_utils::BufferSizeScalarName(dataset()->node_name()), |
| static_cast<float>(buffer_.size()), num_elements()); |
| stats_aggregator->AddScalar( |
| stats_utils::BufferCapacityScalarName(dataset()->node_name()), |
| static_cast<float>(auto_tuner_.buffer_limit()), num_elements()); |
| } |
| // A new element is available. Forward the status from computing it, and |
| // (if we successfully got an element) the output values. |
| Status s = buffer_.front().status; |
| if (s.ok()) { |
| if (dataset()->slack_period_ > 0 && |
| (num_elements() + 1) % dataset()->slack_period_ == 0) { |
| // TODO(rachelim): Consider doing something more sophisticated |
| // to decide how long to sleep for; e.g. using a kalman filter. |
| int64 slack_us = |
| Env::Default()->NowMicros() - buffer_.front().created_us; |
| // Every slack_period_-th element, update the most recent slack time, |
| // measured by the duration between when the element is prefetched |
| // and when it is consumed. We add kSleepFactor * slack_us_ to the |
| // measurement because we slept for that duration before prefetching |
| // the element. |
| slack_us_ = kSleepFactor * slack_us_ + slack_us; |
| VLOG(2) << "Setting slack_us_: " << slack_us_; |
| } |
| *out_tensors = std::move(buffer_.front().value); |
| RecordBufferDequeue(ctx, *out_tensors); |
| } |
| auto_tuner_.RecordConsumption(buffer_.size()); |
| buffer_.pop_front(); |
| *end_of_sequence = false; |
| |
| // Wake the prefetch thread, in case it has been waiting for space |
| // in the buffer. Also wake up threads from other calls to GetNext. |
| // |
| // TODO(mrry): Consider using different condition variables for |
| // GetNext and Prefetch. |
| cond_var_.notify_all(); |
| return s; |
| } |
| |
| Status EnsurePrefetchThreadStarted(IteratorContext* ctx) |
| EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| if (!prefetch_thread_) { |
| std::shared_ptr<IteratorContext> new_ctx = |
| std::make_shared<IteratorContext>(*ctx); |
| prefetch_thread_ = ctx->StartThread( |
| "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); }); |
| } |
| return Status::OK(); |
| } |
| |
| // Prefetches elements of the input, storing results in an internal buffer. |
| // |
| // It owns the iterator context passed to it. |
| void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) { |
| RecordStart(ctx.get()); |
| auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); |
| // Keep track of where we are in an iteration "burst" |
| int num_produced = 0; |
| while (true) { |
| // 1. Wait for a slot in the buffer. |
| { |
| mutex_lock l(mu_); |
| while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) { |
| RecordStop(ctx.get()); |
| cond_var_.wait(l); |
| RecordStart(ctx.get()); |
| } |
| |
| if (cancelled_) { |
| return; |
| } |
| } |
| |
| if (dataset()->slack_period_ > 0 && |
| num_produced % dataset()->slack_period_ == 0) { |
| // For the first element in the "burst", sleep for a bit if there is |
| // slack. |
| VLOG(2) << "Sleeping for: " << slack_us_ * kSleepFactor; |
| ctx->env()->SleepForMicroseconds(slack_us_ * kSleepFactor); |
| } |
| |
| // 2. Read the next element. |
| // Acquire the parent lock since we will be reading an element |
| // from the input iterator. Note that we do not wish to release |
| // this lock till we have added the fetched element to the |
| // `buffer_` else there will be local state that may be missed |
| // by SaveInternal. |
| mutex_lock parent_l(parent_mu_); |
| bool end_of_sequence; |
| BufferElement buffer_element; |
| buffer_element.status = input_impl_->GetNext( |
| ctx.get(), &buffer_element.value, &end_of_sequence); |
| if (buffer_element.status.ok() && end_of_sequence) { |
| mutex_lock l(mu_); |
| prefetch_thread_finished_ = true; |
| cond_var_.notify_all(); |
| return; |
| } |
| |
| // 3. Signal that the element has been produced. |
| { |
| mutex_lock l(mu_); |
| RecordBufferEnqueue(ctx.get(), buffer_element.value); |
| buffer_element.created_us = ctx->env()->NowMicros(); |
| buffer_.push_back(std::move(buffer_element)); |
| cond_var_.notify_all(); |
| } |
| ++num_produced; |
| } |
| } |
| |
| Status WriteStatus(IteratorStateWriter* writer, size_t index, |
| const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| CodeKey(index), static_cast<int64>(status.code()))); |
| if (!status.ok()) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), |
| status.error_message())); |
| } |
| return Status::OK(); |
| } |
| |
| Status ReadStatus(IteratorStateReader* reader, size_t index, Status* status) |
| EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| int64 code_int; |
| TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); |
| error::Code code = static_cast<error::Code>(code_int); |
| |
| if (code != error::Code::OK) { |
| string error_message; |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(ErrorMessageKey(index), &error_message)); |
| *status = Status(code, error_message); |
| } else { |
| *status = Status::OK(); |
| } |
| return Status::OK(); |
| } |
| |
| string CodeKey(size_t index) { |
| return full_name(strings::StrCat(kStatus, "[", index, "]", kCodeSuffix)); |
| } |
| |
| string ErrorMessageKey(size_t index) { |
| return full_name( |
| strings::StrCat(kStatus, "[", index, "]", kErrorMessageSuffix)); |
| } |
| |
| // This mutex is used to ensure exclusivity between multiple threads |
| // reading/writing this iterator's local state. |
| mutex mu_; |
| // This mutex is used to ensure exclusivity between multiple threads |
| // accessing the parent iterator. We keep this separate from `mu_` to |
| // allow prefetching to run in parallel with GetNext calls. |
| mutex parent_mu_ ACQUIRED_BEFORE(mu_); |
| std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_); |
| condition_variable cond_var_; |
| PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_); |
| std::deque<BufferElement> buffer_ GUARDED_BY(mu_); |
| std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_); |
| bool cancelled_ GUARDED_BY(mu_) = false; |
| bool prefetch_thread_finished_ GUARDED_BY(mu_) = false; |
| |
| std::atomic<int64> slack_us_; |
| }; |
| const DatasetBase* const input_; |
| const int64 buffer_size_; |
| |
| // If non-zero, determines the period between injecting "slack" into the |
| // execution. |
| const int64 slack_period_; |
| |
| // Determines whether legacy autotuning should be used. |
| const bool legacy_autotune_ = true; |
| }; |
| |
| PrefetchDatasetOp::PrefetchDatasetOp(OpKernelConstruction* ctx) |
| : UnaryDatasetOpKernel(ctx) { |
| if (ctx->HasAttr(kSlackPeriod)) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kSlackPeriod, &slack_period_)); |
| } |
| if (ctx->HasAttr(kLegacyAutotune)) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kLegacyAutotune, &legacy_autotune_)); |
| } |
| } |
| |
| void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, |
| DatasetBase** output) { |
| int64 buffer_size = 0; |
| OP_REQUIRES_OK(ctx, |
| ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size)); |
| OP_REQUIRES(ctx, buffer_size >= 0 || buffer_size == model::kAutotune, |
| errors::InvalidArgument("buffer_size must be >= 0 or set " |
| "buffer_size to be ", |
| model::kAutotune, " for auto-tuning")); |
| |
| if (buffer_size == model::kAutotune) { |
| metrics::RecordTFDataAutotune(kDatasetType); |
| } |
| |
| *output = |
| new Dataset(ctx, input, buffer_size, slack_period_, legacy_autotune_); |
| } |
| |
| namespace { |
| REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU).Priority(2), |
| PrefetchDatasetOp); |
| REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") |
| .Device(DEVICE_GPU) |
| .HostMemory("buffer_size") |
| .HostMemory("input_dataset") |
| .HostMemory("handle") |
| .Priority(1), |
| PrefetchDatasetOp); |
| } // namespace |
| |
| } // namespace data |
| } // namespace tensorflow |