Update keras metrics to use memory efficient alternative when collect values for evenly distributed thresholds.
This implementation is based on the example in tf.slim. It exhibits a run time and space complexity of O(T + N), where T is the number of thresholds and N is the size of predictions. Metrics that rely on standard implementation instead exhibit a complexity of O(T * N). It could save a lot of memory when N is large.
Added a unit test to verify the memory consumption. Under eager context, the ratio of memory between old and new approach is between 80 and 500. Set the limit to 50 to avoid the flakiness.
PiperOrigin-RevId: 374315460
Change-Id: If775df7031287d647a56589a7cfe9bafa7dd8cf3
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 2ce823d..bf7c617 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -973,6 +973,8 @@
self.init_thresholds = thresholds
self.thresholds = metrics_utils.parse_init_thresholds(
thresholds, default_threshold=0.5)
+ self._evenly_distribute_thresholds = (
+ metrics_utils.evenly_distributed_thresholds(self.thresholds))
self.accumulator = self.add_weight(
'accumulator',
shape=(len(self.thresholds),),
@@ -996,6 +998,7 @@
y_true,
y_pred,
thresholds=self.thresholds,
+ evenly_distribute_thresholds=self._evenly_distribute_thresholds,
sample_weight=sample_weight)
def result(self):
@@ -1295,6 +1298,8 @@
default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
self.thresholds = metrics_utils.parse_init_thresholds(
thresholds, default_threshold=default_threshold)
+ self._evenly_distribute_thresholds = (
+ metrics_utils.evenly_distributed_thresholds(self.thresholds))
self.true_positives = self.add_weight(
'true_positives',
shape=(len(self.thresholds),),
@@ -1326,6 +1331,7 @@
y_true,
y_pred,
thresholds=self.thresholds,
+ evenly_distribute_thresholds=self._evenly_distribute_thresholds,
top_k=self.top_k,
class_id=self.class_id,
sample_weight=sample_weight)
@@ -1421,6 +1427,8 @@
default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
self.thresholds = metrics_utils.parse_init_thresholds(
thresholds, default_threshold=default_threshold)
+ self._evenly_distribute_thresholds = (
+ metrics_utils.evenly_distributed_thresholds(self.thresholds))
self.true_positives = self.add_weight(
'true_positives',
shape=(len(self.thresholds),),
@@ -1452,6 +1460,7 @@
y_true,
y_pred,
thresholds=self.thresholds,
+ evenly_distribute_thresholds=self._evenly_distribute_thresholds,
top_k=self.top_k,
class_id=self.class_id,
sample_weight=sample_weight)
@@ -1515,10 +1524,12 @@
# Compute `num_thresholds` thresholds in [0, 1]
if num_thresholds == 1:
self.thresholds = [0.5]
+ self._evenly_distribute_thresholds = False
else:
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
self.thresholds = [0.0] + thresholds + [1.0]
+ self._evenly_distribute_thresholds = True
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates confusion matrix statistics.
@@ -1543,6 +1554,7 @@
y_true,
y_pred,
thresholds=self.thresholds,
+ evenly_distribute_thresholds=self._evenly_distribute_thresholds,
class_id=self.class_id,
sample_weight=sample_weight)
@@ -2079,6 +2091,9 @@
# If specified, use the supplied thresholds.
self.num_thresholds = len(thresholds) + 2
thresholds = sorted(thresholds)
+ self._evenly_distribute_thresholds = (
+ metrics_utils.evenly_distributed_thresholds(
+ np.array([0.0] + thresholds + [1.0])))
else:
if num_thresholds <= 1:
raise ValueError('`num_thresholds` must be > 1.')
@@ -2088,6 +2103,7 @@
self.num_thresholds = num_thresholds
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
+ self._evenly_distribute_thresholds = True
# Add an endpoint "threshold" below zero and above one for either
# threshold method to account for floating point imprecisions.
@@ -2240,6 +2256,7 @@
y_true,
y_pred,
self._thresholds,
+ evenly_distribute_thresholds=self._evenly_distribute_thresholds,
sample_weight=sample_weight,
multi_label=self.multi_label,
label_weights=label_weights)
diff --git a/tensorflow/python/keras/metrics_confusion_matrix_test.py b/tensorflow/python/keras/metrics_confusion_matrix_test.py
index b39d0e3..f0c5a87 100644
--- a/tensorflow/python/keras/metrics_confusion_matrix_test.py
+++ b/tensorflow/python/keras/metrics_confusion_matrix_test.py
@@ -19,6 +19,7 @@
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import combinations
@@ -33,6 +34,11 @@
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
+try:
+ import memory_profiler # pylint:disable=g-import-not-at-top
+except ImportError:
+ memory_profiler = None
+
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class FalsePositivesTest(test.TestCase, parameterized.TestCase):
@@ -1776,5 +1782,165 @@
self.assertAllEqual(auc_obj.true_positives, np.zeros((5, 2)))
+@combinations.generate(combinations.combine(mode=['eager']))
+class ThresholdsTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters([
+ metrics.TruePositives(),
+ metrics.TrueNegatives(),
+ metrics.FalsePositives(),
+ metrics.FalseNegatives(),
+ metrics.Precision(),
+ metrics.Recall(),
+ metrics.SensitivityAtSpecificity(0.5),
+ metrics.SpecificityAtSensitivity(0.5),
+ metrics.PrecisionAtRecall(0.5),
+ metrics.RecallAtPrecision(0.5),
+ metrics.AUC()])
+ def test_with_default_thresholds(self, metric_obj):
+ # By default, the thresholds will be evenly distributed if there are more
+ # than 1. In case there is only 1 thresholds, then we expect
+ # _evenly_distribute_thresholds to be false.
+ expected = len(metric_obj.thresholds) > 1
+ self.assertEqual(metric_obj._evenly_distribute_thresholds, expected)
+
+ @parameterized.parameters([
+ metrics.TruePositives,
+ metrics.TrueNegatives,
+ metrics.FalsePositives,
+ metrics.FalseNegatives,
+ metrics.Precision,
+ metrics.Recall])
+ def test_with_manual_thresholds(self, metric_cls):
+ even_thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
+ metric_obj = metric_cls(thresholds=even_thresholds)
+ self.assertTrue(metric_obj._evenly_distribute_thresholds)
+
+ uneven_thresholds = [0.0, 0.45, 1.0]
+ metric_obj = metric_cls(thresholds=uneven_thresholds)
+ self.assertFalse(metric_obj._evenly_distribute_thresholds)
+
+ def test_manual_thresholds_auc(self):
+ # The AUC metric handles manual thresholds input differently (it will add
+ # 0.0 and 1.0 for user).
+ even_thresholds = [0.25, 0.5, 0.75]
+ auc = metrics.AUC(thresholds=even_thresholds)
+ self.assertTrue(auc._evenly_distribute_thresholds)
+
+ # Test for save model
+ cloned = metrics.AUC.from_config(auc.get_config())
+ self.assertTrue(cloned._evenly_distribute_thresholds)
+
+ uneven_thresholds = [0.45,]
+ auc = metrics.AUC(thresholds=uneven_thresholds)
+ self.assertFalse(auc._evenly_distribute_thresholds)
+
+ cloned = metrics.AUC.from_config(auc.get_config())
+ self.assertFalse(cloned._evenly_distribute_thresholds)
+
+ @parameterized.parameters([
+ metrics.TruePositives,
+ metrics.TrueNegatives,
+ metrics.FalsePositives,
+ metrics.FalseNegatives,
+ metrics.Precision,
+ metrics.Recall,
+ metrics.AUC])
+ def test_even_thresholds_correctness(self, metric_cls):
+ with compat.forward_compatibility_horizon(2021, 6, 9):
+ # make sure the old approach and new approach produce same result
+ # for evenly distributed thresholds
+ y_true = np.random.randint(2, size=(10,))
+ y_pred = np.random.rand(10)
+
+ even_thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
+ if metric_cls == metrics.AUC:
+ even_thresholds = even_thresholds[1:-1]
+ metric_obj = metric_cls(thresholds=even_thresholds)
+ metric_obj.update_state(y_true, y_pred)
+ result1 = metric_obj.result()
+
+ metric_obj2 = metric_cls(thresholds=even_thresholds)
+ # Force to use the old approach
+ metric_obj2._evenly_distribute_thresholds = False
+ metric_obj2.update_state(y_true, y_pred)
+ result2 = metric_obj2.result()
+
+ self.assertAllClose(result1, result2)
+ # Check all the variables are the same, eg tp, tn, fp, fn
+ for v1, v2 in zip(metric_obj.variables, metric_obj2.variables):
+ self.assertAllClose(v1, v2)
+
+ @parameterized.parameters([
+ metrics.SensitivityAtSpecificity,
+ metrics.SpecificityAtSensitivity,
+ metrics.PrecisionAtRecall,
+ metrics.RecallAtPrecision])
+ def test_even_thresholds_correctness_2(self, metric_cls):
+ with compat.forward_compatibility_horizon(2021, 6, 9):
+ y_true = np.random.randint(2, size=(10,))
+ y_pred = np.random.rand(10)
+
+ metric_obj = metric_cls(0.5)
+ metric_obj.update_state(y_true, y_pred)
+ result1 = metric_obj.result()
+
+ metric_obj2 = metric_cls(0.5)
+ # Force to use the old approach
+ metric_obj2._evenly_distribute_thresholds = False
+ metric_obj2.update_state(y_true, y_pred)
+ result2 = metric_obj2.result()
+
+ self.assertAllClose(result1, result2)
+ # Check all the variables are the same, eg tp, tn, fp, fn
+ for v1, v2 in zip(metric_obj.variables, metric_obj2.variables):
+ self.assertAllClose(v1, v2)
+
+
+@combinations.generate(combinations.combine(mode=['eager']))
+class AUCMemoryTest(test.TestCase, parameterized.TestCase):
+ # This test is added to measure the memory footprint for
+ # metrics_utils._update_confusion_matrix_variables_optimized().
+
+ def test_memory_usage(self):
+ if memory_profiler is None:
+ self.skipTest('Skip test since memory_profiler is not available.')
+
+ with compat.forward_compatibility_horizon(2021, 6, 9):
+ self.y_true = np.random.randint(2, size=(1024, 1024))
+ self.y_pred = np.random.rand(1024, 1024)
+
+ memeory_usage_1 = memory_profiler.memory_usage((self.even_thresholds_auc))
+ memeory_usage_2 = memory_profiler.memory_usage(
+ (self.uneven_thresholds_auc))
+ # memory usage is a list of number which sampled when running the function
+ # The pure memory consumption is approximately max(usage) - min(usage)
+ memeory_usage_1 = max(memeory_usage_1) - min(memeory_usage_1)
+ memeory_usage_2 = max(memeory_usage_2) - min(memeory_usage_2)
+
+ # Since we expect the new approach should have memory footprint as
+ # O(T + N) and old apporach has O(T * N). When N = 200 here, the ratio
+ # between 1 and 2 should be at least 50 (some room for other overhead).
+ self.assertLess(memeory_usage_1 * 50, memeory_usage_2)
+
+ def even_thresholds_auc(self):
+ auc = metrics.AUC(num_thresholds=200)
+ self.assertTrue(auc._evenly_distribute_thresholds)
+
+ auc(self.y_true, self.y_pred)
+
+ def uneven_thresholds_auc(self):
+ num_thresholds = 200
+ thresholds = [x / (num_thresholds - 1) for x in range(num_thresholds)]
+ thresholds[100] += 1 / 200
+ thresholds = thresholds[1:-1]
+
+ auc = metrics.AUC(thresholds=thresholds)
+ self.assertFalse(auc._evenly_distribute_thresholds)
+ self.assertEqual(auc.num_thresholds, num_thresholds)
+
+ auc(self.y_true, self.y_pred)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/utils/metrics_utils.py b/tensorflow/python/keras/utils/metrics_utils.py
index 098552e..b4cfe40 100644
--- a/tensorflow/python/keras/utils/metrics_utils.py
+++ b/tensorflow/python/keras/utils/metrics_utils.py
@@ -20,6 +20,9 @@
from enum import Enum
+import numpy as np
+
+from tensorflow.python.compat import compat
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -29,12 +32,14 @@
from tensorflow.python.keras.utils.generic_utils import to_list
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.ops.parallel_for.control_flow_ops import vectorized_map
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import tf_decorator
@@ -243,6 +248,231 @@
raise ValueError('Invalid AUC summation method value "%s".' % key)
+def _update_confusion_matrix_variables_optimized(
+ variables_to_update,
+ y_true,
+ y_pred,
+ thresholds,
+ multi_label=False,
+ sample_weights=None,
+ label_weights=None,
+ thresholds_with_epsilon=False):
+ """Update confusion matrix variables with memory efficient alternative.
+
+ Note that the thresholds need to be evenly distributed within the list, eg,
+ the diff between consecutive elements are the same.
+
+ To compute TP/FP/TN/FN, we are measuring a binary classifier
+ C(t) = (predictions >= t)
+ at each threshold 't'. So we have
+ TP(t) = sum( C(t) * true_labels )
+ FP(t) = sum( C(t) * false_labels )
+
+ But, computing C(t) requires computation for each t. To make it fast,
+ observe that C(t) is a cumulative integral, and so if we have
+ thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1}
+ where n = num_thresholds, and if we can compute the bucket function
+ B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
+ then we get
+ C(t_i) = sum( B(j), j >= i )
+ which is the reversed cumulative sum in tf.cumsum().
+
+ We can compute B(i) efficiently by taking advantage of the fact that
+ our thresholds are evenly distributed, in that
+ width = 1.0 / (num_thresholds - 1)
+ thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
+ Given a prediction value p, we can map it to its bucket by
+ bucket_index(p) = floor( p * (num_thresholds - 1) )
+ so we can use tf.math.unsorted_segment_sum() to update the buckets in one
+ pass.
+
+ Consider following example:
+ y_true = [0, 0, 1, 1]
+ y_pred = [0.1, 0.5, 0.3, 0.9]
+ thresholds = [0.0, 0.5, 1.0]
+ num_buckets = 2 # [0.0, 1.0], (1.0, 2.0]
+ bucket_index(y_pred) = tf.math.floor(y_pred * num_buckets)
+ = tf.math.floor([0.2, 1.0, 0.6, 1.8])
+ = [0, 0, 0, 1]
+ # The meaning of this bucket is that if any of the label is true,
+ # then 1 will be added to the corresponding bucket with the index.
+ # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the
+ # label for 1.8 is true, then 1 will be added to bucket 1.
+ #
+ # Note the second item "1.0" is floored to 0, since the value need to be
+ # strictly larger than the bucket lower bound.
+ # In the implementation, we use tf.math.ceil() - 1 to achieve this.
+ tp_bucket_value = tf.math.unsorted_segment_sum(true_labels, bucket_indices,
+ num_segments=num_thresholds)
+ = [1, 1, 0]
+ # For [1, 1, 0] here, it means there is 1 true value contributed by bucket 0,
+ # and 1 value contributed by bucket 1. When we aggregate them to together,
+ # the result become [a + b + c, b + c, c], since large thresholds will always
+ # contribute to the value for smaller thresholds.
+ true_positive = tf.math.cumsum(tp_bucket_value, reverse=True)
+ = [2, 1, 0]
+
+ This implementation exhibits a run time and space complexity of O(T + N),
+ where T is the number of thresholds and N is the size of predictions.
+ Metrics that rely on standard implementation instead exhibit a complexity of
+ O(T * N).
+
+ Args:
+ variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
+ and corresponding variables to update as values.
+ y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be cast
+ to `bool`.
+ y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
+ the range `[0, 1]`.
+ thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
+ It need to be evenly distributed (the diff between each element need to be
+ the same).
+ multi_label: Optional boolean indicating whether multidimensional
+ prediction/labels should be treated as multilabel responses, or flattened
+ into a single label. When True, the valus of `variables_to_update` must
+ have a second dimension equal to the number of labels in y_true and
+ y_pred, and those tensors must not be RaggedTensors.
+ sample_weights: Optional `Tensor` whose rank is either 0, or the same rank
+ as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
+ must be either `1`, or the same as the corresponding `y_true` dimension).
+ label_weights: Optional tensor of non-negative weights for multilabel
+ data. The weights are applied when calculating TP, FP, FN, and TN without
+ explicit multilabel handling (i.e. when the data is to be flattened).
+ thresholds_with_epsilon: Optional boolean indicating whether the leading and
+ tailing thresholds has any epsilon added for floating point imprecisions.
+ It will change how we handle the leading and tailing bucket.
+
+ Returns:
+ Update op.
+ """
+ num_thresholds = thresholds.shape.as_list()[0]
+
+ if sample_weights is None:
+ sample_weights = 1.0
+ else:
+ sample_weights = weights_broadcast_ops.broadcast_weights(
+ math_ops.cast(sample_weights, dtype=y_pred.dtype), y_pred)
+ if not multi_label:
+ sample_weights = array_ops.reshape(sample_weights, [-1])
+ if label_weights is None:
+ label_weights = 1.0
+ else:
+ label_weights = array_ops.expand_dims(label_weights, 0)
+ label_weights = weights_broadcast_ops.broadcast_weights(label_weights,
+ y_pred)
+ if not multi_label:
+ label_weights = array_ops.reshape(label_weights, [-1])
+ weights = math_ops.multiply(sample_weights, label_weights)
+
+ # We shouldn't need this, but in case there are predict value that is out of
+ # the range of [0.0, 1.0]
+ y_pred = clip_ops.clip_by_value(y_pred,
+ clip_value_min=0.0, clip_value_max=1.0)
+
+ y_true = math_ops.cast(math_ops.cast(y_true, dtypes.bool), y_true.dtype)
+ if not multi_label:
+ y_true = array_ops.reshape(y_true, [-1])
+ y_pred = array_ops.reshape(y_pred, [-1])
+
+ true_labels = math_ops.multiply(y_true, weights)
+ false_labels = math_ops.multiply((1.0 - y_true), weights)
+
+ # Compute the bucket indices for each prediction value.
+ # Since the predict value has to be strictly greater than the thresholds,
+ # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket.
+ # We have to use math.ceil(val) - 1 for the bucket.
+ bucket_indices = math_ops.ceil(y_pred * (num_thresholds - 1)) - 1
+
+ if thresholds_with_epsilon:
+ # In this case, the first bucket should actually take into account since
+ # the any prediction between [0.0, 1.0] should be larger than the first
+ # threshold. We change the bucket value from -1 to 0.
+ bucket_indices = nn_ops.relu(bucket_indices)
+
+ bucket_indices = math_ops.cast(bucket_indices, dtypes.int32)
+
+ if multi_label:
+ # We need to run bucket segment sum for each of the label class. In the
+ # multi_label case, the rank of the label is 2. We first transpose it so
+ # that the label dim becomes the first and we can parallel run though them.
+ true_labels = array_ops.transpose_v2(true_labels)
+ false_labels = array_ops.transpose_v2(false_labels)
+ bucket_indices = array_ops.transpose_v2(bucket_indices)
+
+ def gather_bucket(label_and_bucket_index):
+ label, bucket_index = label_and_bucket_index[0], label_and_bucket_index[1]
+ return math_ops.unsorted_segment_sum(
+ data=label, segment_ids=bucket_index, num_segments=num_thresholds)
+ tp_bucket_v = vectorized_map(
+ gather_bucket, (true_labels, bucket_indices))
+ fp_bucket_v = vectorized_map(
+ gather_bucket, (false_labels, bucket_indices))
+ tp = array_ops.transpose_v2(
+ math_ops.cumsum(tp_bucket_v, reverse=True, axis=1))
+ fp = array_ops.transpose_v2(
+ math_ops.cumsum(fp_bucket_v, reverse=True, axis=1))
+ else:
+ tp_bucket_v = math_ops.unsorted_segment_sum(
+ data=true_labels, segment_ids=bucket_indices,
+ num_segments=num_thresholds)
+ fp_bucket_v = math_ops.unsorted_segment_sum(
+ data=false_labels, segment_ids=bucket_indices,
+ num_segments=num_thresholds)
+ tp = math_ops.cumsum(tp_bucket_v, reverse=True)
+ fp = math_ops.cumsum(fp_bucket_v, reverse=True)
+
+ # fn = sum(true_labels) - tp
+ # tn = sum(false_labels) - fp
+ if (ConfusionMatrix.TRUE_NEGATIVES in variables_to_update or
+ ConfusionMatrix.FALSE_NEGATIVES in variables_to_update):
+ if multi_label:
+ total_true_labels = math_ops.reduce_sum(true_labels, axis=1)
+ total_false_labels = math_ops.reduce_sum(false_labels, axis=1)
+ else:
+ total_true_labels = math_ops.reduce_sum(true_labels)
+ total_false_labels = math_ops.reduce_sum(false_labels)
+
+ update_ops = []
+ if ConfusionMatrix.TRUE_POSITIVES in variables_to_update:
+ variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES]
+ update_ops.append(variable.assign_add(tp))
+ if ConfusionMatrix.FALSE_POSITIVES in variables_to_update:
+ variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES]
+ update_ops.append(variable.assign_add(fp))
+ if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update:
+ variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES]
+ tn = total_false_labels - fp
+ update_ops.append(variable.assign_add(tn))
+ if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update:
+ variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES]
+ fn = total_true_labels - tp
+ update_ops.append(variable.assign_add(fn))
+ return control_flow_ops.group(update_ops)
+
+
+def evenly_distributed_thresholds(thresholds):
+ """Check if the thresholds list is evenly distributed.
+
+ We could leverage evenly distributed thresholds to use less memory when
+ calculate metrcis like AUC where each individual threshold need to be
+ evaluted.
+
+ Args:
+ thresholds: A python list or tuple, or 1D numpy array whose value is ranged
+ in [0, 1].
+
+ Returns:
+ boolean, whether the values in the inputs are evenly distributed.
+ """
+ # Check the list value and see if it is evenly distributed.
+ num_thresholds = len(thresholds)
+ if num_thresholds < 3:
+ return False
+ even_thresholds = np.arange(num_thresholds,
+ dtype=np.float32) / (num_thresholds - 1)
+ return np.allclose(thresholds, even_thresholds, atol=backend.epsilon())
+
+
def update_confusion_matrix_variables(variables_to_update,
y_true,
y_pred,
@@ -251,7 +481,8 @@
class_id=None,
sample_weight=None,
multi_label=False,
- label_weights=None):
+ label_weights=None,
+ evenly_distribute_thresholds=False):
"""Returns op to update the given confusion matrix variables.
For every pair of values in y_true and y_pred:
@@ -293,6 +524,10 @@
label_weights: (optional) tensor of non-negative weights for multilabel
data. The weights are applied when calculating TP, FP, FN, and TN without
explicit multilabel handling (i.e. when the data is to be flattened).
+ evenly_distribute_thresholds: Boolean, whether the thresholds are evenly
+ distributed within the list. An optimized method will be used if this is
+ the case. See _update_confusion_matrix_variables_optimized() for more
+ details.
Returns:
Update op.
@@ -320,9 +555,20 @@
y_true = math_ops.cast(y_true, dtype=variable_dtype)
y_pred = math_ops.cast(y_pred, dtype=variable_dtype)
+
+ if evenly_distribute_thresholds:
+ # Check whether the thresholds has any leading or tailing epsilon added
+ # for floating point imprecision. The leading and tailing threshold will be
+ # handled bit differently as the corner case.
+ # At this point, thresholds should be a list/array with more than 2 items,
+ # and ranged between [0, 1]. See evenly_distributed_thresholds() for more
+ # details.
+ thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0
+
thresholds = ops.convert_to_tensor_v2_with_dispatch(
thresholds, dtype=variable_dtype)
- num_thresholds = thresholds.shape[0]
+ num_thresholds = thresholds.shape.as_list()[0]
+
if multi_label:
one_thresh = math_ops.equal(
math_ops.cast(1, dtype=dtypes.int32),
@@ -368,6 +614,15 @@
y_true = y_true[..., class_id]
y_pred = y_pred[..., class_id]
+ if evenly_distribute_thresholds and compat.forward_compatible(2021, 6, 8):
+ # The new approach will take effect after 2021/6/8, to give enough time
+ # for Brella release to pick up the new op tf.math.cumsum with float32.
+ return _update_confusion_matrix_variables_optimized(
+ variables_to_update, y_true, y_pred, thresholds,
+ multi_label=multi_label, sample_weights=sample_weight,
+ label_weights=label_weights,
+ thresholds_with_epsilon=thresholds_with_epsilon)
+
pred_shape = array_ops.shape(y_pred)
num_predictions = pred_shape[0]
if y_pred.shape.ndims == 1: