RaggedTensor support for sparse_categorical_crossentropy.
diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py
index 98de43f..c70997f 100644
--- a/tensorflow/python/keras/losses.py
+++ b/tensorflow/python/keras/losses.py
@@ -1215,13 +1215,15 @@
return backend.mean(math_ops.squared_difference(y_pred, y_true), axis=-1)
-def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred):
+def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred, y_pred_extra_dim=False):
"""Apply a loss function on a per batch basis.
Args:
loss_fn: The loss function
y_true: truth values (RaggedTensor)
y_pred: predicted values (RaggedTensor)
+ y_pred_extra_dim: whether y_pred has an additional dimension compared to
+ y_true
Returns:
Loss-function result. A dense tensor if the output has a single dimension
@@ -1244,17 +1246,35 @@
])
def _convert_to_dense(inputs):
- return tuple(rt.to_tensor() for rt in inputs)
+ return tuple(
+ rt.to_tensor() if isinstance(rt, ragged_tensor.RaggedTensor) else rt
+ for rt in inputs)
- def _wrapper(inputs):
+ def _call_loss(inputs, ragged_output):
+ """ Adapt the result to ragged or dense tensor according to the expected
+ output type. This is done so that all the return values of the map
+ operation have the same type.
+ """
+ r = loss_fn(*inputs)
+ if ragged_output and not isinstance(r, ragged_tensor.RaggedTensor):
+ r = ragged_tensor.RaggedTensor.from_tensor(r)
+ elif not ragged_output and isinstance(r, ragged_tensor.RaggedTensor):
+ r = r.to_tensor()
+ return r
+
+ def _wrapper(inputs, ragged_output):
_, y_pred = inputs
if isinstance(y_pred, ragged_tensor.RaggedTensor):
return control_flow_ops.cond(
rt_is_equiv_dense(y_pred),
- lambda: loss_fn(*_convert_to_dense(inputs)), lambda: loss_fn(*inputs))
+ lambda: _call_loss(_convert_to_dense(inputs), ragged_output),
+ lambda: _call_loss(inputs, ragged_output))
return loss_fn(*inputs)
+ if not isinstance(y_true, ragged_tensor.RaggedTensor):
+ return loss_fn(y_true, y_pred.to_tensor())
+
lshape = y_pred.shape.as_list()[1:-1]
if len(lshape) > 0:
spec = ragged_tensor.RaggedTensorSpec(shape=lshape, dtype=y_pred.dtype)
@@ -1262,9 +1282,14 @@
spec = tensor_spec.TensorSpec(shape=[], dtype=y_pred.dtype)
nested_splits_list = [rt.nested_row_splits for rt in (y_true, y_pred)]
+ if y_pred_extra_dim:
+ nested_splits_list[1] = nested_splits_list[1][:-1]
+
+ map_fn = functools.partial(_wrapper, ragged_output=len(lshape) > 1)
+
assertion_list = ragged_util.assert_splits_match(nested_splits_list)
with ops.control_dependencies(assertion_list):
- return ragged_map_ops.map_fn(_wrapper, elems=(y_true, y_pred), dtype=spec)
+ return ragged_map_ops.map_fn(map_fn, elems=(y_true, y_pred), dtype=spec)
@dispatch.dispatch_for_types(mean_squared_error, ragged_tensor.RaggedTensor)
@@ -1713,6 +1738,29 @@
y_true, y_pred, from_logits=from_logits, axis=axis)
+@dispatch.dispatch_for_types(sparse_categorical_crossentropy,
+ ragged_tensor.RaggedTensor)
+def _ragged_tensor_sparse_categorical_crossentropy(
+ y_true, y_pred, from_logits=False, axis=-1):
+ """ Implements support for handling RaggedTensors.
+
+ Expected y_pred shape: (batch, sequence_len, n_classes) with sequence_len
+ being variable per batch.
+ Return shape: (batch, sequence_len).
+
+ When used by SparseCategoricalCrossentropy() with the default reduction
+ (SUM_OVER_BATCH_SIZE), the reduction averages the loss over the
+ number of elements independent of the batch. E.g. if the RaggedTensor
+ has 2 batches with [2, 1] values respectively, the resulting loss is
+ the sum of the individual loss values divided by 3.
+ """
+ fn = functools.partial(
+ sparse_categorical_crossentropy,
+ from_logits=from_logits,
+ axis=axis)
+ return _ragged_tensor_apply_loss(fn, y_true, y_pred, y_pred_extra_dim=True)
+
+
@keras_export('keras.metrics.binary_crossentropy',
'keras.losses.binary_crossentropy')
@dispatch.add_dispatch_support
diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py
index 03716d0..c8e6f72 100644
--- a/tensorflow/python/keras/losses_test.py
+++ b/tensorflow/python/keras/losses_test.py
@@ -1150,6 +1150,35 @@
loss = cce_obj(y_true, y_pred, sample_weight=2.3)
self.assertAlmostEqual(self.evaluate(loss), .7449, 3)
+ def test_ragged_tensors(self):
+ cce_obj = losses.SparseCategoricalCrossentropy()
+ y_true = ragged_factory_ops.constant([[0, 1], [2]])
+ y_pred = ragged_factory_ops.constant(
+ [[[.9, .05, .05], [.5, .89, .6]], [[.05, .01, .94]]],
+ dtype=dtypes.float32)
+ # batch losses [[0.1054, 0.8047], [0.0619]]
+ sample_weight = constant_op.constant([[1.2], [3.4]], shape=(2, 1))
+ loss = cce_obj(y_true, y_pred, sample_weight=sample_weight)
+ # sum([0.1054, 0.8047, 0.0619]) / 3
+ self.assertAlmostEqual(self.evaluate(loss), 0.4341, 3)
+
+ # Test with logits.
+ logits = ragged_factory_ops.constant([[[8., 1., 1.], [0., 9., 1.]],
+ [[2., 3., 5.]]])
+ cce_obj = losses.SparseCategoricalCrossentropy(from_logits=True)
+ # batch losses [[0.0018, 0.0004], [0.1698]]
+ loss = cce_obj(y_true, logits, sample_weight=sample_weight)
+ self.assertAlmostEqual(self.evaluate(loss), 0.1934, 3)
+
+ def test_ragged_tensors_3d(self):
+ # shape [2, 1, None]
+ y_true = ragged_factory_ops.constant([[[1, 1]], [[0]]])
+ # shape [2, 1, None, 2]
+ y_pred = ragged_factory_ops.constant(
+ [[[[0.1, 0.9], [0.1, 0.9]]], [[[0.9, 0.1]]]])
+ cce_obj = losses.SparseCategoricalCrossentropy()
+ loss = cce_obj(y_true, y_pred)
+ self.assertAlmostEqual(self.evaluate(loss), 0.1054, 3)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class HingeTest(test.TestCase):