| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #include "tensorflow/core/kernels/data/parallel_interleave_dataset_op.h" |
| |
| #include <atomic> |
| #include <deque> |
| #include <memory> |
| #include <utility> |
| |
| #include "absl/strings/str_format.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" |
| #include "tensorflow/core/data/captured_function.h" |
| #include "tensorflow/core/data/dataset_utils.h" |
| #include "tensorflow/core/data/name_utils.h" |
| #include "tensorflow/core/data/stats_utils.h" |
| #include "tensorflow/core/framework/dataset.h" |
| #include "tensorflow/core/framework/metrics.h" |
| #include "tensorflow/core/framework/model.h" |
| #include "tensorflow/core/framework/partial_tensor_shape.h" |
| #include "tensorflow/core/framework/stats_aggregator.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/threadpool.h" |
| #include "tensorflow/core/lib/gtl/cleanup.h" |
| #include "tensorflow/core/lib/random/random.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/lib/strings/stringprintf.h" |
| #include "tensorflow/core/platform/blocking_counter.h" |
| #include "tensorflow/core/platform/cpu_info.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/stringprintf.h" |
| #include "tensorflow/core/profiler/lib/traceme.h" |
| #include "tensorflow/core/profiler/lib/traceme_encode.h" |
| |
| namespace tensorflow { |
| namespace data { |
| |
| // See documentation in ../../ops/dataset_ops.cc for a high-level |
| // description of the following op. |
| |
| /* static */ constexpr const char* const |
| ParallelInterleaveDatasetOp::kDatasetType; |
| /* static */ constexpr const char* const |
| ParallelInterleaveDatasetOp::kInputDataset; |
| /* static */ constexpr const char* const |
| ParallelInterleaveDatasetOp::kOtherArguments; |
| /* static */ constexpr const char* const |
| ParallelInterleaveDatasetOp::kCycleLength; |
| /* static */ constexpr const char* const |
| ParallelInterleaveDatasetOp::kBlockLength; |
| /* static */ constexpr const char* const |
| ParallelInterleaveDatasetOp::kBufferOutputElements; |
| /* static */ constexpr const char* const |
| ParallelInterleaveDatasetOp::kPrefetchInputElements; |
| /* static */ constexpr const char* const |
| ParallelInterleaveDatasetOp::kNumParallelCalls; |
| /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc; |
| /* static */ constexpr const char* const |
| ParallelInterleaveDatasetOp::kTarguments; |
| /* static */ constexpr const char* const |
| ParallelInterleaveDatasetOp::kOutputTypes; |
| /* static */ constexpr const char* const |
| ParallelInterleaveDatasetOp::kOutputShapes; |
| /* static */ constexpr const char* const |
| ParallelInterleaveDatasetOp::kDeterministic; |
| /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy; |
| |
| namespace { |
| |
| constexpr char kParallelism[] = "parallelism"; |
| constexpr char kBlockIndex[] = "block_index"; |
| constexpr char kCycleIndex[] = "cycle_index"; |
| constexpr char kEndOfInput[] = "end_of_input"; |
| constexpr char kElementIdCounter[] = "element_id_counter"; |
| constexpr char kCurrentElements[] = "current_elements"; |
| constexpr char kCurrentElementsSize[] = "current_elements.size"; |
| constexpr char kFutureElements[] = "future_elements"; |
| constexpr char kFutureElementsSize[] = "future_elements.size"; |
| constexpr char kResultsSuffix[] = ".results"; |
| constexpr char kCodeSuffix[] = ".code"; |
| constexpr char kErrorMessageSuffix[] = ".error_message"; |
| constexpr char kIdSuffix[] = ".id"; |
| constexpr char kSizeSuffix[] = ".size"; |
| constexpr char kInputsSuffix[] = ".inputs"; |
| constexpr char kIsReadySuffix[] = ".is_ready"; |
| |
| constexpr char kParallelInterleaveDatasetV2[] = "ParallelInterleaveDatasetV2"; |
| constexpr char kParallelInterleaveDatasetV3[] = "ParallelInterleaveDatasetV3"; |
| constexpr char kParallelInterleaveDatasetV4[] = "ParallelInterleaveDatasetV4"; |
| |
| // `kCyclePrefetchFactor * cycle_length` is the default number of future cycle |
| // elements that will be prefetched ahead of time. The purpose of prefetching |
| // future cycle elements is to overlap expensive initialization (e.g. opening of |
| // a remote file) with other computation. |
| constexpr double kDefaultCyclePrefetchFactor = 2.0L; |
| |
| // `kPerIteratorPrefetchFactor * block_length + 1` is the default number of |
| // per-iterator results that will be prefetched ahead of time. The `+ 1` is to |
| // match the behavior of the original implementation. |
| constexpr double kDefaultPerIteratorPrefetchFactor = 2.0L; |
| |
| // Period between reporting dataset statistics. |
| constexpr int kStatsReportingPeriodMillis = 1000; |
| |
| inline int64_t CeilDiv(int64_t numerator, int64_t denominator) { |
| return (numerator + denominator - 1) / denominator; |
| } |
| |
| int64_t ComputeBufferOutputElements(int64_t configured_buffer_output_elements, |
| int64_t block_length) { |
| if (configured_buffer_output_elements != model::kAutotune) { |
| return configured_buffer_output_elements; |
| } |
| return kDefaultPerIteratorPrefetchFactor * block_length + 1; |
| } |
| |
| int64_t ComputePrefetchInputElements(int64_t configured_prefetch_input_elements, |
| int64_t cycle_length) { |
| if (configured_prefetch_input_elements != model::kAutotune) { |
| return configured_prefetch_input_elements; |
| } |
| return kDefaultCyclePrefetchFactor * cycle_length; |
| } |
| |
| int64_t OpVersionFromOpName(absl::string_view op_name) { |
| if (op_name == kParallelInterleaveDatasetV2) { |
| return 2; |
| } else if (op_name == kParallelInterleaveDatasetV3) { |
| return 3; |
| } else { |
| DCHECK_EQ(op_name, kParallelInterleaveDatasetV4); |
| return 4; |
| } |
| } |
| |
| } // namespace |
| |
| // The motivation for creating an alternative implementation of parallel |
| // interleave is to decouple the degree of parallelism from the cycle length. |
| // This makes it possible to change the degree of parallelism (e.g. through |
| // auto-tuning) without changing the cycle length (which would change the order |
| // in which elements are produced). |
| class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { |
| public: |
| Dataset(OpKernelContext* ctx, const DatasetBase* input, |
| std::unique_ptr<CapturedFunction> captured_func, int64_t cycle_length, |
| int64_t block_length, int64_t buffer_output_elements, |
| int64_t prefetch_input_elements, int64_t num_parallel_calls, |
| DeterminismPolicy deterministic, const DataTypeVector& output_types, |
| const std::vector<PartialTensorShape>& output_shapes, int op_version) |
| : DatasetBase(DatasetContext(ctx)), |
| input_(input), |
| captured_func_(std::move(captured_func)), |
| cycle_length_(cycle_length), |
| block_length_(block_length), |
| buffer_output_elements_( |
| ComputeBufferOutputElements(buffer_output_elements, block_length)), |
| prefetch_input_elements_(ComputePrefetchInputElements( |
| prefetch_input_elements, cycle_length)), |
| num_parallel_calls_(num_parallel_calls), |
| deterministic_(deterministic), |
| output_types_(output_types), |
| output_shapes_(output_shapes), |
| op_version_(op_version), |
| traceme_metadata_( |
| {{"autotune", |
| num_parallel_calls == model::kAutotune ? "true" : "false"}, |
| {"block_length", |
| strings::Printf("%lld", static_cast<long long>(block_length))}, |
| {"cycle_length", |
| strings::Printf("%lld", static_cast<long long>(cycle_length))}, |
| {"deterministic", |
| deterministic.IsNondeterministic() ? "false" : "true"}}) { |
| input_->Ref(); |
| } |
| |
| ~Dataset() override { input_->Unref(); } |
| |
| std::unique_ptr<IteratorBase> MakeIteratorInternal( |
| const string& prefix) const override { |
| name_utils::IteratorPrefixParams params; |
| params.op_version = op_version_; |
| bool deterministic = |
| deterministic_.IsDeterministic() || deterministic_.IsDefault(); |
| return absl::make_unique<ParallelInterleaveIterator>( |
| ParallelInterleaveIterator::Params{ |
| this, |
| name_utils::IteratorPrefix( |
| ParallelInterleaveDatasetOp::kDatasetType, prefix, params)}, |
| deterministic); |
| } |
| |
| const DataTypeVector& output_dtypes() const override { return output_types_; } |
| |
| const std::vector<PartialTensorShape>& output_shapes() const override { |
| return output_shapes_; |
| } |
| |
| string DebugString() const override { |
| name_utils::DatasetDebugStringParams params; |
| params.op_version = op_version_; |
| return name_utils::DatasetDebugString( |
| ParallelInterleaveDatasetOp::kDatasetType, params); |
| } |
| |
| int64_t Cardinality() const override { |
| int64_t n = input_->Cardinality(); |
| if (n == kInfiniteCardinality) { |
| return n; |
| } |
| return kUnknownCardinality; |
| } |
| |
| Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override { |
| inputs->push_back(input_); |
| return Status::OK(); |
| } |
| |
| Status CheckExternalState() const override { |
| TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); |
| return input_->CheckExternalState(); |
| } |
| |
| protected: |
| Status AsGraphDefInternal(SerializationContext* ctx, |
| DatasetGraphDefBuilder* b, |
| Node** output) const override { |
| std::vector<std::pair<size_t, Node*>> inputs; |
| std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs; |
| int input_index = 0; |
| |
| Node* input_node; |
| TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); |
| inputs.emplace_back(input_index++, input_node); |
| |
| std::vector<Node*> other_arguments; |
| DataTypeVector other_arguments_types; |
| TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, |
| &other_arguments_types)); |
| list_inputs.emplace_back(input_index++, other_arguments); |
| |
| Node* cycle_length_node; |
| TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node)); |
| inputs.emplace_back(input_index++, cycle_length_node); |
| |
| Node* block_length_node; |
| TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node)); |
| inputs.emplace_back(input_index++, block_length_node); |
| |
| if (op_version_ >= 4) { |
| Node* buffer_output_elements_node; |
| TF_RETURN_IF_ERROR( |
| b->AddScalar(buffer_output_elements_, &buffer_output_elements_node)); |
| inputs.emplace_back(input_index++, buffer_output_elements_node); |
| |
| Node* prefetch_input_elements_node; |
| TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_, |
| &prefetch_input_elements_node)); |
| inputs.emplace_back(input_index++, prefetch_input_elements_node); |
| } |
| |
| Node* num_parallel_calls_node; |
| TF_RETURN_IF_ERROR( |
| b->AddScalar(num_parallel_calls_, &num_parallel_calls_node)); |
| inputs.emplace_back(input_index++, num_parallel_calls_node); |
| |
| std::vector<std::pair<StringPiece, AttrValue>> attrs; |
| AttrValue f; |
| b->BuildAttrValue(captured_func_->func(), &f); |
| attrs.emplace_back(kFunc, f); |
| |
| AttrValue other_arguments_types_attr; |
| b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); |
| attrs.emplace_back(kTarguments, other_arguments_types_attr); |
| |
| if (op_version_ == 2) { |
| AttrValue sloppy_attr; |
| b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr); |
| attrs.emplace_back(kSloppy, sloppy_attr); |
| } |
| if (op_version_ >= 3) { |
| AttrValue deterministic_attr; |
| b->BuildAttrValue(deterministic_.String(), &deterministic_attr); |
| attrs.emplace_back(kDeterministic, deterministic_attr); |
| } |
| |
| TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output)); |
| return Status::OK(); |
| } |
| |
| private: |
| class ParallelInterleaveIterator : public DatasetIterator<Dataset> { |
| public: |
| ParallelInterleaveIterator(const Params& params, bool deterministic) |
| : DatasetIterator<Dataset>(params), |
| mu_(std::make_shared<mutex>()), |
| num_parallel_calls_cond_var_(std::make_shared<condition_variable>()), |
| num_parallel_calls_(std::make_shared<model::SharedState>( |
| params.dataset->num_parallel_calls_, mu_, |
| num_parallel_calls_cond_var_)), |
| deterministic_(deterministic), |
| current_elements_(params.dataset->cycle_length_) {} |
| |
| ~ParallelInterleaveIterator() override { |
| CancelThreads(/*wait=*/true); |
| } |
| |
| // TODO(jsimsa): Register cancellation callback once the implementation is |
| // refactored not to hold mu_ while calling `GetNext` on the input. |
| Status Initialize(IteratorContext* ctx) override { |
| mutex_lock l(*mu_); |
| // Note that if `ctx->thread_pool()` is non-null, then instead of creating |
| // a dedicated thread pool of size `num_threads`, computation will be |
| // scheduled into the shared threadpool. The threadpool is guaranteed to |
| // support `num_threads` concurrent tasks without blocking indefinitely. |
| // |
| // Allocate one thread for the worker manager, one thread for stats |
| // collection, `cycle_length_` threads for the current workers, and |
| // `future_elements_prefetch_` for the future workers. |
| int max_current_workers = dataset()->cycle_length_; |
| int future_workers = |
| dataset()->prefetch_input_elements_ + dataset()->cycle_length_; |
| int num_threads = 1 + max_current_workers + future_workers; |
| if (ctx->stats_aggregator()) { |
| num_threads++; |
| } |
| thread_pool_ = ctx->CreateThreadPool( |
| "data_parallel_interleave_worker_pool", num_threads); |
| if (num_parallel_calls_->value == model::kAutotune) { |
| num_parallel_calls_->value = dataset()->cycle_length_; |
| } |
| ctx_ = std::make_unique<IteratorContext>(*ctx); |
| cancellation_manager_ = absl::make_unique<CancellationManager>(); |
| IteratorContext::Params params(ctx); |
| params.interleave_depth += 1; |
| params.cancellation_manager = cancellation_manager_.get(); |
| TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( |
| IteratorContext(params), this, prefix(), &input_impl_)); |
| return dataset()->captured_func_->Instantiate( |
| ctx, &instantiated_captured_func_); |
| } |
| |
| Status GetNextInternal(IteratorContext* ctx, |
| std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) override { |
| std::shared_ptr<Result> result; |
| { |
| mutex_lock l(*mu_); |
| EnsureInitialElementsCreated(); |
| EnsureThreadsStarted(); |
| while (!cancelled_ && !Consume(&result)) { |
| RecordStop(ctx); |
| if (deterministic_) { |
| VLOG(3) << "Blocked waiting for element " |
| << current_elements_[cycle_index_]->id; |
| current_elements_[cycle_index_]->cond_var.wait(l); |
| } else { |
| any_element_available_cond_var_.wait(l); |
| } |
| RecordStart(ctx); |
| } |
| if (cancelled_) { |
| return errors::Cancelled("Iterator was cancelled"); |
| } |
| } |
| if (!result) { |
| *end_of_sequence = true; |
| return Status::OK(); |
| } |
| profiler::TraceMe traceme([&] { |
| return profiler::TraceMeEncode("ParallelInterleaveConsume", |
| {{"element_id", result->id}}); |
| }); |
| if (result->status.ok()) { |
| *out_tensors = std::move(result->return_values); |
| RecordBufferDequeue(ctx, *out_tensors); |
| } |
| *end_of_sequence = false; |
| return result->status; |
| } |
| |
| protected: |
| std::shared_ptr<model::Node> CreateNode( |
| IteratorContext* ctx, model::Node::Args args) const override { |
| return model::MakeAsyncInterleaveManyNode( |
| std::move(args), |
| {model::MakeParameter(kParallelism, num_parallel_calls_, /*min=*/1, |
| /*max=*/dataset()->cycle_length_)}); |
| } |
| |
| // TODO(aaudibert): Refactor the implementations to avoid the need for |
| // `IteratorContext` when saving the state of the iterator. |
| Status SaveInternal(SerializationContext* ctx, |
| IteratorStateWriter* writer) override { |
| TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( |
| dataset()->captured_func_->CheckExternalState())); |
| mutex_lock l(*mu_); |
| wait_for_checkpoint_ = true; |
| // Wait for all in-flight calls to complete. |
| while (num_active_workers_ > 0) { |
| zero_active_workers_cond_var_.wait(l); |
| } |
| // Initialize all elements and filter out elements with no input. |
| InitializeInputs(element_id_counter_); |
| for (auto& element : current_elements_) { |
| if (element && element->no_input) { |
| element.reset(); |
| } |
| } |
| while (!future_elements_.empty() && future_elements_.back()->no_input) { |
| future_elements_.pop_back(); |
| } |
| wait_for_checkpoint_ = false; |
| DCHECK_EQ(num_active_workers_, 0); |
| VLOG(4) << "State before save:\n" << DebugString(); |
| TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(prefix(), kBlockIndex, block_index_)); |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(prefix(), kCycleIndex, cycle_index_)); |
| if (end_of_input_) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kEndOfInput, "")); |
| } |
| TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kElementIdCounter, |
| element_id_counter_)); |
| TF_RETURN_IF_ERROR(WriteCurrentElements(ctx, writer)); |
| TF_RETURN_IF_ERROR(WriteFutureElements(ctx, writer)); |
| // Wake workers back up. |
| current_workers_cond_var_.notify_all(); |
| future_workers_cond_var_.notify_all(); |
| return Status::OK(); |
| } |
| |
| Status RestoreInternal(IteratorContext* ctx, |
| IteratorStateReader* reader) override { |
| { |
| mutex_lock l(*mu_); |
| DCHECK(!threads_initialized_); |
| DCHECK(!initial_elements_created_); |
| TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(prefix(), kBlockIndex, &block_index_)); |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(prefix(), kCycleIndex, &cycle_index_)); |
| TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kElementIdCounter, |
| &element_id_counter_)); |
| end_of_input_ = reader->Contains(prefix(), kEndOfInput); |
| } |
| TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader)); |
| TF_RETURN_IF_ERROR(ReadFutureElements(ctx, reader)); |
| mutex_lock l(*mu_); |
| initial_elements_created_ = false; |
| for (int i = 0; i < current_elements_.size(); ++i) { |
| int index = (cycle_index_ + i) % current_elements_.size(); |
| auto element = current_elements_[index]; |
| if (element) { |
| elements_to_process_.push_back(index); |
| element->initialized = true; |
| element->cycle_index = index; |
| initial_elements_created_ = true; |
| } |
| } |
| for (const auto& element : future_elements_) { |
| element->initialized = true; |
| } |
| last_valid_current_element_ = current_elements_.size() - 1; |
| while (last_valid_current_element_ >= 0 && |
| !current_elements_[last_valid_current_element_]) { |
| last_valid_current_element_--; |
| } |
| VLOG(2) << "Parallel interleave iterator restored"; |
| VLOG(4) << "State after restore:\n" << DebugString(); |
| return Status::OK(); |
| } |
| |
| TraceMeMetadata GetTraceMeMetadata() const override { |
| int64_t parallelism = -1; |
| int64_t results_ready = -1; |
| int64_t active_elements = -1; |
| // NOTE: We only set the parallelism value if the lock can be acquired |
| // right away to avoid introducing tracing overhead. |
| if (mu_->try_lock()) { |
| parallelism = num_parallel_calls_->value; |
| results_ready = 0; |
| active_elements = 0; |
| for (int i = 0; i < current_elements_.size(); ++i) { |
| if (current_elements_[i]) { |
| results_ready += current_elements_[i]->results.size(); |
| if (current_elements_[i]->active) { |
| active_elements++; |
| } |
| } |
| } |
| mu_->unlock(); |
| } |
| auto result = dataset()->traceme_metadata_; |
| result.push_back(std::make_pair( |
| "parallelism", |
| parallelism == -1 |
| ? kTraceInfoUnavailable |
| : strings::Printf("%lld", static_cast<long long>(parallelism)))); |
| result.push_back(std::make_pair( |
| "results_ready", results_ready == -1 |
| ? kTraceInfoUnavailable |
| : strings::Printf("%lld", static_cast<long long>( |
| results_ready)))); |
| result.push_back(std::make_pair( |
| "active_elements", |
| results_ready == -1 ? kTraceInfoUnavailable |
| : strings::Printf("%lld", static_cast<long long>( |
| active_elements)))); |
| return result; |
| } |
| |
| private: |
| // Represents the result of fetching an element from a dataset. |
| struct Result { |
| Status status; |
| int64_t id = -1; |
| std::vector<Tensor> return_values; |
| }; |
| |
| // The interleave transformation repeatedly inputs elements, applies the |
| // user-provided function to transform the input elements to datasets, and |
| // interleaves the elements of these datasets as its output. |
| // |
| // This structure represents an input element and derived state. |
| struct Element { |
| // Unique identifier, needed to support checkpointing. |
| int64_t id TF_GUARDED_BY(&ParallelInterleaveIterator::mu_); |
| // The actual input element. Iterator created from the input element. A |
| // null value indicates that the element either reached end of input or |
| // hasn't been initialized yet. |
| std::unique_ptr<std::vector<Tensor>> inputs |
| TF_GUARDED_BY(&ParallelInterleaveIterator::mu_); |
| // Iterator created from the input element. A null value indicates that |
| // the element either reached end of input or hasn't been initialized yet. |
| std::unique_ptr<IteratorBase> iterator |
| TF_GUARDED_BY(&ParallelInterleaveIterator::mu_); |
| // Buffer for storing the outputs of `iterator`. |
| std::deque<std::shared_ptr<Result>> TF_GUARDED_BY( |
| &ParallelInterleaveIterator::mu_) results; |
| // The element's index in the cycle, if it is in the current cycle. |
| // -1 if the element is not in the current cycle. |
| int64_t cycle_index TF_GUARDED_BY(&ParallelInterleaveIterator::mu_) = -1; |
| // Whether the element is currently being processed by a worker thread. |
| // This is used to ensure that only one thread at a time tries to process |
| // an element. |
| bool active TF_GUARDED_BY(&ParallelInterleaveIterator::mu_) = false; |
| // Whether the inputs and iterator have been initialized. |
| bool initialized TF_GUARDED_BY(&ParallelInterleaveIterator::mu_) = false; |
| // Whether we tried to initialize the element, but the input iterator |
| // was exhausted so we could produce no inputs. |
| bool no_input TF_GUARDED_BY(&ParallelInterleaveIterator::mu_) = false; |
| // Condition variable for communicating between current worker threads |
| // and GetNext. |
| condition_variable cond_var; |
| |
| std::string DebugString() |
| TF_EXCLUSIVE_LOCKS_REQUIRED(&ParallelInterleaveIterator::mu_) { |
| return absl::StrFormat( |
| "Element(id: %d, iterator_null: %d, results_size: %d, " |
| "cycle_index: %d, active: %d, initialized: %d, no_input: %d)", |
| id, iterator == nullptr, results.size(), cycle_index, active, |
| initialized, no_input); |
| } |
| }; |
| |
| // Sets the cancellation bit and wakes up all threads that need to be |
| // cancelled. Optionally, the method waits until all threads finish |
| // executing. |
| void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) { |
| cancellation_manager_->StartCancel(); |
| mutex_lock l(*mu_); |
| cancelled_ = true; |
| // Wake up all threads so that they can exit. This will also wake up any |
| // threads waiting in GetNextInternal. |
| for (const auto& element : current_elements_) { |
| if (element) { |
| element->cond_var.notify_all(); |
| } |
| } |
| current_workers_cond_var_.notify_all(); |
| future_workers_cond_var_.notify_all(); |
| num_parallel_calls_cond_var_->notify_all(); |
| stats_thread_cond_var_.notify_all(); |
| while (wait && outstanding_threads_ > 0) { |
| outstanding_threads_finished_cond_var_.wait(l); |
| } |
| any_element_available_cond_var_.notify_all(); |
| zero_active_workers_cond_var_.notify_all(); |
| } |
| |
| void EnsureInitialElementsCreated() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| if (!initial_elements_created_) { |
| for (int i = 0; i < dataset()->cycle_length_; ++i) { |
| current_elements_[i] = MakeElement(); |
| if (!current_elements_[i]) { |
| break; |
| } |
| current_elements_[i]->cycle_index = i; |
| elements_to_process_.push_back(i); |
| last_valid_current_element_ = i; |
| } |
| initial_elements_created_ = true; |
| } |
| } |
| |
| void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| if (!threads_initialized_) { |
| IncrementOutstandingThreads(); |
| thread_pool_->Schedule([this]() { WorkerManagerThread(); }); |
| if (ctx_->stats_aggregator()) { |
| IncrementOutstandingThreads(); |
| thread_pool_->Schedule([this]() { StatsThread(); }); |
| } |
| threads_initialized_ = true; |
| } |
| } |
| |
| // Advances the position in the interleave cycle to the next cycle |
| // element. |
| void AdvanceToNextInCycle() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| DCHECK_NE(last_valid_current_element_, -1); |
| block_index_ = 0; |
| cycle_index_ = (cycle_index_ + 1) % (last_valid_current_element_ + 1); |
| } |
| |
| // Advances the position in the interleave cycle by one. |
| void AdvancePosition() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| ++block_index_; |
| if (block_index_ == dataset()->block_length_) { |
| AdvanceToNextInCycle(); |
| } |
| } |
| |
| // Consumes a result (if available), returning an indication of whether |
| // a result is available. If `true` is returned, `result` either |
| // points to a valid result or is null if end of input has been reached. |
| bool Consume(std::shared_ptr<Result>* result) |
| TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| if (deterministic_) { |
| return ConsumeHelper(result); |
| } |
| // If we are allowed to be nondeterministic (i.e. return results out of |
| // order), try to find an element in the cycle that has a result |
| // available. |
| for (int i = 0; i < dataset()->cycle_length_; ++i) { |
| if (ConsumeHelper(result)) { |
| return true; |
| } |
| AdvanceToNextInCycle(); |
| } |
| return false; |
| } |
| |
| // Consumes a result (if available), returning an indication of whether |
| // a result is available. If `true` is returned, `result` either |
| // points to a valid result or is null if end of input has been reached. |
| bool ConsumeHelper(std::shared_ptr<Result>* result) |
| TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| while (true) { |
| if (last_valid_current_element_ == -1) { |
| // Reached end of input. |
| return true; |
| } |
| for (int64_t i = 0; i < (last_valid_current_element_ + 1); ++i) { |
| int64_t index = |
| (cycle_index_ + i) % (last_valid_current_element_ + 1); |
| if (current_elements_[index]) { |
| cycle_index_ = index; |
| if (i > 0) { |
| block_index_ = 0; |
| } |
| break; |
| } |
| } |
| DCHECK(current_elements_[cycle_index_]); |
| std::shared_ptr<Element> element = current_elements_[cycle_index_]; |
| if (!element->results.empty()) { |
| // We found a result. |
| std::swap(*result, element->results.front()); |
| element->results.pop_front(); |
| if (!element->active) { |
| elements_to_process_.push_back(cycle_index_); |
| current_workers_cond_var_.notify_one(); |
| } |
| AdvancePosition(); |
| return true; |
| } |
| if (!element->initialized || element->iterator) { |
| // The element is still producing results, so we wait. |
| return false; |
| } |
| // We've consumed all results from the element. Get a new element from |
| // future_elements, or create a new element if no future elements are |
| // available. |
| if (!future_elements_.empty()) { |
| std::shared_ptr<Element> future_element = |
| std::move(future_elements_.front()); |
| future_elements_.pop_front(); |
| if (future_element->iterator) { |
| EnableAutotune(ctx_.get(), future_element->iterator.get()); |
| } |
| future_element->cycle_index = cycle_index_; |
| current_elements_[cycle_index_] = std::move(future_element); |
| future_workers_cond_var_.notify_one(); |
| if (!current_elements_[cycle_index_]->active) { |
| current_workers_cond_var_.notify_one(); |
| } |
| } else { |
| current_elements_[cycle_index_] = MakeElement(); |
| if (current_elements_[cycle_index_]) { |
| current_elements_[cycle_index_]->cycle_index = cycle_index_; |
| elements_to_process_.push_back(cycle_index_); |
| element->cycle_index = cycle_index_; |
| current_workers_cond_var_.notify_one(); |
| } |
| while (last_valid_current_element_ >= 0 && |
| !current_elements_[last_valid_current_element_]) { |
| last_valid_current_element_--; |
| if (cycle_index_ > last_valid_current_element_) { |
| // We are about to move the cycle index below in |
| // AdvanceToNextInCycle(). |
| cycle_index_ = last_valid_current_element_; |
| } |
| } |
| } |
| if (last_valid_current_element_ != -1) { |
| AdvanceToNextInCycle(); |
| } |
| } |
| } |
| |
| // Creates a new element. |
| std::shared_ptr<Element> MakeElement() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| if (end_of_input_) { |
| return nullptr; |
| } |
| auto element = std::make_shared<Element>(); |
| element->id = element_id_counter_++; |
| uninitialized_elements_.push_back(element); |
| return element; |
| } |
| |
| // Thread responsible for launching all worker threads. The thread stays |
| // around after startup in case autotuning increases num_parallel_calls. |
| void WorkerManagerThread() TF_LOCKS_EXCLUDED(mu_) { |
| RecordStart(ctx_.get()); |
| auto cleanup = gtl::MakeCleanup([this]() { |
| RecordStop(ctx_.get()); |
| mutex_lock l(*mu_); |
| DecrementOutstandingThreads(); |
| }); |
| int initial_current_workers; |
| // When elements are moved from `future_elements_` to `current_elements_`, |
| // the future worker which created the element may continue to process |
| // the element for some time. That is why we need an additional |
| // `cycle_length_` future workers to guarantee that whenever |
| // `future_element_.size() < future_elements_prefetch_`, there will be a |
| // future worker available to create a new future element. |
| int future_workers = |
| dataset()->prefetch_input_elements_ + dataset()->cycle_length_; |
| { |
| mutex_lock l(*mu_); |
| initial_current_workers = num_parallel_calls_->value; |
| outstanding_threads_ += initial_current_workers + future_workers; |
| num_current_workers_ += initial_current_workers; |
| num_active_workers_ += initial_current_workers + future_workers; |
| num_current_active_workers_ += initial_current_workers; |
| } |
| // Start current workers before future workers to improve startup time. |
| for (int i = 0; i < initial_current_workers; ++i) { |
| StartCurrentWorkerThread(); |
| } |
| for (int i = 0; i < future_workers; ++i) { |
| StartFutureWorkerThread(); |
| } |
| while (true) { |
| { |
| mutex_lock l(*mu_); |
| while (!cancelled_ && |
| num_current_workers_ >= num_parallel_calls_->value) { |
| RecordStop(ctx_.get()); |
| num_parallel_calls_cond_var_->wait(l); |
| RecordStart(ctx_.get()); |
| } |
| if (cancelled_ || end_of_input_) { |
| return; |
| } |
| IncrementOutstandingThreads(); |
| IncrementCurrentWorkers(); |
| IncrementActiveWorkers(); |
| IncrementCurrentActiveWorkers(); |
| StartCurrentWorkerThread(); |
| } |
| } |
| } |
| |
| void StartCurrentWorkerThread() { |
| thread_pool_->Schedule([this]() { CurrentWorkerThread(); }); |
| } |
| |
| void StartFutureWorkerThread() { |
| thread_pool_->Schedule([this]() { FutureWorkerThread(); }); |
| } |
| |
| // Current workers are responsible for keeping elements in |
| // `current_elements_` processed. An element is processed if it is either |
| // done or its `results` buffer is full (contains `kPerIteratorPrefetch` |
| // elements). |
| // |
| // Current workers cycle between two phases: (1) finding an element and (2) |
| // processing it. When a worker is processing an element, it will |
| // claim the element by setting `element->active`, then continue to produce |
| // results for the element until enough results have been computed for the |
| // current cycle and the results buffer is full. |
| void CurrentWorkerThread() TF_LOCKS_EXCLUDED(mu_) { |
| RecordStart(ctx_.get()); |
| auto done = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| RecordStop(ctx_.get()); |
| DecrementActiveWorkers(); |
| DecrementCurrentActiveWorkers(); |
| DecrementOutstandingThreads(); |
| DecrementCurrentWorkers(); |
| }; |
| while (true) { |
| int element_index; |
| std::shared_ptr<Element> element; |
| // Find an element to process. |
| { |
| mutex_lock l(*mu_); |
| // In case autotune changes num_parallel_calls. |
| if (num_current_workers_ > num_parallel_calls_->value) { |
| done(); |
| return; |
| } |
| // Look for an element that needs processing. |
| element.reset(); |
| while (!cancelled_) { |
| while (!elements_to_process_.empty() && !wait_for_checkpoint_) { |
| int index = elements_to_process_.front(); |
| elements_to_process_.pop_front(); |
| auto& e = current_elements_[index]; |
| if (NeedsProcessing(e) && !e->active) { |
| element_index = index; |
| element = e; |
| break; |
| } |
| } |
| if (element) { |
| break; |
| } |
| DecrementCurrentActiveWorkers(); |
| WaitWorkerThread(¤t_workers_cond_var_, &l); |
| IncrementCurrentActiveWorkers(); |
| } |
| if (cancelled_) { |
| done(); |
| return; |
| } |
| VLOG(3) << "Current worker woke up to process " << element->id; |
| element->active = true; |
| } |
| // Loop on the element until we fill its results buffer or reach end of |
| // input for the element. |
| while (true) { |
| ProcessElement(element); |
| { |
| mutex_lock l(*mu_); |
| // Check whether we have produced enough results for the current |
| // cycle. |
| if (!NeedsProcessing(element)) { |
| element->active = false; |
| break; |
| } |
| } |
| } |
| } |
| } |
| |
| // Future workers process elements after the current interleave cycle. A |
| // future worker's job is to keep `future_elements_` filled with elements. |
| // Elements in `future_elements` have had their first `kPerIteratorPrefetch` |
| // results computed. |
| void FutureWorkerThread() TF_LOCKS_EXCLUDED(mu_) { |
| RecordStart(ctx_.get()); |
| auto done = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| RecordStop(ctx_.get()); |
| DecrementActiveWorkers(); |
| DecrementOutstandingThreads(); |
| }; |
| std::shared_ptr<Element> element; |
| while (true) { |
| { |
| mutex_lock l(*mu_); |
| if (element) { |
| element->active = false; |
| if (element->cycle_index != -1) { |
| element->cond_var.notify_one(); |
| // A current worker may need to process the element further. |
| elements_to_process_.push_back(element->cycle_index); |
| current_workers_cond_var_.notify_one(); |
| } |
| } |
| while (!cancelled_ && (future_elements_.size() >= |
| dataset()->prefetch_input_elements_ || |
| wait_for_checkpoint_)) { |
| WaitWorkerThread(&future_workers_cond_var_, &l); |
| } |
| if (cancelled_) { |
| done(); |
| return; |
| } |
| element = MakeElement(); |
| if (!element) { |
| done(); |
| return; |
| } |
| VLOG(3) << "Future worker created element " << element->id; |
| element->active = true; |
| future_elements_.push_back(element); |
| } |
| ProcessElement(element); |
| } |
| } |
| |
| // Generates results for the given element until the element's results |
| // buffer is full or the element is done producing results. |
| void ProcessElement(std::shared_ptr<Element> element) |
| TF_LOCKS_EXCLUDED(mu_) { |
| DCHECK(element != nullptr); |
| IteratorBase* iterator; |
| int64_t input_element_id; |
| // Initialize the inputs and iterator if necessary. |
| { |
| mutex_lock l(*mu_); |
| DCHECK(element->active); |
| input_element_id = element->id; |
| if (!element->iterator) { |
| InitializeInputs(input_element_id); |
| if (!element->iterator) { |
| return; |
| } |
| } |
| // `iterator` will remain valid after releasing the lock because we have |
| // marked the element as active, so no other thread will modify its |
| // iterator. |
| iterator = element->iterator.get(); |
| } |
| DCHECK(iterator != nullptr); |
| // Process until the results queue is full or we reach end of input. |
| while (true) { |
| auto result = std::make_shared<Result>(); |
| profiler::TraceMe traceme([&] { |
| result->id = profiler::TraceMe::NewActivityId(); |
| return profiler::TraceMeEncode( |
| "ParallelInterleaveProduce", |
| {{"input_element_id", input_element_id}, |
| {"element_id", result->id}}); |
| }); |
| bool end_of_input = false; |
| result->status = iterator->GetNext(ctx_.get(), &result->return_values, |
| &end_of_input); |
| if (end_of_input) { |
| mutex_lock l(*mu_); |
| element->iterator.reset(); |
| element->inputs.reset(); |
| NotifyElementUpdate(element); |
| break; |
| } |
| RecordBufferEnqueue(ctx_.get(), result->return_values); |
| mutex_lock l(*mu_); |
| element->results.push_back(std::move(result)); |
| NotifyElementUpdate(element); |
| if (element->results.size() == dataset()->buffer_output_elements_) { |
| break; |
| } |
| } |
| } |
| |
| // Initialize inputs and create an iterator for all elements up to |
| // element_id. |
| void InitializeInputs(int element_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| while (!uninitialized_elements_.empty() && |
| uninitialized_elements_.front()->id <= element_id) { |
| std::shared_ptr<Element> element = uninitialized_elements_.front(); |
| uninitialized_elements_.pop_front(); |
| element->initialized = true; |
| // Check if we've already reached end of input. |
| if (end_of_input_) { |
| element->no_input = true; |
| NotifyElementUpdate(element); |
| continue; |
| } |
| profiler::TraceMe traceme([input_element_id = element->id] { |
| return profiler::TraceMeEncode( |
| "ParallelInterleaveInitializeInput", |
| {{"input_element_id", input_element_id}}); |
| }); |
| std::vector<Tensor> inputs; |
| Status status; |
| { |
| // TODO(aaudibert): Refactor the implementation to move calls of |
| // `GetNext` out of the scope of `mu_`. |
| status = input_impl_->GetNext(ctx_.get(), &inputs, &end_of_input_); |
| } |
| if (!status.ok()) { |
| AddErrorResult(element, status); |
| continue; |
| } |
| if (end_of_input_) { |
| element->no_input = true; |
| NotifyElementUpdate(element); |
| continue; |
| } |
| element->inputs = |
| absl::make_unique<std::vector<Tensor>>(std::move(inputs)); |
| IteratorContext::Params params(ctx_.get()); |
| params.interleave_depth += 1; |
| IteratorContext ctx(params); |
| status = MakeIteratorFromInputElement( |
| &ctx, this, *element->inputs, element->id, |
| *instantiated_captured_func_, prefix(), &element->iterator, |
| model_node()); |
| if (!status.ok()) { |
| element->inputs.reset(); |
| element->iterator.reset(); |
| AddErrorResult(element, status); |
| continue; |
| } |
| if (element->cycle_index == -1) { |
| DisableAutotune(ctx_.get(), element->iterator.get()); |
| } |
| } |
| } |
| |
| // Adds an error result for the given element. |
| void AddErrorResult(std::shared_ptr<Element> element, Status status) |
| TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| auto result = std::make_shared<Result>(); |
| result->status = status; |
| element->results.push_back(std::move(result)); |
| NotifyElementUpdate(element); |
| } |
| |
| // Cancels all threads (including the manager) and waits for them to finish. |
| void StopAllThreads(mutex_lock* l) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {} |
| |
| // Waits on the given cond_var in a worker thread. |
| void WaitWorkerThread(condition_variable* cond_var, mutex_lock* l) |
| TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| DecrementActiveWorkers(); |
| RecordStop(ctx_.get()); |
| cond_var->wait(*l); |
| RecordStart(ctx_.get()); |
| IncrementActiveWorkers(); |
| } |
| |
| void NotifyElementUpdate(std::shared_ptr<Element> element) |
| TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| if (deterministic_) { |
| element->cond_var.notify_one(); |
| } else { |
| any_element_available_cond_var_.notify_one(); |
| } |
| } |
| |
| bool NeedsProcessing(const std::shared_ptr<Element>& element) |
| TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| if (!element) { |
| return false; |
| } |
| if (!element->initialized) { |
| return true; |
| } |
| return element->iterator && |
| element->results.size() < dataset()->buffer_output_elements_; |
| } |
| |
| inline void IncrementCurrentWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| num_current_workers_++; |
| } |
| |
| inline void DecrementCurrentWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| num_current_workers_--; |
| } |
| |
| inline void IncrementActiveWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| num_active_workers_++; |
| } |
| |
| inline void DecrementActiveWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| num_active_workers_--; |
| if (num_active_workers_ == 0) { |
| zero_active_workers_cond_var_.notify_one(); |
| } |
| } |
| |
| inline void IncrementCurrentActiveWorkers() |
| TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| num_current_active_workers_++; |
| } |
| |
| inline void DecrementCurrentActiveWorkers() |
| TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| num_current_active_workers_--; |
| } |
| |
| inline void IncrementOutstandingThreads() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| outstanding_threads_++; |
| } |
| |
| inline void DecrementOutstandingThreads() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| outstanding_threads_--; |
| if (outstanding_threads_ == 0) { |
| outstanding_threads_finished_cond_var_.notify_one(); |
| } |
| } |
| |
| void StatsThread() { |
| for (int64_t step = 0;; ++step) { |
| int num_current_active_workers; |
| int num_current_workers; |
| { |
| mutex_lock l(*mu_); |
| if (step != 0 && !cancelled_) { |
| stats_thread_cond_var_.wait_for( |
| l, std::chrono::milliseconds(kStatsReportingPeriodMillis)); |
| } |
| if (cancelled_) { |
| DecrementOutstandingThreads(); |
| return; |
| } |
| num_current_active_workers = num_current_active_workers_; |
| num_current_workers = num_current_workers_; |
| } |
| if (num_current_workers == 0) { |
| // Avoid division by zero. |
| num_current_workers = 1; |
| } |
| ctx_->stats_aggregator()->AddScalar( |
| stats_utils::ThreadUtilizationScalarName(dataset()->node_name()), |
| static_cast<float>(num_current_active_workers) / |
| static_cast<float>(num_current_workers), |
| step); |
| } |
| } |
| |
| Status WriteStatusLocked(IteratorStateWriter* writer, |
| const string& iterator_name, size_t idx, |
| const Status& status) |
| TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| iterator_name, CodeKey(idx), static_cast<int64_t>(status.code()))); |
| if (!status.ok()) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| iterator_name, ErrorMessageKey(idx), status.error_message())); |
| } |
| return Status::OK(); |
| } |
| |
| Status ReadStatusLocked(IteratorStateReader* reader, |
| const string& iterator_name, size_t idx, |
| Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| int64_t code_int; |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(iterator_name, CodeKey(idx), &code_int)); |
| error::Code code = static_cast<error::Code>(code_int); |
| |
| if (code != error::Code::OK) { |
| tstring error_message; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| iterator_name, ErrorMessageKey(idx), &error_message)); |
| *status = Status(code, error_message); |
| } else { |
| *status = Status::OK(); |
| } |
| return Status::OK(); |
| } |
| |
| string CodeKey(size_t idx) { |
| return absl::StrCat(kResultsSuffix, "[", idx, "]", kCodeSuffix); |
| } |
| |
| string ErrorMessageKey(size_t idx) { |
| return absl::StrCat(kResultsSuffix, "[", idx, "]", kErrorMessageSuffix); |
| } |
| |
| Status WriteElement(SerializationContext* ctx, |
| std::shared_ptr<Element> element, int idx, |
| const string& key_prefix, IteratorStateWriter* writer) |
| TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { |
| const auto& iterator_name = |
| absl::StrCat(prefix(), "::", key_prefix, "::", idx); |
| if (element->iterator) { |
| TF_RETURN_IF_ERROR(SaveInput(ctx, writer, element->iterator)); |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(iterator_name, kIdSuffix, element->id)); |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| iterator_name, absl::StrCat(kInputsSuffix, kSizeSuffix), |
| element->inputs->size())); |
| for (int i = 0; i < element->inputs->size(); i++) { |
| TF_RETURN_IF_ERROR(writer->WriteTensor( |
| iterator_name, absl::StrCat(kInputsSuffix, "[", i, "]"), |
| element->inputs->at(i))); |
| } |
| } |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| iterator_name, absl::StrCat(kResultsSuffix, kSizeSuffix), |
| element->results.size())); |
| for (size_t i = 0; i < element->results.size(); i++) { |
| std::shared_ptr<Result> result = element->results[i]; |
| TF_RETURN_IF_ERROR( |
| WriteStatusLocked(writer, iterator_name, i, result->status)); |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| iterator_name, |
| absl::StrCat(kResultsSuffix, "[", i, "]", kSizeSuffix), |
| result->return_values.size())); |
| for (size_t j = 0; j < result->return_values.size(); j++) { |
| TF_RETURN_IF_ERROR(writer->WriteTensor( |
| iterator_name, absl::StrCat(kResultsSuffix, "[", i, "][", j, "]"), |
| result->return_values[j])); |
| } |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| iterator_name, |
| absl::StrCat(kResultsSuffix, "[", i, "]", kIsReadySuffix), "")); |
| } |
| return Status::OK(); |
| } |
| |
| Status WriteCurrentElements(SerializationContext* ctx, |
| IteratorStateWriter* writer) |
| TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCurrentElementsSize, |
| current_elements_.size())); |
| for (int idx = 0; idx < current_elements_.size(); idx++) { |
| if (current_elements_[idx]) { |
| TF_RETURN_IF_ERROR(WriteElement(ctx, current_elements_[idx], idx, |
| kCurrentElements, writer)); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status WriteFutureElements(SerializationContext* ctx, |
| IteratorStateWriter* writer) |
| TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kFutureElementsSize, |
| future_elements_.size())); |
| for (int idx = 0; idx < future_elements_.size(); idx++) { |
| if (future_elements_[idx]) { |
| TF_RETURN_IF_ERROR(WriteElement(ctx, future_elements_[idx], idx, |
| kFutureElements, writer)); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status ReadElement(IteratorContext* ctx, IteratorStateReader* reader, |
| int idx, const string& key_prefix, |
| std::shared_ptr<Element>* out) { |
| std::unique_ptr<IteratorBase> iterator; |
| auto element = std::make_shared<Element>(); |
| { |
| mutex_lock l(*mu_); |
| const auto& iterator_name = |
| absl::StrCat(prefix(), "::", key_prefix, "::", idx); |
| if (!reader->Contains(iterator_name, |
| absl::StrCat(kResultsSuffix, kSizeSuffix))) { |
| return Status::OK(); |
| } |
| int64_t results_size; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| iterator_name, absl::StrCat(kResultsSuffix, kSizeSuffix), |
| &results_size)); |
| element->results.resize(results_size); |
| for (size_t i = 0; i < results_size; i++) { |
| auto result = std::make_shared<Result>(); |
| TF_RETURN_IF_ERROR( |
| ReadStatusLocked(reader, iterator_name, i, &result->status)); |
| int64_t num_return_values; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| iterator_name, |
| absl::StrCat(kResultsSuffix, "[", i, "]", kSizeSuffix), |
| &num_return_values)); |
| result->return_values.reserve(num_return_values); |
| for (size_t j = 0; j < num_return_values; j++) { |
| result->return_values.emplace_back(); |
| TF_RETURN_IF_ERROR(reader->ReadTensor( |
| ctx->flr(), iterator_name, |
| absl::StrCat(kResultsSuffix, "[", i, "][", j, "]"), |
| &result->return_values.back())); |
| } |
| RecordBufferEnqueue(ctx, result->return_values); |
| element->results[i] = std::move(result); |
| } |
| if (!reader->Contains(iterator_name, |
| absl::StrCat(kInputsSuffix, kSizeSuffix))) { |
| element->iterator.reset(); |
| *out = std::move(element); |
| return Status::OK(); |
| } |
| int64_t inputs_size; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| iterator_name, absl::StrCat(kInputsSuffix, kSizeSuffix), |
| &inputs_size)); |
| element->inputs = std::make_unique<std::vector<Tensor>>(inputs_size); |
| for (int i = 0; i < inputs_size; i++) { |
| TF_RETURN_IF_ERROR( |
| reader->ReadTensor(ctx->flr(), iterator_name, |
| absl::StrCat(kInputsSuffix, "[", i, "]"), |
| &element->inputs->at(i))); |
| } |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(iterator_name, kIdSuffix, &element->id)); |
| IteratorContext::Params params(ctx); |
| params.interleave_depth += 1; |
| IteratorContext ctx_copy(params); |
| TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( |
| &ctx_copy, this, *element->inputs, element->id, |
| *instantiated_captured_func_.get(), prefix(), &iterator, |
| model_node())); |
| } |
| TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator)); |
| mutex_lock l(*mu_); |
| element->iterator = std::move(iterator); |
| *out = std::move(element); |
| return Status::OK(); |
| } |
| |
| Status ReadCurrentElements(IteratorContext* ctx, |
| IteratorStateReader* reader) { |
| int64_t size; |
| { |
| mutex_lock l(*mu_); |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(prefix(), kCurrentElementsSize, &size)); |
| if (current_elements_.size() != size) { |
| // This could mean two things: (1) the user created their checkpoint |
| // from a dataset with one cycle_length, then changed the cycle_length |
| // and tried to restore from the old checkpoint, or (2) the user set |
| // the cycle length to tf.data.AUTOTUNE, wrote the checkpoint from one |
| // machine, then tried to restore the checkpoint on another machine |
| // with a different CPU budget (causing autotune to pick a different |
| // cycle length). |
| return errors::FailedPrecondition( |
| "The iterator cycle length ", current_elements_.size(), |
| " is different from the cycle length to restore from the " |
| "checkpoint: ", |
| size); |
| } |
| } |
| if (size == 0) { |
| return Status::OK(); |
| } |
| std::vector<std::shared_ptr<Element>> elements; |
| TF_RETURN_IF_ERROR( |
| ReadElementsParallel(ctx, reader, size, kCurrentElements, &elements)); |
| mutex_lock l(*mu_); |
| for (auto& element : current_elements_) { |
| DCHECK(element == nullptr); |
| } |
| for (int idx = 0; idx < size; ++idx) { |
| current_elements_[idx] = std::move(elements[idx]); |
| } |
| return Status::OK(); |
| } |
| |
| Status ReadFutureElements(IteratorContext* ctx, |
| IteratorStateReader* reader) { |
| int64_t size; |
| { |
| mutex_lock l(*mu_); |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(prefix(), kFutureElementsSize, &size)); |
| future_elements_.resize(size); |
| } |
| if (size == 0) { |
| return Status::OK(); |
| } |
| std::vector<std::shared_ptr<Element>> elements; |
| TF_RETURN_IF_ERROR( |
| ReadElementsParallel(ctx, reader, size, kFutureElements, &elements)); |
| mutex_lock l(*mu_); |
| for (auto& element : future_elements_) { |
| DCHECK(element == nullptr); |
| } |
| for (int idx = 0; idx < size; ++idx) { |
| future_elements_[idx] = std::move(elements[idx]); |
| } |
| return Status::OK(); |
| } |
| |
| Status ReadElementsParallel( |
| IteratorContext* ctx, IteratorStateReader* reader, int64_t size, |
| const string& name, std::vector<std::shared_ptr<Element>>* elements) { |
| elements->resize(size); |
| Status s = Status::OK(); |
| BlockingCounter counter(size); |
| for (int idx = 0; idx < size; ++idx) { |
| thread_pool_->Schedule( |
| [this, ctx, reader, idx, name, &s, &counter, elements] { |
| RecordStart(ctx); |
| auto cleanup = gtl::MakeCleanup([this, ctx, &counter]() { |
| RecordStop(ctx); |
| counter.DecrementCount(); |
| }); |
| std::shared_ptr<Element> elem; |
| Status ret_status = ReadElement(ctx, reader, idx, name, &elem); |
| mutex_lock l(*mu_); |
| if (cancelled_) { |
| s.Update( |
| errors::Cancelled("Cancelled in ReadElementsParallel")); |
| return; |
| } |
| if (!ret_status.ok()) { |
| s.Update(ret_status); |
| return; |
| } |
| (*elements)[idx] = elem; |
| }); |
| } |
| counter.Wait(); |
| return s; |
| } |
| |
| std::string DebugString() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| std::string result; |
| result.append(strings::StrCat("Cycle index: ", cycle_index_, "\n")); |
| result.append(strings::StrCat("Block index: ", block_index_, "\n")); |
| result.append(strings::StrCat("End of input: ", end_of_input_, "\n")); |
| { |
| result.append("Current elements:\n"); |
| for (int i = 0; i < current_elements_.size(); ++i) { |
| string element_string = "null"; |
| if (current_elements_[i]) { |
| element_string = current_elements_[i]->DebugString(); |
| } |
| result.append(absl::StrFormat("%d: %s\n", i, element_string)); |
| } |
| } |
| { |
| result.append("Future elements:\n"); |
| for (int i = 0; i < future_elements_.size(); ++i) { |
| string element_string = "null"; |
| if (future_elements_[i]) { |
| element_string = future_elements_[i]->DebugString(); |
| } |
| result.append(absl::StrFormat("%d: %s\n", i, element_string)); |
| } |
| } |
| return result; |
| } |
| |
| // Indices of `current_elements_` which need to be processed by a current |
| // worker. |
| std::deque<int> elements_to_process_; |
| |
| // The last index in `current_elements_` containing a non-null element. |
| // This allows us to optimize the situation when the cycle_length is large |
| // but the input dataset doesn't have many elements. By tracking the index |
| // of the last valid element, GetNext can avoid checking many null entries |
| // each time through the cycle. |
| // |
| // TODO(aaudibert): Generalize this optimization by removing null elements |
| // from `current_elements_`, e.g. by compacting the vector when x% of |
| // its elements are null. |
| int64_t last_valid_current_element_ TF_GUARDED_BY(mu_) = -1; |
| |
| // Identifies whether the current_elements_ vector has been initialized. |
| bool initial_elements_created_ TF_GUARDED_BY(mu_) = false; |
| |
| // Identifies whether the element threads have been initialized. |
| bool threads_initialized_ TF_GUARDED_BY(mu_) = false; |
| |
| // Used for coordination between the main thread, the manager threads, and |
| // the worker threads. |
| // |
| // NOTE: We should never call GetNext on the input while holding this mutex. |
| const std::shared_ptr<mutex> mu_; |
| |
| // Condition variable for waking up current workers. |
| condition_variable current_workers_cond_var_; |
| |
| // Condition variable for waking up future workers. |
| condition_variable future_workers_cond_var_; |
| |
| // Condition variable for waking up the stats thread. |
| condition_variable stats_thread_cond_var_; |
| |
| // Number of active worker threads which might be processing elements, |
| // including both current workers and future workers. Used by |
| // checkpointing to wait for outstanding work to finish. |
| int num_active_workers_ TF_GUARDED_BY(mu_) = 0; |
| |
| // Number of active current worker threads. |
| int num_current_active_workers_ TF_GUARDED_BY(mu_) = 0; |
| |
| // Condition variable notified whenever the total number of active workers |
| // drops to zero. Used for checkpointing. |
| condition_variable zero_active_workers_cond_var_; |
| |
| // Condition notified whenever num_parallel_calls_ changes. Shared so that |
| // autotuning can notify us when num_parallel_calls_ changes. |
| std::shared_ptr<condition_variable> num_parallel_calls_cond_var_; |
| |
| // Identifies the maximum number of parallel calls. |
| const std::shared_ptr<model::SharedState> num_parallel_calls_; |
| |
| // The number of current workers currently alive or scheduled to be started. |
| // This includes current workers which are blocked waiting for work. |
| int num_current_workers_ TF_GUARDED_BY(mu_) = 0; |
| |
| // Condition variable to signal that a result has been produced by some |
| // element thread. Only used when `deterministic` is false. |
| condition_variable any_element_available_cond_var_; |
| |
| // Determines whether outputs can be produced in deterministic order. |
| const bool deterministic_; |
| |
| // Controls cancellation of `input_impl_`. Must be ordered before |
| // `input_impl_` so that `input_impl_` is destroyed first. |
| std::unique_ptr<CancellationManager> cancellation_manager_; |
| |
| // Iterator for input elements. |
| std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_); |
| |
| // Identifies position in the interleave cycle. |
| int64_t block_index_ TF_GUARDED_BY(mu_) = 0; |
| // It is an invariant that either `last_valid_current_element_ == -1` or |
| // `cycle_index_ <= last_valid_current_element_`. |
| int64_t cycle_index_ TF_GUARDED_BY(mu_) = 0; |
| |
| // Elements of the current interleave cycle. |
| std::vector<std::shared_ptr<Element>> current_elements_ TF_GUARDED_BY(mu_); |
| |
| // Elements which still need their inputs and iterators to be initialized. |
| // Elements at the front need to be initialized first. |
| std::deque<std::shared_ptr<Element>> uninitialized_elements_ |
| TF_GUARDED_BY(mu_); |
| |
| // Elements to be used in the interleave cycle in the future. The element |
| // at the front is the next element to add to the interleave cycle when a |
| // current element is exhausted. |
| std::deque<std::shared_ptr<Element>> future_elements_ TF_GUARDED_BY(mu_); |
| |
| // Identifies whether the global end of input has been reached. |
| bool end_of_input_ TF_GUARDED_BY(mu_) = false; |
| |
| // The number of outstanding element threads. |
| int outstanding_threads_ TF_GUARDED_BY(mu_) = 0; |
| |
| // Condition variable notified when outstanding_threads_ drops to 0. |
| condition_variable outstanding_threads_finished_cond_var_; |
| |
| std::unique_ptr<thread::ThreadPool> thread_pool_; |
| |
| int64_t element_id_counter_ TF_GUARDED_BY(mu_) = 0; |
| |
| // Iterator context used in worker threads. |
| std::unique_ptr<IteratorContext> ctx_; |
| |
| // Set to true during checkpointing to alert element threads that they |
| // should pause operation. This is needed to prevent constantly-active |
| // worker threads from blocking checkpointing indefinitely. |
| bool wait_for_checkpoint_ = false; |
| |
| std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_; |
| |
| // Identifies whether background threads should be cancelled. |
| bool cancelled_ TF_GUARDED_BY(mu_) = false; |
| }; |
| |
| const DatasetBase* const input_; |
| const std::unique_ptr<CapturedFunction> captured_func_; |
| const int64_t cycle_length_; |
| const int64_t block_length_; |
| const int64_t buffer_output_elements_; |
| const int64_t prefetch_input_elements_; |
| const int64_t num_parallel_calls_; |
| const DeterminismPolicy deterministic_; |
| const DataTypeVector output_types_; |
| const std::vector<PartialTensorShape> output_shapes_; |
| const int op_version_; |
| const TraceMeMetadata traceme_metadata_; |
| }; |
| |
| ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp( |
| OpKernelConstruction* ctx) |
| : UnaryDatasetOpKernel(ctx), |
| op_version_(OpVersionFromOpName(ctx->def().op())) { |
| OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{}, |
| &func_metadata_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); |
| if (op_version_ == 2) { |
| bool sloppy; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kSloppy, &sloppy)); |
| if (sloppy) { |
| deterministic_ = |
| DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic); |
| } else { |
| deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault); |
| } |
| } |
| if (op_version_ >= 3) { |
| std::string deterministic; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic)); |
| OP_REQUIRES_OK( |
| ctx, DeterminismPolicy::FromString(deterministic, &deterministic_)); |
| } |
| } |
| |
| void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx, |
| DatasetBase* input, |
| DatasetBase** output) { |
| int64_t block_length = 0; |
| OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length)); |
| OP_REQUIRES(ctx, block_length > 0, |
| errors::InvalidArgument("`block_length` must be > 0")); |
| |
| int64_t buffer_output_elements = model::kAutotune; |
| int64_t prefetch_input_elements = model::kAutotune; |
| if (op_version_ >= 4) { |
| OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferOutputElements, |
| &buffer_output_elements)); |
| OP_REQUIRES(ctx, |
| buffer_output_elements == model::kAutotune || |
| buffer_output_elements > 0, |
| errors::InvalidArgument("`buffer_output_elements` must be ", |
| model::kAutotune, " or > 0 but is ", |
| buffer_output_elements)); |
| |
| OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements, |
| &prefetch_input_elements)); |
| OP_REQUIRES(ctx, |
| prefetch_input_elements == model::kAutotune || |
| prefetch_input_elements >= 0, |
| errors::InvalidArgument("`prefetch_input_elements` must be ", |
| model::kAutotune, " or >= 0 but is ", |
| prefetch_input_elements)); |
| } |
| |
| int64_t num_parallel_calls = 0; |
| OP_REQUIRES_OK( |
| ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls)); |
| OP_REQUIRES( |
| ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutotune, |
| errors::InvalidArgument("num_parallel_calls must be greater than zero.")); |
| int64_t cycle_length = 0; |
| OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length)); |
| if (cycle_length == model::kAutotune) { |
| if (num_parallel_calls != model::kAutotune) { |
| cycle_length = std::min(num_parallel_calls, |
| static_cast<int64_t>(port::MaxParallelism())); |
| } else { |
| // If parallelism is to be autotuned, we set the cycle length so that |
| // the number of thread created for the current and future cycle elements |
| // roughly matches the number of schedulable cores. |
| const int num_threads_per_cycle_length = kDefaultCyclePrefetchFactor + 1; |
| cycle_length = |
| CeilDiv(port::MaxParallelism(), num_threads_per_cycle_length); |
| } |
| } |
| OP_REQUIRES(ctx, cycle_length > 0, |
| errors::InvalidArgument("`cycle_length` must be > 0")); |
| |
| OP_REQUIRES( |
| ctx, num_parallel_calls <= cycle_length, |
| errors::InvalidArgument( |
| "num_parallel_calls must less than or equal to cycle_length.")); |
| |
| std::unique_ptr<CapturedFunction> captured_func; |
| OP_REQUIRES_OK(ctx, |
| CapturedFunction::Create(ctx, func_metadata_, kOtherArguments, |
| &captured_func)); |
| |
| if (num_parallel_calls == model::kAutotune) { |
| metrics::RecordTFDataAutotune(kDatasetType); |
| } |
| |
| *output = new Dataset( |
| ctx, input, std::move(captured_func), cycle_length, block_length, |
| buffer_output_elements, prefetch_input_elements, num_parallel_calls, |
| deterministic_, output_types_, output_shapes_, op_version_); |
| } |
| |
| namespace { |
| REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV2).Device(DEVICE_CPU), |
| ParallelInterleaveDatasetOp); |
| REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV3).Device(DEVICE_CPU), |
| ParallelInterleaveDatasetOp); |
| REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV4).Device(DEVICE_CPU), |
| ParallelInterleaveDatasetOp); |
| REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV2); |
| REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV3); |
| REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV4); |
| } // namespace |
| } // namespace data |
| } // namespace tensorflow |