blob: 159b43b4cd057c8adc763c3fc5a332c26b759e68 [file] [log] [blame]
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
#define TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
#include <atomic>
#include <random>
#include <string>
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
// RecordYielder produces value records from a set of tfrecord files
// in a random order.
//
// It guarantees that:
// 1) all records in tfrecords are yielded within every epoch;
// 2) each record is yielded only once within every epoch;
// 3) the order in which records are yielded is highly randomized.
// 4) the peak memory usage is roughly avg record size *
// (opts.bufsize + opts.parellelism * 16).
//
// Usage example:
// RecordYielder::Options opts;
// opts.file_pattern = "input-*";
// opts.seed = 301;
// opts.bufsize = 1000000; // A randomized buffer with 1M records.
// opts.parallelism = 8; // Uses 8 tfrecord iterators to iterate
// // through all files.
// RecordYielder yielder(opts);
// string val;
// while (true) {
// yielder.YieldOne(&val);
// // process val
// }
//
// RecordYielder can be accessed by multiple threads concurrently.
class RecordYielder {
public:
struct Options {
// Glob pattern for tfrecords.
string file_pattern;
// Random seed. It determines how data files are shuffled and how
// records are shuffled.
int64 seed = 0;
// Each epoch, all files are first shuffled according to the
// random seed and the epoch number, and then all files are
// left-shifted by file_shuffle_shift_ratio * num_files slots. If
// file_shuffle_shift_ratio is not within [0, 1), the
// implementation clip it to [0, 1).
float file_shuffle_shift_ratio = 0;
// Randomization buffer keeps these many records.
uint64 bufsize = 1;
// Uses these many concurrent tfrecord iterators to iterate through
// tfrecords.
int32 parallelism = 1;
string compression_type;
};
explicit RecordYielder(OpKernelConstruction* context,
const RecordYielder::Options& opts);
~RecordYielder();
RecordYielder(const RecordYielder&) = delete;
RecordYielder& operator=(const RecordYielder&) = delete;
// Yields one 'value'.
Status YieldOne(string* value);
// Returns the current epoch number.
int64 current_epoch() const { return epoch_; }
private:
typedef RecordYielder ME;
Options opts_;
// Backgrounds threads. Owned.
thread::ThreadPool* thread_;
// Epoch number.
std::atomic<int64> epoch_;
mutex mu_;
// Turned to true when this is deleted.
bool stop_ GUARDED_BY(mu_) = false;
Status status_ GUARDED_BY(mu_);
// PRG used for randomization.
std::mt19937_64 rnd_ GUARDED_BY(mu_);
// Randomization buffer.
std::vector<string> buf_ GUARDED_BY(mu_);
// True iff we are draining an epoch.
bool epoch_end_ = false;
int64 num_records_added_in_epoch_ = 0;
int64 num_records_yielded_in_epoch_ = 0;
// Trigger when the main loop has exited.
Notification main_loop_done_;
// condition_variables.
condition_variable buf_empty_;
bool BufEmpty() const SHARED_LOCKS_REQUIRED(mu_) {
return stop_ || buf_.empty();
}
condition_variable buf_not_full_;
bool BufNotFull() const SHARED_LOCKS_REQUIRED(mu_) {
return stop_ || buf_.size() < opts_.bufsize;
}
condition_variable buf_enough_;
bool BufEnough() const SHARED_LOCKS_REQUIRED(mu_) {
// NOTE: Unless we are finishing an epoch, we want to make sure
// the buf_ contains enough randomized elements before yielding
// any.
return stop_ || !status_.ok() || (epoch_end_ && !buf_.empty()) ||
(!epoch_end_ &&
buf_.size() >= std::max<uint64>(1, opts_.bufsize / 2));
}
void MainLoop();
struct Shard;
void ShardLoop(Shard* shard);
bool ShouldFinish(const Status& s);
bool Add(std::vector<string>* values);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_