[automerger skipped] Improve script for building and running dump intermediate tensors. am: 0288bc3de7 am: e86341ca48 -s ours
am skip reason: Change-Id I4d2994d5b69fade7f29b48e387d0764469480d3a with SHA-1 ff911b7ecb is in history
Original change: https://android-review.googlesource.com/c/platform/test/mlts/benchmark/+/1393199
Change-Id: I6b33e9ab10b5dd3f537407c484e016a63e9fe94b
diff --git a/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java b/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
index 0ce2fe8..a2d485f 100644
--- a/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
+++ b/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
@@ -122,7 +122,7 @@
mNotificationManager = NotificationManagerCompat.from(this);
NotificationChannel channel =
new NotificationChannel(CHANNEL_ID, "Default", NotificationManager.IMPORTANCE_LOW);
- // mNotificationManager.createNotificationChannel(channel);
+ mNotificationManager.createNotificationChannel(channel);
mNotificationManager = NotificationManagerCompat.from(this);
String title = "NN API Dogfood";
String msg = String.format("Background test %d of %d is running", getNumRuns(), NUM_RUNS);
diff --git a/tools/gen_tflite_visualization.sh b/tools/gen_tflite_visualization.sh
index 5dbe893..67bb543 100755
--- a/tools/gen_tflite_visualization.sh
+++ b/tools/gen_tflite_visualization.sh
@@ -1,23 +1,38 @@
#!/bin/bash
-# Prereq:
-# g4d -f NAME
-# blaze build third_party/tensorflow/lite/tools:visualize
+# This script generate visualizations and metadata json files of the tflite models
+# Results are stored in /tmp by default
+# Prerequisites:
+# Follow the link to run the visualize.py
+# https://www.tensorflow.org/lite/guide/faq#how_do_i_inspect_a_tflite_file
-ANDROID_BUILD_TOP="/usr/local/google/home/$(whoami)/android/master"
-MODEL_DIR="$ANDROID_BUILD_TOP/test/mlts/models/assets"
-# The .json files are always output to /tmp
-HTML_DIR="/tmp"
+if [[ -z "$ANDROID_BUILD_TOP" ]]; then
+ echo ANDROID_BUILD_TOP not set, bailing out
+ echo you must run lunch before running this script
+ exit 1
+fi
-mkdir -p $HTML_DIR
+echo "Follow the link to set up the prerequisites for running visualize.py: \
+https://www.tensorflow.org/lite/guide/faq#how_do_i_inspect_a_tflite_file"
+read -p "Are you able to run the visualize.py script with bazel? [Y/N]" -n 1 -r
+echo
+if [[ $REPLY =~ ^[Yy]$ ]]; then
+ MODEL_DIR="$ANDROID_BUILD_TOP/test/mlts/models/assets"
+ # The .json files are always output to /tmp by the tflite visualize tool
+ HTML_DIR="${1:-/tmp}"
+ mkdir -p $HTML_DIR
-for file in "$MODEL_DIR"/*.tflite
-do
- if [ -f "$file" ]; then
- filename=`basename $file`
- modelname=${filename%.*}
- blaze-bin/third_party/tensorflow/lite/tools/visualize $file $HTML_DIR/$modelname.html
- fi
-done
+ set -e
+ for file in "$MODEL_DIR"/*.tflite
+ do
+ if [ -f "$file" ]; then
+ filename=`basename $file`
+ modelname=${filename%.*}
+ bazel run //tensorflow/lite/tools:visualize $file $HTML_DIR/$modelname.html
+ fi
+ done
+else
+ echo "Please set up first following https://www.tensorflow.org/lite/guide/faq#how_do_i_inspect_a_tflite_file."
+fi
-# Example visualization: blaze-bin/third_party/tensorflow/lite/tools/visualize ~/android/master/test/mlts/models/assets/mobilenet_v1_0.75_192.tflite /tmp/mobilenet_v1_0.75_192.html
+exit
\ No newline at end of file
diff --git a/tools/requirements.txt b/tools/requirements.txt
new file mode 100644
index 0000000..c9c40d8
--- /dev/null
+++ b/tools/requirements.txt
@@ -0,0 +1,211 @@
+absl-py==0.7.1
+alabaster==0.7.12
+anaconda-client==1.7.2
+anaconda-navigator==1.9.7
+anaconda-project==0.8.2
+asn1crypto==0.24.0
+astor==0.8.0
+astroid==2.2.5
+astropy==3.1.2
+atomicwrites==1.3.0
+attrs==19.1.0
+Babel==2.6.0
+backcall==0.1.0
+backports.os==0.1.1
+backports.shutil-get-terminal-size==1.0.0
+beautifulsoup4==4.7.1
+bitarray==0.8.3
+bkcharts==0.2
+bleach==3.1.0
+bokeh==1.0.4
+boto==2.49.0
+Bottleneck==1.2.1
+certifi==2019.3.9
+cffi==1.12.2
+chardet==3.0.4
+Click==7.0
+cloudpickle==0.8.0
+clyent==1.2.2
+colorama==0.4.1
+conda==4.6.11
+conda-build==3.17.8
+conda-verify==3.1.1
+contextlib2==0.5.5
+cryptography==2.6.1
+cycler==0.10.0
+Cython==0.29.6
+cytoolz==0.9.0.1
+dask==1.1.4
+decorator==4.4.0
+defusedxml==0.5.0
+distributed==1.26.0
+docutils==0.14
+entrypoints==0.3
+et-xmlfile==1.0.1
+fastcache==1.0.2
+filelock==3.0.10
+Flask==1.0.2
+future==0.17.1
+gast==0.2.2
+gevent==1.4.0
+glob2==0.6
+gmpy2==2.0.8
+google-pasta==0.1.7
+greenlet==0.4.15
+grpcio==1.22.0
+h5py==2.9.0
+heapdict==1.0.0
+html5lib==1.0.1
+idna==2.8
+imageio==2.5.0
+imagesize==1.1.0
+importlib-metadata==0.0.0
+ipykernel==5.1.0
+ipython==7.4.0
+ipython-genutils==0.2.0
+ipywidgets==7.4.2
+isort==4.3.16
+itsdangerous==1.1.0
+jdcal==1.4
+jedi==0.13.3
+jeepney==0.4
+Jinja2==2.10
+jsonschema==3.0.1
+jupyter==1.0.0
+jupyter-client==5.2.4
+jupyter-console==6.0.0
+jupyter-core==4.4.0
+jupyterlab==0.35.4
+jupyterlab-server==0.2.0
+Keras-Applications==1.0.8
+Keras-Preprocessing==1.1.0
+keyring==18.0.0
+kiwisolver==1.0.1
+lazy-object-proxy==1.3.1
+libarchive-c==2.8
+lief==0.9.0
+llvmlite==0.28.0
+locket==0.2.0
+lxml==4.3.2
+Markdown==3.1.1
+MarkupSafe==1.1.1
+matplotlib==3.0.3
+mccabe==0.6.1
+mistune==0.8.4
+mkl-fft==1.0.10
+mkl-random==1.0.2
+more-itertools==6.0.0
+mpmath==1.1.0
+msgpack==0.6.1
+multipledispatch==0.6.0
+navigator-updater==0.2.1
+nbconvert==5.4.1
+nbformat==4.4.0
+networkx==2.2
+nltk==3.4
+nose==1.3.7
+notebook==5.7.8
+numba==0.43.1
+numexpr==2.6.9
+numpy==1.16.2
+numpydoc==0.8.0
+olefile==0.46
+openpyxl==2.6.1
+packaging==19.0
+pandas==0.24.2
+pandocfilters==1.4.2
+parso==0.3.4
+partd==0.3.10
+path.py==11.5.0
+pathlib2==2.3.3
+patsy==0.5.1
+pep8==1.7.1
+pexpect==4.6.0
+pickleshare==0.7.5
+Pillow==5.4.1
+pkginfo==1.5.0.1
+pluggy==0.9.0
+ply==3.11
+prometheus-client==0.6.0
+prompt-toolkit==2.0.9
+protobuf==3.8.0
+psutil==5.6.1
+ptyprocess==0.6.0
+py==1.8.0
+pycodestyle==2.5.0
+pycosat==0.6.3
+pycparser==2.19
+pycrypto==2.6.1
+pycurl==7.43.0.2
+pyflakes==2.1.1
+Pygments==2.3.1
+pylint==2.3.1
+pyodbc==4.0.26
+pyOpenSSL==19.0.0
+pyparsing==2.3.1
+pyrsistent==0.14.11
+PySocks==1.6.8
+pytest==4.3.1
+pytest-arraydiff==0.3
+pytest-astropy==0.5.0
+pytest-doctestplus==0.3.0
+pytest-openfiles==0.3.2
+pytest-remotedata==0.3.1
+python-dateutil==2.8.0
+pytz==2018.9
+PyWavelets==1.0.2
+PyYAML==5.1
+pyzmq==18.0.0
+QtAwesome==0.5.7
+qtconsole==4.4.3
+QtPy==1.7.0
+requests==2.21.0
+rope==0.12.0
+ruamel-yaml==0.15.46
+scikit-image==0.14.2
+scikit-learn==0.20.3
+scipy==1.2.1
+seaborn==0.9.0
+SecretStorage==3.1.1
+Send2Trash==1.5.0
+simplegeneric==0.8.1
+singledispatch==3.4.0.3
+six==1.12.0
+snowballstemmer==1.2.1
+sortedcollections==1.1.2
+sortedcontainers==2.1.0
+soupsieve==1.8
+Sphinx==1.8.5
+sphinxcontrib-websupport==1.1.0
+spyder==3.3.3
+spyder-kernels==0.4.2
+SQLAlchemy==1.3.1
+statsmodels==0.9.0
+sympy==1.3
+tables==3.5.1
+tb-nightly==1.14.0a20190603
+tblib==1.3.2
+tensorboard==1.14.0
+tensorflow==1.14.0
+tensorflow-estimator==1.14.0
+termcolor==1.1.0
+terminado==0.8.1
+testpath==0.4.2
+tf-estimator-nightly==1.14.0.dev2019060501
+toolz==0.9.0
+tornado==6.0.2
+tqdm==4.31.1
+traitlets==4.3.2
+unicodecsv==0.14.1
+urllib3==1.24.1
+wcwidth==0.1.7
+webencodings==0.5.1
+Werkzeug==0.14.1
+widgetsnbextension==3.4.2
+wrapt==1.11.1
+wurlitzer==1.0.2
+xlrd==1.2.0
+XlsxWriter==1.1.5
+xlwt==1.3.0
+zict==0.1.4
+zipp==0.3.3
diff --git a/tools/tensor_utils.py b/tools/tensor_utils.py
index cdb4682..1200195 100644
--- a/tools/tensor_utils.py
+++ b/tools/tensor_utils.py
@@ -5,16 +5,23 @@
"""
import argparse
+import datetime
import numpy as np
import os
import pandas as pd
import tensorflow as tf
-import matplotlib.pyplot as plt
import json
import seaborn as sns
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+import multiprocessing
from matplotlib.pylab import *
-import matplotlib.animation as animation
+from tqdm import tqdm
+# Enable large animation size
+matplotlib.rcParams['animation.embed_limit'] = 2**128
# Enable tensor.numpy()
tf.compat.v1.enable_eager_execution()
@@ -81,9 +88,9 @@
def __init__(self, android_build_top, dump_dir, tflite_model_json_dir='/tmp'):
# key: nnapi model name, value: ModelMetaData
self.models = dict()
- self.ANDROID_BUILD_TOP = android_build_top
- self.TFLITE_MODEL_JSON_DIR = tflite_model_json_dir
- self.DUMP_DIR = dump_dir
+ self.ANDROID_BUILD_TOP = android_build_top + "/"
+ self.TFLITE_MODEL_JSON_DIR = tflite_model_json_dir + "/"
+ self.DUMP_DIR = dump_dir + "/"
self.nnapi_to_tflite_name = dict()
self.tflite_to_nnapi_name = dict()
self.__load_mobilenet_topk_aosp()
@@ -129,8 +136,8 @@
"""Generate a html file containing the hist and heatmap animation of all models"""
model_names = self.model_names if model_names is None else model_names
html_data = ''
- for model_name in model_names:
- print('processing', model_name)
+ for model_name in tqdm(model_names):
+ print(datetime.datetime.now(), 'Processing', model_name)
html_data += '<h3>{}</h3>'.format(model_name)
model_data = ModelData(nnapi_model_name=model_name, manager=self)
ani = model_data.gen_error_hist_animation()
@@ -141,6 +148,47 @@
with open(output_file_path, 'w') as f:
f.write(html_data)
+ def generate_hist_animation_html(self, model_name):
+ """Generate a html hist animation for a model, used for multiprocessing"""
+ html_data = '<h3>{}</h3>'.format(model_name)
+ model_data = ModelData(nnapi_model_name=model_name, manager=self)
+ ani = model_data.gen_error_hist_animation()
+ html_data += ani.to_jshtml()
+ print(datetime.datetime.now(), "Done histogram for", model_name)
+ self.return_dict[model_name + "-hist"] = html_data
+
+ def generate_heatmap_animation_html(self, model_name):
+ """Generate a html hist animation for a model, used for multiprocessing"""
+ model_data = ModelData(nnapi_model_name=model_name, manager=self)
+ ani = model_data.gen_heatmap_animation()
+ html_data = ani.to_jshtml()
+ print(datetime.datetime.now(), "Done heatmap for", model_name)
+ self.return_dict[model_name + "-heatmap"] = html_data
+
+ def multiprocessing_generate_animation_html(self, output_file_path,
+ model_names=None, heatmap=True):
+ """
+ Generate a html file containing the hist and heatmap animation of all models
+ with multiple process.
+ """
+ model_names = self.model_names if model_names is None else model_names
+ manager = multiprocessing.Manager()
+ self.return_dict = manager.dict()
+ jobs = []
+ for model_name in model_names:
+ for target_func in [self.generate_hist_animation_html, self.generate_heatmap_animation_html]:
+ p = multiprocessing.Process(target=target_func, args=(model_name,))
+ jobs.append(p)
+ p.start()
+ # wait for completion
+ for job in jobs:
+ job.join()
+
+ with open(output_file_path, 'w') as f:
+ for model_name in model_names:
+ f.write(self.return_dict[model_name + "-hist"])
+ f.write(self.return_dict[model_name + "-heatmap"])
+
############################ TensorDict ############################
class TensorDict(dict):
@@ -203,7 +251,7 @@
return diff
diff = diff.astype(float)
cpu_tensor = cpu_tensor.astype(float)
- # Divide by max so the relative error range is conveniently [-1, 1]
+ # Devide by max so the relative error range is conveniently [-1, 1]
max_cpu_nnapi_tensor = np.maximum(np.abs(cpu_tensor), np.abs(nnapi_tensor))
relative_diff = np.divide(diff, max_cpu_nnapi_tensor, out=np.zeros_like(diff),
where=max_cpu_nnapi_tensor>0)
@@ -240,7 +288,7 @@
nnapi_model_name: the name of the model
manager: ModelMetaDataManager
"""
- def __init__(self, nnapi_model_name, manager):
+ def __init__(self, nnapi_model_name, manager, seq_limit=10):
self.nnapi_model_name = nnapi_model_name
self.manager = manager
self.model_dir = self.get_target_model_dir(manager.DUMP_DIR,
@@ -251,6 +299,7 @@
return_df=True)
self.layers = sorted(self.tensor_dict['cpu'].keys())
self.cmap = sns.diverging_palette(220, 20, sep=20, as_cmap=True)
+ self.seq_limit = seq_limit
def get_target_model_dir(self, dump_dir, target_model_name):
# Get the model directory path
@@ -265,6 +314,11 @@
ax.hist(self.tensor_dict.calc_diff(layer, relative_error=relative_error), bins=bins,
range=range, log=True)
+ def __get_layer_num(self):
+ if self.seq_limit:
+ return min(len(self.layers), len(self.mmd.output_meta_data) * self.seq_limit)
+ return len(self.layers)
+
def update_hist_data(self, i, fig, ax1, ax2, bins=50, plot_library='sns'):
# Use % because there may be multiple testing samples
operation = self.mmd.output_meta_data[i % len(self.mmd.output_meta_data)]['operator_code']
@@ -286,9 +340,8 @@
range=(-absolute_range, absolute_range), relative_error=False)
def gen_error_hist_animation(self, save_video_path=None, video_fps=10):
- layers = self.layers
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12,9))
- ani = animation.FuncAnimation(fig, self.update_hist_data, len(layers),
+ ani = animation.FuncAnimation(fig, self.update_hist_data, self.__get_layer_num(),
fargs=(fig, ax1, ax2),
interval=200, repeat=False)
# close before return to avoid dangling plot
@@ -327,7 +380,6 @@
g3 = self.__sns_heatmap(data=reshaped_nnapi, ax=axs[0][2], cbar_ax=axs[1][2])
def gen_heatmap_animation(self, save_video_path=None, video_fps=10, figsize=(13,6)):
- layers = self.layers
fig = plt.figure(constrained_layout=True, figsize=figsize)
widths = [1, 1, 1]
heights = [7, 1]
@@ -339,7 +391,7 @@
for col in range(3):
axs[-1].append(fig.add_subplot(spec[row, col]))
- ani = animation.FuncAnimation(fig, self.update_heatmap_data, len(layers),
+ ani = animation.FuncAnimation(fig, self.update_heatmap_data, self.__get_layer_num(),
fargs=(fig, axs),
interval=200, repeat=False)
if save_video_path:
@@ -471,19 +523,25 @@
return obj.tolist()
return json.JSONEncoder.default(self, obj)
-def main(android_build_top, dump_dir, model_name, output_file_path=None):
- if output_file_path is None:
- output_file_path = '/tmp/intermediate.html'
+def main(args):
+ output_file_path = args.output_file_path if args.output_file_path else '/tmp/intermediate.html'
+
manager = ModelMetaDataManager(
- android_build_top,
- dump_dir,
+ args.android_build_top,
+ args.dump_dir,
tflite_model_json_dir='/tmp')
- if model_name:
+
+ if args.no_parallel or args.model_name:
+ generation_func = manager.generate_animation_html
+ else:
+ generation_func = manager.multiprocessing_generate_animation_html
+
+ if args.model_name:
model_data = ModelData(nnapi_model_name=model_name, manager=manager)
print(model_data.tensor_dict)
- manager.generate_animation_html(output_file_path=output_file_path, model_names=[model_name])
+ generation_func(output_file_path=output_file_path, model_names=[args.model_name])
else:
- manager.generate_animation_html(output_file_path=output_file_path)
+ generation_func(output_file_path=output_file_path)
if __name__ == '__main__':
@@ -494,5 +552,6 @@
parser.add_argument('dump_dir', help='The dump dir pulled from the device.')
parser.add_argument('--model_name', help='NNAPI model name. Run all models if not specified.')
parser.add_argument('--output_file_path', help='Animation HTML path.')
+ parser.add_argument('--no_parallel', help='Run on a single process instead of multiple processes.')
args = parser.parse_args()
- main(args.android_build_top, args.dump_dir, args.model_name, args.output_file_path)
\ No newline at end of file
+ main(args)
\ No newline at end of file