blob: 64763c72576a6ffbce4a94069e48d938b08c1a5a [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// A simple LLaVA runner that includes preprocessing and post processing logic.
// The runner takes in a prompt string as well as a list of images as input and
// emits a string as output.
#include <executorch/examples/models/llava/runner/llava_image_prefiller.h>
#include <executorch/examples/models/llava/runner/llava_runner.h>
#include <executorch/examples/models/llava/runner/llava_text_decoder_runner.h>
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
#include <ctime>
#include <memory>
#include <sstream>
#include <vector>
using ::executorch::extension::llm::Stats;
namespace torch::executor {
bool LlavaRunner::is_loaded() {
bool instantiated = tokenizer_ && text_decoder_runner_ && text_prefiller_ &&
image_prefiller_ && text_token_generator_;
if (!instantiated) {
return false;
}
return text_decoder_runner_->is_method_loaded() &&
image_prefiller_->is_method_loaded();
}
Error LlavaRunner::load() {
if (is_loaded()) {
return Error::Ok;
}
stats_.model_load_start_ms = util::time_in_ms();
// Load the tokenizer
tokenizer_ = std::make_unique<BPETokenizer>();
tokenizer_->load(tokenizer_path_);
// Load the text decoder runner
text_decoder_runner_ = std::make_unique<LlavaTextDecoderRunner>(
module_.get(), tokenizer_->vocab_size(), temperature_);
text_decoder_runner_->load();
// Load the text prefiller
text_prefiller_ = std::make_unique<TextPrefiller>(
text_decoder_runner_.get(),
/*use_kv_cache=*/true,
/*enable_parallel_prefill=*/true);
// Load the image prefiller
image_prefiller_ = std::make_unique<LlavaImagePrefiller>(module_.get());
image_prefiller_->load();
// Load the text token generator
text_token_generator_ = std::make_unique<TextTokenGenerator>(
tokenizer_.get(),
text_decoder_runner_.get(),
/*use_kv_cache=*/true,
std::make_unique<std::unordered_set<uint64_t>>(
std::unordered_set<uint64_t>{tokenizer_->eos_tok()}),
&stats_);
stats_.model_load_end_ms = util::time_in_ms();
return Error::Ok;
}
Error LlavaRunner::prefill_images(
std::vector<Image>& images,
int64_t& start_pos) {
for (auto& image : images) {
// pos is updated inside image prefill.
ET_UNWRAP(image_prefiller_->prefill(image, start_pos));
}
return Error::Ok;
}
Result<uint64_t> LlavaRunner::prefill_prompt(
const std::string& prompt,
int64_t& start_pos,
int8_t bos,
int8_t eos) {
std::vector<uint64_t> prompt_tokens =
ET_UNWRAP(tokenizer_->encode(prompt, bos, eos));
return text_prefiller_->prefill(prompt_tokens, start_pos);
}
Error LlavaRunner::generate_from_pos(
const std::string& prompt,
int32_t seq_len,
int64_t start_pos,
std::function<void(const std::string&)> token_callback,
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback) {
// prefill user prompt. No BOS because preset prompt already has it.
token_callback(prompt);
uint64_t prefill_next_token =
ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0));
stats_.num_prompt_tokens = start_pos;
// Generate tokens
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
{prefill_next_token}, start_pos, seq_len, token_callback));
// Bookkeeping
stats_.num_generated_tokens = num_generated_tokens;
::executorch::llm::print_report(stats_);
if (stats_callback) {
stats_callback(stats_);
}
return Error::Ok;
}
Error LlavaRunner::generate(
std::vector<Image> images,
const std::string& prompt,
int32_t seq_len,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
if (!is_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(load());
}
ET_LOG(
Info,
"RSS after loading model: %f MiB (0 if unsupported)",
util::get_rss_bytes() / 1024.0 / 1024.0);
// Wrap the token_callback with print function
std::function<void(const std::string&)> wrapped_callback =
[token_callback](const std::string& piece) {
util::safe_printf(piece.c_str());
fflush(stdout);
if (token_callback) {
token_callback(piece);
}
};
int64_t pos = 0;
// prefill preset prompt
prefill_prompt(kPresetPrompt, pos, /*bos=*/1, /*eos*/ 0);
// prefill images
prefill_images(images, pos);
ET_LOG(
Info,
"RSS after prompt and image prefill: %f MiB (0 if unsupported)",
util::get_rss_bytes() / 1024.0 / 1024.0);
// Generate tokens
Error err =
generate_from_pos(prompt, seq_len, pos, wrapped_callback, stats_callback);
ET_LOG(
Info,
"RSS after finishing text generation: %f MiB (0 if unsupported)",
util::get_rss_bytes() / 1024.0 / 1024.0);
return err;
}
} // namespace torch::executor