Remove flatbuffer types/headers from flatbuffer_loader.h (#82893)

This completely hides the flatbuffer types and headers from users of flatbuffer_loader/serializer, turning them into an internal implementation detail.

A followup diff will fix up the buck files to hide the dependencies more thoroughly.

While doing this I found another use of a flatbuffer-defined name (`FLATBUFFERS_MAX_ALIGNMENT`), which highlighted the issues described in T128189662.

Differential Revision: [D38292794](https://our.internmc.facebook.com/intern/diff/D38292794/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82893
Approved by: https://github.com/qihqi
diff --git a/torch/csrc/init_flatbuffer_module.cpp b/torch/csrc/init_flatbuffer_module.cpp
index 90f5bd8..f739f83 100644
--- a/torch/csrc/init_flatbuffer_module.cpp
+++ b/torch/csrc/init_flatbuffer_module.cpp
@@ -21,21 +21,24 @@
 
 namespace py = pybind11;
 
+using torch::jit::kFlatbufferDataAlignmentBytes;
+
 static std::shared_ptr<char> copyStr(const std::string& bytes) {
-  size_t size = (bytes.size() / FLATBUFFERS_MAX_ALIGNMENT + 1) *
-      FLATBUFFERS_MAX_ALIGNMENT;
+  size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
+      kFlatbufferDataAlignmentBytes;
 #ifdef _WIN32
   std::shared_ptr<char> bytes_copy(
-      static_cast<char*>(_aligned_malloc(size, FLATBUFFERS_MAX_ALIGNMENT)),
+      static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
       _aligned_free);
 #elif defined(__APPLE__)
   void* p;
-  ::posix_memalign(&p, FLATBUFFERS_MAX_ALIGNMENT, size);
+  ::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
   TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
   std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
 #else
   std::shared_ptr<char> bytes_copy(
-      static_cast<char*>(aligned_alloc(FLATBUFFERS_MAX_ALIGNMENT, size)), free);
+      static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
+      free);
 #endif
   memcpy(bytes_copy.get(), bytes.data(), bytes.size());
   return bytes_copy;
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
index 8865290..c37366a 100644
--- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
@@ -1,5 +1,19 @@
 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
 
+#ifdef FLATBUFFERS_VERSION_MAJOR
+#error "flatbuffer_loader.h must not include any flatbuffers headers"
+#endif // FLATBUFFERS_VERSION_MAJOR
+
+#include <array>
+#include <istream>
+#include <memory>
+#include <string>
+#include <tuple>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
 #include <ATen/ATen.h>
 #include <ATen/core/dynamic_type.h>
 #include <ATen/core/ivalue.h>
@@ -12,8 +26,10 @@
 #include <caffe2/serialize/inline_container.h>
 #include <torch/csrc/jit/frontend/script_type_parser.h>
 #include <torch/csrc/jit/mobile/file_format.h>
+#include <torch/csrc/jit/mobile/function.h>
 #include <torch/csrc/jit/mobile/import.h>
 #include <torch/csrc/jit/mobile/interpreter.h>
+#include <torch/csrc/jit/mobile/module.h>
 #include <torch/csrc/jit/mobile/observer.h>
 #include <torch/csrc/jit/mobile/type_parser.h>
 #include <torch/csrc/jit/runtime/instruction.h>
@@ -28,35 +44,110 @@
 #include <torch/csrc/jit/mobile/upgrader_mobile.h>
 #endif
 
-#if defined(HAVE_MMAP)
-#include <fcntl.h>
-#include <sys/mman.h>
-#include <sys/stat.h>
-#include <unistd.h>
-#endif
-
 #ifdef _WIN32
 #include <malloc.h>
 #else
 #include <cstdlib>
 #endif
 
-#include <string>
-#include <vector>
-
 namespace torch {
 namespace jit {
 
+// Our own alignment requirement does not need to be exactly the same as what
+// flatbuffers supports, but what flatbuffers supports needs to satisfy our
+// requirement.
+static_assert(
+    kFlatbufferDataAlignmentBytes <= FLATBUFFERS_MAX_ALIGNMENT,
+    "Sizes must be compatible");
+static_assert(
+    (kFlatbufferDataAlignmentBytes & ~(kFlatbufferDataAlignmentBytes - 1)) ==
+        kFlatbufferDataAlignmentBytes,
+    "Must be a power of 2");
+
+namespace {
+
 static constexpr c10::string_view kCustomClassPrefix =
     "__torch__.torch.classes";
 static constexpr c10::string_view kTorchPrefix = "__torch__";
 static constexpr c10::string_view kJitPrefix = "torch.jit";
 
-template <typename T, typename U>
-std::vector<T> parseListNative(const U* list) {
-  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
-  return {list->items()->begin(), list->items()->end()};
-}
+class FlatbufferLoader final {
+ public:
+  FlatbufferLoader();
+
+  typedef IValue (
+      *IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
+  void registerIValueParser(
+      mobile::serialization::IValueUnion ivalue_type,
+      IValueParser parser);
+  mobile::Module parseModule(mobile::serialization::Module* module);
+
+  void extractJitSourceAndConstants(
+      ExtraFilesMap* jit_sources,
+      std::vector<IValue>* constants);
+
+  typedef TypePtr (*TypeResolver)(
+      const std::string& type_str,
+      std::shared_ptr<CompilationUnit> cu);
+
+  void internal_registerTypeResolver(TypeResolver type_resolver);
+
+  IValue& getIValue(uint32_t pos) {
+    TORCH_CHECK(pos < all_ivalues_.size());
+    return all_ivalues_[pos];
+  }
+
+  mobile::Function* getFunction(uint32_t pos) {
+    return all_functions_[pos];
+  }
+
+  ClassTypePtr getType(uint32_t pos) {
+    TORCH_CHECK(pos < all_types_.size());
+    return all_types_[pos];
+  }
+
+  c10::Storage getStorage(uint32_t index);
+  TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
+  ClassTypePtr getOrCreateClassTypeForObject(
+      const mobile::serialization::Object* object);
+
+  const mobile::serialization::Module* getCurrentFlatbufferInput() {
+    return module_;
+  }
+
+  void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
+    should_copy_tensor_memory_ = should_copy_tensor_memory;
+  }
+
+  std::shared_ptr<mobile::CompilationUnit> mcu_;
+  std::shared_ptr<CompilationUnit> cu_;
+
+ private:
+  IValue parseIValue(const mobile::serialization::IValue* ivalue);
+  std::unique_ptr<mobile::Function> parseFunction(
+      const mobile::serialization::Function* method);
+  void parseAndPopulate(
+      uint32_t i,
+      const mobile::serialization::IValue* ivalue);
+
+  std::unordered_map<uint32_t, mobile::Function*> all_functions_;
+  std::vector<ClassTypePtr> all_types_;
+  std::unordered_set<uint32_t> initialized_types_;
+  std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
+  std::vector<bool> storage_loaded_;
+  std::vector<c10::Storage> storages_;
+  std::vector<IValue> all_ivalues_;
+  std::array<
+      IValueParser,
+      static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
+      ivalue_parsers_;
+  TypeResolver type_resolver_ = nullptr;
+  mobile::serialization::Module* module_ = nullptr;
+  bool module_parsed_ = false;
+  bool should_copy_tensor_memory_ = false;
+  // 0 -> mobile_ivalue_size_ elements are from the mobile module.
+  uint32_t mobile_ivalue_size_ = 0;
+};
 
 IValue parseList(
     FlatbufferLoader&,
@@ -225,7 +316,6 @@
   return m;
 }
 
-namespace {
 void appendUpgraderFunctions(mobile::Function* function) {
 #ifndef DISABLE_UPGRADER
   for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
@@ -233,7 +323,6 @@
   }
 #endif
 }
-} // namespace
 
 std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
     const mobile::serialization::Function* method) {
@@ -266,9 +355,7 @@
         op->name()->str(), op->overload_name()->str(), num_args);
   }
 
-  if (should_load_operators_) {
-    function->initialize_operators(true);
-  }
+  function->initialize_operators(true);
 
   for (const auto i : *method->type_annotations()) {
     function->append_type(getOrCreateTypeAnnotations(i));
@@ -434,6 +521,12 @@
   return res;
 }
 
+template <typename T, typename U>
+std::vector<T> parseListNative(const U* list) {
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
+  return {list->items()->begin(), list->items()->end()};
+}
+
 IValue parseIntList(
     FlatbufferLoader&,
     const mobile::serialization::IValue& ivalue) {
@@ -641,6 +734,8 @@
   parseExtraFilesFromVector(module_->jit_sources(), jit_sources);
 }
 
+} // namespace
+
 mobile::Module parse_and_initialize_mobile_module(
     void* data,
     size_t,
@@ -649,6 +744,8 @@
     bool should_copy_tensor_memory) {
   TORCH_CHECK(
       mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
+  // TODO(T128189662): If not copying, enforce that data is aligned to
+  // kFlatbufferDataAlignmentBytes, and add unit tests.
 
   FlatbufferLoader loader;
   loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
@@ -687,6 +784,8 @@
     ExtraFilesMap* extra_files) {
   TORCH_CHECK(
       mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
+  // TODO(T128189662): Enforce that data is aligned to
+  // kFlatbufferDataAlignmentBytes, and add unit tests.
 
   FlatbufferLoader loader;
   auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
@@ -699,16 +798,6 @@
   return m;
 }
 
-mobile::Module initialize_mobile_module(
-    mobile::serialization::Module* flatbuffer_module,
-    c10::optional<at::Device>,
-    bool should_copy_tensor_memory) {
-  auto flatbufferLoader = FlatbufferLoader();
-  flatbufferLoader.setShouldCopyTensorMemory(should_copy_tensor_memory);
-  mobile::Module m = flatbufferLoader.parseModule(flatbuffer_module);
-  return m;
-}
-
 mobile::Module load_mobile_module_from_file(
     const std::string& filename,
     c10::optional<c10::Device> device,
@@ -786,7 +875,8 @@
       std::move(data), size, device, extra_files);
 }
 
-static mobile::Module parse_flatbuffer_no_object(
+namespace {
+mobile::Module parse_flatbuffer_no_object(
     std::shared_ptr<char> data,
     size_t size,
     c10::optional<at::Device> device) {
@@ -815,6 +905,7 @@
   m.set_delete_memory(std::move(data));
   return m;
 }
+} // namespace
 
 bool register_flatbuffer_loader() {
   load_flatbuffer_bytes = parse_and_initialize_mobile_module;
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h
index a1d3095..eee44d4 100644
--- a/torch/csrc/jit/mobile/flatbuffer_loader.h
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.h
@@ -1,25 +1,32 @@
 #pragma once
 
+#include <istream>
+#include <memory>
 #include <string>
 #include <unordered_map>
 #include <vector>
 
 #include <ATen/core/ivalue.h>
-#include <caffe2/serialize/inline_container.h>
-#include <torch/csrc/jit/mobile/function.h>
-#include <torch/csrc/jit/mobile/interpreter.h>
+#include <c10/core/Device.h>
+#include <c10/macros/Macros.h>
+#include <c10/util/Optional.h>
 #include <torch/csrc/jit/mobile/module.h>
-#include <torch/csrc/jit/runtime/instruction.h>
-#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
-#include <torch/custom_class.h>
 
 /**
  * Defines the public API for loading flatbuffer-serialized mobile modules.
+ * Note that this header must not include or depend on flatbuffer-defined
+ * types, to avoid leaking those details to PyTorch clients.
  */
 
 namespace torch {
 namespace jit {
 
+/// All non-copied data pointers provided to `parse_and_initialize_*` functions
+/// must be aligned to this boundary. Since the Module will point directly into
+/// the data, this alignment is necessary to ensure that certain types/structs
+/// are properly aligned.
+constexpr size_t kFlatbufferDataAlignmentBytes = 16;
+
 /// Maps file names to file contents.
 using ExtraFilesMap = std::unordered_map<std::string, std::string>;
 
@@ -30,22 +37,9 @@
 //    structure
 // 3. Module initialization: Produce mobile::Module out of the structure
 //    produced in 2.
-// Under this context, the structure described in 2. is
-// mobile::serialization::Module
-
-/// DEPRECATED: Use a parse/load function below.
-// Parse a mobile::Module from flatbuffer's in-memory Module representation.
-// The caller is assumed to manage the lifetimes of Module.
-// This function does step 3 described above.
-// If should_copy_tensor_memory is true, then the returned module will NOT
-// have refences to flatbuffer_module, so it can be discarded.
-// If should_copy_tensor_memory is false, then returned module will have
-// tensors that points inside of flatbuffer_module; the caller need to make
-// sure that flatbuffer_module outlives returned Module.
-TORCH_API mobile::Module initialize_mobile_module(
-    mobile::serialization::Module* flatbuffer_module,
-    c10::optional<at::Device> device = c10::nullopt,
-    bool should_copy_tensor_memory = false);
+// Under this context, the structure described in 2. is the flatbuffer-defined
+// type mobile::serialization::Module. However, this step/type is not visible in
+// the public API.
 
 // Parse a mobile::Module from raw bytes.
 //
@@ -59,7 +53,8 @@
 //
 // If should_copy_tensor_memory is false, then returned module will have tensors
 // that points inside of `data`; the caller will need to make sure that `data`
-// outlives the returned Module.
+// outlives the returned Module. Also, `data` must be aligned to
+// kFlatbufferDataAlignmentBytes.
 TORCH_API mobile::Module parse_and_initialize_mobile_module(
     void* data,
     size_t size, // of `data`, in bytes.
@@ -71,7 +66,8 @@
 //
 // This function does steps 2+3 described above.
 //
-// The returned Module holds a reference to `data`.
+// The returned Module holds a reference to `data`, which must be aligned to
+// kFlatbufferDataAlignmentBytes.
 //
 // If you do not want the Module to hold a reference to `data`, see the raw
 // pointer overload of this function.
@@ -107,12 +103,6 @@
     c10::optional<at::Device> device = c10::nullopt,
     ExtraFilesMap* extra_files = nullptr);
 
-/// DEPRECATED: Use the `extra_files` parameter of one of the parse/load
-/// functions above.
-TORCH_API void parseExtraFiles(
-    mobile::serialization::Module* module,
-    ExtraFilesMap& extra_files);
-
 TORCH_API uint64_t get_bytecode_version(std::istream& in);
 TORCH_API uint64_t get_bytecode_version(const std::string& filename);
 TORCH_API uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content);
@@ -133,97 +123,5 @@
 // in this file directly.
 TORCH_API bool register_flatbuffer_loader();
 
-/// DEPRECATED: Use one of the parse/load functions above.
-class TORCH_API FlatbufferLoader {
- public:
-  FlatbufferLoader();
-
-  typedef IValue (
-      *IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
-  void registerIValueParser(
-      mobile::serialization::IValueUnion ivalue_type,
-      IValueParser parser);
-  mobile::Module parseModule(mobile::serialization::Module* module);
-
-  void extractJitSourceAndConstants(
-      ExtraFilesMap* jit_sources,
-      std::vector<IValue>* constants);
-
-  typedef TypePtr (*TypeResolver)(
-      const std::string& type_str,
-      std::shared_ptr<CompilationUnit> cu);
-
-  void internal_registerTypeResolver(TypeResolver type_resolver);
-
-  IValue& getIValue(uint32_t pos) {
-    TORCH_CHECK(pos < all_ivalues_.size());
-    return all_ivalues_[pos];
-  }
-
-  mobile::Function* getFunction(uint32_t pos) {
-    return all_functions_[pos];
-  }
-
-  ClassTypePtr getType(uint32_t pos) {
-    TORCH_CHECK(pos < all_ivalues_.size());
-    return all_types_[pos];
-  }
-
-  c10::Storage getStorage(uint32_t index);
-  TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
-  ClassTypePtr getOrCreateClassTypeForObject(
-      const mobile::serialization::Object* object);
-
-  const mobile::serialization::Module* getCurrentFlatbufferInput() {
-    return module_;
-  }
-
-  bool getShouldCopyTensorMemory() {
-    return should_copy_tensor_memory_;
-  }
-
-  void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
-    should_copy_tensor_memory_ = should_copy_tensor_memory;
-  }
-
-  // Whether or not should load operators in functions.
-  // Not loading operators is useful because if an operator is not found
-  // then we throw exceptions, and sometimes we want to print out
-  // what operators are included before that to debug.
-  void setShouldLoadOperators(bool should_load_operators) {
-    should_load_operators_ = should_load_operators;
-  }
-
-  std::shared_ptr<mobile::CompilationUnit> mcu_;
-  std::shared_ptr<CompilationUnit> cu_;
-
- private:
-  IValue parseIValue(const mobile::serialization::IValue* ivalue);
-  std::unique_ptr<mobile::Function> parseFunction(
-      const mobile::serialization::Function* method);
-  void parseAndPopulate(
-      uint32_t i,
-      const mobile::serialization::IValue* ivalue);
-
-  std::unordered_map<uint32_t, mobile::Function*> all_functions_;
-  std::vector<ClassTypePtr> all_types_;
-  std::unordered_set<uint32_t> initialized_types_;
-  std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
-  std::vector<bool> storage_loaded_;
-  std::vector<c10::Storage> storages_;
-  std::vector<IValue> all_ivalues_;
-  std::array<
-      IValueParser,
-      static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
-      ivalue_parsers_;
-  TypeResolver type_resolver_ = nullptr;
-  mobile::serialization::Module* module_ = nullptr;
-  bool module_parsed_ = false;
-  bool should_copy_tensor_memory_ = false;
-  bool should_load_operators_ = true;
-  // 0 -> mobile_ivalue_size_ elements are from the mobile module.
-  uint32_t mobile_ivalue_size_ = 0;
-};
-
 } // namespace jit
 } // namespace torch