Add usage examples to metric functions and improve some API docs.

PiperOrigin-RevId: 303200932
Change-Id: I18ac3d545b32975e85f3532f4c60e59e9bb05548
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 4333ff7..5cbd59c 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 # ==============================================================================
 # pylint: disable=unused-import
+# pylint: disable=g-classes-have-attributes
 """Built-in metrics.
 """
 from __future__ import absolute_import
@@ -77,6 +78,11 @@
 class Metric(base_layer.Layer):
   """Encapsulates metric logic and state.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+    **kwargs: Additional layer keywords arguments.
+
   Usage:
 
   ```python
@@ -291,16 +297,15 @@
 
 
 class Reduce(Metric):
-  """Encapsulates metrics that perform a reduce operation on the values."""
+  """Encapsulates metrics that perform a reduce operation on the values.
+
+  Args:
+    reduction: a `tf.keras.metrics.Reduction` enum value.
+    name: string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+  """
 
   def __init__(self, reduction, name, dtype=None):
-    """Creates a `Reduce` instance.
-
-    Args:
-      reduction: a `tf.keras.metrics.Reduction` enum value.
-      name: string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(Reduce, self).__init__(name=name, dtype=dtype)
     self.reduction = reduction
     with ops.init_scope():
@@ -312,11 +317,7 @@
             'count', initializer=init_ops.zeros_initializer)
 
   def update_state(self, values, sample_weight=None):
-    """Accumulates statistics for computing the reduction metric.
-
-    For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE,
-    then the value of `result()` is 4. If the `sample_weight` is specified as
-    [1, 1, 0, 0] then value of `result()` would be 2.
+    """Accumulates statistics for computing the metric.
 
     Args:
       values: Per-example value.
@@ -399,6 +400,10 @@
   If `sample_weight` is `None`, weights default to 1.  Use `sample_weight` of 0
   to mask values.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.Sum()
@@ -416,12 +421,6 @@
   """
 
   def __init__(self, name='sum', dtype=None):
-    """Creates a `Sum` instance.
-
-    Args:
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM,
                               name=name, dtype=dtype)
 
@@ -440,6 +439,10 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.Mean()
@@ -461,12 +464,6 @@
   """
 
   def __init__(self, name='mean', dtype=None):
-    """Creates a `Mean` instance.
-
-    Args:
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(Mean, self).__init__(
         reduction=metrics_utils.Reduction.WEIGHTED_MEAN, name=name, dtype=dtype)
 
@@ -483,6 +480,11 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    normalizer: The normalizer values with same shape as predictions.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.MeanRelativeError(normalizer=[1, 3, 2, 3])
@@ -506,13 +508,6 @@
   """
 
   def __init__(self, normalizer, name=None, dtype=None):
-    """Creates a `MeanRelativeError` instance.
-
-    Args:
-      normalizer: The normalizer values with same shape as predictions.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(MeanRelativeError, self).__init__(name=name, dtype=dtype)
     normalizer = math_ops.cast(normalizer, self._dtype)
     self.normalizer = normalizer
@@ -555,18 +550,17 @@
 
 
 class MeanMetricWrapper(Mean):
-  """Wraps a stateless metric function with the Mean metric."""
+  """Wraps a stateless metric function with the Mean metric.
+
+  Args:
+    fn: The metric function to wrap, with signature `fn(y_true, y_pred,
+      **kwargs)`.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+    **kwargs: The keyword arguments that are passed on to `fn`.
+  """
 
   def __init__(self, fn, name=None, dtype=None, **kwargs):
-    """Creates a `MeanMetricWrapper` instance.
-
-    Args:
-      fn: The metric function to wrap, with signature
-        `fn(y_true, y_pred, **kwargs)`.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-      **kwargs: The keyword arguments that are passed on to `fn`.
-    """
     super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype)
     self._fn = fn
     self._fn_kwargs = kwargs
@@ -640,6 +634,10 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.Accuracy()
@@ -677,6 +675,12 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+    threshold: (Optional) Float representing the threshold for deciding
+    whether prediction values are 1 or 0.
+
   Usage:
 
   >>> m = tf.keras.metrics.BinaryAccuracy()
@@ -699,14 +703,6 @@
   """
 
   def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5):
-    """Creates a `BinaryAccuracy` instance.
-
-    Args:
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-      threshold: (Optional) Float representing the threshold for deciding
-      whether prediction values are 1 or 0.
-    """
     super(BinaryAccuracy, self).__init__(
         binary_accuracy, name, dtype=dtype, threshold=threshold)
 
@@ -729,6 +725,10 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.CategoricalAccuracy()
@@ -756,12 +756,6 @@
   """
 
   def __init__(self, name='categorical_accuracy', dtype=None):
