| /* |
| * Copyright (c) Meta Platforms, Inc. and affiliates. |
| * All rights reserved. |
| * |
| * This source code is licensed under the BSD-style license found in the |
| * LICENSE file in the root directory of this source tree. |
| */ |
| |
| #include <executorch/runtime/executor/program.h> |
| |
| #include <cctype> |
| #include <filesystem> |
| |
| #include <cstring> |
| #include <memory> |
| |
| #include <executorch/extension/data_loader/buffer_data_loader.h> |
| #include <executorch/extension/data_loader/file_data_loader.h> |
| #include <executorch/runtime/core/error.h> |
| #include <executorch/runtime/core/result.h> |
| #include <executorch/runtime/platform/runtime.h> |
| #include <executorch/schema/program_generated.h> |
| #include <executorch/test/utils/DeathTest.h> |
| |
| #include <gtest/gtest.h> |
| |
| using namespace ::testing; |
| using executorch::runtime::DataLoader; |
| using executorch::runtime::Error; |
| using executorch::runtime::FreeableBuffer; |
| using executorch::runtime::Program; |
| using executorch::runtime::Result; |
| using torch::executor::util::BufferDataLoader; |
| using torch::executor::util::FileDataLoader; |
| |
| // Verification level to use for tests not specifically focused on verification. |
| // Use the highest level to exercise it more. |
| constexpr Program::Verification kDefaultVerification = |
| Program::Verification::InternalConsistency; |
| |
| class ProgramTest : public ::testing::Test { |
| protected: |
| void SetUp() override { |
| // Since these tests cause ET_LOG to be called, the PAL must be initialized |
| // first. |
| executorch::runtime::runtime_init(); |
| |
| // Load the serialized ModuleAdd data. |
| const char* path = std::getenv("ET_MODULE_ADD_PATH"); |
| Result<FileDataLoader> loader = FileDataLoader::from(path); |
| ASSERT_EQ(loader.error(), Error::Ok); |
| |
| // This file should always be compatible. |
| Result<FreeableBuffer> header = loader->load( |
| /*offset=*/0, |
| Program::kMinHeadBytes, |
| /*segment_info=*/ |
| DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); |
| ASSERT_EQ(header.error(), Error::Ok); |
| EXPECT_EQ( |
| Program::check_header(header->data(), header->size()), |
| Program::HeaderStatus::CompatibleVersion); |
| |
| add_loader_ = std::make_unique<FileDataLoader>(std::move(loader.get())); |
| |
| // Load the serialized ModuleMultiEntry data. |
| path = std::getenv("ET_MODULE_MULTI_ENTRY_PATH"); |
| Result<FileDataLoader> multi_loader = FileDataLoader::from(path); |
| ASSERT_EQ(multi_loader.error(), Error::Ok); |
| |
| // This file should always be compatible. |
| Result<FreeableBuffer> multi_header = multi_loader->load( |
| /*offset=*/0, |
| Program::kMinHeadBytes, |
| /*segment_info=*/ |
| DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); |
| ASSERT_EQ(multi_header.error(), Error::Ok); |
| EXPECT_EQ( |
| Program::check_header(multi_header->data(), multi_header->size()), |
| Program::HeaderStatus::CompatibleVersion); |
| |
| multi_loader_ = |
| std::make_unique<FileDataLoader>(std::move(multi_loader.get())); |
| } |
| |
| std::unique_ptr<FileDataLoader> add_loader_; |
| std::unique_ptr<FileDataLoader> multi_loader_; |
| }; |
| |
| namespace executorch { |
| namespace runtime { |
| namespace testing { |
| // Provides access to private Program methods. |
| class ProgramTestFriend final { |
| public: |
| ET_NODISCARD static Result<FreeableBuffer> LoadSegment( |
| const Program* program, |
| const DataLoader::SegmentInfo& segment_info) { |
| return program->LoadSegment(segment_info); |
| } |
| |
| ET_NODISCARD static Error load_mutable_subsegment_into( |
| const Program* program, |
| size_t mutable_data_segments_index, |
| size_t offset_index, |
| size_t size, |
| void* buffer) { |
| return program->load_mutable_subsegment_into( |
| mutable_data_segments_index, offset_index, size, buffer); |
| } |
| |
| const static executorch_flatbuffer::Program* GetInternalProgram( |
| const Program* program) { |
| return program->internal_program_; |
| } |
| }; |
| } // namespace testing |
| } // namespace runtime |
| } // namespace executorch |
| |
| using executorch::runtime::testing::ProgramTestFriend; |
| |
| TEST_F(ProgramTest, DataParsesWithMinimalVerification) { |
| // Parse the Program from the data. |
| Result<Program> program = |
| Program::load(add_loader_.get(), Program::Verification::Minimal); |
| |
| // Should have succeeded. |
| EXPECT_EQ(program.error(), Error::Ok); |
| } |
| |
| TEST_F(ProgramTest, DataParsesWithInternalConsistencyVerification) { |
| // Parse the Program from the data. |
| Result<Program> program = Program::load( |
| add_loader_.get(), Program::Verification::InternalConsistency); |
| |
| // Should have succeeded. |
| EXPECT_EQ(program.error(), Error::Ok); |
| } |
| |
| TEST_F(ProgramTest, BadMagicFailsToLoad) { |
| // Make a local copy of the data. |
| size_t data_len = add_loader_->size().get(); |
| auto data = std::make_unique<char[]>(data_len); |
| { |
| Result<FreeableBuffer> src = add_loader_->load( |
| /*offset=*/0, |
| data_len, |
| /*segment_info=*/ |
| DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); |
| ASSERT_EQ(src.error(), Error::Ok); |
| ASSERT_EQ(src->size(), data_len); |
| memcpy(data.get(), src->data(), data_len); |
| // FreeableBuffer goes out of scope and frees its data. |
| } |
| |
| // Corrupt the magic value. |
| EXPECT_EQ(data[4], 'E'); |
| data[4] = 'X'; |
| EXPECT_EQ(data[5], 'T'); |
| data[5] = 'Y'; |
| |
| // Wrap the modified data in a loader. |
| BufferDataLoader data_loader(data.get(), data_len); |
| |
| { |
| // Parse the Program from the data. Use minimal verification to show that |
| // even this catches the header problem. |
| Result<Program> program = |
| Program::load(&data_loader, Program::Verification::Minimal); |
| |
| // Should fail. |
| ASSERT_EQ(program.error(), Error::InvalidProgram); |
| } |
| |
| // Fix the data. |
| data[4] = 'E'; |
| data[5] = 'T'; |
| |
| { |
| // Parse the Program from the data again. |
| Result<Program> program = |
| Program::load(&data_loader, Program::Verification::Minimal); |
| |
| // Should now succeed. |
| ASSERT_EQ(program.error(), Error::Ok); |
| } |
| } |
| |
| TEST_F(ProgramTest, VerificationCatchesTruncation) { |
| // Get the program data. |
| size_t full_data_len = add_loader_->size().get(); |
| Result<FreeableBuffer> full_data = add_loader_->load( |
| /*offset=*/0, |
| full_data_len, |
| /*segment_info=*/ |
| DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); |
| ASSERT_EQ(full_data.error(), Error::Ok); |
| |
| // Make a loader that only exposes half of the data. |
| BufferDataLoader half_data_loader(full_data->data(), full_data_len / 2); |
| |
| // Loading with full verification should fail. |
| Result<Program> program = Program::load( |
| &half_data_loader, Program::Verification::InternalConsistency); |
| ASSERT_EQ(program.error(), Error::InvalidProgram); |
| } |
| |
| TEST_F(ProgramTest, VerificationCatchesCorruption) { |
| // Make a local copy of the data. |
| size_t data_len = add_loader_->size().get(); |
| auto data = std::make_unique<char[]>(data_len); |
| { |
| Result<FreeableBuffer> src = add_loader_->load( |
| /*offset=*/0, |
| data_len, |
| /*segment_info=*/ |
| DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); |
| ASSERT_EQ(src.error(), Error::Ok); |
| ASSERT_EQ(src->size(), data_len); |
| memcpy(data.get(), src->data(), data_len); |
| // FreeableBuffer goes out of scope and frees its data. |
| } |
| |
| // Corrupt the second half of the data. |
| std::memset(&data[data_len / 2], 0x55, data_len - (data_len / 2)); |
| |
| // Wrap the corrupted data in a loader. |
| BufferDataLoader data_loader(data.get(), data_len); |
| |
| // Should fail to parse corrupted data when using full verification. |
| Result<Program> program = |
| Program::load(&data_loader, Program::Verification::InternalConsistency); |
| ASSERT_EQ(program.error(), Error::InvalidProgram); |
| } |
| |
| TEST_F(ProgramTest, UnalignedProgramDataFails) { |
| // Make a local copy of the data, on an odd alignment. |
| size_t data_len = add_loader_->size().get(); |
| auto data = std::make_unique<char[]>(data_len + 1); |
| { |
| Result<FreeableBuffer> src = add_loader_->load( |
| /*offset=*/0, |
| data_len, |
| /*segment_info=*/ |
| DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); |
| ASSERT_EQ(src.error(), Error::Ok); |
| ASSERT_EQ(src->size(), data_len); |
| memcpy(data.get() + 1, src->data(), data_len); |
| // FreeableBuffer goes out of scope and frees its data. |
| } |
| |
| // Wrap the offset data in a loader. |
| BufferDataLoader data_loader(data.get() + 1, data_len); |
| |
| // Should refuse to accept unaligned data. |
| Result<Program> program = |
| Program::load(&data_loader, Program::Verification::Minimal); |
| ASSERT_NE(program.error(), Error::Ok); |
| } |
| |
| TEST_F(ProgramTest, LoadSegmentWithNoSegments) { |
| // Load a program with no appended segments. |
| Result<Program> program = |
| Program::load(add_loader_.get(), kDefaultVerification); |
| EXPECT_EQ(program.error(), Error::Ok); |
| |
| // Loading a non-program segment should fail. |
| const auto segment_info = DataLoader::SegmentInfo( |
| DataLoader::SegmentInfo::Type::Backend, |
| /*segment_index=*/0, |
| "some-backend"); |
| Result<FreeableBuffer> segment = |
| ProgramTestFriend::LoadSegment(&program.get(), segment_info); |
| EXPECT_NE(segment.error(), Error::Ok); |
| } |
| |
| TEST_F(ProgramTest, ShortDataHeader) { |
| Result<FreeableBuffer> header = add_loader_->load( |
| /*offset=*/0, |
| Program::kMinHeadBytes, |
| /*segment_info=*/ |
| DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); |
| ASSERT_EQ(header.error(), Error::Ok); |
| |
| // Provide less than the required amount of data. |
| EXPECT_EQ( |
| Program::check_header(header->data(), Program::kMinHeadBytes - 1), |
| Program::HeaderStatus::ShortData); |
| } |
| |
| TEST_F(ProgramTest, IncompatibleHeader) { |
| // Make a local copy of the header. |
| size_t data_len = Program::kMinHeadBytes; |
| auto data = std::make_unique<char[]>(data_len); |
| { |
| Result<FreeableBuffer> src = add_loader_->load( |
| /*offset=*/0, |
| data_len, |
| /*segment_info=*/ |
| DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); |
| ASSERT_EQ(src.error(), Error::Ok); |
| ASSERT_EQ(src->size(), data_len); |
| memcpy(data.get(), src->data(), data_len); |
| // FreeableBuffer goes out of scope and frees its data. |
| } |
| |
| // Change the number part of the magic value to a different value. |
| EXPECT_EQ(data[4], 'E'); |
| EXPECT_EQ(data[5], 'T'); |
| EXPECT_TRUE(std::isdigit(data[6])) << "Not a digit: " << data[6]; |
| EXPECT_TRUE(std::isdigit(data[7])) << "Not a digit: " << data[7]; |
| |
| // Modify the tens digit. |
| if (data[6] == '9') { |
| data[6] = '0'; |
| } else { |
| data[6] += 1; |
| } |
| EXPECT_TRUE(std::isdigit(data[6])) << "Not a digit: " << data[6]; |
| |
| // Should count as present but incompatible. |
| EXPECT_EQ( |
| Program::check_header(data.get(), data_len), |
| Program::HeaderStatus::IncompatibleVersion); |
| } |
| |
| TEST_F(ProgramTest, HeaderNotPresent) { |
| // Make a local copy of the header. |
| size_t data_len = Program::kMinHeadBytes; |
| auto data = std::make_unique<char[]>(data_len); |
| { |
| Result<FreeableBuffer> src = add_loader_->load( |
| /*offset=*/0, |
| data_len, |
| /*segment_info=*/ |
| DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); |
| ASSERT_EQ(src.error(), Error::Ok); |
| ASSERT_EQ(src->size(), data_len); |
| memcpy(data.get(), src->data(), data_len); |
| // FreeableBuffer goes out of scope and frees its data. |
| } |
| |
| // Corrupt the magic value. |
| EXPECT_EQ(data[4], 'E'); |
| data[4] = 'X'; |
| EXPECT_EQ(data[5], 'T'); |
| data[5] = 'Y'; |
| |
| // The header is not present. |
| EXPECT_EQ( |
| Program::check_header(data.get(), data_len), |
| Program::HeaderStatus::NotPresent); |
| } |
| |
| TEST_F(ProgramTest, getMethods) { |
| // Parse the Program from the data. |
| Result<Program> program_res = |
| Program::load(multi_loader_.get(), kDefaultVerification); |
| EXPECT_EQ(program_res.error(), Error::Ok); |
| |
| Program program(std::move(program_res.get())); |
| |
| // Method calls should succeed without hitting ET_CHECK. |
| EXPECT_EQ(program.num_methods(), 2); |
| auto res = program.get_method_name(0); |
| EXPECT_TRUE(res.ok()); |
| EXPECT_EQ(strcmp(res.get(), "forward"), 0); |
| auto res2 = program.get_method_name(1); |
| EXPECT_TRUE(res2.ok()); |
| EXPECT_EQ(strcmp(res2.get(), "forward2"), 0); |
| } |
| |
| // Test that the deprecated Load method (capital 'L') still works. |
| TEST_F(ProgramTest, DEPRECATEDLoad) { |
| // Parse the Program from the data. |
| // NOLINTNEXTLINE(facebook-hte-Deprecated) |
| Result<Program> program_res = Program::Load(multi_loader_.get()); |
| EXPECT_EQ(program_res.error(), Error::Ok); |
| } |
| |
| TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) { |
| Result<Program> program = |
| Program::load(add_loader_.get(), kDefaultVerification); |
| ASSERT_EQ(program.error(), Error::Ok); |
| |
| // Load constant segment data should fail. |
| const auto segment_info = DataLoader::SegmentInfo( |
| DataLoader::SegmentInfo::Type::Constant, |
| /*segment_index=*/0); |
| Result<FreeableBuffer> segment = |
| ProgramTestFriend::LoadSegment(&program.get(), segment_info); |
| EXPECT_NE(segment.error(), Error::Ok); |
| |
| const executorch_flatbuffer::Program* flatbuffer_program = |
| ProgramTestFriend::GetInternalProgram(&program.get()); |
| |
| // The constant buffer should be empty. |
| EXPECT_EQ(flatbuffer_program->constant_buffer()->size(), 0); |
| |
| // Expect 1 constant segment, placeholder for non-const tensors. |
| EXPECT_EQ(flatbuffer_program->segments()->size(), 1); |
| } |
| |
| TEST_F(ProgramTest, LoadConstantSegment) { |
| // Load the serialized ModuleLinear data, with constants in the segment. |
| const char* linear_path = std::getenv("ET_MODULE_LINEAR_PATH"); |
| Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path); |
| ASSERT_EQ(linear_loader.error(), Error::Ok); |
| |
| // This file should always be compatible. |
| Result<FreeableBuffer> linear_header = linear_loader->load( |
| /*offset=*/0, |
| Program::kMinHeadBytes, |
| /*segment_info=*/ |
| DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); |
| ASSERT_EQ(linear_header.error(), Error::Ok); |
| EXPECT_EQ( |
| Program::check_header(linear_header->data(), linear_header->size()), |
| Program::HeaderStatus::CompatibleVersion); |
| |
| Result<Program> program = Program::load(&linear_loader.get()); |
| ASSERT_EQ(program.error(), Error::Ok); |
| |
| // Load constant segment data, which is currently always in segment index |
| // zero. |
| const auto segment_info = DataLoader::SegmentInfo( |
| DataLoader::SegmentInfo::Type::Constant, |
| /*segment_index=*/0); |
| Result<FreeableBuffer> segment = |
| ProgramTestFriend::LoadSegment(&program.get(), segment_info); |
| EXPECT_EQ(segment.error(), Error::Ok); |
| |
| const executorch_flatbuffer::Program* flatbuffer_program = |
| ProgramTestFriend::GetInternalProgram(&program.get()); |
| |
| // Expect one segment containing the constants. |
| EXPECT_EQ(flatbuffer_program->segments()->size(), 1); |
| |
| // The constant buffer should be empty. |
| EXPECT_EQ(flatbuffer_program->constant_buffer()->size(), 0); |
| |
| // Check constant segment offsets. |
| EXPECT_EQ(flatbuffer_program->constant_segment()->segment_index(), 0); |
| EXPECT_GE(flatbuffer_program->constant_segment()->offsets()->size(), 1); |
| } |
| |
| TEST_F(ProgramTest, LoadConstantSegmentWhenConstantBufferExists) { |
| // Load the serialized ModuleLinear data, with constants in the flatbuffer and |
| // no constants in the segment. |
| const char* linear_path = |
| std::getenv("DEPRECATED_ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH"); |
| Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path); |
| ASSERT_EQ(linear_loader.error(), Error::Ok); |
| |
| // This file should always be compatible. |
| Result<FreeableBuffer> linear_header = linear_loader->load( |
| /*offset=*/0, |
| Program::kMinHeadBytes, |
| /*segment_info=*/ |
| DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); |
| ASSERT_EQ(linear_header.error(), Error::Ok); |
| EXPECT_EQ( |
| Program::check_header(linear_header->data(), linear_header->size()), |
| Program::HeaderStatus::CompatibleVersion); |
| |
| Result<Program> program = Program::load(&linear_loader.get()); |
| ASSERT_EQ(program.error(), Error::Ok); |
| |
| const executorch_flatbuffer::Program* flatbuffer_program = |
| ProgramTestFriend::GetInternalProgram(&program.get()); |
| |
| // Expect no segments. |
| EXPECT_EQ(flatbuffer_program->segments()->size(), 0); |
| |
| // The constant buffer should exist. |
| EXPECT_GE(flatbuffer_program->constant_buffer()->size(), 1); |
| } |
| |
| TEST_F(ProgramTest, LoadFromMutableSegment) { |
| // Load the serialized ModuleSimpleTrain data. |
| auto path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH"); |
| Result<FileDataLoader> training_loader = FileDataLoader::from(path); |
| ASSERT_EQ(training_loader.error(), Error::Ok); |
| |
| // This file should always be compatible. |
| Result<FreeableBuffer> training_header = training_loader->load( |
| /*offset=*/0, |
| Program::kMinHeadBytes, |
| DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); |
| ASSERT_EQ(training_header.error(), Error::Ok); |
| EXPECT_EQ( |
| Program::check_header(training_header->data(), training_header->size()), |
| Program::HeaderStatus::CompatibleVersion); |
| |
| Result<Program> program = Program::load(&training_loader.get()); |
| ASSERT_EQ(program.error(), Error::Ok); |
| |
| // dummy buffers to load into |
| uint8_t buffer[1] = {0}; |
| uint8_t buffer2[1] = {0}; |
| |
| // Load some mutable segment data |
| Error err = ProgramTestFriend::load_mutable_subsegment_into( |
| &program.get(), 0, 1, 1, buffer); |
| EXPECT_EQ(err, Error::Ok); |
| |
| // Check that the data loaded correctly, and then mutate it |
| EXPECT_EQ(buffer[0], 232); // 232 comes from inspecting the file itself. The |
| // file is seeded so this value should be stable. |
| buffer[0] = 0; |
| |
| // Load the same mutable segment data from file into a different buffer. |
| err = ProgramTestFriend::load_mutable_subsegment_into( |
| &program.get(), |
| 0, // mutable_data_segments_index |
| 1, // offset_index |
| 1, // size |
| buffer2); |
| EXPECT_EQ(err, Error::Ok); |
| |
| // Check that new data loaded from the file does not reflect the change to |
| // buffer. |
| EXPECT_EQ(buffer2[0], 232); |
| |
| const executorch_flatbuffer::Program* flatbuffer_program = |
| ProgramTestFriend::GetInternalProgram(&program.get()); |
| |
| // Expect 2 segments. 1 mutable segment and 1 constant segment. |
| EXPECT_EQ(flatbuffer_program->segments()->size(), 2); |
| |
| // Expect a mutable data segment. |
| EXPECT_EQ(flatbuffer_program->mutable_data_segments()->size(), 1); |
| |
| // Expect the 0 index to be reserved and the offsets for weight and bias of |
| // linear to be indices 1 and 2. |
| EXPECT_EQ( |
| flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->size(), |
| 3); |
| EXPECT_EQ( |
| flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(0), |
| 0); |
| EXPECT_EQ( |
| flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(1), |
| 0); |
| EXPECT_EQ( |
| flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(2), |
| 36); |
| |
| // Loading beyond file should fail |
| err = ProgramTestFriend::load_mutable_subsegment_into( |
| &program.get(), 0, 1, 500, buffer); |
| EXPECT_NE(err, Error::Ok); |
| |
| // Loading beyond offsets should fail |
| err = ProgramTestFriend::load_mutable_subsegment_into( |
| &program.get(), 0, 500, 1, buffer); |
| EXPECT_NE(err, Error::Ok); |
| |
| // Loading beyond segments should fail |
| err = ProgramTestFriend::load_mutable_subsegment_into( |
| &program.get(), 500, 1, 1, buffer); |
| EXPECT_NE(err, Error::Ok); |
| } |