[JITRT] Add clustering policy for gather-scatter.
PiperOrigin-RevId: 459510174
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc
index 6703609..683ae39e 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc
@@ -748,6 +748,44 @@
}
};
+// -------------------------------------------------------------------------- //
+// Gather Operations.
+// -------------------------------------------------------------------------- //
+
+class GatherOpClusteringPolicy : public DefaultClusteringPolicy {
+ public:
+ GatherOpClusteringPolicy()
+ : DefaultClusteringPolicy(IsGatherOp(), ValueConstraint::kRank) {}
+
+ private:
+ std::function<bool(Operation* op)> IsGatherOp() {
+ return [](Operation* op) {
+ return mlir::isa<mlir::TF::GatherNdOp, mlir::TF::GatherV2Op,
+ mlir::TF::GatherOp>(op);
+ };
+ }
+};
+
+// -------------------------------------------------------------------------- //
+// Scatter Operations.
+// -------------------------------------------------------------------------- //
+
+class ScatterOpClusteringPolicy : public DefaultClusteringPolicy {
+ public:
+ ScatterOpClusteringPolicy()
+ : DefaultClusteringPolicy(IsScatterOp(), ValueConstraint::kRank) {}
+
+ private:
+ std::function<bool(Operation* op)> IsScatterOp() {
+ return [](Operation* op) {
+ return mlir::isa<
+ mlir::TF::ScatterNdOp, mlir::TF::TensorScatterAddOp,
+ mlir::TF::TensorScatterMaxOp, mlir::TF::TensorScatterMinOp,
+ mlir::TF::TensorScatterSubOp, mlir::TF::TensorScatterUpdateOp>(op);
+ };
+ }
+};
+
} // namespace
void populateTfJitRtClusteringPolicies(ClusteringPolicySet& policies,
@@ -780,6 +818,11 @@
SqueezeOpClusteringPolicy>();
}
+ if (is_enabled(JitRtClusteringTier::kGatherScatter)) {
+ policies.Add<GatherOpClusteringPolicy, //
+ ScatterOpClusteringPolicy>();
+ }
+
if (is_enabled(JitRtClusteringTier::kAll)) {
policies.Add<BatchMatMulV2OpClusteringPolicy, //
BroadcastToOpClusteringPolicy, //
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h
index 10bbd31..9dda224 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h
@@ -30,8 +30,9 @@
enum class JitRtClusteringTier : uint8_t {
kCwise = 0x1,
kTranspose = 0x2,
- kMetadata = 0x4, // shape, reshape, ...
- kReductions = 0x8, // all, any, min, max, mean, prod, sum
+ kMetadata = 0x4, // shape, reshape, ...
+ kReductions = 0x8, // all, any, min, max, mean, prod, sum
+ kGatherScatter = 0x10, // gather, scatter, gather_v2,...
// Only cwise operations (unary, binary, ternary).
kTier0 = kCwise,