blob: e1c6dbeb67b9d270e065ba5b89110c03dbb46f34 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and
limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/io/compression.h"
#include "tensorflow/core/lib/io/inputstream_interface.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
class GraphDef;
namespace data {
namespace experimental {
class SnapshotMetadataRecord;
class SnapshotTensorMetadata;
} // namespace experimental
namespace snapshot_util {
constexpr char kMetadataFilename[] = "snapshot.metadata";
constexpr char kModeAuto[] = "auto";
constexpr char kModeWrite[] = "write";
constexpr char kModeRead[] = "read";
constexpr char kModePassthrough[] = "passthrough";
enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
class Writer {
static constexpr const size_t kHeaderSize = sizeof(uint64);
static constexpr const char* const kClassName = "SnapshotWriter";
static constexpr const char* const kWriteStringPiece = "WriteStringPiece";
static constexpr const char* const kWriteCord = "WriteCord";
static constexpr const char* const kSeparator = "::";
static Status Create(Env* env, const std::string& filename,
const std::string& compression_type, int version,
const DataTypeVector& dtypes,
std::unique_ptr<Writer>* out_writer);
Status WriteTensors(const std::vector<Tensor>& tensors);
Status Sync();
Status Close();
explicit Writer(const std::string& filename,
const std::string& compression_type, int version,
const DataTypeVector& dtypes);
Status Initialize(tensorflow::Env* env);
Status WriteRecord(const StringPiece& data);
#if defined(PLATFORM_GOOGLE)
Status WriteRecord(const absl::Cord& data);
std::unique_ptr<WritableFile> dest_;
const std::string filename_;
const std::string compression_type_;
const int version_;
const DataTypeVector dtypes_;
// We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that
// in dest_ if we want compression. ZlibOutputBuffer doesn't own the original
// dest_ and so we need somewhere to store the original one.
std::unique_ptr<WritableFile> zlib_underlying_dest_;
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
int num_simple_ = 0;
int num_complex_ = 0;
class Reader {
// The reader input buffer size is deliberately large because the input reader
// will throw an error if the compressed block length cannot fit in the input
// buffer.
static constexpr const int64 kSnappyReaderInputBufferSizeBytes =
1 << 30; // 1 GiB
// TODO(b/148804377): Set this in a smarter fashion.
static constexpr const int64 kSnappyReaderOutputBufferSizeBytes =
32 << 20; // 32 MiB
static constexpr const size_t kHeaderSize = sizeof(uint64);
static constexpr const char* const kClassName = "SnapshotReader";
static constexpr const char* const kReadString = "ReadString";
static constexpr const char* const kReadCord = "ReadCord";
static constexpr const char* const kSeparator = "::";
explicit Reader(RandomAccessFile* file, const string& compression_type,
int version, const DataTypeVector& dtypes);
Status ReadTensors(std::vector<Tensor>* read_tensors);
Status ReadTensorsV0(std::vector<Tensor>* read_tensors);
Status SnappyUncompress(
const experimental::SnapshotTensorMetadata* metadata,
std::vector<Tensor>* simple_tensors,
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
Status ReadRecord(tstring* record);
#if defined(PLATFORM_GOOGLE)
Status ReadRecord(absl::Cord* record);
RandomAccessFile* file_;
std::unique_ptr<io::InputStreamInterface> input_stream_;
const string compression_type_;
const int version_;
const DataTypeVector dtypes_;
int num_simple_ = 0;
int num_complex_ = 0;
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
Status WriteMetadataFile(const string& hash_dir,
const experimental::SnapshotMetadataRecord* metadata);
Status ReadMetadataFile(const string& hash_dir,
experimental::SnapshotMetadataRecord* metadata,
bool* file_exists);
Status DumpDatasetGraph(const std::string& path, uint64 hash,
const GraphDef* graph);
Status DetermineOpState(const std::string& mode_string, bool file_exists,
const experimental::SnapshotMetadataRecord* metadata,
const uint64 pending_snapshot_expiry_seconds,
Mode* mode);
} // namespace snapshot_util
} // namespace data
} // namespace tensorflow