Use variadic insert to add OpRewritePatterns.
PiperOrigin-RevId: 267179090
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc
index 780e638..03f55f1 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc
@@ -111,26 +111,23 @@
}
template <typename LhloOp>
-struct BinaryOpConverter : public RewritePattern {
- explicit BinaryOpConverter(MLIRContext* context)
- : RewritePattern(LhloOp::getOperationName(), {}, 1, context) {}
+struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
+ using OpRewritePattern<LhloOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation* op,
+ PatternMatchResult matchAndRewrite(LhloOp op,
PatternRewriter& rewriter) const override {
- auto binary_op = cast<LhloOp>(op);
-
- const auto& lhs = binary_op.lhs();
- const auto& rhs = binary_op.rhs();
+ const auto& lhs = op.lhs();
+ const auto& rhs = op.rhs();
const auto& lhs_type = lhs->getType().template cast<MemRefType>();
const auto& rhs_type = rhs->getType().template cast<MemRefType>();
const auto& element_type = lhs_type.getElementType();
if (lhs_type.getShape() != rhs_type.getShape()) {
- return matchFailure();
+ return this->matchFailure();
}
const auto& shape = lhs_type.getShape();
SmallVector<Value*, 4> induction_vars;
- const auto loc = op->getLoc();
+ const auto loc = op.getLoc();
for (int i = 0; i < shape.size(); ++i) {
auto forOp = rewriter.create<AffineForOp>(loc, 0, shape[i]);
induction_vars.push_back(forOp.getInductionVar());
@@ -140,23 +137,26 @@
auto r = rewriter.create<LoadOp>(loc, rhs, induction_vars);
auto result = GetBinaryOp<LhloOp>(element_type, loc, l, r, rewriter);
if (result == nullptr) {
- return matchFailure();
+ return this->matchFailure();
}
- rewriter.create<StoreOp>(loc, result, binary_op.out(), induction_vars);
+ rewriter.create<StoreOp>(loc, result, op.out(), induction_vars);
rewriter.replaceOp(op, {});
- return matchSuccess();
+ return this->matchSuccess();
}
};
void populateLHLOToAffineConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
- patterns->insert<BinaryOpConverter<xla_lhlo::AddOp>>(context);
- patterns->insert<BinaryOpConverter<xla_lhlo::AndOp>>(context);
- patterns->insert<BinaryOpConverter<xla_lhlo::DivOp>>(context);
- patterns->insert<BinaryOpConverter<xla_lhlo::MaxOp>>(context);
- patterns->insert<BinaryOpConverter<xla_lhlo::MinOp>>(context);
- patterns->insert<BinaryOpConverter<xla_lhlo::MulOp>>(context);
- patterns->insert<BinaryOpConverter<xla_lhlo::SubOp>>(context);
+ // clang-format off
+ patterns->insert<
+ BinaryOpConverter<xla_lhlo::AddOp>,
+ BinaryOpConverter<xla_lhlo::AndOp>,
+ BinaryOpConverter<xla_lhlo::DivOp>,
+ BinaryOpConverter<xla_lhlo::MaxOp>,
+ BinaryOpConverter<xla_lhlo::MinOp>,
+ BinaryOpConverter<xla_lhlo::MulOp>,
+ BinaryOpConverter<xla_lhlo::SubOp>>(context);
+ // clang-format on
}
struct LhloLegalizeToAffine : public FunctionPass<LhloLegalizeToAffine> {