blob: 1f76ca21476f436857c240d554b3079a2faf258b [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/time/clock.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/raw_coding.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/compression.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
#if !defined(IS_SLIM_BUILD)
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_inputstream.h"
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
#endif // IS_SLIM_BUILD
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/base64.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/cord.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/protobuf/data/experimental/snapshot.pb.h"
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
namespace data {
namespace experimental {
namespace {
enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
// Defaults to 10 GiB per shard.
const int64 kDefaultShardSizeBytes = 10LL * 1024 * 1024 * 1024;
const size_t kHeaderSize = sizeof(uint64);
const char kSnapshotFilename[] = "snapshot.metadata";
constexpr char kSnapshotReaderWorkerPool[] = "snapshot_reader_worker_pool";
constexpr char kSnapshotWriterWorkerPool[] = "snapshot_writer_worker_pool";
class SnapshotWriter {
public:
explicit SnapshotWriter(WritableFile* dest, const string& compression_type =
io::compression::kNone)
: dest_(dest), compression_type_(compression_type) {
if (compression_type == io::compression::kGzip) {
#if defined(IS_SLIM_BUILD)
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
<< "off compression.";
#else // IS_SLIM_BUILD
io::ZlibCompressionOptions zlib_options;
zlib_options = io::ZlibCompressionOptions::GZIP();
io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
dest, zlib_options.input_buffer_size, zlib_options.output_buffer_size,
zlib_options);
TF_CHECK_OK(zlib_output_buffer->Init());
dest_ = zlib_output_buffer;
dest_is_owned_ = true;
#endif // IS_SLIM_BUILD
}
}
Status WriteRecord(const StringPiece& data) {
char header[kHeaderSize];
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
return dest_->Append(data);
}
#if defined(PLATFORM_GOOGLE)
Status WriteRecord(const absl::Cord& data) {
char header[kHeaderSize];
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
return dest_->Append(data);
}
#endif // PLATFORM_GOOGLE
Status Close() {
if (dest_is_owned_) {
Status s = dest_->Close();
delete dest_;
dest_ = nullptr;
return s;
}
return Status::OK();
}
~SnapshotWriter() {
if (dest_ != nullptr) {
Status s = Close();
if (!s.ok()) {
LOG(ERROR) << "Could not finish writing file: " << s;
}
}
}
private:
WritableFile* dest_;
bool dest_is_owned_ = false;
const string compression_type_;
};
class SnapshotReader {
public:
explicit SnapshotReader(
RandomAccessFile* file,
const string& compression_type = io::compression::kNone)
: input_stream_(new io::RandomAccessInputStream(file)),
compression_type_(compression_type) {
if (compression_type_ == io::compression::kGzip) {
#if defined(IS_SLIM_BUILD)
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
<< "off compression.";
#else // IS_SLIM_BUILD
io::ZlibCompressionOptions zlib_options;
zlib_options = io::ZlibCompressionOptions::GZIP();
input_stream_.reset(new io::ZlibInputStream(
input_stream_.release(), zlib_options.input_buffer_size,
zlib_options.output_buffer_size, zlib_options, true));
#endif // IS_SLIM_BUILD
}
}
Status ReadRecord(string* record) {
string header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data());
return input_stream_->ReadNBytes(length, record);
}
#if defined(PLATFORM_GOOGLE)
Status ReadRecord(absl::Cord* record) {
string header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data());
if (compression_type_ == io::compression::kNone) {
return input_stream_->ReadNBytes(length, record);
} else {
string tmp_str;
Status s = input_stream_->ReadNBytes(length, &tmp_str);
record->Append(tmp_str);
return s;
}
}
#endif
private:
std::unique_ptr<io::InputStreamInterface> input_stream_;
const string compression_type_;
};
Status WriteMetadataFile(const string& hash_dir,
const experimental::SnapshotMetadataRecord& metadata) {
string metadata_filename = absl::StrCat(hash_dir, "/", kSnapshotFilename);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
std::unique_ptr<WritableFile> file;
TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(metadata_filename, &file));
auto writer = absl::make_unique<SnapshotWriter>(file.get());
TF_RETURN_IF_ERROR(writer->WriteRecord(metadata.SerializeAsString()));
TF_RETURN_IF_ERROR(writer->Close());
return Status::OK();
}
Status ReadMetadataFile(const string& hash_dir,
experimental::SnapshotMetadataRecord* metadata) {
string metadata_filename = absl::StrCat(hash_dir, "/", kSnapshotFilename);
TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename));
std::unique_ptr<RandomAccessFile> file;
TF_CHECK_OK(Env::Default()->NewRandomAccessFile(metadata_filename, &file));
string record_bytes;
auto reader = absl::make_unique<SnapshotReader>(file.get());
TF_CHECK_OK(reader->ReadRecord(&record_bytes));
metadata->ParseFromString(record_bytes);
return Status::OK();
}
SnapshotMode DetermineOpState(
const Status& file_status,
const experimental::SnapshotMetadataRecord& metadata,
const uint64 pending_snapshot_expiry_seconds) {
if (errors::IsNotFound(file_status)) {
return WRITER;
}
if (metadata.finalized()) {
// File found, snapshot has been finalized.
return READER;
}
if (metadata.creation_timestamp() >=
(static_cast<int64>(Env::Default()->NowMicros()) -
pending_snapshot_expiry_seconds * 1000000)) {
// Someone else is already writing and time has not expired.
return PASSTHROUGH;
} else {
// Time has expired, we write regardless.
return WRITER;
}
}
Status GraphHash(const GraphDef& graph_def, std::string* hash) {
grappler::GraphView gv(&graph_def);
std::string sink_node_name;
for (auto& node : graph_def.node()) {
if (node.op() == "_Retval") {
sink_node_name = node.name();
break;
}
}
if (sink_node_name.empty()) {
return errors::Internal("Cannot find sink node for dataset graph.");
}
uint64 hash_int = HashSubgraph(graph_def, gv.GetNode(sink_node_name));
*hash = strings::StrCat(strings::Hex(hash_int, strings::kZeroPad16));
return Status::OK();
}
class SnapshotDatasetOp : public UnaryDatasetOpKernel {
public:
explicit SnapshotDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("reader_path_prefix", &reader_path_prefix_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("writer_path_prefix", &writer_path_prefix_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("compression", &compression_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shard_size_bytes", &shard_size_bytes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("pending_snapshot_expiry_seconds",
&pending_snapshot_expiry_seconds_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("num_reader_threads", &num_reader_threads_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("reader_buffer_size", &reader_buffer_size_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("num_writer_threads", &num_writer_threads_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("writer_buffer_size", &writer_buffer_size_));
if (shard_size_bytes_ == -1) shard_size_bytes_ = kDefaultShardSizeBytes;
// Default to 1 day expiry for snapshots.
if (pending_snapshot_expiry_seconds_ == -1) {
pending_snapshot_expiry_seconds_ = 86400;
}
if (num_reader_threads_ == -1) num_reader_threads_ = 1;
if (reader_buffer_size_ == -1) reader_buffer_size_ = 1;
if (num_writer_threads_ == -1) num_writer_threads_ = 1;
if (writer_buffer_size_ == -1) writer_buffer_size_ = 1;
OP_REQUIRES(
ctx,
compression_ == io::compression::kNone ||
compression_ == io::compression::kGzip,
errors::InvalidArgument("compression must be either '' or 'GZIP'."));
OP_REQUIRES(
ctx, shard_size_bytes_ >= 1024 * 1024,
errors::InvalidArgument("shard_size_bytes must be at least 1 MiB."));
OP_REQUIRES(
ctx, pending_snapshot_expiry_seconds_ >= 1,
errors::InvalidArgument(
"pending_snapshot_expiry_seconds must be at least 1 second."));
}
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
string path;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "path", &path));
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.input_list = &input_list;
params.check_external_state = false;
GraphDef graph_def;
OP_REQUIRES_OK(
ctx, AsGraphDef(ctx, input, SerializationContext(params), &graph_def));
string graph_hash;
OP_REQUIRES_OK(ctx, GraphHash(graph_def, &graph_hash));
*output = new Dataset(ctx, input, path, graph_hash, reader_path_prefix_,
writer_path_prefix_, compression_, shard_size_bytes_,
pending_snapshot_expiry_seconds_, num_reader_threads_,
reader_buffer_size_, num_writer_threads_,
writer_buffer_size_);
}
private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, const string& path,
const string& graph_hash, const string& reader_path_prefix,
const string& writer_path_prefix, const string& compression,
const uint64 shard_size_bytes,
const uint64 pending_snapshot_expiry_seconds,
const uint64 num_reader_threads, const uint64 reader_buffer_size,
const uint64 num_writer_threads, const uint64 writer_buffer_size)
: DatasetBase(DatasetContext(ctx)),
input_(input),
dir_(path),
graph_hash_(graph_hash),
reader_path_prefix_(reader_path_prefix),
writer_path_prefix_(writer_path_prefix),
compression_(compression),
shard_size_bytes_(shard_size_bytes),
pending_snapshot_expiry_seconds_(pending_snapshot_expiry_seconds),
num_reader_threads_(num_reader_threads),
reader_buffer_size_(reader_buffer_size),
num_writer_threads_(num_writer_threads),
writer_buffer_size_(writer_buffer_size) {
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(
Iterator::Params{this, strings::StrCat(prefix, "::Snapshot")});
}
const DataTypeVector& output_dtypes() const override {
return input_->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return input_->output_shapes();
}
string DebugString() const override { return "SnapshotDatasetOp::Dataset"; }
int64 Cardinality() const override { return input_->Cardinality(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* path = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(dir_, &path));
AttrValue compression_attr;
b->BuildAttrValue(compression_, &compression_attr);
AttrValue reader_path_prefix_attr;
b->BuildAttrValue(reader_path_prefix_, &reader_path_prefix_attr);
AttrValue writer_path_prefix_attr;
b->BuildAttrValue(writer_path_prefix_, &writer_path_prefix_attr);
AttrValue shard_size_bytes_attr;
b->BuildAttrValue<int64>(shard_size_bytes_, &shard_size_bytes_attr);
AttrValue pending_snapshot_expiry_seconds_attr;
b->BuildAttrValue<int64>(pending_snapshot_expiry_seconds_,
&pending_snapshot_expiry_seconds_attr);
AttrValue num_reader_threads_attr;
b->BuildAttrValue<int64>(num_reader_threads_, &num_reader_threads_attr);
AttrValue reader_buffer_size_attr;
b->BuildAttrValue<int64>(reader_buffer_size_, &reader_buffer_size_attr);
AttrValue num_writer_threads_attr;
b->BuildAttrValue<int64>(num_writer_threads_, &num_writer_threads_attr);
AttrValue writer_buffer_size_attr;
b->BuildAttrValue<int64>(writer_buffer_size_, &writer_buffer_size_attr);
TF_RETURN_IF_ERROR(b->AddDataset(
this,
/*inputs=*/
{std::make_pair(0, input_graph_node), std::make_pair(1, path)},
/*list_inputs=*/
{},
/*attrs=*/
{{"compression", compression_attr},
{"reader_path_prefix", reader_path_prefix_attr},
{"writer_path_prefix", writer_path_prefix_attr},
{"shard_size_bytes", shard_size_bytes_attr},
{"pending_snapshot_expiry_seconds",
pending_snapshot_expiry_seconds_attr},
{"num_reader_threads", num_reader_threads_attr},
{"reader_buffer_size", reader_buffer_size_attr},
{"num_writer_threads", num_writer_threads_attr},
{"writer_buffer_size", writer_buffer_size_attr}},
output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
hash_dir_ = absl::StrCat(dataset()->dir_, "/", dataset()->graph_hash_);
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (iterator_ == nullptr) {
experimental::SnapshotMetadataRecord metadata;
Status s = ReadMetadataFile(hash_dir_, &metadata);
state_ = DetermineOpState(
s, metadata, dataset()->pending_snapshot_expiry_seconds_);
switch (state_) {
case WRITER:
iterator_ = absl::make_unique<SnapshotWriterIterator>(
SnapshotWriterIterator::Params{
dataset(), strings::StrCat(prefix(), "Impl")},
hash_dir_);
break;
case READER:
iterator_ = absl::make_unique<SnapshotReaderIterator>(
SnapshotReaderIterator::Params{
dataset(), strings::StrCat(prefix(), "Impl")},
hash_dir_, metadata);
break;
case PASSTHROUGH:
iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
SnapshotPassthroughIterator::Params{
dataset(), strings::StrCat(prefix(), "Impl")});
break;
}
TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
}
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
// TODO(frankchn): Make save iterators work
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
// TODO(frankchn): Make iterator restores work
return Status::OK();
}
private:
class SnapshotReaderIterator : public DatasetIterator<Dataset> {
public:
explicit SnapshotReaderIterator(
const Params& params, const string& hash_dir,
const experimental::SnapshotMetadataRecord& metadata)
: DatasetIterator<Dataset>(params),
hash_dir_(hash_dir),
metadata_(metadata) {}
~SnapshotReaderIterator() override {
mutex_lock l(mu_);
cancelled_ = true;
cond_var_.notify_all();
while (num_active_threads_ > 0) {
cond_var_.wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
thread_pool_ = ctx->CreateThreadPool(kSnapshotReaderWorkerPool,
dataset()->num_reader_threads_);
run_id_ = metadata_.run_id();
run_dir_ = absl::StrCat(hash_dir_, "/", run_id_);
// Get all the files in the run_dir.
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
absl::StrCat(run_dir_, "/*"), &filenames_));
if (filenames_.empty()) {
return errors::InvalidArgument("Could not find any files in dir: ",
run_dir_);
}
std::sort(filenames_.begin(), filenames_.end());
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
absl::Time start = absl::Now();
mutex_lock l(mu_);
if (!background_threads_started_) {
for (int i = 0; i < dataset()->num_reader_threads_; ++i) {
++num_active_threads_;
thread_pool_->Schedule([this]() { ReadingFilesLoop(); });
}
background_threads_started_ = true;
}
// Wait till the buffer has something in it.
while (!cancelled_ && buffer_.empty() &&
!background_threads_finished_) {
cond_var_.wait(l);
}
if (cancelled_) {
return errors::Cancelled(
"SnapshotDatasetOp::Dataset::SnapshotReaderIterator::GetNext");
}
if (!buffer_.empty()) {
Status s = buffer_.front().status;
if (s.ok()) {
*end_of_sequence = false;
*out_tensors = std::move(buffer_.front().value);
// Printing some statistics along the way.
int64 num_bytes = 0;
for (int i = 0; i < out_tensors->size(); ++i) {
num_bytes += (*out_tensors)[i].TotalBytes();
}
absl::Time end = absl::Now();
absl::Duration d = end - start;
time_spent_micros_ += absl::ToInt64Microseconds(d);
kbytes_read_ += static_cast<double>(num_bytes) / 1024.0;
elements_produced_++;
if (elements_produced_ % 10000 == 0) {
LOG(INFO) << "Current read throughput (MBPS): "
<< ((kbytes_read_ / 1024.0) /
(time_spent_micros_ / 1000000.0));
}
}
buffer_.pop_front();
cond_var_.notify_all();
return s;
}
if (background_threads_finished_) {
*end_of_sequence = true;
return Status::OK();
}
return errors::Internal("Unreachable point in SnapshotReader");
}
private:
// Reads one file end to end.
Status ReadFile(const string& filename) {
std::unique_ptr<RandomAccessFile> file;
TF_CHECK_OK(Env::Default()->NewRandomAccessFile(filename, &file));
std::unique_ptr<SnapshotReader> reader(
new SnapshotReader(file.get(), dataset()->compression_));
while (true) {
// Wait for a slot in the buffer.
{
mutex_lock l(mu_);
while (!cancelled_ &&
buffer_.size() >= dataset()->reader_buffer_size_) {
cond_var_.wait(l);
}
if (cancelled_) {
return errors::Cancelled(
"SnapshotDatasetOp::Dataset::SnapshotReaderIterator::"
"ReadFile");
}
}
#if !defined(PLATFORM_GOOGLE)
string record_bytes;
Status s = reader->ReadRecord(&record_bytes);
#else
absl::Cord record_cord;
Status s = reader->ReadRecord(&record_cord);
#endif
if (s.ok()) {
experimental::SnapshotRecord record;
#if !defined(PLATFORM_GOOGLE)
record.ParseFromString(record_bytes);
#else
record.ParseFromCord(record_cord);
#endif
std::vector<Tensor> out_tensors;
for (int i = 0; i < record.tensor_size(); ++i) {
Tensor t;
if (!t.FromProto(record.tensor(i))) {
return errors::DataLoss("Unable to parse tensor from proto.");
}
out_tensors.push_back(t);
}
BufferElement elem;
std::swap(elem.value, out_tensors);
elem.status = Status::OK();
mutex_lock l(mu_);
buffer_.push_back(std::move(elem));
cond_var_.notify_all();
} else if (errors::IsOutOfRange(s)) {
return Status::OK();
} else {
return s;
}
}
return Status::OK();
}
// Pulls one file off the filenames_ list and reads it through. When
// all files are read, terminates.
void ReadingFilesLoop() {
auto cleanup = gtl::MakeCleanup([this]() {
mutex_lock l(mu_);
--num_active_threads_;
cond_var_.notify_all();
});
while (true) {
string filename = "";
{
mutex_lock l(mu_);
if (next_file_index_ >= filenames_.size()) {
return;
}
filename = absl::StrCat(dataset()->reader_path_prefix_,
filenames_[next_file_index_]);
VLOG(2) << "Starting to read: " << filename;
next_file_index_++;
}
Status s = ReadFile(filename);
// If we get to the end of the file, it's a clean termination and
// we are at the end of the file. If all files have been processed,
// then we insert an end_of_sequence marker in the buffer and
// terminate the loop.
if (s.ok()) {
VLOG(2) << "Finished reading: " << filename;
mutex_lock l(mu_);
num_files_done_++;
if (num_files_done_ >= filenames_.size()) {
background_threads_finished_ = true;
cond_var_.notify_all();
return;
}
} else {
LOG(ERROR) << "Encountered an error: " << s.ToString();
BufferElement elem;
elem.status = s;
mutex_lock l(mu_);
buffer_.push_back(std::move(elem));
cond_var_.notify_all();
return;
}
}
}
struct BufferElement {
Status status;
std::vector<Tensor> value;
};
mutex mu_;
condition_variable cond_var_;
const string hash_dir_;
const experimental::SnapshotMetadataRecord metadata_;
string run_id_ GUARDED_BY(mu_);
string run_dir_ GUARDED_BY(mu_);
std::vector<string> filenames_;
uint64 elements_produced_ GUARDED_BY(mu_) = 0;
int64 time_spent_micros_ GUARDED_BY(mu_) = 0;
double kbytes_read_ GUARDED_BY(mu_) = 0;
size_t next_file_index_ GUARDED_BY(mu_) = 0;
int64 num_files_done_ GUARDED_BY(mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_;
int64 num_active_threads_ GUARDED_BY(mu_) = 0;
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false;
bool background_threads_started_ GUARDED_BY(mu_) = false;
bool background_threads_finished_ GUARDED_BY(mu_) = false;
};
class SnapshotWriterIterator : public DatasetIterator<Dataset> {
public:
explicit SnapshotWriterIterator(const Params& params,
const string& hash_dir)
: DatasetIterator<Dataset>(params), hash_dir_(hash_dir) {}
~SnapshotWriterIterator() override {
mutex_lock l(mu_);
cancelled_ = true;
cond_var_.notify_all();
while (num_active_threads_ > 0) {
cond_var_.wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
thread_pool_ = ctx->CreateThreadPool(kSnapshotWriterWorkerPool,
dataset()->num_writer_threads_);
run_id_ = strings::StrCat(
strings::Hex(random::New64(), strings::kZeroPad4));
run_dir_ = absl::StrCat(dataset()->writer_path_prefix_, hash_dir_,
"/", run_id_);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(run_dir_));
experimental::SnapshotMetadataRecord metadata;
metadata.set_creation_timestamp(Env::Default()->NowMicros());
metadata.set_graph_hash(dataset()->graph_hash_);
metadata.set_run_id(run_id_);
metadata.set_finalized(false);
TF_RETURN_IF_ERROR(WriteMetadataFile(hash_dir_, metadata));
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
absl::Time start = absl::Now();
bool first_call;
{
mutex_lock l(mu_);
first_call = first_call_;
if (first_call_) {
for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
++num_active_threads_;
thread_pool_->Schedule([this]() { WriterThread(); });
}
first_call_ = false;
}
}
// When we reach the end of the data, we'd like to finalize the
// snapshot and write the metadata file out. If we just check for
// end_of_sequence on the GetNext call then we will need to make
// N + 1 GetNext calls (if N is the total number of elements in the
// dataset). So right now we solve this issue by prefetching the next
// element in the data stream. Therefore the first call ends up
// pulling two elements.
if (first_call) {
TF_RETURN_IF_ERROR(FillBuffer(ctx));
}
{
mutex_lock l(mu_);
// Populate out_tensors with the prefetched data.
*out_tensors = std::move(next_elem_.value);
*end_of_sequence = next_elem_.end_of_sequence;
}
// Update prefetched_elem with the next element.
TF_RETURN_IF_ERROR(FillBuffer(ctx));
// Book keeping to report some statistics.
mutex_lock l(mu_);
int64 num_bytes = 0;
for (auto out_tensor : *out_tensors) {
num_bytes += out_tensor.TotalBytes();
}
absl::Time end = absl::Now();
absl::Duration d = end - start;
time_spent_micros_ += absl::ToInt64Microseconds(d);
bytes_produced_ += num_bytes;
elements_produced_++;
if (elements_produced_ % 10000 == 0) {
LOG(INFO) << "Current write throughput (MBPS): "
<< (bytes_produced_ * 1000000.0) /
(time_spent_micros_ * 1024.0 * 1024.0);
}
return Status::OK();
}
private:
struct BufferElement {
std::vector<Tensor> value;
bool end_of_sequence;
};
string GetSnapshotFilename() {
mutex_lock l(mu_);
string snapshot_data_filename = absl::StrCat(
run_dir_, "/", strings::Printf("%08llu", next_file_index_),
".snapshot");
next_file_index_++;
return snapshot_data_filename;
}
Status FillBuffer(IteratorContext* ctx) LOCKS_EXCLUDED(mu_) {
BufferElement elem;
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, &elem.value, &elem.end_of_sequence));
mutex_lock l(mu_);
next_elem_ = std::move(elem);
if (next_elem_.end_of_sequence) {
end_of_sequence_ = true;
cond_var_.notify_all();
// Now we wait till all background threads finish.
while (num_active_threads_ > 0) {
cond_var_.wait(l);
}
return Status::OK();
}
// Wait for a space in the buffer_.
while (!cancelled_ &&
buffer_.size() >= dataset()->writer_buffer_size_) {
cond_var_.wait(l);
}
if (cancelled_) {
return errors::Cancelled(
"SnapshotDatasetOp::SnapshotWriterIterator::GetNext");
}
if (buffer_.size() >= dataset()->writer_buffer_size_) {
return errors::Internal(
"Buffer size: ", buffer_.size(), " should be smaller than ",
"maximum size: ", dataset()->writer_buffer_size_);
}
BufferElement elem_copy = next_elem_;
buffer_.push_back(elem_copy);
cond_var_.notify_all();
return Status::OK();
}
Status ProcessOneElement(int64* bytes_written,
string* snapshot_data_filename,
std::unique_ptr<WritableFile>* file,
std::unique_ptr<SnapshotWriter>* writer,
bool* end_of_processing) {
bool cancelled = false;
*end_of_processing = false;
bool produced_elem = false;
bool snapshot_failed = false;
BufferElement elem;
{
mutex_lock l(mu_);
// Wait for buffer to not be empty.
while (!cancelled_ && buffer_.empty() && !end_of_sequence_ &&
!snapshot_failed_) {
cond_var_.wait(l);
}
cancelled = cancelled_;
if (!buffer_.empty()) {
produced_elem = true;
std::swap(elem, buffer_.front());
buffer_.pop_front();
cond_var_.notify_all();
} else {
*end_of_processing = end_of_sequence_;
}
snapshot_failed = snapshot_failed_;
}
if (cancelled || snapshot_failed) {
TF_RETURN_IF_ERROR((*writer)->Close());
TF_RETURN_IF_ERROR((*file)->Close());
if (snapshot_failed) {
return errors::Internal(
"SnapshotDataset::SnapshotWriterIterator snapshot failed");
}
return errors::Cancelled(
"SnapshotDataset::SnapshotWriterIterator cancelled");
}
if (produced_elem) {
experimental::SnapshotRecord record;
for (auto out_tensor : elem.value) {
*bytes_written += out_tensor.TotalBytes();
TensorProto* t = record.add_tensor();
out_tensor.AsProtoTensorContent(t);
}
if (*bytes_written > dataset()->shard_size_bytes_) {
// If we exceed the shard size, we get a new file and reset.
TF_RETURN_IF_ERROR((*writer)->Close());
TF_RETURN_IF_ERROR((*file)->Close());
*snapshot_data_filename = GetSnapshotFilename();
TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(
*snapshot_data_filename, file));
*writer = absl::make_unique<SnapshotWriter>(
file->get(), dataset()->compression_);
*bytes_written = 0;
}
#if defined(PLATFORM_GOOGLE)
TF_RETURN_IF_ERROR(
(*writer)->WriteRecord(record.SerializeAsCord()));
#else // PLATFORM_GOOGLE
TF_RETURN_IF_ERROR(
(*writer)->WriteRecord(record.SerializeAsString()));
#endif // PLATFORM_GOOGLE
return Status::OK();
}
if (*end_of_processing) {
TF_RETURN_IF_ERROR((*writer)->Close());
TF_RETURN_IF_ERROR((*file)->Close());
mutex_lock l(mu_);
if (!written_final_metadata_file_) {
experimental::SnapshotMetadataRecord metadata;
TF_RETURN_IF_ERROR(ReadMetadataFile(hash_dir_, &metadata));
if (metadata.run_id() == run_id_) {
metadata.set_finalized(true);
TF_RETURN_IF_ERROR(WriteMetadataFile(hash_dir_, metadata));
} else {
// TODO(frankchn): We lost the race, remove all snapshots.
}
written_final_metadata_file_ = true;
cond_var_.notify_all();
}
}
return Status::OK();
}
// Just pulls off elements from the buffer and writes them.
void WriterThread() {
auto cleanup = gtl::MakeCleanup([this]() {
mutex_lock l(mu_);
--num_active_threads_;
cond_var_.notify_all();
});
int64 bytes_written = 0;
string snapshot_data_filename = GetSnapshotFilename();
std::unique_ptr<WritableFile> file;
Status s =
Env::Default()->NewWritableFile(snapshot_data_filename, &file);
if (!s.ok()) {
LOG(ERROR) << "Creating " << snapshot_data_filename
<< " failed: " << s.ToString();
mutex_lock l(mu_);
snapshot_failed_ = true;
cond_var_.notify_all();
return;
}
std::unique_ptr<SnapshotWriter> writer(
new SnapshotWriter(file.get(), dataset()->compression_));
bool end_of_processing = false;
while (!end_of_processing) {
Status s =
ProcessOneElement(&bytes_written, &snapshot_data_filename,
&file, &writer, &end_of_processing);
if (!s.ok()) {
LOG(INFO) << "Error while writing snapshot data to disk: "
<< s.ToString();
mutex_lock l(mu_);
snapshot_failed_ = true;
cond_var_.notify_all();
return;
}
}
}
mutex mu_;
// This condition variable is notified
// 1. By the background writer threads when an element from the buffer
// is consumed.
// 2. By the main thread when it puts something into the buffer.
// 3. By the main thread when the destructor is called to cancel.
// 4. By the background writer threads when any error is encountered
// while writing.
// 5. By the background threads when they finish.
condition_variable cond_var_;
BufferElement next_elem_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_;
const string hash_dir_;
string run_id_ GUARDED_BY(mu_);
string run_dir_ GUARDED_BY(mu_);
uint64 elements_produced_ GUARDED_BY(mu_) = 0;
int64 time_spent_micros_ GUARDED_BY(mu_) = 0;
int64 bytes_produced_ GUARDED_BY(mu_) = 0;
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
bool snapshot_failed_ GUARDED_BY(mu_) = false;
bool cancelled_ GUARDED_BY(mu_) = false;
bool first_call_ GUARDED_BY(mu_) = true;
bool end_of_sequence_ GUARDED_BY(mu_) = false;
bool written_final_metadata_file_ GUARDED_BY(mu_) = false;
uint64 next_file_index_ GUARDED_BY(mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_;
int64 num_active_threads_ GUARDED_BY(mu_) = 0;
};
class SnapshotPassthroughIterator : public DatasetIterator<Dataset> {
public:
explicit SnapshotPassthroughIterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
private:
std::unique_ptr<IteratorBase> input_impl_;
};
string hash_dir_ GUARDED_BY(mu_);
SnapshotMode state_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
mutex mu_;
};
const DatasetBase* const input_;
const string dir_;
const string graph_hash_;
const string reader_path_prefix_;
const string writer_path_prefix_;
const string compression_;
const uint64 shard_size_bytes_;
const uint64 pending_snapshot_expiry_seconds_;
const uint64 num_reader_threads_;
const uint64 reader_buffer_size_;
const uint64 num_writer_threads_;
const uint64 writer_buffer_size_;
};
const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
string reader_path_prefix_;
string writer_path_prefix_;
string compression_;
int64 shard_size_bytes_;
int64 pending_snapshot_expiry_seconds_;
int64 num_reader_threads_;
int64 reader_buffer_size_;
int64 num_writer_threads_;
int64 writer_buffer_size_;
};
REGISTER_KERNEL_BUILDER(Name("SnapshotDataset").Device(DEVICE_CPU),
SnapshotDatasetOp);
} // namespace
} // namespace experimental
} // namespace data
} // namespace tensorflow