Adding serialization unit tests for the different combinations of compile metric inputs.

PiperOrigin-RevId: 286305667
Change-Id: I4b016ac41492d52f361a121f2c303cdd753fbae8
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index b5f565f..1c14fb1 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -724,6 +724,20 @@
 )
 
 tf_py_test(
+    name = "metrics_serialization_test",
+    size = "medium",
+    srcs = ["saving/metrics_serialization_test.py"],
+    python_version = "PY3",
+    shard_count = 4,
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
+tf_py_test(
     name = "advanced_activations_test",
     size = "medium",
     srcs = ["layers/advanced_activations_test.py"],
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 164d9ba..bf315be 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -38,7 +38,6 @@
 from tensorflow.python.framework import random_seed
 from tensorflow.python.keras import keras_parameterized
 from tensorflow.python.keras import testing_utils
-from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import sequential
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
 from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
@@ -644,18 +643,10 @@
       ds = dataset_ops.Dataset.from_tensor_slices((train_input, train_label))
       return ds.batch(8, drop_remainder=True)
 
-    class Bias(base_layer.Layer):
-
-      def build(self, input_shape):
-        self.bias = self.add_variable('bias', (1,), initializer='zeros')
-
-      def call(self, inputs):
-        return inputs + self.bias
-
     # Very simple bias model to eliminate randomness.
     optimizer = gradient_descent.SGD(0.1)
     model = sequential.Sequential()
-    model.add(Bias(input_shape=(1,)))
+    model.add(testing_utils.Bias(input_shape=(1,)))
     model.compile(loss='mae', optimizer=optimizer, metrics=['mae'])
     train_ds = get_input_datasets()
 
diff --git a/tensorflow/python/keras/engine/correctness_test.py b/tensorflow/python/keras/engine/correctness_test.py
index 3f75b2b..9f6f418 100644
--- a/tensorflow/python/keras/engine/correctness_test.py
+++ b/tensorflow/python/keras/engine/correctness_test.py
@@ -27,23 +27,13 @@
 from tensorflow.python.platform import test
 
 
-class Bias(keras.layers.Layer):
-  """Layer that add a bias to its inputs."""
-
-  def build(self, input_shape):
-    self.bias = self.add_variable('bias', (1,), initializer='zeros')
-
-  def call(self, inputs):
-    return inputs + self.bias
-
-
 class MultiInputSubclassed(keras.Model):
   """Subclassed Model that adds its inputs and then adds a bias."""
 
   def __init__(self):
     super(MultiInputSubclassed, self).__init__()
     self.add = keras.layers.Add()
-    self.bias = Bias()
+    self.bias = testing_utils.Bias()
 
   def call(self, inputs):
     added = self.add(inputs)
@@ -56,7 +46,7 @@
   input_2 = keras.Input(shape=(1,))
   input_3 = keras.Input(shape=(1,))
   added = keras.layers.Add()([input_1, input_2, input_3])
-  output = Bias()(added)
+  output = testing_utils.Bias()(added)
   return keras.Model([input_1, input_2, input_3], output)
 
 
@@ -65,7 +55,8 @@
 class SimpleBiasTest(keras_parameterized.TestCase):
 
   def _get_simple_bias_model(self):
-    model = testing_utils.get_model_from_layers([Bias()], input_shape=(1,))
+    model = testing_utils.get_model_from_layers([testing_utils.Bias()],
+                                                input_shape=(1,))
     model.compile(
         keras.optimizer_v2.gradient_descent.SGD(0.1),
         'mae',
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 2f268b9..a83689c 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -1476,17 +1476,9 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_add_loss_correctness(self):
-    class Bias(keras.layers.Layer):
-
-      def build(self, input_shape):
-        self.bias = self.add_variable('bias', (1,), initializer='zeros')
-
-      def call(self, inputs):
-        return inputs + self.bias
-
     inputs = keras.Input(shape=(1,))
     targets = keras.Input(shape=(1,))
-    outputs = Bias()(inputs)
+    outputs = testing_utils.Bias()(inputs)
     model = keras.Model([inputs, targets], outputs)
 
     model.add_loss(2 * math_ops.reduce_mean(
@@ -1507,19 +1499,10 @@
 
   @keras_parameterized.run_all_keras_modes
   def test_add_loss_with_sample_weight_correctness(self):
-
-    class Bias(keras.layers.Layer):
-
-      def build(self, input_shape):
-        self.bias = self.add_variable('bias', (1,), initializer='zeros')
-
-      def call(self, inputs):
-        return inputs + self.bias
-
     inputs = keras.Input(shape=(1,))
     targets = keras.Input(shape=(1,))
     sw = keras.Input(shape=(1,))
-    outputs = Bias()(inputs)
+    outputs = testing_utils.Bias()(inputs)
     model = keras.Model([inputs, targets, sw], outputs)
 
     model.add_loss(2 * math_ops.reduce_mean(
diff --git a/tensorflow/python/keras/saving/metrics_serialization_test.py b/tensorflow/python/keras/saving/metrics_serialization_test.py
new file mode 100644
index 0000000..e0a7fc9
--- /dev/null
+++ b/tensorflow/python/keras/saving/metrics_serialization_test.py
@@ -0,0 +1,267 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras metrics serialization."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import layers
+from tensorflow.python.keras import metrics
+from tensorflow.python.keras import optimizer_v2
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.keras.utils import generic_utils
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import nest
+
+try:
+  import h5py  # pylint:disable=g-import-not-at-top
+except ImportError:
+  h5py = None
+
+
+# Custom metric
+class MyMeanAbsoluteError(metrics.MeanMetricWrapper):
+
+  def __init__(self, name='my_mae', dtype=None):
+    super(MyMeanAbsoluteError, self).__init__(_my_mae, name, dtype=dtype)
+
+
+# Custom metric function
+def _my_mae(y_true, y_pred):
+  return keras.backend.mean(math_ops.abs(y_pred - y_true), axis=-1)
+
+
+def _get_multi_io_model():
+  inp_1 = layers.Input(shape=(1,), name='input_1')
+  inp_2 = layers.Input(shape=(1,), name='input_2')
+  d = testing_utils.Bias(name='output')
+  out_1 = d(inp_1)
+  out_2 = d(inp_2)
+  return keras.Model([inp_1, inp_2], [out_1, out_2])
+
+
+@keras_parameterized.run_all_keras_modes
+@parameterized.named_parameters(
+    dict(testcase_name='string', value=['mae']),
+    dict(testcase_name='built_in_fn', value=[metrics.mae]),
+    dict(testcase_name='built_in_class', value=[metrics.MeanAbsoluteError]),
+    dict(testcase_name='custom_fn', value=[_my_mae]),
+    dict(testcase_name='custom_class', value=[MyMeanAbsoluteError]),
+    dict(testcase_name='list_of_strings', value=['mae', 'mae']),
+    dict(
+        testcase_name='list_of_built_in_fns', value=[metrics.mae, metrics.mae]),
+    dict(
+        testcase_name='list_of_built_in_classes',
+        value=[metrics.MeanAbsoluteError, metrics.MeanAbsoluteError]),
+    dict(testcase_name='list_of_custom_fns', value=[_my_mae, _my_mae]),
+    dict(
+        testcase_name='list_of_custom_classes',
+        value=[MyMeanAbsoluteError, MyMeanAbsoluteError]),
+    dict(testcase_name='list_of_string_and_list', value=['mae', ['mae']]),
+    dict(
+        testcase_name='list_of_built_in_fn_and_list',
+        value=[metrics.mae, [metrics.mae]]),
+    dict(
+        testcase_name='list_of_built_in_class_and_list',
+        value=[metrics.MeanAbsoluteError, [metrics.MeanAbsoluteError]]),
+    dict(
+        testcase_name='list_of_custom_fn_and_list', value=[_my_mae, [_my_mae]]),
+    dict(
+        testcase_name='list_of_custom_class_and_list',
+        value=[MyMeanAbsoluteError, [MyMeanAbsoluteError]]),
+    dict(
+        testcase_name='list_of_lists_of_custom_fns',
+        value=[[_my_mae], [_my_mae, 'mae']]),
+    dict(
+        testcase_name='list_of_lists_of_custom_classes',
+        value=[[MyMeanAbsoluteError], [MyMeanAbsoluteError, 'mae']]),
+    dict(
+        testcase_name='dict_of_list_of_string',
+        value={
+            'output': ['mae'],
+            'output_1': ['mae'],
+        }),
+    dict(
+        testcase_name='dict_of_list_of_built_in_fn',
+        value={
+            'output': [metrics.mae],
+            'output_1': [metrics.mae],
+        }),
+    dict(
+        testcase_name='dict_of_list_of_built_in_class',
+        value={
+            'output': [metrics.MeanAbsoluteError],
+            'output_1': [metrics.MeanAbsoluteError],
+        }),
+    dict(
+        testcase_name='dict_of_list_of_custom_fn',
+        value={
+            'output': [_my_mae],
+            'output_1': [_my_mae],
+        }),
+    dict(
+        testcase_name='dict_of_list_of_custom_class',
+        value={
+            'output': [MyMeanAbsoluteError],
+            'output_1': [MyMeanAbsoluteError],
+        }),
+    dict(
+        testcase_name='dict_of_string',
+        value={
+            'output': 'mae',
+            'output_1': 'mae',
+        }),
+    dict(
+        testcase_name='dict_of_built_in_fn',
+        value={
+            'output': metrics.mae,
+            'output_1': metrics.mae,
+        }),
+    dict(
+        testcase_name='dict_of_built_in_class',
+        value={
+            'output': metrics.MeanAbsoluteError,
+            'output_1': metrics.MeanAbsoluteError,
+        }),
+    dict(
+        testcase_name='dict_of_custom_fn',
+        value={
+            'output': _my_mae,
+            'output_1': _my_mae
+        }),
+    dict(
+        testcase_name='dict_of_custom_class',
+        value={
+            'output': MyMeanAbsoluteError,
+            'output_1': MyMeanAbsoluteError,
+        }),
+)
+class MetricsSerialization(keras_parameterized.TestCase):
+
+  def setUp(self):
+    super(MetricsSerialization, self).setUp()
+    tmpdir = self.get_temp_dir()
+    self.addCleanup(shutil.rmtree, tmpdir)
+    self.model_filename = os.path.join(tmpdir, 'tmp_model_metric.h5')
+    self.x = np.array([[0.], [1.], [2.]], dtype='float32')
+    self.y = np.array([[0.5], [2.], [3.5]], dtype='float32')
+    self.w = np.array([1.25, 0.5, 1.25], dtype='float32')
+
+  def test_serializing_model_with_metric_with_custom_object_scope(self, value):
+
+    def get_instance(x):
+      if isinstance(x, str):
+        return x
+      if issubclass(x, metrics.Metric):
+        return x()
+      return x
+
+    metric_input = nest.map_structure(get_instance, value)
+    weighted_metric_input = nest.map_structure(get_instance, value)
+
+    with generic_utils.custom_object_scope({
+        'MyMeanAbsoluteError': MyMeanAbsoluteError,
+        '_my_mae': _my_mae,
+        'Bias': testing_utils.Bias,
+    }):
+      model = _get_multi_io_model()
+      model.compile(
+          optimizer_v2.gradient_descent.SGD(0.1),
+          'mae',
+          metrics=metric_input,
+          weighted_metrics=weighted_metric_input,
+          run_eagerly=testing_utils.should_run_eagerly(),
+          experimental_run_tf_function=testing_utils.should_run_tf_function())
+      history = model.fit([self.x, self.x], [self.y, self.y],
+                          batch_size=3,
+                          epochs=3,
+                          sample_weight=[self.w, self.w])
+
+      # Assert training.
+      self.assertAllClose(history.history['loss'], [2., 1.6, 1.2], 1e-3)
+      eval_results = model.evaluate([self.x, self.x], [self.y, self.y],
+                                    sample_weight=[self.w, self.w])
+
+      if h5py is None:
+        return
+      model.save(self.model_filename)
+      loaded_model = keras.models.load_model(self.model_filename)
+      loaded_model.predict([self.x, self.x])
+      loaded_eval_results = loaded_model.evaluate(
+          [self.x, self.x], [self.y, self.y], sample_weight=[self.w, self.w])
+
+      # Assert all evaluation results are the same.
+      self.assertAllClose(eval_results, loaded_eval_results, 1e-9)
+
+  def test_serializing_model_with_metric_with_custom_objects(self, value):
+
+    def get_instance(x):
+      if isinstance(x, str):
+        return x
+      if issubclass(x, metrics.Metric):
+        return x()
+      return x
+
+    metric_input = nest.map_structure(get_instance, value)
+    weighted_metric_input = nest.map_structure(get_instance, value)
+
+    model = _get_multi_io_model()
+    model.compile(
+        optimizer_v2.gradient_descent.SGD(0.1),
+        'mae',
+        metrics=metric_input,
+        weighted_metrics=weighted_metric_input,
+        run_eagerly=testing_utils.should_run_eagerly(),
+        experimental_run_tf_function=testing_utils.should_run_tf_function())
+    history = model.fit([self.x, self.x], [self.y, self.y],
+                        batch_size=3,
+                        epochs=3,
+                        sample_weight=[self.w, self.w])
+
+    # Assert training.
+    self.assertAllClose(history.history['loss'], [2., 1.6, 1.2], 1e-3)
+    eval_results = model.evaluate([self.x, self.x], [self.y, self.y],
+                                  sample_weight=[self.w, self.w])
+
+    if h5py is None:
+      return
+    model.save(self.model_filename)
+    loaded_model = keras.models.load_model(
+        self.model_filename,
+        custom_objects={
+            'MyMeanAbsoluteError': MyMeanAbsoluteError,
+            '_my_mae': _my_mae,
+            'Bias': testing_utils.Bias,
+        })
+    loaded_model.predict([self.x, self.x])
+    loaded_eval_results = loaded_model.evaluate([self.x, self.x],
+                                                [self.y, self.y],
+                                                sample_weight=[self.w, self.w])
+
+    # Assert all evaluation results are the same.
+    self.assertAllClose(eval_results, loaded_eval_results, 1e-9)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 2c48434..4ee32ee 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -605,6 +605,15 @@
   raise ValueError('Unknown model type {}'.format(model_type))
 
 
+class Bias(keras.layers.Layer):
+
+  def build(self, input_shape):
+    self.bias = self.add_variable('bias', (1,), initializer='zeros')
+
+  def call(self, inputs):
+    return inputs + self.bias
+
+
 class _MultiIOSubclassModel(keras.Model):
   """Multi IO Keras subclass model."""