blob: ba8336653f4de6be4876ba20b7401faedd54027a [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
#include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_inputstream.h"
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
#include "tensorflow/core/platform/coding.h"
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/protobuf/data/experimental/snapshot.pb.h"
namespace tensorflow {
namespace data {
namespace snapshot_util {
/* static */ constexpr const int64 Reader::kSnappyReaderInputBufferSizeBytes;
/* static */ constexpr const int64 Reader::kSnappyReaderOutputBufferSizeBytes;
Writer::Writer(const std::string& filename, const std::string& compression_type,
int version, const DataTypeVector& dtypes)
: filename_(filename),
compression_type_(compression_type),
version_(version),
dtypes_(dtypes) {}
Status Writer::Create(Env* env, const std::string& filename,
const std::string& compression_type, int version,
const DataTypeVector& dtypes,
std::unique_ptr<Writer>* out_writer) {
*out_writer =
absl::WrapUnique(new Writer(filename, compression_type, version, dtypes));
return (*out_writer)->Initialize(env);
}
Status Writer::Initialize(tensorflow::Env* env) {
TF_RETURN_IF_ERROR(env->NewWritableFile(filename_, &dest_));
#if defined(IS_SLIM_BUILD)
if (compression_type_ != io::compression::kNone) {
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
<< "off compression.";
}
#else // IS_SLIM_BUILD
if (compression_type_ == io::compression::kGzip) {
zlib_underlying_dest_.swap(dest_);
io::ZlibCompressionOptions zlib_options;
zlib_options = io::ZlibCompressionOptions::GZIP();
io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
zlib_underlying_dest_.get(), zlib_options.input_buffer_size,
zlib_options.output_buffer_size, zlib_options);
TF_CHECK_OK(zlib_output_buffer->Init());
dest_.reset(zlib_output_buffer);
}
#endif // IS_SLIM_BUILD
simple_tensor_mask_.reserve(dtypes_.size());
for (const auto& dtype : dtypes_) {
if (DataTypeCanUseMemcpy(dtype)) {
simple_tensor_mask_.push_back(true);
num_simple_++;
} else {
simple_tensor_mask_.push_back(false);
num_complex_++;
}
}
return Status::OK();
}
Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
if (compression_type_ != io::compression::kSnappy) {
experimental::SnapshotRecord record;
for (const auto& tensor : tensors) {
TensorProto* t = record.add_tensor();
tensor.AsProtoTensorContent(t);
}
#if defined(PLATFORM_GOOGLE)
return WriteRecord(record.SerializeAsCord());
#else // PLATFORM_GOOGLE
return WriteRecord(record.SerializeAsString());
#endif // PLATFORM_GOOGLE
}
if (version_ != 1) {
return errors::InvalidArgument("Version: ", version_, " is not supported.");
}
if (compression_type_ != io::compression::kSnappy) {
return errors::InvalidArgument(
"Version 1 is only compatible with snappy compression");
}
std::vector<const TensorBuffer*> tensor_buffers;
tensor_buffers.reserve(num_simple_);
std::vector<TensorProto> tensor_protos;
tensor_protos.reserve(num_complex_);
experimental::SnapshotTensorMetadata metadata;
int64 total_size = 0;
for (int i = 0; i < tensors.size(); ++i) {
const Tensor& tensor = tensors[i];
experimental::TensorMetadata* tensor_metadata =
metadata.add_tensor_metadata();
tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
int64 size = 0;
if (simple_tensor_mask_[i]) {
auto tensor_buffer = DMAHelper::buffer(&tensor);
tensor_buffers.push_back(tensor_buffer);
size = tensor_buffer->size();
} else {
TensorProto proto;
tensor.AsProtoTensorContent(&proto);
size = proto.ByteSizeLong();
tensor_protos.push_back(std::move(proto));
}
tensor_metadata->set_tensor_size_bytes(size);
total_size += size;
}
std::vector<char> uncompressed(total_size);
char* position = uncompressed.data();
int buffer_index = 0;
int proto_index = 0;
for (int i = 0; i < tensors.size(); ++i) {
const auto& tensor_metadata = metadata.tensor_metadata(i);
if (simple_tensor_mask_[i]) {
memcpy(position, tensor_buffers[buffer_index]->data(),
tensor_metadata.tensor_size_bytes());
buffer_index++;
} else {
tensor_protos[proto_index].SerializeToArray(
position, tensor_metadata.tensor_size_bytes());
proto_index++;
}
position += tensor_metadata.tensor_size_bytes();
}
DCHECK_EQ(position, uncompressed.data() + total_size);
string output;
if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) {
return errors::Internal("Failed to compress using snappy.");
}
#if defined(PLATFORM_GOOGLE)
absl::Cord metadata_serialized = metadata.SerializeAsCord();
#else // PLATFORM_GOOGLE
std::string metadata_serialized = metadata.SerializeAsString();
#endif // PLATFORM_GOOGLE
TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized));
TF_RETURN_IF_ERROR(WriteRecord(output));
return Status::OK();
}
Status Writer::Sync() { return dest_->Sync(); }
Status Writer::Close() {
if (dest_ != nullptr) {
TF_RETURN_IF_ERROR(dest_->Close());
dest_ = nullptr;
}
if (zlib_underlying_dest_ != nullptr) {
TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close());
zlib_underlying_dest_ = nullptr;
}
return Status::OK();
}
Writer::~Writer() {
Status s = Close();
if (!s.ok()) {
LOG(ERROR) << "Could not finish writing file: " << s;
}
}
Status Writer::WriteRecord(const StringPiece& data) {
char header[kHeaderSize];
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
return dest_->Append(data);
}
#if defined(PLATFORM_GOOGLE)
Status Writer::WriteRecord(const absl::Cord& data) {
char header[kHeaderSize];
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
return dest_->Append(data);
}
#endif // PLATFORM_GOOGLE
Reader::Reader(RandomAccessFile* file, const string& compression_type,
int version, const DataTypeVector& dtypes)
: file_(file),
input_stream_(new io::RandomAccessInputStream(file)),
compression_type_(compression_type),
version_(version),
dtypes_(dtypes) {
#if defined(IS_SLIM_BUILD)
if (compression_type_ != io::compression::kNone) {
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
<< "off compression.";
}
#else // IS_SLIM_BUILD
if (compression_type_ == io::compression::kGzip) {
io::ZlibCompressionOptions zlib_options;
zlib_options = io::ZlibCompressionOptions::GZIP();
input_stream_ = absl::make_unique<io::ZlibInputStream>(
input_stream_.release(), zlib_options.input_buffer_size,
zlib_options.output_buffer_size, zlib_options, true);
} else if (compression_type_ == io::compression::kSnappy) {
if (version_ == 0) {
input_stream_ = absl::make_unique<io::SnappyInputBuffer>(
file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
/*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
} else {
input_stream_ =
absl::make_unique<io::BufferedInputStream>(file_, 64 << 20);
}
}
#endif // IS_SLIM_BUILD
simple_tensor_mask_.reserve(dtypes.size());
for (const auto& dtype : dtypes) {
if (DataTypeCanUseMemcpy(dtype)) {
simple_tensor_mask_.push_back(true);
num_simple_++;
} else {
simple_tensor_mask_.push_back(false);
num_complex_++;
}
}
}
Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) {
profiler::TraceMe activity(
[&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
profiler::TraceMeLevel::kInfo);
if (version_ == 0 || compression_type_ != io::compression::kSnappy) {
return ReadTensorsV0(read_tensors);
}
if (version_ != 1) {
return errors::InvalidArgument("Version: ", version_, " is not supported.");
}
if (compression_type_ != io::compression::kSnappy) {
return errors::InvalidArgument("Version 1 only supports snappy.");
}
experimental::SnapshotTensorMetadata metadata;
tstring metadata_str;
TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
return errors::DataLoss("Could not parse SnapshotTensorMetadata");
}
read_tensors->reserve(metadata.tensor_metadata_size());
std::vector<Tensor> simple_tensors;
simple_tensors.reserve(num_simple_);
std::vector<std::pair<std::unique_ptr<char[]>, size_t>> tensor_proto_strs;
tensor_proto_strs.reserve(num_complex_);
TF_RETURN_IF_ERROR(
SnappyUncompress(&metadata, &simple_tensors, &tensor_proto_strs));
int simple_index = 0;
int complex_index = 0;
for (int i = 0; i < simple_tensor_mask_.size(); ++i) {
if (simple_tensor_mask_[i]) {
read_tensors->push_back(std::move(simple_tensors[simple_index]));
simple_index++;
} else {
auto tensor_proto_str = std::move(tensor_proto_strs[complex_index].first);
size_t tensor_proto_size = tensor_proto_strs[complex_index].second;
TensorProto tp;
#if defined(PLATFORM_GOOGLE)
auto tensor_proto_ptr = tensor_proto_str.release();
absl::Cord c;
c.AppendExternalMemory(
absl::string_view(tensor_proto_ptr, tensor_proto_size),
tensor_proto_ptr,
[](void* arg) { delete[] static_cast<char*>(arg); });
if (!tp.ParseFromCord(c)) {
return errors::Internal("Could not parse TensorProto");
}
#else // PLATFORM_GOOGLE
if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) {
return errors::Internal("Could not parse TensorProto");
}
#endif // PLATFORM_GOOGLE
Tensor t;
if (!t.FromProto(tp)) {
return errors::Internal("Could not parse Tensor");
}
read_tensors->push_back(std::move(t));
complex_index++;
}
}
return Status::OK();
}
Status Reader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
experimental::SnapshotRecord record;
#if defined(PLATFORM_GOOGLE)
absl::Cord c;
TF_RETURN_IF_ERROR(ReadRecord(&c));
record.ParseFromCord(c);
#else // PLATFORM_GOOGLE
tstring record_bytes;
TF_RETURN_IF_ERROR(ReadRecord(&record_bytes));
record.ParseFromArray(record_bytes.data(), record_bytes.size());
#endif // PLATFORM_GOOGLE
read_tensors->reserve(record.tensor_size());
for (int i = 0; i < record.tensor_size(); ++i) {
read_tensors->emplace_back();
if (!read_tensors->back().FromProto(record.tensor(i))) {
return errors::DataLoss("Unable to parse tensor from proto.");
}
}
return Status::OK();
}
Status Reader::SnappyUncompress(
const experimental::SnapshotTensorMetadata* metadata,
std::vector<Tensor>* simple_tensors,
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
tensor_proto_strs) {
tstring compressed;
TF_RETURN_IF_ERROR(ReadRecord(&compressed));
size_t size;
if (!port::Snappy_GetUncompressedLength(compressed.data(), compressed.size(),
&size)) {
return errors::Internal("Could not get snappy uncompressed length");
}
int num_tensors = metadata->tensor_metadata_size();
std::vector<struct iovec> iov(num_tensors);
int index = 0;
int64 total_size = 0;
for (int i = 0; i < simple_tensor_mask_.size(); ++i) {
const auto& tensor_metadata = metadata->tensor_metadata(i);
if (simple_tensor_mask_[i]) {
TensorShape shape(tensor_metadata.tensor_shape());
Tensor simple_tensor(dtypes_[i], shape);
TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor);
iov[index].iov_base = buffer->data();
iov[index].iov_len = buffer->size();
simple_tensors->push_back(std::move(simple_tensor));
} else {
auto tensor_proto_str =
absl::make_unique<char[]>(tensor_metadata.tensor_size_bytes());
iov[index].iov_base = tensor_proto_str.get();
iov[index].iov_len = tensor_metadata.tensor_size_bytes();
tensor_proto_strs->push_back(std::make_pair(
std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes()));
}
total_size += iov[index].iov_len;
index++;
}
if (size != total_size) {
return errors::Internal("Uncompressed size mismatch. Snappy expects ", size,
" whereas the tensor metadata suggests ",
total_size);
}
if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(),
iov.data(), num_tensors)) {
return errors::Internal("Failed to perform snappy decompression.");
}
return Status::OK();
}
Status Reader::ReadRecord(tstring* record) {
tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data());
return input_stream_->ReadNBytes(length, record);
}
#if defined(PLATFORM_GOOGLE)
Status Reader::ReadRecord(absl::Cord* record) {
tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data());
if (compression_type_ == io::compression::kNone) {
return input_stream_->ReadNBytes(length, record);
} else {
auto tmp_str = absl::make_unique<tstring>();
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str.get()));
tstring* tmp_str_raw = tmp_str.release();
record->AppendExternalMemory(*tmp_str_raw, tmp_str_raw,
[](absl::string_view unused_data, void* arg) {
delete static_cast<tstring*>(arg);
});
return Status::OK();
}
}
#endif
Status WriteMetadataFile(const string& hash_dir,
const experimental::SnapshotMetadataRecord* metadata) {
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
std::string tmp_filename =
absl::StrCat(metadata_filename, "-tmp-", random::New64());
TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), tmp_filename, *metadata));
return Env::Default()->RenameFile(tmp_filename, metadata_filename);
}
Status ReadMetadataFile(const string& hash_dir,
experimental::SnapshotMetadataRecord* metadata,
bool* file_exists) {
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
Status s = Env::Default()->FileExists(metadata_filename);
*file_exists = s.ok();
if (*file_exists) {
return ReadBinaryProto(Env::Default(), metadata_filename, metadata);
} else {
return Status::OK();
}
}
Status DumpDatasetGraph(const std::string& path, uint64 hash,
const GraphDef* graph) {
std::string hash_hex =
strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
std::string graph_file =
io::JoinPath(path, absl::StrCat(hash_hex, "-graph.pbtxt"));
LOG(INFO) << "Graph hash is " << hash_hex << ", writing to " << graph_file;
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(path));
return WriteTextProto(Env::Default(), graph_file, *graph);
}
Status DetermineOpState(const std::string& mode_string, bool file_exists,
const experimental::SnapshotMetadataRecord* metadata,
const uint64 pending_snapshot_expiry_seconds,
Mode* mode) {
if (mode_string == kModeRead) {
// In read mode, we should expect a metadata file is written.
if (!file_exists) {
return errors::NotFound("Metadata file does not exist.");
}
LOG(INFO) << "Overriding mode to reader.";
*mode = READER;
return Status::OK();
}
if (mode_string == kModeWrite) {
LOG(INFO) << "Overriding mode to writer.";
*mode = WRITER;
return Status::OK();
}
if (mode_string == kModePassthrough) {
LOG(INFO) << "Overriding mode to passthrough.";
*mode = PASSTHROUGH;
return Status::OK();
}
if (!file_exists) {
*mode = WRITER;
return Status::OK();
}
if (metadata->finalized()) {
// File found, snapshot has been finalized.
*mode = READER;
return Status::OK();
}
if (metadata->creation_timestamp() >=
(static_cast<int64>(EnvTime::NowMicros()) -
pending_snapshot_expiry_seconds * 1000000)) {
// Someone else is already writing and time has not expired.
*mode = PASSTHROUGH;
return Status::OK();
} else {
// Time has expired, we write regardless.
*mode = WRITER;
return Status::OK();
}
}
} // namespace snapshot_util
} // namespace data
} // namespace tensorflow