blob: 1e17c75400759a34565d5bf7312bf355db5120e6 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// A simple llama2 runner that includes preprocessing and post processing logic.
// The module takes in a string as input and emits a string as output.
#include <executorch/examples/models/llama2/runner/runner.h>
#include <ctime>
#include <executorch/extension/llm/runner/util.h>
#include <executorch/examples/models/llama2/tokenizer/llama_tiktoken.h>
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
namespace torch::executor {
namespace {
static constexpr auto kAppendEosToPrompt = "append_eos_to_prompt";
static constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
static constexpr auto kBosId = "get_bos_id";
static constexpr auto kEosIds = "get_eos_ids";
static constexpr auto kMaxSeqLen = "get_max_seq_len";
static constexpr auto kNBos = "get_n_bos";
static constexpr auto kNEos = "get_n_eos";
static constexpr auto kVocabSize = "get_vocab_size";
static constexpr auto kUseKVCache = "use_kv_cache";
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
} // namespace
Runner::Runner(
const std::string& model_path,
const std::string& tokenizer_path,
const float temperature)
// NOTE: we observed ~2x loading performance increase on iPhone 15
// and a ~5% improvement on Galaxy S22 by switching to
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
: temperature_(temperature),
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
tokenizer_path_(tokenizer_path),
metadata_({
{kAppendEosToPrompt, false},
{kEnableDynamicShape, false},
{kMaxSeqLen, 128},
{kNBos, 1},
{kNEos, 1},
{kUseKVCache, true},
{kUseSDPAWithKVCache, false},
}) {
ET_LOG(
Info,
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
model_path.c_str(),
tokenizer_path.c_str());
}
bool Runner::is_loaded() const {
return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
text_prefiller_ && text_token_generator_;
}
Error Runner::load() {
if (is_loaded()) {
return Error::Ok;
}
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
// load tokenizer. Assuming tiktoken is the default tokenizer
tokenizer_ = nullptr;
tokenizer_ = get_tiktoken_for_llama();
Error err = tokenizer_->load(tokenizer_path_);
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
// fallback to BPE tokenizer.
if (err == Error::InvalidArgument) {
ET_LOG(
Info,
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
tokenizer_path_.c_str());
tokenizer_.reset();
tokenizer_ = std::make_unique<BPETokenizer>();
tokenizer_->load(tokenizer_path_);
}
ET_LOG(Info, "Reading metadata from model");
metadata_[kBosId] = tokenizer_->bos_tok();
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
std::unordered_set<uint64_t>{tokenizer_->eos_tok()});
metadata_[kVocabSize] = tokenizer_->vocab_size();
const auto method_names =
ET_UNWRAP(module_->method_names(), "Failed reading method names");
for (auto& pair : metadata_) {
const auto& method_name = pair.first;
auto& value = pair.second;
if (method_names.count(method_name)) {
value = ET_UNWRAP(module_->get(method_name))
.toScalar()
.to<decltype(metadata_)::mapped_type>();
} else {
ET_LOG(
Info,
"Methond %s not found, using the default value %" PRId64,
method_name.c_str(),
value);
}
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
}
if (method_names.count(kEosIds)) {
eos_ids->clear();
for (const auto& eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
auto value = eos_id.toScalar().to<int64_t>();
eos_ids->emplace(value);
ET_LOG(Info, "eos_id = %" PRId64, value);
}
}
text_decoder_runner_ = std::make_unique<TextDecoderRunner>(
module_.get(),
metadata_.at(kUseKVCache),
metadata_.at(kVocabSize),
temperature_);
text_prefiller_ = std::make_unique<TextPrefiller>(
text_decoder_runner_.get(),
metadata_.at(kUseKVCache),
metadata_.at(kEnableDynamicShape));
text_token_generator_ = std::make_unique<TextTokenGenerator>(
tokenizer_.get(),
text_decoder_runner_.get(),
metadata_.at(kUseKVCache),
std::move(eos_ids),
&stats_);
return Error::Ok;
}
Error Runner::generate(
const std::string& prompt,
int32_t seq_len,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback,
bool echo) {
// Prepare the inputs.
// Use ones-initialized inputs.
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
if (!is_loaded()) {
stats_.model_load_start_ms = util::time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(load());
stats_.model_load_end_ms = util::time_in_ms();
}
ET_LOG(
Info,
"RSS after loading model: %f MiB (0 if unsupported)",
util::get_rss_bytes() / 1024.0 / 1024.0);
// Wrap the token_callback with print function
std::function<void(const std::string&)> wrapped_callback =
[token_callback](const std::string& piece) {
util::safe_printf(piece.c_str());
fflush(stdout);
if (token_callback) {
token_callback(piece);
}
};
// First token time only measures the time it takes to encode the prompt and
// return a response token.
stats_.inference_start_ms = util::time_in_ms();
shouldStop_ = false;
// Set the sequence length to the max seq length if not provided
seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxSeqLen))
? seq_len
: metadata_.at(kMaxSeqLen);
Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
prompt,
metadata_.at(kNBos),
metadata_.at(kAppendEosToPrompt) ? metadata_.at(kNEos) : 0);
ET_CHECK_OK_OR_RETURN_ERROR(
encode_res.error(), "Failed to encode prompt %s", prompt.c_str());
// encode the (string) prompt into tokens sequence
std::vector<uint64_t> prompt_tokens = encode_res.get();
int num_prompt_tokens = prompt_tokens.size();
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
ET_CHECK_MSG(
num_prompt_tokens < metadata_.at(kMaxSeqLen),
"num_prompt_tokens %d >= max_seq_len_ %" PRId64
", Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
num_prompt_tokens,
metadata_.at(kMaxSeqLen));
ET_CHECK_MSG(
num_prompt_tokens < seq_len,
"num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()",
num_prompt_tokens,
seq_len);
// Prefill first
// Here feed all tokens to the model and get the next predicted token
// after the prompt. After that we will enter generate loop.
// print prompts
if (echo) {
wrapped_callback(prompt);
}
int64_t pos = 0;
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
stats_.first_token_ms = util::time_in_ms();
stats_.prompt_eval_end_ms = util::time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
uint64_t cur_token = prefill_res.get();
// print the first token from prefill. No prev_token so use cur_token for it.
wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token)));
ET_LOG(
Info,
"RSS after prompt prefill: %f MiB (0 if unsupported)",
util::get_rss_bytes() / 1024.0 / 1024.0);
// start the main loop
prompt_tokens.push_back(cur_token);
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
prompt_tokens, num_prompt_tokens, seq_len, wrapped_callback));
stats_.inference_end_ms = util::time_in_ms();
printf("\n");
ET_LOG(
Info,
"RSS after finishing text generation: %f MiB (0 if unsupported)",
util::get_rss_bytes() / 1024.0 / 1024.0);
if (num_prompt_tokens + num_generated_tokens == seq_len) {
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
}
stats_.num_prompt_tokens = num_prompt_tokens;
stats_.num_generated_tokens = num_generated_tokens;
::executorch::llm::print_report(stats_);
if (stats_callback) {
stats_callback(stats_);
}
return Error::Ok;
}
void Runner::stop() {
if (is_loaded()) {
text_token_generator_->stop();
} else {
ET_LOG(Error, "Token generator is not loaded, cannot stop");
}
}
} // namespace torch::executor