Segmentation fault in flatbuffers when parsing malformed modules (#95221)
Fixes #95061, #95062
Add Flatbuffer verification before parsing to avoid crashing on malformed modules. Flatbuffers doesn't perform boundary checks at runtime for the sake of performance, so when parsing untrusted modules it is highly recommended to verify overall buffer integrity.
This bug can be triggered both by C++ (`torch::jit::load`, `torch::jitload_jit_module_from_file`) and Python API (`torch.jit.load`, `torch.jit.jit_module_from_flatbuffer`).
Crash files to reproduce:
[crash-1feb368861083e3d242e5c3fcb1090869f4819c4.txt](https://github.com/pytorch/pytorch/files/10795267/crash-1feb368861083e3d242e5c3fcb1090869f4819c4.txt)
[crash-7e8ffd314223be96b43ca246d3d3481702869455.txt](https://github.com/pytorch/pytorch/files/10795268/crash-7e8ffd314223be96b43ca246d3d3481702869455.txt)
[crash-ad4d7c6183af8f34fe1cb5c8133315c6389c409f.txt](https://github.com/pytorch/pytorch/files/10795279/crash-ad4d7c6183af8f34fe1cb5c8133315c6389c409f.txt)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95221
Approved by: https://github.com/qihqi, https://github.com/davidberard98
diff --git a/test/cpp/jit/test_flatbuffer.cpp b/test/cpp/jit/test_flatbuffer.cpp
index 1eea96d..2a631bf 100644
--- a/test/cpp/jit/test_flatbuffer.cpp
+++ b/test/cpp/jit/test_flatbuffer.cpp
@@ -53,6 +53,23 @@
}
} // namespace
+TEST(FlatbufferTest, LoadMalformedModule) {
+ // Manually create some data with Flatbuffer header.
+ std::stringstream bad_data;
+ bad_data << "PK\x03\x04PTMF\x00\x00"
+ << "*}NV\xb3\xfa\xdf\x00pa";
+
+ // Loading module from it should throw an exception.
+ // Check guard at parse_and_initialize_mobile_module_for_jit.
+ ASSERT_THROWS_WITH_MESSAGE(
+ torch::jit::load(bad_data), "Malformed Flatbuffer module");
+
+ // Check guard at parse_and_initialize_mobile_module.
+ ASSERT_THROWS_WITH_MESSAGE(
+ parse_mobile_module(bad_data.str().data(), bad_data.str().size()),
+ "Malformed Flatbuffer module");
+}
+
TEST(FlatbufferTest, UpsampleNearest2d) {
Module m("m");
m.define(R"(
diff --git a/test/cpp/jit/test_lite_trainer.cpp b/test/cpp/jit/test_lite_trainer.cpp
index 311a818..c88775a 100644
--- a/test/cpp/jit/test_lite_trainer.cpp
+++ b/test/cpp/jit/test_lite_trainer.cpp
@@ -1,3 +1,5 @@
+#include <test/cpp/jit/test_utils.h>
+
#include <gtest/gtest.h>
#include <c10/core/TensorOptions.h>
@@ -238,6 +240,17 @@
EXPECT_ANY_THROW(_load_parameters(empty));
}
+TEST(MobileTest, LoadParametersMalformedFlatbuffer) {
+ // Manually create some data with Flatbuffer header.
+ std::stringstream bad_data;
+ bad_data << "PK\x03\x04PTMF\x00\x00"
+ << "*}NV\xb3\xfa\xdf\x00pa";
+
+ // Loading parameters from it should throw an exception.
+ ASSERT_THROWS_WITH_MESSAGE(
+ _load_parameters(bad_data), "Malformed Flatbuffer module");
+}
+
TEST(LiteTrainerTest, SGD) {
Module m("m");
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
index 687ffad..7c296b0 100644
--- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
@@ -743,7 +743,7 @@
mobile::Module parse_and_initialize_mobile_module(
void* data,
- size_t,
+ size_t size,
c10::optional<at::Device>,
ExtraFilesMap* extra_files,
bool should_copy_tensor_memory) {
@@ -752,6 +752,12 @@
// TODO(T128189662): If not copying, enforce that data is aligned to
// kFlatbufferDataAlignmentBytes, and add unit tests.
+ // Validate Flatbuffer module before parsing.
+ flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data), size);
+ TORCH_CHECK(
+ mobile::serialization::VerifyModuleBuffer(verifier),
+ "Malformed Flatbuffer module");
+
FlatbufferLoader loader;
loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
@@ -782,7 +788,7 @@
mobile::Module parse_and_initialize_mobile_module_for_jit(
void* data,
- size_t,
+ size_t size,
ExtraFilesMap& jit_sources,
std::vector<IValue>& jit_constants,
c10::optional<at::Device>,
@@ -792,6 +798,12 @@
// TODO(T128189662): Enforce that data is aligned to
// kFlatbufferDataAlignmentBytes, and add unit tests.
+ // Validate Flatbuffer module before parsing.
+ flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data), size);
+ TORCH_CHECK(
+ mobile::serialization::VerifyModuleBuffer(verifier),
+ "Malformed Flatbuffer module");
+
FlatbufferLoader loader;
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
mobile::Module m = loader.parseModule(flatbuffer_module);
@@ -886,6 +898,13 @@
c10::optional<at::Device> device) {
(void)device;
(void)size;
+
+ // Validate Flatbuffer module before parsing.
+ flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data.get()), size);
+ TORCH_CHECK(
+ mobile::serialization::VerifyModuleBuffer(verifier),
+ "Malformed Flatbuffer module");
+
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
FlatbufferLoader loader;
// replace parserObject with to handle only class with field case