Add static pass registration
Add static pass registration and change mlir-opt to use it. Future work is needed to refactor the registration for PassManager usage.
Change build targets to alwayslink to enforce registration.
PiperOrigin-RevId: 220390178
diff --git a/include/mlir/Pass.h b/include/mlir/Pass.h
index cd7b702..d1610bf 100644
--- a/include/mlir/Pass.h
+++ b/include/mlir/Pass.h
@@ -18,7 +18,10 @@
#ifndef MLIR_PASS_H
#define MLIR_PASS_H
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Compiler.h"
+#include <functional>
namespace mlir {
class Function;
@@ -81,6 +84,65 @@
virtual PassResult runOnModule(Module *m) override;
};
+using PassAllocatorFunction = std::function<Pass *()>;
+
+/// Structure to group information about a pass (argument to invoke via
+/// mlir-opt, description, pass allocator and unique ID).
+class PassInfo {
+public:
+ /// PassInfo constructor should not be invoked directly, instead use
+ /// PassRegistration or registerPass.
+ PassInfo(StringRef arg, StringRef description, const void *passID,
+ PassAllocatorFunction allocator)
+ : arg(arg), description(description), allocator(allocator),
+ passID(passID){};
+
+ /// Returns an allocated instance of this pass.
+ Pass *createPass() const {
+ assert(allocator &&
+ "Cannot call createPass on PassInfo without default allocator");
+ return allocator();
+ }
+
+ /// Returns the command line option that may be passed to 'mlir-opt' that will
+ /// cause this pass to run or null if there is no such argument.
+ StringRef getPassArgument() const { return arg; }
+
+ /// Returns a description for the pass, this never returns null.
+ StringRef getPassDescription() const { return description; }
+
+private:
+ // The argument with which to invoke the pass via mlir-opt.
+ StringRef arg;
+
+ // Description of the pass.
+ StringRef description;
+
+ // Allocator to construct an instance of this pass.
+ PassAllocatorFunction allocator;
+
+ // Unique identifier for pass.
+ const void *passID;
+};
+
+/// Register a specific dialect creation function with the system, typically
+/// used through the PassRegistration template.
+void registerPass(StringRef arg, StringRef description, const void *passID,
+ const PassAllocatorFunction &function);
+
+/// PassRegistration provides a global initializer that registers a Pass
+/// allocation routine.
+///
+/// Usage:
+///
+/// // At namespace scope.
+/// static PassRegistration<MyPass> Unused("unused", "Unused pass");
+template <typename ConcretePass> struct PassRegistration {
+ PassRegistration(StringRef arg, StringRef description) {
+ registerPass(arg, description, &ConcretePass::passID,
+ [&]() { return new ConcretePass(); });
+ }
+};
} // end namespace mlir
#endif // MLIR_PASS_H
diff --git a/include/mlir/Support/PassNameParser.h b/include/mlir/Support/PassNameParser.h
new file mode 100644
index 0000000..bbdf433
--- /dev/null
+++ b/include/mlir/Support/PassNameParser.h
@@ -0,0 +1,40 @@
+//===- PassNameParser.h - Base classes for compiler passes ------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// The PassNameParser class adds all passes linked in to the system that are
+// creatable to the tool.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_PASSNAMEPARSER_H_
+#define MLIR_SUPPORT_PASSNAMEPARSER_H_
+
+#include "llvm/Support/CommandLine.h"
+
+namespace mlir {
+class PassInfo;
+
+/// Adds command line option for each registered pass.
+struct PassNameParser : public llvm::cl::parser<const PassInfo *> {
+ PassNameParser(llvm::cl::Option &opt);
+
+ void printOptionInfo(const llvm::cl::Option &O,
+ size_t GlobalWidth) const override;
+};
+} // end namespace mlir
+
+#endif // MLIR_SUPPORT_PASSNAMEPARSER_H_
diff --git a/lib/Analysis/MemRefBoundCheck.cpp b/lib/Analysis/MemRefBoundCheck.cpp
index 0725cea..a7f0ebf 100644
--- a/lib/Analysis/MemRefBoundCheck.cpp
+++ b/lib/Analysis/MemRefBoundCheck.cpp
@@ -45,10 +45,14 @@
PassResult runOnCFGFunction(CFGFunction *f) override { return success(); }
void visitOperationStmt(OperationStmt *opStmt);
+
+ static char passID;
};
} // end anonymous namespace
+char MemRefBoundCheck::passID = 0;
+
FunctionPass *mlir::createMemRefBoundCheckPass() {
return new MemRefBoundCheck();
}
@@ -164,3 +168,7 @@
PassResult MemRefBoundCheck::runOnMLFunction(MLFunction *f) {
return walk(f), success();
}
+
+static PassRegistration<MemRefBoundCheck>
+ memRefBoundCheck("memref-bound-check",
+ "Check memref accesses in an MLFunction");
diff --git a/lib/Analysis/MemRefDependenceCheck.cpp b/lib/Analysis/MemRefDependenceCheck.cpp
index 3ca669c..7a620c1 100644
--- a/lib/Analysis/MemRefDependenceCheck.cpp
+++ b/lib/Analysis/MemRefDependenceCheck.cpp
@@ -51,10 +51,13 @@
loadsAndStores.push_back(opStmt);
}
}
+ static char passID;
};
} // end anonymous namespace
+char MemRefDependenceCheck::passID = 0;
+
FunctionPass *mlir::createMemRefDependenceCheckPass() {
return new MemRefDependenceCheck();
}
@@ -132,3 +135,7 @@
checkDependences(loadsAndStores);
return success();
}
+
+static PassRegistration<MemRefDependenceCheck>
+ pass("memref-dependence-check",
+ "Checks dependences between all pairs of memref accesses.");
diff --git a/lib/Analysis/Pass.cpp b/lib/Analysis/Pass.cpp
index 1249c18..ea9da5b 100644
--- a/lib/Analysis/Pass.cpp
+++ b/lib/Analysis/Pass.cpp
@@ -23,6 +23,9 @@
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
+#include "mlir/Support/PassNameParser.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ManagedStatic.h"
using namespace mlir;
/// Out of line virtual method to ensure vtables and metadata are emitted to a
@@ -51,3 +54,37 @@
return success();
}
+
+// TODO: The pass registry and pass name parsing should be moved out.
+static llvm::ManagedStatic<llvm::DenseMap<const void *, PassInfo>> passRegistry;
+
+void mlir::registerPass(StringRef arg, StringRef description,
+ const void *passID,
+ const PassAllocatorFunction &function) {
+ bool inserted = passRegistry
+ ->insert(std::make_pair(
+ passID, PassInfo(arg, description, passID, function)))
+ .second;
+ assert(inserted && "Pass registered multiple times");
+ (void)inserted;
+}
+
+PassNameParser::PassNameParser(llvm::cl::Option &opt)
+ : llvm::cl::parser<const PassInfo *>(opt) {
+ for (const auto &kv : *passRegistry) {
+ addLiteralOption(kv.second.getPassArgument(), &kv.second,
+ kv.second.getPassDescription());
+ }
+}
+
+void PassNameParser::printOptionInfo(const llvm::cl::Option &O,
+ size_t GlobalWidth) const {
+ PassNameParser *TP = const_cast<PassNameParser *>(this);
+ llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(),
+ [](const PassNameParser::OptionInfo *VT1,
+ const PassNameParser::OptionInfo *VT2) {
+ return VT1->Name.compare(VT2->Name);
+ });
+ using llvm::cl::parser;
+ parser<const PassInfo *>::printOptionInfo(O, GlobalWidth);
+}
diff --git a/lib/Transforms/CFGFunctionViewGraph.cpp b/lib/Transforms/CFGFunctionViewGraph.cpp
index a75d26c..810264c 100644
--- a/lib/Transforms/CFGFunctionViewGraph.cpp
+++ b/lib/Transforms/CFGFunctionViewGraph.cpp
@@ -74,13 +74,16 @@
namespace {
struct PrintCFGPass : public FunctionPass {
- PrintCFGPass(llvm::raw_ostream &os, bool shortNames, const llvm::Twine &title)
+ PrintCFGPass(llvm::raw_ostream &os = llvm::errs(), bool shortNames = false,
+ const llvm::Twine &title = "")
: os(os), shortNames(shortNames), title(title) {}
PassResult runOnCFGFunction(CFGFunction *function) override {
mlir::writeGraph(os, function, shortNames, title);
return success();
}
+ static char passID;
+
private:
llvm::raw_ostream &os;
bool shortNames;
@@ -88,8 +91,13 @@
};
} // namespace
+char PrintCFGPass::passID = 0;
+
FunctionPass *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os,
bool shortNames,
const llvm::Twine &title) {
return new PrintCFGPass(os, shortNames, title);
}
+
+static PassRegistration<PrintCFGPass> pass("print-cfg-graph",
+ "Print CFG graph per function");
diff --git a/lib/Transforms/Canonicalizer.cpp b/lib/Transforms/Canonicalizer.cpp
index f34118c..3a62f132 100644
--- a/lib/Transforms/Canonicalizer.cpp
+++ b/lib/Transforms/Canonicalizer.cpp
@@ -35,9 +35,13 @@
/// Canonicalize operations in functions.
struct Canonicalizer : public FunctionPass {
PassResult runOnFunction(Function *fn) override;
+
+ static char passID;
};
} // end anonymous namespace
+char Canonicalizer::passID = 0;
+
PassResult Canonicalizer::runOnFunction(Function *fn) {
auto *context = fn->getContext();
OwningPatternList patterns;
@@ -54,3 +58,6 @@
/// Create a Canonicalizer pass.
FunctionPass *mlir::createCanonicalizerPass() { return new Canonicalizer(); }
+
+static PassRegistration<Canonicalizer> pass("canonicalize",
+ "Canonicalize operations");
diff --git a/lib/Transforms/ComposeAffineMaps.cpp b/lib/Transforms/ComposeAffineMaps.cpp
index af4a5d1..61e2e8f 100644
--- a/lib/Transforms/ComposeAffineMaps.cpp
+++ b/lib/Transforms/ComposeAffineMaps.cpp
@@ -50,10 +50,14 @@
void visitOperationStmt(OperationStmt *stmt);
PassResult runOnMLFunction(MLFunction *f) override;
using StmtWalker<ComposeAffineMaps>::walk;
+
+ static char passID;
};
} // end anonymous namespace
+char ComposeAffineMaps::passID = 0;
+
FunctionPass *mlir::createComposeAffineMapsPass() {
return new ComposeAffineMaps();
}
@@ -92,3 +96,6 @@
}
return success();
}
+
+static PassRegistration<ComposeAffineMaps> pass("compose-affine-maps",
+ "Compose affine maps");
diff --git a/lib/Transforms/ConstantFold.cpp b/lib/Transforms/ConstantFold.cpp
index 411d1ca..9005c2b 100644
--- a/lib/Transforms/ConstantFold.cpp
+++ b/lib/Transforms/ConstantFold.cpp
@@ -40,9 +40,13 @@
void visitForStmt(ForStmt *stmt);
PassResult runOnCFGFunction(CFGFunction *f) override;
PassResult runOnMLFunction(MLFunction *f) override;
+
+ static char passID;
};
} // end anonymous namespace
+char ConstantFold::passID = 0;
+
/// Attempt to fold the specified operation, updating the IR to match. If
/// constants are found, we keep track of them in the existingConstants list.
///
@@ -174,3 +178,6 @@
/// Creates a constant folding pass.
FunctionPass *mlir::createConstantFoldPass() { return new ConstantFold(); }
+
+static PassRegistration<ConstantFold>
+ pass("constant-fold", "Constant fold operations in functions");
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 52687da..b36717d 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -70,6 +70,8 @@
PassResult runOnModule(Module *m) override;
+ static char passID;
+
private:
// Generates CFG functions for all ML functions in the module.
void convertMLFunctions();
@@ -90,6 +92,8 @@
};
} // end anonymous namespace
+char ModuleConverter::passID = 0;
+
// Iterates over all functions in the module generating CFG functions
// equivalent to ML functions and replacing references to ML functions
// with references to the generated ML functions.
@@ -163,3 +167,7 @@
/// Function references are appropriately patched to refer to the newly
/// generated CFG functions.
ModulePass *mlir::createConvertToCFGPass() { return new ModuleConverter(); }
+
+static PassRegistration<ModuleConverter>
+ pass("convert-to-cfg",
+ "Convert all ML functions in the module to CFG ones");
diff --git a/lib/Transforms/LoopFusion.cpp b/lib/Transforms/LoopFusion.cpp
index d9cdf9d..ae4647e 100644
--- a/lib/Transforms/LoopFusion.cpp
+++ b/lib/Transforms/LoopFusion.cpp
@@ -45,6 +45,7 @@
LoopFusion() {}
PassResult runOnMLFunction(MLFunction *f) override;
+ static char passID;
};
// LoopCollector walks the statements in an MLFunction and builds a map from
@@ -75,6 +76,8 @@
} // end anonymous namespace
+char LoopFusion::passID = 0;
+
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
// TODO(andydavis) Remove the following test code when more general loop
@@ -242,3 +245,5 @@
return success();
}
+
+static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");
diff --git a/lib/Transforms/LoopTiling.cpp b/lib/Transforms/LoopTiling.cpp
index bd66e33..3bff008 100644
--- a/lib/Transforms/LoopTiling.cpp
+++ b/lib/Transforms/LoopTiling.cpp
@@ -42,10 +42,14 @@
struct LoopTiling : public FunctionPass {
PassResult runOnMLFunction(MLFunction *f) override;
constexpr static unsigned kDefaultTileSize = 32;
+
+ static char passID;
};
} // end anonymous namespace
+char LoopTiling::passID = 0;
+
/// Creates a pass to perform loop tiling on all suitable loop nests of an
/// MLFunction.
FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
@@ -238,3 +242,5 @@
}
return success();
}
+
+static PassRegistration<LoopTiling> pass("loop-tile", "Tile loop nests");
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 15c7014..ae09098 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -56,22 +56,20 @@
Optional<unsigned> unrollFactor;
Optional<bool> unrollFull;
- explicit LoopUnroll(Optional<unsigned> unrollFactor,
- Optional<bool> unrollFull)
+ explicit LoopUnroll(Optional<unsigned> unrollFactor = None,
+ Optional<bool> unrollFull = None)
: unrollFactor(unrollFactor), unrollFull(unrollFull) {}
PassResult runOnMLFunction(MLFunction *f) override;
/// Unroll this for stmt. Returns false if nothing was done.
bool runOnForStmt(ForStmt *forStmt);
+
+ static char passID;
};
} // end anonymous namespace
-FunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
- return new LoopUnroll(unrollFactor == -1 ? None
- : Optional<unsigned>(unrollFactor),
- unrollFull == -1 ? None : Optional<bool>(unrollFull));
-}
+char LoopUnroll::passID = 0;
PassResult LoopUnroll::runOnMLFunction(MLFunction *f) {
// Gathers all innermost loops through a post order pruned walk.
@@ -286,3 +284,11 @@
return true;
}
+
+FunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
+ return new LoopUnroll(unrollFactor == -1 ? None
+ : Optional<unsigned>(unrollFactor),
+ unrollFull == -1 ? None : Optional<bool>(unrollFull));
+}
+
+static PassRegistration<LoopUnroll> pass("loop-unroll", "Unroll loops");
diff --git a/lib/Transforms/LoopUnrollAndJam.cpp b/lib/Transforms/LoopUnrollAndJam.cpp
index f437b44..ce6e939 100644
--- a/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/lib/Transforms/LoopUnrollAndJam.cpp
@@ -70,14 +70,18 @@
Optional<unsigned> unrollJamFactor;
static const unsigned kDefaultUnrollJamFactor = 4;
- explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor)
+ explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor = None)
: unrollJamFactor(unrollJamFactor) {}
PassResult runOnMLFunction(MLFunction *f) override;
bool runOnForStmt(ForStmt *forStmt);
+
+ static char passID;
};
} // end anonymous namespace
+char LoopUnrollAndJam::passID = 0;
+
FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
return new LoopUnrollAndJam(
unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor));
@@ -239,3 +243,6 @@
return true;
}
+
+static PassRegistration<LoopUnrollAndJam> pass("loop-unroll-jam",
+ "Unroll and jam loops");
diff --git a/lib/Transforms/PipelineDataTransfer.cpp b/lib/Transforms/PipelineDataTransfer.cpp
index c59e007..52052e0 100644
--- a/lib/Transforms/PipelineDataTransfer.cpp
+++ b/lib/Transforms/PipelineDataTransfer.cpp
@@ -47,10 +47,14 @@
// Collect all 'for' statements.
void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
std::vector<ForStmt *> forStmts;
+
+ static char passID;
};
} // end anonymous namespace
+char PipelineDataTransfer::passID = 0;
+
/// Creates a pass to pipeline explicit movement of data across levels of the
/// memory hierarchy.
FunctionPass *mlir::createPipelineDataTransferPass() {
@@ -306,3 +310,8 @@
return success();
}
+
+static PassRegistration<PipelineDataTransfer> pass(
+ "pipeline-data-transfer",
+ "Pipeline non-blocking data transfers between explicitly managed levels of "
+ "the memory hierarchy");
diff --git a/lib/Transforms/SimplifyAffineExpr.cpp b/lib/Transforms/SimplifyAffineExpr.cpp
index a412a83..92d585f 100644
--- a/lib/Transforms/SimplifyAffineExpr.cpp
+++ b/lib/Transforms/SimplifyAffineExpr.cpp
@@ -47,10 +47,14 @@
void visitIfStmt(IfStmt *ifStmt);
void visitOperationStmt(OperationStmt *opStmt);
+
+ static char passID;
};
} // end anonymous namespace
+char SimplifyAffineStructures::passID = 0;
+
FunctionPass *mlir::createSimplifyAffineStructuresPass() {
return new SimplifyAffineStructures();
}
@@ -83,3 +87,6 @@
walk(f);
return success();
}
+
+static PassRegistration<SimplifyAffineStructures>
+ pass("simplify-affine-structures", "Simplify affine expressions");
diff --git a/lib/Transforms/Vectorize.cpp b/lib/Transforms/Vectorize.cpp
index fa97b70..63969af 100644
--- a/lib/Transforms/Vectorize.cpp
+++ b/lib/Transforms/Vectorize.cpp
@@ -199,10 +199,14 @@
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
MLFunctionMatcherContext MLContext;
+
+ static char passID;
};
} // end anonymous namespace
+char Vectorize::passID = 0;
+
/////// TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate. //////
namespace {
@@ -669,3 +673,7 @@
}
FunctionPass *mlir::createVectorizePass() { return new Vectorize(); }
+
+static PassRegistration<Vectorize>
+ pass("vectorize",
+ "Vectorize to a target independent n-D vector abstraction");
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index 700436e..3225860 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -30,6 +30,7 @@
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
#include "mlir/Pass.h"
+#include "mlir/Support/PassNameParser.h"
#include "mlir/TensorFlow/ControlFlowOps.h"
#include "mlir/TensorFlow/Passes.h"
#include "mlir/TensorFlowLite/Passes.h"
@@ -67,58 +68,7 @@
"expected-* lines on the corresponding line"),
cl::init(false));
-enum Passes {
- Canonicalize,
- ComposeAffineMaps,
- ConstantFold,
- ConvertToCFG,
- TFLiteLegaize,
- LoopFusion,
- LoopTiling,
- LoopUnroll,
- LoopUnrollAndJam,
- MemRefBoundCheck,
- MemRefDependenceCheck,
- PipelineDataTransfer,
- PrintCFGGraph,
- SimplifyAffineStructures,
- TFRaiseControlFlow,
- Vectorize,
- XLALower,
-};
-
-static cl::list<Passes> passList(
- "", cl::desc("Compiler passes to run"),
- cl::values(
- clEnumValN(Canonicalize, "canonicalize", "Canonicalize operations"),
- clEnumValN(ComposeAffineMaps, "compose-affine-maps",
- "Compose affine maps"),
- clEnumValN(ConstantFold, "constant-fold",
- "Constant fold operations in functions"),
- clEnumValN(ConvertToCFG, "convert-to-cfg",
- "Convert all ML functions in the module to CFG ones"),
- clEnumValN(LoopFusion, "loop-fusion", "Fuse loop nests"),
- clEnumValN(LoopTiling, "loop-tile", "Tile loop nests"),
- clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
- clEnumValN(LoopUnrollAndJam, "loop-unroll-jam", "Unroll and jam loops"),
- clEnumValN(MemRefBoundCheck, "memref-bound-check",
- "Convert all ML functions in the module to CFG ones"),
- clEnumValN(MemRefDependenceCheck, "memref-dependence-check",
- "Checks dependences between all pairs of memref accesses."),
- clEnumValN(PipelineDataTransfer, "pipeline-data-transfer",
- "Pipeline non-blocking data transfers between"
- "explicitly managed levels of the memory hierarchy"),
- clEnumValN(PrintCFGGraph, "print-cfg-graph",
- "Print CFG graph per function"),
- clEnumValN(SimplifyAffineStructures, "simplify-affine-structures",
- "Simplify affine expressions"),
- clEnumValN(TFLiteLegaize, "tfl-legalize",
- "Legalize operations to TensorFlow Lite dialect"),
- clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
- "Dynamic TensorFlow Switch/Match nodes to a CFG"),
- clEnumValN(Vectorize, "vectorize",
- "Vectorize to a target independent n-D vector abstraction."),
- clEnumValN(XLALower, "xla-lower", "Lower to XLA dialect")));
+static std::vector<const mlir::PassInfo *> *passList;
enum OptResult { OptSuccess, OptFailure };
@@ -190,65 +140,9 @@
return OptFailure;
// Run each of the passes that were selected.
- for (unsigned i = 0, e = passList.size(); i != e; ++i) {
- auto passKind = passList[i];
- Pass *pass = nullptr;
- switch (passKind) {
- case Canonicalize:
- pass = createCanonicalizerPass();
- break;
- case ComposeAffineMaps:
- pass = createComposeAffineMapsPass();
- break;
- case ConstantFold:
- pass = createConstantFoldPass();
- break;
- case ConvertToCFG:
- pass = createConvertToCFGPass();
- break;
- case LoopFusion:
- pass = createLoopFusionPass();
- break;
- case LoopTiling:
- pass = createLoopTilingPass();
- break;
- case LoopUnroll:
- pass = createLoopUnrollPass();
- break;
- case LoopUnrollAndJam:
- pass = createLoopUnrollAndJamPass();
- break;
- case MemRefBoundCheck:
- pass = createMemRefBoundCheckPass();
- break;
- case MemRefDependenceCheck:
- pass = createMemRefDependenceCheckPass();
- break;
- case PipelineDataTransfer:
- pass = createPipelineDataTransferPass();
- break;
- case PrintCFGGraph:
- pass = createPrintCFGGraphPass();
- break;
- case SimplifyAffineStructures:
- pass = createSimplifyAffineStructuresPass();
- break;
- case TFLiteLegaize:
- pass = tfl::createLegalizer();
- break;
- case TFRaiseControlFlow:
- pass = createRaiseTFControlFlowPass();
- break;
- case Vectorize:
- pass = createVectorizePass();
- break;
- case XLALower:
- pass = createXLALowerPass();
- break;
- }
-
+ for (const auto *passInfo : *passList) {
+ std::unique_ptr<Pass> pass(passInfo->createPass());
PassResult result = pass->runOnModule(module.get());
- delete pass;
if (result)
return OptFailure;
@@ -468,6 +362,10 @@
llvm::PrettyStackTraceProgram x(argc, argv);
InitLLVM y(argc, argv);
+ // Parse pass names in main to ensure static initialization completed.
+ llvm::cl::list<const mlir::PassInfo *, bool, mlir::PassNameParser> passList(
+ "", llvm::cl::desc("Compiler passes to run"));
+ ::passList = &passList;
cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
// Set up the input file.