Rollforward 0e074ae064c441185e09b1fe4b0ee87676f10266.
The issue was that the CL checked the `dtype` attribute of the input logits, but the logits might be a list or anything else convertible to a tensor. Now the logits are explicitly converted to a tensor first.
END_PUBLIC
PiperOrigin-RevId: 426035926
Change-Id: I5df63cebe034060ed569f29544b5caf48dafe77c
diff --git a/tensorflow/python/kernel_tests/nn_ops/ctc_loss_op_test.py b/tensorflow/python/kernel_tests/nn_ops/ctc_loss_op_test.py
index 650bc95..b1fb7bf 100644
--- a/tensorflow/python/kernel_tests/nn_ops/ctc_loss_op_test.py
+++ b/tensorflow/python/kernel_tests/nn_ops/ctc_loss_op_test.py
@@ -936,18 +936,20 @@
[[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out)
-def _ctc_loss_v3(labels, logits, label_length, logit_length, use_gpu):
+def _ctc_loss_v3(labels, logits, label_length, logit_length, use_gpu,
+ sparse=True):
with test_util.device(use_gpu=use_gpu):
- sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
+ if sparse:
+ labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
with backprop.GradientTape() as t:
t.watch(logits)
ref_loss = ctc_ops.ctc_loss_v3(
- labels=sparse_labels,
+ labels=labels,
logits=logits,
label_length=label_length,
logit_length=logit_length,
blank_index=0)
- ref_grad = t.gradient(ref_loss, [logits])
+ ref_grad = t.gradient(ref_loss, logits)
return ref_loss, ref_grad
@@ -1000,6 +1002,78 @@
self.assertAllClose(loss, ref_loss, atol=1e-6)
self.assertAllClose(grad, ref_grad, atol=2e-6)
+ @parameterized.parameters([False, True])
+ def testCtcLossFp16(self, sparse_labels):
+ batch_size = 8
+ num_labels = 6
+ max_label_length = 5
+ num_frames = 12
+
+ labels = np.random.randint(1, num_labels, [batch_size, max_label_length])
+ labels = ops.convert_to_tensor(labels, dtypes.int64)
+ fp16_logits = np.random.uniform(size=[num_frames, batch_size, num_labels])
+ fp16_logits = ops.convert_to_tensor(fp16_logits, dtypes.float16)
+ label_length = np.random.randint(2, max_label_length, [batch_size])
+ label_length = ops.convert_to_tensor(label_length, dtypes.int64)
+
+ label_mask = array_ops.sequence_mask(
+ label_length, maxlen=max_label_length, dtype=label_length.dtype)
+ labels *= label_mask
+ logit_length = [num_frames] * batch_size
+
+ fp16_loss, fp16_grad = _ctc_loss_v3(
+ labels, fp16_logits, label_length, logit_length, use_gpu=True,
+ sparse=sparse_labels)
+ fp32_loss, fp32_grad = _ctc_loss_v3(
+ labels, math_ops.cast(fp16_logits, dtypes.float32), label_length,
+ logit_length, use_gpu=True, sparse=sparse_labels)
+
+ self.assertEqual(fp16_loss.dtype, dtypes.float16)
+ self.assertEqual(fp16_grad.dtype, dtypes.float16)
+ self.assertAllClose(
+ self.evaluate(fp16_loss),
+ self.evaluate(math_ops.cast(fp32_loss, dtypes.float16))
+ )
+ self.assertAllClose(
+ self.evaluate(fp16_grad),
+ self.evaluate(math_ops.cast(fp32_grad, dtypes.float16))
+ )
+
+ @parameterized.parameters([False, True])
+ def testCtcLossWithListLogits(self, sparse_labels):
+ batch_size = 8
+ num_labels = 6
+ max_label_length = 5
+ num_frames = 12
+
+ labels = np.random.randint(1, num_labels, [batch_size, max_label_length])
+ labels = ops.convert_to_tensor(labels, dtypes.int64)
+ logits = np.random.uniform(size=[num_frames, batch_size, num_labels])
+ label_length = np.random.randint(2, max_label_length, [batch_size])
+ label_length = ops.convert_to_tensor(label_length, dtypes.int64)
+
+ label_mask = array_ops.sequence_mask(
+ label_length, maxlen=max_label_length, dtype=label_length.dtype)
+ labels *= label_mask
+ logit_length = [num_frames] * batch_size
+ if sparse_labels:
+ labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
+
+ list_loss = ctc_ops.ctc_loss_v3(
+ labels=labels,
+ logits=logits.tolist(),
+ label_length=label_length,
+ logit_length=logit_length,
+ blank_index=0)
+ tensor_loss = ctc_ops.ctc_loss_v3(
+ labels=labels,
+ logits=ops.convert_to_tensor(logits, dtypes.float32),
+ label_length=label_length,
+ logit_length=logit_length,
+ blank_index=0)
+
+ self.assertAllClose(self.evaluate(list_loss), self.evaluate(tensor_loss))
+
@test_util.run_v2_only
def testCtcLossAlgorithmFallback(self):
"""Test if GPU CTC loss can fallback to the correct algorithm."""
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index a0f101b..7a56b4a 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -214,9 +214,15 @@
# For internal calculations, we transpose to [time, batch, num_classes]
inputs = deprecation.deprecated_argument_lookup("logits", logits, "inputs",
inputs)
+
+ inputs = ops.convert_to_tensor(inputs, name="logits")
if not time_major:
inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N)
+ orig_dtype = inputs.dtype
+ if orig_dtype in (dtypes.float16, dtypes.bfloat16):
+ inputs = math_ops.cast(inputs, dtypes.float32)
+
# gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the
# blank index to be 0, but v1 views it as the last index.
if use_cudnn:
@@ -233,6 +239,9 @@
ctc_merge_repeated=ctc_merge_repeated,
ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs)
+ if orig_dtype in (dtypes.float16, dtypes.bfloat16):
+ loss = math_ops.cast(loss, orig_dtype)
+
return loss
# pylint: disable=unused-argument
@@ -927,6 +936,8 @@
if blank_index < 0:
blank_index += _get_dim(logits, 2)
+ logits = ops.convert_to_tensor(logits, name="logits")
+
params = {
"labels": labels,
"logits": logits,
@@ -1042,6 +1053,10 @@
label_length = ops.convert_to_tensor(label_length, name="label_length")
logit_length = ops.convert_to_tensor(logit_length, name="logit_length")
+ orig_dtype = logits.dtype
+ if orig_dtype in (dtypes.float16, dtypes.bfloat16):
+ logits = math_ops.cast(logits, dtypes.float32)
+
if not logits_time_major:
logits = array_ops.transpose(logits, perm=[1, 0, 2])
@@ -1093,7 +1108,10 @@
return result[0], grad
- return compute_ctc_loss(*args)
+ loss = compute_ctc_loss(*args)
+ if orig_dtype in (dtypes.float16, dtypes.bfloat16):
+ loss = math_ops.cast(loss, orig_dtype)
+ return loss
@tf_export("nn.collapse_repeated")