blob: f3e81b7d106eed156f10474cddd490c939f6e8d9 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// A simple llama2 runner that includes preprocessing and post processing logic.
// The module takes in a string as input and emits a string as output.
#include <executorch/examples/models/llama2/runner/runner.h>
#include <executorch/extension/runner_util/managed_tensor.h>
#include <ctime>
#include <memory>
#include <sstream>
#ifdef USE_ATEN_LIB
#include <torch/torch.h>
#endif
#include <executorch/examples/models/llama2/runner/util.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/platform/log.h>
namespace torch::executor {
namespace {
static constexpr auto kTopp = 0.9f;
} // namespace
Runner::Runner(
const std::string& model_path,
const std::string& tokenizer_path,
const float temperature)
: module_(std::make_unique<Module>(
model_path,
Module::MlockConfig::UseMlockIgnoreErrors)),
tokenizer_path_(tokenizer_path),
temperature_(temperature) {
ET_LOG(
Info,
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
model_path.c_str(),
tokenizer_path.c_str());
}
bool Runner::is_loaded() const {
return module_->is_loaded() && tokenizer_ && sampler_;
}
Error Runner::load() {
if (is_loaded()) {
return Error::Ok;
}
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
// Read out metadata: vocab_size (expected by the model), BOS, EOS, n_BOS,
// n_EOS max_seq_len from the model
ET_LOG(Info, "Reading metadata from model");
const auto method_names = module_->method_names();
ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model");
model_methods_ = method_names.get();
vocab_size_ = getMetadataHelper<int64_t>("get_vocab_size", 32000);
bos_id_ = getMetadataHelper<int64_t>("get_bos_id", 1);
eos_id_ = getMetadataHelper<int64_t>("get_eos_id", 2);
n_bos_ = getMetadataHelper<int64_t>("get_n_bos", 1);
n_eos_ = getMetadataHelper<int64_t>("get_n_eos", 1);
max_seq_len_ = getMetadataHelper<int64_t>("get_max_seq_len", 128);
use_kv_cache_ = getMetadataHelper("use_kv_cache", false);
use_sdpa_with_kv_cache_ = getMetadataHelper("use_sdpa_with_kv_cache", false);
append_eos_ = getMetadataHelper("append_eos_to_prompt", false);
// Load tokenizer
tokenizer_ = std::make_unique<Tokenizer>(vocab_size_, bos_id_, eos_id_);
tokenizer_->load(tokenizer_path_);
if (tokenizer_->bos_tok() != bos_id_) {
ET_LOG(
Error,
"Tokenizer's BOS id %d does not match model's BOS id %d, will override tokenizer's BOS.",
tokenizer_->bos_tok(),
bos_id_);
}
if (tokenizer_->eos_tok() != eos_id_) {
ET_LOG(
Error,
"Tokenizer's EOS id %d does not match model's EOS id %d, will override tokenizer's EOS.",
tokenizer_->eos_tok(),
eos_id_);
}
// Create sampler
sampler_ = std::make_unique<Sampler>(
vocab_size_,
temperature_,
kTopp,
static_cast<unsigned long long>(std::time(nullptr)));
return Error::Ok;
}
template <typename T>
T Runner::getMetadataHelper(std::string method_name, T default_val) {
T res = default_val;
if (model_methods_.count(method_name)) {
Result<std::vector<EValue>> outputs = module_->execute(method_name);
if (outputs.ok()) {
std::vector<EValue> outs = outputs.get();
if (outs.size() > 0) {
res = outs[0].to<T>();
}
}
} else {
ET_LOG(
Info,
"The model does not contain %s method, using default value %lld",
method_name.c_str(),
(long long)default_val);
}
ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res);
return res;
}
std::vector<exec_aten::SizesType> Runner::getKVCacheShape() {
// shape: (n_layers, args.max_batch_size, args.max_seq_len, self.n_kv_heads,
// self.head_dim)
std::vector<std::string> methods = {
"get_n_layers",
"get_max_batch_size",
"get_max_seq_len",
"get_n_kv_heads",
"get_head_dim"};
std::vector<int64_t> default_values = {12, 1, 128, 32, 128};
std::vector<exec_aten::SizesType> result;
for (int i = 0; i < methods.size(); ++i) {
// convert from int64_t to int32_t
result.push_back(getMetadataHelper<int64_t>(methods[i], default_values[i]));
}
return result;
}
template <typename T>
int32_t Runner::logitsToToken(
const exec_aten::Tensor& logits_tensor,
int64_t pos,
T _) {
(void)_;
T* logits = logits_tensor.mutable_data_ptr<T>();
// Since the logits are for all tokens, get the last token probabilities
T* logits_last = logits;
if (!use_kv_cache_) {
logits_last += pos * tokenizer_->vocab_size();
}
return sampler_->sample(logits_last);
}
Error Runner::generate(
const std::string& prompt,
int32_t seq_len,
std::function<void(const std::string&)> callback) {
// Prepare the inputs.
// Use ones-initialized inputs.
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
if (!is_loaded()) {
timers_.model_load_start_ms = util::time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(load());
timers_.model_load_end_ms = util::time_in_ms();
}
// First token time only measures the time it takes to encode the prompt and
// return a response token.
timers_.inference_start_ms = util::time_in_ms();
shouldStop_ = false;
// encode the (string) prompt into tokens sequence
int num_prompt_tokens = 0;
// max # of prompt tokens: len(prompt) + '\0', ?BOS, ?EOS
int* prompt_tokens = new int[prompt.size() + 1 + n_bos_ + n_eos_];
// Set the sequence length to the max seq length if not provided
seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;
tokenizer_->encode(
prompt.c_str(),
n_bos_,
append_eos_ ? n_eos_ : 0,
prompt_tokens,
&num_prompt_tokens);
for (int i = 0; i < num_prompt_tokens; i++) {
ET_LOG(Info, "prompt_tokens[%d]: %d", i, prompt_tokens[i]);
}
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
ET_CHECK_MSG(
num_prompt_tokens < max_seq_len_,
"Max seq length exceeded - please increase max seq len value in .../llama2/model.py");
ET_CHECK_MSG(
num_prompt_tokens < seq_len,
"Sequence length exceeded - please increase the seq_len value passed to generate()");
// start the main loop
int next; // will store the next token in the sequence
int64_t pos = num_prompt_tokens - 1; // position in the sequence
int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
int logits_index = 0; // index of the logits tensor in the output
int k_cache_index = 0;
int v_cache_index = 0;
std::vector<exec_aten::SizesType> kv_cache_shape = getKVCacheShape();
std::vector<exec_aten::SizesType> input_shape = {1, 1};
std::vector<exec_aten::SizesType> pos_shape = {};
std::vector<uint8_t> k_data;
std::vector<uint8_t> v_data;
std::vector<int64_t> token_data; // allocate space for the tokens
ScalarType dtype = static_cast<ScalarType>(
getMetadataHelper("get_dtype", (int64_t)ScalarType::Float));
if (use_kv_cache_) {
// set pos to 0, refill token by token
pos = 0;
logits_index = 2;
k_cache_index = 0;
v_cache_index = 1;
// TODO(): Fix this by inspecting graph signature
if (use_sdpa_with_kv_cache_) {
logits_index = 0;
k_cache_index = 1;
v_cache_index = 2;
}
// initialize kv cache
size_t n_bytes = 1;
for (exec_aten::SizesType shape : kv_cache_shape) {
n_bytes *= shape;
}
n_bytes *= torch::executor::elementSize(dtype);
k_data.resize(n_bytes);
std::fill(k_data.begin(), k_data.end(), 0);
v_data.resize(n_bytes);
std::fill(v_data.begin(), v_data.end(), 0);
token_data.resize(1);
} else {
// reserve data for tokens, notice the size is still 0.
token_data.resize(seq_len);
}
// initialize tensor wrappers
ManagedTensor k_managed(k_data.data(), k_data.size(), kv_cache_shape, dtype);
ManagedTensor v_managed(v_data.data(), v_data.size(), kv_cache_shape, dtype);
ManagedTensor pos_managed(&pos, 0, {}, ScalarType::Long);
// copy prompt tokens into data
for (int i = 0; i <= pos; ++i) {
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
token_data[i] = prompt_tokens[i];
if (i > 0) {
printf(
"%s",
ET_UNWRAP(
tokenizer_->decode(prompt_tokens[i - 1], prompt_tokens[i])));
}
}
// create a 1xN int tensor with next as value
while (pos < seq_len) {
// ET_LOG(Info, "Generating step %d...", pos);
// set the current token in the tensor
std::vector<EValue> inputs;
if (use_kv_cache_) {
token_data[0] = token;
input_shape[1] = 1;
// inputs: [tokens, start_pos, k_cache, v_cache]
inputs.emplace_back(pos_managed.get_aliasing_tensor());
inputs.emplace_back(k_managed.get_aliasing_tensor());
inputs.emplace_back(v_managed.get_aliasing_tensor());
} else {
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
token_data[pos] = token;
input_shape[1] = pos + 1;
}
ManagedTensor token_managed(
token_data.data(), token_data.size(), input_shape, ScalarType::Long);
inputs.insert(inputs.begin(), token_managed.get_aliasing_tensor());
// For kv cache, inputs: [tokens, start_pos, k_cache, v_cache]
// Otherwise inputs: [tokens]
Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
ET_CHECK_MSG(
outputs_res.ok(),
"Execution of method forward failed with status 0x%" PRIx32,
static_cast<int32_t>(outputs_res.error()));
// ET_LOG(Info, "Model executed successfully.");
std::vector<EValue> outputs = outputs_res.get();
// Check the outputs.
ET_CHECK_MSG(
outputs.size() > 0,
"Expecting output to have at least one evalue. Got %zu",
outputs.size());
if (pos == num_prompt_tokens) {
timers_.first_token_ms = util::time_in_ms();
} else if (pos == num_prompt_tokens - 1) {
timers_.prompt_eval_end_ms = util::time_in_ms();
}
int32_t next_tok;
exec_aten::Tensor logits_tensor = outputs.at(logits_index).toTensor();
long sample_start_time_ms = util::time_in_ms();
switch (logits_tensor.scalar_type()) {
case ScalarType::Float: {
next_tok = logitsToToken<float>(logits_tensor, pos, 0);
break;
}
case ScalarType::Half: {
next_tok = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0);
break;
}
default:
ET_CHECK_MSG(
false,
"Unsupported dtype output %hhd",
static_cast<int8_t>(logits_tensor.scalar_type()));
}
timers_.aggregate_sampling_time_ms +=
util::time_in_ms() - sample_start_time_ms;
// advance the state machine
if (pos < num_prompt_tokens - 1) {
// prefill, force the next token to be the next prompt token
next = prompt_tokens[pos + 1];
} else {
// otherwise sample the next token from the logits
next = next_tok;
}
// ET_LOG(Info, "Output saved, next = %d", next);
pos++;
// print the token as string, decode it with the Tokenizer object
auto piece_res = tokenizer_->decode(token, next);
ET_CHECK(piece_res.ok());
const char* piece = piece_res.get();
// same as printf("%s", piece), but skips "unsafe" bytes
util::safe_printf(piece);
fflush(stdout);
if (callback) {
callback(piece);
}
if (shouldStop_) {
break;
}
// data-dependent terminating condition: we have n_eos_ number of EOS
if (pos >= num_prompt_tokens && next == eos_id_) {
printf("\n");
ET_LOG(Info, "\nReached to the end of generation");
break;
}
token = next;
if (use_kv_cache_) {
// outputs: [k_cache, v_cache, logits, k_cache, v_cache]
memcpy(
k_data.data(),
outputs.at(k_cache_index).toTensor().const_data_ptr(),
k_data.size());
memcpy(
v_data.data(),
outputs.at(v_cache_index).toTensor().const_data_ptr(),
v_data.size());
}
}
timers_.inference_end_ms = util::time_in_ms();
printf("\n");
if (pos == seq_len) {
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
}
timers_.printReport(num_prompt_tokens, pos - num_prompt_tokens);
delete[] prompt_tokens;
return Error::Ok;
}
void Runner::TimeStamps::printReport(
const int64_t& num_prompt_tokens,
const int64_t& num_generated_tokens) {
printf(
"PyTorchObserver %s\n",
toJsonString(num_prompt_tokens, num_generated_tokens).c_str());
ET_LOG(
Info,
"\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64,
num_prompt_tokens,
num_generated_tokens);
ET_LOG(
Info,
"\tModel Load Time:\t\t%f (seconds)",
((double)(model_load_end_ms - model_load_start_ms) /
SCALING_FACTOR_UNITS_PER_SECOND));
double inference_time_ms = (double)(inference_end_ms - inference_start_ms);
ET_LOG(
Info,
"\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
inference_time_ms / SCALING_FACTOR_UNITS_PER_SECOND,
(num_generated_tokens) / (double)(inference_end_ms - inference_start_ms) *
SCALING_FACTOR_UNITS_PER_SECOND);
double prompt_eval_time = (double)(prompt_eval_end_ms - inference_start_ms);
ET_LOG(
Info,
"\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
prompt_eval_time / SCALING_FACTOR_UNITS_PER_SECOND,
(num_prompt_tokens) / prompt_eval_time * SCALING_FACTOR_UNITS_PER_SECOND);
double eval_time = (double)(inference_end_ms - prompt_eval_end_ms);
ET_LOG(
Info,
"\t\tGenerated %" PRIu64
" tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
num_generated_tokens,
eval_time / SCALING_FACTOR_UNITS_PER_SECOND,
num_generated_tokens / eval_time * SCALING_FACTOR_UNITS_PER_SECOND);
// Time to first token is measured from the start of inference, excluding
// model load time.
ET_LOG(
Info,
"\tTime to first generated token:\t%f (seconds)",
((double)(first_token_ms - inference_start_ms) /
SCALING_FACTOR_UNITS_PER_SECOND));
ET_LOG(
Info,
"\tSampling time over %" PRIu64
" tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
num_prompt_tokens + num_generated_tokens,
(double)aggregate_sampling_time_ms / SCALING_FACTOR_UNITS_PER_SECOND,
(num_prompt_tokens + num_generated_tokens) /
(double)aggregate_sampling_time_ms * SCALING_FACTOR_UNITS_PER_SECOND);
}
const std::string Runner::TimeStamps::toJsonString(
const int64_t& num_prompt_tokens,
const int64_t& num_generated_tokens) {
std::stringstream ss;
ss << "{\"prompt_tokens\":" << num_prompt_tokens << ","
<< "\"generated_tokens\":" << num_generated_tokens << ","
<< "\"model_load_start_ms\":" << model_load_start_ms << ","
<< "\"model_load_end_ms\":" << model_load_end_ms << ","
<< "\"inference_start_ms\":" << inference_start_ms << ","
<< "\"inference_end_ms\":" << inference_end_ms << ","
<< "\"prompt_eval_end_ms\":" << prompt_eval_end_ms << ","
<< "\"first_token_ms\":" << first_token_ms << ","
<< "\"aggregate_sampling_time_ms\":" << aggregate_sampling_time_ms << ","
<< "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
<< SCALING_FACTOR_UNITS_PER_SECOND << "}";
return ss.str();
}
void Runner::stop() {
shouldStop_ = true;
}
// explicit instantiation of template methods
template int64_t Runner::getMetadataHelper<int64_t>(
std::string method_name,
int64_t default_val);
template bool Runner::getMetadataHelper<bool>(
std::string method_name,
bool default_val);
} // namespace torch::executor