Add `write_record_metadata` to PyTorchFileWriter (#125184)

Add `PyTorchFileWriter.write_record_metadata(record_name, num_bytes)` that
- writes the zipfile header/end of central directory metadata for an entry*
- reserves `num_bytes` in the zipfile for the payload.

*Since the payload is not provided, the CRC32 computation is skipped and 0s are written in the corresponding entry of the zipfile header

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125184
Approved by: https://github.com/albanD
diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc
index 533fd42..173153e 100644
--- a/caffe2/serialize/inline_container.cc
+++ b/caffe2/serialize/inline_container.cc
@@ -612,15 +612,35 @@
   return ret;
 }
 
+// This func will not update combined_uncomp_crc32_ with the uncomp_crc32
+// since there is no way to get the uncomp_crc32 when no buffer is provided.
+size_t ostream_seek_func(
+  void* pOpaque,
+  mz_uint64 file_ofs,
+  size_t n) {
+  auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
+  if (self->current_pos_ != file_ofs) {
+    CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
+  }
+  size_t ret = self->seek_func_(n);
+  if (self->current_pos_ + n != ret) {
+    self->err_seen_ = true;
+  }
+  self->current_pos_ += n;
+  return n;
+}
+
 PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name)
     : archive_name_(basename(file_name)) {
   setup(file_name);
 }
 
 PyTorchStreamWriter::PyTorchStreamWriter(
-    const std::function<size_t(const void*, size_t)> writer_func)
+    const std::function<size_t(const void*, size_t)> writer_func,
+    const std::function<size_t(size_t)> seek_func)
     : archive_name_("archive"),
-      writer_func_(writer_func) {
+      writer_func_(writer_func),
+      seek_func_(seek_func) {
   setup(archive_name_);
 }
 
@@ -649,10 +669,15 @@
       file_stream_.write(static_cast<const char*>(buf), nbytes);
       return !file_stream_ ? 0 : nbytes;
     };
+    seek_func_ = [this](size_t nbytes) -> size_t {
+      file_stream_.seekp(nbytes, std::ios_base::cur);
+      return file_stream_.tellp();
+    };
   }
 
   ar_->m_pIO_opaque = this;
   ar_->m_pWrite = ostream_write_func;
+  ar_->m_pSeek = ostream_seek_func;
 
   mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
   valid("initializing archive ", file_name.c_str());
@@ -682,20 +707,20 @@
       detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
   uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
   mz_zip_writer_add_mem_ex_v2(
-      ar_.get(),
-      full_name.c_str(),
-      data,
-      size,
-      nullptr,
-      0,
-      flags,
-      0,
-      0,
-      nullptr,
-      padding_.c_str(),
-      padding_size,
-      nullptr,
-      0);
+      /*pZip=*/ar_.get(),
+      /*pArchive_name=*/full_name.c_str(),
+      /*pBuf=*/data,
+      /*buf_size=*/size,
+      /*pComment=*/nullptr,
+      /*comment_size=*/0,
+      /*level_and_flags=*/flags,
+      /*uncomp_size=*/0,
+      /*uncomp_crc32=*/0,
+      /*last_modified=*/nullptr,
+      /*user_extra_data=*/padding_.c_str(),
+      /*user_extra_data_len=*/padding_size,
+      /*user_extra_data_central=*/nullptr,
+      /*user_extra_data_central_len=*/0);
   valid("writing file ", name.c_str());
   files_written_.insert(name);
 }
diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h
index 6a13d41..6dea54f 100644
--- a/caffe2/serialize/inline_container.h
+++ b/caffe2/serialize/inline_container.h
@@ -203,11 +203,21 @@
   size_t additional_reader_size_threshold_;
 };
 
+namespace {
+
+size_t default_seek_func(size_t nbytes) {
+  TORCH_CHECK(false, "attempting to write record metadata but seek_func unimplemented, please implement seek_func");
+  return 0;
+}
+
+} // namespace
+
 class TORCH_API PyTorchStreamWriter final {
  public:
   explicit PyTorchStreamWriter(const std::string& archive_name);
   explicit PyTorchStreamWriter(
-      const std::function<size_t(const void*, size_t)> writer_func);
+      const std::function<size_t(const void*, size_t)> writer_func,
+      const std::function<size_t(size_t)> seek_func = default_seek_func);
 
   void setMinVersion(const uint64_t version);
 
@@ -246,6 +256,7 @@
   std::string padding_;
   std::ofstream file_stream_;
   std::function<size_t(const void*, size_t)> writer_func_;
