[MLIR][HLO] Add more tests for `rank-specialization-cluster` pass

PiperOrigin-RevId: 373343750
Change-Id: Ie07010935710749bc9f6eb8a0690a58b418e2139
diff --git a/tensorflow/compiler/mlir/hlo/tests/rank-specialization.mlir b/tensorflow/compiler/mlir/hlo/tests/rank-specialization.mlir
index 8e8b646..eb505cf 100644
--- a/tensorflow/compiler/mlir/hlo/tests/rank-specialization.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/rank-specialization.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-hlo-opt %s --mhlo-rank-specialization-cluster | FileCheck %s
+// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
 
 // CHECK-LABEL: @add_mul
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
@@ -17,3 +17,35 @@
       : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
   return %1 : tensor<*xf32>
 }
+
+// -----
+
+// Unary MHLO operation.
+// CHECK-LABEL: @sqrt
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
+func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> {
+  // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]])
+  // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>):
+  // CHECK:   %[[TMP0:.*]] = "mhlo.sqrt"(%[[ARG_]])
+  // CHECK:   %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]])
+  // CHECK:   %[[TMP2:.*]] = "mhlo.sqrt"(%[[TMP1]])
+  // CHECK:   "chlo.rank_specialization_cluster_yield"(%[[TMP2]])
+  // CHECK: return %[[RES]]
+  %0 = "mhlo.sqrt"(%arg) : (tensor<*xf32>) -> tensor<*xf32>
+  %1 = "mhlo.sqrt"(%0) : (tensor<*xf32>) -> tensor<*xf32>
+  %2 = "mhlo.sqrt"(%1) : (tensor<*xf32>) -> tensor<*xf32>
+  return %2 : tensor<*xf32>
+}
+
+// -----
+
+// Don't cluster single ranked operation.
+// CHECK-LABEL: @sqrt_ranked
+// CHECK-SAME: (%[[ARG:.*]]: tensor<3x?xf32>)
+func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
+  // CHECK-NOT: rank_specialization_cluster
+  %0 = "mhlo.sqrt"(%arg) : (tensor<3x?xf32>) -> tensor<3x?xf32>
+  %1 = "mhlo.sqrt"(%0) : (tensor<3x?xf32>) -> tensor<3x?xf32>
+  %2 = "mhlo.sqrt"(%1) : (tensor<3x?xf32>) -> tensor<3x?xf32>
+  return %2 : tensor<3x?xf32>
+}