Use fused location for rewritten ops in generated rewrites.

    This does tracks the location by recording all the ops in the source pattern and using the fused location for the transformed op. Track the locations via the rewrite state which is a bit heavy weight, in follow up to change to matchAndRewrite this will be addressed (and need for extra array go away).

--

PiperOrigin-RevId: 249986555
diff --git a/test/mlir-tblgen/pattern.td b/test/mlir-tblgen/pattern.td
index 66ff381..d8e92cb 100644
--- a/test/mlir-tblgen/pattern.td
+++ b/test/mlir-tblgen/pattern.td
@@ -19,6 +19,7 @@
 }
 
 def MyRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
+def MyRule2 : Pat<(OpA (OpA $input, $attr), $attr), (OpB $input, $attr)>;
 
 // Test rewrite rule naming
 // ---
@@ -27,6 +28,13 @@
 // CHECK-NEXT: {{.*pattern.td.*}}
 // CHECK: struct MyRule : public RewritePattern
 
+// CHECK-LABEL: struct MyRule2 : public RewritePattern
+// CHECK: s.autogeneratedRewritePatternOps[0] = op0;
+// CHECK: s.autogeneratedRewritePatternOps[1] = op1;
+// CHECK: rewriter.getFusedLoc({
+// CHECK-SAME: s.autogeneratedRewritePatternOps[0]->getLoc()
+// CHECK-SAME: s.autogeneratedRewritePatternOps[1]->getLoc()
+
 def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
 
 // Test basic structure generated from Pattern
diff --git a/tools/mlir-tblgen/RewriterGen.cpp b/tools/mlir-tblgen/RewriterGen.cpp
index 9103cb0..cd93b98 100644
--- a/tools/mlir-tblgen/RewriterGen.cpp
+++ b/tools/mlir-tblgen/RewriterGen.cpp
@@ -166,6 +166,7 @@
   // Emits C++ statements for matching the `index`-th argument of the given DAG
   // `tree` as an operand.
   void emitOperandMatch(DagNode tree, int index, int depth, int indent);
+
   // Emits C++ statements for matching the `index`-th argument of the given DAG
   // `tree` as an attribute.
   void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
@@ -231,6 +232,9 @@
   FmtContext matchCtx;
   // Format contexts containing placeholder substitutations for rewrite().
   FmtContext rewriteCtx;
+
+  // Number of op processed.
+  int opCounter = 0;
 };
 } // end anonymous namespace
 
@@ -289,6 +293,9 @@
           << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
                      depth + 1, depth, i);
       emitOpMatch(argTree, depth + 1);
+      os.indent(indent + 2)
+          << formatv("s.autogeneratedRewritePatternOps[{0}] = op{1};\n",
+                     ++opCounter, depth + 1);
       os.indent(indent) << "}\n";
       continue;
     }
@@ -397,6 +404,7 @@
     auto ctx = op0->getContext(); (void)ctx;
     auto state = llvm::make_unique<MatchedState>();
     auto &s = *state;
+    s.autogeneratedRewritePatternOps[0] = op0;
 )";
 
   // The rewrite pattern may specify that certain outputs should be unused in
@@ -500,6 +508,10 @@
   for (const auto &result : pattern.getSourcePatternBoundOps()) {
     os.indent(4) << "Operation *" << result.getKey() << ";\n";
   }
+  // TODO(jpienaar): Change to matchAndRewrite & capture ops with consistent
+  // numbering so that it can be reused for fused loc.
+  os.indent(4) << "Operation* autogeneratedRewritePatternOps["
+               << pattern.getSourcePattern().getNumOps() << "];\n";
   os << "  };\n";
 
   emitMatchMethod(tree);
@@ -521,8 +533,12 @@
   void rewrite(Operation *op, std::unique_ptr<PatternState> state,
                PatternRewriter &rewriter) const override {
     auto& s = *static_cast<MatchedState *>(state.get());
-    auto loc = op->getLoc(); (void)loc;
-)";
+    auto loc = rewriter.getFusedLoc({)";
+  for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
+    os << (i ? ", " : "") << "s.autogeneratedRewritePatternOps[" << i
+       << "]->getLoc()";
+  }
+  os << "}); (void)loc;\n";
 
   // Collect the replacement value for each result
   llvm::SmallVector<std::string, 2> resultValues;