Removing identity backtracking from entropy losses.

PiperOrigin-RevId: 316956157
Change-Id: I91130052e29e69ae131fe8aad0bbd1d4d42b00f1
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 391c695..9330425 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -4637,12 +4637,6 @@
   return nn.softsign(x)
 
 
-def _backtrack_identity(tensor):
-  while tensor.op.type == 'Identity':
-    tensor = tensor.op.inputs[0]
-  return tensor
-
-
 @keras_export('keras.backend.categorical_crossentropy')
 @dispatch.add_dispatch_support
 def categorical_crossentropy(target, output, from_logits=False, axis=-1):
@@ -4695,17 +4689,16 @@
     return nn.softmax_cross_entropy_with_logits_v2(
         labels=target, logits=output, axis=axis)
 
-  if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
-    output = _backtrack_identity(output)
-    if output.op.type == 'Softmax':
-      # When softmax activation function is used for output operation, we
-      # use logits from the softmax function directly to compute loss in order
-      # to prevent collapsing zero when training.
-      # See b/117284466
-      assert len(output.op.inputs) == 1
-      output = output.op.inputs[0]
-      return nn.softmax_cross_entropy_with_logits_v2(
-          labels=target, logits=output, axis=axis)
+  if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
+      output.op.type == 'Softmax'):
+    # When softmax activation function is used for output operation, we
+    # use logits from the softmax function directly to compute loss in order
+    # to prevent collapsing zero when training.
+    # See b/117284466
+    assert len(output.op.inputs) == 1
+    output = output.op.inputs[0]
+    return nn.softmax_cross_entropy_with_logits_v2(
+        labels=target, logits=output, axis=axis)
 
   # scale preds so that the class probas of each sample sum to 1
   output = output / math_ops.reduce_sum(output, axis, True)
@@ -4740,17 +4733,16 @@
   target = ops.convert_to_tensor_v2(target)
   output = ops.convert_to_tensor_v2(output)
 
-  if not from_logits and not isinstance(
-      output, (ops.EagerTensor, variables_module.Variable)):
-    output = _backtrack_identity(output)
-    if output.op.type == 'Softmax':
-      # When softmax activation function is used for output operation, we
-      # use logits from the softmax function directly to compute loss in order
-      # to prevent collapsing zero when training.
-      # See b/117284466
-      assert len(output.op.inputs) == 1
-      output = output.op.inputs[0]
-      from_logits = True
+  if (not from_logits and
+      not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
+      output.op.type == 'Softmax'):
+    # When softmax activation function is used for output operation, we
+    # use logits from the softmax function directly to compute loss in order
+    # to prevent collapsing zero when training.
+    # See b/117284466
+    assert len(output.op.inputs) == 1
+    output = output.op.inputs[0]
+    from_logits = True
 
   if not from_logits:
     epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
@@ -4821,15 +4813,14 @@
   if from_logits:
     return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
 
-  if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
-    output = _backtrack_identity(output)
-    if output.op.type == 'Sigmoid':
-      # When sigmoid activation function is used for output operation, we
-      # use logits from the sigmoid function directly to compute loss in order
-      # to prevent collapsing zero when training.
-      assert len(output.op.inputs) == 1
-      output = output.op.inputs[0]
-      return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
+  if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
+      output.op.type == 'Sigmoid'):
+    # When sigmoid activation function is used for output operation, we
+    # use logits from the sigmoid function directly to compute loss in order
+    # to prevent collapsing zero when training.
+    assert len(output.op.inputs) == 1
+    output = output.op.inputs[0]
+    return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
 
   epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
   output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
diff --git a/tensorflow/python/keras/tests/add_loss_correctness_test.py b/tensorflow/python/keras/tests/add_loss_correctness_test.py
index a19eec7..f99b285 100644
--- a/tensorflow/python/keras/tests/add_loss_correctness_test.py
+++ b/tensorflow/python/keras/tests/add_loss_correctness_test.py
@@ -34,6 +34,7 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training.rmsprop import RMSPropOptimizer
 
 MAE = losses.MeanAbsoluteError
@@ -450,6 +451,19 @@
           'Expected a symbolic Tensors or a callable for the loss value'):
         model.add_loss(model.weights[0])
 
+  @keras_parameterized.run_all_keras_modes
+  def test_add_entropy_loss_on_functional_model(self):
+    inputs = Input(shape=(1,))
+    targets = Input(shape=(1,))
+    outputs = testing_utils.Bias()(inputs)
+    model = Model([inputs, targets], outputs)
+    model.add_loss(losses.binary_crossentropy(targets, outputs))
+    model.compile('sgd', run_eagerly=testing_utils.should_run_eagerly())
+    with test.mock.patch.object(logging, 'warning') as mock_log:
+      model.fit([self.x, self.y], batch_size=3, epochs=5)
+      self.assertNotIn('Gradients do not exist for variables',
+                       str(mock_log.call_args))
+
 
 if __name__ == '__main__':
   test.main()