Fix cross entropy losses to use the underlying logits of Softmax layer by backtracking Identity op if any.

PiperOrigin-RevId: 284448871
Change-Id: Ic0c6214ba78abaf348c23abdcb9e2543e2d08a38
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 411dcca..7122d6e 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -4440,6 +4440,12 @@
   return nn.softsign(x)
 
 
+def _backtrack_identity(tensor):
+  while tensor.op.type == 'Identity':
+    tensor = tensor.op.inputs[0]
+  return tensor
+
+
 @keras_export('keras.backend.categorical_crossentropy')
 def categorical_crossentropy(target, output, from_logits=False, axis=-1):
   """Categorical crossentropy between an output tensor and a target tensor.
@@ -4484,24 +4490,28 @@
   dtype=float32)
 
   """
-  if not from_logits:
-    if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
-        output.op.type != 'Softmax'):
-      # scale preds so that the class probas of each sample sum to 1
-      output = output / math_ops.reduce_sum(output, axis, True)
-      # Compute cross entropy from probabilities.
-      epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
-      output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
-      return -math_ops.reduce_sum(target * math_ops.log(output), axis)
-    else:
+  if from_logits:
+    return nn.softmax_cross_entropy_with_logits_v2(
+        labels=target, logits=output, axis=axis)
+
+  if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
+    output = _backtrack_identity(output)
+    if output.op.type == 'Softmax':
       # When softmax activation function is used for output operation, we
       # use logits from the softmax function directly to compute loss in order
       # to prevent collapsing zero when training.
       # See b/117284466
       assert len(output.op.inputs) == 1
       output = output.op.inputs[0]
-  return nn.softmax_cross_entropy_with_logits_v2(
-      labels=target, logits=output, axis=axis)
+      return nn.softmax_cross_entropy_with_logits_v2(
+          labels=target, logits=output, axis=axis)
+
+  # scale preds so that the class probas of each sample sum to 1
+  output = output / math_ops.reduce_sum(output, axis, True)
+  # Compute cross entropy from probabilities.
+  epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
+  output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
+  return -math_ops.reduce_sum(target * math_ops.log(output), axis)
 
 
 @keras_export('keras.backend.sparse_categorical_crossentropy')
@@ -4525,19 +4535,22 @@
   Raises:
       ValueError: if `axis` is neither -1 nor one of the axes of `output`.
   """
-  if not from_logits:
-    if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
-        output.op.type != 'Softmax'):
-      epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
-      output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
-      output = math_ops.log(output)
-    else:
+  if not from_logits and not isinstance(
+      output, (ops.EagerTensor, variables_module.Variable)):
+    output = _backtrack_identity(output)
+    if output.op.type == 'Softmax':
       # When softmax activation function is used for output operation, we
       # use logits from the softmax function directly to compute loss in order
       # to prevent collapsing zero when training.
       # See b/117284466
       assert len(output.op.inputs) == 1
       output = output.op.inputs[0]
+      from_logits = True
+
+  if not from_logits:
+    epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
+    output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
+    output = math_ops.log(output)
 
   if isinstance(output.shape, (tuple, list)):
     output_rank = len(output.shape)
@@ -4596,23 +4609,26 @@
   Returns:
       A tensor.
   """
-  if not from_logits:
-    if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
-        output.op.type != 'Sigmoid'):
-      epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
-      output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
+  if from_logits:
+    return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
 
-      # Compute cross entropy from probabilities.
-      bce = target * math_ops.log(output + epsilon())
-      bce += (1 - target) * math_ops.log(1 - output + epsilon())
-      return -bce
-    else:
+  if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
+    output = _backtrack_identity(output)
+    if output.op.type == 'Sigmoid':
       # When sigmoid activation function is used for output operation, we
       # use logits from the sigmoid function directly to compute loss in order
       # to prevent collapsing zero when training.
       assert len(output.op.inputs) == 1
       output = output.op.inputs[0]
-  return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
+      return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
+
+  epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
+  output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
+
+  # Compute cross entropy from probabilities.
+  bce = target * math_ops.log(output + epsilon())
+  bce += (1 - target) * math_ops.log(1 - output + epsilon())
+  return -bce
 
 
 @keras_export('keras.backend.sigmoid')
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 9b6f923..8d8d24f 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -1579,6 +1579,15 @@
 class BackendCrossEntropyLossesTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
+  def test_binary_crossentropy_with_sigmoid(self):
+    t = keras.backend.constant([[0, 1, 0]])
+    logits = keras.backend.constant([[8., 1., 1.]])
+    p = keras.backend.sigmoid(logits)
+    p = array_ops.identity(array_ops.identity(p))
+    result = self.evaluate(keras.backend.binary_crossentropy(t, p))
+    self.assertArrayNear(result[0], [8., 0.313, 1.313], 1e-3)
+
+  @test_util.run_in_graph_and_eager_modes
   def test_categorical_crossentropy_loss(self):
     t = keras.backend.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
 
@@ -1638,6 +1647,15 @@
     self.assertArrayNear(result, [.002, .003, .036], 1e-3)
 
   @test_util.run_in_graph_and_eager_modes
+  def test_categorical_crossentropy_with_softmax(self):
+    t = keras.backend.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+    logits = keras.backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
+    p = keras.backend.softmax(logits)
+    p = array_ops.identity(array_ops.identity(p))
+    result = self.evaluate(keras.backend.categorical_crossentropy(t, p))
+    self.assertArrayNear(result, [0.002, 0.0005, 0.17], 1e-3)
+
+  @test_util.run_in_graph_and_eager_modes
   def test_sparse_categorical_crossentropy_loss(self):
     t = keras.backend.constant([0, 1, 2])
 
@@ -1702,6 +1720,15 @@
 
       _ = f([t_val, p_val])
 
+  @test_util.run_in_graph_and_eager_modes
+  def test_sparse_categorical_crossentropy_with_softmax(self):
+    t = keras.backend.constant([0, 1, 2])
+    logits = keras.backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
+    p = keras.backend.softmax(logits)
+    p = array_ops.identity(array_ops.identity(p))
+    result = self.evaluate(keras.backend.sparse_categorical_crossentropy(t, p))
+    self.assertArrayNear(result, [0.002, 0.0005, 0.17], 1e-3)
+
 
 @test_util.run_all_in_graph_and_eager_modes
 @test_util.with_control_flow_v2