Add `write_record_metadata` to PyTorchFileWriter (#125184)
Add `PyTorchFileWriter.write_record_metadata(record_name, num_bytes)` that
- writes the zipfile header/end of central directory metadata for an entry*
- reserves `num_bytes` in the zipfile for the payload.
*Since the payload is not provided, the CRC32 computation is skipped and 0s are written in the corresponding entry of the zipfile header
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125184
Approved by: https://github.com/albanD
diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc
index 533fd42..173153e 100644
--- a/caffe2/serialize/inline_container.cc
+++ b/caffe2/serialize/inline_container.cc
@@ -612,15 +612,35 @@
return ret;
}
+// This func will not update combined_uncomp_crc32_ with the uncomp_crc32
+// since there is no way to get the uncomp_crc32 when no buffer is provided.
+size_t ostream_seek_func(
+ void* pOpaque,
+ mz_uint64 file_ofs,
+ size_t n) {
+ auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
+ if (self->current_pos_ != file_ofs) {
+ CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
+ }
+ size_t ret = self->seek_func_(n);
+ if (self->current_pos_ + n != ret) {
+ self->err_seen_ = true;
+ }
+ self->current_pos_ += n;
+ return n;
+}
+
PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name)
: archive_name_(basename(file_name)) {
setup(file_name);
}
PyTorchStreamWriter::PyTorchStreamWriter(
- const std::function<size_t(const void*, size_t)> writer_func)
+ const std::function<size_t(const void*, size_t)> writer_func,
+ const std::function<size_t(size_t)> seek_func)
: archive_name_("archive"),
- writer_func_(writer_func) {
+ writer_func_(writer_func),
+ seek_func_(seek_func) {
setup(archive_name_);
}
@@ -649,10 +669,15 @@
file_stream_.write(static_cast<const char*>(buf), nbytes);
return !file_stream_ ? 0 : nbytes;
};
+ seek_func_ = [this](size_t nbytes) -> size_t {
+ file_stream_.seekp(nbytes, std::ios_base::cur);
+ return file_stream_.tellp();
+ };
}
ar_->m_pIO_opaque = this;
ar_->m_pWrite = ostream_write_func;
+ ar_->m_pSeek = ostream_seek_func;
mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
valid("initializing archive ", file_name.c_str());
@@ -682,20 +707,20 @@
detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
mz_zip_writer_add_mem_ex_v2(
- ar_.get(),
- full_name.c_str(),
- data,
- size,
- nullptr,
- 0,
- flags,
- 0,
- 0,
- nullptr,
- padding_.c_str(),
- padding_size,
- nullptr,
- 0);
+ /*pZip=*/ar_.get(),
+ /*pArchive_name=*/full_name.c_str(),
+ /*pBuf=*/data,
+ /*buf_size=*/size,
+ /*pComment=*/nullptr,
+ /*comment_size=*/0,
+ /*level_and_flags=*/flags,
+ /*uncomp_size=*/0,
+ /*uncomp_crc32=*/0,
+ /*last_modified=*/nullptr,
+ /*user_extra_data=*/padding_.c_str(),
+ /*user_extra_data_len=*/padding_size,
+ /*user_extra_data_central=*/nullptr,
+ /*user_extra_data_central_len=*/0);
valid("writing file ", name.c_str());
files_written_.insert(name);
}
diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h
index 6a13d41..6dea54f 100644
--- a/caffe2/serialize/inline_container.h
+++ b/caffe2/serialize/inline_container.h
@@ -203,11 +203,21 @@
size_t additional_reader_size_threshold_;
};
+namespace {
+
+size_t default_seek_func(size_t nbytes) {
+ TORCH_CHECK(false, "attempting to write record metadata but seek_func unimplemented, please implement seek_func");
+ return 0;
+}
+
+} // namespace
+
class TORCH_API PyTorchStreamWriter final {
public:
explicit PyTorchStreamWriter(const std::string& archive_name);
explicit PyTorchStreamWriter(
- const std::function<size_t(const void*, size_t)> writer_func);
+ const std::function<size_t(const void*, size_t)> writer_func,
+ const std::function<size_t(size_t)> seek_func = default_seek_func);
void setMinVersion(const uint64_t version);
@@ -246,6 +256,7 @@
std::string padding_;
std::ofstream file_stream_;
std::function<size_t(const void*, size_t)> writer_func_;
+ std::function<size_t(size_t)> seek_func_;
uint64_t combined_uncomp_crc32_ = 0;
std::string serialization_id_;
@@ -259,6 +270,10 @@
uint64_t file_ofs,
const void* pBuf,
size_t n);
+ friend size_t ostream_seek_func(
+ void* pOpaque,
+ uint64_t file_ofs,
+ size_t n);
};
namespace detail {
diff --git a/test/test_serialization.py b/test/test_serialization.py
index 2f7e6ba..e3e7b8c 100644
--- a/test/test_serialization.py
+++ b/test/test_serialization.py
@@ -4000,6 +4000,50 @@
y['even'][0] = torch.tensor(-0.25, dtype=dtype)
self.assertEqual(y['x'][:2].to(dtype=torch.float32), torch.tensor([-0.25, 0.25]))
+ @parametrize('filename', (True, False))
+ @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
+ def test_filewriter_metadata_writing(self, filename):
+ sd = torch.nn.Linear(3, 5).state_dict()
+ weight_nbytes = sd['weight'].untyped_storage().nbytes()
+ bias_nbytes = sd['bias'].untyped_storage().nbytes()
+ # TemporaryFileName will give a string
+ # NamedTemporaryFile will be treated as a buffer
+ file_creation_func = TemporaryFileName if filename else tempfile.NamedTemporaryFile
+
+ with file_creation_func() as f, file_creation_func() as g:
+ # save state_dict in f
+ torch.save(sd, f)
+ if not filename:
+ f.seek(0)
+ # extract 'data.pkl' for use in our fake checkpoint
+ with torch.serialization._open_file_like(f, 'rb') as opened_file:
+ with torch.serialization._open_zipfile_reader(opened_file) as zip_file:
+ data_file = io.BytesIO(zip_file.get_record('data.pkl'))
+ data_0_offset = zip_file.get_record_offset('data/0')
+ data_1_offset = zip_file.get_record_offset('data/1')
+
+ # write nulls for 'data/0' and 'data/1'
+ with open(f if filename else f.name, 'rb+') as opened_f:
+ opened_f.seek(data_0_offset)
+ opened_f.write(b'0' * weight_nbytes)
+ opened_f.seek(data_1_offset)
+ opened_f.write(b'0' * bias_nbytes)
+
+ with torch.serialization._open_zipfile_writer(g) as zip_file:
+ data_value = data_file.getvalue()
+ zip_file.write_record('data.pkl', data_value, len(data_value))
+ zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
+ # Only write metadata for storages
+ zip_file.write_record_metadata('data/0', weight_nbytes)
+ zip_file.write_record_metadata('data/1', bias_nbytes)
+
+ if not filename:
+ f.seek(0)
+ g.seek(0)
+ sd_loaded = torch.load(g)
+ sd_loaded_ref = torch.load(f)
+ self.assertEqual(sd_loaded, sd_loaded_ref)
+
def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super().run(*args, **kwargs)
diff --git a/third_party/miniz-2.1.0/miniz.c b/third_party/miniz-2.1.0/miniz.c
index 4b5d53f..7d526cf 100755
--- a/third_party/miniz-2.1.0/miniz.c
+++ b/third_party/miniz-2.1.0/miniz.c
@@ -6250,6 +6250,7 @@
mz_uint32 extra_size = 0;
mz_uint8 extra_data[MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE];
mz_uint16 bit_flags = 0;
+ mz_bool write_metadata_only = buf_size && !pBuf;
if ((int)level_and_flags < 0)
level_and_flags = MZ_DEFAULT_LEVEL;
@@ -6263,7 +6264,7 @@
level = level_and_flags & 0xF;
store_data_uncompressed = ((!level) || (level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA));
- if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING) || ((buf_size) && (!pBuf)) || (!pArchive_name) || ((comment_size) && (!pComment)) || (level > MZ_UBER_COMPRESSION))
+ if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING) || (!pArchive_name) || ((comment_size) && (!pComment)) || (level > MZ_UBER_COMPRESSION))
return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER);
pState = pZip->m_pState;
@@ -6308,7 +6309,9 @@
if (!(level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA))
{
- uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size);
+ if (!write_metadata_only) {
+ uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size);
+ }
uncomp_size = buf_size;
if (uncomp_size <= 3)
{
@@ -6330,8 +6333,8 @@
if (!pState->m_zip64)
{
/* Bail early if the archive would obviously become too large */
- if ((pZip->m_archive_size + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + archive_name_size
- + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + comment_size + user_extra_data_len +
+ if ((pZip->m_archive_size + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + archive_name_size
+ + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + comment_size + user_extra_data_len +
pState->m_central_dir.m_size + MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE + user_extra_data_central_len
+ MZ_ZIP_DATA_DESCRIPTER_SIZE32) > 0xFFFFFFFF)
{
@@ -6442,7 +6445,14 @@
if (store_data_uncompressed)
{
- if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pBuf, buf_size) != buf_size)
+ mz_bool write_failed;
+ if (write_metadata_only) {
+ write_failed = pZip->m_pSeek(pZip->m_pIO_opaque, cur_archive_file_ofs, buf_size) != buf_size;
+ } else {
+ write_failed = pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pBuf, buf_size) != buf_size;
+ }
+
+ if (write_failed)
{
pZip->m_pFree(pZip->m_pAlloc_opaque, pComp);
return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED);
diff --git a/third_party/miniz-2.1.0/miniz.h b/third_party/miniz-2.1.0/miniz.h
index 2cad137..cb7eb9d 100755
--- a/third_party/miniz-2.1.0/miniz.h
+++ b/third_party/miniz-2.1.0/miniz.h
@@ -116,7 +116,7 @@
-/* Defines to completely disable specific portions of miniz.c:
+/* Defines to completely disable specific portions of miniz.c:
If all macros here are defined the only functionality remaining will be CRC-32, adler-32, tinfl, and tdefl. */
/* Define MINIZ_NO_STDIO to disable all usage and any functions which rely on stdio for file I/O. */
@@ -139,7 +139,7 @@
/* Define MINIZ_NO_ZLIB_COMPATIBLE_NAME to disable zlib names, to prevent conflicts against stock zlib. */
#define MINIZ_NO_ZLIB_COMPATIBLE_NAMES
-/* Define MINIZ_NO_MALLOC to disable all calls to malloc, free, and realloc.
+/* Define MINIZ_NO_MALLOC to disable all calls to malloc, free, and realloc.
Note if MINIZ_NO_MALLOC is defined then the user must always provide custom user alloc/free/realloc
callbacks to the zlib and archive API's, and a few stand-alone helper API's which don't provide custom user
functions (such as tdefl_compress_mem_to_heap() and tinfl_decompress_mem_to_heap()) won't work. */
@@ -980,6 +980,7 @@
typedef size_t (*mz_file_read_func)(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n);
typedef size_t (*mz_file_write_func)(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n);
+typedef size_t (*mz_file_seek_func)(void *pOpaque, mz_uint64 file_ofs, size_t n);
typedef mz_bool (*mz_file_needs_keepalive)(void *pOpaque);
struct mz_zip_internal_state_tag;
@@ -1071,6 +1072,7 @@
mz_file_read_func m_pRead;
mz_file_write_func m_pWrite;
+ mz_file_seek_func m_pSeek;
mz_file_needs_keepalive m_pNeeds_keepalive;
void *m_pIO_opaque;
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index a5e3c60..8b3e606 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -1394,10 +1394,22 @@
buffer.attr("write")(std::move(memory_view));
return size;
};
- return std::make_unique<PyTorchStreamWriter>(std::move(writer_func));
+ auto seek_func = [=](size_t offset) {
+ auto current_pos = py::cast<size_t>(buffer.attr("tell")());
+ buffer.attr("seek")(
+ offset, py::module::import("os").attr("SEEK_CUR"));
+ return current_pos + offset;
+ };
+ return std::make_unique<PyTorchStreamWriter>(
+ std::move(writer_func), std::move(seek_func));
}))
.def(py::init<const std::function<size_t(const void*, size_t)>&>())
.def(
+ "write_record_metadata",
+ [](PyTorchStreamWriter& self, const std::string& name, size_t size) {
+ return self.writeRecord(name, nullptr, size);
+ })
+ .def(
"write_record",
[](PyTorchStreamWriter& self,
const std::string& name,