Add more interpreter functions on Unity Plugin
diff --git a/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs
index 83291e6..5b885f6 100644
--- a/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs
+++ b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs
@@ -44,11 +44,20 @@
}
void Start () {
- interpreter = new Interpreter(model.bytes);
- Debug.LogFormat(
- "InputCount: {0}, OutputCount: {1}",
- interpreter.GetInputTensorCount(),
- interpreter.GetOutputTensorCount());
+ Debug.LogFormat("TensorFlow Lite Verion: {0}", Interpreter.GetVersion());
+
+ interpreter = new Interpreter(
+ modelData: model.bytes,
+ threads: 2);
+
+ int inputCount = interpreter.GetInputTensorCount();
+ int outputCount = interpreter.GetOutputTensorCount();
+ for (int i = 0; i < inputCount; i++) {
+ Debug.LogFormat("Input {0}: {1}", i, interpreter.GetInputTensorInfo(i));
+ }
+ for (int i = 0; i < inputCount; i++) {
+ Debug.LogFormat("Output {0}: {1}", i, interpreter.GetOutputTensorInfo(i));
+ }
}
void Update () {
diff --git a/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs
index fd8818f..2fc89bd 100644
--- a/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs
+++ b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs
@@ -19,6 +19,7 @@
using TfLiteInterpreterOptions = System.IntPtr;
using TfLiteModel = System.IntPtr;
using TfLiteTensor = System.IntPtr;
+using TfLiteDelegate = System.IntPtr;
namespace TensorFlowLite
{
@@ -31,25 +32,31 @@
private TfLiteModel model;
private TfLiteInterpreter interpreter;
+ private TfLiteInterpreterOptions options;
- public Interpreter(byte[] modelData) {
+ public Interpreter(byte[] modelData, int threads) {
GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned);
IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject();
model = TfLiteModelCreate(modelDataPtr, modelData.Length);
if (model == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Model");
- interpreter = TfLiteInterpreterCreate(model, /*options=*/IntPtr.Zero);
+
+ options = TfLiteInterpreterOptionsCreate();
+
+ if (threads > 1) {
+ TfLiteInterpreterOptionsSetNumThreads(options, threads);
+ }
+
+ interpreter = TfLiteInterpreterCreate(model, options);
if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter");
}
- ~Interpreter() {
- Dispose();
- }
-
public void Dispose() {
- if (interpreter != IntPtr.Zero) TfLiteInterpreterDelete(interpreter);
- interpreter = IntPtr.Zero;
if (model != IntPtr.Zero) TfLiteModelDelete(model);
model = IntPtr.Zero;
+ if (interpreter != IntPtr.Zero) TfLiteInterpreterDelete(interpreter);
+ interpreter = IntPtr.Zero;
+ if (options != IntPtr.Zero) TfLiteInterpreterOptionsDelete(options);
+ options = IntPtr.Zero;
}
public void Invoke() {
@@ -89,20 +96,66 @@
tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData)));
}
+ public string GetInputTensorInfo(int index) {
+ TfLiteTensor tensor = TfLiteInterpreterGetInputTensor(interpreter, index);
+ return GetTensorInfo(tensor);
+ }
+
+ public string GetOutputTensorInfo(int index) {
+ TfLiteTensor tensor = TfLiteInterpreterGetOutputTensor(interpreter, index);
+ return GetTensorInfo(tensor);
+ }
+
public static string GetVersion() {
return Marshal.PtrToStringAnsi(TfLiteVersion());
}
+ private static string GetTensorName(TfLiteTensor tensor) {
+ return Marshal.PtrToStringAnsi(TfLiteTensorName(tensor));
+ }
+
+ private static string GetTensorInfo(TfLiteTensor tensor) {
+ var sb = new System.Text.StringBuilder();
+ sb.AppendFormat("{0} type:{1}, dims:[",
+ GetTensorName(tensor),
+ TfLiteTensorType(tensor));
+
+ int dims = TfLiteTensorNumDims(tensor);
+ for (int i = 0; i < dims; i++) {
+ sb.Append(TfLiteTensorDim(tensor, i));
+ sb.Append(i == dims - 1 ? "]" : ", ");
+ }
+ return sb.ToString();
+ }
+
private static void ThrowIfError(int resultCode) {
if (resultCode != 0) throw new Exception("TensorFlowLite operation failed.");
}
#region Externs
+ public enum TfLiteType {
+ NoType = 0,
+ Float32 = 1,
+ Int32 = 2,
+ UInt8 = 3,
+ Int64 = 4,
+ String = 5,
+ Bool = 6,
+ Int16 = 7,
+ Complex64 = 8,
+ Int8 = 9,
+ Float16 = 10,
+ }
+
+ public struct TfLiteQuantizationParams {
+ public float scale;
+ public int zero_point;
+ }
+
[DllImport (TensorFlowLibrary)]
private static extern unsafe IntPtr TfLiteVersion();
-
[DllImport (TensorFlowLibrary)]
private static extern unsafe TfLiteInterpreter TfLiteModelCreate(IntPtr model_data, int model_size);
@@ -110,6 +163,23 @@
private static extern unsafe TfLiteInterpreter TfLiteModelDelete(TfLiteModel model);
[DllImport (TensorFlowLibrary)]
+ private static extern unsafe TfLiteInterpreterOptions TfLiteInterpreterOptionsCreate();
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe void TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions options);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe TfLiteInterpreterOptions TfLiteInterpreterOptionsSetNumThreads(
+ TfLiteInterpreterOptions options,
+ int num_threads
+ );
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe TfLiteInterpreterOptions TfLiteInterpreterOptionsAddDelegate(
+ TfLiteInterpreterOptions options,
+ TfLiteDelegate _delegate);
+
+ [DllImport (TensorFlowLibrary)]
private static extern unsafe TfLiteInterpreter TfLiteInterpreterCreate(
TfLiteModel model,
TfLiteInterpreterOptions optional_options);
@@ -148,6 +218,27 @@
private static extern unsafe TfLiteTensor TfLiteInterpreterGetOutputTensor(
TfLiteInterpreter interpreter,
int output_index);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe TfLiteType TfLiteTensorType(TfLiteTensor tensor);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe int TfLiteTensorNumDims(TfLiteTensor tensor);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern int TfLiteTensorDim(TfLiteTensor tensor, int dim_index);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern uint TfLiteTensorByteSize(TfLiteTensor tensor);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe IntPtr TfLiteTensorData(TfLiteTensor tensor);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe IntPtr TfLiteTensorName(TfLiteTensor tensor);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe TfLiteQuantizationParams TfLiteTensorQuantizationParams(TfLiteTensor tensor);
[DllImport (TensorFlowLibrary)]
private static extern unsafe int TfLiteTensorCopyFromBuffer(