-    """Creates a `CategoricalAccuracy` instance.
-
-    Args:
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(CategoricalAccuracy, self).__init__(
         categorical_accuracy, name, dtype=dtype)
 
@@ -773,7 +767,7 @@
   ```python
   acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1))
   ```
-  
+
   You can provide logits of classes as `y_pred`, since argmax of
   logits and probabilities are same.
 
@@ -785,6 +779,10 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.SparseCategoricalAccuracy()
@@ -818,6 +816,12 @@
 class TopKCategoricalAccuracy(MeanMetricWrapper):
   """Computes how often targets are in the top `K` predictions.
 
+  Args:
+    k: (Optional) Number of top elements to look at for computing accuracy.
+      Defaults to 5.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)
@@ -842,14 +846,6 @@
   """
 
   def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None):
-    """Creates a `TopKCategoricalAccuracy` instance.
-
-    Args:
-      k: (Optional) Number of top elements to look at for computing accuracy.
-        Defaults to 5.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(TopKCategoricalAccuracy, self).__init__(
         top_k_categorical_accuracy, name, dtype=dtype, k=k)
 
@@ -858,6 +854,12 @@
 class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
   """Computes how often integer targets are in the top `K` predictions.
 
+  Args:
+    k: (Optional) Number of top elements to look at for computing accuracy.
+      Defaults to 5.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
@@ -882,38 +884,29 @@
   """
 
   def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None):
-    """Creates a `SparseTopKCategoricalAccuracy` instance.
-
-    Args:
-      k: (Optional) Number of top elements to look at for computing accuracy.
-        Defaults to 5.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(SparseTopKCategoricalAccuracy, self).__init__(
         sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k)
 
 
 class _ConfusionMatrixConditionCount(Metric):
-  """Calculates the number of the given confusion matrix condition."""
+  """Calculates the number of the given confusion matrix condition.
+
+  Args:
+    confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions.
+    thresholds: (Optional) Defaults to 0.5. A float value or a python list/tuple
+      of float threshold values in [0, 1]. A threshold is compared with
+      prediction values to determine the truth value of predictions (i.e., above
+      the threshold is `true`, below is `false`). One metric value is generated
+      for each threshold value.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+  """
 
   def __init__(self,
                confusion_matrix_cond,
                thresholds=None,
                name=None,
                dtype=None):
-    """Creates a `_ConfusionMatrixConditionCount` instance.
-
-    Args:
-      confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions.
-      thresholds: (Optional) Defaults to 0.5. A float value or a python
-        list/tuple of float threshold values in [0, 1]. A threshold is compared
-        with prediction values to determine the truth value of predictions
-        (i.e., above the threshold is `true`, below is `false`). One metric
-        value is generated for each threshold value.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype)
     self._confusion_matrix_cond = confusion_matrix_cond
     self.init_thresholds = thresholds
@@ -925,7 +918,7 @@
         initializer=init_ops.zeros_initializer)
 
   def update_state(self, y_true, y_pred, sample_weight=None):
-    """Accumulates the given confusion matrix condition statistics.
+    """Accumulates the metric statistics.
 
     Args:
       y_true: The ground truth values.