+  std::function<size_t(size_t)> seek_func_;
   uint64_t combined_uncomp_crc32_ = 0;
   std::string serialization_id_;
 
@@ -259,6 +270,10 @@
       uint64_t file_ofs,
       const void* pBuf,
       size_t n);
+  friend size_t ostream_seek_func(
+      void* pOpaque,
+      uint64_t file_ofs,
+      size_t n);
 };
 
 namespace detail {
diff --git a/test/test_serialization.py b/test/test_serialization.py
index 2f7e6ba..e3e7b8c 100644
--- a/test/test_serialization.py
+++ b/test/test_serialization.py
@@ -4000,6 +4000,50 @@
             y['even'][0] = torch.tensor(-0.25, dtype=dtype)
             self.assertEqual(y['x'][:2].to(dtype=torch.float32), torch.tensor([-0.25, 0.25]))
 
+    @parametrize('filename', (True, False))
+    @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
+    def test_filewriter_metadata_writing(self, filename):
+        sd = torch.nn.Linear(3, 5).state_dict()
+        weight_nbytes = sd['weight'].untyped_storage().nbytes()
+        bias_nbytes = sd['bias'].untyped_storage().nbytes()
+        # TemporaryFileName will give a string
+        # NamedTemporaryFile will be treated as a buffer
+        file_creation_func = TemporaryFileName if filename else tempfile.NamedTemporaryFile
+
+        with file_creation_func() as f, file_creation_func() as g:
+            # save state_dict in f
+            torch.save(sd, f)
+            if not filename:
+                f.seek(0)
+            # extract 'data.pkl' for use in our fake checkpoint
+            with torch.serialization._open_file_like(f, 'rb') as opened_file:
+                with torch.serialization._open_zipfile_reader(opened_file) as zip_file:
+                    data_file = io.BytesIO(zip_file.get_record('data.pkl'))
+                    data_0_offset = zip_file.get_record_offset('data/0')
+                    data_1_offset = zip_file.get_record_offset('data/1')
+
+            # write nulls for 'data/0' and 'data/1'
+            with open(f if filename else f.name, 'rb+') as opened_f:
+                opened_f.seek(data_0_offset)
+                opened_f.write(b'0' * weight_nbytes)
+                opened_f.seek(data_1_offset)
+                opened_f.write(b'0' * bias_nbytes)
+
+            with torch.serialization._open_zipfile_writer(g) as zip_file:
+                data_value = data_file.getvalue()
+                zip_file.write_record('data.pkl', data_value, len(data_value))
+                zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
+                # Only write metadata for storages
+                zip_file.write_record_metadata('data/0', weight_nbytes)
+                zip_file.write_record_metadata('data/1', bias_nbytes)
+
+            if not filename:
+                f.seek(0)
+                g.seek(0)
+            sd_loaded = torch.load(g)
+            sd_loaded_ref = torch.load(f)
+            self.assertEqual(sd_loaded, sd_loaded_ref)
+
     def run(self, *args, **kwargs):
         with serialization_method(use_zip=True):
             return super().run(*args, **kwargs)
diff --git a/third_party/miniz-2.1.0/miniz.c b/third_party/miniz-2.1.0/miniz.c
index 4b5d53f..7d526cf 100755
--- a/third_party/miniz-2.1.0/miniz.c
+++ b/third_party/miniz-2.1.0/miniz.c
@@ -6250,6 +6250,7 @@
     mz_uint32 extra_size = 0;
     mz_uint8 extra_data[MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE];
     mz_uint16 bit_flags = 0;
+    mz_bool write_metadata_only = buf_size && !pBuf;
 
     if ((int)level_and_flags < 0)
         level_and_flags = MZ_DEFAULT_LEVEL;
@@ -6263,7 +6264,7 @@
     level = level_and_flags & 0xF;
     store_data_uncompressed = ((!level) || (level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA));
 
-    if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING) || ((buf_size) && (!pBuf)) || (!pArchive_name) || ((comment_size) && (!pComment)) || (level > MZ_UBER_COMPRESSION))
+    if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING) || (!pArchive_name) || ((comment_size) && (!pComment)) || (level > MZ_UBER_COMPRESSION))
         return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER);
 
     pState = pZip->m_pState;
