Add axis argument
diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py
index 645ed5b..70a4cf3 100644
--- a/tensorflow/python/keras/losses.py
+++ b/tensorflow/python/keras/losses.py
@@ -1609,7 +1609,8 @@
 def categorical_crossentropy(y_true,
                              y_pred,
                              from_logits=False,
-                             label_smoothing=0):
+                             label_smoothing=0,
+                             axis=-1):
   """Computes the categorical crossentropy loss.
 
   Standalone usage:
@@ -1629,6 +1630,8 @@
     label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For
       example, if `0.1`, use `0.1 / num_classes` for non-target labels
       and `0.9 + 0.1 / num_classes` for target labels.
+    axis: (Optional) Defaults to -1. The dimension along which the entropy is
+      computed.
 
   Returns:
     Categorical crossentropy loss value.
@@ -1644,7 +1647,7 @@
 
   y_true = control_flow_util.smart_cond(
       label_smoothing, _smooth_labels, lambda: y_true)
-  return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
+  return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits, axis=axis)
 
 
 @dispatch.dispatch_for_types(categorical_crossentropy,
@@ -1707,7 +1710,7 @@
 @keras_export('keras.metrics.binary_crossentropy',
               'keras.losses.binary_crossentropy')
 @dispatch.add_dispatch_support
-def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):
+def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0, axis=-1):
   """Computes the binary crossentropy loss.
 
   Standalone usage:
@@ -1727,6 +1730,8 @@
     label_smoothing: Float in [0, 1]. If > `0` then smooth the labels by 
       squeezing them towards 0.5 That is, using `1. - 0.5 * label_smoothing`
       for the target class and `0.5 * label_smoothing` for the non-target class.
+    axis: (Optional) Defaults to -1. The dimension along which the mean is
+      computed.
 
   Returns:
     Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`.
@@ -1742,7 +1747,7 @@
   y_true = control_flow_util.smart_cond(
       label_smoothing, _smooth_labels, lambda: y_true)
   return K.mean(
-      K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
+      K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=axis)
 
 
 @dispatch.dispatch_for_types(binary_crossentropy, ragged_tensor.RaggedTensor)