@@ -973,6 +966,15 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    thresholds: (Optional) Defaults to 0.5. A float value or a python
+      list/tuple of float threshold values in [0, 1]. A threshold is compared
+      with prediction values to determine the truth value of predictions
+      (i.e., above the threshold is `true`, below is `false`). One metric
+      value is generated for each threshold value.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.FalsePositives()
@@ -994,17 +996,6 @@
   """
 
   def __init__(self, thresholds=None, name=None, dtype=None):
-    """Creates a `FalsePositives` instance.
-
-    Args:
-      thresholds: (Optional) Defaults to 0.5. A float value or a python
-        list/tuple of float threshold values in [0, 1]. A threshold is compared
-        with prediction values to determine the truth value of predictions
-        (i.e., above the threshold is `true`, below is `false`). One metric
-        value is generated for each threshold value.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(FalsePositives, self).__init__(
         confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES,
         thresholds=thresholds,
@@ -1023,6 +1014,15 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    thresholds: (Optional) Defaults to 0.5. A float value or a python
+      list/tuple of float threshold values in [0, 1]. A threshold is compared
+      with prediction values to determine the truth value of predictions
+      (i.e., above the threshold is `true`, below is `false`). One metric
+      value is generated for each threshold value.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.FalseNegatives()
@@ -1044,17 +1044,6 @@
   """
 
   def __init__(self, thresholds=None, name=None, dtype=None):
-    """Creates a `FalseNegatives` instance.
-
-    Args:
-      thresholds: (Optional) Defaults to 0.5. A float value or a python
-        list/tuple of float threshold values in [0, 1]. A threshold is compared
-        with prediction values to determine the truth value of predictions
-        (i.e., above the threshold is `true`, below is `false`). One metric
-        value is generated for each threshold value.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(FalseNegatives, self).__init__(
         confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES,
         thresholds=thresholds,
@@ -1073,6 +1062,15 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    thresholds: (Optional) Defaults to 0.5. A float value or a python
+      list/tuple of float threshold values in [0, 1]. A threshold is compared
+      with prediction values to determine the truth value of predictions
+      (i.e., above the threshold is `true`, below is `false`). One metric
+      value is generated for each threshold value.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.TrueNegatives()
@@ -1094,17 +1092,6 @@
   """
 
   def __init__(self, thresholds=None, name=None, dtype=None):
-    """Creates a `TrueNegatives` instance.
-
-    Args:
-      thresholds: (Optional) Defaults to 0.5. A float value or a python
-        list/tuple of float threshold values in [0, 1]. A threshold is compared
-        with prediction values to determine the truth value of predictions
-        (i.e., above the threshold is `true`, below is `false`). One metric
-        value is generated for each threshold value.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(TrueNegatives, self).__init__(
         confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES,
         thresholds=thresholds,
@@ -1123,6 +1110,15 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    thresholds: (Optional) Defaults to 0.5. A float value or a python
+      list/tuple of float threshold values in [0, 1]. A threshold is compared
+      with prediction values to determine the truth value of predictions
+      (i.e., above the threshold is `true`, below is `false`). One metric
+      value is generated for each threshold value.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.TruePositives()
@@ -1144,17 +1140,6 @@
   """
 
   def __init__(self, thresholds=None, name=None, dtype=None):
-    """Creates a `TruePositives` instance.
-
-    Args:
-      thresholds: (Optional) Defaults to 0.5. A float value or a python
-        list/tuple of float threshold values in [0, 1]. A threshold is compared
-        with prediction values to determine the truth value of predictions
-        (i.e., above the threshold is `true`, below is `false`). One metric
-        value is generated for each threshold value.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(TruePositives, self).__init__(
         confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES,
         thresholds=thresholds,
@@ -1183,6 +1168,21 @@
   top-k highest predictions, and computing the fraction of them for which
   `class_id` is indeed a correct label.
 
+  Args:
+    thresholds: (Optional) A float value or a python list/tuple of float
+      threshold values in [0, 1]. A threshold is compared with prediction
+      values to determine the truth value of predictions (i.e., above the
+      threshold is `true`, below is `false`). One metric value is generated
+      for each threshold value. If neither thresholds nor top_k are set, the
+      default is to calculate precision with `thresholds=0.5`.
+    top_k: (Optional) Unset by default. An int value specifying the top-k
+      predictions to consider when calculating precision.
+    class_id: (Optional) Integer class ID for which we want binary metrics.
+      This must be in the half-open interval `[0, num_classes)`, where
+      `num_classes` is the last dimension of predictions.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.Precision()
@@ -1221,23 +1221,6 @@
                class_id=None,
                name=None,
                dtype=None):
-    """Creates a `Precision` instance.
-
-    Args:
-      thresholds: (Optional) A float value or a python list/tuple of float
-        threshold values in [0, 1]. A threshold is compared with prediction
-        values to determine the truth value of predictions (i.e., above the
-        threshold is `true`, below is `false`). One metric value is generated
-        for each threshold value. If neither thresholds nor top_k are set, the
-        default is to calculate precision with `thresholds=0.5`.
-      top_k: (Optional) Unset by default. An int value specifying the top-k
-        predictions to consider when calculating precision.
-      class_id: (Optional) Integer class ID for which we want binary metrics.
-        This must be in the half-open interval `[0, num_classes)`, where
-        `num_classes` is the last dimension of predictions.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(Precision, self).__init__(name=name, dtype=dtype)
     self.init_thresholds = thresholds
     self.top_k = top_k
@@ -1321,6 +1304,21 @@
   fraction of them for which `class_id` is above the threshold and/or in the
   top-k predictions.
 
+  Args:
+    thresholds: (Optional) A float value or a python list/tuple of float
+      threshold values in [0, 1]. A threshold is compared with prediction
+      values to determine the truth value of predictions (i.e., above the
+      threshold is `true`, below is `false`). One metric value is generated
+      for each threshold value. If neither thresholds nor top_k are set, the
+      default is to calculate recall with `thresholds=0.5`.
+    top_k: (Optional) Unset by default. An int value specifying the top-k
+      predictions to consider when calculating recall.
+    class_id: (Optional) Integer class ID for which we want binary metrics.
+      This must be in the half-open interval `[0, num_classes)`, where
+      `num_classes` is the last dimension of predictions.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.Recall()
@@ -1347,23 +1345,6 @@
                class_id=None,
                name=None,
                dtype=None):
-    """Creates a `Recall` instance.
-
-    Args:
-      thresholds: (Optional) A float value or a python list/tuple of float
-        threshold values in [0, 1]. A threshold is compared with prediction
-        values to determine the truth value of predictions (i.e., above the
-        threshold is `true`, below is `false`). One metric value is generated
-        for each threshold value. If neither thresholds nor top_k are set, the
-        default is to calculate recall with `thresholds=0.5`.
-      top_k: (Optional) Unset by default. An int value specifying the top-k
-        predictions to consider when calculating recall.
-      class_id: (Optional) Integer class ID for which we want binary metrics.
-        This must be in the half-open interval `[0, num_classes)`, where
-        `num_classes` is the last dimension of predictions.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(Recall, self).__init__(name=name, dtype=dtype)
     self.init_thresholds = thresholds
     self.top_k = top_k
@@ -1541,6 +1522,13 @@
   For additional information about specificity and sensitivity, see the
   following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
 
+  Args:
+    specificity: A scalar value in range `[0, 1]`.
+    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
+      use for matching the given specificity.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5)
@@ -1566,15 +1554,6 @@
   """
 
   def __init__(self, specificity, num_thresholds=200, name=None, dtype=None):
-    """Creates a `SensitivityAtSpecificity` instance.
-
-    Args:
-      specificity: A scalar value in range `[0, 1]`.
-      num_thresholds: (Optional) Defaults to 200. The number of thresholds to
-        use for matching the given specificity.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     if specificity < 0 or specificity > 1:
       raise ValueError('`specificity` must be in the range [0, 1].')
     self.specificity = specificity
@@ -1619,6 +1598,13 @@
   For additional information about specificity and sensitivity, see the
   following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
 
+  Args:
+    sensitivity: A scalar value in range `[0, 1]`.
+    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
+      use for matching the given sensitivity.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5)
@@ -1644,15 +1630,6 @@
   """
 
   def __init__(self, sensitivity, num_thresholds=200, name=None, dtype=None):
-    """Creates a `SpecificityAtSensitivity` instance.
-
-    Args:
-      sensitivity: A scalar value in range `[0, 1]`.
-      num_thresholds: (Optional) Defaults to 200. The number of thresholds to
-        use for matching the given sensitivity.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     if sensitivity < 0 or sensitivity > 1:
       raise ValueError('`sensitivity` must be in the range [0, 1].')
     self.sensitivity = sensitivity
@@ -1689,6 +1666,13 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    recall: A scalar value in range `[0, 1]`.
+    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
+      use for matching the given recall.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.PrecisionAtRecall(0.5)
@@ -1714,15 +1698,6 @@
   """
 
   def __init__(self, recall, num_thresholds=200, name=None, dtype=None):
-    """Creates a `PrecisionAtRecall` instance.
-
-    Args:
-      recall: A scalar value in range `[0, 1]`.
-      num_thresholds: (Optional) Defaults to 200. The number of thresholds to
-        use for matching the given recall.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     if recall < 0 or recall > 1:
       raise ValueError('`recall` must be in the range [0, 1].')
     self.recall = recall
@@ -1762,6 +1737,13 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    precision: A scalar value in range `[0, 1]`.
+    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
+      use for matching the given precision.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.RecallAtPrecision(0.8)
@@ -1787,15 +1769,6 @@
   """
 
   def __init__(self, precision, num_thresholds=200, name=None, dtype=None):
-    """Creates a `RecallAtPrecision` instance.
-
-    Args:
-      precision: A scalar value in range `[0, 1]`.
-      num_thresholds: (Optional) Defaults to 200. The number of thresholds to
-        use for matching the given precision.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     if precision < 0 or precision > 1:
       raise ValueError('`precision` must be in the range [0, 1].')
     self.precision = precision
@@ -1850,6 +1823,44 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
+      use when discretizing the roc curve. Values must be > 1.
+    curve: (Optional) Specifies the name of the curve to be computed, 'ROC'
+      [default] or 'PR' for the Precision-Recall-curve.
+    summation_method: (Optional) Specifies the Riemann summation method used
+      (https://en.wikipedia.org/wiki/Riemann_sum): 'interpolation' [default],
+        applies mid-point summation scheme for `ROC`. For PR-AUC, interpolates
+        (true/false) positives but not the ratio that is precision (see Davis
+        & Goadrich 2006 for details); 'minoring' that applies left summation
+        for increasing intervals and right summation for decreasing intervals;
+        'majoring' that does the opposite.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+    thresholds: (Optional) A list of floating point values to use as the
+      thresholds for discretizing the curve. If set, the `num_thresholds`
+      parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
+      equal to {-epsilon, 1+epsilon} for a small positive epsilon value will
+      be automatically included with these to correctly handle predictions
+      equal to exactly 0 or 1.
+    multi_label: boolean indicating whether multilabel data should be
+      treated as such, wherein AUC is computed separately for each label and
+      then averaged across labels, or (when False) if the data should be
+      flattened into a single label before AUC computation. In the latter
+      case, when multilabel data is passed to AUC, each label-prediction pair
+      is treated as an individual data point. Should be set to False for
+      multi-class data.
+    label_weights: (optional) list, array, or tensor of non-negative weights
+      used to compute AUCs for multilabel data. When `multi_label` is True,
+      the weights are applied to the individual label AUCs when they are
+      averaged to produce the multi-label AUC. When it's False, they are used
+      to weight the individual label predictions in computing the confusion
+      matrix on the flattened data. Note that this is unlike class_weights in
+      that class_weights weights the example depending on the value of its
+      label, whereas label_weights depends only on the index of that label
+      before flattening; therefore `label_weights` should not be used for
+      multi-class data.
+
   Usage:
 
   >>> m = tf.keras.metrics.AUC(num_thresholds=3)
@@ -1884,46 +1895,6 @@
                thresholds=None,
                multi_label=False,
                label_weights=None):
-    """Creates an `AUC` instance.
-
-    Args:
-      num_thresholds: (Optional) Defaults to 200. The number of thresholds to
-        use when discretizing the roc curve. Values must be > 1.
-      curve: (Optional) Specifies the name of the curve to be computed, 'ROC'
-        [default] or 'PR' for the Precision-Recall-curve.
-      summation_method: (Optional) Specifies the Riemann summation method used
-        (https://en.wikipedia.org/wiki/Riemann_sum): 'interpolation' [default],
-          applies mid-point summation scheme for `ROC`. For PR-AUC, interpolates
-          (true/false) positives but not the ratio that is precision (see Davis
-          & Goadrich 2006 for details); 'minoring' that applies left summation
-          for increasing intervals and right summation for decreasing intervals;
-          'majoring' that does the opposite.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-      thresholds: (Optional) A list of floating point values to use as the
-        thresholds for discretizing the curve. If set, the `num_thresholds`
-        parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
-        equal to {-epsilon, 1+epsilon} for a small positive epsilon value will
-        be automatically included with these to correctly handle predictions
-        equal to exactly 0 or 1.
-      multi_label: boolean indicating whether multilabel data should be
-        treated as such, wherein AUC is computed separately for each label and
-        then averaged across labels, or (when False) if the data should be
-        flattened into a single label before AUC computation. In the latter
-        case, when multilabel data is passed to AUC, each label-prediction pair
-        is treated as an individual data point. Should be set to False for
-        multi-class data.
-      label_weights: (optional) list, array, or tensor of non-negative weights
-        used to compute AUCs for multilabel data. When `multi_label` is True,
-        the weights are applied to the individual label AUCs when they are
-        averaged to produce the multi-label AUC. When it's False, they are used
-        to weight the individual label predictions in computing the confusion
-        matrix on the flattened data. Note that this is unlike class_weights in
-        that class_weights weights the example depending on the value of its
-        label, whereas label_weights depends only on the index of that label
-        before flattening; therefore `label_weights` should not be used for
-        multi-class data.
-    """
     # Validate configurations.
     if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
         metrics_utils.AUCCurve):
