Add pattern file location to generated code to trace origin of pattern.
--
PiperOrigin-RevId: 249734666
diff --git a/include/mlir/TableGen/Pattern.h b/include/mlir/TableGen/Pattern.h
index 1b75712..f5eb9a3 100644
--- a/include/mlir/TableGen/Pattern.h
+++ b/include/mlir/TableGen/Pattern.h
@@ -220,6 +220,12 @@
// Returns the benefit score of the pattern.
int getBenefit() const;
+ using IdentifierLine = std::pair<StringRef, unsigned>;
+
+ // Returns the file location of the pattern (buffer identifier + line number
+ // pair).
+ std::vector<IdentifierLine> getLocation() const;
+
private:
// Recursively collects all bound arguments inside the DAG tree rooted
// at `tree`.
diff --git a/lib/TableGen/Pattern.cpp b/lib/TableGen/Pattern.cpp
index 285f1c9..31bab81 100644
--- a/lib/TableGen/Pattern.cpp
+++ b/lib/TableGen/Pattern.cpp
@@ -229,6 +229,20 @@
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
}
+std::vector<tblgen::Pattern::IdentifierLine>
+tblgen::Pattern::getLocation() const {
+ std::vector<std::pair<StringRef, unsigned>> result;
+ result.reserve(def.getLoc().size());
+ for (auto loc : def.getLoc()) {
+ unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
+ assert(buf && "invalid source location");
+ result.emplace_back(
+ llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
+ llvm::SrcMgr.getLineAndColumn(loc, buf).first);
+ }
+ return result;
+}
+
void tblgen::Pattern::collectBoundArguments(DagNode tree) {
auto &op = getDialectOp(tree);
auto numOpArgs = op.getNumArgs();
diff --git a/test/mlir-tblgen/pattern.td b/test/mlir-tblgen/pattern.td
index bb5055a..b5a6c60 100644
--- a/test/mlir-tblgen/pattern.td
+++ b/test/mlir-tblgen/pattern.td
@@ -23,6 +23,8 @@
// Test rewrite rule naming
// ---
+// CHECK: Generated from:
+// CHECK-NEXT: {{.*pattern.td.*}}
// CHECK: struct MyRule : public RewritePattern
def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
diff --git a/tools/mlir-tblgen/RewriterGen.cpp b/tools/mlir-tblgen/RewriterGen.cpp
index 57068e2..b00be1c 100644
--- a/tools/mlir-tblgen/RewriterGen.cpp
+++ b/tools/mlir-tblgen/RewriterGen.cpp
@@ -30,6 +30,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatAdapters.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
@@ -41,6 +42,15 @@
using namespace mlir;
using namespace mlir::tblgen;
+namespace llvm {
+template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
+ static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
+ raw_ostream &os, StringRef style) {
+ os << v.first << ":" << v.second;
+ }
+};
+} // end namespace llvm
+
// Returns the bound symbol for the given op argument or op named `symbol`.
//
// Arguments and ops bound in the source pattern are grouped into a
@@ -445,6 +455,9 @@
loc, "replacing op with variadic results not supported right now");
// Emit RewritePattern for Pattern.
+ auto locs = pattern.getLocation();
+ os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
+ make_range(locs.rbegin(), locs.rend()));
os << formatv(R"(struct {0} : public RewritePattern {
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
rewriteName, rootName, pattern.getBenefit())