[GML] Implement fusion of unary ops through the fusion interface

Also add tests for point-wise fusion.

PiperOrigin-RevId: 453731794
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index 30baafa..90fd8a5 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -2536,6 +2536,7 @@
         "@llvm-project//mlir:ArithmeticDialect",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:MathDialect",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TensorDialect",
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion.cc
index 5664fad..b6c0add 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion.cc
@@ -23,6 +23,7 @@
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -141,15 +142,6 @@
       op.known_nonexpanding_dimensionsAttr());
 }
 
-// TODO(frgossen): This should become a fusion interface.
-template <class OpTy>
-Value whatWillBeTheFusionIfaceUnaryOp(OpTy op, Value tile,
-                                      PatternRewriter& rewriter) {
-  auto loc = op.getLoc();
-  auto operandSub = rewriter.create<MaterializeOp>(loc, op.operand(), tile);
-  return rewriter.create<OpTy>(loc, operandSub);
-}
-
 struct FusionPattern : public OpRewritePattern<gml_st::MaterializeOp> {
   using OpRewritePattern<gml_st::MaterializeOp>::OpRewritePattern;
 
@@ -158,8 +150,10 @@
     Operation* def = op.source().getDefiningOp();
 
     if (auto iface = llvm::dyn_cast_or_null<FusionIterface>(def)) {
-      rewriter.replaceOp(op, iface.fuse(op, rewriter));
-      return success();
+      if (Value fused = iface.fuse(op, rewriter)) {
+        rewriter.replaceOp(op, fused);
+        return success();
+      }
     }
 
     // TODO(frgossen): The below cases should eventually be replaced by the use
@@ -173,27 +167,14 @@
       return success();
     }
 