@@ -2262,6 +2233,12 @@
   This metric keeps the average cosine similarity between `predictions` and
   `labels` over a stream of data.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+    axis: (Optional) Defaults to -1. The dimension along which the cosine
+      similarity is computed.
+
   Usage:
 
   >>> # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]]
@@ -2292,14 +2269,6 @@
   """
 
   def __init__(self, name='cosine_similarity', dtype=None, axis=-1):
-    """Creates a `CosineSimilarity` instance.
-
-    Args:
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-      axis: (Optional) Defaults to -1. The dimension along which the cosine
-        similarity is computed.
-    """
     super(CosineSimilarity, self).__init__(
         cosine_similarity, name, dtype=dtype, axis=axis)
 
@@ -2308,6 +2277,10 @@
 class MeanAbsoluteError(MeanMetricWrapper):
   """Computes the mean absolute error between the labels and predictions.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.MeanAbsoluteError()
@@ -2339,6 +2312,10 @@
 class MeanAbsolutePercentageError(MeanMetricWrapper):
   """Computes the mean absolute percentage error between `y_true` and `y_pred`.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.MeanAbsolutePercentageError()
@@ -2372,6 +2349,10 @@
 class MeanSquaredError(MeanMetricWrapper):
   """Computes the mean squared error between `y_true` and `y_pred`.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.MeanSquaredError()
