Add intermediate tensor dump.
Add utility classes for reading binary tensors.
Add tensor visualization.
Test: mm
Merged-In:I5aeca89bf10ade0f206a0426c3f860f9483d7aa3
Change-Id:I5aeca89bf10ade0f206a0426c3f860f9483d7aa3
(cherry picked from commit 809f892f69eaf82cbf22bd64d958a878fc767b4a)
diff --git a/AndroidManifest.xml b/AndroidManifest.xml
index 93648b7..f4a7eed 100644
--- a/AndroidManifest.xml
+++ b/AndroidManifest.xml
@@ -37,7 +37,7 @@
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
- <activity android:exported="true" android:name="com.android.nn.benchmark.util.DumpAllTensors">
+ <activity android:exported="true" android:name="com.android.nn.benchmark.util.DumpIntermediateTensors">
</activity>
<activity android:exported="true" android:name="com.android.nn.benchmark.util.TestExternalStorageActivity">
</activity>
diff --git a/jni/run_tflite.cpp b/jni/run_tflite.cpp
index 9472480..5ca88a8 100644
--- a/jni/run_tflite.cpp
+++ b/jni/run_tflite.cpp
@@ -100,7 +100,7 @@
if (enable_intermediate_tensors_dump) {
// Make output of every op a model output. This way we will be able to
// fetch each intermediate tensor when running with delegates.
- std::vector<int> outputs;
+ outputs.clear();
for (size_t node = 0; node < mTfliteInterpreter->nodes_size(); ++node) {
auto node_outputs =
mTfliteInterpreter->node_and_registration(node)->first.outputs;
@@ -337,15 +337,18 @@
return false;
}
- for (int tensor = 0; tensor < mTfliteInterpreter->tensors_size();
- ++tensor) {
- auto* output_tensor = mTfliteInterpreter->tensor(tensor);
+ // The order of the tensor is not sorted by the tensor index
+ for (int tensor_order = 0; tensor_order < outputs.size(); ++tensor_order) {
+ int tensor_index = outputs[tensor_order];
+ auto* output_tensor = mTfliteInterpreter->tensor(tensor_index);
if (output_tensor->data.raw == nullptr) {
+ __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
+ "output_tensor->data.raw == nullptr at index %d ", tensor_index);
continue;
}
char fullpath[1024];
- snprintf(fullpath, 1024, "%s/dump_%.3d_seq_%.3d_tensor_%.3d", path,
- seqInferenceIndex, i, tensor);
+ snprintf(fullpath, 1024, "%s/dump_%.3d_seq_%.3d_order_%.3d_tensor_%.3d", path,
+ seqInferenceIndex, i, tensor_order, tensor_index);
FILE* f = fopen(fullpath, "wb");
fwrite(output_tensor->data.raw, output_tensor->bytes, 1, f);
fclose(f);
diff --git a/jni/run_tflite.h b/jni/run_tflite.h
index dae6bb0..4a533db 100644
--- a/jni/run_tflite.h
+++ b/jni/run_tflite.h
@@ -94,6 +94,8 @@
std::unique_ptr<tflite::FlatBufferModel> mTfliteModel;
std::unique_ptr<tflite::Interpreter> mTfliteInterpreter;
+ // Store indices of output tensors, used to dump intermediate tensors
+ std::vector<int> outputs;
};
#endif // COM_EXAMPLE_ANDROID_NN_BENCHMARK_RUN_TFLITE_H
diff --git a/src/com/android/nn/benchmark/util/DumpAllTensors.java b/src/com/android/nn/benchmark/util/DumpAllTensors.java
deleted file mode 100644
index 07bedec..0000000
--- a/src/com/android/nn/benchmark/util/DumpAllTensors.java
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.nn.benchmark.util;
-
-import android.app.Activity;
-import android.os.Bundle;
-import com.android.nn.benchmark.core.NNTestBase;
-import com.android.nn.benchmark.core.TestModels.TestModelEntry;
-import com.android.nn.benchmark.core.TestModels;
-import java.io.IOException;
-import java.io.File;
-
-/** Helper activity for dumping state of interference intermediate tensors.
- *
- * Example usage:
- * adb shell am start -n com.android.nn.benchmark.app/com.android.nn.benchmark.\
- * util.DumpAllTensors --es modelName mobilenet_quantized inputAssetIndex 0 useNNAPI true
- *
- * Assets will be then dumped into /data/data/com.android.nn.benchmark.app/files/dump
- * To fetch:
- * adb pull /data/data/com.android.benchmark.app/files/dump
- *
- */
-public class DumpAllTensors extends Activity {
- public static final String EXTRA_MODEL_NAME = "modelName";
- public static final String EXTRA_USE_NNAPI = "useNNAPI";
- public static final String EXTRA_INPUT_ASSET_INDEX= "inputAssetIndex";
- public static final String EXTRA_INPUT_ASSET_SIZE= "inputAssetSize";
-
- @Override
- protected void onCreate(Bundle savedInstanceState) {
- super.onCreate(savedInstanceState);
- Bundle extras = getIntent().getExtras();
-
- String modelName = extras.getString(EXTRA_MODEL_NAME);
- if (modelName == null) {
- throw new IllegalArgumentException("No modelName extra passed with intent");
- }
- boolean useNNAPI = extras.getBoolean(EXTRA_USE_NNAPI, false);
- int inputAssetIndex = extras.getInt(EXTRA_INPUT_ASSET_INDEX, 0);
- int inputAssetSize = extras.getInt(EXTRA_INPUT_ASSET_SIZE, 1);
-
- try {
- File dumpDir = new File(getFilesDir(), "dump");
- deleteRecursive(dumpDir);
- dumpDir.mkdir();
-
- TestModelEntry modelEntry = TestModels.getModelByName(modelName);
- NNTestBase testBase = modelEntry.createNNTestBase(useNNAPI, true);
- testBase.setupModel(this);
- testBase.dumpAllLayers(dumpDir, inputAssetIndex, inputAssetSize);
- } catch (IOException e) {
- throw new IllegalStateException("Failed to dump tensors", e);
- }
- finish();
- }
-
- private void deleteRecursive(File fileOrDirectory) {
- if (fileOrDirectory.isDirectory()) {
- for (File child : fileOrDirectory.listFiles()) {
- deleteRecursive(child);
- }
- }
- fileOrDirectory.delete();
- }
-}
diff --git a/src/com/android/nn/benchmark/util/DumpIntermediateTensors.java b/src/com/android/nn/benchmark/util/DumpIntermediateTensors.java
new file mode 100644
index 0000000..e79a8b3
--- /dev/null
+++ b/src/com/android/nn/benchmark/util/DumpIntermediateTensors.java
@@ -0,0 +1,124 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.nn.benchmark.util;
+
+import android.util.Log;
+
+import android.app.Activity;
+import android.os.Bundle;
+import com.android.nn.benchmark.core.NNTestBase;
+import com.android.nn.benchmark.core.TestModels.TestModelEntry;
+import com.android.nn.benchmark.core.TestModels;
+import java.io.IOException;
+import java.io.File;
+
+
+/** Helper activity for dumping state of interference intermediate tensors.
+ *
+ * Example usage:
+ * adb shell am start -n com.android.nn.benchmark.app/com.android.nn.benchmark.\
+ * util.DumpIntermediateTensors --es modelName mobilenet_v1_1.0_224_quant_topk_aosp,tts_float\
+ * inputAssetIndex 0
+ *
+ * Assets will be then dumped into /data/data/com.android.nn.benchmark.app/files/intermediate
+ * To fetch:
+ * adb pull /data/data/com.android.nn.benchmark.app/files/intermediate
+ *
+ */
+public class DumpIntermediateTensors extends Activity {
+ protected static final String TAG = "VDEBUG";
+ public static final String EXTRA_MODEL_NAME = "modelName";
+ public static final String EXTRA_INPUT_ASSET_INDEX= "inputAssetIndex";
+ public static final String EXTRA_INPUT_ASSET_SIZE= "inputAssetSize";
+ public static final String DUMP_DIR = "intermediate";
+ public static final String CPU_DIR = "cpu";
+ public static final String NNAPI_DIR = "nnapi";
+ // TODO(veralin): Update to use other models in vendor as well.
+ // Due to recent change in NNScoringTest, the model names are moved to here.
+ private static final String[] MODEL_NAMES = new String[]{
+ "tts_float",
+ "asr_float",
+ "mobilenet_v1_1.0_224_quant_topk_aosp",
+ "mobilenet_v1_1.0_224_topk_aosp",
+ "mobilenet_v1_0.75_192_quant_topk_aosp",
+ "mobilenet_v1_0.75_192_topk_aosp",
+ "mobilenet_v1_0.5_160_quant_topk_aosp",
+ "mobilenet_v1_0.5_160_topk_aosp",
+ "mobilenet_v1_0.25_128_quant_topk_aosp",
+ "mobilenet_v1_0.25_128_topk_aosp",
+ "mobilenet_v2_0.35_128_topk_aosp",
+ "mobilenet_v2_0.5_160_topk_aosp",
+ "mobilenet_v2_0.75_192_topk_aosp",
+ "mobilenet_v2_1.0_224_topk_aosp",
+ "mobilenet_v2_1.0_224_quant_topk_aosp",
+ };
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ Bundle extras = getIntent().getExtras();
+
+ String userModelName = extras.getString(EXTRA_MODEL_NAME);
+ int inputAssetIndex = extras.getInt(EXTRA_INPUT_ASSET_INDEX, 0);
+ int inputAssetSize = extras.getInt(EXTRA_INPUT_ASSET_SIZE, 1);
+
+ // Default to run all models in NNScoringTest
+ String[] modelNames = userModelName == null? MODEL_NAMES: userModelName.split(",");
+
+ try {
+ File dumpDir = new File(getFilesDir(), DUMP_DIR);
+ safeMkdir(dumpDir);
+
+ for (String modelName : modelNames) {
+ File modelDir = new File(getFilesDir() + "/" + DUMP_DIR, modelName);
+ safeMkdir(modelDir);
+ // Run in CPU and NNAPI mode
+ for (final boolean useNNAPI : new boolean[] {false, true}) {
+ String useNNAPIDir = useNNAPI? NNAPI_DIR: CPU_DIR;
+ Log.i(TAG, "Running " + modelName + " in " + useNNAPIDir);
+ TestModelEntry modelEntry = TestModels.getModelByName(modelName);
+ NNTestBase testBase = modelEntry.createNNTestBase(
+ useNNAPI, true/*enableIntermediateTensorsDump*/);
+ testBase.setupModel(this);
+ File outputDir = new File(getFilesDir() + "/" + DUMP_DIR +
+ "/" + modelName, useNNAPIDir);
+ safeMkdir(outputDir);
+ testBase.dumpAllLayers(outputDir, inputAssetIndex, inputAssetSize);
+ }
+ }
+
+ } catch (Exception e) {
+ Log.e(TAG, "Failed to dump tensors", e);
+ throw new IllegalStateException("Failed to dump tensors", e);
+ }
+ finish();
+ }
+
+ private void deleteRecursive(File fileOrDirectory) {
+ if (fileOrDirectory.isDirectory()) {
+ for (File child : fileOrDirectory.listFiles()) {
+ deleteRecursive(child);
+ }
+ }
+ fileOrDirectory.delete();
+ }
+
+ private void safeMkdir(File fileOrDirectory) {
+ deleteRecursive(fileOrDirectory);
+ fileOrDirectory.mkdir();
+ }
+}
diff --git a/tools/gen_tflite_visualization.sh b/tools/gen_tflite_visualization.sh
new file mode 100755
index 0000000..5dbe893
--- /dev/null
+++ b/tools/gen_tflite_visualization.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+# Prereq:
+# g4d -f NAME
+# blaze build third_party/tensorflow/lite/tools:visualize
+
+ANDROID_BUILD_TOP="/usr/local/google/home/$(whoami)/android/master"
+MODEL_DIR="$ANDROID_BUILD_TOP/test/mlts/models/assets"
+# The .json files are always output to /tmp
+HTML_DIR="/tmp"
+
+mkdir -p $HTML_DIR
+
+for file in "$MODEL_DIR"/*.tflite
+do
+ if [ -f "$file" ]; then
+ filename=`basename $file`
+ modelname=${filename%.*}
+ blaze-bin/third_party/tensorflow/lite/tools/visualize $file $HTML_DIR/$modelname.html
+ fi
+done
+
+# Example visualization: blaze-bin/third_party/tensorflow/lite/tools/visualize ~/android/master/test/mlts/models/assets/mobilenet_v1_0.75_192.tflite /tmp/mobilenet_v1_0.75_192.html
diff --git a/tools/tensor_utils.py b/tools/tensor_utils.py
new file mode 100644
index 0000000..a29fd3c
--- /dev/null
+++ b/tools/tensor_utils.py
@@ -0,0 +1,296 @@
+ #!/usr/bin/python3
+"""Read intermediate tensors generated by DumpAllTensors activity
+
+Tools for reading/ parsing intermediate tensors.
+"""
+
+import argparse
+import numpy as np
+import os
+import pandas as pd
+import tensorflow as tf
+import matplotlib.pyplot as plt
+import json
+
+from matplotlib.pylab import *
+import matplotlib.animation as animation
+# Enable tensor.numpy()
+tf.compat.v1.enable_eager_execution()
+
+################################ ModelMetaDataManager ################################
+class ModelMetaDataManager(object):
+ """Maps model name in nnapi to its graph architecture with lazy initialization.
+
+ # Arguments
+ android_build_top: the root directory of android source tree
+ dump_dir: directory containing intermediate tensors pulled from device
+ tflite_model_json_path: directory containing intermediate json output of
+ model visualization tool (third_party/tensorflow/lite/tools:visualize)
+ The json output path from the tool is always /tmp.
+ """
+
+ class ModelMetaData(object):
+ """Store graph information of a model."""
+
+ def __init__(self, tflite_model_json_path='/tmp'):
+ with open(tflite_model_json_path, 'rb') as f:
+ model_json = json.load(f)
+ self.operators = model_json['subgraphs'][0]['operators']
+ self.operator_codes = [item['builtin_code']\
+ for item in model_json['operator_codes']]
+ self.output_meta_data = []
+ self.load_output_meta_data()
+
+ def load_output_meta_data(self):
+ for operator in self.operators:
+ data = {}
+ # Each operator can only have one output
+ assert(len(operator['outputs']) == 1)
+ data['output_tensor_index'] = operator['outputs'][0]
+ data['fused_activation_function'] = operator\
+ .get('builtin_options', {})\
+ .get('fused_activation_function', '')
+ data['operator_code'] = self.operator_codes[operator['opcode_index']]
+ self.output_meta_data.append(data)
+
+ def __init__(self, android_build_top, dump_dir, tflite_model_json_dir='/tmp'):
+ # key: nnapi model name, value: ModelMetaData
+ self.models = dict()
+ self.ANDROID_BUILD_TOP = android_build_top
+ self.TFLITE_MODEL_JSON_DIR = tflite_model_json_dir
+ self.DUMP_DIR = dump_dir
+ self.nnapi_to_tflite_name = dict()
+ self.tflite_to_nnapi_name = dict()
+ self.__load_mobilenet_topk_aosp__()
+ self.model_names = sorted(os.listdir(dump_dir))
+
+ def __load_mobilenet_topk_aosp__(self):
+ """Load information about tflite and nnapi model names."""
+ json_path = '{}/{}'.format(
+ self.ANDROID_BUILD_TOP,
+ 'test/mlts/models/assets/models_list/mobilenet_topk_aosp.json')
+ with open(json_path, 'rb') as f:
+ topk_aosp = json.load(f)
+ for model in topk_aosp['models']:
+ self.nnapi_to_tflite_name[model['name']] = model['modelFile']
+ self.tflite_to_nnapi_name[model['modelFile']] = model['name']
+
+ def __get_model_json_path__(self, tflite_model_name):
+ """Return tflite model jason path."""
+ json_path = '{}/{}.json'.format(self.TFLITE_MODEL_JSON_DIR, tflite_model_name)
+ return json_path
+
+ def __load_model__(self, tflite_model_name):
+ """Initialize a ModelMetaData for this model."""
+ model = self.ModelMetaData(self.__get_model_json_path__(tflite_model_name))
+ nnapi_model_name = self.model_name_tflite_to_nnapi(tflite_model_name)
+ self.models[nnapi_model_name] = model
+
+ def model_name_nnapi_to_tflite(self, nnapi_model_name):
+ return self.nnapi_to_tflite_name.get(nnapi_model_name, nnapi_model_name)
+
+ def model_name_tflite_to_nnapi(self, tflite_model_name):
+ return self.tflite_to_nnapi_name.get(tflite_model_name, tflite_model_name)
+
+ def get_model_meta_data(self, nnapi_model_name):
+ """Retrieve the ModelMetaData with lazy initialization."""
+ tflite_model_name = self.model_name_nnapi_to_tflite(nnapi_model_name)
+ if nnapi_model_name not in self.models:
+ self.__load_model__(tflite_model_name)
+ return self.models[nnapi_model_name]
+
+ def generate_animation_html(self, output_file_path, model_names=None):
+ model_names = self.model_names if model_names is None else model_names
+ html_data = ''
+ for model_name in model_names:
+ print('processing', model_name)
+ html_data += '<h3>{}</h3>'.format(model_name)
+ model_data = ModelData(nnapi_model_name=model_name, manager=self)
+ ani = model_data.gen_error_hist_animation()
+ html_data += ani.to_jshtml()
+ with open(output_file_path, 'w') as f:
+ f.write(html_data)
+
+
+################################ TensorDict ################################
+class TensorDict(dict):
+ """A class to store cpu and nnapi tensors.
+
+ # Arguments
+ model_dir: directory containing intermediate tensors pulled from device
+ """
+ def __init__(self, model_dir):
+ super().__init__()
+ for useNNAPIDir in ['cpu', 'nnapi']:
+ dir_path = model_dir + useNNAPIDir + "/"
+ self[useNNAPIDir] = self.read_tensors_from_dir(dir_path)
+ self.tensor_sanity_check()
+ self.max_absolute_diff, self.min_absolute_diff = 0.0, 0.0
+ self.max_relative_diff, self.min_relative_diff = 0.0, 0.0
+ self.layers = sorted(self['cpu'].keys())
+ self.calc_range()
+
+ def bytes_to_numpy_tensor(self, file_path):
+ tensor_type = tf.int8 if 'quant' in file_path else tf.float32
+ with open(file_path, mode='rb') as f:
+ tensor_bytes = f.read()
+ tensor = tf.decode_raw(input_bytes=tensor_bytes, out_type=tensor_type)
+ if np.isnan(np.sum(tensor)):
+ print('WARNING: tensor contains inf or nan')
+ return tensor.numpy()
+
+ def read_tensors_from_dir(self, dir_path):
+ tensor_dict = dict()
+ for tensor_file in os.listdir(dir_path):
+ tensor = self.bytes_to_numpy_tensor(dir_path + tensor_file)
+ tensor_dict[tensor_file] = tensor
+ return tensor_dict
+
+ def tensor_sanity_check(self):
+ # Make sure the cpu tensors and nnapi tensors have the same outputs
+ assert(len(self['cpu']) == len(self['nnapi']))
+ key_diff = set(self['cpu'].keys()) - set(self['nnapi'].keys())
+ assert(len(key_diff) == 0)
+ print('Tensor sanity check passed')
+
+ def calc_range(self):
+ for layer in self.layers:
+ diff = self.calc_diff(layer, relative_error=False)
+ # update absolute max, min
+ self.max_absolute_diff = max(self.max_absolute_diff, np.max(diff))
+ self.min_absolute_diff = min(self.min_absolute_diff, np.min(diff))
+ self.absolute_range = max(abs(self.min_absolute_diff),
+ abs(self.max_absolute_diff))
+
+ def calc_diff(self, layer, relative_error=True):
+ cpu_tensor = self['cpu'][layer]
+ nnapi_tensor = self['nnapi'][layer]
+ assert(cpu_tensor.shape == nnapi_tensor.shape)
+ diff = cpu_tensor - nnapi_tensor
+ if not relative_error:
+ return diff
+ diff = diff.astype(float)
+ cpu_tensor = cpu_tensor.astype(float)
+ max_cpu_nnapi_tensor = np.maximum(np.abs(cpu_tensor), np.abs(nnapi_tensor))
+ relative_diff = np.divide(diff, max_cpu_nnapi_tensor, out=np.zeros_like(diff),\
+ where=max_cpu_nnapi_tensor>0)
+ relative_diff[relative_diff>1] = 1.0
+ relative_diff[relative_diff<-1] = -1.0
+ return relative_diff
+
+ def gen_tensor_diff_stats(self, relative_error=True, return_df=True, plot_diff=False):
+ stats = []
+ for layer in self.layers:
+ diff = self.calc_diff(layer, relative_error)
+ if plot_diff:
+ self.plot_tensor_diff(diff)
+ if return_df:
+ stats.append({
+ 'layer': layer,
+ 'min': np.min(diff),
+ 'max': np.max(diff),
+ 'mean': np.mean(diff),
+ 'median': np.median(diff)
+ })
+ if return_df:
+ return pd.DataFrame(stats)
+
+ def plot_tensor_diff(diff):
+ plt.figure()
+ plt.hist(diff, bins=50, log=True)
+ plt.plot()
+
+
+################################ Model Data ################################
+
+class ModelData(object):
+ """A class to store all relevant inormation of a model.
+
+ # Arguments
+ nnapi_model_name: the name of the model
+ manager: ModelMetaDataManager
+ """
+ def __init__(self, nnapi_model_name, manager):
+ self.nnapi_model_name = nnapi_model_name
+ self.manager = manager
+ self.model_dir = self.get_target_model_dir(manager.DUMP_DIR, nnapi_model_name)
+ self.tensor_dict = TensorDict(self.model_dir)
+ self.mmd = manager.get_model_meta_data(nnapi_model_name)
+ self.stats = self.tensor_dict.gen_tensor_diff_stats(relative_error=True,
+ return_df=True)
+ self.layers = sorted(self.tensor_dict['cpu'].keys())
+
+ def get_target_model_dir(self, dump_dir, target_model_name):
+ target_model_dir = dump_dir + target_model_name + "/"
+ return target_model_dir
+
+ def updateData(self, i, fig, ax1, ax2, bins=50):
+ operation = self.mmd.output_meta_data[i % len(self.mmd.output_meta_data)]['operator_code']
+ layer = self.layers[i]
+ subtitle = fig.suptitle('{} | {}\n{}'
+ .format(self.nnapi_model_name, layer, operation),
+ fontsize='x-large')
+ for ax in (ax1, ax2):
+ ax.clear()
+ ax1.set_title('Relative Error')
+ ax2.set_title('Absolute Error')
+ ax1.hist(self.tensor_dict.calc_diff(layer, relative_error=True), bins=bins,
+ range=(-1, 1), log=True)
+ absolute_range = self.tensor_dict.absolute_range
+ ax2.hist(self.tensor_dict.calc_diff(layer, relative_error=False), bins=bins,
+ range=(-absolute_range, absolute_range), log=True)
+
+ def gen_error_hist_animation(self):
+ # For fast testing, add [:10] to the end of next line
+ layers = self.layers
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12,9))
+ ani = animation.FuncAnimation(fig, self.updateData, len(layers),
+ fargs=(fig, ax1, ax2),
+ interval=200, repeat=False)
+ # close before return to avoid dangling plot
+ plt.close()
+ return ani
+
+ def plot_error_heatmap(self, target_layer, length=1):
+ target_diff = self.tensor_dict['cpu'][target_layer] - \
+ self.tensor_dict['nnapi'][target_layer]
+ width = int(len(target_diff)/ length)
+ reshaped_target_diff = target_diff[:length * width].reshape(length, width)
+ fig, ax = subplots(figsize=(18, 5))
+ plt.title('Heat Map of Error between CPU and NNAPI')
+ plt.imshow(reshaped_target_diff, cmap='hot', interpolation='nearest')
+ plt.colorbar()
+ plt.show()
+
+
+################################NumpyEncoder ################################
+
+class NumpyEncoder(json.JSONEncoder):
+ """Enable numpy array serilization in a dictionary.
+
+ # Usage:
+ a = np.array([[1, 2, 3], [4, 5, 6]])
+ json.dumps({'a': a, 'aa': [2, (2, 3, 4), a], 'bb': [2]}, cls=NumpyEncoder)
+ """
+ def default(self, obj):
+ if isinstance(obj, np.ndarray):
+ return obj.tolist()
+ return json.JSONEncoder.default(self, obj)
+
+def main(android_build_top, dump_dir, model_name):
+ manager = ModelMetaDataManager(
+ android_build_top,
+ dump_dir,
+ tflite_model_json_dir='/tmp')
+ model_data = ModelData(nnapi_model_name=model_name, manager=manager)
+ print(model_data.tensor_dict)
+
+if __name__ == '__main__':
+ # Example usage
+ # python tensor_utils.py ~/android/master/ ~/android/master/intermediate/ tts_float
+ parser = argparse.ArgumentParser(description='Utilities for parsing intermediate tensors.')
+ parser.add_argument('android_build_top', help='Your Android build top path')
+ parser.add_argument('dump_dir', help='The dump dir pulled from the device')
+ parser.add_argument('model_name', help='NNAPI model name')
+ args = parser.parse_args()
+ main(args.android_build_top, args.dump_dir, args.model_name)
\ No newline at end of file