-    // Case `cos`.
-    if (auto cos = llvm::dyn_cast_or_null<mhlo::CosOp>(def)) {
-      rewriter.replaceOp(
-          op, whatWillBeTheFusionIfaceUnaryOp(cos, op.subset(), rewriter));
-      return success();
-    }
-
-    // Case `tanh`.
-    if (auto tanh = llvm::dyn_cast_or_null<mhlo::TanhOp>(def)) {
-      rewriter.replaceOp(
-          op, whatWillBeTheFusionIfaceUnaryOp(tanh, op.subset(), rewriter));
-      return success();
-    }
-
     return failure();
   }
 };
 
 class FusionPass : public FusionPassBase<FusionPass> {
   void getDependentDialects(DialectRegistry& registry) const final {
-    registry.insert<GmlStDialect>();
+    registry
+        .insert<GmlStDialect, math::MathDialect, arith::ArithmeticDialect>();
     registerFusionInterfaceExternalModels(registry);
   }
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion_interface.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion_interface.cc
index 5a86f00..80c2689 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion_interface.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion_interface.cc
@@ -20,8 +20,6 @@
 #include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface.cc.inc"
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Support/LogicalResult.h"
 
 namespace mlir {
 namespace gml_st {
@@ -29,33 +27,34 @@
 namespace {
 
 template <typename OpTy>
-struct BinaryElementwiseFusionInterface
-    : public FusionIterface::ExternalModel<
-          BinaryElementwiseFusionInterface<OpTy>, OpTy> {
+struct ElementwiseFusionInterface
+    : public FusionIterface::ExternalModel<ElementwiseFusionInterface<OpTy>,
+                                           OpTy> {
   Value fuse(Operation* op, MaterializeOp materializeOp,
              OpBuilder& builder) const {
-    auto binaryElementwiseOp = cast<OpTy>(op);
+    // Supports tile and point subsets.
     Value subset = materializeOp.subset();
-    Location loc = materializeOp.getLoc();
+    auto subsetTy = subset.getType();
+    if (!subsetTy.isa<PointType, TileType>()) return {};
 
-    return llvm::TypeSwitch<Type, Value>(subset.getType())
-        .Case([&](PointType) -> Value {
-          auto lhs = builder.create<MaterializeOp>(
-              loc, binaryElementwiseOp.lhs(), subset);
-          auto rhs = builder.create<MaterializeOp>(
-              loc, binaryElementwiseOp.rhs(), subset);
-          return mhlo::MhloOpToStdScalarOp::map<OpTy>(
-              binaryElementwiseOp, materializeOp.getType(),
-              llvm::ArrayRef<Value>{lhs, rhs}, &builder);
-        })
+    // Materialize subsets for all arguments.
+    auto ewiseOp = cast<OpTy>(op);
+    Location loc = materializeOp.getLoc();
+    auto subsetArgs = llvm::to_vector(
+        llvm::map_range(ewiseOp->getOperands(), [&](const auto& arg) -> Value {
+          return builder.create<MaterializeOp>(loc, arg, subset);
+        }));
+
+    // Materialize elementwise op for subset.
+    return llvm::TypeSwitch<Type, Value>(subsetTy)
         .Case([&](TileType) -> Value {
-          auto lhs = builder.create<MaterializeOp>(
-              loc, binaryElementwiseOp.lhs(), subset);
-          auto rhs = builder.create<MaterializeOp>(
-              loc, binaryElementwiseOp.rhs(), subset);
-          return builder.create<OpTy>(loc, lhs, rhs);
+          return builder.create<OpTy>(loc, subsetArgs);
         })
-        .Default([&](Type) -> Value { return {}; });
+        .Case([&](PointType) -> Value {
+          return mhlo::MhloOpToStdScalarOp::map<OpTy>(
+              ewiseOp, materializeOp.getType(), subsetArgs, &builder);
+        })
+        .Default([](Type) -> Value { return {}; });
   }
 };
 
@@ -64,9 +63,10 @@
 void registerFusionInterfaceExternalModels(DialectRegistry& registry) {
   registry.insert<mhlo::MhloDialect>();
   registry.addExtension(+[](MLIRContext* ctx, mhlo::MhloDialect* /*dialect*/) {
-    mhlo::AddOp::attachInterface<BinaryElementwiseFusionInterface<mhlo::AddOp>>(
-        *ctx);
-    mhlo::SubOp::attachInterface<BinaryElementwiseFusionInterface<mhlo::SubOp>>(
+    mhlo::AddOp::attachInterface<ElementwiseFusionInterface<mhlo::AddOp>>(*ctx);
+    mhlo::SubOp::attachInterface<ElementwiseFusionInterface<mhlo::SubOp>>(*ctx);
+    mhlo::CosOp::attachInterface<ElementwiseFusionInterface<mhlo::CosOp>>(*ctx);
+    mhlo::TanhOp::attachInterface<ElementwiseFusionInterface<mhlo::TanhOp>>(
         *ctx);
   });
 }
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/fusion.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/fusion.mlir
index 64b8a71..230b895 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/fusion.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/fusion.mlir
@@ -90,6 +90,48 @@
 
 // -----
 
+// CHECK-LABEL: @cos
+// CHECK-SAME:  %[[ARG:.*]]: tensor<32x32xf32>, %[[TILE:.*]]: !gml_st.tile<?x?>
+func.func @cos(%arg: tensor<32x32xf32>, %tile: !gml_st.tile<?x?>)
+    -> tensor<?x?xf32> {
+  // CHECK-DAG: %[[ARG_SUB:.*]] = gml_st.materialize %[[ARG]] at %[[TILE]] : tensor<32x32xf32> at !gml_st.tile<?x?>
+  // CHECK-DAG: %[[RES:.*]] = mhlo.cosine %[[ARG_SUB]] : tensor<?x?xf32>
+  // CHECK:     return %[[RES]]
+  %0 = mhlo.cosine %arg : tensor<32x32xf32>
+  %1 = gml_st.materialize %0 at %tile : tensor<32x32xf32> at !gml_st.tile<?x?>
+  return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @add_point
+// CHECK-SAME:  %[[LHS:.*]]: tensor<32x32xf32>, %[[RHS:.*]]: tensor<32x32xf32>, %[[POINT:.*]]: !gml_st.point
+func.func @add_point(%lhs: tensor<32x32xf32>, %rhs: tensor<32x32xf32>,
+    %point: !gml_st.point) -> f32 {
+  // CHECK-DAG: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]] at %[[POINT]] : tensor<32x32xf32> at !gml_st.point
+  // CHECK-DAG: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]] at %[[POINT]] : tensor<32x32xf32> at !gml_st.point
+  // CHECK-DAG: %[[RES:.*]] = arith.addf %[[LHS_SUB]], %[[RHS_SUB]]
+  // CHECK:     return %[[RES]]
+  %0 = mhlo.add %lhs, %rhs : tensor<32x32xf32>
+  %1 = gml_st.materialize %0 at %point : tensor<32x32xf32> at !gml_st.point
+  func.return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @cos_point
+// CHECK-SAME:  %[[ARG:.*]]: tensor<32x32xf32>, %[[POINT:.*]]: !gml_st.point
+func.func @cos_point(%arg: tensor<32x32xf32>, %point: !gml_st.point) -> f32 {
+  // CHECK-DAG: %[[ARG_SUB:.*]] = gml_st.materialize %[[ARG]] at %[[POINT]] : tensor<32x32xf32> at !gml_st.point
+  // CHECK-DAG: %[[RES:.*]] = math.cos %[[ARG_SUB]]
+  // CHECK:     return %[[RES]]
+  %0 = mhlo.cosine %arg : tensor<32x32xf32>
+  %1 = gml_st.materialize %0 at %point : tensor<32x32xf32> at !gml_st.point
+  return %1 : f32
+}
+
+// -----
+
 #cwise_trait = {
   indexing_maps = [
     affine_map<(d0) -> (d0)>,