@@ -2403,6 +2384,10 @@
 class MeanSquaredLogarithmicError(MeanMetricWrapper):
   """Computes the mean squared logarithmic error between `y_true` and `y_pred`.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.MeanSquaredLogarithmicError()
@@ -2439,6 +2424,10 @@
   `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
   provided we will convert them to -1 or 1.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.Hinge()
@@ -2471,6 +2460,10 @@
   `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
   provided we will convert them to -1 or 1.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.SquaredHinge()
@@ -2503,6 +2496,10 @@
 class CategoricalHinge(MeanMetricWrapper):
   """Computes the categorical hinge metric between `y_true` and `y_pred`.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.CategoricalHinge()
@@ -2593,6 +2590,10 @@
 
   `logcosh = log((exp(x) + exp(-x))/2)`, where x is the error (y_pred - y_true)
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.LogCoshError()
@@ -2624,6 +2625,10 @@
 
   `metric = y_pred - y_true * log(y_pred)`
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.Poisson()
@@ -2655,6 +2660,10 @@
 
   `metric = y_true * log(y_true / y_pred)`
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.KLDivergence()
@@ -2695,6 +2704,13 @@
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
 
+  Args:
+    num_classes: The possible number of labels the prediction task can have.
+      This value must be provided, since a confusion matrix of dimension =
+      [num_classes, num_classes] will be allocated.
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> # cm = [[1, 1],
@@ -2725,15 +2741,6 @@
   """
 
   def __init__(self, num_classes, name=None, dtype=None):
-    """Creates a `MeanIoU` instance.
-
-    Args:
-      num_classes: The possible number of labels the prediction task can have.
-        This value must be provided, since a confusion matrix of dimension =
-        [num_classes, num_classes] will be allocated.
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(MeanIoU, self).__init__(name=name, dtype=dtype)
     self.num_classes = num_classes
 
