[JITRT] Add clustering policy for gather-scatter.

PiperOrigin-RevId: 459510174
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc
index 6703609..683ae39e 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc
@@ -748,6 +748,44 @@
   }
 };
 
+// -------------------------------------------------------------------------- //
+// Gather Operations.
+// -------------------------------------------------------------------------- //
+
+class GatherOpClusteringPolicy : public DefaultClusteringPolicy {
+ public:
+  GatherOpClusteringPolicy()
+      : DefaultClusteringPolicy(IsGatherOp(), ValueConstraint::kRank) {}
+
+ private:
+  std::function<bool(Operation* op)> IsGatherOp() {
+    return [](Operation* op) {
+      return mlir::isa<mlir::TF::GatherNdOp, mlir::TF::GatherV2Op,
+                       mlir::TF::GatherOp>(op);
+    };
+  }
+};
+
+// -------------------------------------------------------------------------- //
+// Scatter Operations.
+// -------------------------------------------------------------------------- //
+
+class ScatterOpClusteringPolicy : public DefaultClusteringPolicy {
+ public:
+  ScatterOpClusteringPolicy()
+      : DefaultClusteringPolicy(IsScatterOp(), ValueConstraint::kRank) {}
+
+ private:
+  std::function<bool(Operation* op)> IsScatterOp() {
+    return [](Operation* op) {
+      return mlir::isa<
+          mlir::TF::ScatterNdOp, mlir::TF::TensorScatterAddOp,
+          mlir::TF::TensorScatterMaxOp, mlir::TF::TensorScatterMinOp,
+          mlir::TF::TensorScatterSubOp, mlir::TF::TensorScatterUpdateOp>(op);
+    };
+  }
+};
+
 }  // namespace
 
 void populateTfJitRtClusteringPolicies(ClusteringPolicySet& policies,
@@ -780,6 +818,11 @@
                  SqueezeOpClusteringPolicy>();
   }
 
+  if (is_enabled(JitRtClusteringTier::kGatherScatter)) {
+    policies.Add<GatherOpClusteringPolicy,  //
+                 ScatterOpClusteringPolicy>();
+  }
+
   if (is_enabled(JitRtClusteringTier::kAll)) {
     policies.Add<BatchMatMulV2OpClusteringPolicy,  //
                  BroadcastToOpClusteringPolicy,    //
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h
index 10bbd31..9dda224 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h
@@ -30,8 +30,9 @@
 enum class JitRtClusteringTier : uint8_t {
   kCwise = 0x1,
   kTranspose = 0x2,
-  kMetadata = 0x4,    // shape, reshape, ...
-  kReductions = 0x8,  // all, any, min, max, mean, prod, sum
+  kMetadata = 0x4,        // shape, reshape, ...
+  kReductions = 0x8,      // all, any, min, max, mean, prod, sum
+  kGatherScatter = 0x10,  // gather, scatter, gather_v2,...
 
   // Only cwise operations (unary, binary, ternary).
   kTier0 = kCwise,