Rollforward 0e074ae064c441185e09b1fe4b0ee87676f10266.

The issue was that the CL checked the `dtype` attribute of the input logits, but the logits might be a list or anything else convertible to a tensor. Now the logits are explicitly converted to a tensor first.
END_PUBLIC

PiperOrigin-RevId: 426035926
Change-Id: I5df63cebe034060ed569f29544b5caf48dafe77c
diff --git a/tensorflow/python/kernel_tests/nn_ops/ctc_loss_op_test.py b/tensorflow/python/kernel_tests/nn_ops/ctc_loss_op_test.py
index 650bc95..b1fb7bf 100644
--- a/tensorflow/python/kernel_tests/nn_ops/ctc_loss_op_test.py
+++ b/tensorflow/python/kernel_tests/nn_ops/ctc_loss_op_test.py
@@ -936,18 +936,20 @@
           [[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out)
 
 
-def _ctc_loss_v3(labels, logits, label_length, logit_length, use_gpu):
+def _ctc_loss_v3(labels, logits, label_length, logit_length, use_gpu,
+                 sparse=True):
   with test_util.device(use_gpu=use_gpu):
-    sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
+    if sparse:
+      labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
     with backprop.GradientTape() as t:
       t.watch(logits)
       ref_loss = ctc_ops.ctc_loss_v3(
-          labels=sparse_labels,
+          labels=labels,
           logits=logits,
           label_length=label_length,
           logit_length=logit_length,
           blank_index=0)
-    ref_grad = t.gradient(ref_loss, [logits])
+    ref_grad = t.gradient(ref_loss, logits)
     return ref_loss, ref_grad
 
 
@@ -1000,6 +1002,78 @@
     self.assertAllClose(loss, ref_loss, atol=1e-6)
     self.assertAllClose(grad, ref_grad, atol=2e-6)
 
+  @parameterized.parameters([False, True])
+  def testCtcLossFp16(self, sparse_labels):
+    batch_size = 8
+    num_labels = 6
+    max_label_length = 5
+    num_frames = 12
+
+    labels = np.random.randint(1, num_labels, [batch_size, max_label_length])
+    labels = ops.convert_to_tensor(labels, dtypes.int64)
+    fp16_logits = np.random.uniform(size=[num_frames, batch_size, num_labels])
+    fp16_logits = ops.convert_to_tensor(fp16_logits, dtypes.float16)
+    label_length = np.random.randint(2, max_label_length, [batch_size])
+    label_length = ops.convert_to_tensor(label_length, dtypes.int64)
+
+    label_mask = array_ops.sequence_mask(
+        label_length, maxlen=max_label_length, dtype=label_length.dtype)
+    labels *= label_mask
+    logit_length = [num_frames] * batch_size
+
+    fp16_loss, fp16_grad = _ctc_loss_v3(
+        labels, fp16_logits, label_length, logit_length, use_gpu=True,
+        sparse=sparse_labels)
+    fp32_loss, fp32_grad = _ctc_loss_v3(
+        labels, math_ops.cast(fp16_logits, dtypes.float32), label_length,
+        logit_length, use_gpu=True, sparse=sparse_labels)
+
+    self.assertEqual(fp16_loss.dtype, dtypes.float16)
+    self.assertEqual(fp16_grad.dtype, dtypes.float16)
+    self.assertAllClose(
+        self.evaluate(fp16_loss),
+        self.evaluate(math_ops.cast(fp32_loss, dtypes.float16))
+    )
+    self.assertAllClose(
+        self.evaluate(fp16_grad),
+        self.evaluate(math_ops.cast(fp32_grad, dtypes.float16))
+    )
+
+  @parameterized.parameters([False, True])
+  def testCtcLossWithListLogits(self, sparse_labels):
+    batch_size = 8
+    num_labels = 6
+    max_label_length = 5
+    num_frames = 12
+
+    labels = np.random.randint(1, num_labels, [batch_size, max_label_length])
+    labels = ops.convert_to_tensor(labels, dtypes.int64)
+    logits = np.random.uniform(size=[num_frames, batch_size, num_labels])
+    label_length = np.random.randint(2, max_label_length, [batch_size])
+    label_length = ops.convert_to_tensor(label_length, dtypes.int64)
+
+    label_mask = array_ops.sequence_mask(
+        label_length, maxlen=max_label_length, dtype=label_length.dtype)
+    labels *= label_mask
+    logit_length = [num_frames] * batch_size
+    if sparse_labels:
+      labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
+
+    list_loss = ctc_ops.ctc_loss_v3(
+        labels=labels,
+        logits=logits.tolist(),
+        label_length=label_length,
+        logit_length=logit_length,
+        blank_index=0)
+    tensor_loss = ctc_ops.ctc_loss_v3(
+        labels=labels,
+        logits=ops.convert_to_tensor(logits, dtypes.float32),
+        label_length=label_length,
+        logit_length=logit_length,
+        blank_index=0)
+
+    self.assertAllClose(self.evaluate(list_loss), self.evaluate(tensor_loss))
+
   @test_util.run_v2_only
   def testCtcLossAlgorithmFallback(self):
     """Test if GPU CTC loss can fallback to the correct algorithm."""
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index a0f101b..7a56b4a 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -214,9 +214,15 @@
   # For internal calculations, we transpose to [time, batch, num_classes]
   inputs = deprecation.deprecated_argument_lookup("logits", logits, "inputs",
                                                   inputs)
+
+  inputs = ops.convert_to_tensor(inputs, name="logits")
   if not time_major:
     inputs = array_ops.transpose(inputs, [1, 0, 2])  # (B,T,N) => (T,B,N)
 
+  orig_dtype = inputs.dtype
+  if orig_dtype in (dtypes.float16, dtypes.bfloat16):
+    inputs = math_ops.cast(inputs, dtypes.float32)
+
   # gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the
   # blank index to be 0, but v1 views it as the last index.
   if use_cudnn:
@@ -233,6 +239,9 @@
       ctc_merge_repeated=ctc_merge_repeated,
       ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs)
 
+  if orig_dtype in (dtypes.float16, dtypes.bfloat16):
+    loss = math_ops.cast(loss, orig_dtype)
+
   return loss
 
 # pylint: disable=unused-argument
@@ -927,6 +936,8 @@
     if blank_index < 0:
       blank_index += _get_dim(logits, 2)
 
+    logits = ops.convert_to_tensor(logits, name="logits")
+
     params = {
         "labels": labels,
         "logits": logits,
@@ -1042,6 +1053,10 @@
     label_length = ops.convert_to_tensor(label_length, name="label_length")
     logit_length = ops.convert_to_tensor(logit_length, name="logit_length")
 
+    orig_dtype = logits.dtype
+    if orig_dtype in (dtypes.float16, dtypes.bfloat16):
+      logits = math_ops.cast(logits, dtypes.float32)
+
     if not logits_time_major:
       logits = array_ops.transpose(logits, perm=[1, 0, 2])
 
@@ -1093,7 +1108,10 @@
 
       return result[0], grad
 
-    return compute_ctc_loss(*args)
+    loss = compute_ctc_loss(*args)
+    if orig_dtype in (dtypes.float16, dtypes.bfloat16):
+      loss = math_ops.cast(loss, orig_dtype)
+    return loss
 
 
 @tf_export("nn.collapse_repeated")