@@ -2825,6 +2832,10 @@
   `total` tracks the sum of the weighted values, and `count` stores the sum of
   the weighted counts.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+
   Usage:
 
   >>> m = tf.keras.metrics.MeanTensor()
@@ -2839,12 +2850,6 @@
   """
 
   def __init__(self, name='mean_tensor', dtype=None):
-    """Creates a `MeanTensor` instance.
-
-    Args:
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-    """
     super(MeanTensor, self).__init__(name=name, dtype=dtype)
     self._shape = None
     self._total = None
@@ -2936,6 +2941,16 @@
   This is the crossentropy metric class to be used when there are only two
   label classes (0 and 1).
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+    from_logits: (Optional )Whether output is expected to be a logits tensor.
+      By default, we consider that output encodes a probability distribution.
+    label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are
+      smoothed, meaning the confidence on label values are relaxed.
+      e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for
+      label `0` and `0.9` for label `1`".
+
   Usage:
 
   >>> m = tf.keras.metrics.BinaryCrossentropy()
@@ -2965,19 +2980,6 @@
                dtype=None,
                from_logits=False,
                label_smoothing=0):
-    """Creates a `BinaryCrossentropy` instance.
-
-    Args:
-      name: (Optional) string name of the metric instance.
-      dtype: (Optional) data type of the metric result.
-      from_logits: (Optional )Whether output is expected to be a logits tensor.
-        By default, we consider that output encodes a probability distribution.
-      label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are
-        smoothed, meaning the confidence on label values are relaxed.
-        e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for
-        label `0` and `0.9` for label `1`"
-    """
-
     super(BinaryCrossentropy, self).__init__(
         binary_crossentropy,
         name,
@@ -2995,6 +2997,16 @@
   representation. eg., When labels values are [2, 0, 1],
    `y_true` = [[0, 0, 1], [1, 0, 0], [0, 1, 0]].
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+    from_logits: (Optional) Whether output is expected to be a logits tensor.
+      By default, we consider that output encodes a probability distribution.
+    label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are
+      smoothed, meaning the confidence on label values are relaxed. e.g.
+      `label_smoothing=0.2` means that we will use a value of `0.1` for label
+      `0` and `0.9` for label `1`"
+
   Usage:
 
   >>> # EPSILON = 1e-7, y = y_true, y` = y_pred
@@ -3026,16 +3038,6 @@
     loss='mse',
     metrics=[tf.keras.metrics.CategoricalCrossentropy()])
   ```
-
-  Args:
-    name: (Optional) string name of the metric instance.
-    dtype: (Optional) data type of the metric result.
-    from_logits: (Optional ) Whether `y_pred` is expected to be a logits tensor.
-      By default, we assume that `y_pred` encodes a probability distribution.
-    label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,
-      meaning the confidence on label values are relaxed. e.g.
-      `label_smoothing=0.2` means that we will use a value of `0.1` for label
-      `0` and `0.9` for label `1`"
   """
 
   def __init__(self,
@@ -3043,7 +3045,6 @@
                dtype=None,
                from_logits=False,
                label_smoothing=0):
