Remove flatbuffer types/headers from flatbuffer_loader.h (#82893)
This completely hides the flatbuffer types and headers from users of flatbuffer_loader/serializer, turning them into an internal implementation detail.
A followup diff will fix up the buck files to hide the dependencies more thoroughly.
While doing this I found another use of a flatbuffer-defined name (`FLATBUFFERS_MAX_ALIGNMENT`), which highlighted the issues described in T128189662.
Differential Revision: [D38292794](https://our.internmc.facebook.com/intern/diff/D38292794/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82893
Approved by: https://github.com/qihqi
diff --git a/torch/csrc/init_flatbuffer_module.cpp b/torch/csrc/init_flatbuffer_module.cpp
index 90f5bd8..f739f83 100644
--- a/torch/csrc/init_flatbuffer_module.cpp
+++ b/torch/csrc/init_flatbuffer_module.cpp
@@ -21,21 +21,24 @@
namespace py = pybind11;
+using torch::jit::kFlatbufferDataAlignmentBytes;
+
static std::shared_ptr<char> copyStr(const std::string& bytes) {
- size_t size = (bytes.size() / FLATBUFFERS_MAX_ALIGNMENT + 1) *
- FLATBUFFERS_MAX_ALIGNMENT;
+ size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
+ kFlatbufferDataAlignmentBytes;
#ifdef _WIN32
std::shared_ptr<char> bytes_copy(
- static_cast<char*>(_aligned_malloc(size, FLATBUFFERS_MAX_ALIGNMENT)),
+ static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
_aligned_free);
#elif defined(__APPLE__)
void* p;
- ::posix_memalign(&p, FLATBUFFERS_MAX_ALIGNMENT, size);
+ ::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
#else
std::shared_ptr<char> bytes_copy(
- static_cast<char*>(aligned_alloc(FLATBUFFERS_MAX_ALIGNMENT, size)), free);
+ static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
+ free);
#endif
memcpy(bytes_copy.get(), bytes.data(), bytes.size());
return bytes_copy;
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
index 8865290..c37366a 100644
--- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
@@ -1,5 +1,19 @@
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
+#ifdef FLATBUFFERS_VERSION_MAJOR
+#error "flatbuffer_loader.h must not include any flatbuffers headers"
+#endif // FLATBUFFERS_VERSION_MAJOR
+
+#include <array>
+#include <istream>
+#include <memory>
+#include <string>
+#include <tuple>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
#include <ATen/ATen.h>
#include <ATen/core/dynamic_type.h>
#include <ATen/core/ivalue.h>
@@ -12,8 +26,10 @@
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/mobile/file_format.h>
+#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/interpreter.h>
+#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/instruction.h>
@@ -28,35 +44,110 @@
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
#endif
-#if defined(HAVE_MMAP)
-#include <fcntl.h>
-#include <sys/mman.h>
-#include <sys/stat.h>
-#include <unistd.h>
-#endif
-
#ifdef _WIN32
#include <malloc.h>
#else
#include <cstdlib>
#endif
-#include <string>
-#include <vector>
-
namespace torch {
namespace jit {
+// Our own alignment requirement does not need to be exactly the same as what
+// flatbuffers supports, but what flatbuffers supports needs to satisfy our
+// requirement.
+static_assert(
+ kFlatbufferDataAlignmentBytes <= FLATBUFFERS_MAX_ALIGNMENT,
+ "Sizes must be compatible");
+static_assert(
+ (kFlatbufferDataAlignmentBytes & ~(kFlatbufferDataAlignmentBytes - 1)) ==
+ kFlatbufferDataAlignmentBytes,
+ "Must be a power of 2");
+
+namespace {
+
static constexpr c10::string_view kCustomClassPrefix =
"__torch__.torch.classes";
static constexpr c10::string_view kTorchPrefix = "__torch__";
static constexpr c10::string_view kJitPrefix = "torch.jit";
-template <typename T, typename U>
-std::vector<T> parseListNative(const U* list) {
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
- return {list->items()->begin(), list->items()->end()};
-}
+class FlatbufferLoader final {
+ public:
+ FlatbufferLoader();
+
+ typedef IValue (
+ *IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
+ void registerIValueParser(
+ mobile::serialization::IValueUnion ivalue_type,
+ IValueParser parser);
+ mobile::Module parseModule(mobile::serialization::Module* module);
+
+ void extractJitSourceAndConstants(
+ ExtraFilesMap* jit_sources,
+ std::vector<IValue>* constants);
+
+ typedef TypePtr (*TypeResolver)(
+ const std::string& type_str,
+ std::shared_ptr<CompilationUnit> cu);
+
+ void internal_registerTypeResolver(TypeResolver type_resolver);
+
+ IValue& getIValue(uint32_t pos) {
+ TORCH_CHECK(pos < all_ivalues_.size());
+ return all_ivalues_[pos];
+ }
+
+ mobile::Function* getFunction(uint32_t pos) {
+ return all_functions_[pos];
+ }
+
+ ClassTypePtr getType(uint32_t pos) {
+ TORCH_CHECK(pos < all_types_.size());
+ return all_types_[pos];
+ }
+
+ c10::Storage getStorage(uint32_t index);
+ TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
+ ClassTypePtr getOrCreateClassTypeForObject(
+ const mobile::serialization::Object* object);
+
+ const mobile::serialization::Module* getCurrentFlatbufferInput() {
+ return module_;
+ }
+
+ void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
+ should_copy_tensor_memory_ = should_copy_tensor_memory;
+ }
+
+ std::shared_ptr<mobile::CompilationUnit> mcu_;
+ std::shared_ptr<CompilationUnit> cu_;
+
+ private:
+ IValue parseIValue(const mobile::serialization::IValue* ivalue);
+ std::unique_ptr<mobile::Function> parseFunction(
+ const mobile::serialization::Function* method);
+ void parseAndPopulate(
+ uint32_t i,
+ const mobile::serialization::IValue* ivalue);
+
+ std::unordered_map<uint32_t, mobile::Function*> all_functions_;
+ std::vector<ClassTypePtr> all_types_;
+ std::unordered_set<uint32_t> initialized_types_;
+ std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
+ std::vector<bool> storage_loaded_;
+ std::vector<c10::Storage> storages_;
+ std::vector<IValue> all_ivalues_;
+ std::array<
+ IValueParser,
+ static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
+ ivalue_parsers_;
+ TypeResolver type_resolver_ = nullptr;
+ mobile::serialization::Module* module_ = nullptr;
+ bool module_parsed_ = false;
+ bool should_copy_tensor_memory_ = false;
+ // 0 -> mobile_ivalue_size_ elements are from the mobile module.
+ uint32_t mobile_ivalue_size_ = 0;
+};
IValue parseList(
FlatbufferLoader&,
@@ -225,7 +316,6 @@
return m;
}
-namespace {
void appendUpgraderFunctions(mobile::Function* function) {
#ifndef DISABLE_UPGRADER
for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
@@ -233,7 +323,6 @@
}
#endif
}
-} // namespace
std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
const mobile::serialization::Function* method) {
@@ -266,9 +355,7 @@
op->name()->str(), op->overload_name()->str(), num_args);
}
- if (should_load_operators_) {
- function->initialize_operators(true);
- }
+ function->initialize_operators(true);
for (const auto i : *method->type_annotations()) {
function->append_type(getOrCreateTypeAnnotations(i));
@@ -434,6 +521,12 @@
return res;
}
+template <typename T, typename U>
+std::vector<T> parseListNative(const U* list) {
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
+ return {list->items()->begin(), list->items()->end()};
+}
+
IValue parseIntList(
FlatbufferLoader&,
const mobile::serialization::IValue& ivalue) {
@@ -641,6 +734,8 @@
parseExtraFilesFromVector(module_->jit_sources(), jit_sources);
}
+} // namespace
+
mobile::Module parse_and_initialize_mobile_module(
void* data,
size_t,
@@ -649,6 +744,8 @@
bool should_copy_tensor_memory) {
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
+ // TODO(T128189662): If not copying, enforce that data is aligned to
+ // kFlatbufferDataAlignmentBytes, and add unit tests.
FlatbufferLoader loader;
loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
@@ -687,6 +784,8 @@
ExtraFilesMap* extra_files) {
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
+ // TODO(T128189662): Enforce that data is aligned to
+ // kFlatbufferDataAlignmentBytes, and add unit tests.
FlatbufferLoader loader;
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
@@ -699,16 +798,6 @@
return m;
}
-mobile::Module initialize_mobile_module(
- mobile::serialization::Module* flatbuffer_module,
- c10::optional<at::Device>,
- bool should_copy_tensor_memory) {
- auto flatbufferLoader = FlatbufferLoader();
- flatbufferLoader.setShouldCopyTensorMemory(should_copy_tensor_memory);
- mobile::Module m = flatbufferLoader.parseModule(flatbuffer_module);
- return m;
-}
-
mobile::Module load_mobile_module_from_file(
const std::string& filename,
c10::optional<c10::Device> device,
@@ -786,7 +875,8 @@
std::move(data), size, device, extra_files);
}
-static mobile::Module parse_flatbuffer_no_object(
+namespace {
+mobile::Module parse_flatbuffer_no_object(
std::shared_ptr<char> data,
size_t size,
c10::optional<at::Device> device) {
@@ -815,6 +905,7 @@
m.set_delete_memory(std::move(data));
return m;
}
+} // namespace
bool register_flatbuffer_loader() {
load_flatbuffer_bytes = parse_and_initialize_mobile_module;
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h
index a1d3095..eee44d4 100644
--- a/torch/csrc/jit/mobile/flatbuffer_loader.h
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.h
@@ -1,25 +1,32 @@
#pragma once
+#include <istream>
+#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <ATen/core/ivalue.h>
-#include <caffe2/serialize/inline_container.h>
-#include <torch/csrc/jit/mobile/function.h>
-#include <torch/csrc/jit/mobile/interpreter.h>
+#include <c10/core/Device.h>
+#include <c10/macros/Macros.h>
+#include <c10/util/Optional.h>
#include <torch/csrc/jit/mobile/module.h>
-#include <torch/csrc/jit/runtime/instruction.h>
-#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
-#include <torch/custom_class.h>
/**
* Defines the public API for loading flatbuffer-serialized mobile modules.
+ * Note that this header must not include or depend on flatbuffer-defined
+ * types, to avoid leaking those details to PyTorch clients.
*/
namespace torch {
namespace jit {
+/// All non-copied data pointers provided to `parse_and_initialize_*` functions
+/// must be aligned to this boundary. Since the Module will point directly into
+/// the data, this alignment is necessary to ensure that certain types/structs
+/// are properly aligned.
+constexpr size_t kFlatbufferDataAlignmentBytes = 16;
+
/// Maps file names to file contents.
using ExtraFilesMap = std::unordered_map<std::string, std::string>;
@@ -30,22 +37,9 @@
// structure
// 3. Module initialization: Produce mobile::Module out of the structure
// produced in 2.
-// Under this context, the structure described in 2. is
-// mobile::serialization::Module
-
-/// DEPRECATED: Use a parse/load function below.
-// Parse a mobile::Module from flatbuffer's in-memory Module representation.
-// The caller is assumed to manage the lifetimes of Module.
-// This function does step 3 described above.
-// If should_copy_tensor_memory is true, then the returned module will NOT
-// have refences to flatbuffer_module, so it can be discarded.
-// If should_copy_tensor_memory is false, then returned module will have
-// tensors that points inside of flatbuffer_module; the caller need to make
-// sure that flatbuffer_module outlives returned Module.
-TORCH_API mobile::Module initialize_mobile_module(
- mobile::serialization::Module* flatbuffer_module,
- c10::optional<at::Device> device = c10::nullopt,
- bool should_copy_tensor_memory = false);
+// Under this context, the structure described in 2. is the flatbuffer-defined
+// type mobile::serialization::Module. However, this step/type is not visible in
+// the public API.
// Parse a mobile::Module from raw bytes.
//
@@ -59,7 +53,8 @@
//
// If should_copy_tensor_memory is false, then returned module will have tensors
// that points inside of `data`; the caller will need to make sure that `data`
-// outlives the returned Module.
+// outlives the returned Module. Also, `data` must be aligned to
+// kFlatbufferDataAlignmentBytes.
TORCH_API mobile::Module parse_and_initialize_mobile_module(
void* data,
size_t size, // of `data`, in bytes.
@@ -71,7 +66,8 @@
//
// This function does steps 2+3 described above.
//
-// The returned Module holds a reference to `data`.
+// The returned Module holds a reference to `data`, which must be aligned to
+// kFlatbufferDataAlignmentBytes.
//
// If you do not want the Module to hold a reference to `data`, see the raw
// pointer overload of this function.
@@ -107,12 +103,6 @@
c10::optional<at::Device> device = c10::nullopt,
ExtraFilesMap* extra_files = nullptr);
-/// DEPRECATED: Use the `extra_files` parameter of one of the parse/load
-/// functions above.
-TORCH_API void parseExtraFiles(
- mobile::serialization::Module* module,
- ExtraFilesMap& extra_files);
-
TORCH_API uint64_t get_bytecode_version(std::istream& in);
TORCH_API uint64_t get_bytecode_version(const std::string& filename);
TORCH_API uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content);
@@ -133,97 +123,5 @@
// in this file directly.
TORCH_API bool register_flatbuffer_loader();
-/// DEPRECATED: Use one of the parse/load functions above.
-class TORCH_API FlatbufferLoader {
- public:
- FlatbufferLoader();
-
- typedef IValue (
- *IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
- void registerIValueParser(
- mobile::serialization::IValueUnion ivalue_type,
- IValueParser parser);
- mobile::Module parseModule(mobile::serialization::Module* module);
-
- void extractJitSourceAndConstants(
- ExtraFilesMap* jit_sources,
- std::vector<IValue>* constants);
-
- typedef TypePtr (*TypeResolver)(
- const std::string& type_str,
- std::shared_ptr<CompilationUnit> cu);
-
- void internal_registerTypeResolver(TypeResolver type_resolver);
-
- IValue& getIValue(uint32_t pos) {
- TORCH_CHECK(pos < all_ivalues_.size());
- return all_ivalues_[pos];
- }
-
- mobile::Function* getFunction(uint32_t pos) {
- return all_functions_[pos];
- }
-
- ClassTypePtr getType(uint32_t pos) {
- TORCH_CHECK(pos < all_ivalues_.size());
- return all_types_[pos];
- }
-
- c10::Storage getStorage(uint32_t index);
- TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
- ClassTypePtr getOrCreateClassTypeForObject(
- const mobile::serialization::Object* object);
-
- const mobile::serialization::Module* getCurrentFlatbufferInput() {
- return module_;
- }
-
- bool getShouldCopyTensorMemory() {
- return should_copy_tensor_memory_;
- }
-
- void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
- should_copy_tensor_memory_ = should_copy_tensor_memory;
- }
-
- // Whether or not should load operators in functions.
- // Not loading operators is useful because if an operator is not found
- // then we throw exceptions, and sometimes we want to print out
- // what operators are included before that to debug.
- void setShouldLoadOperators(bool should_load_operators) {
- should_load_operators_ = should_load_operators;
- }
-
- std::shared_ptr<mobile::CompilationUnit> mcu_;
- std::shared_ptr<CompilationUnit> cu_;
-
- private:
- IValue parseIValue(const mobile::serialization::IValue* ivalue);
- std::unique_ptr<mobile::Function> parseFunction(
- const mobile::serialization::Function* method);
- void parseAndPopulate(
- uint32_t i,
- const mobile::serialization::IValue* ivalue);
-
- std::unordered_map<uint32_t, mobile::Function*> all_functions_;
- std::vector<ClassTypePtr> all_types_;
- std::unordered_set<uint32_t> initialized_types_;
- std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
- std::vector<bool> storage_loaded_;
- std::vector<c10::Storage> storages_;
- std::vector<IValue> all_ivalues_;
- std::array<
- IValueParser,
- static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
- ivalue_parsers_;
- TypeResolver type_resolver_ = nullptr;
- mobile::serialization::Module* module_ = nullptr;
- bool module_parsed_ = false;
- bool should_copy_tensor_memory_ = false;
- bool should_load_operators_ = true;
- // 0 -> mobile_ivalue_size_ elements are from the mobile module.
- uint32_t mobile_ivalue_size_ = 0;
-};
-
} // namespace jit
} // namespace torch