[GML] Implement fusion of unary ops through the fusion interface
Also add tests for point-wise fusion.
PiperOrigin-RevId: 453731794
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index 30baafa..90fd8a5 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -2536,6 +2536,7 @@
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion.cc
index 5664fad..b6c0add 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion.cc
@@ -23,6 +23,7 @@
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -141,15 +142,6 @@
op.known_nonexpanding_dimensionsAttr());
}
-// TODO(frgossen): This should become a fusion interface.
-template <class OpTy>
-Value whatWillBeTheFusionIfaceUnaryOp(OpTy op, Value tile,
- PatternRewriter& rewriter) {
- auto loc = op.getLoc();
- auto operandSub = rewriter.create<MaterializeOp>(loc, op.operand(), tile);
- return rewriter.create<OpTy>(loc, operandSub);
-}
-
struct FusionPattern : public OpRewritePattern<gml_st::MaterializeOp> {
using OpRewritePattern<gml_st::MaterializeOp>::OpRewritePattern;
@@ -158,8 +150,10 @@
Operation* def = op.source().getDefiningOp();
if (auto iface = llvm::dyn_cast_or_null<FusionIterface>(def)) {
- rewriter.replaceOp(op, iface.fuse(op, rewriter));
- return success();
+ if (Value fused = iface.fuse(op, rewriter)) {
+ rewriter.replaceOp(op, fused);
+ return success();
+ }
}
// TODO(frgossen): The below cases should eventually be replaced by the use
@@ -173,27 +167,14 @@
return success();
}
- // Case `cos`.
- if (auto cos = llvm::dyn_cast_or_null<mhlo::CosOp>(def)) {
- rewriter.replaceOp(
- op, whatWillBeTheFusionIfaceUnaryOp(cos, op.subset(), rewriter));
- return success();
- }
-
- // Case `tanh`.
- if (auto tanh = llvm::dyn_cast_or_null<mhlo::TanhOp>(def)) {
- rewriter.replaceOp(
- op, whatWillBeTheFusionIfaceUnaryOp(tanh, op.subset(), rewriter));
- return success();
- }
-
return failure();
}
};
class FusionPass : public FusionPassBase<FusionPass> {
void getDependentDialects(DialectRegistry& registry) const final {
- registry.insert<GmlStDialect>();
+ registry
+ .insert<GmlStDialect, math::MathDialect, arith::ArithmeticDialect>();
registerFusionInterfaceExternalModels(registry);
}
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion_interface.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion_interface.cc
index 5a86f00..80c2689 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion_interface.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion_interface.cc
@@ -20,8 +20,6 @@
#include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface.cc.inc"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace gml_st {
@@ -29,33 +27,34 @@
namespace {
template <typename OpTy>
-struct BinaryElementwiseFusionInterface
- : public FusionIterface::ExternalModel<
- BinaryElementwiseFusionInterface<OpTy>, OpTy> {
+struct ElementwiseFusionInterface
+ : public FusionIterface::ExternalModel<ElementwiseFusionInterface<OpTy>,
+ OpTy> {
Value fuse(Operation* op, MaterializeOp materializeOp,
OpBuilder& builder) const {
- auto binaryElementwiseOp = cast<OpTy>(op);
+ // Supports tile and point subsets.
Value subset = materializeOp.subset();
- Location loc = materializeOp.getLoc();
+ auto subsetTy = subset.getType();
+ if (!subsetTy.isa<PointType, TileType>()) return {};
- return llvm::TypeSwitch<Type, Value>(subset.getType())
- .Case([&](PointType) -> Value {
- auto lhs = builder.create<MaterializeOp>(
- loc, binaryElementwiseOp.lhs(), subset);
- auto rhs = builder.create<MaterializeOp>(
- loc, binaryElementwiseOp.rhs(), subset);
- return mhlo::MhloOpToStdScalarOp::map<OpTy>(
- binaryElementwiseOp, materializeOp.getType(),
- llvm::ArrayRef<Value>{lhs, rhs}, &builder);
- })
+ // Materialize subsets for all arguments.
+ auto ewiseOp = cast<OpTy>(op);
+ Location loc = materializeOp.getLoc();
+ auto subsetArgs = llvm::to_vector(
+ llvm::map_range(ewiseOp->getOperands(), [&](const auto& arg) -> Value {
+ return builder.create<MaterializeOp>(loc, arg, subset);
+ }));
+
+ // Materialize elementwise op for subset.
+ return llvm::TypeSwitch<Type, Value>(subsetTy)
.Case([&](TileType) -> Value {
- auto lhs = builder.create<MaterializeOp>(
- loc, binaryElementwiseOp.lhs(), subset);
- auto rhs = builder.create<MaterializeOp>(
- loc, binaryElementwiseOp.rhs(), subset);
- return builder.create<OpTy>(loc, lhs, rhs);
+ return builder.create<OpTy>(loc, subsetArgs);
})
- .Default([&](Type) -> Value { return {}; });
+ .Case([&](PointType) -> Value {
+ return mhlo::MhloOpToStdScalarOp::map<OpTy>(
+ ewiseOp, materializeOp.getType(), subsetArgs, &builder);
+ })
+ .Default([](Type) -> Value { return {}; });
}
};
@@ -64,9 +63,10 @@
void registerFusionInterfaceExternalModels(DialectRegistry& registry) {
registry.insert<mhlo::MhloDialect>();
registry.addExtension(+[](MLIRContext* ctx, mhlo::MhloDialect* /*dialect*/) {
- mhlo::AddOp::attachInterface<BinaryElementwiseFusionInterface<mhlo::AddOp>>(
- *ctx);
- mhlo::SubOp::attachInterface<BinaryElementwiseFusionInterface<mhlo::SubOp>>(
+ mhlo::AddOp::attachInterface<ElementwiseFusionInterface<mhlo::AddOp>>(*ctx);
+ mhlo::SubOp::attachInterface<ElementwiseFusionInterface<mhlo::SubOp>>(*ctx);
+ mhlo::CosOp::attachInterface<ElementwiseFusionInterface<mhlo::CosOp>>(*ctx);
+ mhlo::TanhOp::attachInterface<ElementwiseFusionInterface<mhlo::TanhOp>>(
*ctx);
});
}
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/fusion.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/fusion.mlir
index 64b8a71..230b895 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/fusion.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/fusion.mlir
@@ -90,6 +90,48 @@
// -----
+// CHECK-LABEL: @cos
+// CHECK-SAME: %[[ARG:.*]]: tensor<32x32xf32>, %[[TILE:.*]]: !gml_st.tile<?x?>
+func.func @cos(%arg: tensor<32x32xf32>, %tile: !gml_st.tile<?x?>)
+ -> tensor<?x?xf32> {
+ // CHECK-DAG: %[[ARG_SUB:.*]] = gml_st.materialize %[[ARG]] at %[[TILE]] : tensor<32x32xf32> at !gml_st.tile<?x?>
+ // CHECK-DAG: %[[RES:.*]] = mhlo.cosine %[[ARG_SUB]] : tensor<?x?xf32>
+ // CHECK: return %[[RES]]
+ %0 = mhlo.cosine %arg : tensor<32x32xf32>
+ %1 = gml_st.materialize %0 at %tile : tensor<32x32xf32> at !gml_st.tile<?x?>
+ return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @add_point
+// CHECK-SAME: %[[LHS:.*]]: tensor<32x32xf32>, %[[RHS:.*]]: tensor<32x32xf32>, %[[POINT:.*]]: !gml_st.point
+func.func @add_point(%lhs: tensor<32x32xf32>, %rhs: tensor<32x32xf32>,
+ %point: !gml_st.point) -> f32 {
+ // CHECK-DAG: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]] at %[[POINT]] : tensor<32x32xf32> at !gml_st.point
+ // CHECK-DAG: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]] at %[[POINT]] : tensor<32x32xf32> at !gml_st.point
+ // CHECK-DAG: %[[RES:.*]] = arith.addf %[[LHS_SUB]], %[[RHS_SUB]]
+ // CHECK: return %[[RES]]
+ %0 = mhlo.add %lhs, %rhs : tensor<32x32xf32>
+ %1 = gml_st.materialize %0 at %point : tensor<32x32xf32> at !gml_st.point
+ func.return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @cos_point
+// CHECK-SAME: %[[ARG:.*]]: tensor<32x32xf32>, %[[POINT:.*]]: !gml_st.point
+func.func @cos_point(%arg: tensor<32x32xf32>, %point: !gml_st.point) -> f32 {
+ // CHECK-DAG: %[[ARG_SUB:.*]] = gml_st.materialize %[[ARG]] at %[[POINT]] : tensor<32x32xf32> at !gml_st.point
+ // CHECK-DAG: %[[RES:.*]] = math.cos %[[ARG_SUB]]
+ // CHECK: return %[[RES]]
+ %0 = mhlo.cosine %arg : tensor<32x32xf32>
+ %1 = gml_st.materialize %0 at %point : tensor<32x32xf32> at !gml_st.point
+ return %1 : f32
+}
+
+// -----
+
#cwise_trait = {
indexing_maps = [
affine_map<(d0) -> (d0)>,