blob: 9d1c426e67b6777bd0034d44a074995da465af7d [file] [log] [blame]
#include <cstdio>
#include <cstring>
#include <cerrno>
#include <istream>
#include <ostream>
#include <fstream>
#include <c10/core/Allocator.h>
#include <c10/core/Backend.h>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/serialize/file_adapter.h"
#include "caffe2/serialize/inline_container.h"
#include "caffe2/serialize/istream_adapter.h"
#include "caffe2/serialize/read_adapter_interface.h"
#include "miniz.h"
namespace caffe2 {
namespace serialize {
size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) {
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
return self->read(file_ofs, static_cast<char*>(pBuf), n);
}
static std::string basename(const std::string& name) {
size_t start = 0;
for(size_t i = 0; i < name.size(); ++i) {
if (name[i] == '\\' || name[i] == '/') {
start = i + 1;
}
}
if (start >= name.size())
return "";
size_t end = name.size();
for(size_t i = end; i > start; --i) {
if (name[i - 1] == '.') {
end = i - 1;
break;
}
}
return name.substr(start, end - start);
}
size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
return in_->read(pos, buf, n, "reading file");
}
PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name)
: ar_(caffe2::make_unique<mz_zip_archive>()),
in_(caffe2::make_unique<FileAdapter>(file_name)) {
init();
}
PyTorchStreamReader::PyTorchStreamReader(std::istream* in)
: ar_(caffe2::make_unique<mz_zip_archive>()),
in_(caffe2::make_unique<IStreamAdapter>(in)) {
init();
}
PyTorchStreamReader::PyTorchStreamReader(
std::unique_ptr<ReadAdapterInterface> in)
: ar_(caffe2::make_unique<mz_zip_archive>()), in_(std::move(in)) {
init();
}
void PyTorchStreamReader::init() {
AT_ASSERT(in_ != nullptr);
AT_ASSERT(ar_ != nullptr);
memset(ar_.get(), 0, sizeof(mz_zip_archive));
size_t size = in_->size();
// check for the old magic number,
constexpr size_t kMagicValueLength = 8;
if (size > kMagicValueLength) {
char buf[kMagicValueLength];
read(0, buf, kMagicValueLength);
valid("checking magic number");
AT_ASSERTM(
memcmp("PYTORCH1", buf, kMagicValueLength) != 0,
"File is an unsupported archive format from the preview release.");
}
ar_->m_pIO_opaque = this;
ar_->m_pRead = istream_read_func;
mz_zip_reader_init(ar_.get(), size, 0);
valid("reading zip archive");
// figure out the archive_name (i.e. the zip folder all the other files are in)
// all lookups to getRecord will be prefixed by this folder
int n = mz_zip_reader_get_num_files(ar_.get());
if (n == 0) {
CAFFE_THROW("archive does not contain any files");
}
size_t name_size = mz_zip_reader_get_filename(ar_.get(), 0, nullptr, 0);
valid("getting filename");
std::string buf(name_size, '\0');
mz_zip_reader_get_filename(ar_.get(), 0, &buf[0], name_size);
valid("getting filename");
auto pos = buf.find_first_of('/');
if (pos == std::string::npos) {
CAFFE_THROW("file in archive is not in a subdirectory: ", buf);
}
archive_name_ = buf.substr(0, pos);
// version check
at::DataPtr version_ptr;
size_t version_size;
std::tie(version_ptr, version_size) = getRecord("version");
std::string version(static_cast<const char*>(version_ptr.get()), version_size);
size_t version_number = caffe2::stoull(version);
AT_ASSERTM(
version_number >= kMinSupportedFileFormatVersion,
"Attempted to read a PyTorch file with version ",
c10::to_string(version_number),
", but the minimum supported version for reading is ",
c10::to_string(kMinSupportedFileFormatVersion),
". Your PyTorch script module file is too old. Please re-export it again.");
AT_ASSERTM(
version_number <= kMaxSupportedFileFormatVersion,
"Attempted to read a PyTorch file with version ",
version_number,
", but the maximum supported version for reading is ",
kMaxSupportedFileFormatVersion,
". Your PyTorch installation may be too old.");
}
void PyTorchStreamReader::valid(const char* what) {
auto err = mz_zip_get_last_error(ar_.get());
if (err != MZ_ZIP_NO_ERROR) {
CAFFE_THROW("PytorchStreamReader failed ", what, ": ", mz_zip_get_error_string(err));
}
}
constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
static std::string getPadding(size_t cursor, const std::string& filename, size_t size) {
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename.size() + sizeof(mz_uint16) * 2;
if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
start += sizeof(mz_uint16) * 2;
if (size >= MZ_UINT32_MAX) {
start += 2*sizeof(mz_uint64);
}
if (cursor >= MZ_UINT32_MAX) {
start += sizeof(mz_uint64);
}
}
size_t mod = start % kFieldAlignment;
size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
size_t padding_size = next_offset - start;
std::string buf(padding_size + 4, 'Z');
// zip extra encoding (key, size_of_extra_bytes)
buf[0] = 'F';
buf[1] = 'B';
buf[2] = (uint8_t) padding_size;
buf[3] = (uint8_t) (padding_size >> 8);
return buf;
}
size_t PyTorchStreamReader::getFileID(const std::string& name) {
std::stringstream ss;
ss << archive_name_ << "/" << name;
size_t result = mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0);
if (ar_->m_last_error == MZ_ZIP_FILE_NOT_FOUND) {
CAFFE_THROW("file not found: ", ss.str());
}
valid("locating file");
return result;
}
// return dataptr, size
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
size_t key = getFileID(name);
mz_zip_archive_file_stat stat;
mz_zip_reader_file_stat(ar_.get(), key, &stat);
valid("retrieving file meta-data");
void * ptr = malloc(stat.m_uncomp_size);
mz_zip_reader_extract_to_mem(ar_.get(), key, ptr, stat.m_uncomp_size, 0);
valid("reading file");
at::DataPtr retval(ptr, ptr, free, at::kCPU);
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
}
static int64_t read_le_16(uint8_t* buf) {
return buf[0] + (buf[1] << 8);
}
size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
mz_zip_archive_file_stat stat;
mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat);
valid("retriving file meta-data");
uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
in_->read(
stat.m_local_header_ofs,
local_header,
MZ_ZIP_LOCAL_DIR_HEADER_SIZE,
"reading file header");
size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
}
PyTorchStreamReader::~PyTorchStreamReader() {
mz_zip_reader_end(ar_.get());
valid("closing reader");
}
size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n) {
auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
if (self->current_pos_ != file_ofs) {
// xxx - windows ostringstream refuses to seek to the end of an empty string
// so we workaround this by not calling seek unless necessary
// in the case of the first write (to the empty string) file_ofs and
// current_pos_ will be 0 and the seek won't occur.
self->out_->seekp(file_ofs);
if(!*self->out_)
return 0;
}
self->out_->write(static_cast<const char*>(pBuf), n);
if(!*self->out_)
return 0;
self->current_pos_ = file_ofs + n;
return n;
}
PyTorchStreamWriter::PyTorchStreamWriter(
std::string file_name,
std::ostream* out)
: ar_(caffe2::make_unique<mz_zip_archive>()),
archive_name_(basename(file_name)),
out_(out) {
memset(ar_.get(), 0, sizeof(mz_zip_archive));
if (archive_name_.size() == 0) {
CAFFE_THROW("invalid file name: ", file_name);
}
if (!out_) {
file_stream_.open(file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
out_ = &file_stream_;
valid("opening archive");
}
ar_->m_pIO_opaque = this;
ar_->m_pWrite = ostream_write_func;
mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
valid("initializing archive");
std::stringstream version;
version << kMaxSupportedFileFormatVersion << "\n";
writeRecord("version", version.str().c_str(), version.str().size());
}
void PyTorchStreamWriter::writeRecord(const std::string& name, const void* data, size_t size) {
AT_ASSERT(!finalized_);
std::stringstream ss;
ss << archive_name_ << "/" << name;
const std::string& full_name = ss.str();
std::string padding = getPadding(ar_->m_archive_size, full_name, size);
uint32_t flags = 0;
mz_zip_writer_add_mem_ex_v2(
ar_.get(),
full_name.c_str(),
data,
size,
nullptr,
0,
flags,
0,
0,
nullptr,
padding.c_str(),
padding.size(),
nullptr,
0);
valid("writing file");
}
void PyTorchStreamWriter::writeEndOfFile() {
AT_ASSERT(!finalized_);
finalized_ = true;
mz_zip_writer_finalize_archive(ar_.get());
mz_zip_writer_end(ar_.get());
valid("writing central directory");
if (file_stream_.is_open())
file_stream_.close();
}
void PyTorchStreamWriter::valid(const char* what) {
auto err = mz_zip_get_last_error(ar_.get());
if (err != MZ_ZIP_NO_ERROR) {
CAFFE_THROW("PytorchStreamWriter failed ", what, ": ", mz_zip_get_error_string(err));
}
if (!*out_) {
CAFFE_THROW("PytorchStreamWriter failed ", what, ".");
}
}
PyTorchStreamWriter::~PyTorchStreamWriter() {
if (!finalized_) {
writeEndOfFile();
}
}
} // namespace serialize
} // namespace caffe2