Merge "Fix failure in scoring test caused by timeout"
diff --git a/tools/tensor_utils.py b/tools/tensor_utils.py
index a29fd3c..7d96541 100644
--- a/tools/tensor_utils.py
+++ b/tools/tensor_utils.py
@@ -11,24 +11,49 @@
 import tensorflow as tf
 import matplotlib.pyplot as plt
 import json
+import seaborn as sns
 
 from matplotlib.pylab import *
 import matplotlib.animation as animation
 # Enable tensor.numpy()
 tf.compat.v1.enable_eager_execution()
 
-################################ ModelMetaDataManager ################################
+
+############################ Helper Functions ############################
+def reshape_to_matrix(array):
+  """Reshape an array to a square matrix padded with np.nan at the end."""
+  array = array.astype(float)
+  width = math.ceil(len(array)**0.5)
+  height = math.ceil(len(array)/ width)
+  padded = np.pad(array=array,
+                  pad_width=(0, width * height - len(array)),
+                  mode='constant',
+                  constant_values=np.nan)
+  padded = padded.reshape(width, -1)
+  return padded
+
+def save_ani_to_video(ani, save_video_path, video_fps=5):
+  Writer = animation.writers['ffmpeg']
+  writer = Writer(fps=video_fps)
+  #Save the movie
+  ani.save(save_video_path, writer=writer, dpi=250)
+
+def save_ani_to_html(ani, save_html_path):
+  with open(save_html_path, 'w') as f:
+    f.write(ani.to_jshtml())
+
+############################ ModelMetaDataManager ############################
 class ModelMetaDataManager(object):
   """Maps model name in nnapi to its graph architecture with lazy initialization.
 
   # Arguments
-    android_build_top: the root directory of android source tree
-    dump_dir: directory containing intermediate tensors pulled from device
+    android_build_top: the root directory of android source tree dump_dir:
+    directory containing intermediate tensors pulled from device
     tflite_model_json_path: directory containing intermediate json output of
-    model visualization tool (third_party/tensorflow/lite/tools:visualize)
-    The json output path from the tool is always /tmp.
+    model visualization tool (third_party/tensorflow/lite/tools:visualize) The
+    json output path from the tool is always /tmp.
   """
-
+  ############################ ModelMetaData ############################
   class ModelMetaData(object):
     """Store graph information of a model."""
 
@@ -61,10 +86,10 @@
     self.DUMP_DIR = dump_dir
     self.nnapi_to_tflite_name = dict()
     self.tflite_to_nnapi_name = dict()
-    self.__load_mobilenet_topk_aosp__()
+    self.__load_mobilenet_topk_aosp()
     self.model_names = sorted(os.listdir(dump_dir))
 
