Fix cross entropy losses to use the underlying logits of Softmax layer by backtracking Identity op if any.
PiperOrigin-RevId: 284448871
Change-Id: Ic0c6214ba78abaf348c23abdcb9e2543e2d08a38
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 411dcca..7122d6e 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -4440,6 +4440,12 @@
return nn.softsign(x)
+def _backtrack_identity(tensor):
+ while tensor.op.type == 'Identity':
+ tensor = tensor.op.inputs[0]
+ return tensor
+
+
@keras_export('keras.backend.categorical_crossentropy')
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy between an output tensor and a target tensor.
@@ -4484,24 +4490,28 @@
dtype=float32)
"""
- if not from_logits:
- if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
- output.op.type != 'Softmax'):
- # scale preds so that the class probas of each sample sum to 1
- output = output / math_ops.reduce_sum(output, axis, True)
- # Compute cross entropy from probabilities.
- epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
- output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
- return -math_ops.reduce_sum(target * math_ops.log(output), axis)
- else:
+ if from_logits:
+ return nn.softmax_cross_entropy_with_logits_v2(
+ labels=target, logits=output, axis=axis)
+
+ if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
+ output = _backtrack_identity(output)
+ if output.op.type == 'Softmax':
# When softmax activation function is used for output operation, we
# use logits from the softmax function directly to compute loss in order
# to prevent collapsing zero when training.
# See b/117284466
assert len(output.op.inputs) == 1
output = output.op.inputs[0]
- return nn.softmax_cross_entropy_with_logits_v2(
- labels=target, logits=output, axis=axis)
+ return nn.softmax_cross_entropy_with_logits_v2(
+ labels=target, logits=output, axis=axis)
+
+ # scale preds so that the class probas of each sample sum to 1
+ output = output / math_ops.reduce_sum(output, axis, True)
+ # Compute cross entropy from probabilities.
+ epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
+ output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
+ return -math_ops.reduce_sum(target * math_ops.log(output), axis)
@keras_export('keras.backend.sparse_categorical_crossentropy')
@@ -4525,19 +4535,22 @@
Raises:
ValueError: if `axis` is neither -1 nor one of the axes of `output`.
"""
- if not from_logits:
- if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
- output.op.type != 'Softmax'):
- epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
- output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
- output = math_ops.log(output)
- else:
+ if not from_logits and not isinstance(
+ output, (ops.EagerTensor, variables_module.Variable)):
+ output = _backtrack_identity(output)
+ if output.op.type == 'Softmax':
# When softmax activation function is used for output operation, we
# use logits from the softmax function directly to compute loss in order
# to prevent collapsing zero when training.
# See b/117284466
assert len(output.op.inputs) == 1
output = output.op.inputs[0]
+ from_logits = True
+
+ if not from_logits:
+ epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
+ output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
+ output = math_ops.log(output)
if isinstance(output.shape, (tuple, list)):
output_rank = len(output.shape)
@@ -4596,23 +4609,26 @@
Returns:
A tensor.
"""
- if not from_logits:
- if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
- output.op.type != 'Sigmoid'):
- epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
- output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
+ if from_logits:
+ return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
- # Compute cross entropy from probabilities.
- bce = target * math_ops.log(output + epsilon())
- bce += (1 - target) * math_ops.log(1 - output + epsilon())
- return -bce
- else:
+ if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
+ output = _backtrack_identity(output)
+ if output.op.type == 'Sigmoid':
# When sigmoid activation function is used for output operation, we
# use logits from the sigmoid function directly to compute loss in order
# to prevent collapsing zero when training.
assert len(output.op.inputs) == 1
output = output.op.inputs[0]
- return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
+ return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
+
+ epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
+ output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
+
+ # Compute cross entropy from probabilities.
+ bce = target * math_ops.log(output + epsilon())
+ bce += (1 - target) * math_ops.log(1 - output + epsilon())
+ return -bce
@keras_export('keras.backend.sigmoid')
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 9b6f923..8d8d24f 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -1579,6 +1579,15 @@
class BackendCrossEntropyLossesTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
+ def test_binary_crossentropy_with_sigmoid(self):
+ t = keras.backend.constant([[0, 1, 0]])
+ logits = keras.backend.constant([[8., 1., 1.]])
+ p = keras.backend.sigmoid(logits)
+ p = array_ops.identity(array_ops.identity(p))
+ result = self.evaluate(keras.backend.binary_crossentropy(t, p))
+ self.assertArrayNear(result[0], [8., 0.313, 1.313], 1e-3)
+
+ @test_util.run_in_graph_and_eager_modes
def test_categorical_crossentropy_loss(self):
t = keras.backend.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
@@ -1638,6 +1647,15 @@
self.assertArrayNear(result, [.002, .003, .036], 1e-3)
@test_util.run_in_graph_and_eager_modes
+ def test_categorical_crossentropy_with_softmax(self):
+ t = keras.backend.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+ logits = keras.backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
+ p = keras.backend.softmax(logits)
+ p = array_ops.identity(array_ops.identity(p))
+ result = self.evaluate(keras.backend.categorical_crossentropy(t, p))
+ self.assertArrayNear(result, [0.002, 0.0005, 0.17], 1e-3)
+
+ @test_util.run_in_graph_and_eager_modes
def test_sparse_categorical_crossentropy_loss(self):
t = keras.backend.constant([0, 1, 2])
@@ -1702,6 +1720,15 @@
_ = f([t_val, p_val])
+ @test_util.run_in_graph_and_eager_modes
+ def test_sparse_categorical_crossentropy_with_softmax(self):
+ t = keras.backend.constant([0, 1, 2])
+ logits = keras.backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
+ p = keras.backend.softmax(logits)
+ p = array_ops.identity(array_ops.identity(p))
+ result = self.evaluate(keras.backend.sparse_categorical_crossentropy(t, p))
+ self.assertArrayNear(result, [0.002, 0.0005, 0.17], 1e-3)
+
@test_util.run_all_in_graph_and_eager_modes
@test_util.with_control_flow_v2