blob: 1d769a027c8b42b945e156d3aa2b51f031b10109 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Utils related to keras metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import weakref
from enum import Enum
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.generic_utils import to_list
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.ops.losses import util as tf_losses_utils
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.tpu import tpu
from tensorflow.python.util import tf_decorator
NEG_INF = -1e10
class Reduction(Enum):
"""Types of metrics reduction.
Contains the following values:
* `SUM`: Scalar sum of weighted values.
* `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by
number of elements.
* `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights.
"""
SUM = 'sum'
SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
WEIGHTED_MEAN = 'weighted_mean'
def update_state_wrapper(update_state_fn):
"""Decorator to wrap metric `update_state()` with `add_update()`.
Args:
update_state_fn: function that accumulates metric statistics.
Returns:
Decorated function that wraps `update_state_fn()` with `add_update()`.
"""
def decorated(metric_obj, *args, **kwargs):
"""Decorated function with `add_update()`."""
strategy = distribution_strategy_context.get_strategy()
# TODO(b/142574744): Remove this check if a better solution is found for
# declaring keras Metric outside of TPUStrategy and then updating it per
# replica.
for weight in metric_obj.weights:
if (tpu.is_tpu_strategy(strategy) and
not strategy.extended.variable_created_in_scope(weight)
and not distribution_strategy_context.in_cross_replica_context()):
raise ValueError(
'Trying to run metric.update_state in replica context when '
'the metric was not created in TPUStrategy scope. '
'Make sure the keras Metric is created in TPUstrategy scope. ')
with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
update_op = update_state_fn(*args, **kwargs)
if update_op is not None: # update_op will be None in eager execution.
metric_obj.add_update(update_op)
return update_op
return tf_decorator.make_decorator(update_state_fn, decorated)
def result_wrapper(result_fn):
"""Decorator to wrap metric `result()` function in `merge_call()`.
Result computation is an idempotent operation that simply calculates the
metric value using the state variables.
If metric state variables are distributed across replicas/devices and
`result()` is requested from the context of one device - This function wraps
`result()` in a distribution strategy `merge_call()`. With this,
the metric state variables will be aggregated across devices.
Args:
result_fn: function that computes the metric result.
Returns:
Decorated function that wraps `result_fn()` in distribution strategy
`merge_call()`.
"""
def decorated(metric_obj, *args):
"""Decorated function with merge_call."""
has_strategy = distribution_strategy_context.has_strategy()
replica_context = distribution_strategy_context.get_replica_context()
if not has_strategy or replica_context is None:
result_t = array_ops.identity(result_fn(*args))
else:
# TODO(psv): Test distribution of metrics using different distribution
# strategies.
# Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
# with distribution object as the first parameter. We create a wrapper
# here so that the result function need not have that parameter.
def merge_fn_wrapper(distribution, merge_fn, *args):
# We will get `PerReplica` merge function. Taking the first one as all
# are identical copies of the function that we had passed below.
result = distribution.experimental_local_results(merge_fn)[0](*args)
# Wrapping result in identity so that control dependency between
# update_op from `update_state` and result works in case result returns
# a tensor.
return array_ops.identity(result)
# Wrapping result in merge_call. merge_call is used when we want to leave
# replica mode and compute a value in cross replica mode.
result_t = replica_context.merge_call(
merge_fn_wrapper, args=(result_fn,) + args)
# We are saving the result op here to be used in train/test execution
# functions. This basically gives the result op that was generated with a
# control dep to the updates for these workflows.
metric_obj._call_result = result_t
return result_t
return tf_decorator.make_decorator(result_fn, decorated)
def weakmethod(method):
"""Creates a weak reference to the bound method."""
cls = method.im_class
func = method.im_func
instance_ref = weakref.ref(method.im_self)
@functools.wraps(method)
def inner(*args, **kwargs):
return func.__get__(instance_ref(), cls)(*args, **kwargs)
del method
return inner
def assert_thresholds_range(thresholds):
if thresholds is not None:
invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1]
if invalid_thresholds:
raise ValueError(
'Threshold values must be in [0, 1]. Invalid values: {}'.format(
invalid_thresholds))
def parse_init_thresholds(thresholds, default_threshold=0.5):
if thresholds is not None:
assert_thresholds_range(to_list(thresholds))
thresholds = to_list(default_threshold if thresholds is None else thresholds)
return thresholds
class ConfusionMatrix(Enum):
TRUE_POSITIVES = 'tp'
FALSE_POSITIVES = 'fp'
TRUE_NEGATIVES = 'tn'
FALSE_NEGATIVES = 'fn'
class AUCCurve(Enum):
"""Type of AUC Curve (ROC or PR)."""
ROC = 'ROC'
PR = 'PR'
@staticmethod
def from_str(key):
if key in ('pr', 'PR'):
return AUCCurve.PR
elif key in ('roc', 'ROC'):
return AUCCurve.ROC
else:
raise ValueError('Invalid AUC curve value "%s".' % key)
class AUCSummationMethod(Enum):
"""Type of AUC summation method.
https://en.wikipedia.org/wiki/Riemann_sum)
Contains the following values:
* 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
`PR` curve, interpolates (true/false) positives but not the ratio that is
precision (see Davis & Goadrich 2006 for details).
* 'minoring': Applies left summation for increasing intervals and right
summation for decreasing intervals.
* 'majoring': Applies right summation for increasing intervals and left
summation for decreasing intervals.
"""
INTERPOLATION = 'interpolation'
MAJORING = 'majoring'
MINORING = 'minoring'
@staticmethod
def from_str(key):
if key in ('interpolation', 'Interpolation'):
return AUCSummationMethod.INTERPOLATION
elif key in ('majoring', 'Majoring'):
return AUCSummationMethod.MAJORING
elif key in ('minoring', 'Minoring'):
return AUCSummationMethod.MINORING
else:
raise ValueError('Invalid AUC summation method value "%s".' % key)
def update_confusion_matrix_variables(variables_to_update,
y_true,
y_pred,
thresholds,
top_k=None,
class_id=None,
sample_weight=None,
multi_label=False,
label_weights=None):
"""Returns op to update the given confusion matrix variables.
For every pair of values in y_true and y_pred:
true_positive: y_true == True and y_pred > thresholds
false_negatives: y_true == True and y_pred <= thresholds
true_negatives: y_true == False and y_pred <= thresholds
false_positive: y_true == False and y_pred > thresholds
The results will be weighted and added together. When multiple thresholds are
provided, we will repeat the same for every threshold.
For estimation of these metrics over a stream of data, the function creates an
`update_op` operation that updates the given variables.
If `sample_weight` is `None`, weights default to 1.
Use weights of 0 to mask values.
Args:
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
and corresponding variables to update as values.
y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
the range `[0, 1]`.
thresholds: A float value or a python list or tuple of float thresholds in
`[0, 1]`, or NEG_INF (used when top_k is set).
top_k: Optional int, indicates that the positive labels should be limited to
the top k predictions.
class_id: Optional int, limits the prediction and labels to the class
specified by this argument.
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
`y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must
be either `1`, or the same as the corresponding `y_true` dimension).
multi_label: Optional boolean indicating whether multidimensional
prediction/labels should be treated as multilabel responses, or flattened
into a single label. When True, the valus of `variables_to_update` must
have a second dimension equal to the number of labels in y_true and
y_pred, and those tensors must not be RaggedTensors.
label_weights: (optional) tensor of non-negative weights for multilabel
data. The weights are applied when calculating TP, FP, FN, and TN without
explicit multilabel handling (i.e. when the data is to be flattened).
Returns:
Update op.
Raises:
ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
`sample_weight` is not `None` and its shape doesn't match `y_pred`, or if
`variables_to_update` contains invalid keys.
"""
if multi_label and label_weights is not None:
raise ValueError('`label_weights` for multilabel data should be handled '
'outside of `update_confusion_matrix_variables` when '
'`multi_label` is True.')
if variables_to_update is None:
return
y_true = math_ops.cast(y_true, dtype=dtypes.float32)
y_pred = math_ops.cast(y_pred, dtype=dtypes.float32)
if multi_label:
thresh_shape = array_ops.shape(thresholds)
num_thresholds = thresh_shape[0]
one_thresh = math_ops.equal(
math_ops.cast(1, dtype=dtypes.int32),
array_ops.rank(thresholds),
name='one_set_of_thresholds_cond')
else:
[y_pred,
y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
sample_weight)
num_thresholds = len(to_list(thresholds))
one_thresh = math_ops.cast(True, dtype=dtypes.bool)
if not any(
key for key in variables_to_update if key in list(ConfusionMatrix)):
raise ValueError(
'Please provide at least one valid confusion matrix '
'variable to update. Valid variable key options are: "{}". '
'Received: "{}"'.format(
list(ConfusionMatrix), variables_to_update.keys()))
invalid_keys = [
key for key in variables_to_update if key not in list(ConfusionMatrix)
]
if invalid_keys:
raise ValueError(
'Invalid keys: {}. Valid variable key options are: "{}"'.format(
invalid_keys, list(ConfusionMatrix)))
with ops.control_dependencies([
check_ops.assert_greater_equal(
y_pred,
math_ops.cast(0.0, dtype=y_pred.dtype),
message='predictions must be >= 0'),
check_ops.assert_less_equal(
y_pred,
math_ops.cast(1.0, dtype=y_pred.dtype),
message='predictions must be <= 1')
]):
if sample_weight is None:
y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions(
y_pred, y_true)
else:
y_pred, y_true, sample_weight = (
tf_losses_utils.squeeze_or_expand_dimensions(
y_pred, y_true, sample_weight=sample_weight))
y_pred.shape.assert_is_compatible_with(y_true.shape)
if top_k is not None:
y_pred = _filter_top_k(y_pred, top_k)
if class_id is not None:
y_true = y_true[..., class_id]
y_pred = y_pred[..., class_id]
pred_shape = array_ops.shape(y_pred)
num_predictions = pred_shape[0]
if y_pred.shape.ndims == 1:
num_labels = 1
else:
num_labels = gen_math_ops.Prod(input=pred_shape[1:], axis=0)
thresh_label_tile = control_flow_ops.cond(
one_thresh, lambda: num_labels,
lambda: math_ops.cast(1, dtype=dtypes.int32))
# Reshape predictions and labels, adding a dim for thresholding.
if multi_label:
predictions_extra_dim = array_ops.expand_dims(y_pred, 0)
labels_extra_dim = array_ops.expand_dims(
math_ops.cast(y_true, dtype=dtypes.bool), 0)
else:
# Flatten predictions and labels when not multilabel.
predictions_extra_dim = array_ops.reshape(y_pred, [1, -1])
labels_extra_dim = array_ops.reshape(
math_ops.cast(y_true, dtype=dtypes.bool), [1, -1])
# Tile the thresholds for every prediction.
if multi_label:
thresh_pretile_shape = [num_thresholds, 1, -1]
thresh_tiles = [1, num_predictions, thresh_label_tile]
data_tiles = [num_thresholds, 1, 1]
else:
thresh_pretile_shape = [num_thresholds, -1]
thresh_tiles = [1, num_predictions * num_labels]
data_tiles = [num_thresholds, 1]
thresh_tiled = array_ops.tile(
array_ops.reshape(
array_ops.constant(thresholds, dtype=dtypes.float32),
thresh_pretile_shape), array_ops.stack(thresh_tiles))
# Tile the predictions for every threshold.
preds_tiled = array_ops.tile(predictions_extra_dim, data_tiles)
# Compare predictions and threshold.
pred_is_pos = math_ops.greater(preds_tiled, thresh_tiled)
# Tile labels by number of thresholds
label_is_pos = array_ops.tile(labels_extra_dim, data_tiles)
if sample_weight is not None:
sample_weight = weights_broadcast_ops.broadcast_weights(
math_ops.cast(sample_weight, dtype=dtypes.float32), y_pred)
weights_tiled = array_ops.tile(
array_ops.reshape(sample_weight, thresh_tiles), data_tiles)
else:
weights_tiled = None
if label_weights is not None and not multi_label:
label_weights = array_ops.expand_dims(label_weights, 0)
label_weights = weights_broadcast_ops.broadcast_weights(label_weights,
y_pred)
label_weights_tiled = array_ops.tile(
array_ops.reshape(label_weights, thresh_tiles), data_tiles)
if weights_tiled is None:
weights_tiled = label_weights_tiled
else:
weights_tiled = math_ops.multiply(weights_tiled, label_weights_tiled)
update_ops = []
def weighted_assign_add(label, pred, weights, var):
label_and_pred = math_ops.cast(
math_ops.logical_and(label, pred), dtype=dtypes.float32)
if weights is not None:
label_and_pred *= weights
return var.assign_add(math_ops.reduce_sum(label_and_pred, 1))
loop_vars = {
ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
}
update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
if update_fn or update_tn:
pred_is_neg = math_ops.logical_not(pred_is_pos)
loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
if update_fp or update_tn:
label_is_neg = math_ops.logical_not(label_is_pos)
loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
if update_tn:
loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg)
for matrix_cond, (label, pred) in loop_vars.items():
if matrix_cond in variables_to_update:
update_ops.append(
weighted_assign_add(label, pred, weights_tiled,
variables_to_update[matrix_cond]))
return control_flow_ops.group(update_ops)
def _filter_top_k(x, k):
"""Filters top-k values in the last dim of x and set the rest to NEG_INF.
Used for computing top-k prediction values in dense labels (which has the same
shape as predictions) for recall and precision top-k metrics.
Args:
x: tensor with any dimensions.
k: the number of values to keep.
Returns:
tensor with same shape and dtype as x.
"""
_, top_k_idx = nn_ops.top_k(x, k, sorted=False)
top_k_mask = math_ops.reduce_sum(
array_ops.one_hot(top_k_idx, array_ops.shape(x)[-1], axis=-1), axis=-2)
return x * top_k_mask + NEG_INF * (1 - top_k_mask)
def ragged_assert_compatible_and_get_flat_values(values, mask=None):
"""If ragged, it checks the compatibility and then returns the flat_values.
Note: If two tensors are dense, it does not check their compatibility.
Note: Although two ragged tensors with different ragged ranks could have
identical overall rank and dimension sizes and hence be compatible,
we do not support those cases.
Args:
values: A list of potentially ragged tensor of the same ragged_rank.
mask: A potentially ragged tensor of the same ragged_rank as elements in
Values.
Returns:
A tuple in which the first element is the list of tensors and the second
is the mask tensor. ([Values], mask). Mask and the element in Values
are equal to the flat_values of the input arguments (if they were ragged).
"""
if isinstance(values, list):
is_all_ragged = \
all(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
is_any_ragged = \
any(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
else:
is_all_ragged = isinstance(values, ragged_tensor.RaggedTensor)
is_any_ragged = is_all_ragged
if (is_all_ragged and
((mask is None) or isinstance(mask, ragged_tensor.RaggedTensor))):
to_be_stripped = False
if not isinstance(values, list):
values = [values]
to_be_stripped = True
# NOTE: we leave the flat_values compatibility to
# tf.TensorShape `assert_is_compatible_with`
# check if both dynamic dimensions are equal and then use the flat_values.
nested_row_split_list = [rt.nested_row_splits for rt in values]
assertion_list = ragged_util.assert_splits_match(nested_row_split_list)
# if both are ragged sample_weights also should be ragged with same dims.
if isinstance(mask, ragged_tensor.RaggedTensor):
assertion_list_for_mask = ragged_util.assert_splits_match(
[nested_row_split_list[0], mask.nested_row_splits])
tmp = control_flow_ops.with_dependencies(assertion_list_for_mask,
mask.flat_values)
mask = array_ops.expand_dims(tmp, -1)
# values has at least 1 element.
flat_values = []
for value in values:
tmp = control_flow_ops.with_dependencies(assertion_list,
value.flat_values)
flat_values.append(array_ops.expand_dims(tmp, -1))
values = flat_values[0] if to_be_stripped else flat_values
elif is_any_ragged:
raise TypeError('One of the inputs does not have acceptable types.')
# values are empty or value are not ragged and mask is ragged.
elif isinstance(mask, ragged_tensor.RaggedTensor):
raise TypeError('Ragged mask is not allowed with non-ragged inputs.')
return values, mask