Update torch flatbuffer usage to OSS version (#71957)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71957

Update users of flatbuffer serializer/loader to use the version in torch/csrc.

Test Plan:
sandcastle

Ran `buck run :test_models -- -k test_aten_relu` passes

Reviewed By: gmagogsfm

Differential Revision: D33720611

fbshipit-source-id: 6cdf7ab43ffca83327a677853be8f4918c47d53d
(cherry picked from commit 4f59e3547e2cd346a3f2310bc2d1f6a931fb826e)
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
index 1704109..c61738b 100644
--- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
@@ -80,6 +80,31 @@
     FlatbufferLoader&,
     const mobile::serialization::IValue& ivalue);
 
+TypePtr resolveType(
+    const std::string& type_string,
+    std::shared_ptr<CompilationUnit> cu) {
+  TypePtr type;
+  c10::string_view type_str(type_string);
+  if (type_str.starts_with(kCustomClassPrefix)) {
+    type = getCustomClass(type_string);
+    TORCH_CHECK(
+        type, "The implementation of class ", type_string, " cannot be found.");
+  } else if (
+      type_str.starts_with(kTorchPrefix) || type_str.starts_with(kJitPrefix)) {
+    c10::QualifiedName qn(type_string);
+    if (cu->get_class(qn) == nullptr) {
+      auto classtype = ClassType::create(qn, cu, true);
+      cu->register_type(classtype);
+      type = classtype;
+    } else {
+      type = cu->get_class(qn);
+    }
+  } else {
+    type = c10::parseType(type_string);
+  }
+  return type;
+}
+
 FlatbufferLoader::FlatbufferLoader()
     : mcu_(std::make_shared<mobile::CompilationUnit>()),
       cu_(std::make_shared<CompilationUnit>()),
@@ -107,6 +132,7 @@
   registerIValueParser(mobile::serialization::IValueUnion::Device, &parseBasic);
   registerIValueParser(
       mobile::serialization::IValueUnion::EnumValue, &parseEnum);
+  internal_registerTypeResolver(&resolveType);
 }
 
 void FlatbufferLoader::registerIValueParser(
@@ -115,6 +141,11 @@
   ivalue_parsers_[static_cast<uint8_t>(ivalue_type)] = parser;
 }
 
+void FlatbufferLoader::internal_registerTypeResolver(
+    TypeResolver type_resolver) {
+  type_resolver_ = type_resolver;
+}
+
 mobile::Module FlatbufferLoader::parseModule(
     mobile::serialization::Module* module) {
   module_ = module;
@@ -496,28 +527,7 @@
   if (iter != type_annotations_.end()) {
     return iter->second;
   }
-  TypePtr type;
-  c10::string_view qn_str(offset->c_str(), offset->size());
-  c10::QualifiedName qn(offset->str());
-  if (qn_str.starts_with(kCustomClassPrefix)) {
-    type = getCustomClass(qn.qualifiedName());
-    TORCH_CHECK(
-        type,
-        "The implementation of class ",
-        qn.qualifiedName(),
-        " cannot be found.");
-  } else if (
-      qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
-    if (cu_->get_class(qn) == nullptr) {
-      auto classtype = ClassType::create(qn, cu_, true);
-      cu_->register_type(classtype);
-      type = classtype;
-    } else {
-      type = cu_->get_class(qn);
-    }
-  } else {
-    type = c10::parseType(qn.qualifiedName());
-  }
+  TypePtr type = type_resolver_(offset->str(), cu_);
   type_annotations_[offset] = type;
   return type;
 }
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h
index 9a225fa..9b719b6 100644
--- a/torch/csrc/jit/mobile/flatbuffer_loader.h
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.h
@@ -61,6 +61,12 @@
       IValueParser parser);
   mobile::Module parseModule(mobile::serialization::Module* module);
 
+  typedef TypePtr (*TypeResolver)(
+      const std::string& type_str,
+      std::shared_ptr<CompilationUnit> cu);
+
+  void internal_registerTypeResolver(TypeResolver type_resolver);
+
   IValue& getIValue(uint32_t pos) {
     TORCH_CHECK(pos < all_ivalues_.size());
     return all_ivalues_[pos];
@@ -103,6 +109,7 @@
       IValueParser,
       static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
       ivalue_parsers_;
+  TypeResolver type_resolver_ = nullptr;
   mobile::serialization::Module* module_ = nullptr;
 };