-  def __load_mobilenet_topk_aosp__(self):
+  def __load_mobilenet_topk_aosp(self):
     """Load information about tflite and nnapi model names."""
     json_path = '{}/{}'.format(
         self.ANDROID_BUILD_TOP,
@@ -75,14 +100,15 @@
       self.nnapi_to_tflite_name[model['name']] = model['modelFile']
       self.tflite_to_nnapi_name[model['modelFile']] = model['name']
 
-  def __get_model_json_path__(self, tflite_model_name):
+  def __get_model_json_path(self, tflite_model_name):
     """Return tflite model jason path."""
-    json_path = '{}/{}.json'.format(self.TFLITE_MODEL_JSON_DIR, tflite_model_name)
+    json_path = '{}/{}.json'.format(self.TFLITE_MODEL_JSON_DIR,
+                                    tflite_model_name)
     return json_path
 
-  def __load_model__(self, tflite_model_name):
+  def __load_model(self, tflite_model_name):
     """Initialize a ModelMetaData for this model."""
-    model = self.ModelMetaData(self.__get_model_json_path__(tflite_model_name))
+    model = self.ModelMetaData(self.__get_model_json_path(tflite_model_name))
     nnapi_model_name = self.model_name_tflite_to_nnapi(tflite_model_name)
     self.models[nnapi_model_name] = model
 
@@ -96,10 +122,11 @@
     """Retrieve the ModelMetaData with lazy initialization."""
     tflite_model_name = self.model_name_nnapi_to_tflite(nnapi_model_name)
     if nnapi_model_name not in self.models:
-      self.__load_model__(tflite_model_name)
+      self.__load_model(tflite_model_name)
     return self.models[nnapi_model_name]
 
-  def generate_animation_html(self, output_file_path, model_names=None):
+  def generate_animation_html(self, output_file_path, model_names=None, heatmap=True):
+    """Generate a html file containing the hist and heatmap animation of all models"""
     model_names = self.model_names if model_names is None else model_names
     html_data = ''
     for model_name in model_names:
@@ -108,11 +135,14 @@
       model_data = ModelData(nnapi_model_name=model_name, manager=self)
       ani = model_data.gen_error_hist_animation()
       html_data += ani.to_jshtml()
+      if heatmap:
+        ani = model_data.gen_heatmap_animation()
+        html_data += ani.to_jshtml()
     with open(output_file_path, 'w') as f:
       f.write(html_data)
 
 
-################################ TensorDict ################################
+############################ TensorDict ############################
 class TensorDict(dict):
   """A class to store cpu and nnapi tensors.
 
@@ -131,6 +161,7 @@
     self.calc_range()
 
   def bytes_to_numpy_tensor(self, file_path):
+    """Load bytes outputed from DumpIntermediateTensor into numpy tensor."""
     tensor_type = tf.int8 if 'quant' in file_path else tf.float32
     with open(file_path, mode='rb') as f:
       tensor_bytes = f.read()
@@ -148,9 +179,7 @@
 
   def tensor_sanity_check(self):
     # Make sure the cpu tensors and nnapi tensors have the same outputs
-    assert(len(self['cpu']) == len(self['nnapi']))
-    key_diff = set(self['cpu'].keys()) - set(self['nnapi'].keys())
-    assert(len(key_diff) == 0)
+    assert(set(self['cpu'].keys()) == set(self['nnapi'].keys()))
     print('Tensor sanity check passed')
 
   def calc_range(self):
@@ -171,11 +200,10 @@
       return diff
     diff = diff.astype(float)
     cpu_tensor = cpu_tensor.astype(float)
+    # Divide by max so the relative error range is conveniently [-1, 1]
     max_cpu_nnapi_tensor = np.maximum(np.abs(cpu_tensor), np.abs(nnapi_tensor))
-    relative_diff = np.divide(diff, max_cpu_nnapi_tensor, out=np.zeros_like(diff),\
+    relative_diff = np.divide(diff, max_cpu_nnapi_tensor, out=np.zeros_like(diff),
                               where=max_cpu_nnapi_tensor>0)
-    relative_diff[relative_diff>1] = 1.0
-    relative_diff[relative_diff<-1] = -1.0
     return relative_diff
 
   def gen_tensor_diff_stats(self, relative_error=True, return_df=True, plot_diff=False):
@@ -201,8 +229,7 @@
     plt.plot()
 
 
-################################ Model Data ################################
-
+############################ Model Data ############################
 class ModelData(object):
   """A class to store all relevant inormation of a model.
 
@@ -213,62 +240,226 @@
   def __init__(self, nnapi_model_name, manager):
     self.nnapi_model_name = nnapi_model_name
     self.manager = manager
-    self.model_dir = self.get_target_model_dir(manager.DUMP_DIR, nnapi_model_name)
+    self.model_dir = self.get_target_model_dir(manager.DUMP_DIR,
+                                               nnapi_model_name)
     self.tensor_dict = TensorDict(self.model_dir)
     self.mmd = manager.get_model_meta_data(nnapi_model_name)
     self.stats = self.tensor_dict.gen_tensor_diff_stats(relative_error=True,
                                                         return_df=True)
     self.layers = sorted(self.tensor_dict['cpu'].keys())
+    self.cmap = sns.diverging_palette(220, 20, sep=20, as_cmap=True)
 
   def get_target_model_dir(self, dump_dir, target_model_name):
+    # Get the model directory path
     target_model_dir = dump_dir + target_model_name + "/"
     return target_model_dir
 
-  def updateData(self, i, fig, ax1, ax2, bins=50):
+  def __sns_distplot(self, layer, bins, ax, range, relative_error):
+    sns.distplot(self.tensor_dict.calc_diff(layer, relative_error=relative_error), bins=bins,
+             hist_kws={"range":range, "log":True}, ax=ax, kde=False)
+
+  def __plt_hist(self, layer, bins, ax, range, relative_error):
+    ax.hist(self.tensor_dict.calc_diff(layer, relative_error=relative_error), bins=bins,
+             range=range, log=True)
+
+  def update_hist_data(self, i, fig, ax1, ax2, bins=50, plot_library='sns'):
+    # Use % because there may be multiple testing samples
     operation = self.mmd.output_meta_data[i % len(self.mmd.output_meta_data)]['operator_code']
     layer = self.layers[i]
     subtitle = fig.suptitle('{} | {}\n{}'
                       .format(self.nnapi_model_name, layer, operation),
                       fontsize='x-large')
     for ax in (ax1, ax2):
-        ax.clear()
+      ax.clear()
     ax1.set_title('Relative Error')
     ax2.set_title('Absolute Error')
-    ax1.hist(self.tensor_dict.calc_diff(layer, relative_error=True), bins=bins,
-             range=(-1, 1), log=True)
     absolute_range = self.tensor_dict.absolute_range
-    ax2.hist(self.tensor_dict.calc_diff(layer, relative_error=False), bins=bins,
-             range=(-absolute_range, absolute_range), log=True)
 
-  def gen_error_hist_animation(self):
-    # For fast testing, add [:10] to the end of next line
+    # Determine underlying plotting library
+    hist_func = self.__plt_hist if plot_library == 'matplotlib' else self.__sns_distplot
+    hist_func(layer=layer, bins=bins, ax=ax1,
+              range=(-1, 1), relative_error=True)
+    hist_func(layer=layer, bins=bins, ax=ax2,
+              range=(-absolute_range, absolute_range), relative_error=False)
+
+  def gen_error_hist_animation(self, save_video_path=None, video_fps=10):
     layers = self.layers
     fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12,9))
