Add pattern file location to generated code to trace origin of pattern.

--

PiperOrigin-RevId: 249734666
diff --git a/include/mlir/TableGen/Pattern.h b/include/mlir/TableGen/Pattern.h
index 1b75712..f5eb9a3 100644
--- a/include/mlir/TableGen/Pattern.h
+++ b/include/mlir/TableGen/Pattern.h
@@ -220,6 +220,12 @@
   // Returns the benefit score of the pattern.
   int getBenefit() const;
 
+  using IdentifierLine = std::pair<StringRef, unsigned>;
+
+  // Returns the file location of the pattern (buffer identifier + line number
+  // pair).
+  std::vector<IdentifierLine> getLocation() const;
+
 private:
   // Recursively collects all bound arguments inside the DAG tree rooted
   // at `tree`.
diff --git a/lib/TableGen/Pattern.cpp b/lib/TableGen/Pattern.cpp
index 285f1c9..31bab81 100644
--- a/lib/TableGen/Pattern.cpp
+++ b/lib/TableGen/Pattern.cpp
@@ -229,6 +229,20 @@
   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
 }
 
+std::vector<tblgen::Pattern::IdentifierLine>
+tblgen::Pattern::getLocation() const {
+  std::vector<std::pair<StringRef, unsigned>> result;
+  result.reserve(def.getLoc().size());
+  for (auto loc : def.getLoc()) {
+    unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
+    assert(buf && "invalid source location");
+    result.emplace_back(
+        llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
+        llvm::SrcMgr.getLineAndColumn(loc, buf).first);
+  }
+  return result;
+}
+
 void tblgen::Pattern::collectBoundArguments(DagNode tree) {
   auto &op = getDialectOp(tree);
   auto numOpArgs = op.getNumArgs();
diff --git a/test/mlir-tblgen/pattern.td b/test/mlir-tblgen/pattern.td
index bb5055a..b5a6c60 100644
--- a/test/mlir-tblgen/pattern.td
+++ b/test/mlir-tblgen/pattern.td
@@ -23,6 +23,8 @@
 // Test rewrite rule naming
 // ---
 
+// CHECK: Generated from:
+// CHECK-NEXT: {{.*pattern.td.*}}
 // CHECK: struct MyRule : public RewritePattern
 
 def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
diff --git a/tools/mlir-tblgen/RewriterGen.cpp b/tools/mlir-tblgen/RewriterGen.cpp
index 57068e2..b00be1c 100644
--- a/tools/mlir-tblgen/RewriterGen.cpp
+++ b/tools/mlir-tblgen/RewriterGen.cpp
@@ -30,6 +30,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatAdapters.h"
 #include "llvm/Support/PrettyStackTrace.h"
 #include "llvm/Support/Signals.h"
 #include "llvm/TableGen/Error.h"
@@ -41,6 +42,15 @@
 using namespace mlir;
 using namespace mlir::tblgen;
 
+namespace llvm {
+template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
+  static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
+                     raw_ostream &os, StringRef style) {
+    os << v.first << ":" << v.second;
+  }
+};
+} // end namespace llvm
+
 // Returns the bound symbol for the given op argument or op named `symbol`.
 //
 // Arguments and ops bound in the source pattern are grouped into a
@@ -445,6 +455,9 @@
         loc, "replacing op with variadic results not supported right now");
 
   // Emit RewritePattern for Pattern.
+  auto locs = pattern.getLocation();
+  os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
+                make_range(locs.rbegin(), locs.rend()));
   os << formatv(R"(struct {0} : public RewritePattern {
   {0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
                 rewriteName, rootName, pattern.getBenefit())