Update snapshot_util::Writer to own WritableFiles instead borrowing

PiperOrigin-RevId: 306084988
Change-Id: I31e57d8f4d766127799ea55445c3eed401f3e457
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index bbad927..e5614be 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -529,6 +529,7 @@
         "//tensorflow/core/platform:coding",
         "//tensorflow/core/platform:random",
         "//tensorflow/core/profiler/lib:traceme",
+        "@com_google_absl//absl/memory",
     ],
 )
 
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
index b752c3a..db9984e 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
@@ -965,7 +965,8 @@
               }
               for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
                 ++num_active_threads_;
-                thread_pool_->Schedule([this]() { WriterThread(); });
+                thread_pool_->Schedule(
+                    [this, env = ctx->env()]() { WriterThread(env); });
               }
               first_call_ = false;
             }
@@ -1262,9 +1263,8 @@
 
         Status ProcessOneElement(int64* bytes_written,
                                  string* snapshot_data_filename,
-                                 std::unique_ptr<WritableFile>* file,
                                  std::unique_ptr<snapshot_util::Writer>* writer,
-                                 bool* end_of_processing) {
+                                 bool* end_of_processing, Env* env) {
           profiler::TraceMe activity(
               [&]() {
                 return absl::StrCat(prefix(), kSeparator, kProcessOneElement);
@@ -1296,8 +1296,6 @@
 
           if (cancelled || snapshot_failed) {
             TF_RETURN_IF_ERROR((*writer)->Close());
-            TF_RETURN_IF_ERROR((*file)->Sync());
-            TF_RETURN_IF_ERROR((*file)->Close());
             if (snapshot_failed) {
               return errors::Internal(
                   "SnapshotDataset::SnapshotWriterIterator snapshot failed");
@@ -1312,20 +1310,17 @@
             }
 
             bool should_close;
-            TF_RETURN_IF_ERROR(ShouldCloseFile(*snapshot_data_filename,
-                                               *bytes_written, (*writer).get(),
-                                               (*file).get(), &should_close));
+            TF_RETURN_IF_ERROR(
+                ShouldCloseWriter(*snapshot_data_filename, *bytes_written,
+                                  (*writer).get(), &should_close));
             if (should_close) {
               // If we exceed the shard size, we get a new file and reset.
               TF_RETURN_IF_ERROR((*writer)->Close());
-              TF_RETURN_IF_ERROR((*file)->Sync());
-              TF_RETURN_IF_ERROR((*file)->Close());
               *snapshot_data_filename = GetSnapshotFilename();
-              TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
-                  *snapshot_data_filename, file));
-              *writer = absl::make_unique<snapshot_util::Writer>(
-                  file->get(), dataset()->compression_, kCurrentVersion,
-                  dataset()->output_dtypes());
+
+              TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
+                  env, *snapshot_data_filename, dataset()->compression_,
+                  kCurrentVersion, dataset()->output_dtypes(), writer));
               *bytes_written = 0;
             }
             TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value));
