Add model-wise comparison with float golden model.

PiperOrigin-RevId: 354013412
Change-Id: Iebc07fa89c61121f09362fc3409a048abf726688
diff --git a/tensorflow/lite/experimental/quantization_debugger/debugger.py b/tensorflow/lite/experimental/quantization_debugger/debugger.py
index 2d0793d..25ea2dc 100644
--- a/tensorflow/lite/experimental/quantization_debugger/debugger.py
+++ b/tensorflow/lite/experimental/quantization_debugger/debugger.py
@@ -35,20 +35,27 @@
 class QuantizationDebugOptions:
   """Debug options to set up a given QuantizationDebugger."""
 
-  def __init__(
-      self,
-      layer_debug_metrics: Optional[Mapping[str, Callable[[np.ndarray],
-                                                          float]]] = None
-  ) -> None:
+  def __init__(self,
+               layer_debug_metrics: Optional[Mapping[str,
+                                                     Callable[[np.ndarray],
+                                                              float]]] = None,
+               model_debug_metrics: Optional[Mapping[
+                   str, Callable[[Sequence[np.ndarray], Sequence[np.ndarray]],
+                                 float]]] = None):
     """Initializes debugger options.
 
     Args:
       layer_debug_metrics: a dict to specify layer debug functions
-        {function_name_str: function} where the function accpets result of
+        {function_name_str: function} where the function accepts result of
           NumericVerify Op, which is value difference between float and
           dequantized op results. The function returns single scalar value.
+      model_debug_metrics: a dict to specify model debug functions
+        {function_name_str: function} where the function accepts outputs from
+          two models, and returns single scalar value for a metric. (e.g.
+          accuracy, IoU)
     """
     self.layer_debug_metrics = layer_debug_metrics
+    self.model_debug_metrics = model_debug_metrics
 
 
 @tf_export.tf_export(v1=['lite.experimental.QuantizationDebugger'])
@@ -64,14 +71,18 @@
       self,
       quant_debug_model_path: Optional[str] = None,
       quant_debug_model_content: Optional[bytes] = None,
+      float_model_path: Optional[str] = None,
+      float_model_content: Optional[bytes] = None,
       debug_dataset: Optional[Callable[[],
                                        Iterable[Sequence[np.ndarray]]]] = None,
       debug_options: Optional[QuantizationDebugOptions] = None) -> None:
     """Runs the TFLite debugging model with given debug options.
 
     Args:
-      quant_debug_model_path: Path to debug mode TF-Lite Flatbuffer file.
-      quant_debug_model_content: Content of the quantized debug model.
+      quant_debug_model_path: Path to the quantized debug TFLite model file.
+      quant_debug_model_content: Content of the quantized debug TFLite model.
+      float_model_path: Path to float TFLite model file.
+      float_model_content: Content of the float TFLite model.
       debug_dataset: a factory function that returns dataset generator which is
         used to generate input samples (list of np.ndarray) for the model. The
         generated elements must have same types and shape as inputs to the
@@ -84,6 +95,8 @@
     Attributes:
       layer_statistics: results of error metrics for each NumericVerify op
         results. in {layer_name: {metric_name: metric}} format.
+      model_statistics: results of error metrics for difference between float
+        and quantized models. in {metric_name: metric} format.
     """
     self._data_gen = debug_dataset
     self._debug_options = debug_options or QuantizationDebugOptions()
@@ -91,6 +104,9 @@
     input_data = next(iter(self._data_gen()))
     self._quant_interpreter = tf.lite.Interpreter(quant_debug_model_path,
                                                   quant_debug_model_content)
+    if self._debug_options.model_debug_metrics:
+      self._float_interpreter = tf.lite.Interpreter(float_model_path,
+                                                    float_model_content)
 
     self._numeric_verify_tensor_details = None
     if not self._get_numeric_verify_tensor_details():
@@ -101,10 +117,13 @@
       self._layer_debug_metrics.update(self._debug_options.layer_debug_metrics)
 
     self.layer_statistics = None
+    self.model_statistics = None
 
   def run(self) -> None:
     """Runs models and gets metrics."""
     self.layer_statistics = self._collect_layer_statistics()
