blob: 5f2cac188fce0bcdf7fc018336d6618b6ccb750e [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cassert>
#include <chrono>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include <executorch/examples/models/llama2/runner/runner.h>
#include <executorch/examples/models/llava/runner/llava_runner.h>
#include <executorch/extension/llm/runner/image.h>
#include <executorch/runtime/platform/log.h>
#include <executorch/runtime/platform/platform.h>
#include <executorch/runtime/platform/runtime.h>
#if defined(ET_USE_THREADPOOL)
#include <executorch/extension/threadpool/cpuinfo_utils.h>
#include <executorch/extension/threadpool/threadpool.h>
#endif
#include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h>
#ifdef __ANDROID__
#include <android/log.h>
// For Android, write to logcat
void et_pal_emit_log_message(
et_timestamp_t timestamp,
et_pal_log_level_t level,
const char* filename,
const char* function,
size_t line,
const char* message,
size_t length) {
int android_log_level = ANDROID_LOG_UNKNOWN;
if (level == 'D') {
android_log_level = ANDROID_LOG_DEBUG;
} else if (level == 'I') {
android_log_level = ANDROID_LOG_INFO;
} else if (level == 'E') {
android_log_level = ANDROID_LOG_ERROR;
} else if (level == 'F') {
android_log_level = ANDROID_LOG_FATAL;
}
__android_log_print(android_log_level, "LLAMA", "%s", message);
}
#endif
using namespace torch::executor;
namespace executorch_jni {
class ExecuTorchLlamaCallbackJni
: public facebook::jni::JavaClass<ExecuTorchLlamaCallbackJni> {
public:
constexpr static const char* kJavaDescriptor =
"Lorg/pytorch/executorch/LlamaCallback;";
void onResult(std::string result) const {
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
static const auto method =
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");
facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result);
method(self(), s);
}
void onStats(const Stats& result) const {
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
static const auto method = cls->getMethod<void(jfloat)>("onStats");
double eval_time =
(double)(result.inference_end_ms - result.prompt_eval_end_ms);
float tps = result.num_generated_tokens / eval_time *
result.SCALING_FACTOR_UNITS_PER_SECOND;
method(self(), tps);
}
};
class ExecuTorchLlamaJni
: public facebook::jni::HybridClass<ExecuTorchLlamaJni> {
private:
friend HybridBase;
int model_type_category_;
std::unique_ptr<Runner> runner_;
std::unique_ptr<MultimodalRunner> multi_modal_runner_;
public:
constexpr static auto kJavaDescriptor =
"Lorg/pytorch/executorch/LlamaModule;";
constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
jint model_type_category,
facebook::jni::alias_ref<jstring> model_path,
facebook::jni::alias_ref<jstring> tokenizer_path,
jfloat temperature) {
return makeCxxInstance(
model_type_category, model_path, tokenizer_path, temperature);
}
ExecuTorchLlamaJni(
jint model_type_category,
facebook::jni::alias_ref<jstring> model_path,
facebook::jni::alias_ref<jstring> tokenizer_path,
jfloat temperature) {
#if defined(ET_USE_THREADPOOL)
// Reserve 1 thread for the main thread.
uint32_t num_performant_cores =
torch::executorch::cpuinfo::get_num_performant_cores() - 1;
if (num_performant_cores > 0) {
ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores);
torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool(
num_performant_cores);
}
#endif
model_type_category_ = model_type_category;
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
multi_modal_runner_ = std::make_unique<LlavaRunner>(
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str(),
temperature);
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
runner_ = std::make_unique<Runner>(
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str(),
temperature);
}
}
jint generate(
facebook::jni::alias_ref<jintArray> image,
jint width,
jint height,
jint channels,
facebook::jni::alias_ref<jstring> prompt,
jint seq_len,
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
auto image_size = image->size();
std::vector<Image> images;
if (image_size != 0) {
std::vector<jint> image_data_jint(image_size);
std::vector<uint8_t> image_data(image_size);
image->getRegion(0, image_size, image_data_jint.data());
for (int i = 0; i < image_size; i++) {
image_data[i] = image_data_jint[i];
}
Image image_runner{image_data, width, height, channels};
images.push_back(image_runner);
}
multi_modal_runner_->generate(
std::move(images),
prompt->toStdString(),
seq_len,
[callback](std::string result) { callback->onResult(result); },
[callback](const Stats& result) { callback->onStats(result); });
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
runner_->generate(
prompt->toStdString(),
seq_len,
[callback](std::string result) { callback->onResult(result); },
[callback](const Stats& result) { callback->onStats(result); });
}
return 0;
}
// Returns a tuple of (error, start_pos)
// Contract is valid within an AAR (JNI + corresponding Java code)
// If the first element is not Error::Ok, the other element is undefined.
facebook::jni::local_ref<jlongArray> prefill_prompt(
facebook::jni::alias_ref<jstring> prompt,
jlong start_pos,
jint bos,
jint eos) {
facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
return tuple_result;
}
auto&& result = multi_modal_runner_->prefill_prompt(
prompt->toStdString(), start_pos, bos, eos);
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
if (result.ok()) {
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
}
return tuple_result;
}
// Returns a tuple of (error, start_pos)
// Contract is valid within an AAR (JNI + corresponding Java code)
// If the first element is not Error::Ok, the other element is undefined.
facebook::jni::local_ref<jlongArray> prefill_images(
facebook::jni::alias_ref<jintArray> image,
jint width,
jint height,
jint channels,
jlong start_pos) {
facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
return tuple_result;
}
auto image_size = image->size();
std::vector<Image> images;
if (image_size != 0) {
std::vector<jint> image_data_jint(image_size);
std::vector<uint8_t> image_data(image_size);
image->getRegion(0, image_size, image_data_jint.data());
for (int i = 0; i < image_size; i++) {
image_data[i] = image_data_jint[i];
}
Image image_runner{image_data, width, height, channels};
images.push_back(image_runner);
}
// TODO(hsz): make start_pos a reference and update it here
jint result = static_cast<jint>(
multi_modal_runner_->prefill_images(images, start_pos));
tuple_result->pin()[0] = result;
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
return tuple_result;
}
jint generate_from_pos(
facebook::jni::alias_ref<jstring> prompt,
jint seq_len,
jlong start_pos,
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
return static_cast<jint>(Error::NotSupported);
}
return static_cast<jint>(multi_modal_runner_->generate_from_pos(
prompt->toStdString(),
seq_len,
start_pos,
[callback](const std::string& result) { callback->onResult(result); },
[callback](const ::executorch::extension::llm::Stats& stats) {
callback->onStats(stats);
}));
}
void stop() {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
multi_modal_runner_->stop();
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
runner_->stop();
}
}
jint load() {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
return static_cast<jint>(multi_modal_runner_->load());
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
return static_cast<jint>(runner_->load());
}
return static_cast<jint>(Error::InvalidArgument);
}
static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", ExecuTorchLlamaJni::initHybrid),
makeNativeMethod("generate", ExecuTorchLlamaJni::generate),
makeNativeMethod("stop", ExecuTorchLlamaJni::stop),
makeNativeMethod("load", ExecuTorchLlamaJni::load),
});
}
};
} // namespace executorch_jni
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(
vm, [] { executorch_jni::ExecuTorchLlamaJni::registerNatives(); });
}