@@ -1334,8 +1329,6 @@
 
           if (*end_of_processing) {
             TF_RETURN_IF_ERROR((*writer)->Close());
-            TF_RETURN_IF_ERROR((*file)->Sync());
-            TF_RETURN_IF_ERROR((*file)->Close());
             mutex_lock l(mu_);
             if (!written_final_metadata_file_) {
               experimental::SnapshotMetadataRecord metadata;
@@ -1358,7 +1351,7 @@
         }
 
         // Just pulls off elements from the buffer and writes them.
-        void WriterThread() {
+        void WriterThread(Env* env) {
           auto cleanup = gtl::MakeCleanup([this]() {
             mutex_lock l(mu_);
             --num_active_threads_;
@@ -1367,9 +1360,10 @@
 
           int64 bytes_written = 0;
           string snapshot_data_filename = GetSnapshotFilename();
-          std::unique_ptr<WritableFile> file;
-          Status s =
-              Env::Default()->NewAppendableFile(snapshot_data_filename, &file);
+          std::unique_ptr<snapshot_util::Writer> writer;
+          Status s = snapshot_util::Writer::Create(
+              env, snapshot_data_filename, dataset()->compression_,
+              kCurrentVersion, dataset()->output_dtypes(), &writer);
           if (!s.ok()) {
             LOG(ERROR) << "Creating " << snapshot_data_filename
                        << " failed: " << s.ToString();
@@ -1378,16 +1372,12 @@
             cond_var_.notify_all();
             return;
           }
-          std::unique_ptr<snapshot_util::Writer> writer(
-              new snapshot_util::Writer(file.get(), dataset()->compression_,
-                                        kCurrentVersion,
-                                        dataset()->output_dtypes()));
 
           bool end_of_processing = false;
           while (!end_of_processing) {
             Status s =
                 ProcessOneElement(&bytes_written, &snapshot_data_filename,
-                                  &file, &writer, &end_of_processing);
+                                  &writer, &end_of_processing, env);
             if (!s.ok()) {
               LOG(INFO) << "Error while writing snapshot data to disk: "
                         << s.ToString();
@@ -1401,9 +1391,9 @@
           }
         }
 
-        Status ShouldCloseFile(const string& filename, uint64 bytes_written,
-                               snapshot_util::Writer* writer,
-                               WritableFile* file, bool* should_close) {
+        Status ShouldCloseWriter(const string& filename, uint64 bytes_written,
+                                 snapshot_util::Writer* writer,
+                                 bool* should_close) {
           // If the compression ratio has been estimated, use it to decide
           // whether the file should be closed. We avoid estimating the
           // compression ratio repeatedly because it requires syncing the file,
@@ -1425,7 +1415,6 @@
           // Use the actual file size to determine compression ratio.
           // Make sure that all bytes are written out.
           TF_RETURN_IF_ERROR(writer->Sync());
-          TF_RETURN_IF_ERROR(file->Sync());
           uint64 file_size;
           TF_RETURN_IF_ERROR(Env::Default()->GetFileSize(filename, &file_size));
           mutex_lock l(mu_);
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.cc b/tensorflow/core/kernels/data/experimental/snapshot_util.cc
index 72d2c5c..ba83366 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_util.cc
+++ b/tensorflow/core/kernels/data/experimental/snapshot_util.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
 
+#include "absl/memory/memory.h"
 #include "tensorflow/core/common_runtime/dma_helper.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/tensor.pb.h"
@@ -39,29 +40,45 @@
 /* static */ constexpr const int64 Reader::kSnappyReaderInputBufferSizeBytes;
 /* static */ constexpr const int64 Reader::kSnappyReaderOutputBufferSizeBytes;
 
-Writer::Writer(WritableFile* dest, const string& compression_type, int version,
-               const DataTypeVector& dtypes)
-    : dest_(dest), compression_type_(compression_type), version_(version) {
+Writer::Writer(const std::string& filename, const std::string& compression_type,
+               int version, const DataTypeVector& dtypes)
+    : filename_(filename),
+      compression_type_(compression_type),
+      version_(version),
+      dtypes_(dtypes) {}
+
+Status Writer::Create(Env* env, const std::string& filename,
+                      const std::string& compression_type, int version,
+                      const DataTypeVector& dtypes,
+                      std::unique_ptr<Writer>* out_writer) {
+  *out_writer =
+      absl::WrapUnique(new Writer(filename, compression_type, version, dtypes));
+
+  return (*out_writer)->Initialize(env);
+}
+
+Status Writer::Initialize(tensorflow::Env* env) {
+  TF_RETURN_IF_ERROR(env->NewWritableFile(filename_, &dest_));
 #if defined(IS_SLIM_BUILD)
-  if (compression_type != io::compression::kNone) {
+  if (compression_type_ != io::compression::kNone) {
     LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
                << "off compression.";
   }
 #else   // IS_SLIM_BUILD
-  if (compression_type == io::compression::kGzip) {
+  if (compression_type_ == io::compression::kGzip) {
+    zlib_underlying_dest_.swap(dest_);
     io::ZlibCompressionOptions zlib_options;
     zlib_options = io::ZlibCompressionOptions::GZIP();
 
-    io::ZlibOutputBuffer* zlib_output_buffer =
-        new io::ZlibOutputBuffer(dest, zlib_options.input_buffer_size,
-                                 zlib_options.output_buffer_size, zlib_options);
+    io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
+        zlib_underlying_dest_.get(), zlib_options.input_buffer_size,
+        zlib_options.output_buffer_size, zlib_options);
     TF_CHECK_OK(zlib_output_buffer->Init());
-    dest_ = zlib_output_buffer;
-    dest_is_owned_ = true;
+    dest_.reset(zlib_output_buffer);
   }
 #endif  // IS_SLIM_BUILD
-  simple_tensor_mask_.reserve(dtypes.size());
-  for (const auto& dtype : dtypes) {
+  simple_tensor_mask_.reserve(dtypes_.size());
+  for (const auto& dtype : dtypes_) {
     if (DataTypeCanUseMemcpy(dtype)) {
       simple_tensor_mask_.push_back(true);
       num_simple_++;
@@ -70,6 +87,8 @@
       num_complex_++;
     }
   }
+
+  return Status::OK();
 }
 
 Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
@@ -156,21 +175,21 @@
 Status Writer::Sync() { return dest_->Sync(); }
 
 Status Writer::Close() {
-  if (dest_is_owned_) {
-    Status s = dest_->Close();
-    delete dest_;
+  if (dest_ != nullptr) {
+    TF_RETURN_IF_ERROR(dest_->Close());
     dest_ = nullptr;
-    return s;
+  }
+  if (zlib_underlying_dest_ != nullptr) {
+    TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close());
+    zlib_underlying_dest_ = nullptr;
   }
   return Status::OK();
 }
 
 Writer::~Writer() {
-  if (dest_ != nullptr) {
-    Status s = Close();
-    if (!s.ok()) {
-      LOG(ERROR) << "Could not finish writing file: " << s;
-    }
+  Status s = Close();
+  if (!s.ok()) {
+    LOG(ERROR) << "Could not finish writing file: " << s;
   }
 }
 
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.h b/tensorflow/core/kernels/data/experimental/snapshot_util.h
index e962bb5..e1c6dbe 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_util.h
+++ b/tensorflow/core/kernels/data/experimental/snapshot_util.h
@@ -20,6 +20,7 @@
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/io/compression.h"
 #include "tensorflow/core/lib/io/inputstream_interface.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/file_system.h"
 #include "tensorflow/core/platform/status.h"
 
@@ -56,8 +57,10 @@
   static constexpr const char* const kWriteCord = "WriteCord";
   static constexpr const char* const kSeparator = "::";
 
-  explicit Writer(WritableFile* dest, const string& compression_type,
-                  int version, const DataTypeVector& dtypes);
+  static Status Create(Env* env, const std::string& filename,
+                       const std::string& compression_type, int version,
+                       const DataTypeVector& dtypes,
+                       std::unique_ptr<Writer>* out_writer);
 
   Status WriteTensors(const std::vector<Tensor>& tensors);
 
@@ -68,16 +71,27 @@
   ~Writer();
 
  private:
+  explicit Writer(const std::string& filename,
+                  const std::string& compression_type, int version,
+                  const DataTypeVector& dtypes);
+
+  Status Initialize(tensorflow::Env* env);
+
   Status WriteRecord(const StringPiece& data);
 
 #if defined(PLATFORM_GOOGLE)
   Status WriteRecord(const absl::Cord& data);
 #endif  // PLATFORM_GOOGLE
 
-  WritableFile* dest_;
-  bool dest_is_owned_ = false;
-  const string compression_type_;
+  std::unique_ptr<WritableFile> dest_;
+  const std::string filename_;
+  const std::string compression_type_;
   const int version_;
+  const DataTypeVector dtypes_;
+  // We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that
+  // in dest_ if we want compression. ZlibOutputBuffer doesn't own the original
+  // dest_ and so we need somewhere to store the original one.
+  std::unique_ptr<WritableFile> zlib_underlying_dest_;
   std::vector<bool> simple_tensor_mask_;  // true for simple, false for complex.
   int num_simple_ = 0;
   int num_complex_ = 0;