+    if self._debug_options.model_debug_metrics:
+      self.model_statistics = self._collect_model_statistics()
 
   def _collect_layer_statistics(self) -> Dict[str, Dict[str, float]]:
     """Collects layer statistics by applying layer debug metrics.
@@ -142,6 +161,49 @@
 
     return layer_statistics
 
+  def _collect_model_statistics(self) -> Dict[str, float]:
+    """Collects model output metrics.
+
+    For all data from the given RepresentativeDataset, collect all model output
+    results from float model & quantized debug model, and calculate metrics
+    by using model output functions. As a result, self.model_results is filled,
+
+    where self.model_results[model_output_function_name] = `aggregated model
+    output function value` (a scalar).
+
+    Returns:
+      aggregated per-model output discrepancy mertics.
+      {metric_name: aggregated_metric}
+    """
+
+    model_statistics = collections.defaultdict(list)
+
+    initialize = True
+    for tensor_data in self._data_gen():
+      self._set_input_tensors(self._quant_interpreter, tensor_data, initialize)
+      self._set_input_tensors(self._float_interpreter, tensor_data, initialize)
+      initialize = False
+
+      # Run the models.
+      self._quant_interpreter.invoke()
+      self._float_interpreter.invoke()
+
+      # Collect the output results from both models.
+      float_tensor_data = self._get_output_tensors(self._float_interpreter)
+      quant_tensor_data = self._get_output_tensors(self._quant_interpreter)
+
+      # Calculate the metrics.
+      for (metric_name,
+           metric_fn) in self._debug_options.model_debug_metrics.items():
+        model_statistics[metric_name].append(
+            metric_fn(float_tensor_data, quant_tensor_data))
+
+    # Calculate final aggregated metrics for each outputs.
+    return {
+        metric_name: np.mean(metric)
+        for metric_name, metric in model_statistics.items()
+    }
+
   def _set_input_tensors(
       self,
       interpreter: tf.lite.Interpreter,
@@ -176,6 +238,21 @@
     for input_idx, tensor in zip(input_indices, tensor_data):
       interpreter.set_tensor(input_idx, tensor)
 
+  def _get_output_tensors(self,
+                          interpreter: tf.lite.Interpreter) -> List[np.ndarray]:
+    """Returns output tensors of given TFLite model Interpreter.
+
+    Args:
+      interpreter: a tf.lite.Interpreter object with allocated tensors.
+
+    Returns:
+      a list of numpy arrays representing output tensor results.
+    """
+    return [
+        interpreter.get_tensor(tensor['index'])
+        for tensor in interpreter.get_output_details()
+    ]
+
   def _get_numeric_verify_tensor_details(self) -> List[str]:
     """Returns all names of all tensors from NumericVerify op."""
     if not self._numeric_verify_tensor_details:
diff --git a/tensorflow/lite/experimental/quantization_debugger/debugger_test.py b/tensorflow/lite/experimental/quantization_debugger/debugger_test.py
index 2c2a145..12a5a7a 100644
--- a/tensorflow/lite/experimental/quantization_debugger/debugger_test.py
+++ b/tensorflow/lite/experimental/quantization_debugger/debugger_test.py
@@ -48,6 +48,12 @@
     yield [np.arange(9).reshape((1, 3, 3, 1)).astype(np.float32) * i]
 
 
+def _convert_model(func):
+  """Converts TF model to TFLite float model."""
+  converter = lite.TFLiteConverterV2.from_concrete_functions([func])
+  return converter.convert()
+
+
 def _quantize_model(func, calibration_gen, debug=True):
   """Quantizes model, in debug or normal mode."""
   converter = lite.TFLiteConverterV2.from_concrete_functions([func])
@@ -70,11 +76,12 @@
   @classmethod
   def setUpClass(cls):
     super().setUpClass()
-    cls.float_model = _get_model()
-    cls.debug_model = _quantize_model(cls.float_model, _calibration_gen)
+    cls.tf_model = _get_model()
+    cls.float_model = _convert_model(cls.tf_model)
+    cls.debug_model = _quantize_model(cls.tf_model, _calibration_gen)
 
   @test_util.run_v2_only
-  def test_quantization_debugger(self):
+  def test_quantization_debugger_layer_metrics(self):
     options = debugger.QuantizationDebugOptions(
         layer_debug_metrics={'l1_norm': lambda diffs: np.mean(np.abs(diffs))})
     quant_debugger = debugger.QuantizationDebugger(
@@ -99,6 +106,24 @@
       self.assertAlmostEqual(value, actual_metrics[key], places=5)
 
   @test_util.run_v2_only
+  def test_quantization_debugger_model_metrics(self):
+    options = debugger.QuantizationDebugOptions(
+        model_debug_metrics={'stdev': lambda x, y: np.std(x[0] - y[0])})
+    quant_debugger = debugger.QuantizationDebugger(
+        quant_debug_model_content=QuantizationDebuggerTest.debug_model,
+        float_model_content=QuantizationDebuggerTest.float_model,
+        debug_dataset=_calibration_gen,
+        debug_options=options)
+    quant_debugger.run()
+
+    expected_metrics = {'stdev': 0.050998904}
+    actual_metrics = quant_debugger.model_statistics
+
+    self.assertCountEqual(expected_metrics.keys(), actual_metrics.keys())
+    for key, value in expected_metrics.items():
+      self.assertAlmostEqual(value, actual_metrics[key], places=5)
+
+  @test_util.run_v2_only
   def test_quantization_debugger_wrong_input_raises_ValueError(self):
 
     def wrong_calibration_gen():
@@ -118,7 +143,7 @@
   @test_util.run_v2_only
   def test_quantization_debugger_non_debug_model_raises_ValueError(self):
     normal_quant_model = _quantize_model(
-        QuantizationDebuggerTest.float_model, _calibration_gen, debug=False)
+        QuantizationDebuggerTest.tf_model, _calibration_gen, debug=False)
 
     with self.assertRaisesRegex(
         ValueError, 'Please check if the quantized model is in debug mode'):