RaggedTensor support for sparse_categorical_crossentropy.
diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py
index 98de43f..c70997f 100644
--- a/tensorflow/python/keras/losses.py
+++ b/tensorflow/python/keras/losses.py
@@ -1215,13 +1215,15 @@
   return backend.mean(math_ops.squared_difference(y_pred, y_true), axis=-1)
 
 
-def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred):
+def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred, y_pred_extra_dim=False):
   """Apply a loss function on a per batch basis.
 
   Args:
     loss_fn: The loss function
     y_true: truth values (RaggedTensor)
     y_pred: predicted values (RaggedTensor)
+    y_pred_extra_dim: whether y_pred has an additional dimension compared to
+      y_true
 
   Returns:
     Loss-function result. A dense tensor if the output has a single dimension
@@ -1244,17 +1246,35 @@
     ])
 
   def _convert_to_dense(inputs):
-    return tuple(rt.to_tensor() for rt in inputs)
+    return tuple(
+        rt.to_tensor() if isinstance(rt, ragged_tensor.RaggedTensor) else rt
+        for rt in inputs)
 
-  def _wrapper(inputs):
+  def _call_loss(inputs, ragged_output):
+    """ Adapt the result to ragged or dense tensor according to the expected
+        output type. This is done so that all the return values of the map
+        operation have the same type.
+    """
+    r = loss_fn(*inputs)
+    if ragged_output and not isinstance(r, ragged_tensor.RaggedTensor):
+      r = ragged_tensor.RaggedTensor.from_tensor(r)
+    elif not ragged_output and isinstance(r, ragged_tensor.RaggedTensor):
+      r = r.to_tensor()
+    return r
+
+  def _wrapper(inputs, ragged_output):
     _, y_pred = inputs
     if isinstance(y_pred, ragged_tensor.RaggedTensor):
       return control_flow_ops.cond(
           rt_is_equiv_dense(y_pred),
-          lambda: loss_fn(*_convert_to_dense(inputs)), lambda: loss_fn(*inputs))
+          lambda: _call_loss(_convert_to_dense(inputs), ragged_output),
+          lambda: _call_loss(inputs, ragged_output))
 
     return loss_fn(*inputs)
 
+  if not isinstance(y_true, ragged_tensor.RaggedTensor):
+    return loss_fn(y_true, y_pred.to_tensor())
+
   lshape = y_pred.shape.as_list()[1:-1]
   if len(lshape) > 0:
     spec = ragged_tensor.RaggedTensorSpec(shape=lshape, dtype=y_pred.dtype)
@@ -1262,9 +1282,14 @@
     spec = tensor_spec.TensorSpec(shape=[], dtype=y_pred.dtype)
 
   nested_splits_list = [rt.nested_row_splits for rt in (y_true, y_pred)]
+  if y_pred_extra_dim:
+    nested_splits_list[1] = nested_splits_list[1][:-1]
+
+  map_fn = functools.partial(_wrapper, ragged_output=len(lshape) > 1)
+
   assertion_list = ragged_util.assert_splits_match(nested_splits_list)
   with ops.control_dependencies(assertion_list):
-    return ragged_map_ops.map_fn(_wrapper, elems=(y_true, y_pred), dtype=spec)
+    return ragged_map_ops.map_fn(map_fn, elems=(y_true, y_pred), dtype=spec)
 
 
 @dispatch.dispatch_for_types(mean_squared_error, ragged_tensor.RaggedTensor)
@@ -1713,6 +1738,29 @@
       y_true, y_pred, from_logits=from_logits, axis=axis)
 
 
+@dispatch.dispatch_for_types(sparse_categorical_crossentropy,
+                             ragged_tensor.RaggedTensor)
+def _ragged_tensor_sparse_categorical_crossentropy(
+    y_true, y_pred, from_logits=False, axis=-1):
+  """ Implements support for handling RaggedTensors.
+
+      Expected y_pred shape: (batch, sequence_len, n_classes) with sequence_len
+      being variable per batch.
+      Return shape: (batch, sequence_len).
+
+      When used by SparseCategoricalCrossentropy() with the default reduction
+      (SUM_OVER_BATCH_SIZE), the reduction averages the loss over the
+      number of elements independent of the batch. E.g. if the RaggedTensor
+      has 2 batches with [2, 1] values respectively, the resulting loss is
+      the sum of the individual loss values divided by 3.
+  """
+  fn = functools.partial(
+      sparse_categorical_crossentropy,
+      from_logits=from_logits,
+      axis=axis)
+  return _ragged_tensor_apply_loss(fn, y_true, y_pred, y_pred_extra_dim=True)
+
+
 @keras_export('keras.metrics.binary_crossentropy',
               'keras.losses.binary_crossentropy')
 @dispatch.add_dispatch_support
diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py
index 03716d0..c8e6f72 100644
--- a/tensorflow/python/keras/losses_test.py
+++ b/tensorflow/python/keras/losses_test.py
@@ -1150,6 +1150,35 @@
     loss = cce_obj(y_true, y_pred, sample_weight=2.3)
     self.assertAlmostEqual(self.evaluate(loss), .7449, 3)
 
+  def test_ragged_tensors(self):
+    cce_obj = losses.SparseCategoricalCrossentropy()
+    y_true = ragged_factory_ops.constant([[0, 1], [2]])
+    y_pred = ragged_factory_ops.constant(
+        [[[.9, .05, .05], [.5, .89, .6]], [[.05, .01, .94]]],
+        dtype=dtypes.float32)
+    # batch losses [[0.1054, 0.8047], [0.0619]]
+    sample_weight = constant_op.constant([[1.2], [3.4]], shape=(2, 1))
+    loss = cce_obj(y_true, y_pred, sample_weight=sample_weight)
+    # sum([0.1054, 0.8047, 0.0619]) / 3
+    self.assertAlmostEqual(self.evaluate(loss), 0.4341, 3)
+
+    # Test with logits.
+    logits = ragged_factory_ops.constant([[[8., 1., 1.], [0., 9., 1.]],
+                                          [[2., 3., 5.]]])
+    cce_obj = losses.SparseCategoricalCrossentropy(from_logits=True)
+    # batch losses [[0.0018, 0.0004], [0.1698]]
+    loss = cce_obj(y_true, logits, sample_weight=sample_weight)
+    self.assertAlmostEqual(self.evaluate(loss), 0.1934, 3)
+
+  def test_ragged_tensors_3d(self):
+    # shape [2, 1, None]
+    y_true = ragged_factory_ops.constant([[[1, 1]], [[0]]])
+    # shape [2, 1, None, 2]
+    y_pred = ragged_factory_ops.constant(
+        [[[[0.1, 0.9], [0.1, 0.9]]], [[[0.9, 0.1]]]])
+    cce_obj = losses.SparseCategoricalCrossentropy()
+    loss = cce_obj(y_true, y_pred)
+    self.assertAlmostEqual(self.evaluate(loss), 0.1054, 3)
 
 @combinations.generate(combinations.combine(mode=['graph', 'eager']))
 class HingeTest(test.TestCase):