-
     super(CategoricalCrossentropy, self).__init__(
         categorical_crossentropy,
         name,
@@ -3067,6 +3068,14 @@
   The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is
   `[batch_size, num_classes]`.
 
+  Args:
+    name: (Optional) string name of the metric instance.
+    dtype: (Optional) data type of the metric result.
+    from_logits: (Optional) Whether output is expected to be a logits tensor.
+      By default, we consider that output encodes a probability distribution.
+    axis: (Optional) Defaults to -1. The dimension along which the metric is
+      computed.
+
   Usage:
 
   >>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]]
@@ -3101,14 +3110,6 @@
     loss='mse',
     metrics=[tf.keras.metrics.SparseCategoricalCrossentropy()])
   ```
-
-  Args:
-    name: (Optional) string name of the metric instance.
-    dtype: (Optional) data type of the metric result.
-    from_logits: (Optional ) Whether `y_pred` is expected to be a logits tensor.
-      By default, we assume that `y_pred` encodes a probability distribution.
-    axis: (Optional) Defaults to -1. The dimension along which the metric is
-      computed.
   """
 
   def __init__(self,
@@ -3116,7 +3117,6 @@
                dtype=None,
                from_logits=False,
                axis=-1):
-
     super(SparseCategoricalCrossentropy, self).__init__(
         sparse_categorical_crossentropy,
         name,
@@ -3196,6 +3196,14 @@
 def binary_accuracy(y_true, y_pred, threshold=0.5):
   """Calculates how often predictions matches binary labels.
 
+  Usage:
+  >>> y_true = [[1], [1], [0], [0]]
+  >>> y_pred = [[1], [1], [0], [0]]
+  >>> m = tf.keras.metrics.binary_accuracy(y_true, y_pred)
+  >>> assert m.shape == (4,)
+  >>> m.numpy()
+  array([1., 1., 1., 1.], dtype=float32)
+
   Args:
     y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
     y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
@@ -3205,6 +3213,7 @@
   Returns:
     Binary accuracy values. shape = `[batch_size, d0, .. dN-1]`
   """
+  y_pred = ops.convert_to_tensor_v2(y_pred)
   threshold = math_ops.cast(threshold, y_pred.dtype)
   y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype)
   return K.mean(math_ops.equal(y_true, y_pred), axis=-1)
@@ -3214,6 +3223,14 @@
 def categorical_accuracy(y_true, y_pred):
   """Calculates how often predictions matches one-hot labels.
 
+  Usage:
+  >>> y_true = [[0, 0, 1], [0, 1, 0]]
+  >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
+  >>> m = tf.keras.metrics.categorical_accuracy(y_true, y_pred)
+  >>> assert m.shape == (2,)
+  >>> m.numpy()
+  array([0., 1.], dtype=float32)
+
   You can provide logits of classes as `y_pred`, since argmax of
   logits and probabilities are same.
 
@@ -3234,6 +3251,14 @@
 def sparse_categorical_accuracy(y_true, y_pred):
   """Calculates how often predictions matches integer labels.
 
+  Usage:
+  >>> y_true = [2, 1]
+  >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
+  >>> m = tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred)
+  >>> assert m.shape == (2,)
+  >>> m.numpy()
+  array([0., 1.], dtype=float32)
+
   You can provide logits of classes as `y_pred`, since argmax of
   logits and probabilities are same.
 
@@ -3244,8 +3269,10 @@
   Returns:
     Sparse categorical accuracy values.
   """
-  y_pred_rank = ops.convert_to_tensor_v2(y_pred).shape.ndims
-  y_true_rank = ops.convert_to_tensor_v2(y_true).shape.ndims
+  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_true = ops.convert_to_tensor_v2(y_true)
+  y_pred_rank = y_pred.shape.ndims
+  y_true_rank = y_true.shape.ndims
   # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
   if (y_true_rank is not None) and (y_pred_rank is not None) and (len(
       K.int_shape(y_true)) == len(K.int_shape(y_pred))):
@@ -3264,6 +3291,14 @@
 def top_k_categorical_accuracy(y_true, y_pred, k=5):
   """Computes how often targets are in the top `K` predictions.
 
+  Usage:
+  >>> y_true = [[0, 0, 1], [0, 1, 0]]
+  >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
+  >>> m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)
+  >>> assert m.shape == (2,)
+  >>> m.numpy()
+  array([1., 1.], dtype=float32)
+
   Args:
     y_true: The ground truth values.
     y_pred: The prediction values.
@@ -3281,6 +3316,15 @@
 def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
   """Computes how often integer targets are in the top `K` predictions.
 
+  Usage:
+  >>> y_true = [2, 1]
+  >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
+  >>> m = tf.keras.metrics.sparse_top_k_categorical_accuracy(
+  ...     y_true, y_pred, k=3)
+  >>> assert m.shape == (2,)
+  >>> m.numpy()
+  array([1., 1.], dtype=float32)
+
   Args:
     y_true: tensor of true targets.
     y_pred: tensor of predicted targets.
@@ -3345,11 +3389,29 @@
 
 @keras_export('keras.metrics.serialize')
 def serialize(metric):
+  """Serializes metric function or `Metric` instance.
+
+  Arguments:
+    metric: A Keras `Metric` instance or a metric function.
+
+  Returns:
+    Metric configuration dictionary.
+  """
   return serialize_keras_object(metric)
 
 
 @keras_export('keras.metrics.deserialize')
 def deserialize(config, custom_objects=None):
+  """Deserializes a serialized metric class/function instance.
+
+  Arguments:
+    config: Metric configuration.
+    custom_objects: Optional dictionary mapping names (strings) to custom
+      objects (classes and functions) to be considered during deserialization.
+
+  Returns:
+      A Keras `Metric` instance or a metric function.
+  """
   return deserialize_keras_object(
       config,
       module_objects=globals(),
@@ -3359,7 +3421,38 @@
 
 @keras_export('keras.metrics.get')
 def get(identifier):
-  """Return a metric given its identifer."""
+  """Retrieves a Keras metric as a `function`/`Metric` class instance.
+
+  The `identifier` may be the string name of a metric function or class.
+
+  >>> metric = tf.keras.metrics.get("categorical_crossentropy")
+  >>> type(metric)
+  <class 'function'>
+  >>> metric = tf.keras.metrics.get("CategoricalCrossentropy")
+  >>> type(metric)
+  <class '...tensorflow.python.keras.metrics.CategoricalCrossentropy'>
+
+  You can also specify `config` of the metric to this function by passing dict
+  containing `class_name` and `config` as an identifier. Also note that the
+  `class_name` must map to a `Metric` class
+
+  >>> identifier = {"class_name": "CategoricalCrossentropy",
+  ...               "config": {"from_logits": True}}
+  >>> metric = tf.keras.metrics.get(identifier)
+  >>> type(metric)
+  <class '...tensorflow.python.keras.metrics.CategoricalCrossentropy'>
+
+  Arguments:
+    identifier: A metric identifier. One of None or string name of a metric
+      function/class or metric configuration dictionary or a metric function or
+      a metric class instance
+
+  Returns:
+    A Keras metric as a `function`/ `Metric` class instance.
+
+  Raises:
+    ValueError: If `identifier` cannot be interpreted.
+  """
   if isinstance(identifier, dict):
     return deserialize(identifier)
   elif isinstance(identifier, six.string_types):