blob: fc0857fdf4355f5f4b0eebdc002b6fe253bd03f6 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <dlfcn.h>
#include <jni.h>
#include <stdio.h>
#include <time.h>
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
#include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h"
#include "tensorflow/lite/java/src/main/native/jni_utils.h"
#include "tensorflow/lite/util.h"
namespace tflite {
// This is to be provided at link-time by a library.
extern std::unique_ptr<OpResolver> CreateOpResolver();
} // namespace tflite
using tflite::jni::BufferErrorReporter;
using tflite::jni::ThrowException;
namespace {
tflite_api_dispatcher::Interpreter* convertLongToInterpreter(JNIEnv* env,
jlong handle) {
if (handle == 0) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Invalid handle to Interpreter.");
return nullptr;
}
return reinterpret_cast<tflite_api_dispatcher::Interpreter*>(handle);
}
tflite_api_dispatcher::TfLiteModel* convertLongToModel(JNIEnv* env,
jlong handle) {
if (handle == 0) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Invalid handle to model.");
return nullptr;
}
return reinterpret_cast<tflite_api_dispatcher::TfLiteModel*>(handle);
}
BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) {
if (handle == 0) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Invalid handle to ErrorReporter.");
return nullptr;
}
return reinterpret_cast<BufferErrorReporter*>(handle);
}
TfLiteDelegate* convertLongToDelegate(JNIEnv* env, jlong handle) {
if (handle == 0) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Invalid handle to delegate.");
return nullptr;
}
return reinterpret_cast<TfLiteDelegate*>(handle);
}
std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
int size = static_cast<int>(env->GetArrayLength(inputs));
std::vector<int> outputs(size, 0);
jint* ptr = env->GetIntArrayElements(inputs, nullptr);
if (ptr == nullptr) {
ThrowException(env, kIllegalArgumentException,
"Array has empty dimensions.");
return {};
}
for (int i = 0; i < size; ++i) {
outputs[i] = ptr[i];
}
env->ReleaseIntArrayElements(inputs, ptr, JNI_ABORT);
return outputs;
}
int getDataType(TfLiteType data_type) {
switch (data_type) {
case kTfLiteFloat32:
return 1;
case kTfLiteInt32:
return 2;
case kTfLiteUInt8:
return 3;
case kTfLiteInt64:
return 4;
case kTfLiteString:
return 5;
case kTfLiteBool:
return 6;
default:
return -1;
}
}
void printDims(char* buffer, int max_size, int* dims, int num_dims) {
if (max_size <= 0) return;
buffer[0] = '?';
int size = 1;
for (int i = 1; i < num_dims; ++i) {
if (max_size > size) {
int written_size =
snprintf(buffer + size, max_size - size, ",%d", dims[i]);
if (written_size < 0) return;
size += written_size;
}
}
}
// Checks whether there is any difference between dimensions of a tensor and a
// given dimensions. Returns true if there is difference, else false.
bool AreDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) {
int num_dims = static_cast<int>(env->GetArrayLength(dims));
jint* ptr = env->GetIntArrayElements(dims, nullptr);
if (ptr == nullptr) {
ThrowException(env, kIllegalArgumentException,
"Empty dimensions of input array.");
return true;
}
bool is_different = false;
if (tensor->dims->size != num_dims) {
is_different = true;
} else {
for (int i = 0; i < num_dims; ++i) {
if (ptr[i] != tensor->dims->data[i]) {
is_different = true;
break;
}
}
}
env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
return is_different;
}
// TODO(yichengfan): evaluate the benefit to use tflite verifier.
bool VerifyModel(const void* buf, size_t len) {
flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
return tflite::VerifyModelBuffer(verifier);
}
} // namespace
#ifdef __cplusplus
extern "C" {
#endif
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
jclass clazz,
jlong handle) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
jclass string_class = env->FindClass("java/lang/String");
if (string_class == nullptr) {
ThrowException(env, kUnsupportedOperationException,
"Internal error: Can not find java/lang/String class to get "
"input names.");
return nullptr;
}
size_t size = interpreter->inputs().size();
jobjectArray names = static_cast<jobjectArray>(
env->NewObjectArray(size, string_class, env->NewStringUTF("")));
for (int i = 0; i < size; ++i) {
env->SetObjectArrayElement(names, i,
env->NewStringUTF(interpreter->GetInputName(i)));
}
return names;
}
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return;
if (interpreter->AllocateTensors() != kTfLiteOk) {
ThrowException(
env, kIllegalStateException,
"Internal error: Unexpected failure when preparing tensor allocations:"
" %s",
error_reporter->CachedErrorMessage());
}
}
JNIEXPORT jboolean JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
JNIEnv* env, jclass clazz, jlong handle) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return JNI_FALSE;
// TODO(b/132995737): Remove this logic by caching whether an unresolved
// Flex op is present during Interpreter creation.
for (size_t subgraph_i = 0; subgraph_i < interpreter->subgraphs_size();
++subgraph_i) {
const auto* subgraph = interpreter->subgraph(static_cast<int>(subgraph_i));
for (int node_i : subgraph->execution_plan()) {
const auto& registration =
subgraph->node_and_registration(node_i)->second;
if (tflite::IsUnresolvedCustomOp(registration) &&
tflite::IsFlexOp(registration.custom_name)) {
return JNI_TRUE;
}
}
}
return JNI_FALSE;
}
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
JNIEnv* env, jclass clazz, jlong handle, jint input_index) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
return interpreter->inputs()[input_index];
}
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
JNIEnv* env, jclass clazz, jlong handle, jint output_index) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
return interpreter->outputs()[output_index];
}
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength(
JNIEnv* env, jclass clazz, jlong handle) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
return static_cast<jint>(interpreter->execution_plan().size());
}
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
jclass clazz,
jlong handle) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
return static_cast<jint>(interpreter->inputs().size());
}
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
jclass clazz,
jlong handle) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
return static_cast<jint>(interpreter->outputs().size());
}
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
jclass clazz,
jlong handle) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
jclass string_class = env->FindClass("java/lang/String");
if (string_class == nullptr) {
ThrowException(env, kUnsupportedOperationException,
"Internal error: Can not find java/lang/String class to get "
"output names.");
return nullptr;
}
size_t size = interpreter->outputs().size();
jobjectArray names = static_cast<jobjectArray>(
env->NewObjectArray(size, string_class, env->NewStringUTF("")));
for (int i = 0; i < size; ++i) {
env->SetObjectArrayElement(
names, i, env->NewStringUTF(interpreter->GetOutputName(i)));
}
return names;
}
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
jclass clazz,
jlong handle,
jboolean state) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
interpreter->UseNNAPI(static_cast<bool>(state));
}
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow));
}
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
interpreter->SetAllowBufferHandleOutput(allow);
}
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jboolean state,
jint num_threads) {
// If not using xnnpack, simply don't apply the delegate.
if (!state) {
return;
}
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) {
return;
}
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) {
return;
}
// We use dynamic loading to avoid taking a hard dependency on XNNPack.
// This allows clients that use trimmed builds to save on binary size.
auto xnnpack_options_default =
reinterpret_cast<decltype(TfLiteXNNPackDelegateOptionsDefault)*>(
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateOptionsDefault"));
auto xnnpack_create =
reinterpret_cast<decltype(TfLiteXNNPackDelegateCreate)*>(
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateCreate"));
auto xnnpack_delete =
reinterpret_cast<decltype(TfLiteXNNPackDelegateDelete)*>(
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateDelete"));
if (xnnpack_options_default && xnnpack_create && xnnpack_delete) {
TfLiteXNNPackDelegateOptions options = xnnpack_options_default();
if (num_threads > 0) {
options.num_threads = num_threads;
}
tflite_api_dispatcher::Interpreter::TfLiteDelegatePtr delegate(
xnnpack_create(&options), xnnpack_delete);
auto delegation_status =
interpreter->ModifyGraphWithDelegate(std::move(delegate));
// kTfLiteApplicationError occurs in cases where delegation fails but
// the runtime is invokable (eg. another delegate has already been applied).
// We don't throw an Exception in that case.
// TODO(b/166483905): Add support for multiple delegates when model allows.
if (delegation_status != kTfLiteOk &&
delegation_status != kTfLiteApplicationError) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Failed to apply XNNPACK delegate: %s",
error_reporter->CachedErrorMessage());
}
} else {
ThrowException(env, kIllegalArgumentException,
"Failed to load XNNPACK delegate from current runtime. "
"Have you added the necessary dependencies?");
}
}
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
jclass clazz,
jlong handle,
jint num_threads) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
interpreter->SetNumThreads(static_cast<int>(num_threads));
}
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
JNIEnv* env, jclass clazz, jint size) {
BufferErrorReporter* error_reporter =
new BufferErrorReporter(env, static_cast<int>(size));
return reinterpret_cast<jlong>(error_reporter);
}
// Verifies whether the model is a flatbuffer file.
class JNIFlatBufferVerifier : public tflite_api_dispatcher::TfLiteVerifier {
public:
bool Verify(const char* data, int length,
tflite::ErrorReporter* reporter) override {
if (!VerifyModel(data, length)) {
reporter->Report("The model is not a valid Flatbuffer file");
return false;
}
return true;
}
};
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) {
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return 0;
const char* path = env->GetStringUTFChars(model_file, nullptr);
std::unique_ptr<tflite_api_dispatcher::TfLiteVerifier> verifier;
verifier.reset(new JNIFlatBufferVerifier());
auto model = tflite_api_dispatcher::TfLiteModel::VerifyAndBuildFromFile(
path, verifier.get(), error_reporter);
if (!model) {
ThrowException(env, kIllegalArgumentException,
"Contents of %s does not encode a valid "
"TensorFlow Lite model: %s",
path, error_reporter->CachedErrorMessage());
env->ReleaseStringUTFChars(model_file, path);
return 0;
}
env->ReleaseStringUTFChars(model_file, path);
return reinterpret_cast<jlong>(model.release());
}
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
JNIEnv* env, jclass /*clazz*/, jobject model_buffer, jlong error_handle) {
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return 0;
const char* buf =
static_cast<char*>(env->GetDirectBufferAddress(model_buffer));
jlong capacity = env->GetDirectBufferCapacity(model_buffer);
if (!VerifyModel(buf, capacity)) {
ThrowException(env, kIllegalArgumentException,
"ByteBuffer is not a valid flatbuffer model");
return 0;
}
auto model = tflite_api_dispatcher::TfLiteModel::BuildFromBuffer(
buf, static_cast<size_t>(capacity), error_reporter);
if (!model) {
ThrowException(env, kIllegalArgumentException,
"ByteBuffer does not encode a valid model: %s",
error_reporter->CachedErrorMessage());
return 0;
}
return reinterpret_cast<jlong>(model.release());
}
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
jint num_threads) {
tflite_api_dispatcher::TfLiteModel* model =
convertLongToModel(env, model_handle);
if (model == nullptr) return 0;
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return 0;
auto resolver = ::tflite::CreateOpResolver();
std::unique_ptr<tflite_api_dispatcher::Interpreter> interpreter;
TfLiteStatus status = tflite_api_dispatcher::InterpreterBuilder(
*model, *(resolver.get()))(&interpreter, static_cast<int>(num_threads));
if (status != kTfLiteOk) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Cannot create interpreter: %s",
error_reporter->CachedErrorMessage());
return 0;
}
// Note that tensor allocation is performed explicitly by the owning Java
// NativeInterpreterWrapper instance.
return reinterpret_cast<jlong>(interpreter.release());
}
// Sets inputs, runs inference, and returns outputs as long handles.
JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
if (interpreter == nullptr) return;
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return;
if (interpreter->Invoke() != kTfLiteOk) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Failed to run on the given Interpreter: %s",
error_reporter->CachedErrorMessage());
return;
}
}
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
JNIEnv* env, jclass clazz, jlong handle, jint output_idx) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return -1;
const int idx = static_cast<int>(output_idx);
if (output_idx < 0 || output_idx >= interpreter->outputs().size()) {
ThrowException(env, kIllegalArgumentException,
"Failed to get %d-th output out of %d outputs", output_idx,
interpreter->outputs().size());
return -1;
}
TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]);
int type = getDataType(target->type);
return static_cast<jint>(type);
}
JNIEXPORT jboolean JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
jint input_idx, jintArray dims, jboolean strict) {
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return JNI_FALSE;
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
if (interpreter == nullptr) return JNI_FALSE;
if (input_idx < 0 || input_idx >= interpreter->inputs().size()) {
ThrowException(env, kIllegalArgumentException,
"Input error: Can not resize %d-th input for a model having "
"%d inputs.",
input_idx, interpreter->inputs().size());
return JNI_FALSE;
}
const int tensor_idx = interpreter->inputs()[input_idx];
// check whether it is resizing with the same dimensions.
TfLiteTensor* target = interpreter->tensor(tensor_idx);
bool is_changed = AreDimsDifferent(env, target, dims);
if (is_changed) {
TfLiteStatus status;
if (strict) {
status = interpreter->ResizeInputTensorStrict(
tensor_idx, convertJIntArrayToVector(env, dims));
} else {
status = interpreter->ResizeInputTensor(
tensor_idx, convertJIntArrayToVector(env, dims));
}
if (status != kTfLiteOk) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Failed to resize %d-th input: %s",
input_idx, error_reporter->CachedErrorMessage());
return JNI_FALSE;
}
}
return is_changed ? JNI_TRUE : JNI_FALSE;
}
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
jlong delegate_handle) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
if (interpreter == nullptr) return;
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return;
TfLiteDelegate* delegate = convertLongToDelegate(env, delegate_handle);
if (delegate == nullptr) return;
TfLiteStatus status = interpreter->ModifyGraphWithDelegate(delegate);
if (status != kTfLiteOk) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Failed to apply delegate: %s",
error_reporter->CachedErrorMessage());
}
}
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
if (interpreter == nullptr) return;
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return;
TfLiteStatus status = interpreter->ResetVariableTensors();
if (status != kTfLiteOk) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Failed to reset variable tensors: %s",
error_reporter->CachedErrorMessage());
}
}
JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete(
JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle,
jlong interpreter_handle) {
if (interpreter_handle != 0) {
delete convertLongToInterpreter(env, interpreter_handle);
}
if (model_handle != 0) {
delete convertLongToModel(env, model_handle);
}
if (error_handle != 0) {
delete convertLongToErrorReporter(env, error_handle);
}
}
#ifdef __cplusplus
} // extern "C"
#endif