Add shape inference support for tf_device.cluster in shape inference pass.

This follows how shape inference is performed for passthrough ops like tf_device.launch.

PiperOrigin-RevId: 328241534
Change-Id: I83cc618291f5dc84ba18f963b413d5117e98de8c
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
index 3e61357..26df602 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
@@ -530,6 +530,21 @@
     return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32>
   }
 
+  // CHECK-LABEL: infer_device_cluster
+  func @infer_device_cluster(%arg0: tensor<1x8x2xi32>) -> (tensor<*xf32>, tensor<*xf32>) {
+    %0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+    %1 = "tf_device.cluster"() ({
+      %2 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x8x2xi32>) -> tensor<1x8x2xf32>
+      tf_device.return %2 : tensor<1x8x2xf32>
+    // CHECK: () -> tensor<1x8x2xf32>
+    }) : () -> tensor<*xf32>
+    // CHECK: "tf.Cast"(%{{.*}}) {Truncate = false} : (tensor<1x8x2xf32>) -> tensor<*xf32>
+    // CHECK: (tensor<i32>, tensor<1x8x2xf32>) -> (tensor<1x8x1xf32>, tensor<1x8x1xf32>)
+    %3:2 = "tf.Split"(%0, %1) {device = ""} : (tensor<i32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
+    %4 = addf %1, %1 : tensor<*xf32>
+    return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32>
+  }
+
   // CHECK-LABEL: func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<1xi32>
   func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<*xi32> {
    // CHECK: %[[RESULT:.*]] = tensor_cast
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 88ad787..0042216 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -116,12 +116,12 @@
 
 // Returns if the shape inference pass supports an op outside the TF dialect.
 bool IsSupportedNonTFOp(Operation* op) {
-  return isa<ReturnOp, tf_device::ReturnOp, tf_executor::EnterOp,
-             tf_executor::ExitOp, tf_executor::FetchOp, tf_executor::GraphOp,
-             tf_executor::IslandOp, tf_executor::LoopCondOp,
-             tf_executor::MergeOp, tf_executor::NextIterationSinkOp,
-             tf_executor::SwitchNOp, tf_executor::SwitchOp,
-             tf_executor::YieldOp>(op);
+  return isa<ReturnOp, tf_device::ReturnOp, tf_device::ClusterOp,
+             tf_device::LaunchOp, tf_executor::EnterOp, tf_executor::ExitOp,
+             tf_executor::FetchOp, tf_executor::GraphOp, tf_executor::IslandOp,
+             tf_executor::LoopCondOp, tf_executor::MergeOp,
+             tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp,
+             tf_executor::SwitchOp, tf_executor::YieldOp>(op);
 }
 
 // Returns whether a cast back would need to be inserted, e.g., whether the
@@ -745,6 +745,11 @@
     return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
                                             op->getResults());
   }
+  if (auto cluster_op = dyn_cast<tf_device::ClusterOp>(op)) {
+    auto terminator = cluster_op.GetBody().getTerminator();
+    return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
+                                            op->getResults());
+  }
   if (op->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
     return RefineShapeForPassThroughOps(op);
   }