-    ani = animation.FuncAnimation(fig, self.updateData, len(layers),
+    ani = animation.FuncAnimation(fig, self.update_hist_data, len(layers),
                                   fargs=(fig, ax1, ax2),
                                   interval=200, repeat=False)
     # close before return to avoid dangling plot
+    if save_video_path:
+      save_ani_to_video(ani, save_video_path, video_fps)
+    plt.close()
+    return ani
+
+  def __sns_heatmap(self, data, ax, cbar_ax, **kwargs):
+    return sns.heatmap(data, cmap=self.cmap, cbar=True, ax=ax, cbar_ax=cbar_ax,
+                     cbar_kws={"orientation": "horizontal"}, center=0, **kwargs)
+
+  def update_heatmap_data(self, i, fig, axs):
+    # Use % because there may be multiple testing samples
+    operation = self.mmd.output_meta_data[i % len(self.mmd.output_meta_data)]['operator_code']
+    layer = self.layers[i]
+    subtitle = fig.suptitle('{} | {}\n{}\n'
+                      .format(self.nnapi_model_name, layer, operation),
+                      fontsize='x-large')
+    # Clear all the axs and redraw
+    # It's important to clear the colorbars as well to avoid duplicate colorbars
+    for ax_tuple in axs:
+      for ax in ax_tuple:
+        ax.clear()
+    axs[0][0].set_title('Diff')
+    axs[0][1].set_title('CPU Tensor')
+    axs[0][2].set_title('NNAPI Tensor')
+
+    reshaped_diff = reshape_to_matrix(self.tensor_dict.calc_diff(layer, relative_error=False))
+    reshaped_cpu = reshape_to_matrix(self.tensor_dict['cpu'][layer])
+    reshaped_nnapi = reshape_to_matrix(self.tensor_dict['nnapi'][layer])
+    absolute_range = self.tensor_dict.absolute_range
+    g1 = self.__sns_heatmap(data=reshaped_diff, ax=axs[0][0], cbar_ax=axs[1][0],
+                            vmin=-absolute_range, vmax=absolute_range)
+    g2 = self.__sns_heatmap(data=reshaped_cpu, ax=axs[0][1], cbar_ax=axs[1][1])
+    g3 = self.__sns_heatmap(data=reshaped_nnapi, ax=axs[0][2], cbar_ax=axs[1][2])
+
+  def gen_heatmap_animation(self, save_video_path=None, video_fps=10, figsize=(13,6)):
+    layers = self.layers
+    fig = plt.figure(constrained_layout=True, figsize=figsize)
+    widths = [1, 1, 1]
+    heights = [7, 1]
+    spec = fig.add_gridspec(ncols=3, nrows=2, width_ratios=widths,
+                            height_ratios=heights)
+    axs = []
+    for row in range(2):
+      axs.append([])
+      for col in range(3):
+          axs[-1].append(fig.add_subplot(spec[row, col]))
+
+    ani = animation.FuncAnimation(fig, self.update_heatmap_data, len(layers),
+                                  fargs=(fig, axs),
+                                  interval=200, repeat=False)
+    if save_video_path:
+      save_ani_to_video(ani, save_video_path, video_fps)
+    # close before return to avoid dangling plot
     plt.close()
     return ani
 
-  def plot_error_heatmap(self, target_layer, length=1):
+  def plot_error_heatmap(self, target_layer, vmin=None, vmax=None):
+    # Plot the diff heatmap for a given layer
     target_diff = self.tensor_dict['cpu'][target_layer] - \
                   self.tensor_dict['nnapi'][target_layer]
-    width = int(len(target_diff)/ length)
-    reshaped_target_diff = target_diff[:length * width].reshape(length, width)
-    fig, ax = subplots(figsize=(18, 5))
+    reshaped_target_diff = reshape_to_matrix(target_diff)
+    fig, ax = subplots(figsize=(9, 9))
     plt.title('Heat Map of Error between CPU and NNAPI')
-    plt.imshow(reshaped_target_diff, cmap='hot', interpolation='nearest')
-    plt.colorbar()
+    sns.heatmap(reshaped_target_diff,
+                cmap=self.cmap,
+                mask=np.isnan(reshaped_target_diff),
+                center=0)
     plt.show()
 
 
-################################NumpyEncoder ################################
+############################ ModelDataComparison ############################
+class ModelDataComparison:
+  """A class to store and compare multiple ModelData.
 
+  # Arguments
+    model_data_list: a list of ModelData to be compared. Can be modified through
+    the class variable.
+  """
+  def __init__(self, dump_dir_list, android_build_top, tflite_model_json_dir, model_name):
+    self.dump_dir_list = dump_dir_list
+    self.android_build_top = android_build_top
+    self.tflite_model_json_dir = tflite_model_json_dir
+    self.set_model_name(model_name)
+
+  def set_model_name(self, model_name):
+    # Set model to be compared and load/ reload all model data
+    self.model_name = model_name
+    self.__load_data()
+
+  def __load_data(self):
+    # Load all model data
+    self.manager_list = []
+    self.model_data_list = []
+    for i, dump_dir in enumerate(self.dump_dir_list):
+      manager = ModelMetaDataManager(self.android_build_top,
+                                     dump_dir,
+                                     tflite_model_json_dir=self.tflite_model_json_dir)
+      model_data = ModelData(nnapi_model_name=self.model_name, manager=manager)
+      self.manager_list.append(manager)
+      self.model_data_list.append(model_data)
+    self.sanity_check()
+
+  def sanity_check(self):
+    # Check
+    # 1) if there are more than one model to be compared
+    # 2) The data has the same intermediate layers
+    assert(len(self.model_data_list) >= 1)
+    sample_model_data = self.model_data_list[0]
+    for i in range(1, len(self.model_data_list)):
+      assert(set(sample_model_data.tensor_dict['cpu'].keys()) ==
+             set(self.model_data_list[i].tensor_dict['nnapi'].keys()))
+    print('Sanity Check Passed')
+    self.layers = sample_model_data.layers
+    self.mmd = sample_model_data.mmd
+
+  def update_hist_comparison_data(self, i, fig, axs, bins=50):
+    # Use % because there may be multiple testing samples
+    sample_model_data = self.model_data_list[0]
+    operation = self.mmd.output_meta_data[i % len(self.mmd.output_meta_data)]['operator_code']
+    layer = self.layers[i]
+    subtitle = fig.suptitle('{} | {}\n{}'
+                      .format(sample_model_data.nnapi_model_name, layer, operation),
+                      fontsize='x-large')
+    for row in axs:
+      for ax in row:
+        ax.clear()
+
+    hist_ax = axs[0][0]
+    hist_ax.set_title('Diff Histogram')
+    labels = [dump_dir.split('/')[-2] for dump_dir in self.dump_dir_list]
+    cmap = sns.diverging_palette(220, 20, sep=20, as_cmap=True)
+    for i, ax in enumerate(axs[1]):
+      model_data = self.model_data_list[i]
+      axs[1][i].set_title(labels[i])
+      reshaped_diff = reshape_to_matrix(
+        self.model_data_list[i].tensor_dict.calc_diff(layer, relative_error=False))
+      sns.heatmap(reshaped_diff, cmap=cmap, cbar=True, ax=axs[1][i], cbar_ax=axs[2][i],
+                     cbar_kws={"orientation": "horizontal"}, center=0)
+      sns.distplot(model_data.tensor_dict.calc_diff(layer, relative_error=False), bins=bins,
+             hist_kws={"log":True}, ax=hist_ax, kde=False)
+    hist_ax.legend(labels)
+
+  def gen_error_hist_comparison_animation(self, save_video_path=None, video_fps=10):
+    layers = self.layers
+    N = len(self.model_data_list)
+    widths = [1] * N
+    heights = [N * 0.7, 1, 0.2]
+    fig = plt.figure(figsize=(5 * N, 4 * N))
+    gs = fig.add_gridspec(3, N, width_ratios=widths, height_ratios=heights)
+    axs = [[], [], []]
+    axs[0].append(fig.add_subplot(gs[0, :]))
+    for i in range(N):
+      # heatmap
+      axs[1].append(fig.add_subplot(gs[1, i]))
+      # colorbar
+      axs[2].append(fig.add_subplot(gs[2, i]))
+    ani = animation.FuncAnimation(fig, self.update_hist_comparison_data, len(layers),
+                                  fargs=(fig, axs),
+                                  interval=200, repeat=False)
+    if save_video_path:
+      save_ani_to_video(ani, save_video_path, video_fps)
+    # close before return to avoid dangling plot
+    plt.close()
+    return ani
+
+
+############################ NumpyEncoder ############################
 class NumpyEncoder(json.JSONEncoder):
   """Enable numpy array serilization in a dictionary.
 
-  # Usage:
+  Usage:
     a = np.array([[1, 2, 3], [4, 5, 6]])
     json.dumps({'a': a, 'aa': [2, (2, 3, 4), a], 'bb': [2]}, cls=NumpyEncoder)
   """
@@ -277,20 +468,28 @@
           return obj.tolist()
       return json.JSONEncoder.default(self, obj)
 
-def main(android_build_top, dump_dir, model_name):
+def main(android_build_top, dump_dir, model_name, output_file_path=None):
+  if output_file_path is None:
+    output_file_path = '/tmp/intermediate.html'
   manager = ModelMetaDataManager(
     android_build_top,
     dump_dir,
     tflite_model_json_dir='/tmp')
-  model_data = ModelData(nnapi_model_name=model_name, manager=manager)
-  print(model_data.tensor_dict)
+  if model_name:
+    model_data = ModelData(nnapi_model_name=model_name, manager=manager)
+    print(model_data.tensor_dict)
+    manager.generate_animation_html(output_file_path=output_file_path, model_names=[model_name])
+  else:
+    manager.generate_animation_html(output_file_path=output_file_path)
+
 
 if __name__ == '__main__':
   # Example usage
   # python tensor_utils.py ~/android/master/ ~/android/master/intermediate/ tts_float
   parser = argparse.ArgumentParser(description='Utilities for parsing intermediate tensors.')
-  parser.add_argument('android_build_top', help='Your Android build top path')
-  parser.add_argument('dump_dir', help='The dump dir pulled from the device')
-  parser.add_argument('model_name', help='NNAPI model name')
+  parser.add_argument('android_build_top', help='Your Android build top path.')
+  parser.add_argument('dump_dir', help='The dump dir pulled from the device.')
+  parser.add_argument('--model_name', help='NNAPI model name. Run all models if not specified.')
+  parser.add_argument('--output_file_path', help='Animation HTML path.')
   args = parser.parse_args()
-  main(args.android_build_top, args.dump_dir, args.model_name)
\ No newline at end of file
+  main(args.android_build_top, args.dump_dir, args.model_name, args.output_file_path)
\ No newline at end of file