| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #include "tensorflow/core/kernels/data/shuffle_dataset_op.h" |
| |
| #include <deque> |
| #include <tuple> |
| #include <vector> |
| |
| #include "tensorflow/core/framework/partial_tensor_shape.h" |
| #include "tensorflow/core/framework/resource_mgr.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/kernels/data/name_utils.h" |
| #include "tensorflow/core/kernels/data/random_seed_ops.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/random/philox_random.h" |
| #include "tensorflow/core/lib/random/random.h" |
| #include "tensorflow/core/lib/random/random_distributions.h" |
| |
| namespace tensorflow { |
| namespace data { |
| |
| // See documentation in ../../ops/dataset_ops.cc for a high-level |
| // description of the following op. |
| |
| /* static */ constexpr const char* const ShuffleDatasetOpBase::kInputDataset; |
| /* static */ constexpr const char* const ShuffleDatasetOpBase::kBufferSize; |
| /* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed; |
| /* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed2; |
| /* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputTypes; |
| /* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputShapes; |
| |
| /* static */ constexpr const char* const ShuffleDatasetOp::kDatasetType; |
| /* static */ constexpr const char* const |
| ShuffleDatasetOp::kReshuffleEachIteration; |
| |
| /* static */ constexpr const char* const |
| ShuffleAndRepeatDatasetOp::kDatasetType; |
| /* static */ constexpr const char* const ShuffleAndRepeatDatasetOp::kCount; |
| |
| const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds. |
| const int64 kMaxEpochsInBuffer = 3; |
| |
| constexpr char kNumRandomSamples[] = "num_random_samples"; |
| constexpr char kEndOfInputSequence[] = "end_of_input_sequence"; |
| constexpr char kEpoch[] = "epoch"; |
| constexpr char kNumElements[] = "num_elements"; |
| constexpr char kSlicesSize[] = "slices_size"; |
| constexpr char kSlicesStart[] = "slices_start"; |
| constexpr char kSlicesEnd[] = "slices_end"; |
| constexpr char kBuffer[] = "buffer"; |
| constexpr char kSize[] = "size"; |
| constexpr char kRandomSeedGenerator[] = "RandomSeedGenerator"; |
| constexpr char kTFData[] = "tf_data"; |
| constexpr char kDSNumRandomSamples[] = "ds_num_random_samples"; |
| constexpr char kFixedSeedDatasetPrefix[] = "FixedSeed"; |
| constexpr char kReshufflingDatasetPrefix[] = "Reshuffling"; |
| constexpr char kShuffleDataset[] = "ShuffleDataset"; |
| |
| ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx) |
| : UnaryDatasetOpKernel(ctx) {} |
| |
| // Abstract base dataset that implements a shuffling iterator. |
| class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { |
| public: |
| ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input, |
| int64 buffer_size, int64 count) |
| : DatasetBase(DatasetContext(ctx)), |
| input_(input), |
| buffer_size_(buffer_size), |
| count_(count) { |
| input_->Ref(); |
| } |
| |
| ~ShuffleDatasetBase() override { input_->Unref(); } |
| |
| const DataTypeVector& output_dtypes() const override { |
| return input_->output_dtypes(); |
| } |
| |
| const std::vector<PartialTensorShape>& output_shapes() const override { |
| return input_->output_shapes(); |
| } |
| |
| int64 Cardinality() const override { |
| if (count_ == -1 || input_->Cardinality() == kInfiniteCardinality) { |
| return kInfiniteCardinality; |
| } else if (input_->Cardinality() == kUnknownCardinality) { |
| return kUnknownCardinality; |
| } else { |
| return input_->Cardinality() * count_; |
| } |
| } |
| |
| Status CheckExternalState() const override { |
| return input_->CheckExternalState(); |
| } |
| |
| protected: |
| template <class T> |
| class Iterator : public DatasetIterator<T> { |
| public: |
| explicit Iterator(const typename DatasetIterator<T>::Params& params, |
| int64 seed, int64 seed2) |
| : DatasetIterator<T>(params), |
| seed_(seed), |
| seed2_(seed2), |
| input_impl_(nullptr), |
| epoch_(0), |
| num_elements_(0), |
| parent_generator_(seed, seed2), |
| generator_(&parent_generator_) { |
| buffer_ = absl::make_unique<std::vector<Tensor>[]>( |
| params.dataset->buffer_size_); |
| slices_.push_back(absl::make_unique<Slice>(0, 0)); |
| } |
| |
| Status GetNextInternal(IteratorContext* ctx, |
| std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) override { |
| mutex_lock l(mu_); |
| int64 start_micros = ctx->env()->NowMicros(); |
| int64 num_log_entries = 0; |
| bool first_call = false; |
| if (!input_impl_ && epoch_ == 0) { |
| first_call = true; |
| TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator( |
| ctx, this->prefix(), &input_impl_)); |
| } |
| while (input_impl_ && num_elements_ < this->dataset()->buffer_size_) { |
| if (ctx->env()->NowMicros() > |
| ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) { |
| num_log_entries++; |
| LOG(INFO) << "Filling up shuffle buffer (this may take a while): " |
| << num_elements_ << " of " << this->dataset()->buffer_size_; |
| } |
| std::vector<Tensor> input_element; |
| bool end_of_input_sequence = false; |
| while (this->dataset()->count_ == -1 || |
| epoch_ < this->dataset()->count_) { |
| TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element, |
| &end_of_input_sequence)); |
| if (!end_of_input_sequence) { |
| first_call = false; |
| break; |
| } |
| if (first_call && this->dataset()->count_ == -1) { |
| // If the first call to GetNext() fails because the end |
| // of sequence has been reached, we terminate the |
| // iteration immediately. (Otherwise, this iterator |
| // would loop infinitely and never produce a value.) |
| *end_of_sequence = true; |
| return Status::OK(); |
| } |
| epoch_++; |
| int64 n = slices_.back()->end; |
| slices_.push_back(absl::make_unique<Slice>(n, n)); |
| TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator( |
| ctx, this->prefix(), &input_impl_)); |
| } |
| if (!end_of_input_sequence) { |
| if (num_elements_ == 0) { |
| VLOG(1) << "Starting to fill up shuffle buffer of size: " |
| << this->dataset()->buffer_size_; |
| } |
| this->RecordBufferEnqueue(ctx, input_element); |
| buffer_[slices_.back()->end % this->dataset()->buffer_size_] = |
| std::move(input_element); |
| num_elements_++; |
| slices_.back()->end++; |
| } else { |
| input_impl_.reset(); |
| } |
| if (slices_.size() > kMaxEpochsInBuffer) { |
| // When the elements stored in `buffer_` span more than |
| // `kMaxEpochsInBuffer` epochs, we do not fill the buffer further to |
| // conserve memory. This means that the upper bound on the size of |
| // `buffer_` is `kMaxEpochsInBuffer * cardinality(input_dataset) + |
| // 1`. |
| break; |
| } |
| } |
| if (num_log_entries > 0) { |
| LOG(INFO) << "Shuffle buffer filled."; |
| } |
| |
| if (num_elements_ > 0) { |
| *end_of_sequence = false; |
| // Garbage collect all empty slices. |
| while (!slices_.empty() && |
| slices_.front()->start == slices_.front()->end) { |
| slices_.pop_front(); |
| } |
| DCHECK(!slices_.empty()); |
| // Choose an element to produce uniformly at random from the first |
| // slice, and then remove the element from the slice. |
| int64 offset = |
| Random() % (slices_.front()->end - slices_.front()->start); |
| int64 index = |
| (slices_.front()->start + offset) % this->dataset()->buffer_size_; |
| *out_tensors = std::move(buffer_[index]); |
| this->RecordBufferDequeue(ctx, *out_tensors); |
| std::swap( |
| buffer_[index], |
| buffer_[slices_.front()->start % this->dataset()->buffer_size_]); |
| slices_.front()->start++; |
| num_elements_--; |
| } else { |
| DCHECK(input_impl_ == nullptr); |
| *end_of_sequence = true; |
| } |
| return Status::OK(); |
| } |
| |
| protected: |
| std::shared_ptr<model::Node> CreateNode( |
| IteratorContext* ctx, model::Node::Args args) const override { |
| return model::MakeKnownRatioNode(std::move(args), |
| /*ratio=*/1); |
| } |
| |
| void ResetRngs() EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| // Reset the generators based on the current iterator seeds. |
| parent_generator_ = random::PhiloxRandom(seed_, seed2_); |
| generator_ = |
| random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_); |
| generator_.Skip(num_random_samples_); |
| } |
| |
| Status SaveInternal(IteratorStateWriter* writer) override { |
| mutex_lock l(mu_); |
| // Save state needed to restore the random number generators. |
| TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kNumRandomSamples), |
| num_random_samples_)); |
| TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed), seed_)); |
| TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed2), seed2_)); |
| |
| // Save input iterator if it hasn't been exhausted else write |
| // "end_of_input_sequence". |
| if (!input_impl_) { |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(this->full_name(kEndOfInputSequence), "")); |
| } else { |
| TF_RETURN_IF_ERROR(this->SaveInput(writer, input_impl_)); |
| } |
| |
| // Save the epoch counter, buffer, and buffer slices. |
| TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kEpoch), epoch_)); |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(this->full_name(kNumElements), num_elements_)); |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(this->full_name(kSlicesSize), slices_.size())); |
| for (size_t i = 0; i < slices_.size(); ++i) { |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(this->full_name(absl::StrJoin( |
| std::make_tuple(kSlicesStart, i), "_")), |
| slices_[i]->start)); |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| this->full_name(absl::StrJoin(std::make_tuple(kSlicesEnd, i), "_")), |
| slices_[i]->end)); |
| for (size_t j = slices_[i]->start; j < slices_[i]->end; ++j) { |
| size_t index = j % this->dataset()->buffer_size_; |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| this->full_name( |
| absl::StrJoin(std::make_tuple(kBuffer, index, kSize), "_")), |
| buffer_[index].size())); |
| for (size_t k = 0; k < buffer_[index].size(); ++k) { |
| TF_RETURN_IF_ERROR(writer->WriteTensor( |
| this->full_name( |
| absl::StrJoin(std::make_tuple(kBuffer, index, k), "_")), |
| buffer_[index][k])); |
| } |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status RestoreInternal(IteratorContext* ctx, |
| IteratorStateReader* reader) override { |
| mutex_lock l(mu_); |
| // Restore the random number generators. |
| TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kNumRandomSamples), |
| &num_random_samples_)); |
| TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed), &seed_)); |
| TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed2), &seed2_)); |
| ResetRngs(); |
| |
| // Restore the input iterator if it wasn't already exhausted. |
| if (!reader->Contains(this->full_name(kEndOfInputSequence))) { |
| TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator( |
| ctx, this->prefix(), &input_impl_)); |
| TF_RETURN_IF_ERROR(this->RestoreInput(ctx, reader, input_impl_)); |
| } else { |
| input_impl_.reset(); |
| } |
| |
| // Restore the epoch counter, buffer, and buffer slices. |
| TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kEpoch), &epoch_)); |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(this->full_name(kNumElements), &num_elements_)); |
| size_t slices_size; |
| { |
| int64 temp; |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(this->full_name(kSlicesSize), &temp)); |
| slices_size = static_cast<size_t>(temp); |
| } |
| buffer_ = absl::make_unique<std::vector<Tensor>[]>( |
| this->dataset()->buffer_size_); |
| for (size_t i = 0; i < slices_size; ++i) { |
| int64 start; |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(this->full_name(absl::StrJoin( |
| std::make_tuple(kSlicesStart, i), "_")), |
| &start)); |
| int64 end; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| this->full_name(absl::StrJoin(std::make_tuple(kSlicesEnd, i), "_")), |
| &end)); |
| slices_.push_back(absl::make_unique<Slice>(start, end)); |
| for (size_t j = start; j < end; ++j) { |
| size_t index = j % this->dataset()->buffer_size_; |
| int64 list_size; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| this->full_name( |
| absl::StrJoin(std::make_tuple(kBuffer, index, kSize), "_")), |
| &list_size)); |
| buffer_[index] = std::vector<Tensor>(list_size); |
| for (int k = 0; k < list_size; ++k) { |
| TF_RETURN_IF_ERROR(reader->ReadTensor( |
| this->full_name( |
| absl::StrJoin(std::make_tuple(kBuffer, index, k), "_")), |
| &buffer_[index][k])); |
| } |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| mutex mu_; |
| int64 seed_ GUARDED_BY(mu_); |
| int64 seed2_ GUARDED_BY(mu_); |
| |
| private: |
| // Used to represent slices of `buffer_` that belong to different epochs. |
| // The invariant maintained by the implementation is: `start` <= `end`. |
| // When using `start` and `end` to index into `buffer_`, their values |
| // should be taken modulo the size of `buffer_` as their absolute value |
| // can be greater than the range of `buffer_`. |
| struct Slice { |
| Slice(int64 start, int64 end) : start(start), end(end) {} |
| |
| int64 start; |
| int64 end; |
| }; |
| |
| random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random() |
| EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| num_random_samples_++; |
| auto out = generator_(); |
| return out; |
| } |
| |
| std::unique_ptr<std::vector<Tensor>[]> buffer_ GUARDED_BY(mu_); |
| std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); |
| int64 epoch_ GUARDED_BY(mu_); |
| int64 num_elements_ GUARDED_BY(mu_); |
| std::deque<std::unique_ptr<Slice>> slices_ GUARDED_BY(mu_); |
| random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); |
| random::SingleSampleAdapter<random::PhiloxRandom> generator_ |
| GUARDED_BY(mu_); |
| int64 num_random_samples_ GUARDED_BY(mu_) = 0; |
| }; |
| |
| const DatasetBase* const input_; |
| const int64 buffer_size_; |
| const int64 count_; |
| }; |
| |
| // A dataset that uses a pseudorandom sequence of seeds for the iterators |
| // created from it. Used when `reshuffle_each_iteration` is true. |
| class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase { |
| public: |
| ReshufflingDataset(OpKernelContext* ctx, const DatasetBase* input, |
| int64 buffer_size, int64 seed, int64 seed2, int64 count) |
| : ShuffleDatasetBase(ctx, input, buffer_size, count), |
| seed_(seed), |
| seed2_(seed2) {} |
| |
| string DebugString() const override { |
| name_utils::DatasetDebugStringParams params; |
| params.dataset_prefix = kReshufflingDatasetPrefix; |
| params.set_args(buffer_size_, seed_, seed2_); |
| return name_utils::DatasetDebugString(kDatasetType, params); |
| } |
| |
| std::unique_ptr<IteratorBase> MakeIteratorInternal( |
| const string& prefix) const override { |
| return absl::make_unique<Iterator>( |
| Iterator::Params{this, |
| name_utils::IteratorPrefix(kDatasetType, prefix)}, |
| seed_, seed2_); |
| } |
| |
| protected: |
| class Iterator : public ShuffleDatasetBase::Iterator<ReshufflingDataset> { |
| public: |
| Iterator(const Params& params, int64 seed, int64 seed2) |
| : ShuffleDatasetBase::Iterator<ReshufflingDataset>(params, seed, |
| seed2) {} |
| |
| ~Iterator() override { seed_generator_->Unref(); } |
| |
| Status Initialize(IteratorContext* ctx) override { |
| // Firstly, lookup or create a seed generator from the IteratorResource |
| // resource_mgr. |
| ResourceMgr* mgr = ctx->resource_mgr(); |
| RandomSeedGenerator* seed_generator; |
| const string name = strings::StrCat( |
| prefix(), name_utils::kDelimiter, dataset()->type_string(), |
| name_utils::kDelimiter, kRandomSeedGenerator); |
| |
| int64 dataset_seed, dataset_seed2; |
| { |
| tf_shared_lock l(mu_); |
| // Ideally we'd like to hold this lock in the LookupOrCreate method, |
| // but that trips up our Deadlock detection code. |
| dataset_seed = seed_; |
| dataset_seed2 = seed2_; |
| } |
| TF_RETURN_IF_ERROR(mgr->LookupOrCreate<RandomSeedGenerator>( |
| kTFData, name, &seed_generator, |
| [dataset_seed, dataset_seed2](RandomSeedGenerator** seed_generator) { |
| // On the first iterator creation, use the original seeds from the |
| // dataset to seed a `RandomSeedGenerator` that will provide seeds |
| // for subsequent repetitions of the same dataset. |
| *seed_generator = |
| new RandomSeedGenerator(dataset_seed, dataset_seed2); |
| return Status::OK(); |
| })); |
| seed_generator_ = seed_generator; |
| seed_generator_->GenerateRandomSeeds(&seed_, &seed2_); |
| mutex_lock l(mu_); |
| ResetRngs(); |
| return Status::OK(); |
| } |
| |
| protected: |
| std::shared_ptr<model::Node> CreateNode( |
| IteratorContext* ctx, model::Node::Args args) const override { |
| return model::MakeKnownRatioNode(std::move(args), |
| /*ratio=*/1); |
| } |
| |
| Status SaveInternal(IteratorStateWriter* writer) override { |
| // Save RNG state of Dataset. |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(full_name(kDSNumRandomSamples), |
| seed_generator_->num_random_samples())); |
| |
| // Save the Iterator. |
| return ShuffleDatasetBase::Iterator<ReshufflingDataset>::SaveInternal( |
| writer); |
| } |
| |
| Status RestoreInternal(IteratorContext* ctx, |
| IteratorStateReader* reader) override { |
| // Restore RNG state of Dataset. |
| int64 num_random_samples; |
| TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kDSNumRandomSamples), |
| &num_random_samples)); |
| seed_generator_->set_num_random_samples(num_random_samples); |
| seed_generator_->Reset(); |
| |
| // Restore the Iterator. |
| return ShuffleDatasetBase::Iterator<ReshufflingDataset>::RestoreInternal( |
| ctx, reader); |
| } |
| |
| private: |
| RandomSeedGenerator* seed_generator_; |
| }; |
| |
| Status AsGraphDefInternal(SerializationContext* ctx, |
| DatasetGraphDefBuilder* b, |
| Node** output) const override { |
| Node* input_graph_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); |
| Node* buffer_size = nullptr; |
| Node* seed = nullptr; |
| Node* seed2 = nullptr; |
| AttrValue reshuffle_each_iteration; |
| |
| TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); |
| TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed)); |
| TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2)); |
| b->BuildAttrValue(true, &reshuffle_each_iteration); |
| TF_RETURN_IF_ERROR(b->AddDataset( |
| this, {input_graph_node, buffer_size, seed, seed2}, // Inputs |
| {std::make_pair(kReshuffleEachIteration, |
| reshuffle_each_iteration)}, // Attrs |
| output)); |
| return Status::OK(); |
| } |
| |
| private: |
| const int64 seed_; |
| const int64 seed2_; |
| }; |
| |
| // A dataset that uses a pseudorandom sequence of seeds for the iterators |
| // created from it. Used in TF 2.0 when `reshuffle_each_iteration` is true. |
| class ShuffleDatasetOp::ReshufflingDatasetV2 : public ShuffleDatasetBase { |
| public: |
| ReshufflingDatasetV2(OpKernelContext* ctx, const DatasetBase* input, |
| int64 buffer_size, int64 count, |
| const Tensor& resource_handle, |
| RandomSeedGenerator* seed_generator) |
| : ShuffleDatasetBase(ctx, input, buffer_size, count), |
| resource_handle_(resource_handle), |
| seed_generator_(seed_generator) {} |
| |
| ~ReshufflingDatasetV2() override { seed_generator_->Unref(); } |
| |
| string DebugString() const override { |
| name_utils::DatasetDebugStringParams params; |
| params.dataset_prefix = kReshufflingDatasetPrefix; |
| params.set_args(buffer_size_); |
| return name_utils::DatasetDebugString(kDatasetType, params); |
| } |
| |
| Status CheckExternalState() const override { |
| return errors::FailedPrecondition( |
| DebugString(), " depends on random seed generator resource."); |
| } |
| |
| std::unique_ptr<IteratorBase> MakeIteratorInternal( |
| const string& prefix) const override { |
| return absl::make_unique<Iterator>( |
| Iterator::Params{this, |
| name_utils::IteratorPrefix(kDatasetType, prefix)}, |
| seed_generator_); |
| } |
| |
| protected: |
| class Iterator : public ShuffleDatasetBase::Iterator<ReshufflingDatasetV2> { |
| public: |
| Iterator(const Params& params, RandomSeedGenerator* seed_generator) |
| : ShuffleDatasetBase::Iterator<ReshufflingDatasetV2>(params, 0, 0), |
| seed_generator_(seed_generator) {} |
| |
| Status Initialize(IteratorContext* ctx) override { |
| mutex_lock l(mu_); |
| seed_generator_->GenerateRandomSeeds(&seed_, &seed2_); |
| ResetRngs(); |
| return Status::OK(); |
| } |
| |
| protected: |
| std::shared_ptr<model::Node> CreateNode( |
| IteratorContext* ctx, model::Node::Args args) const override { |
| return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1); |
| } |
| |
| Status SaveInternal(IteratorStateWriter* writer) override { |
| // Save state of the seed generator. |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(full_name(kDSNumRandomSamples), |
| seed_generator_->num_random_samples())); |
| |
| // Save the tterator state. |
| return ShuffleDatasetBase::Iterator<ReshufflingDatasetV2>::SaveInternal( |
| writer); |
| } |
| |
| Status RestoreInternal(IteratorContext* ctx, |
| IteratorStateReader* reader) override { |
| // Restore state of the seed generator. |
| int64 num_random_samples; |
| TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kDSNumRandomSamples), |
| &num_random_samples)); |
| seed_generator_->set_num_random_samples(num_random_samples); |
| seed_generator_->Reset(); |
| |
| // Restore the iterator state. |
| return ShuffleDatasetBase::Iterator< |
| ReshufflingDatasetV2>::RestoreInternal(ctx, reader); |
| } |
| |
| private: |
| RandomSeedGenerator* seed_generator_; |
| }; |
| |
| Status AsGraphDefInternal(SerializationContext* ctx, |
| DatasetGraphDefBuilder* b, |
| Node** output) const override { |
| Node* input_graph_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); |
| Node* buffer_size_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node)); |
| Node* resource_handle_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node)); |
| TF_RETURN_IF_ERROR(b->AddDataset( |
| this, |
| {input_graph_node, buffer_size_node, resource_handle_node}, // Inputs |
| {}, // Attrs |
| output)); |
| return Status::OK(); |
| } |
| |
| private: |
| const Tensor resource_handle_; |
| RandomSeedGenerator* seed_generator_ = nullptr; |
| }; |
| |
| // A dataset that uses the same fixed seed for all iterators created from it. |
| // Used when `reshuffle_each_iteration` is false. |
| class ShuffleDatasetOp::FixedSeedDataset : public ShuffleDatasetBase { |
| public: |
| FixedSeedDataset(OpKernelContext* ctx, const DatasetBase* input, |
| int64 buffer_size, int64 seed, int64 seed2, int64 count) |
| : ShuffleDatasetBase(ctx, input, buffer_size, count), |
| seed_(seed), |
| seed2_(seed2) {} |
| |
| string DebugString() const override { |
| name_utils::DatasetDebugStringParams params; |
| params.dataset_prefix = kFixedSeedDatasetPrefix; |
| params.set_args(buffer_size_, seed_, seed2_); |
| return name_utils::DatasetDebugString(kDatasetType, params); |
| } |
| |
| std::unique_ptr<IteratorBase> MakeIteratorInternal( |
| const string& prefix) const override { |
| return absl::make_unique<ShuffleDatasetBase::Iterator<ShuffleDatasetBase>>( |
| ShuffleDatasetBase::Iterator<ShuffleDatasetBase>::Params{ |
| this, name_utils::IteratorPrefix(kDatasetType, prefix)}, |
| seed_, seed2_); |
| } |
| |
| protected: |
| Status AsGraphDefInternal(SerializationContext* ctx, |
| DatasetGraphDefBuilder* b, |
| Node** output) const override { |
| Node* input_graph_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); |
| Node* buffer_size = nullptr; |
| Node* seed = nullptr; |
| Node* seed2 = nullptr; |
| AttrValue reshuffle_each_iteration; |
| |
| TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); |
| TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed)); |
| TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2)); |
| b->BuildAttrValue(false, &reshuffle_each_iteration); |
| TF_RETURN_IF_ERROR(b->AddDataset( |
| this, {input_graph_node, buffer_size, seed, seed2}, // Inputs |
| {std::make_pair(kReshuffleEachIteration, |
| reshuffle_each_iteration)}, // Attrs |
| output)); |
| return Status::OK(); |
| } |
| |
| private: |
| const int64 seed_; |
| const int64 seed2_; |
| }; |
| |
| ShuffleDatasetOp::ShuffleDatasetOp(OpKernelConstruction* ctx) |
| : ShuffleDatasetOpBase(ctx), |
| op_version_(ctx->def().op() == kShuffleDataset ? 1 : 2) { |
| if (ctx->HasAttr(kReshuffleEachIteration)) { |
| OP_REQUIRES_OK( |
| ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_)); |
| } |
| } |
| |
| void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, |
| DatasetBase** output) { |
| int64 buffer_size = 0; |
| OP_REQUIRES_OK(ctx, |
| ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size)); |
| OP_REQUIRES( |
| ctx, buffer_size > 0, |
| errors::InvalidArgument("buffer_size must be greater than zero.")); |
| |
| int64 count = 1; |
| if (op_version_ == 2) { |
| RandomSeedGenerator* seed_generator = nullptr; |
| OP_REQUIRES_OK( |
| ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &seed_generator)); |
| // Transferring ownership of seed generator reference onto |
| // `ReshufflingDatasetV2`. |
| *output = new ReshufflingDatasetV2(ctx, input, buffer_size, count, |
| ctx->input(2), seed_generator); |
| return; |
| } |
| |
| int64 seed; |
| OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed)); |
| |
| int64 seed2; |
| OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2)); |
| |
| // By TensorFlow convention, passing 0 for both seeds indicates |
| // that the shuffling should be seeded non-deterministically. |
| if (seed == 0 && seed2 == 0) { |
| seed = random::New64(); |
| seed2 = random::New64(); |
| } |
| |
| if (reshuffle_each_iteration_) { |
| *output = |
| new ReshufflingDataset(ctx, input, buffer_size, seed, seed2, count); |
| } else { |
| *output = new FixedSeedDataset(ctx, input, buffer_size, seed, seed2, count); |
| } |
| } |
| |
| class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase { |
| public: |
| Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size, |
| int64 seed, int64 seed2, int64 count) |
| : ShuffleDatasetBase(ctx, input, buffer_size, count), |
| seed_(seed), |
| seed2_(seed2) {} |
| |
| string DebugString() const override { |
| name_utils::DatasetDebugStringParams params; |
| params.set_args(buffer_size_, seed_, seed2_); |
| return name_utils::DatasetDebugString(kDatasetType, params); |
| } |
| |
| std::unique_ptr<IteratorBase> MakeIteratorInternal( |
| const string& prefix) const override { |
| return absl::make_unique<ShuffleDatasetBase::Iterator<ShuffleDatasetBase>>( |
| ShuffleDatasetBase::Iterator<ShuffleDatasetBase>::Params{ |
| this, name_utils::IteratorPrefix(kDatasetType, prefix)}, |
| seed_, seed2_); |
| } |
| |
| protected: |
| Status AsGraphDefInternal(SerializationContext* ctx, |
| DatasetGraphDefBuilder* b, |
| Node** output) const override { |
| Node* input_graph_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); |
| Node* buffer_size = nullptr; |
| Node* seed = nullptr; |
| Node* seed2 = nullptr; |
| Node* count = nullptr; |
| |
| TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); |
| TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed)); |
| TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2)); |
| TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); |
| TF_RETURN_IF_ERROR(b->AddDataset( |
| this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs |
| {}, // Attrs |
| output)); |
| return Status::OK(); |
| } |
| |
| private: |
| const int64 seed_; |
| const int64 seed2_; |
| }; |
| |
| ShuffleAndRepeatDatasetOp::ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx) |
| : ShuffleDatasetOpBase(ctx) {} |
| |
| void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx, |
| DatasetBase* input, |
| DatasetBase** output) { |
| int64 buffer_size = 0; |
| OP_REQUIRES_OK(ctx, |
| ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size)); |
| OP_REQUIRES( |
| ctx, buffer_size > 0, |
| errors::InvalidArgument("buffer_size must be greater than zero.")); |
| |
| int64 seed; |
| OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed)); |
| |
| int64 seed2; |
| OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2)); |
| |
| int64 count; |
| OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count)); |
| |
| OP_REQUIRES(ctx, count > 0 || count == -1, |
| errors::InvalidArgument( |
| "count must be greater than zero or equal to -1.")); |
| |
| // By TensorFlow convention, if both seeds are 0, then shuffling should be |
| // seeded non-deterministically. |
| if (seed == 0 && seed2 == 0) { |
| seed = random::New64(); |
| seed2 = random::New64(); |
| } |
| |
| *output = new Dataset(ctx, input, buffer_size, seed, seed2, count); |
| } |
| |
| namespace { |
| REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU), |
| ShuffleDatasetOp); |
| |
| REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV2").Device(DEVICE_CPU), |
| ShuffleDatasetOp); |
| |
| REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU), |
| ShuffleAndRepeatDatasetOp); |
| } // namespace |
| } // namespace data |
| } // namespace tensorflow |