Revert D34805092: Extend _save_for_mobile and _load_for_mobile to support flatbuffer format; Default format is pickle + Change buck targets to support `only pickle` and `pickle + flatbuffer` for migration
Test Plan: revert-hammer
Differential Revision:
D34805092 (https://github.com/pytorch/pytorch/commit/284b2b713592728c362d2d9a2813e021895e89eb)
Original commit changeset: 57f3fc81d68f
Original Phabricator Diff: D34805092 (https://github.com/pytorch/pytorch/commit/284b2b713592728c362d2d9a2813e021895e89eb)
fbshipit-source-id: 780dfb6fd6ba5f9348f24a2fb3c57971b7155541
(cherry picked from commit bebeb8b84e11c34cbde4857d0e1c291731a7c781)
diff --git a/test/cpp/jit/test_flatbuffer.cpp b/test/cpp/jit/test_flatbuffer.cpp
index 18c3f79..a25143a 100644
--- a/test/cpp/jit/test_flatbuffer.cpp
+++ b/test/cpp/jit/test_flatbuffer.cpp
@@ -153,30 +153,16 @@
extra_files["metadata.json"] = "abc";
extra_files["mobile_info.json"] = "{\"key\": 23}";
- std::unordered_map<std::string, std::string> loaded_extra_files;
-#if defined ENABLE_FLATBUFFER
- std::stringstream ss;
- module->_save_for_mobile(ss, extra_files, true, /*use_flatbuffer=*/true);
-
- loaded_extra_files["metadata.json"] = "";
- auto mobile_module = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
-
- ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
- ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
-
- // load it twice using the same stream
- auto mobile_module2 = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
-#else
CompilationOptions options;
mobile::Module bc = jitModuleToMobile(*module, options);
auto buff = save_mobile_module_to_bytes(bc, extra_files);
+ std::unordered_map<std::string, std::string> loaded_extra_files;
loaded_extra_files["metadata.json"] = "";
auto* flatbuffer_module =
mobile::serialization::GetMutableModule(buff.data());
parseExtraFiles(flatbuffer_module, loaded_extra_files);
-#endif
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp
index e42cfdd..5e00eaf 100644
--- a/test/cpp/jit/test_lite_interpreter.cpp
+++ b/test/cpp/jit/test_lite_interpreter.cpp
@@ -991,9 +991,9 @@
module->_save_for_mobile(oss, extra_files);
std::istringstream iss(oss.str());
+ caffe2::serialize::IStreamAdapter adapter{&iss};
std::unordered_map<std::string, std::string> loaded_extra_files;
loaded_extra_files["metadata.json"] = "";
- ASSERT_TRUE(iss.tellg() == std::ios::beg);
torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
@@ -1006,8 +1006,7 @@
loaded_extra_files[file_name.substr(6)] = "";
}
}
- iss.seekg(0, std::ios::beg);
- ASSERT_TRUE(iss.tellg() == std::ios::beg);
+
torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h
index a6aa492..a11bf04 100644
--- a/torch/csrc/jit/api/module.h
+++ b/torch/csrc/jit/api/module.h
@@ -223,14 +223,12 @@
void _save_for_mobile(
std::ostream& out,
const ExtraFilesMap& extra_files = ExtraFilesMap(),
- bool save_mobile_debug_info = false,
- bool use_flatbuffer = false) const;
+ bool save_mobile_debug_info = false) const;
void _save_for_mobile(
const std::string& filename,
const ExtraFilesMap& extra_files = ExtraFilesMap(),
- bool save_mobile_debug_info = false,
- bool use_flatbuffer = false) const;
+ bool save_mobile_debug_info = false) const;
Module copy() const;
diff --git a/torch/csrc/jit/api/module_save.cpp b/torch/csrc/jit/api/module_save.cpp
index 912c386..c8afa5ef 100644
--- a/torch/csrc/jit/api/module_save.cpp
+++ b/torch/csrc/jit/api/module_save.cpp
@@ -16,29 +16,25 @@
void Module::_save_for_mobile(
std::ostream& out,
const ExtraFilesMap& extra_files,
- bool save_mobile_debug_info,
- bool use_flatbuffer) const {
+ bool save_mobile_debug_info) const {
ExportModule(
*this,
out,
extra_files,
true /* bytecode_format */,
- save_mobile_debug_info,
- use_flatbuffer);
+ save_mobile_debug_info);
}
void Module::_save_for_mobile(
const std::string& filename,
const ExtraFilesMap& extra_files,
- bool save_mobile_debug_info,
- bool use_flatbuffer) const {
+ bool save_mobile_debug_info) const {
ExportModule(
*this,
filename,
extra_files,
true /* bytecode_format */,
- save_mobile_debug_info,
- use_flatbuffer);
+ save_mobile_debug_info);
}
} // namespace jit
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
index 0140d90..9b09db7 100644
--- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
@@ -609,34 +609,6 @@
return std::make_tuple(data, size);
}
-std::tuple<std::shared_ptr<char>, size_t> get_stream_content(std::istream& in) {
- // get size of the stream and reset to orig
- std::streampos orig_pos = in.tellg();
- in.seekg(orig_pos, std::ios::end);
- const long size = in.tellg();
- in.seekg(orig_pos, in.beg);
-
- // read stream
- // NOLINT make sure buffer size is multiple of alignment
- size_t buffer_size =
- (size / FLATBUFFERS_MAX_ALIGNMENT + 1) * FLATBUFFERS_MAX_ALIGNMENT;
-#ifdef _WIN32
- std::shared_ptr<char> data(
- static_cast<char*>(
- _aligned_malloc(buffer_size, FLATBUFFERS_MAX_ALIGNMENT)),
- _aligned_free); // NOLINT
-#else
- std::shared_ptr<char> data(
- static_cast<char*>(aligned_alloc(FLATBUFFERS_MAX_ALIGNMENT, buffer_size)),
- free); // NOLINT
-#endif
- in.read(data.get(), size);
-
- // reset stream to original position
- in.seekg(orig_pos, in.beg);
- return std::make_tuple(data, size);
-}
-
void FlatbufferLoader::extractJitSourceAndConstants(
ExtraFilesMap* jit_sources,
std::vector<IValue>* constants) {
@@ -654,9 +626,6 @@
std::shared_ptr<char> data,
size_t,
c10::optional<at::Device>) {
- TORCH_CHECK(
- mobile::serialization::ModuleBufferHasIdentifier(data.get()),
- "Format error");
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
m.set_delete_memory(std::move(data));
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h
index ab01fcf..2947e3a 100644
--- a/torch/csrc/jit/mobile/flatbuffer_loader.h
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.h
@@ -59,9 +59,6 @@
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_file_content(
const char* filename);
-TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
- std::istream& in);
-
class TORCH_API FlatbufferLoader {
public:
FlatbufferLoader();
diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp
index ea1b3ca..b1c7079 100644
--- a/torch/csrc/jit/mobile/import.cpp
+++ b/torch/csrc/jit/mobile/import.cpp
@@ -10,10 +10,6 @@
#include <caffe2/serialize/inline_container.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/api/compilation_unit.h>
-#include <torch/csrc/jit/mobile/file_format.h>
-#if defined(ENABLE_FLATBUFFER)
-#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
-#endif
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
@@ -540,72 +536,18 @@
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
- auto format = getFileFormat(in);
- switch (format) {
- case FileFormat::ZipFileFormat: {
- std::unique_ptr<IStreamAdapter> rai =
- std::make_unique<IStreamAdapter>(&in);
- auto module = _load_for_mobile(std::move(rai), device, extra_files);
- return module;
- }
-#if defined(ENABLE_FLATBUFFER)
- case FileFormat::FlatbufferFileFormat: {
- std::shared_ptr<char> data;
- size_t size = 0;
- std::tie(data, size) = get_stream_content(in);
- auto* flatbuffer_module =
- mobile::serialization::GetMutableModule(data.get());
- mobile::Module m = initialize_mobile_module(flatbuffer_module);
- parseExtraFiles(flatbuffer_module, extra_files);
- return m;
- }
-#else
- case FileFormat::FlatbufferFileFormat: {
- TORCH_CHECK(
- false,
- "Flatbuffer input file but the build hasn't enabled flatbuffer");
- }
-#endif
- default: {
- TORCH_CHECK(false, "Format error");
- }
- }
+ std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
+ auto module = _load_for_mobile(std::move(rai), device, extra_files);
+ return module;
}
mobile::Module _load_for_mobile(
const std::string& filename,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
- auto format = getFileFormat(filename);
- switch (format) {
- case FileFormat::ZipFileFormat: {
- std::unique_ptr<FileAdapter> rai =
- std::make_unique<FileAdapter>(filename);
- auto module = _load_for_mobile(std::move(rai), device, extra_files);
- return module;
- }
-#if defined(ENABLE_FLATBUFFER)
- case FileFormat::FlatbufferFileFormat: {
- std::shared_ptr<char> data;
- size_t size = 0;
- std::tie(data, size) = get_file_content(filename.c_str());
- auto* flatbuffer_module =
- mobile::serialization::GetMutableModule(data.get());
- mobile::Module m = initialize_mobile_module(flatbuffer_module);
- parseExtraFiles(flatbuffer_module, extra_files);
- return m;
- }
-#else
- case FileFormat::FlatbufferFileFormat: {
- TORCH_CHECK(
- false,
- "Flatbuffer input file but the build hasn't enabled flatbuffer");
- }
-#endif
- default: {
- TORCH_CHECK(false, "Format error");
- }
- }
+ std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
+ auto module = _load_for_mobile(std::move(rai), device, extra_files);
+ return module;
}
mobile::Module _load_for_mobile(
@@ -613,37 +555,10 @@
c10::optional<at::Device> device,
ExtraFilesMap& extra_files,
uint64_t module_load_options) {
- auto format = getFileFormat(filename);
- switch (format) {
- case FileFormat::ZipFileFormat: {
- std::unique_ptr<FileAdapter> rai =
- std::make_unique<FileAdapter>(filename);
- auto module = _load_for_mobile_impl(
- std::move(rai), device, extra_files, module_load_options);
- return module;
- }
-#if defined(ENABLE_FLATBUFFER)
- case FileFormat::FlatbufferFileFormat: {
- std::shared_ptr<char> data;
- size_t size = 0;
- std::tie(data, size) = get_file_content(filename.c_str());
- auto* flatbuffer_module =
- mobile::serialization::GetMutableModule(data.get());
- mobile::Module m = initialize_mobile_module(flatbuffer_module);
- parseExtraFiles(flatbuffer_module, extra_files);
- return m;
- }
-#else
- case FileFormat::FlatbufferFileFormat: {
- TORCH_CHECK(
- false,
- "Flatbuffer input file but the build hasn't enabled flatbuffer");
- }
-#endif
- default: {
- TORCH_CHECK(false, "Format error");
- }
- }
+ std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
+ auto module = _load_for_mobile_impl(
+ std::move(rai), device, extra_files, module_load_options);
+ return module;
}
mobile::Module _load_for_mobile(
diff --git a/torch/csrc/jit/mobile/import_data.cpp b/torch/csrc/jit/mobile/import_data.cpp
index cbf644a..c67b521 100644
--- a/torch/csrc/jit/mobile/import_data.cpp
+++ b/torch/csrc/jit/mobile/import_data.cpp
@@ -9,6 +9,7 @@
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/unpickler.h>
#include <torch/custom_class.h>
+
#include <exception>
#include <fstream>
#include <string>
diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp
index 35a8109..f734967 100644
--- a/torch/csrc/jit/python/script_init.cpp
+++ b/torch/csrc/jit/python/script_init.cpp
@@ -1096,32 +1096,23 @@
[](Module& m,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap(),
- bool _save_mobile_debug_info = false,
- bool _use_flatbuffer = false) {
- m._save_for_mobile(
- filename,
- _extra_files,
- _save_mobile_debug_info,
- _use_flatbuffer);
+ bool _save_mobile_debug_info = false) {
+ m._save_for_mobile(filename, _extra_files, _save_mobile_debug_info);
},
py::arg("filename"),
py::arg("_extra_files") = ExtraFilesMap(),
- py::arg("_save_mobile_debug_info") = false,
- py::arg("_use_flatbuffer") = false)
+ py::arg("_save_mobile_debug_info") = false)
.def(
"_save_to_buffer_for_mobile",
[](Module& m,
const ExtraFilesMap& _extra_files = ExtraFilesMap(),
- bool _save_mobile_debug_info = false,
- bool _use_flatbuffer = false) {
+ bool _save_mobile_debug_info = false) {
std::ostringstream buf;
- m._save_for_mobile(
- buf, _extra_files, _save_mobile_debug_info, _use_flatbuffer);
+ m._save_for_mobile(buf, _extra_files, _save_mobile_debug_info);
return py::bytes(buf.str());
},
py::arg("_extra_files") = ExtraFilesMap(),
- py::arg("_save_mobile_debug_info") = false,
- py::arg("_use_flatbuffer") = false)
+ py::arg("_save_mobile_debug_info") = false)
.def("_set_optimized", &Module::set_optimized)
.def(
"dump",
@@ -1900,10 +1891,6 @@
std::istringstream in(buffer);
return _get_mobile_model_contained_types(in);
});
- m.def("_nn_module_to_mobile", [](const Module& module) {
- CompilationOptions options;
- return jitModuleToMobile(module, options);
- });
py::class_<OperatorInfo>(m, "OperatorInfo")
.def_readonly("num_schema_args", &OperatorInfo::num_schema_args);
m.def("_get_model_ops_and_info", [](const std::string& filename) {
diff --git a/torch/csrc/jit/serialization/export.h b/torch/csrc/jit/serialization/export.h
index 36effaa..a865daf 100644
--- a/torch/csrc/jit/serialization/export.h
+++ b/torch/csrc/jit/serialization/export.h
@@ -158,24 +158,21 @@
std::ostream& out,
const ExtraFilesMap& metadata = ExtraFilesMap(),
bool bytecode_format = false,
- bool save_mobile_debug_info = false,
- bool use_flatbuffer = false);
+ bool save_mobile_debug_info = false);
TORCH_API void ExportModule(
const Module& module,
const std::string& filename,
const ExtraFilesMap& metadata = ExtraFilesMap(),
bool bytecode_format = false,
- bool save_mobile_debug_info = false,
- bool use_flatbuffer = false);
+ bool save_mobile_debug_info = false);
TORCH_API void ExportModule(
const Module& module,
const std::function<size_t(const void*, size_t)>& writer_func,
const ExtraFilesMap& metadata = ExtraFilesMap(),
bool bytecode_format = false,
- bool save_mobile_debug_info = false,
- bool use_flatbuffer = false);
+ bool save_mobile_debug_info = false);
// Write the bytes of a pickle archive and the tensors referenced inside that
// archive
diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp
index 25e5299..45cbbdf 100644
--- a/torch/csrc/jit/serialization/export_module.cpp
+++ b/torch/csrc/jit/serialization/export_module.cpp
@@ -16,9 +16,6 @@
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
#include <torch/csrc/jit/serialization/export_bytecode.h>
-#if defined(ENABLE_FLATBUFFER)
-#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
-#endif
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/csrc/jit/serialization/import_export_helpers.h>
@@ -791,45 +788,20 @@
return storage_context_;
}
-#if defined(ENABLE_FLATBUFFER)
-void save_mobile_module_to(
- const Module& module,
- const ExtraFilesMap& extra_files,
- bool save_mobile_debug_info,
- const std::function<size_t(const void*, size_t)>& writer_func) {
- CompilationOptions options = getOptionsFromGlobal();
- mobile::Module mod = jitModuleToMobile(module, options);
- auto buffer = save_mobile_module_to_bytes(mod, extra_files);
- writer_func(reinterpret_cast<void*>(buffer.data()), buffer.size());
-}
-#endif
-
void ExportModule(
const Module& module,
std::ostream& out,
const ExtraFilesMap& extra_files,
bool bytecode_format,
- bool save_mobile_debug_info,
- bool use_flatbuffer) {
- auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
- out.write(static_cast<const char*>(buf), nbytes);
- return !out ? 0 : nbytes;
- };
- if (use_flatbuffer) {
-#if defined(ENABLE_FLATBUFFER)
- save_mobile_module_to(
- module, extra_files, save_mobile_debug_info, writer_func);
-#else
- TORCH_CHECK(
- false,
- "Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
-#endif
- } else {
- caffe2::serialize::PyTorchStreamWriter writer(writer_func);
- ScriptModuleSerializer serializer(writer);
- serializer.serialize(
- module, extra_files, bytecode_format, save_mobile_debug_info);
- }
+ bool save_mobile_debug_info) {
+ caffe2::serialize::PyTorchStreamWriter writer(
+ [&](const void* buf, size_t nbytes) -> size_t {
+ out.write(static_cast<const char*>(buf), nbytes);
+ return !out ? 0 : nbytes;
+ });
+ ScriptModuleSerializer serializer(writer);
+ serializer.serialize(
+ module, extra_files, bytecode_format, save_mobile_debug_info);
}
void ExportModule(
@@ -837,29 +809,11 @@
const std::string& filename,
const ExtraFilesMap& extra_files,
bool bytecode_format,
- bool save_mobile_debug_info,
- bool use_flatbuffer) {
- if (use_flatbuffer) {
-#if defined(ENABLE_FLATBUFFER)
- auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
- std::fstream ofile(filename, std::ios::binary | std::ios::out);
- ofile.write(static_cast<const char*>(buf), nbytes);
- ofile.close();
- return !ofile ? 0 : nbytes;
- };
- save_mobile_module_to(
- module, extra_files, save_mobile_debug_info, writer_func);
-#else
- TORCH_CHECK(
- false,
- "Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
-#endif
- } else {
- caffe2::serialize::PyTorchStreamWriter writer(filename);
- ScriptModuleSerializer serializer(writer);
- serializer.serialize(
- module, extra_files, bytecode_format, save_mobile_debug_info);
- }
+ bool save_mobile_debug_info) {
+ caffe2::serialize::PyTorchStreamWriter writer(filename);
+ ScriptModuleSerializer serializer(writer);
+ serializer.serialize(
+ module, extra_files, bytecode_format, save_mobile_debug_info);
}
void ExportModule(
@@ -867,23 +821,11 @@
const std::function<size_t(const void*, size_t)>& writer_func,
const ExtraFilesMap& extra_files,
bool bytecode_format,
- bool save_mobile_debug_info,
- bool use_flatbuffer) {
- if (use_flatbuffer) {
-#if defined(ENABLE_FLATBUFFER)
- save_mobile_module_to(
- module, extra_files, save_mobile_debug_info, writer_func);
-#else
- TORCH_CHECK(
- false,
- "Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
-#endif
- } else {
- caffe2::serialize::PyTorchStreamWriter writer(writer_func);
- ScriptModuleSerializer serializer(writer);
- serializer.serialize(
- module, extra_files, bytecode_format, save_mobile_debug_info);
- }
+ bool save_mobile_debug_info) {
+ caffe2::serialize::PyTorchStreamWriter writer(writer_func);
+ ScriptModuleSerializer serializer(writer);
+ serializer.serialize(
+ module, extra_files, bytecode_format, save_mobile_debug_info);
}
namespace {