| /* |
| * Copyright (c) Meta Platforms, Inc. and affiliates. |
| * All rights reserved. |
| * |
| * This source code is licensed under the BSD-style license found in the |
| * LICENSE file in the root directory of this source tree. |
| */ |
| |
| #include <cctype> |
| #include <filesystem> |
| |
| #include <cstring> |
| #include <memory> |
| |
| #include <executorch/extension/data_loader/buffer_data_loader.h> |
| #include <executorch/extension/data_loader/file_data_loader.h> |
| #include <executorch/runtime/core/error.h> |
| #include <executorch/runtime/core/result.h> |
| #include <executorch/runtime/executor/program.h> |
| #include <executorch/runtime/platform/runtime.h> |
| #include <executorch/schema/program_generated.h> |
| #include <executorch/test/utils/DeathTest.h> |
| |
| #include <gtest/gtest.h> |
| |
| using namespace ::testing; |
| using torch::executor::Error; |
| using torch::executor::FreeableBuffer; |
| using torch::executor::Program; |
| using torch::executor::Result; |
| using torch::executor::util::BufferDataLoader; |
| using torch::executor::util::FileDataLoader; |
| |
| // Verification level to use for tests not specifically focused on verification. |
| // Use the highest level to exercise it more. |
| constexpr Program::Verification kDefaultVerification = |
| Program::Verification::InternalConsistency; |
| |
| class ProgramTest : public ::testing::Test { |
| protected: |
| void SetUp() override { |
| // Since these tests cause ET_LOG to be called, the PAL must be initialized |
| // first. |
| torch::executor::runtime_init(); |
| |
| // Load the serialized ModuleAdd data. |
| const char* path = std::getenv("ET_MODULE_ADD_PATH"); |
| Result<FileDataLoader> loader = FileDataLoader::from(path); |
| ASSERT_EQ(loader.error(), Error::Ok); |
| |
| // This file should always be compatible. |
| Result<FreeableBuffer> header = |
| loader->Load(/*offset=*/0, Program::kMinHeadBytes); |
| ASSERT_EQ(header.error(), Error::Ok); |
| EXPECT_EQ( |
| Program::check_header(header->data(), header->size()), |
| Program::HeaderStatus::CompatibleVersion); |
| |
| add_loader_ = std::make_unique<FileDataLoader>(std::move(loader.get())); |
| |
| // Load the serialized ModuleAdd data. |
| path = std::getenv("ET_MODULE_MULTI_ENTRY_PATH"); |
| Result<FileDataLoader> multi_loader = FileDataLoader::from(path); |
| ASSERT_EQ(multi_loader.error(), Error::Ok); |
| |
| // This file should always be compatible. |
| Result<FreeableBuffer> multi_header = |
| multi_loader->Load(/*offset=*/0, Program::kMinHeadBytes); |
| ASSERT_EQ(multi_header.error(), Error::Ok); |
| EXPECT_EQ( |
| Program::check_header(multi_header->data(), multi_header->size()), |
| Program::HeaderStatus::CompatibleVersion); |
| |
| multi_loader_ = |
| std::make_unique<FileDataLoader>(std::move(multi_loader.get())); |
| } |
| |
| std::unique_ptr<FileDataLoader> add_loader_; |
| std::unique_ptr<FileDataLoader> multi_loader_; |
| }; |
| |
| namespace torch { |
| namespace executor { |
| namespace testing { |
| // Provides access to private Program methods. |
| class ProgramTestFriend final { |
| public: |
| __ET_NODISCARD static Result<FreeableBuffer> LoadSegment( |
| const Program* program, |
| size_t index) { |
| return program->LoadSegment(index); |
| } |
| |
| const static executorch_flatbuffer::Program* GetInternalProgram( |
| const Program* program) { |
| return program->internal_program_; |
| } |
| }; |
| } // namespace testing |
| } // namespace executor |
| } // namespace torch |
| |
| using torch::executor::testing::ProgramTestFriend; |
| |
| TEST_F(ProgramTest, DataParsesWithMinimalVerification) { |
| // Parse the Program from the data. |
| Result<Program> program = |
| Program::load(add_loader_.get(), Program::Verification::Minimal); |
| |
| // Should have succeeded. |
| EXPECT_EQ(program.error(), Error::Ok); |
| } |
| |
| TEST_F(ProgramTest, DataParsesWithInternalConsistencyVerification) { |
| // Parse the Program from the data. |
| Result<Program> program = Program::load( |
| add_loader_.get(), Program::Verification::InternalConsistency); |
| |
| // Should have succeeded. |
| EXPECT_EQ(program.error(), Error::Ok); |
| } |
| |
| TEST_F(ProgramTest, BadMagicFailsToLoad) { |
| // Make a local copy of the data. |
| size_t data_len = add_loader_->size().get(); |
| auto data = std::make_unique<char[]>(data_len); |
| { |
| Result<FreeableBuffer> src = add_loader_->Load(/*offset=*/0, data_len); |
| ASSERT_EQ(src.error(), Error::Ok); |
| ASSERT_EQ(src->size(), data_len); |
| memcpy(data.get(), src->data(), data_len); |
| // FreeableBuffer goes out of scope and frees its data. |
| } |
| |
| // Corrupt the magic value. |
| EXPECT_EQ(data[4], 'E'); |
| data[4] = 'X'; |
| EXPECT_EQ(data[5], 'T'); |
| data[5] = 'Y'; |
| |
| // Wrap the modified data in a loader. |
| BufferDataLoader data_loader(data.get(), data_len); |
| |
| { |
| // Parse the Program from the data. Use minimal verification to show that |
| // even this catches the header problem. |
| Result<Program> program = |
| Program::load(&data_loader, Program::Verification::Minimal); |
| |
| // Should fail. |
| ASSERT_EQ(program.error(), Error::InvalidProgram); |
| } |
| |
| // Fix the data. |
| data[4] = 'E'; |
| data[5] = 'T'; |
| |
| { |
| // Parse the Program from the data again. |
| Result<Program> program = |
| Program::load(&data_loader, Program::Verification::Minimal); |
| |
| // Should now succeed. |
| ASSERT_EQ(program.error(), Error::Ok); |
| } |
| } |
| |
| TEST_F(ProgramTest, VerificationCatchesTruncation) { |
| // Get the program data. |
| size_t full_data_len = add_loader_->size().get(); |
| Result<FreeableBuffer> full_data = |
| add_loader_->Load(/*offset=*/0, full_data_len); |
| ASSERT_EQ(full_data.error(), Error::Ok); |
| |
| // Make a loader that only exposes half of the data. |
| BufferDataLoader half_data_loader(full_data->data(), full_data_len / 2); |
| |
| // Loading with full verification should fail. |
| Result<Program> program = Program::load( |
| &half_data_loader, Program::Verification::InternalConsistency); |
| ASSERT_EQ(program.error(), Error::InvalidProgram); |
| } |
| |
| TEST_F(ProgramTest, VerificationCatchesCorruption) { |
| // Make a local copy of the data. |
| size_t data_len = add_loader_->size().get(); |
| auto data = std::make_unique<char[]>(data_len); |
| { |
| Result<FreeableBuffer> src = add_loader_->Load(/*offset=*/0, data_len); |
| ASSERT_EQ(src.error(), Error::Ok); |
| ASSERT_EQ(src->size(), data_len); |
| memcpy(data.get(), src->data(), data_len); |
| // FreeableBuffer goes out of scope and frees its data. |
| } |
| |
| // Corrupt the second half of the data. |
| std::memset(&data[data_len / 2], 0x55, data_len - (data_len / 2)); |
| |
| // Wrap the corrupted data in a loader. |
| BufferDataLoader data_loader(data.get(), data_len); |
| |
| // Should fail to parse corrupted data when using full verification. |
| Result<Program> program = |
| Program::load(&data_loader, Program::Verification::InternalConsistency); |
| ASSERT_EQ(program.error(), Error::InvalidProgram); |
| } |
| |
| TEST_F(ProgramTest, UnalignedProgramDataFails) { |
| // Make a local copy of the data, on an odd alignment. |
| size_t data_len = add_loader_->size().get(); |
| auto data = std::make_unique<char[]>(data_len + 1); |
| { |
| Result<FreeableBuffer> src = add_loader_->Load(/*offset=*/0, data_len); |
| ASSERT_EQ(src.error(), Error::Ok); |
| ASSERT_EQ(src->size(), data_len); |
| memcpy(data.get() + 1, src->data(), data_len); |
| // FreeableBuffer goes out of scope and frees its data. |
| } |
| |
| // Wrap the offset data in a loader. |
| BufferDataLoader data_loader(data.get() + 1, data_len); |
| |
| // Should refuse to accept unaligned data. |
| Result<Program> program = |
| Program::load(&data_loader, Program::Verification::Minimal); |
| ASSERT_NE(program.error(), Error::Ok); |
| } |
| |
| TEST_F(ProgramTest, LoadSegmentWithNoSegments) { |
| // Load a program with no segments. |
| Result<Program> program = |
| Program::load(add_loader_.get(), kDefaultVerification); |
| EXPECT_EQ(program.error(), Error::Ok); |
| |
| // Loading a segment should fail. |
| Result<FreeableBuffer> segment = |
| ProgramTestFriend::LoadSegment(&program.get(), 0); |
| EXPECT_NE(segment.error(), Error::Ok); |
| } |
| |
| TEST_F(ProgramTest, ShortDataHeader) { |
| Result<FreeableBuffer> header = |
| add_loader_->Load(/*offset=*/0, Program::kMinHeadBytes); |
| ASSERT_EQ(header.error(), Error::Ok); |
| |
| // Provide less than the required amount of data. |
| EXPECT_EQ( |
| Program::check_header(header->data(), Program::kMinHeadBytes - 1), |
| Program::HeaderStatus::ShortData); |
| } |
| |
| TEST_F(ProgramTest, IncompatibleHeader) { |
| // Make a local copy of the header. |
| size_t data_len = Program::kMinHeadBytes; |
| auto data = std::make_unique<char[]>(data_len); |
| { |
| Result<FreeableBuffer> src = add_loader_->Load(/*offset=*/0, data_len); |
| ASSERT_EQ(src.error(), Error::Ok); |
| ASSERT_EQ(src->size(), data_len); |
| memcpy(data.get(), src->data(), data_len); |
| // FreeableBuffer goes out of scope and frees its data. |
| } |
| |
| // Change the number part of the magic value to a different value. |
| EXPECT_EQ(data[4], 'E'); |
| EXPECT_EQ(data[5], 'T'); |
| EXPECT_TRUE(std::isdigit(data[6])) << "Not a digit: " << data[6]; |
| EXPECT_TRUE(std::isdigit(data[7])) << "Not a digit: " << data[7]; |
| |
| // Modify the tens digit. |
| if (data[6] == '9') { |
| data[6] = '0'; |
| } else { |
| data[6] += 1; |
| } |
| EXPECT_TRUE(std::isdigit(data[6])) << "Not a digit: " << data[6]; |
| |
| // Should count as present but incompatible. |
| EXPECT_EQ( |
| Program::check_header(data.get(), data_len), |
| Program::HeaderStatus::IncompatibleVersion); |
| } |
| |
| TEST_F(ProgramTest, HeaderNotPresent) { |
| // Make a local copy of the header. |
| size_t data_len = Program::kMinHeadBytes; |
| auto data = std::make_unique<char[]>(data_len); |
| { |
| Result<FreeableBuffer> src = add_loader_->Load(/*offset=*/0, data_len); |
| ASSERT_EQ(src.error(), Error::Ok); |
| ASSERT_EQ(src->size(), data_len); |
| memcpy(data.get(), src->data(), data_len); |
| // FreeableBuffer goes out of scope and frees its data. |
| } |
| |
| // Corrupt the magic value. |
| EXPECT_EQ(data[4], 'E'); |
| data[4] = 'X'; |
| EXPECT_EQ(data[5], 'T'); |
| data[5] = 'Y'; |
| |
| // The header is not present. |
| EXPECT_EQ( |
| Program::check_header(data.get(), data_len), |
| Program::HeaderStatus::NotPresent); |
| } |
| |
| TEST_F(ProgramTest, getMethods) { |
| // Parse the Program from the data. |
| Result<Program> program_res = |
| Program::load(multi_loader_.get(), kDefaultVerification); |
| EXPECT_EQ(program_res.error(), Error::Ok); |
| |
| Program program(std::move(program_res.get())); |
| |
| // Method calls should succeed without hitting ET_CHECK. |
| EXPECT_EQ(program.num_methods(), 2); |
| auto res = program.get_method_name(0); |
| EXPECT_TRUE(res.ok()); |
| EXPECT_EQ(strcmp(res.get(), "forward"), 0); |
| auto res2 = program.get_method_name(1); |
| EXPECT_TRUE(res2.ok()); |
| EXPECT_EQ(strcmp(res2.get(), "forward2"), 0); |
| } |
| |
| // Test that the deprecated Load method (capital 'L') still works. |
| TEST_F(ProgramTest, DEPRECATEDLoad) { |
| // Parse the Program from the data. |
| Result<Program> program_res = Program::Load(multi_loader_.get()); |
| EXPECT_EQ(program_res.error(), Error::Ok); |
| } |
| |
| TEST_F(ProgramTest, LoadConstantSegment) { |
| // Load the serialized ModuleLinear data, with constants in the segment and no |
| // constants in the flatbuffer. |
| const char* linear_path = |
| std::getenv("ET_MODULE_LINEAR_CONSTANT_SEGMENT_PATH"); |
| Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path); |
| ASSERT_EQ(linear_loader.error(), Error::Ok); |
| |
| // This file should always be compatible. |
| Result<FreeableBuffer> linear_header = |
| linear_loader->Load(/*offset=*/0, Program::kMinHeadBytes); |
| ASSERT_EQ(linear_header.error(), Error::Ok); |
| EXPECT_EQ( |
| Program::check_header(linear_header->data(), linear_header->size()), |
| Program::HeaderStatus::CompatibleVersion); |
| |
| Result<Program> program = Program::load(&linear_loader.get()); |
| ASSERT_EQ(program.error(), Error::Ok); |
| |
| // Load constant segment data. |
| Result<FreeableBuffer> segment = |
| ProgramTestFriend::LoadSegment(&program.get(), 0); |
| EXPECT_EQ(segment.error(), Error::Ok); |
| |
| const executorch_flatbuffer::Program* flatbuffer_program = |
| ProgramTestFriend::GetInternalProgram(&program.get()); |
| |
| // Expect one segment containing the constants. |
| EXPECT_EQ(flatbuffer_program->segments()->size(), 1); |
| |
| // The constant buffer should be empty. |
| EXPECT_EQ(flatbuffer_program->constant_buffer()->size(), 0); |
| |
| // Check constant segment offsets. |
| EXPECT_EQ(flatbuffer_program->constant_segment()->segment_index(), 0); |
| EXPECT_GE(flatbuffer_program->constant_segment()->offsets()->size(), 1); |
| } |
| |
| TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) { |
| // Load the serialized ModuleLinear data, with constants in the flatbuffer and |
| // no constants in the segment. |
| const char* linear_path = |
| std::getenv("ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH"); |
| Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path); |
| ASSERT_EQ(linear_loader.error(), Error::Ok); |
| |
| // This file should always be compatible. |
| Result<FreeableBuffer> linear_header = |
| linear_loader->Load(/*offset=*/0, Program::kMinHeadBytes); |
| ASSERT_EQ(linear_header.error(), Error::Ok); |
| EXPECT_EQ( |
| Program::check_header(linear_header->data(), linear_header->size()), |
| Program::HeaderStatus::CompatibleVersion); |
| |
| Result<Program> program = Program::load(&linear_loader.get()); |
| ASSERT_EQ(program.error(), Error::Ok); |
| |
| const executorch_flatbuffer::Program* flatbuffer_program = |
| ProgramTestFriend::GetInternalProgram(&program.get()); |
| |
| // Expect no segments. |
| EXPECT_EQ(flatbuffer_program->segments()->size(), 0); |
| |
| // The constant buffer should exist. |
| EXPECT_GE(flatbuffer_program->constant_buffer()->size(), 1); |
| } |