| /* |
| * Copyright (c) Meta Platforms, Inc. and affiliates. |
| * All rights reserved. |
| * |
| * This source code is licensed under the BSD-style license found in the |
| * LICENSE file in the root directory of this source tree. |
| */ |
| |
| /** |
| * @file |
| * |
| * Creates multiple Executor instances at the same time, demonstrating that the |
| * same process can handle multiple runtimes at once. |
| * |
| * Usage: |
| * multi_runner --models=<model.pte>[,<m2.pte>[,...]] [--num_instances=<num>] |
| */ |
| |
| #include <gflags/gflags.h> |
| |
| #include <sys/stat.h> |
| |
| #include <cassert> |
| #include <condition_variable> |
| #include <cstdio> |
| #include <functional> |
| #include <memory> |
| #include <sstream> |
| #include <thread> |
| #include <tuple> |
| |
| #include <executorch/extension/data_loader/buffer_data_loader.h> |
| #include <executorch/runtime/core/error.h> |
| #include <executorch/runtime/core/result.h> |
| #include <executorch/runtime/executor/executor.h> |
| #include <executorch/runtime/executor/test/managed_memory_manager.h> |
| #include <executorch/runtime/platform/log.h> |
| #include <executorch/runtime/platform/runtime.h> |
| #include <executorch/util/read_file.h> |
| #include <executorch/util/util.h> |
| |
| DEFINE_string( |
| models, |
| "", |
| "Comma-separated list of paths to serialized Executorch model files"); |
| DEFINE_int32( |
| num_instances, |
| 10, |
| "Number of Executor instances to create in parallel, for each model"); |
| |
| static bool validate_path_list( |
| const char* flagname, |
| const std::string& path_list); |
| DEFINE_validator(models, &validate_path_list); |
| |
| static bool validate_positive_int32(const char* flagname, int32_t val); |
| DEFINE_validator(num_instances, &validate_positive_int32); |
| |
| namespace { |
| using torch::executor::DataLoader; |
| using torch::executor::Error; |
| using torch::executor::Executor; |
| using torch::executor::FreeableBuffer; |
| using torch::executor::MemoryAllocator; |
| using torch::executor::MemoryManager; |
| using torch::executor::Program; |
| using torch::executor::Result; |
| using torch::executor::testing::ManagedMemoryManager; |
| using torch::executor::util::BufferDataLoader; |
| |
| /** |
| * A model that has been loaded and has had its execution plan and inputs |
| * prepared. Can be run once. |
| * |
| * Creates and owns the underyling state, making things easier to manage. |
| */ |
| class PreparedModel final { |
| public: |
| PreparedModel( |
| const std::string& name, |
| const void* model_data, |
| size_t model_data_size, |
| size_t non_const_mem_bytes, |
| size_t runtime_mem_bytes) |
| : name_(name), |
| loader_(model_data, model_data_size), |
| program_(load_program_or_die(loader_)), |
| memory_manager_(non_const_mem_bytes, runtime_mem_bytes), |
| executor_(&program_, &memory_manager_.get()), |
| has_run_(false) { |
| Error status = executor_.init_execution_plan(); |
| ET_CHECK_MSG( |
| status == Error::Ok, |
| "init_execution_plan() failed with status 0x%" PRIx32, |
| status); |
| inputs_ = |
| torch::executor::util::PrepareInputTensors(executor_.execution_plan()); |
| } |
| |
| void run() { |
| ET_CHECK_MSG(!has_run_, "A PreparedModel may only be run once"); |
| has_run_ = true; |
| |
| Error status = executor_.execution_plan().execute(); |
| ET_CHECK_MSG( |
| status == Error::Ok, |
| "plan.execute() failed with status 0x%" PRIx32, |
| status); |
| |
| // TODO(T131578656): Do something with the outputs. |
| } |
| |
| const std::string& name() const { |
| return name_; |
| } |
| |
| ~PreparedModel() { |
| torch::executor::util::FreeInputs(inputs_); |
| } |
| |
| private: |
| static Program load_program_or_die(DataLoader& loader) { |
| Result<Program> program = Program::Load(&loader); |
| ET_CHECK(program.ok()); |
| return std::move(program.get()); |
| } |
| |
| const std::string name_; |
| BufferDataLoader loader_; // Needs to outlive program_ |
| Program program_; // Needs to outlive executor_ |
| ManagedMemoryManager memory_manager_; // Needs to outlive executor_ |
| Executor executor_; |
| exec_aten::ArrayRef<void*> inputs_; |
| |
| bool has_run_; |
| }; |
| |
| /** |
| * Creates PreparedModels based on the provided serialized data and memory |
| * parameters. |
| */ |
| class ModelFactory { |
| public: |
| ModelFactory( |
| const std::string& name, // For debugging |
| std::shared_ptr<const char> model_data, |
| size_t model_data_size, |
| size_t non_const_mem_bytes = 40 * 1024U * 1024U, // 40 MB |
| size_t runtime_mem_bytes = 2 * 1024U * 1024U) // 2 MB |
| : name_(name), |
| model_data_(model_data), |
| model_data_size_(model_data_size), |
| non_const_mem_bytes_(non_const_mem_bytes), |
| runtime_mem_bytes_(runtime_mem_bytes) {} |
| |
| std::unique_ptr<PreparedModel> prepare( |
| std::string_view name_affix = "") const { |
| return std::make_unique<PreparedModel>( |
| name_affix.empty() ? name_ : std::string(name_affix) + ":" + name_, |
| model_data_.get(), |
| model_data_size_, |
| non_const_mem_bytes_, |
| runtime_mem_bytes_); |
| } |
| |
| const std::string& name() const { |
| return name_; |
| } |
| |
| private: |
| const std::string name_; |
| std::shared_ptr<const char> model_data_; |
| |
| const size_t model_data_size_; |
| const size_t non_const_mem_bytes_; |
| const size_t runtime_mem_bytes_; |
| }; |
| |
| /// Synchronizes a set of model threads as they walk through prepare/run states. |
| class Synchronizer { |
| public: |
| explicit Synchronizer(size_t total_threads) |
| : total_threads_(total_threads), state_(State::INIT_THREAD) {} |
| |
| /// The states for threads to move through. Must advance in order. |
| enum class State { |
| /// Initial state. |
| INIT_THREAD, |
| |
| /// Thread is ready to prepare its model instance. |
| PREPARE_MODEL, |
| |
| /// Thread is ready to run its model instance. |
| RUN_MODEL, |
| }; |
| |
| /// Wait until all threads have requested to advance to this state, then |
| /// advance all of them. |
| void advance_to(State new_state) { |
| std::unique_lock<std::mutex> lock(lock_); |
| |
| // Enforce valid state machine transitions. |
| assert( |
| (new_state == State::PREPARE_MODEL && state_ == State::INIT_THREAD) || |
| (new_state == State::RUN_MODEL && state_ == State::PREPARE_MODEL)); |
| |
| // Indicate that this thread is ready to move to the new state. |
| num_ready_++; |
| if (num_ready_ == total_threads_) { |
| // We were the last thread to become ready. Tell all threads to |
| // move to the next state. |
| state_ = new_state; |
| num_ready_ = 0; |
| cv_.notify_all(); |
| } else { |
| // Wait until all other threads are ready. |
| cv_.wait(lock, [=] { return this->state_ == new_state; }); |
| } |
| } |
| |
| private: |
| /// The total number of threads to wait for. |
| const size_t total_threads_; |
| |
| /// Locks all mutable fields in this class. |
| std::mutex lock_; |
| |
| /// The number of threads that are ready to move to the next state. |
| size_t num_ready_ = 0; |
| |
| /// The state that all threads should be in. |
| State state_; |
| |
| /// Signals threads to check for state updates. |
| std::condition_variable cv_; |
| }; |
| |
| /** |
| * Waits for all threads to begin running; prepares a model and waits for all |
| * threads to finish preparation; runs the model and exits. |
| */ |
| void model_thread(ModelFactory& factory, Synchronizer& sync, size_t thread_id) { |
| ET_LOG( |
| Info, |
| "[%zu] Thread has started for %s.", |
| thread_id, |
| factory.name().c_str()); |
| |
| sync.advance_to(Synchronizer::State::PREPARE_MODEL); |
| |
| // Create and prepare our model instance. |
| ET_LOG(Info, "[%zu] Preparing %s...", thread_id, factory.name().c_str()); |
| std::unique_ptr<PreparedModel> model = |
| factory.prepare(/*name_affix=*/std::to_string(thread_id)); |
| ET_LOG(Info, "[%zu] Prepared %s.", thread_id, model->name().c_str()); |
| |
| sync.advance_to(Synchronizer::State::RUN_MODEL); |
| |
| // Run our model. |
| ET_LOG(Info, "[%zu] Running %s...", thread_id, model->name().c_str()); |
| model->run(); |
| ET_LOG( |
| Info, "[%zu] Finished running %s...", thread_id, model->name().c_str()); |
| |
| // TODO(T131578656): Check the model output. |
| } |
| |
| /** |
| * Splits the provided string on `,` and returns a vector of the non-empty |
| * elements. Does not string whitespace. |
| */ |
| std::vector<std::string> split_string_list(const std::string& list) { |
| std::vector<std::string> items; |
| std::stringstream sstream(list); |
| while (sstream.good()) { |
| std::string item; |
| getline(sstream, item, ','); |
| if (!item.empty()) { |
| items.push_back(item); |
| } |
| } |
| return items; |
| } |
| |
| } // namespace |
| |
| int main(int argc, char** argv) { |
| torch::executor::runtime_init(); |
| |
| // Parse and extract flags. |
| gflags::SetUsageMessage( |
| "Creates multiple Executor instances at the same time, demonstrating " |
| "that the same process can handle multiple runtimes at once."); |
| gflags::ParseCommandLineFlags(&argc, &argv, true); |
| std::vector<std::string> model_paths = split_string_list(FLAGS_models); |
| size_t num_instances = FLAGS_num_instances; |
| |
| // Create a factory for each model provided on the commandline. |
| std::vector<std::unique_ptr<ModelFactory>> factories; |
| for (const auto& model_path : model_paths) { |
| std::shared_ptr<char> file_data; |
| size_t file_size; |
| Error err = torch::executor::util::read_file_content( |
| model_path.c_str(), &file_data, &file_size); |
| ET_CHECK(err == Error::Ok); |
| factories.push_back(std::make_unique<ModelFactory>( |
| /*name=*/model_path, file_data, file_size)); |
| } |
| |
| // Spawn threads to prepare and run separate instances of the models in |
| // parallel. |
| const size_t num_threads = factories.size() * num_instances; |
| Synchronizer state(num_threads); |
| std::vector<std::thread> threads; |
| size_t thread_id = 0; // Unique ID for every thread. |
| ET_LOG(Info, "Creating %zu threads...", num_threads); |
| for (const auto& factory : factories) { |
| for (size_t i = 0; i < num_instances; ++i) { |
| threads.push_back(std::thread( |
| model_thread, std::ref(*factory), std::ref(state), thread_id++)); |
| } |
| } |
| |
| // Wait for all threads to finish. |
| ET_LOG(Info, "Waiting for %zu threads to exit...", threads.size()); |
| for (auto& thread : threads) { |
| thread.join(); |
| } |
| ET_LOG(Info, "All %zu threads exited.", threads.size()); |
| } |
| |
| // |
| // Flag validation |
| // |
| |
| /// Returns true if the specified path exists in the filesystem. |
| static bool path_exists(const std::string& path) { |
| struct stat st; |
| return stat(path.c_str(), &st) == 0; |
| } |
| |
| /// Returns true if `path_list` contains a comma-separated list of at least one |
| /// path that exists in the filesystem. |
| static bool validate_path_list( |
| const char* flagname, |
| const std::string& path_list) { |
| const std::vector<std::string> paths = split_string_list(path_list); |
| if (paths.empty()) { |
| fprintf( |
| stderr, "Must specify at least one valid path with --%s\n", flagname); |
| return false; |
| } |
| for (const auto& path : split_string_list(path_list)) { |
| if (!path_exists(path)) { |
| fprintf( |
| stderr, |
| "Path '%s' does not exist in --%s='%s'\n", |
| path.c_str(), |
| flagname, |
| path_list.c_str()); |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| /// Returns true if `val` is positive. |
| static bool validate_positive_int32(const char* flagname, int32_t val) { |
| if (val <= 0) { |
| fprintf( |
| stderr, "Value must be positive for --%s=%" PRId32 "\n", flagname, val); |
| return false; |
| } |
| return true; |
| } |