Add axis argument
diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py
index 645ed5b..70a4cf3 100644
--- a/tensorflow/python/keras/losses.py
+++ b/tensorflow/python/keras/losses.py
@@ -1609,7 +1609,8 @@
def categorical_crossentropy(y_true,
y_pred,
from_logits=False,
- label_smoothing=0):
+ label_smoothing=0,
+ axis=-1):
"""Computes the categorical crossentropy loss.
Standalone usage:
@@ -1629,6 +1630,8 @@
label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For
example, if `0.1`, use `0.1 / num_classes` for non-target labels
and `0.9 + 0.1 / num_classes` for target labels.
+ axis: (Optional) Defaults to -1. The dimension along which the entropy is
+ computed.
Returns:
Categorical crossentropy loss value.
@@ -1644,7 +1647,7 @@
y_true = control_flow_util.smart_cond(
label_smoothing, _smooth_labels, lambda: y_true)
- return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
+ return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits, axis=axis)
@dispatch.dispatch_for_types(categorical_crossentropy,
@@ -1707,7 +1710,7 @@
@keras_export('keras.metrics.binary_crossentropy',
'keras.losses.binary_crossentropy')
@dispatch.add_dispatch_support
-def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):
+def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0, axis=-1):
"""Computes the binary crossentropy loss.
Standalone usage:
@@ -1727,6 +1730,8 @@
label_smoothing: Float in [0, 1]. If > `0` then smooth the labels by
squeezing them towards 0.5 That is, using `1. - 0.5 * label_smoothing`
for the target class and `0.5 * label_smoothing` for the non-target class.
+ axis: (Optional) Defaults to -1. The dimension along which the mean is
+ computed.
Returns:
Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`.
@@ -1742,7 +1747,7 @@
y_true = control_flow_util.smart_cond(
label_smoothing, _smooth_labels, lambda: y_true)
return K.mean(
- K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
+ K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=axis)
@dispatch.dispatch_for_types(binary_crossentropy, ragged_tensor.RaggedTensor)