blob: 97a86977e826a01d66cbe6d59f3eaf04560a63ec [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
#include <queue>
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
#include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_inputstream.h"
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
#include "tensorflow/core/platform/coding.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/protobuf/data/experimental/snapshot.pb.h"
namespace tensorflow {
namespace data {
namespace snapshot_util {
/* static */ constexpr const int64
CustomReader::kSnappyReaderInputBufferSizeBytes;
/* static */ constexpr const int64
CustomReader::kSnappyReaderOutputBufferSizeBytes;
std::string GetCurrentCheckpointFile(const std::string& shard_directory,
const uint64 current_checkpoint_id) {
return io::JoinPath(shard_directory,
absl::StrFormat("%08d.snapshot", current_checkpoint_id));
}
Status Writer::Create(Env* env, const std::string& filename,
const std::string& compression_type, int version,
const DataTypeVector& dtypes,
std::unique_ptr<Writer>* out_writer) {
switch (version) {
case 1:
*out_writer =
absl::make_unique<CustomWriter>(filename, compression_type, dtypes);
break;
case 2:
*out_writer =
absl::make_unique<TFRecordWriter>(filename, compression_type);
break;
default:
return errors::InvalidArgument("Snapshot writer version: ", version,
" is not supported.");
}
return (*out_writer)->Initialize(env);
}
TFRecordWriter::TFRecordWriter(const std::string& filename,
const std::string& compression_type)
: filename_(filename), compression_type_(compression_type) {}
Status TFRecordWriter::Initialize(tensorflow::Env* env) {
TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_));
record_writer_ = absl::make_unique<io::RecordWriter>(
dest_.get(), io::RecordWriterOptions::CreateRecordWriterOptions(
/*compression_type=*/compression_type_));
return Status::OK();
}
Status TFRecordWriter::WriteTensors(const std::vector<Tensor>& tensors) {
for (const auto& tensor : tensors) {
TensorProto proto;
tensor.AsProtoTensorContent(&proto);
#if defined(PLATFORM_GOOGLE)
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsCord()));
#else // PLATFORM_GOOGLE
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsString()));
#endif // PLATFORM_GOOGLE
}
return Status::OK();
}
Status TFRecordWriter::Sync() {
TF_RETURN_IF_ERROR(record_writer_->Flush());
return dest_->Flush();
}
Status TFRecordWriter::Close() {
if (record_writer_ != nullptr) {
TF_RETURN_IF_ERROR(Sync());
TF_RETURN_IF_ERROR(record_writer_->Close());
TF_RETURN_IF_ERROR(dest_->Close());
record_writer_ = nullptr;
dest_ = nullptr;
}
return Status::OK();
}
TFRecordWriter::~TFRecordWriter() {
Status s = Close();
if (!s.ok()) {
LOG(ERROR) << "Failed to close snapshot file " << filename_ << ": " << s;
}
}
CustomWriter::CustomWriter(const std::string& filename,
const std::string& compression_type,
const DataTypeVector& dtypes)
: filename_(filename),
compression_type_(compression_type),
dtypes_(dtypes) {}
Status CustomWriter::Initialize(tensorflow::Env* env) {
TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_));
#if defined(IS_SLIM_BUILD)
if (compression_type_ != io::compression::kNone) {
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
<< "off compression.";
}
#else // IS_SLIM_BUILD
if (compression_type_ == io::compression::kGzip) {
zlib_underlying_dest_.swap(dest_);
io::ZlibCompressionOptions zlib_options;
zlib_options = io::ZlibCompressionOptions::GZIP();
io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
zlib_underlying_dest_.get(), zlib_options.input_buffer_size,
zlib_options.output_buffer_size, zlib_options);
TF_CHECK_OK(zlib_output_buffer->Init());
dest_.reset(zlib_output_buffer);
}
#endif // IS_SLIM_BUILD
simple_tensor_mask_.reserve(dtypes_.size());
for (const auto& dtype : dtypes_) {
if (DataTypeCanUseMemcpy(dtype)) {
simple_tensor_mask_.push_back(true);
num_simple_++;
} else {
simple_tensor_mask_.push_back(false);
num_complex_++;
}
}
return Status::OK();
}
Status CustomWriter::WriteTensors(const std::vector<Tensor>& tensors) {
if (compression_type_ != io::compression::kSnappy) {
experimental::SnapshotRecord record;
for (const auto& tensor : tensors) {
TensorProto* t = record.add_tensor();
tensor.AsProtoTensorContent(t);
}
#if defined(PLATFORM_GOOGLE)
return WriteRecord(record.SerializeAsCord());
#else // PLATFORM_GOOGLE
return WriteRecord(record.SerializeAsString());
#endif // PLATFORM_GOOGLE
}
if (compression_type_ != io::compression::kSnappy) {
return errors::InvalidArgument("Compression ", compression_type_,
" is not supported.");
}
std::vector<const TensorBuffer*> tensor_buffers;
tensor_buffers.reserve(num_simple_);
std::vector<TensorProto> tensor_protos;
tensor_protos.reserve(num_complex_);
experimental::SnapshotTensorMetadata metadata;
int64 total_size = 0;
for (int i = 0; i < tensors.size(); ++i) {
const Tensor& tensor = tensors[i];
experimental::TensorMetadata* tensor_metadata =
metadata.add_tensor_metadata();
tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
int64 size = 0;
if (simple_tensor_mask_[i]) {
auto tensor_buffer = DMAHelper::buffer(&tensor);
tensor_buffers.push_back(tensor_buffer);
size = tensor_buffer->size();
} else {
TensorProto proto;
tensor.AsProtoTensorContent(&proto);
size = proto.ByteSizeLong();
tensor_protos.push_back(std::move(proto));
}
tensor_metadata->set_tensor_size_bytes(size);
total_size += size;
}
std::vector<char> uncompressed(total_size);
char* position = uncompressed.data();
int buffer_index = 0;
int proto_index = 0;
for (int i = 0; i < tensors.size(); ++i) {
const auto& tensor_metadata = metadata.tensor_metadata(i);
if (simple_tensor_mask_[i]) {
memcpy(position, tensor_buffers[buffer_index]->data(),
tensor_metadata.tensor_size_bytes());
buffer_index++;
} else {
tensor_protos[proto_index].SerializeToArray(
position, tensor_metadata.tensor_size_bytes());
proto_index++;
}
position += tensor_metadata.tensor_size_bytes();
}
DCHECK_EQ(position, uncompressed.data() + total_size);
string output;
if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) {
return errors::Internal("Failed to compress using snappy.");
}
#if defined(PLATFORM_GOOGLE)
absl::Cord metadata_serialized = metadata.SerializeAsCord();
#else // PLATFORM_GOOGLE
std::string metadata_serialized = metadata.SerializeAsString();
#endif // PLATFORM_GOOGLE
TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized));
TF_RETURN_IF_ERROR(WriteRecord(output));
return Status::OK();
}
Status CustomWriter::Sync() { return dest_->Sync(); }
Status CustomWriter::Close() {
if (dest_ != nullptr) {
TF_RETURN_IF_ERROR(dest_->Close());
dest_ = nullptr;
}
if (zlib_underlying_dest_ != nullptr) {
TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close());
zlib_underlying_dest_ = nullptr;
}
return Status::OK();
}
CustomWriter::~CustomWriter() {
Status s = Close();
if (!s.ok()) {
LOG(ERROR) << "Could not finish writing file: " << s;
}
}
Status CustomWriter::WriteRecord(const StringPiece& data) {
char header[kHeaderSize];
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
return dest_->Append(data);
}
#if defined(PLATFORM_GOOGLE)
Status CustomWriter::WriteRecord(const absl::Cord& data) {
char header[kHeaderSize];
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
return dest_->Append(data);
}
#endif // PLATFORM_GOOGLE
Status Reader::Create(Env* env, const std::string& filename,
const string& compression_type, int version,
const DataTypeVector& dtypes,
std::unique_ptr<Reader>* out_reader) {
switch (version) {
// CustomReader is able to read a legacy snapshot file format (v0) though
// custom writer doesn't have the ability to write it any more since it is
// strictly worse than V1.
case 0:
case 1:
*out_reader = absl::make_unique<CustomReader>(filename, compression_type,
version, dtypes);
break;
case 2:
*out_reader =
absl::make_unique<TFRecordReader>(filename, compression_type, dtypes);
break;
default:
return errors::InvalidArgument("Snapshot reader version: ", version,
" is not supported.");
}
return (*out_reader)->Initialize(env);
}
Status Reader::SkipRecords(int64 num_records) {
// TODO(frankchn): Optimize to not parse the entire Tensor and actually skip.
for (int i = 0; i < num_records; ++i) {
std::vector<Tensor> unused_tensors;
TF_RETURN_IF_ERROR(ReadTensors(&unused_tensors));
}
return Status::OK();
}
class Reader::Dataset : public DatasetBase {
public:
explicit Dataset(const std::string& shard_dir, const std::string& compression,
const int64 version, const DataTypeVector& dtypes,
const std::vector<PartialTensorShape>& shapes,
const int64 start_index, DatasetContext::Params params)
: DatasetBase(DatasetContext(std::move(params))),
shard_dir_(shard_dir),
compression_(compression),
version_(version),
dtypes_(dtypes),
shapes_(shapes),
start_index_(start_index) {}
const DataTypeVector& output_dtypes() const override { return dtypes_; }
const std::vector<PartialTensorShape>& output_shapes() const override {
return shapes_;
}
std::string DebugString() const override {
return "snapshot_util::Reader::Dataset";
}
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** node) const override {
// Not necessary perform any serialization as this dataset is only
// constructed at runtime in C++ and will be reconstructed every time.
return Status::OK();
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(node_name(), prefix)});
}
private:
const std::string shard_dir_;
const std::string compression_;
const int64 version_;
const DataTypeVector dtypes_;
const std::vector<PartialTensorShape> shapes_;
const int64 start_index_;
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params), current_checkpoint_id_(0) {}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(Reader::Create(
ctx->env(), GetCurrentFilename(), dataset()->compression_,
dataset()->version_, dataset()->dtypes_, &reader_));
bool end_of_sequence;
for (int64 i = 0; i < dataset()->start_index_; ++i) {
// TODO(frankchn): Optimize this to not parse every single element.
std::vector<Tensor> unused;
TF_RETURN_IF_ERROR(GetNextInternal(ctx, &unused, &end_of_sequence));
}
return Status::OK();
}
protected:
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
*end_of_sequence = false;
Status s = reader_->ReadTensors(out_tensors);
if (!errors::IsOutOfRange(s)) {
return s;
}
Status status = AdvanceToNextFile(ctx->env());
if (errors::IsNotFound(status)) {
*end_of_sequence = true;
return Status::OK();
} else {
return status;
}
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
// Not necessary to save any state as this iterator will be reconstructed
// from scratch when the parent snapshot dataset is restored from
// checkpoint.
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
// Not necessary to restore any state as this iterator will be
// reconstructed from scratch when the parent snapshot dataset is restored
// from checkpoint.
return Status::OK();
}
private:
std::unique_ptr<Reader> reader_;
// Stores the id current checkpoint file that we are in the process of
// reading (e.g. if the file is currently 00000001.snapshot, then this will
// be 1).
uint64 current_checkpoint_id_;
std::string GetCurrentFilename() {
return GetCurrentCheckpointFile(dataset()->shard_dir_,
current_checkpoint_id_);
}
Status AdvanceToNextFile(Env* env) {
current_checkpoint_id_++;
TF_RETURN_IF_ERROR(env->FileExists(GetCurrentFilename()));
return Reader::Create(env, GetCurrentFilename(), dataset()->compression_,
dataset()->version_, dataset()->dtypes_, &reader_);
}
};
};
class Reader::NestedDataset : public DatasetBase {
public:
explicit NestedDataset(std::vector<DatasetBase*> datasets,
DatasetContext::Params params)
: DatasetBase(DatasetContext(std::move(params))), datasets_(datasets) {
dtypes_.push_back(DT_VARIANT);
gtl::InlinedVector<int64, 1> element_dim_sizes;
element_dim_sizes.push_back(1);
partial_shapes_.emplace_back(element_dim_sizes);
}
const DataTypeVector& output_dtypes() const override { return dtypes_; }
const std::vector<PartialTensorShape>& output_shapes() const override {
return partial_shapes_;
}
std::string DebugString() const override {
return "snapshot_util::Reader::NestedDataset";
}
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** node) const override {
// Not necessary perform any serialization as this dataset is only
// constructed at runtime in C++ and will be reconstructed every time.
return Status::OK();
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(node_name(), prefix)});
}
private:
std::vector<DatasetBase*> datasets_;
DataTypeVector dtypes_;
std::vector<PartialTensorShape> partial_shapes_;
class Iterator : public DatasetIterator<NestedDataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<NestedDataset>(params), index_(0) {}
protected:
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
*end_of_sequence = dataset()->datasets_.size() == index_;
if (!*end_of_sequence) {
Tensor tensor(DT_VARIANT, TensorShape({}));
TF_RETURN_IF_ERROR(
StoreDatasetInVariantTensor(dataset()->datasets_[index_], &tensor));
out_tensors->clear();
out_tensors->push_back(std::move(tensor));
index_++;
}
return Status::OK();
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
// Not necessary to save any state as this iterator will be reconstructed
// from scratch when the parent snapshot dataset is restored from
// checkpoint.
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
// Not necessary to restore any state as this iterator will be
// reconstructed from scratch when the parent snapshot dataset is restored
// from checkpoint.
return Status::OK();
}
private:
int64 index_;
};
};
Status Reader::MakeNestedDataset(Env* env,
const std::vector<std::string>& shard_dirs,
const string& compression_type, int version,
const DataTypeVector& dtypes,
const std::vector<PartialTensorShape>& shapes,
const int64 start_index,
DatasetBase** output) {
std::vector<DatasetBase*> datasets;
datasets.reserve(shard_dirs.size());
for (const auto& shard_dir : shard_dirs) {
// TODO(frankchn): The reading pattern could be controlled in a non-round
// robin fashion, so we cannot assume a round-robin manner when restoring.
int64 dataset_start_index = start_index / shard_dirs.size();
if (start_index % shard_dirs.size() > datasets.size()) {
dataset_start_index++;
}
datasets.push_back(
new Dataset(shard_dir, compression_type, version, dtypes, shapes,
dataset_start_index,
DatasetContext::Params({"snapshot_util::Reader::Dataset",
"snapshot_util_reader_Dataset"})));
}
// Rotate the vector such that the first dataset contains the next element
// to be produced.
std::rotate(datasets.begin(),
datasets.begin() + (start_index % shard_dirs.size()),
datasets.end());
*output = new NestedDataset(
datasets, DatasetContext::Params({"snapshot_util::Reader::NestedDataset",
"snapshot_util_reader_NestedDataset"}));
return Status::OK();
}
TFRecordReader::TFRecordReader(const std::string& filename,
const string& compression_type,
const DataTypeVector& dtypes)
: filename_(filename),
offset_(0),
compression_type_(compression_type),
dtypes_(dtypes) {}
Status TFRecordReader::Initialize(Env* env) {
TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(filename_, &file_));
record_reader_ = absl::make_unique<io::RecordReader>(
file_.get(), io::RecordReaderOptions::CreateRecordReaderOptions(
/*compression_type=*/compression_type_));
return Status::OK();
}
Status TFRecordReader::ReadTensors(std::vector<Tensor>* read_tensors) {
read_tensors->reserve(dtypes_.size());
for (int i = 0; i < dtypes_.size(); ++i) {
tstring record;
TF_RETURN_IF_ERROR(record_reader_->ReadRecord(&offset_, &record));
TensorProto proto;
proto.ParseFromArray(record.data(), record.size());
Tensor tensor;
if (!tensor.FromProto(proto)) {
return errors::DataLoss("Unable to parse tensor from stored proto.");
}
read_tensors->push_back(std::move(tensor));
}
return Status::OK();
}
CustomReader::CustomReader(const std::string& filename,
const string& compression_type, const int version,
const DataTypeVector& dtypes)
: filename_(filename),
compression_type_(compression_type),
version_(version),
dtypes_(dtypes) {}
Status CustomReader::Initialize(Env* env) {
TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(filename_, &file_));
input_stream_ = std::make_unique<io::RandomAccessInputStream>(file_.get());
#if defined(IS_SLIM_BUILD)
if (compression_type_ != io::compression::kNone) {
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
<< "off compression.";
}
#else // IS_SLIM_BUILD
if (compression_type_ == io::compression::kGzip) {
io::ZlibCompressionOptions zlib_options;
zlib_options = io::ZlibCompressionOptions::GZIP();
input_stream_ = absl::make_unique<io::ZlibInputStream>(
input_stream_.release(), zlib_options.input_buffer_size,
zlib_options.output_buffer_size, zlib_options, true);
} else if (compression_type_ == io::compression::kSnappy) {
if (version_ == 0) {
input_stream_ = absl::make_unique<io::SnappyInputBuffer>(
file_.get(), /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
/*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
} else {
input_stream_ =
absl::make_unique<io::BufferedInputStream>(file_.get(), 64 << 20);
}
}
#endif // IS_SLIM_BUILD
simple_tensor_mask_.reserve(dtypes_.size());
for (const auto& dtype : dtypes_) {
if (DataTypeCanUseMemcpy(dtype)) {
simple_tensor_mask_.push_back(true);
num_simple_++;
} else {
simple_tensor_mask_.push_back(false);
num_complex_++;
}
}
return Status::OK();
}
Status CustomReader::ReadTensors(std::vector<Tensor>* read_tensors) {
profiler::TraceMe activity(
[&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
profiler::TraceMeLevel::kInfo);
if (version_ == 0 || compression_type_ != io::compression::kSnappy) {
return ReadTensorsV0(read_tensors);
}
if (version_ != 1) {
return errors::InvalidArgument("Version: ", version_, " is not supported.");
}
if (compression_type_ != io::compression::kSnappy) {
return errors::InvalidArgument("Compression ", compression_type_,
" is not supported.");
}
experimental::SnapshotTensorMetadata metadata;
tstring metadata_str;
TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
return errors::DataLoss("Could not parse SnapshotTensorMetadata");
}
read_tensors->reserve(metadata.tensor_metadata_size());
std::vector<Tensor> simple_tensors;
simple_tensors.reserve(num_simple_);
std::vector<std::pair<std::unique_ptr<char[]>, size_t>> tensor_proto_strs;
tensor_proto_strs.reserve(num_complex_);
TF_RETURN_IF_ERROR(
SnappyUncompress(&metadata, &simple_tensors, &tensor_proto_strs));
int simple_index = 0;
int complex_index = 0;
for (int i = 0; i < simple_tensor_mask_.size(); ++i) {
if (simple_tensor_mask_[i]) {
read_tensors->push_back(std::move(simple_tensors[simple_index]));
simple_index++;
} else {
auto tensor_proto_str = std::move(tensor_proto_strs[complex_index].first);
size_t tensor_proto_size = tensor_proto_strs[complex_index].second;
TensorProto tp;
#if defined(PLATFORM_GOOGLE)
absl::string_view tensor_proto_view(tensor_proto_str.get(),
tensor_proto_size);
absl::Cord c = absl::MakeCordFromExternal(
tensor_proto_view, [s = std::move(tensor_proto_str)] {});
if (!tp.ParseFromCord(c)) {
return errors::Internal("Could not parse TensorProto");
}
#else // PLATFORM_GOOGLE
if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) {
return errors::Internal("Could not parse TensorProto");
}
#endif // PLATFORM_GOOGLE
Tensor t;
if (!t.FromProto(tp)) {
return errors::Internal("Could not parse Tensor");
}
read_tensors->push_back(std::move(t));
complex_index++;
}
}
return Status::OK();
}
Status CustomReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
experimental::SnapshotRecord record;
#if defined(PLATFORM_GOOGLE)
absl::Cord c;
TF_RETURN_IF_ERROR(ReadRecord(&c));
record.ParseFromCord(c);
#else // PLATFORM_GOOGLE
tstring record_bytes;
TF_RETURN_IF_ERROR(ReadRecord(&record_bytes));
record.ParseFromArray(record_bytes.data(), record_bytes.size());
#endif // PLATFORM_GOOGLE
read_tensors->reserve(record.tensor_size());
for (int i = 0; i < record.tensor_size(); ++i) {
read_tensors->emplace_back();
if (!read_tensors->back().FromProto(record.tensor(i))) {
return errors::DataLoss("Unable to parse tensor from proto.");
}
}
return Status::OK();
}
Status CustomReader::SnappyUncompress(
const experimental::SnapshotTensorMetadata* metadata,
std::vector<Tensor>* simple_tensors,
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
tensor_proto_strs) {
tstring compressed;
TF_RETURN_IF_ERROR(ReadRecord(&compressed));
size_t size;
if (!port::Snappy_GetUncompressedLength(compressed.data(), compressed.size(),
&size)) {
return errors::Internal("Could not get snappy uncompressed length");
}
int num_tensors = metadata->tensor_metadata_size();
std::vector<struct iovec> iov(num_tensors);
int index = 0;
int64 total_size = 0;
for (int i = 0; i < simple_tensor_mask_.size(); ++i) {
const auto& tensor_metadata = metadata->tensor_metadata(i);
if (simple_tensor_mask_[i]) {
TensorShape shape(tensor_metadata.tensor_shape());
Tensor simple_tensor(dtypes_[i], shape);
TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor);
iov[index].iov_base = buffer->data();
iov[index].iov_len = buffer->size();
simple_tensors->push_back(std::move(simple_tensor));
} else {
auto tensor_proto_str =
absl::make_unique<char[]>(tensor_metadata.tensor_size_bytes());
iov[index].iov_base = tensor_proto_str.get();
iov[index].iov_len = tensor_metadata.tensor_size_bytes();
tensor_proto_strs->push_back(std::make_pair(
std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes()));
}
total_size += iov[index].iov_len;
index++;
}
if (size != total_size) {
return errors::Internal("Uncompressed size mismatch. Snappy expects ", size,
" whereas the tensor metadata suggests ",
total_size);
}
if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(),
iov.data(), num_tensors)) {
return errors::Internal("Failed to perform snappy decompression.");
}
return Status::OK();
}
Status CustomReader::ReadRecord(tstring* record) {
tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data());
return input_stream_->ReadNBytes(length, record);
}
#if defined(PLATFORM_GOOGLE)
Status CustomReader::ReadRecord(absl::Cord* record) {
tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data());
if (compression_type_ == io::compression::kNone) {
return input_stream_->ReadNBytes(length, record);
} else {
auto tmp_str = absl::make_unique<tstring>();
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str.get()));
absl::string_view tmp_str_view(*tmp_str);
record->Append(
absl::MakeCordFromExternal(tmp_str_view, [s = std::move(tmp_str)] {}));
return Status::OK();
}
}
#endif
Status WriteMetadataFile(const string& hash_dir,
const experimental::SnapshotMetadataRecord* metadata) {
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
std::string tmp_filename =
absl::StrCat(metadata_filename, "-tmp-", random::New64());
TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), tmp_filename, *metadata));
return Env::Default()->RenameFile(tmp_filename, metadata_filename);
}
Status ReadMetadataFile(const string& hash_dir,
experimental::SnapshotMetadataRecord* metadata,
bool* file_exists) {
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
Status s = Env::Default()->FileExists(metadata_filename);
*file_exists = s.ok();
if (*file_exists) {
return ReadBinaryProto(Env::Default(), metadata_filename, metadata);
} else {
return Status::OK();
}
}
Status DumpDatasetGraph(const std::string& path, uint64 hash,
const GraphDef* graph) {
std::string hash_hex =
strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
std::string graph_file =
io::JoinPath(path, absl::StrCat(hash_hex, "-graph.pbtxt"));
LOG(INFO) << "Graph hash is " << hash_hex << ", writing to " << graph_file;
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(path));
return WriteTextProto(Env::Default(), graph_file, *graph);
}
Status DetermineOpState(const std::string& mode_string, bool file_exists,
const experimental::SnapshotMetadataRecord* metadata,
const uint64 pending_snapshot_expiry_seconds,
Mode* mode) {
if (mode_string == kModeRead) {
// In read mode, we should expect a metadata file is written.
if (!file_exists) {
return errors::NotFound("Metadata file does not exist.");
}
LOG(INFO) << "Overriding mode to reader.";
*mode = READER;
return Status::OK();
}
if (mode_string == kModeWrite) {
LOG(INFO) << "Overriding mode to writer.";
*mode = WRITER;
return Status::OK();
}
if (mode_string == kModePassthrough) {
LOG(INFO) << "Overriding mode to passthrough.";
*mode = PASSTHROUGH;
return Status::OK();
}
if (!file_exists) {
*mode = WRITER;
return Status::OK();
}
if (metadata->finalized()) {
// File found, snapshot has been finalized.
*mode = READER;
return Status::OK();
}
if (metadata->creation_timestamp() >=
(static_cast<int64>(EnvTime::NowMicros()) -
pending_snapshot_expiry_seconds * 1000000)) {
// Someone else is already writing and time has not expired.
*mode = PASSTHROUGH;
return Status::OK();
} else {
// Time has expired, we write regardless.
*mode = WRITER;
return Status::OK();
}
}
} // namespace snapshot_util
} // namespace data
} // namespace tensorflow