Update torch flatbuffer usage to OSS version (#71957)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71957
Update users of flatbuffer serializer/loader to use the version in torch/csrc.
Test Plan:
sandcastle
Ran `buck run :test_models -- -k test_aten_relu` passes
Reviewed By: gmagogsfm
Differential Revision: D33720611
fbshipit-source-id: 6cdf7ab43ffca83327a677853be8f4918c47d53d
(cherry picked from commit 4f59e3547e2cd346a3f2310bc2d1f6a931fb826e)
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
index 1704109..c61738b 100644
--- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
@@ -80,6 +80,31 @@
FlatbufferLoader&,
const mobile::serialization::IValue& ivalue);
+TypePtr resolveType(
+ const std::string& type_string,
+ std::shared_ptr<CompilationUnit> cu) {
+ TypePtr type;
+ c10::string_view type_str(type_string);
+ if (type_str.starts_with(kCustomClassPrefix)) {
+ type = getCustomClass(type_string);
+ TORCH_CHECK(
+ type, "The implementation of class ", type_string, " cannot be found.");
+ } else if (
+ type_str.starts_with(kTorchPrefix) || type_str.starts_with(kJitPrefix)) {
+ c10::QualifiedName qn(type_string);
+ if (cu->get_class(qn) == nullptr) {
+ auto classtype = ClassType::create(qn, cu, true);
+ cu->register_type(classtype);
+ type = classtype;
+ } else {
+ type = cu->get_class(qn);
+ }
+ } else {
+ type = c10::parseType(type_string);
+ }
+ return type;
+}
+
FlatbufferLoader::FlatbufferLoader()
: mcu_(std::make_shared<mobile::CompilationUnit>()),
cu_(std::make_shared<CompilationUnit>()),
@@ -107,6 +132,7 @@
registerIValueParser(mobile::serialization::IValueUnion::Device, &parseBasic);
registerIValueParser(
mobile::serialization::IValueUnion::EnumValue, &parseEnum);
+ internal_registerTypeResolver(&resolveType);
}
void FlatbufferLoader::registerIValueParser(
@@ -115,6 +141,11 @@
ivalue_parsers_[static_cast<uint8_t>(ivalue_type)] = parser;
}
+void FlatbufferLoader::internal_registerTypeResolver(
+ TypeResolver type_resolver) {
+ type_resolver_ = type_resolver;
+}
+
mobile::Module FlatbufferLoader::parseModule(
mobile::serialization::Module* module) {
module_ = module;
@@ -496,28 +527,7 @@
if (iter != type_annotations_.end()) {
return iter->second;
}
- TypePtr type;
- c10::string_view qn_str(offset->c_str(), offset->size());
- c10::QualifiedName qn(offset->str());
- if (qn_str.starts_with(kCustomClassPrefix)) {
- type = getCustomClass(qn.qualifiedName());
- TORCH_CHECK(
- type,
- "The implementation of class ",
- qn.qualifiedName(),
- " cannot be found.");
- } else if (
- qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
- if (cu_->get_class(qn) == nullptr) {
- auto classtype = ClassType::create(qn, cu_, true);
- cu_->register_type(classtype);
- type = classtype;
- } else {
- type = cu_->get_class(qn);
- }
- } else {
- type = c10::parseType(qn.qualifiedName());
- }
+ TypePtr type = type_resolver_(offset->str(), cu_);
type_annotations_[offset] = type;
return type;
}
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h
index 9a225fa..9b719b6 100644
--- a/torch/csrc/jit/mobile/flatbuffer_loader.h
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.h
@@ -61,6 +61,12 @@
IValueParser parser);
mobile::Module parseModule(mobile::serialization::Module* module);
+ typedef TypePtr (*TypeResolver)(
+ const std::string& type_str,
+ std::shared_ptr<CompilationUnit> cu);
+
+ void internal_registerTypeResolver(TypeResolver type_resolver);
+
IValue& getIValue(uint32_t pos) {
TORCH_CHECK(pos < all_ivalues_.size());
return all_ivalues_[pos];
@@ -103,6 +109,7 @@
IValueParser,
static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
ivalue_parsers_;
+ TypeResolver type_resolver_ = nullptr;
mobile::serialization::Module* module_ = nullptr;
};