| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| # pylint: disable=protected-access |
| """Utils related to keras metrics. |
| """ |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import functools |
| import weakref |
| |
| from enum import Enum |
| |
| from tensorflow.python.distribute import distribution_strategy_context |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.keras.utils import tf_utils |
| from tensorflow.python.keras.utils.generic_utils import to_list |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import check_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import gen_math_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import nn_ops |
| from tensorflow.python.ops import weights_broadcast_ops |
| from tensorflow.python.ops.losses import util as tf_losses_utils |
| from tensorflow.python.ops.ragged import ragged_tensor |
| from tensorflow.python.ops.ragged import ragged_util |
| from tensorflow.python.tpu import tpu |
| from tensorflow.python.util import tf_decorator |
| |
| NEG_INF = -1e10 |
| |
| |
| class Reduction(Enum): |
| """Types of metrics reduction. |
| |
| Contains the following values: |
| |
| * `SUM`: Scalar sum of weighted values. |
| * `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by |
| number of elements. |
| * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights. |
| """ |
| SUM = 'sum' |
| SUM_OVER_BATCH_SIZE = 'sum_over_batch_size' |
| WEIGHTED_MEAN = 'weighted_mean' |
| |
| |
| def update_state_wrapper(update_state_fn): |
| """Decorator to wrap metric `update_state()` with `add_update()`. |
| |
| Args: |
| update_state_fn: function that accumulates metric statistics. |
| |
| Returns: |
| Decorated function that wraps `update_state_fn()` with `add_update()`. |
| """ |
| |
| def decorated(metric_obj, *args, **kwargs): |
| """Decorated function with `add_update()`.""" |
| strategy = distribution_strategy_context.get_strategy() |
| # TODO(b/142574744): Remove this check if a better solution is found for |
| # declaring keras Metric outside of TPUStrategy and then updating it per |
| # replica. |
| |
| for weight in metric_obj.weights: |
| if (tpu.is_tpu_strategy(strategy) and |
| not strategy.extended.variable_created_in_scope(weight) |
| and not distribution_strategy_context.in_cross_replica_context()): |
| raise ValueError( |
| 'Trying to run metric.update_state in replica context when ' |
| 'the metric was not created in TPUStrategy scope. ' |
| 'Make sure the keras Metric is created in TPUstrategy scope. ') |
| |
| with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs): |
| update_op = update_state_fn(*args, **kwargs) |
| if update_op is not None: # update_op will be None in eager execution. |
| metric_obj.add_update(update_op) |
| return update_op |
| |
| return tf_decorator.make_decorator(update_state_fn, decorated) |
| |
| |
| def result_wrapper(result_fn): |
| """Decorator to wrap metric `result()` function in `merge_call()`. |
| |
| Result computation is an idempotent operation that simply calculates the |
| metric value using the state variables. |
| |
| If metric state variables are distributed across replicas/devices and |
| `result()` is requested from the context of one device - This function wraps |
| `result()` in a distribution strategy `merge_call()`. With this, |
| the metric state variables will be aggregated across devices. |
| |
| Args: |
| result_fn: function that computes the metric result. |
| |
| Returns: |
| Decorated function that wraps `result_fn()` in distribution strategy |
| `merge_call()`. |
| """ |
| |
| def decorated(metric_obj, *args): |
| """Decorated function with merge_call.""" |
| has_strategy = distribution_strategy_context.has_strategy() |
| replica_context = distribution_strategy_context.get_replica_context() |
| if not has_strategy or replica_context is None: |
| result_t = array_ops.identity(result_fn(*args)) |
| else: |
| # TODO(psv): Test distribution of metrics using different distribution |
| # strategies. |
| |
| # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn |
| # with distribution object as the first parameter. We create a wrapper |
| # here so that the result function need not have that parameter. |
| def merge_fn_wrapper(distribution, merge_fn, *args): |
| # We will get `PerReplica` merge function. Taking the first one as all |
| # are identical copies of the function that we had passed below. |
| result = distribution.experimental_local_results(merge_fn)[0](*args) |
| |
| # Wrapping result in identity so that control dependency between |
| # update_op from `update_state` and result works in case result returns |
| # a tensor. |
| return array_ops.identity(result) |
| |
| # Wrapping result in merge_call. merge_call is used when we want to leave |
| # replica mode and compute a value in cross replica mode. |
| result_t = replica_context.merge_call( |
| merge_fn_wrapper, args=(result_fn,) + args) |
| |
| # We are saving the result op here to be used in train/test execution |
| # functions. This basically gives the result op that was generated with a |
| # control dep to the updates for these workflows. |
| metric_obj._call_result = result_t |
| return result_t |
| |
| return tf_decorator.make_decorator(result_fn, decorated) |
| |
| |
| def weakmethod(method): |
| """Creates a weak reference to the bound method.""" |
| |
| cls = method.im_class |
| func = method.im_func |
| instance_ref = weakref.ref(method.im_self) |
| |
| @functools.wraps(method) |
| def inner(*args, **kwargs): |
| return func.__get__(instance_ref(), cls)(*args, **kwargs) |
| |
| del method |
| return inner |
| |
| |
| def assert_thresholds_range(thresholds): |
| if thresholds is not None: |
| invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1] |
| if invalid_thresholds: |
| raise ValueError( |
| 'Threshold values must be in [0, 1]. Invalid values: {}'.format( |
| invalid_thresholds)) |
| |
| |
| def parse_init_thresholds(thresholds, default_threshold=0.5): |
| if thresholds is not None: |
| assert_thresholds_range(to_list(thresholds)) |
| thresholds = to_list(default_threshold if thresholds is None else thresholds) |
| return thresholds |
| |
| |
| class ConfusionMatrix(Enum): |
| TRUE_POSITIVES = 'tp' |
| FALSE_POSITIVES = 'fp' |
| TRUE_NEGATIVES = 'tn' |
| FALSE_NEGATIVES = 'fn' |
| |
| |
| class AUCCurve(Enum): |
| """Type of AUC Curve (ROC or PR).""" |
| ROC = 'ROC' |
| PR = 'PR' |
| |
| @staticmethod |
| def from_str(key): |
| if key in ('pr', 'PR'): |
| return AUCCurve.PR |
| elif key in ('roc', 'ROC'): |
| return AUCCurve.ROC |
| else: |
| raise ValueError('Invalid AUC curve value "%s".' % key) |
| |
| |
| class AUCSummationMethod(Enum): |
| """Type of AUC summation method. |
| |
| https://en.wikipedia.org/wiki/Riemann_sum) |
| |
| Contains the following values: |
| * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For |
| `PR` curve, interpolates (true/false) positives but not the ratio that is |
| precision (see Davis & Goadrich 2006 for details). |
| * 'minoring': Applies left summation for increasing intervals and right |
| summation for decreasing intervals. |
| * 'majoring': Applies right summation for increasing intervals and left |
| summation for decreasing intervals. |
| """ |
| INTERPOLATION = 'interpolation' |
| MAJORING = 'majoring' |
| MINORING = 'minoring' |
| |
| @staticmethod |
| def from_str(key): |
| if key in ('interpolation', 'Interpolation'): |
| return AUCSummationMethod.INTERPOLATION |
| elif key in ('majoring', 'Majoring'): |
| return AUCSummationMethod.MAJORING |
| elif key in ('minoring', 'Minoring'): |
| return AUCSummationMethod.MINORING |
| else: |
| raise ValueError('Invalid AUC summation method value "%s".' % key) |
| |
| |
| def update_confusion_matrix_variables(variables_to_update, |
| y_true, |
| y_pred, |
| thresholds, |
| top_k=None, |
| class_id=None, |
| sample_weight=None, |
| multi_label=False, |
| label_weights=None): |
| """Returns op to update the given confusion matrix variables. |
| |
| For every pair of values in y_true and y_pred: |
| |
| true_positive: y_true == True and y_pred > thresholds |
| false_negatives: y_true == True and y_pred <= thresholds |
| true_negatives: y_true == False and y_pred <= thresholds |
| false_positive: y_true == False and y_pred > thresholds |
| |
| The results will be weighted and added together. When multiple thresholds are |
| provided, we will repeat the same for every threshold. |
| |
| For estimation of these metrics over a stream of data, the function creates an |
| `update_op` operation that updates the given variables. |
| |
| If `sample_weight` is `None`, weights default to 1. |
| Use weights of 0 to mask values. |
| |
| Args: |
| variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys |
| and corresponding variables to update as values. |
| y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`. |
| y_pred: A floating point `Tensor` of arbitrary shape and whose values are in |
| the range `[0, 1]`. |
| thresholds: A float value or a python list or tuple of float thresholds in |
| `[0, 1]`, or NEG_INF (used when top_k is set). |
| top_k: Optional int, indicates that the positive labels should be limited to |
| the top k predictions. |
| class_id: Optional int, limits the prediction and labels to the class |
| specified by this argument. |
| sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as |
| `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must |
| be either `1`, or the same as the corresponding `y_true` dimension). |
| multi_label: Optional boolean indicating whether multidimensional |
| prediction/labels should be treated as multilabel responses, or flattened |
| into a single label. When True, the valus of `variables_to_update` must |
| have a second dimension equal to the number of labels in y_true and |
| y_pred, and those tensors must not be RaggedTensors. |
| label_weights: (optional) tensor of non-negative weights for multilabel |
| data. The weights are applied when calculating TP, FP, FN, and TN without |
| explicit multilabel handling (i.e. when the data is to be flattened). |
| |
| Returns: |
| Update op. |
| |
| Raises: |
| ValueError: If `y_pred` and `y_true` have mismatched shapes, or if |
| `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if |
| `variables_to_update` contains invalid keys. |
| """ |
| if multi_label and label_weights is not None: |
| raise ValueError('`label_weights` for multilabel data should be handled ' |
| 'outside of `update_confusion_matrix_variables` when ' |
| '`multi_label` is True.') |
| if variables_to_update is None: |
| return |
| y_true = math_ops.cast(y_true, dtype=dtypes.float32) |
| y_pred = math_ops.cast(y_pred, dtype=dtypes.float32) |
| if multi_label: |
| thresh_shape = array_ops.shape(thresholds) |
| num_thresholds = thresh_shape[0] |
| one_thresh = math_ops.equal( |
| math_ops.cast(1, dtype=dtypes.int32), |
| array_ops.rank(thresholds), |
| name='one_set_of_thresholds_cond') |
| else: |
| [y_pred, |
| y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true], |
| sample_weight) |
| num_thresholds = len(to_list(thresholds)) |
| one_thresh = math_ops.cast(True, dtype=dtypes.bool) |
| |
| if not any( |
| key for key in variables_to_update if key in list(ConfusionMatrix)): |
| raise ValueError( |
| 'Please provide at least one valid confusion matrix ' |
| 'variable to update. Valid variable key options are: "{}". ' |
| 'Received: "{}"'.format( |
| list(ConfusionMatrix), variables_to_update.keys())) |
| |
| invalid_keys = [ |
| key for key in variables_to_update if key not in list(ConfusionMatrix) |
| ] |
| if invalid_keys: |
| raise ValueError( |
| 'Invalid keys: {}. Valid variable key options are: "{}"'.format( |
| invalid_keys, list(ConfusionMatrix))) |
| |
| with ops.control_dependencies([ |
| check_ops.assert_greater_equal( |
| y_pred, |
| math_ops.cast(0.0, dtype=y_pred.dtype), |
| message='predictions must be >= 0'), |
| check_ops.assert_less_equal( |
| y_pred, |
| math_ops.cast(1.0, dtype=y_pred.dtype), |
| message='predictions must be <= 1') |
| ]): |
| if sample_weight is None: |
| y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions( |
| y_pred, y_true) |
| else: |
| y_pred, y_true, sample_weight = ( |
| tf_losses_utils.squeeze_or_expand_dimensions( |
| y_pred, y_true, sample_weight=sample_weight)) |
| y_pred.shape.assert_is_compatible_with(y_true.shape) |
| |
| if top_k is not None: |
| y_pred = _filter_top_k(y_pred, top_k) |
| if class_id is not None: |
| y_true = y_true[..., class_id] |
| y_pred = y_pred[..., class_id] |
| |
| pred_shape = array_ops.shape(y_pred) |
| num_predictions = pred_shape[0] |
| if y_pred.shape.ndims == 1: |
| num_labels = 1 |
| else: |
| num_labels = gen_math_ops.Prod(input=pred_shape[1:], axis=0) |
| thresh_label_tile = control_flow_ops.cond( |
| one_thresh, lambda: num_labels, |
| lambda: math_ops.cast(1, dtype=dtypes.int32)) |
| |
| # Reshape predictions and labels, adding a dim for thresholding. |
| if multi_label: |
| predictions_extra_dim = array_ops.expand_dims(y_pred, 0) |
| labels_extra_dim = array_ops.expand_dims( |
| math_ops.cast(y_true, dtype=dtypes.bool), 0) |
| else: |
| # Flatten predictions and labels when not multilabel. |
| predictions_extra_dim = array_ops.reshape(y_pred, [1, -1]) |
| labels_extra_dim = array_ops.reshape( |
| math_ops.cast(y_true, dtype=dtypes.bool), [1, -1]) |
| |
| # Tile the thresholds for every prediction. |
| if multi_label: |
| thresh_pretile_shape = [num_thresholds, 1, -1] |
| thresh_tiles = [1, num_predictions, thresh_label_tile] |
| data_tiles = [num_thresholds, 1, 1] |
| else: |
| thresh_pretile_shape = [num_thresholds, -1] |
| thresh_tiles = [1, num_predictions * num_labels] |
| data_tiles = [num_thresholds, 1] |
| |
| thresh_tiled = array_ops.tile( |
| array_ops.reshape( |
| array_ops.constant(thresholds, dtype=dtypes.float32), |
| thresh_pretile_shape), array_ops.stack(thresh_tiles)) |
| |
| # Tile the predictions for every threshold. |
| preds_tiled = array_ops.tile(predictions_extra_dim, data_tiles) |
| |
| # Compare predictions and threshold. |
| pred_is_pos = math_ops.greater(preds_tiled, thresh_tiled) |
| |
| # Tile labels by number of thresholds |
| label_is_pos = array_ops.tile(labels_extra_dim, data_tiles) |
| |
| if sample_weight is not None: |
| sample_weight = weights_broadcast_ops.broadcast_weights( |
| math_ops.cast(sample_weight, dtype=dtypes.float32), y_pred) |
| weights_tiled = array_ops.tile( |
| array_ops.reshape(sample_weight, thresh_tiles), data_tiles) |
| else: |
| weights_tiled = None |
| |
| if label_weights is not None and not multi_label: |
| label_weights = array_ops.expand_dims(label_weights, 0) |
| label_weights = weights_broadcast_ops.broadcast_weights(label_weights, |
| y_pred) |
| label_weights_tiled = array_ops.tile( |
| array_ops.reshape(label_weights, thresh_tiles), data_tiles) |
| if weights_tiled is None: |
| weights_tiled = label_weights_tiled |
| else: |
| weights_tiled = math_ops.multiply(weights_tiled, label_weights_tiled) |
| |
| update_ops = [] |
| |
| def weighted_assign_add(label, pred, weights, var): |
| label_and_pred = math_ops.cast( |
| math_ops.logical_and(label, pred), dtype=dtypes.float32) |
| if weights is not None: |
| label_and_pred *= weights |
| return var.assign_add(math_ops.reduce_sum(label_and_pred, 1)) |
| |
| loop_vars = { |
| ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos), |
| } |
| update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update |
| update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update |
| update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update |
| |
| if update_fn or update_tn: |
| pred_is_neg = math_ops.logical_not(pred_is_pos) |
| loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg) |
| |
| if update_fp or update_tn: |
| label_is_neg = math_ops.logical_not(label_is_pos) |
| loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos) |
| if update_tn: |
| loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg) |
| |
| for matrix_cond, (label, pred) in loop_vars.items(): |
| |
| if matrix_cond in variables_to_update: |
| update_ops.append( |
| weighted_assign_add(label, pred, weights_tiled, |
| variables_to_update[matrix_cond])) |
| |
| return control_flow_ops.group(update_ops) |
| |
| |
| def _filter_top_k(x, k): |
| """Filters top-k values in the last dim of x and set the rest to NEG_INF. |
| |
| Used for computing top-k prediction values in dense labels (which has the same |
| shape as predictions) for recall and precision top-k metrics. |
| |
| Args: |
| x: tensor with any dimensions. |
| k: the number of values to keep. |
| |
| Returns: |
| tensor with same shape and dtype as x. |
| """ |
| _, top_k_idx = nn_ops.top_k(x, k, sorted=False) |
| top_k_mask = math_ops.reduce_sum( |
| array_ops.one_hot(top_k_idx, array_ops.shape(x)[-1], axis=-1), axis=-2) |
| return x * top_k_mask + NEG_INF * (1 - top_k_mask) |
| |
| |
| def ragged_assert_compatible_and_get_flat_values(values, mask=None): |
| """If ragged, it checks the compatibility and then returns the flat_values. |
| |
| Note: If two tensors are dense, it does not check their compatibility. |
| Note: Although two ragged tensors with different ragged ranks could have |
| identical overall rank and dimension sizes and hence be compatible, |
| we do not support those cases. |
| Args: |
| values: A list of potentially ragged tensor of the same ragged_rank. |
| mask: A potentially ragged tensor of the same ragged_rank as elements in |
| Values. |
| |
| Returns: |
| A tuple in which the first element is the list of tensors and the second |
| is the mask tensor. ([Values], mask). Mask and the element in Values |
| are equal to the flat_values of the input arguments (if they were ragged). |
| """ |
| if isinstance(values, list): |
| is_all_ragged = \ |
| all(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values) |
| is_any_ragged = \ |
| any(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values) |
| else: |
| is_all_ragged = isinstance(values, ragged_tensor.RaggedTensor) |
| is_any_ragged = is_all_ragged |
| if (is_all_ragged and |
| ((mask is None) or isinstance(mask, ragged_tensor.RaggedTensor))): |
| to_be_stripped = False |
| if not isinstance(values, list): |
| values = [values] |
| to_be_stripped = True |
| |
| # NOTE: we leave the flat_values compatibility to |
| # tf.TensorShape `assert_is_compatible_with` |
| # check if both dynamic dimensions are equal and then use the flat_values. |
| nested_row_split_list = [rt.nested_row_splits for rt in values] |
| assertion_list = ragged_util.assert_splits_match(nested_row_split_list) |
| |
| # if both are ragged sample_weights also should be ragged with same dims. |
| if isinstance(mask, ragged_tensor.RaggedTensor): |
| assertion_list_for_mask = ragged_util.assert_splits_match( |
| [nested_row_split_list[0], mask.nested_row_splits]) |
| tmp = control_flow_ops.with_dependencies(assertion_list_for_mask, |
| mask.flat_values) |
| mask = array_ops.expand_dims(tmp, -1) |
| |
| # values has at least 1 element. |
| flat_values = [] |
| for value in values: |
| tmp = control_flow_ops.with_dependencies(assertion_list, |
| value.flat_values) |
| flat_values.append(array_ops.expand_dims(tmp, -1)) |
| |
| values = flat_values[0] if to_be_stripped else flat_values |
| |
| elif is_any_ragged: |
| raise TypeError('One of the inputs does not have acceptable types.') |
| # values are empty or value are not ragged and mask is ragged. |
| elif isinstance(mask, ragged_tensor.RaggedTensor): |
| raise TypeError('Ragged mask is not allowed with non-ragged inputs.') |
| |
| return values, mask |