| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <dlfcn.h> |
| #include <jni.h> |
| #include <stdio.h> |
| #include <time.h> |
| |
| #include <atomic> |
| #include <map> |
| #include <utility> |
| #include <vector> |
| |
| #include "tensorflow/lite/core/shims/c/common.h" |
| #include "tensorflow/lite/core/shims/cc/create_op_resolver.h" |
| #include "tensorflow/lite/core/shims/cc/interpreter.h" |
| #include "tensorflow/lite/core/shims/cc/interpreter_builder.h" |
| #include "tensorflow/lite/core/shims/cc/model_builder.h" |
| #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" |
| #include "tensorflow/lite/java/src/main/native/jni_utils.h" |
| #include "tensorflow/lite/minimal_logging.h" |
| #include "tensorflow/lite/util.h" |
| |
| using tflite::OpResolver; |
| using tflite::jni::BufferErrorReporter; |
| using tflite::jni::ThrowException; |
| using tflite_shims::FlatBufferModel; |
| using tflite_shims::Interpreter; |
| using tflite_shims::InterpreterBuilder; |
| |
| namespace { |
| |
| Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { |
| if (handle == 0) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Internal error: Invalid handle to Interpreter."); |
| return nullptr; |
| } |
| return reinterpret_cast<Interpreter*>(handle); |
| } |
| |
| FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) { |
| if (handle == 0) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Internal error: Invalid handle to model."); |
| return nullptr; |
| } |
| return reinterpret_cast<FlatBufferModel*>(handle); |
| } |
| |
| BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) { |
| if (handle == 0) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Internal error: Invalid handle to ErrorReporter."); |
| return nullptr; |
| } |
| return reinterpret_cast<BufferErrorReporter*>(handle); |
| } |
| |
| TfLiteOpaqueDelegate* convertLongToDelegate(JNIEnv* env, jlong handle) { |
| if (handle == 0) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Internal error: Invalid handle to delegate."); |
| return nullptr; |
| } |
| return reinterpret_cast<TfLiteOpaqueDelegate*>(handle); |
| } |
| |
| std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) { |
| int size = static_cast<int>(env->GetArrayLength(inputs)); |
| std::vector<int> outputs(size, 0); |
| jint* ptr = env->GetIntArrayElements(inputs, nullptr); |
| if (ptr == nullptr) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Array has empty dimensions."); |
| return {}; |
| } |
| for (int i = 0; i < size; ++i) { |
| outputs[i] = ptr[i]; |
| } |
| env->ReleaseIntArrayElements(inputs, ptr, JNI_ABORT); |
| return outputs; |
| } |
| |
| int getDataType(TfLiteType data_type) { |
| switch (data_type) { |
| case kTfLiteFloat32: |
| return 1; |
| case kTfLiteInt32: |
| return 2; |
| case kTfLiteUInt8: |
| return 3; |
| case kTfLiteInt64: |
| return 4; |
| case kTfLiteString: |
| return 5; |
| case kTfLiteBool: |
| return 6; |
| default: |
| return -1; |
| } |
| } |
| |
| void printDims(char* buffer, int max_size, int* dims, int num_dims) { |
| if (max_size <= 0) return; |
| buffer[0] = '?'; |
| int size = 1; |
| for (int i = 1; i < num_dims; ++i) { |
| if (max_size > size) { |
| int written_size = |
| snprintf(buffer + size, max_size - size, ",%d", dims[i]); |
| if (written_size < 0) return; |
| size += written_size; |
| } |
| } |
| } |
| |
| // Checks whether there is any difference between dimensions of a tensor and a |
| // given dimensions. Returns true if there is difference, else false. |
| bool AreDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) { |
| int num_dims = static_cast<int>(env->GetArrayLength(dims)); |
| jint* ptr = env->GetIntArrayElements(dims, nullptr); |
| if (ptr == nullptr) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Empty dimensions of input array."); |
| return true; |
| } |
| bool is_different = false; |
| if (tensor->dims->size != num_dims) { |
| is_different = true; |
| } else { |
| for (int i = 0; i < num_dims; ++i) { |
| if (ptr[i] != tensor->dims->data[i]) { |
| is_different = true; |
| break; |
| } |
| } |
| } |
| env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); |
| return is_different; |
| } |
| |
| // TODO(yichengfan): evaluate the benefit to use tflite verifier. |
| bool VerifyModel(const void* buf, size_t len) { |
| flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len); |
| return tflite::VerifyModelBuffer(verifier); |
| } |
| |
| #if !TFLITE_DISABLE_SELECT_JAVA_APIS |
| // Return true when the given subgraph index is valid or throw an exception. |
| bool ValidateSubgraphIndex(JNIEnv* env, Interpreter* interpreter, |
| const int subgraph_idx) { |
| if (subgraph_idx < 0 || subgraph_idx >= interpreter->subgraphs_size()) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Input error: Can not access %d-th subgraph for a model " |
| "having %d subgraphs", |
| subgraph_idx, interpreter->subgraphs_size()); |
| return false; |
| } |
| return true; |
| } |
| #endif |
| |
| #if !TFLITE_DISABLE_SELECT_JAVA_APIS |
| // Helper method that fetches the tensor index based on SignatureDef details |
| // from either inputs or outputs. |
| // Returns -1 if invalid names are passed. |
| int GetTensorIndexForSignature(JNIEnv* env, jstring signature_tensor_name, |
| jstring signature_key, Interpreter* interpreter, |
| bool is_input) { |
| // Fetch name strings. |
| const char* signature_key_ptr = |
| env->GetStringUTFChars(signature_key, nullptr); |
| const char* signature_input_name_ptr = |
| env->GetStringUTFChars(signature_tensor_name, nullptr); |
| // Lookup if the input is valid. |
| const auto& signature_list = |
| (is_input ? interpreter->signature_inputs(signature_key_ptr) |
| : interpreter->signature_outputs(signature_key_ptr)); |
| const auto& tensor = signature_list.find(signature_input_name_ptr); |
| // Release the memory before returning. |
| env->ReleaseStringUTFChars(signature_key, signature_key_ptr); |
| env->ReleaseStringUTFChars(signature_tensor_name, signature_input_name_ptr); |
| return tensor == signature_list.end() ? -1 : tensor->second; |
| } |
| |
| jobjectArray GetSignatureInputsOutputsList( |
| const std::map<std::string, uint32_t>& input_output_list, JNIEnv* env) { |
| jclass string_class = env->FindClass("java/lang/String"); |
| if (string_class == nullptr) { |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Internal error: Can not find java/lang/String class to get " |
| "SignatureDef names."); |
| return nullptr; |
| } |
| |
| jobjectArray names = env->NewObjectArray(input_output_list.size(), |
| string_class, env->NewStringUTF("")); |
| int i = 0; |
| for (const auto& input : input_output_list) { |
| env->SetObjectArrayElement(names, i++, |
| env->NewStringUTF(input.first.c_str())); |
| } |
| return names; |
| } |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| |
| // Verifies whether the model is a flatbuffer file. |
| class JNIFlatBufferVerifier : public tflite::TfLiteVerifier { |
| public: |
| bool Verify(const char* data, int length, |
| tflite::ErrorReporter* reporter) override { |
| if (!VerifyModel(data, length)) { |
| reporter->Report("The model is not a valid Flatbuffer file"); |
| return false; |
| } |
| return true; |
| } |
| }; |
| |
| } // namespace |
| |
| extern "C" { |
| |
| JNIEXPORT jobjectArray JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, |
| jclass clazz, |
| jlong handle) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return nullptr; |
| |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return nullptr; |
| jclass string_class = env->FindClass("java/lang/String"); |
| if (string_class == nullptr) { |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Internal error: Can not find java/lang/String class to get " |
| "input names."); |
| return nullptr; |
| } |
| size_t size = interpreter->inputs().size(); |
| jobjectArray names = static_cast<jobjectArray>( |
| env->NewObjectArray(size, string_class, env->NewStringUTF(""))); |
| for (int i = 0; i < size; ++i) { |
| env->SetObjectArrayElement(names, i, |
| env->NewStringUTF(interpreter->GetInputName(i))); |
| } |
| return names; |
| } |
| |
| JNIEXPORT void JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( |
| JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, |
| jint subgraph_idx) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return; |
| |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return; |
| BufferErrorReporter* error_reporter = |
| convertLongToErrorReporter(env, error_handle); |
| if (error_reporter == nullptr) return; |
| |
| if (subgraph_idx == 0) { |
| if (interpreter->AllocateTensors() != kTfLiteOk) { |
| ThrowException(env, tflite::jni::kIllegalStateException, |
| "Internal error: Unexpected failure when preparing tensor " |
| "allocations: %s", |
| error_reporter->CachedErrorMessage()); |
| } |
| return; |
| } |
| |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| TFLITE_LOG(tflite::TFLITE_LOG_WARNING, |
| "Not supported: allocateTensors (non-primary subgraph)"); |
| #else |
| if (!ValidateSubgraphIndex(env, interpreter, subgraph_idx)) return; |
| tflite::Subgraph* subgraph = interpreter->subgraph(subgraph_idx); |
| // TODO(b/184696042): Update the following subgraph API-based implementation |
| // with C++ signature runner API. |
| if (subgraph->AllocateTensors() != kTfLiteOk) { |
| ThrowException( |
| env, tflite::jni::kIllegalStateException, |
| "Internal error: Unexpected failure when preparing tensor allocations:" |
| " %s", |
| error_reporter->CachedErrorMessage()); |
| } |
| #endif |
| } |
| |
| JNIEXPORT jboolean JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp( |
| JNIEnv* env, jclass clazz, jlong handle) { |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| TFLITE_LOG(tflite::TFLITE_LOG_WARNING, "Not supported: hasUnresolvedFlexOp"); |
| return JNI_FALSE; |
| #else |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return JNI_FALSE; |
| |
| // TODO(b/132995737): Remove this logic by caching whether an unresolved |
| // Flex op is present during Interpreter creation. |
| for (size_t subgraph_i = 0; subgraph_i < interpreter->subgraphs_size(); |
| ++subgraph_i) { |
| const auto* subgraph = interpreter->subgraph(static_cast<int>(subgraph_i)); |
| for (int node_i : subgraph->execution_plan()) { |
| const auto& registration = |
| subgraph->node_and_registration(node_i)->second; |
| if (tflite::IsUnresolvedCustomOp(registration) && |
| tflite::IsFlexOp(registration.custom_name)) { |
| return JNI_TRUE; |
| } |
| } |
| } |
| return JNI_FALSE; |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| } |
| |
| JNIEXPORT jobjectArray JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureDefNames( |
| JNIEnv* env, jclass clazz, jlong handle) { |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| TFLITE_LOG(tflite::TFLITE_LOG_WARNING, "Not supported: getSignatureDefNames"); |
| return nullptr; |
| #else |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return nullptr; |
| jclass string_class = env->FindClass("java/lang/String"); |
| if (string_class == nullptr) { |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Internal error: Can not find java/lang/String class to get " |
| "SignatureDef names."); |
| return nullptr; |
| } |
| const auto& signature_defs = interpreter->signature_def_names(); |
| jobjectArray names = static_cast<jobjectArray>(env->NewObjectArray( |
| signature_defs.size(), string_class, env->NewStringUTF(""))); |
| for (int i = 0; i < signature_defs.size(); ++i) { |
| env->SetObjectArrayElement(names, i, |
| env->NewStringUTF(signature_defs[i]->c_str())); |
| } |
| return names; |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| } |
| |
| JNIEXPORT jobjectArray JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureInputs( |
| JNIEnv* env, jclass clazz, jlong handle, jstring signature_key) { |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Not supported: getSignatureInputs"); |
| return nullptr; |
| #else |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return nullptr; |
| const char* signature_key_ptr = |
| env->GetStringUTFChars(signature_key, nullptr); |
| const jobjectArray signature_inputs = GetSignatureInputsOutputsList( |
| interpreter->signature_inputs(signature_key_ptr), env); |
| // Release the memory before returning. |
| env->ReleaseStringUTFChars(signature_key, signature_key_ptr); |
| return signature_inputs; |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| } |
| |
| JNIEXPORT jobjectArray JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureOutputs( |
| JNIEnv* env, jclass clazz, jlong handle, jstring signature_key) { |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Not supported: getSignatureOutputs"); |
| return nullptr; |
| #else |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return nullptr; |
| const char* signature_key_ptr = |
| env->GetStringUTFChars(signature_key, nullptr); |
| const jobjectArray signature_outputs = GetSignatureInputsOutputsList( |
| interpreter->signature_outputs(signature_key_ptr), env); |
| // Release the memory before returning. |
| env->ReleaseStringUTFChars(signature_key, signature_key_ptr); |
| return signature_outputs; |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| } |
| |
| JNIEXPORT jint JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getSubgraphIndexFromSignature( |
| JNIEnv* env, jclass clazz, jlong handle, jstring signature_key) { |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Not supported: getSubgraphIndexFromSignature"); |
| return -1; |
| #else |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return -1; |
| const char* signature_key_ptr = |
| env->GetStringUTFChars(signature_key, nullptr); |
| |
| int32_t subgraph_idx = |
| interpreter->GetSubgraphIndexFromSignature(signature_key_ptr); |
| // Release the memory before returning. |
| env->ReleaseStringUTFChars(signature_key, signature_key_ptr); |
| return subgraph_idx; |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| } |
| |
| JNIEXPORT jint JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndexFromSignature( |
| JNIEnv* env, jclass clazz, jlong handle, jstring signature_input_name, |
| jstring signature_key) { |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Not supported: getInputTensorIndexFromSignature"); |
| return -1; |
| #else |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return -1; |
| return GetTensorIndexForSignature(env, signature_input_name, signature_key, |
| interpreter, /*is_input=*/true); |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| } |
| |
| JNIEXPORT jint JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndexFromSignature( |
| JNIEnv* env, jclass clazz, jlong handle, jstring signature_output_name, |
| jstring signature_key) { |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Not supported: getOutputTensorIndexFromSignature"); |
| return -1; |
| #else |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return -1; |
| return GetTensorIndexForSignature(env, signature_output_name, signature_key, |
| interpreter, /*is_input=*/false); |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| } |
| |
| JNIEXPORT jint JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex( |
| JNIEnv* env, jclass clazz, jlong handle, jint input_index) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return 0; |
| |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return 0; |
| return interpreter->inputs()[input_index]; |
| } |
| |
| JNIEXPORT jint JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex( |
| JNIEnv* env, jclass clazz, jlong handle, jint output_index) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return 0; |
| |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return 0; |
| return interpreter->outputs()[output_index]; |
| } |
| |
| JNIEXPORT jint JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength( |
| JNIEnv* env, jclass clazz, jlong handle) { |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Not supported: getExecutionPlanLength"); |
| return -1; |
| #else |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return 0; |
| return static_cast<jint>(interpreter->execution_plan().size()); |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| } |
| |
| JNIEXPORT jint JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env, |
| jclass clazz, |
| jlong handle) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return 0; |
| |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return 0; |
| return static_cast<jint>(interpreter->inputs().size()); |
| } |
| |
| JNIEXPORT jint JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env, |
| jclass clazz, |
| jlong handle) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return 0; |
| |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return 0; |
| return static_cast<jint>(interpreter->outputs().size()); |
| } |
| |
| JNIEXPORT jobjectArray JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, |
| jclass clazz, |
| jlong handle) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return nullptr; |
| |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return nullptr; |
| jclass string_class = env->FindClass("java/lang/String"); |
| if (string_class == nullptr) { |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Internal error: Can not find java/lang/String class to get " |
| "output names."); |
| return nullptr; |
| } |
| size_t size = interpreter->outputs().size(); |
| jobjectArray names = static_cast<jobjectArray>( |
| env->NewObjectArray(size, string_class, env->NewStringUTF(""))); |
| for (int i = 0; i < size; ++i) { |
| env->SetObjectArrayElement( |
| names, i, env->NewStringUTF(interpreter->GetOutputName(i))); |
| } |
| return names; |
| } |
| |
| JNIEXPORT void JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32( |
| JNIEnv* env, jclass clazz, jlong handle, jboolean allow) { |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return; |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| if (allow) { |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Not supported: SetAllowFp16PrecisionForFp32(true)"); |
| } |
| #else |
| interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow)); |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| } |
| |
| JNIEXPORT void JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput( |
| JNIEnv* env, jclass clazz, jlong handle, jboolean allow) { |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| if (allow) { |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Not supported: allowBufferHandleOutput(true)"); |
| } |
| #else |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return; |
| interpreter->SetAllowBufferHandleOutput(allow); |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| } |
| |
| JNIEXPORT void JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK( |
| JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jint state, |
| jint num_threads) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return; |
| |
| // If not using xnnpack, simply don't apply the delegate. |
| if (state == 0) { |
| return; |
| } |
| |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) { |
| return; |
| } |
| |
| BufferErrorReporter* error_reporter = |
| convertLongToErrorReporter(env, error_handle); |
| if (error_reporter == nullptr) { |
| return; |
| } |
| |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| // TODO(b/173022832): Implement support for XNNPack unconditionally. |
| if (state == -1) { |
| // Instead of throwing an exception, we tolerate the fact that XNNPACK is |
| // not implemented yet, because we try to apply XNNPACK delegate by default. |
| TF_LITE_REPORT_ERROR(error_reporter, |
| "WARNING: Not applying XNNPACK delegate by default " |
| "because it isn't supported in this module.\n"); |
| } else { |
| // In this case, XNNPACK was explicitly requested, so we throw an exception. |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Not supported: XNNPACK delegate"); |
| } |
| #else |
| // We use dynamic loading to avoid taking a hard dependency on XNNPack. |
| // This allows clients that use trimmed builds to save on binary size. |
| auto xnnpack_options_default = |
| reinterpret_cast<decltype(TfLiteXNNPackDelegateOptionsDefault)*>( |
| dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateOptionsDefault")); |
| auto xnnpack_create = |
| reinterpret_cast<decltype(TfLiteXNNPackDelegateCreate)*>( |
| dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateCreate")); |
| auto xnnpack_delete = |
| reinterpret_cast<decltype(TfLiteXNNPackDelegateDelete)*>( |
| dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateDelete")); |
| |
| if (xnnpack_options_default && xnnpack_create && xnnpack_delete) { |
| TfLiteXNNPackDelegateOptions options = xnnpack_options_default(); |
| if (num_threads > 0) { |
| options.num_threads = num_threads; |
| } |
| Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options), |
| xnnpack_delete); |
| auto delegation_status = |
| interpreter->ModifyGraphWithDelegate(std::move(delegate)); |
| // kTfLiteApplicationError occurs in cases where delegation fails but |
| // the runtime is invokable (eg. another delegate has already been applied). |
| // We don't throw an Exception in that case. |
| // TODO(b/166483905): Add support for multiple delegates when model allows. |
| if (delegation_status != kTfLiteOk && |
| delegation_status != kTfLiteApplicationError) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Internal error: Failed to apply XNNPACK delegate: %s", |
| error_reporter->CachedErrorMessage()); |
| } |
| } else if (state == -1) { |
| // Instead of throwing an exception, we tolerate the missing of such |
| // dependencies because we try to apply XNNPACK delegate by default. |
| TF_LITE_REPORT_ERROR( |
| error_reporter, |
| "WARNING: Missing necessary XNNPACK delegate dependencies to apply it " |
| "by default.\n"); |
| } else { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Failed to load XNNPACK delegate from current runtime. " |
| "Have you added the necessary dependencies?"); |
| } |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| } |
| |
| JNIEXPORT jlong JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( |
| JNIEnv* env, jclass clazz, jint size) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return 0; |
| |
| BufferErrorReporter* error_reporter = |
| new BufferErrorReporter(env, static_cast<int>(size)); |
| return reinterpret_cast<jlong>(error_reporter); |
| } |
| |
| JNIEXPORT jlong JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( |
| JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return 0; |
| |
| BufferErrorReporter* error_reporter = |
| convertLongToErrorReporter(env, error_handle); |
| if (error_reporter == nullptr) return 0; |
| const char* path = env->GetStringUTFChars(model_file, nullptr); |
| |
| std::unique_ptr<tflite::TfLiteVerifier> verifier; |
| verifier.reset(new JNIFlatBufferVerifier()); |
| |
| auto model = FlatBufferModel::VerifyAndBuildFromFile(path, verifier.get(), |
| error_reporter); |
| if (!model) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Contents of %s does not encode a valid " |
| "TensorFlow Lite model: %s", |
| path, error_reporter->CachedErrorMessage()); |
| env->ReleaseStringUTFChars(model_file, path); |
| return 0; |
| } |
| env->ReleaseStringUTFChars(model_file, path); |
| return reinterpret_cast<jlong>(model.release()); |
| } |
| |
| JNIEXPORT jlong JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( |
| JNIEnv* env, jclass /*clazz*/, jobject model_buffer, jlong error_handle) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return 0; |
| |
| BufferErrorReporter* error_reporter = |
| convertLongToErrorReporter(env, error_handle); |
| if (error_reporter == nullptr) return 0; |
| const char* buf = |
| static_cast<char*>(env->GetDirectBufferAddress(model_buffer)); |
| jlong capacity = env->GetDirectBufferCapacity(model_buffer); |
| if (!VerifyModel(buf, capacity)) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "ByteBuffer is not a valid flatbuffer model"); |
| return 0; |
| } |
| |
| auto model = FlatBufferModel::BuildFromBuffer( |
| buf, static_cast<size_t>(capacity), error_reporter); |
| if (!model) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "ByteBuffer does not encode a valid model: %s", |
| error_reporter->CachedErrorMessage()); |
| return 0; |
| } |
| return reinterpret_cast<jlong>(model.release()); |
| } |
| |
| JNIEXPORT jlong JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( |
| JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle, |
| jint num_threads) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return 0; |
| |
| FlatBufferModel* model = convertLongToModel(env, model_handle); |
| if (model == nullptr) return 0; |
| BufferErrorReporter* error_reporter = |
| convertLongToErrorReporter(env, error_handle); |
| if (error_reporter == nullptr) return 0; |
| std::unique_ptr<OpResolver> resolver = tflite_shims::CreateOpResolver(); |
| InterpreterBuilder interpreter_builder(*model, *resolver); |
| interpreter_builder.SetNumThreads(static_cast<int>(num_threads)); |
| std::unique_ptr<Interpreter> interpreter; |
| TfLiteStatus status = interpreter_builder(&interpreter); |
| if (status != kTfLiteOk) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Internal error: Cannot create interpreter: %s", |
| error_reporter->CachedErrorMessage()); |
| return 0; |
| } |
| // Note that tensor allocation is performed explicitly by the owning Java |
| // NativeInterpreterWrapper instance. |
| return reinterpret_cast<jlong>(interpreter.release()); |
| } |
| |
| // Sets inputs, runs inference, and returns outputs as long handles. |
| JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( |
| JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return; |
| |
| Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); |
| if (interpreter == nullptr) return; |
| BufferErrorReporter* error_reporter = |
| convertLongToErrorReporter(env, error_handle); |
| if (error_reporter == nullptr) return; |
| |
| if (interpreter->Invoke() != kTfLiteOk) { |
| // TODO(b/168266570): Return InterruptedException. |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Internal error: Failed to run on the given Interpreter: %s", |
| error_reporter->CachedErrorMessage()); |
| return; |
| } |
| } |
| |
| JNIEXPORT void JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_runSignature( |
| JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, |
| jint subgraph_idx) { |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| TFLITE_LOG(tflite::TFLITE_LOG_WARNING, "Not supported: runSignature"); |
| #else |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return; |
| |
| Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); |
| if (interpreter == nullptr) return; |
| BufferErrorReporter* error_reporter = |
| convertLongToErrorReporter(env, error_handle); |
| if (error_reporter == nullptr) return; |
| |
| if (!ValidateSubgraphIndex(env, interpreter, subgraph_idx)) return; |
| tflite::Subgraph* subgraph = interpreter->subgraph(subgraph_idx); |
| if (subgraph->Invoke() != kTfLiteOk) { |
| // TODO(b/168266570): Return InterruptedException. |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Internal error: Failed to run on the given Interpreter: %s", |
| error_reporter->CachedErrorMessage()); |
| return; |
| } |
| // Make sure that the output tensors are readable. |
| for (int tensor_index : subgraph->outputs()) { |
| subgraph->EnsureTensorDataIsReadable(tensor_index); |
| } |
| #endif |
| } |
| |
| JNIEXPORT jint JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( |
| JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return -1; |
| |
| Interpreter* interpreter = convertLongToInterpreter(env, handle); |
| if (interpreter == nullptr) return -1; |
| const int idx = static_cast<int>(output_idx); |
| if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Failed to get %d-th output out of %d outputs", output_idx, |
| interpreter->outputs().size()); |
| return -1; |
| } |
| TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]); |
| int type = getDataType(target->type); |
| return static_cast<jint>(type); |
| } |
| |
| JNIEXPORT jboolean JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( |
| JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, |
| jint input_idx, jintArray dims, jboolean strict, jint subgraph_idx) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return JNI_FALSE; |
| |
| BufferErrorReporter* error_reporter = |
| convertLongToErrorReporter(env, error_handle); |
| if (error_reporter == nullptr) return JNI_FALSE; |
| Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); |
| if (interpreter == nullptr) return JNI_FALSE; |
| if (subgraph_idx == 0) { |
| if (input_idx < 0 || input_idx >= interpreter->inputs().size()) { |
| ThrowException( |
| env, tflite::jni::kIllegalArgumentException, |
| "Input error: Can not resize %d-th input for a model having " |
| "%d inputs.", |
| input_idx, interpreter->inputs().size()); |
| return JNI_FALSE; |
| } |
| const int tensor_idx = interpreter->inputs()[input_idx]; |
| // check whether it is resizing with the same dimensions. |
| TfLiteTensor* target = interpreter->tensor(tensor_idx); |
| bool is_changed = AreDimsDifferent(env, target, dims); |
| if (is_changed) { |
| TfLiteStatus status; |
| if (strict) { |
| status = interpreter->ResizeInputTensorStrict( |
| tensor_idx, convertJIntArrayToVector(env, dims)); |
| } else { |
| status = interpreter->ResizeInputTensor( |
| tensor_idx, convertJIntArrayToVector(env, dims)); |
| } |
| if (status != kTfLiteOk) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Internal error: Failed to resize %d-th input: %s", |
| input_idx, error_reporter->CachedErrorMessage()); |
| return JNI_FALSE; |
| } |
| } |
| return is_changed ? JNI_TRUE : JNI_FALSE; |
| } |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| TFLITE_LOG(tflite::TFLITE_LOG_WARNING, |
| "Not supported: resizeInput (non-primary subgraph)"); |
| return JNI_FALSE; |
| #else |
| if (!ValidateSubgraphIndex(env, interpreter, subgraph_idx)) return JNI_FALSE; |
| tflite::Subgraph* subgraph = interpreter->subgraph(subgraph_idx); |
| if (input_idx < 0 || input_idx >= subgraph->inputs().size()) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Input error: Can not resize %d-th input for a model having " |
| "%d inputs.", |
| input_idx, subgraph->inputs().size()); |
| return JNI_FALSE; |
| } |
| const int tensor_idx = subgraph->inputs()[input_idx]; |
| // check whether it is resizing with the same dimensions. |
| TfLiteTensor* target = subgraph->tensor(tensor_idx); |
| bool is_changed = AreDimsDifferent(env, target, dims); |
| if (is_changed) { |
| TfLiteStatus status; |
| if (strict) { |
| status = subgraph->ResizeInputTensorStrict( |
| tensor_idx, convertJIntArrayToVector(env, dims)); |
| } else { |
| status = subgraph->ResizeInputTensor(tensor_idx, |
| convertJIntArrayToVector(env, dims)); |
| } |
| if (status != kTfLiteOk) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Internal error: Failed to resize %d-th input: %s", |
| input_idx, error_reporter->CachedErrorMessage()); |
| return JNI_FALSE; |
| } |
| } |
| return is_changed ? JNI_TRUE : JNI_FALSE; |
| #endif |
| } |
| |
| JNIEXPORT void JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate( |
| JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, |
| jlong delegate_handle) { |
| if (!tflite::jni::CheckJniInitializedOrThrow(env)) return; |
| |
| Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); |
| if (interpreter == nullptr) return; |
| |
| BufferErrorReporter* error_reporter = |
| convertLongToErrorReporter(env, error_handle); |
| if (error_reporter == nullptr) return; |
| |
| TfLiteOpaqueDelegate* delegate = convertLongToDelegate(env, delegate_handle); |
| if (delegate == nullptr) return; |
| |
| TfLiteStatus status = interpreter->ModifyGraphWithDelegate(delegate); |
| if (status != kTfLiteOk) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Internal error: Failed to apply delegate: %s", |
| error_reporter->CachedErrorMessage()); |
| } |
| } |
| |
| JNIEXPORT jlong JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_createCancellationFlag( |
| JNIEnv* env, jclass clazz, jlong interpreter_handle) { |
| Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); |
| if (interpreter == nullptr) { |
| ThrowException(env, tflite::jni::kIllegalArgumentException, |
| "Internal error: Invalid handle to interpreter."); |
| return 0; |
| } |
| std::atomic_bool* cancellation_flag = new std::atomic_bool(false); |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Not supported: cancellation"); |
| #else |
| interpreter->SetCancellationFunction(cancellation_flag, [](void* payload) { |
| std::atomic_bool* cancellation_flag = |
| reinterpret_cast<std::atomic_bool*>(payload); |
| return cancellation_flag->load() == true; |
| }); |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| return reinterpret_cast<jlong>(cancellation_flag); |
| } |
| |
| JNIEXPORT void JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_deleteCancellationFlag( |
| JNIEnv* env, jclass clazz, jlong flag_handle) { |
| std::atomic_bool* cancellation_flag = |
| reinterpret_cast<std::atomic_bool*>(flag_handle); |
| delete cancellation_flag; |
| } |
| |
| JNIEXPORT void JNICALL |
| Java_org_tensorflow_lite_NativeInterpreterWrapper_setCancelled( |
| JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong flag_handle, |
| jboolean value) { |
| #if TFLITE_DISABLE_SELECT_JAVA_APIS |
| ThrowException(env, tflite::jni::kUnsupportedOperationException, |
| "Not supported: cancellation"); |
| #else |
| std::atomic_bool* cancellation_flag = |
| reinterpret_cast<std::atomic_bool*>(flag_handle); |
| if (cancellation_flag != nullptr) { |
| cancellation_flag->store(static_cast<bool>(value)); |
| } |
| #endif // TFLITE_DISABLE_SELECT_JAVA_APIS |
| } |
| |
| JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete( |
| JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle, |
| jlong interpreter_handle) { |
| if (interpreter_handle != 0) { |
| delete convertLongToInterpreter(env, interpreter_handle); |
| } |
| if (model_handle != 0) { |
| delete convertLongToModel(env, model_handle); |
| } |
| if (error_handle != 0) { |
| delete convertLongToErrorReporter(env, error_handle); |
| } |
| } |
| |
| } // extern "C" |