Removing identity backtracking from entropy losses.
PiperOrigin-RevId: 316956157
Change-Id: I91130052e29e69ae131fe8aad0bbd1d4d42b00f1
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 391c695..9330425 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -4637,12 +4637,6 @@
return nn.softsign(x)
-def _backtrack_identity(tensor):
- while tensor.op.type == 'Identity':
- tensor = tensor.op.inputs[0]
- return tensor
-
-
@keras_export('keras.backend.categorical_crossentropy')
@dispatch.add_dispatch_support
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
@@ -4695,17 +4689,16 @@
return nn.softmax_cross_entropy_with_logits_v2(
labels=target, logits=output, axis=axis)
- if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
- output = _backtrack_identity(output)
- if output.op.type == 'Softmax':
- # When softmax activation function is used for output operation, we
- # use logits from the softmax function directly to compute loss in order
- # to prevent collapsing zero when training.
- # See b/117284466
- assert len(output.op.inputs) == 1
- output = output.op.inputs[0]
- return nn.softmax_cross_entropy_with_logits_v2(
- labels=target, logits=output, axis=axis)
+ if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
+ output.op.type == 'Softmax'):
+ # When softmax activation function is used for output operation, we
+ # use logits from the softmax function directly to compute loss in order
+ # to prevent collapsing zero when training.
+ # See b/117284466
+ assert len(output.op.inputs) == 1
+ output = output.op.inputs[0]
+ return nn.softmax_cross_entropy_with_logits_v2(
+ labels=target, logits=output, axis=axis)
# scale preds so that the class probas of each sample sum to 1
output = output / math_ops.reduce_sum(output, axis, True)
@@ -4740,17 +4733,16 @@
target = ops.convert_to_tensor_v2(target)
output = ops.convert_to_tensor_v2(output)
- if not from_logits and not isinstance(
- output, (ops.EagerTensor, variables_module.Variable)):
- output = _backtrack_identity(output)
- if output.op.type == 'Softmax':
- # When softmax activation function is used for output operation, we
- # use logits from the softmax function directly to compute loss in order
- # to prevent collapsing zero when training.
- # See b/117284466
- assert len(output.op.inputs) == 1
- output = output.op.inputs[0]
- from_logits = True
+ if (not from_logits and
+ not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
+ output.op.type == 'Softmax'):
+ # When softmax activation function is used for output operation, we
+ # use logits from the softmax function directly to compute loss in order
+ # to prevent collapsing zero when training.
+ # See b/117284466
+ assert len(output.op.inputs) == 1
+ output = output.op.inputs[0]
+ from_logits = True
if not from_logits:
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
@@ -4821,15 +4813,14 @@
if from_logits:
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
- if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
- output = _backtrack_identity(output)
- if output.op.type == 'Sigmoid':
- # When sigmoid activation function is used for output operation, we
- # use logits from the sigmoid function directly to compute loss in order
- # to prevent collapsing zero when training.
- assert len(output.op.inputs) == 1
- output = output.op.inputs[0]
- return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
+ if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
+ output.op.type == 'Sigmoid'):
+ # When sigmoid activation function is used for output operation, we
+ # use logits from the sigmoid function directly to compute loss in order
+ # to prevent collapsing zero when training.
+ assert len(output.op.inputs) == 1
+ output = output.op.inputs[0]
+ return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
diff --git a/tensorflow/python/keras/tests/add_loss_correctness_test.py b/tensorflow/python/keras/tests/add_loss_correctness_test.py
index a19eec7..f99b285 100644
--- a/tensorflow/python/keras/tests/add_loss_correctness_test.py
+++ b/tensorflow/python/keras/tests/add_loss_correctness_test.py
@@ -34,6 +34,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.rmsprop import RMSPropOptimizer
MAE = losses.MeanAbsoluteError
@@ -450,6 +451,19 @@
'Expected a symbolic Tensors or a callable for the loss value'):
model.add_loss(model.weights[0])
+ @keras_parameterized.run_all_keras_modes
+ def test_add_entropy_loss_on_functional_model(self):
+ inputs = Input(shape=(1,))
+ targets = Input(shape=(1,))
+ outputs = testing_utils.Bias()(inputs)
+ model = Model([inputs, targets], outputs)
+ model.add_loss(losses.binary_crossentropy(targets, outputs))
+ model.compile('sgd', run_eagerly=testing_utils.should_run_eagerly())
+ with test.mock.patch.object(logging, 'warning') as mock_log:
+ model.fit([self.x, self.y], batch_size=3, epochs=5)
+ self.assertNotIn('Gradients do not exist for variables',
+ str(mock_log.call_args))
+
if __name__ == '__main__':
test.main()