@@ -6308,7 +6309,9 @@
 
 	if (!(level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA))
 	{
-		uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size);
+        if (!write_metadata_only) {
+            uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size);
+        }
 		uncomp_size = buf_size;
 		if (uncomp_size <= 3)
 		{
@@ -6330,8 +6333,8 @@
     if (!pState->m_zip64)
     {
         /* Bail early if the archive would obviously become too large */
-        if ((pZip->m_archive_size + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + archive_name_size 
-			+ MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + comment_size + user_extra_data_len + 
+        if ((pZip->m_archive_size + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + archive_name_size
+			+ MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + comment_size + user_extra_data_len +
 			pState->m_central_dir.m_size + MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE + user_extra_data_central_len
 			+ MZ_ZIP_DATA_DESCRIPTER_SIZE32) > 0xFFFFFFFF)
         {
@@ -6442,7 +6445,14 @@
 
     if (store_data_uncompressed)
     {
-        if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pBuf, buf_size) != buf_size)
+        mz_bool write_failed;
+        if (write_metadata_only) {
+            write_failed = pZip->m_pSeek(pZip->m_pIO_opaque, cur_archive_file_ofs, buf_size) != buf_size;
+        } else {
+            write_failed = pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pBuf, buf_size) != buf_size;
+        }
+
+        if (write_failed)
         {
             pZip->m_pFree(pZip->m_pAlloc_opaque, pComp);
             return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED);
diff --git a/third_party/miniz-2.1.0/miniz.h b/third_party/miniz-2.1.0/miniz.h
index 2cad137..cb7eb9d 100755
--- a/third_party/miniz-2.1.0/miniz.h
+++ b/third_party/miniz-2.1.0/miniz.h
@@ -116,7 +116,7 @@
 
 
 
-/* Defines to completely disable specific portions of miniz.c: 
+/* Defines to completely disable specific portions of miniz.c:
    If all macros here are defined the only functionality remaining will be CRC-32, adler-32, tinfl, and tdefl. */
 
 /* Define MINIZ_NO_STDIO to disable all usage and any functions which rely on stdio for file I/O. */
@@ -139,7 +139,7 @@
 /* Define MINIZ_NO_ZLIB_COMPATIBLE_NAME to disable zlib names, to prevent conflicts against stock zlib. */
 #define MINIZ_NO_ZLIB_COMPATIBLE_NAMES
 
-/* Define MINIZ_NO_MALLOC to disable all calls to malloc, free, and realloc. 
+/* Define MINIZ_NO_MALLOC to disable all calls to malloc, free, and realloc.
    Note if MINIZ_NO_MALLOC is defined then the user must always provide custom user alloc/free/realloc
    callbacks to the zlib and archive API's, and a few stand-alone helper API's which don't provide custom user
    functions (such as tdefl_compress_mem_to_heap() and tinfl_decompress_mem_to_heap()) won't work. */
@@ -980,6 +980,7 @@
 
 typedef size_t (*mz_file_read_func)(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n);
 typedef size_t (*mz_file_write_func)(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n);
+typedef size_t (*mz_file_seek_func)(void *pOpaque, mz_uint64 file_ofs, size_t n);
 typedef mz_bool (*mz_file_needs_keepalive)(void *pOpaque);
 
 struct mz_zip_internal_state_tag;
@@ -1071,6 +1072,7 @@
 
     mz_file_read_func m_pRead;
     mz_file_write_func m_pWrite;
+    mz_file_seek_func m_pSeek;
     mz_file_needs_keepalive m_pNeeds_keepalive;
     void *m_pIO_opaque;
 
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index a5e3c60..8b3e606 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -1394,10 +1394,22 @@
           buffer.attr("write")(std::move(memory_view));
           return size;
         };
-        return std::make_unique<PyTorchStreamWriter>(std::move(writer_func));
+        auto seek_func = [=](size_t offset) {
+          auto current_pos = py::cast<size_t>(buffer.attr("tell")());
+          buffer.attr("seek")(
+              offset, py::module::import("os").attr("SEEK_CUR"));
+          return current_pos + offset;
+        };
+        return std::make_unique<PyTorchStreamWriter>(
+            std::move(writer_func), std::move(seek_func));
       }))
       .def(py::init<const std::function<size_t(const void*, size_t)>&>())
       .def(
+          "write_record_metadata",
+          [](PyTorchStreamWriter& self, const std::string& name, size_t size) {
+            return self.writeRecord(name, nullptr, size);
+          })
+      .def(
           "write_record",
           [](PyTorchStreamWriter& self,
              const std::string& name,