Fix bug in shape inference of TensorListElementShape

The shape inference propagated without checking the result type here, but this
operation can return different element type from those of the shape.

PiperOrigin-RevId: 358902169
Change-Id: I592621a6d7533cb58f140ffb0dd37c6e62d19875
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
index 98fe02e..a5ac3a9 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
@@ -685,9 +685,11 @@
     // CHECK-SAME: tensor<!tf.variant<tensor<16x1xf32>>>
     %tl_0 = "tf.EmptyTensorList"(%elem_shape, %size) : (tensor<2xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x1xf32>>>
     %tl_1 = "tf.TensorListPushBack"(%tl_0, %elem) : (tensor<!tf.variant<tensor<?x1xf32>>>, tensor<16x1xf32>) -> tensor<!tf.variant<tensor<?x1xf32>>>
-    %shape = "tf.TensorListElementShape"(%tl_1) : (tensor<!tf.variant<tensor<?x1xf32>>>) -> tensor<?xi32>
-    // CHECK: "tf._SomeOtherOp"(%[[ELEMENT_SHAPE]])
-    "tf._SomeOtherOp"(%shape) : (tensor<?xi32>) -> ()
+    %shape_32 = "tf.TensorListElementShape"(%tl_1) : (tensor<!tf.variant<tensor<?x1xf32>>>) -> tensor<?xi32>
+    %shape_64 = "tf.TensorListElementShape"(%tl_1) : (tensor<!tf.variant<tensor<?x1xf32>>>) -> tensor<?xi64>
+    // CHECK: %[[CAST:.*]] = "tf.Cast"(%[[ELEMENT_SHAPE]]){{.*}}: (tensor<2xi32>) -> tensor<2xi64>
+    // CHECK: "tf._SomeOtherOp"(%[[ELEMENT_SHAPE]], %[[CAST]])
+    "tf._SomeOtherOp"(%shape_32, %shape_64) : (tensor<?xi32>, tensor<?xi64>) -> ()
     return
   }
 
@@ -1146,4 +1148,4 @@
     %3 = "tf.SpaceToBatchND"(%arg0, %0, %2) {device = ""} : (tensor<1x192x256x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?x128xf32>
     return
   }
-}
\ No newline at end of file
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 23ee5b6..a890c2b 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -42,6 +42,7 @@
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
+#include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
 #include "mlir/Interfaces/FoldInterfaces.h"  // from @llvm-project
@@ -257,12 +258,23 @@
       continue;
     }
     // Refining the tensor list element type might change the output of
-    // TensorListElementShape which is expected tp be the originally assigned
+    // TensorListElementShape which is expected to be the originally assigned
     // shape to TensorList init ops. So replace it with the original element
     // shape value.
     if (auto tl_element_shape =
             dyn_cast<TensorListElementShapeOp>(use.getOwner())) {
-      tl_element_shape.replaceAllUsesWith(initial_element_shape);
+      // If element types match, we can do a direct replacement.
+      if (getElementTypeOrSelf(tl_element_shape.getResult()) ==
+          getElementTypeOrSelf(initial_element_shape.getType())) {
+        tl_element_shape.replaceAllUsesWith(initial_element_shape);
+      } else {
+        OpBuilder b(use.getOwner());
+        auto cast_op = b.create<TF::CastOp>(
+            use.getOwner()->getLoc(), tl_element_shape.getResult().getType(),
+            initial_element_shape,
+            /*truncate=*/b.getBoolAttr(false));
+        tl_element_shape.replaceAllUsesWith(cast_op.getResult());
+      }
       continue;
     }
     // Ignore ops that just consume a TensorList and do not output another
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index cbda91f..bc1bb54 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -987,7 +987,6 @@
 
     self._test_loop_fn(loop_fn, 2)
 
-  @test_util.disable_tfrt("b/180206304")
   def test_create_inside_and_read(self):
 
     def loop_fn(i):