Add support for coalescing adjacent nested pass pipelines.

This allows for parallelizing across pipelines of multiple operation types. AdaptorPasses can now hold pass managers for multiple operation types and will dispatch based upon the operation being operated on.

PiperOrigin-RevId: 268017344
diff --git a/third_party/mlir/include/mlir/Pass/PassManager.h b/third_party/mlir/include/mlir/Pass/PassManager.h
index b240e5b..75c8653 100644
--- a/third_party/mlir/include/mlir/Pass/PassManager.h
+++ b/third_party/mlir/include/mlir/Pass/PassManager.h
@@ -51,6 +51,7 @@
   OpPassManager(OpPassManager &&) = default;
   OpPassManager(const OpPassManager &rhs);
   ~OpPassManager();
+  OpPassManager &operator=(const OpPassManager &rhs);
 
   /// Run the held passes over the given operation.
   LogicalResult run(Operation *op, AnalysisManager am);
@@ -77,12 +78,12 @@
   /// Return the operation name that this pass manager operates on.
   const OperationName &getOpName() const;
 
-private:
-  OpPassManager(OperationName name, bool disableThreads, bool verifyPasses);
-
   /// Returns the internal implementation instance.
   detail::OpPassManagerImpl &getImpl();
 
+private:
+  OpPassManager(OperationName name, bool disableThreads, bool verifyPasses);
+
   /// A pointer to an internal implementation instance.
   std::unique_ptr<detail::OpPassManagerImpl> impl;
 
diff --git a/third_party/mlir/lib/Pass/Pass.cpp b/third_party/mlir/lib/Pass/Pass.cpp
index c183e36..716aa7e 100644
--- a/third_party/mlir/lib/Pass/Pass.cpp
+++ b/third_party/mlir/lib/Pass/Pass.cpp
@@ -25,7 +25,6 @@
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Module.h"
-#include "mlir/Pass/PassManager.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Mutex.h"
 #include "llvm/Support/Parallel.h"
@@ -81,7 +80,7 @@
 }
 
 //===----------------------------------------------------------------------===//
-// OpPassManager
+// OpPassManagerImpl
 //===----------------------------------------------------------------------===//
 
 namespace mlir {
@@ -91,26 +90,18 @@
       : name(name), disableThreads(disableThreads), verifyPasses(verifyPasses) {
   }
 
-  /// Returns the pass manager instance corresponding to the last pass added
-  /// if that pass was a PassAdaptor.
-  OpPassManager *getLastNestedPM() {
-    if (passes.empty())
-      return nullptr;
-    auto lastPassIt = passes.rbegin();
-
-    // If this pass was a verifier, skip it as it is opaque to ordering for
-    // pipeline construction.
-    if (isa<VerifierPass>(*lastPassIt))
-      ++lastPassIt;
-
-    // Get the internal pass manager if this pass is an adaptor.
-    if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(lastPassIt->get()))
-      return &adaptor->getPassManager();
-    if (auto *adaptor = dyn_cast<OpToOpPassAdaptorParallel>(lastPassIt->get()))
-      return &adaptor->getPassManager();
-    return nullptr;
+  /// Merge the passes of this pass manager into the one provided.
+  void mergeInto(OpPassManagerImpl &rhs) {
+    assert(name == rhs.name && "merging unrelated pass managers");
+    for (auto &pass : passes)
+      rhs.passes.push_back(std::move(pass));
+    passes.clear();
   }
 
+  /// Coalesce adjacent AdaptorPasses into one large adaptor. This runs
+  /// recursively through the pipeline graph.
+  void coalesceAdjacentAdaptorPasses();
+
   /// The name of the operation that passes of this pass manager operate on.
   OperationName name;
 
@@ -126,6 +117,62 @@
 } // end namespace detail
 } // end namespace mlir
 
+/// Coalesce adjacent AdaptorPasses into one large adaptor. This runs
+/// recursively through the pipeline graph.
+void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
+  // Bail out early if there are no adaptor passes.
+  if (llvm::none_of(passes, [](std::unique_ptr<Pass> &pass) {
+        return isAdaptorPass(pass.get());
+      }))
+    return;
+
+  // Walk the pass list and merge adjacent adaptors.
+  OpToOpPassAdaptorBase *lastAdaptor = nullptr;
+  for (auto it = passes.begin(), e = passes.end(); it != e; ++it) {
+    // Check to see if this pass is an adaptor.
+    if (auto *currentAdaptor = getAdaptorPassBase(it->get())) {
+      // If it is the first adaptor in a possible chain, remember it and
+      // continue.
+      if (!lastAdaptor) {
+        lastAdaptor = currentAdaptor;
+        continue;
+      }
+
+      // Otherwise, merge into the existing adaptor and delete the current one.
+      currentAdaptor->mergeInto(*lastAdaptor);
+      it->reset();
+
+      // If the verifier is enabled, then next pass is a verifier run so
+      // drop it. Verifier passes are inserted after every pass, so this one
+      // would be a duplicate.
+      if (verifyPasses) {
+        assert(std::next(it) != e && isa<VerifierPass>(*std::next(it)));
+        (++it)->reset();
+      }
+    } else if (lastAdaptor && !isa<VerifierPass>(*it)) {
+      // If this pass is not an adaptor and not a verifier pass, then coalesce
+      // and forget any existing adaptor.
+      for (auto &pm : lastAdaptor->getPassManagers())
+        pm.getImpl().coalesceAdjacentAdaptorPasses();
+      lastAdaptor = nullptr;
+    }
+  }
+
+  // If there was an adaptor at the end of the manager, coalesce it as well.
+  if (lastAdaptor) {
+    for (auto &pm : lastAdaptor->getPassManagers())
+      pm.getImpl().coalesceAdjacentAdaptorPasses();
+  }
+
+  // Now that the adaptors have been merged, erase the empty slot corresponding
+  // to the merged adaptors that were nulled-out in the loop above.
+  llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>());
+}
+
+//===----------------------------------------------------------------------===//
+// OpPassManager
+//===----------------------------------------------------------------------===//
+
 OpPassManager::OpPassManager(OperationName name, bool disableThreads,
                              bool verifyPasses)
     : impl(new OpPassManagerImpl(name, disableThreads, verifyPasses)) {
@@ -136,11 +183,13 @@
          "OpPassManager only supports operating on operations marked as "
          "'IsolatedFromAbove'");
 }
-OpPassManager::OpPassManager(const OpPassManager &rhs)
-    : impl(new OpPassManagerImpl(rhs.impl->name, rhs.impl->disableThreads,
-                                 rhs.impl->verifyPasses)) {
+OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
+OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
+  impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->disableThreads,
+                                   rhs.impl->verifyPasses));
   for (auto &pass : rhs.impl->passes)
     impl->passes.emplace_back(pass->clone());
+  return *this;
 }
 
 OpPassManager::~OpPassManager() {}
@@ -157,23 +206,19 @@
 /// Nest a new operation pass manager for the given operation kind under this
 /// pass manager.
 OpPassManager &OpPassManager::nest(const OperationName &nestedName) {
-  // Check to see if an existing nested pass manager already exists.
-  if (auto *nestedPM = impl->getLastNestedPM()) {
-    if (nestedPM->getOpName() == nestedName)
-      return *nestedPM;
+  OpPassManager nested(nestedName, impl->disableThreads, impl->verifyPasses);
+
+  /// Create an adaptor for this pass. If multi-threading is disabled, then
+  /// create a synchronous adaptor.
+  if (impl->disableThreads || !llvm::llvm_is_multithreaded()) {
+    auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
+    addPass(std::unique_ptr<Pass>(adaptor));
+    return adaptor->getPassManagers().front();
   }
 
-  std::unique_ptr<OpPassManager> nested(
-      new OpPassManager(nestedName, impl->disableThreads, impl->verifyPasses));
-  auto &nestedRef = *nested;
-
-  /// Create an executor adaptor for this pass. If multi-threading is disabled,
-  /// then create a synchronous adaptor.
-  if (impl->disableThreads || !llvm::llvm_is_multithreaded())
-    addPass(std::make_unique<OpToOpPassAdaptor>(std::move(nested)));
-  else
-    addPass(std::make_unique<OpToOpPassAdaptorParallel>(std::move(nested)));
-  return nestedRef;
+  auto *adaptor = new OpToOpPassAdaptorParallel(std::move(nested));
+  addPass(std::unique_ptr<Pass>(adaptor));
+  return adaptor->getPassManagers().front();
 }
 OpPassManager &OpPassManager::nest(StringRef nestedName) {
   return nest(OperationName(nestedName, getContext()));
@@ -227,10 +272,43 @@
   return result;
 }
 
-OpToOpPassAdaptor::OpToOpPassAdaptor(std::unique_ptr<OpPassManager> mgr)
-    : mgr(std::move(mgr)) {}
-OpToOpPassAdaptor::OpToOpPassAdaptor(const OpToOpPassAdaptor &rhs)
-    : mgr(new OpPassManager(*rhs.mgr)) {}
+/// Find an operation pass manager that can operate on an operation of the given
+/// type, or nullptr if one does not exist.
+static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
+                                         const OperationName &name) {
+  auto it = llvm::find_if(
+      mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
+  return it == mgrs.end() ? nullptr : &*it;
+}
+
+OpToOpPassAdaptorBase::OpToOpPassAdaptorBase(OpPassManager &&mgr) {
+  mgrs.emplace_back(std::move(mgr));
+}
+
+/// Merge the current pass adaptor into given 'rhs'.
+void OpToOpPassAdaptorBase::mergeInto(OpToOpPassAdaptorBase &rhs) {
+  for (auto &pm : mgrs) {
+    // If an existing pass manager exists, then merge the given pass manager
+    // into it.
+    if (auto *existingPM = findPassManagerFor(rhs.mgrs, pm.getOpName())) {
+      pm.getImpl().mergeInto(existingPM->getImpl());
+    } else {
+      // Otherwise, add the given pass manager to the list.
+      rhs.mgrs.emplace_back(std::move(pm));
+    }
+  }
+  mgrs.clear();
+
+  // After coalescing, sort the pass managers within rhs by name.
+  llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(),
+                       [](const OpPassManager *lhs, const OpPassManager *rhs) {
+                         return lhs->getOpName().getStringRef().compare(
+                             rhs->getOpName().getStringRef());
+                       });
+}
+
+OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr)
+    : OpToOpPassAdaptorBase(std::move(mgr)) {}
 
 /// Run the held pipeline over all nested operations.
 void OpToOpPassAdaptor::runOnOperation() {
@@ -240,7 +318,8 @@
   for (auto &region : getOperation()->getRegions()) {
     for (auto &block : region) {
       for (auto &op : block) {
-        if (op.getName() != mgr->getOpName())
+        auto *mgr = findPassManagerFor(mgrs, op.getName());
+        if (!mgr)
           continue;
 
         // Run the held pipeline over the current operation.
@@ -257,12 +336,17 @@
   }
 }
 
-OpToOpPassAdaptorParallel::OpToOpPassAdaptorParallel(
-    std::unique_ptr<OpPassManager> mgr)
-    : mgr(std::move(mgr)) {}
-OpToOpPassAdaptorParallel::OpToOpPassAdaptorParallel(
-    const OpToOpPassAdaptorParallel &rhs)
-    : mgr(std::make_unique<OpPassManager>(*rhs.mgr)) {}
+OpToOpPassAdaptorParallel::OpToOpPassAdaptorParallel(OpPassManager &&mgr)
+    : OpToOpPassAdaptorBase(std::move(mgr)) {}
+
+/// Utility functor that checks if the two ranges of pass managers have a size
+/// mismatch.
+static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
+                            ArrayRef<OpPassManager> rhs) {
+  return lhs.size() != rhs.size() ||
+         llvm::any_of(llvm::seq<size_t>(0, lhs.size()),
+                      [&](size_t i) { return lhs[i].size() != rhs[i].size(); });
+}
 
 // Run the held pipeline asynchronously across the functions within the module.
 void OpToOpPassAdaptorParallel::runOnOperation() {
@@ -270,8 +354,8 @@
 
   // Create the async executors if they haven't been created, or if the main
   // pipeline has changed.
-  if (asyncExecutors.empty() || asyncExecutors.front().size() != mgr->size())
-    asyncExecutors = {llvm::hardware_concurrency(), *mgr};
+  if (asyncExecutors.empty() || hasSizeMismatch(asyncExecutors.front(), mgrs))
+    asyncExecutors.assign(llvm::hardware_concurrency(), mgrs);
 
   // Run a prepass over the module to collect the operations to execute over.
   // This ensures that an analysis manager exists for each operation, as well as
@@ -280,8 +364,8 @@
   for (auto &region : getOperation()->getRegions()) {
     for (auto &block : region) {
       for (auto &op : block) {
-        // Add this operation iff the name matches the current pass manager.
-        if (op.getName() == mgr->getOpName())
+        // Add this operation iff the name matches the any of the pass managers.
+        if (findPassManagerFor(mgrs, op.getName()))
           opAMPairs.emplace_back(&op, am.slice(&op));
       }
     }
@@ -304,7 +388,7 @@
       llvm::parallel::par, asyncExecutors.begin(),
       std::next(asyncExecutors.begin(),
                 std::min(asyncExecutors.size(), opAMPairs.size())),
-      [&](OpPassManager &pm) {
+      [&](MutableArrayRef<OpPassManager> pms) {
         for (auto e = opAMPairs.size(); !passFailed && opIt < e;) {
           // Get the next available operation index.
           unsigned nextID = opIt++;
@@ -314,14 +398,16 @@
           // Set the order id for this thread in the diagnostic handler.
           diagHandler.setOrderIDForThread(nextID);
 
+          // Get the pass manager for this operation and execute it.
           auto &it = opAMPairs[nextID];
+          auto *pm = findPassManagerFor(pms, it.first->getName());
+          assert(pm && "expected valid pass manager for operation");
 
-          // Run the executor over the current operation.
           if (instrumentor)
-            instrumentor->runBeforePipeline(pm.getOpName(), parentThreadID);
-          auto pipelineResult = runPipeline(pm, it.first, it.second);
+            instrumentor->runBeforePipeline(pm->getOpName(), parentThreadID);
+          auto pipelineResult = runPipeline(*pm, it.first, it.second);
           if (instrumentor)
-            instrumentor->runAfterPipeline(pm.getOpName(), parentThreadID);
+            instrumentor->runAfterPipeline(pm->getOpName(), parentThreadID);
 
           // Handle a failed pipeline result.
           if (failed(pipelineResult)) {
@@ -336,14 +422,14 @@
     signalPassFailure();
 }
 
-/// Utility function to return the operation name that the given adaptor pass
-/// operates on. Return None if the given pass is not an adaptor pass.
-Optional<StringRef> mlir::detail::getAdaptorPassOpName(Pass *pass) {
+/// Utility function to convert the given class to the base adaptor it is an
+/// adaptor pass, returns nullptr otherwise.
+OpToOpPassAdaptorBase *mlir::detail::getAdaptorPassBase(Pass *pass) {
   if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
-    return adaptor->getPassManager().getOpName().getStringRef();
+    return adaptor;
   if (auto *adaptor = dyn_cast<OpToOpPassAdaptorParallel>(pass))
-    return adaptor->getPassManager().getOpName().getStringRef();
-  return llvm::None;
+    return adaptor;
+  return nullptr;
 }
 
 //===----------------------------------------------------------------------===//
@@ -359,6 +445,11 @@
 
 /// Run the passes within this manager on the provided module.
 LogicalResult PassManager::run(ModuleOp module) {
+  // Before running, make sure to coalesce any adjacent pass adaptors in the
+  // pipeline.
+  opPassManager.getImpl().coalesceAdjacentAdaptorPasses();
+
+  // Construct an analysis manager for the pipeline and run it.
   ModuleAnalysisManager am(module, instrumentor.get());
   return opPassManager.run(module, am);
 }
diff --git a/third_party/mlir/lib/Pass/PassDetail.h b/third_party/mlir/lib/Pass/PassDetail.h
index fdd5610..29bb04d 100644
--- a/third_party/mlir/lib/Pass/PassDetail.h
+++ b/third_party/mlir/lib/Pass/PassDetail.h
@@ -18,10 +18,9 @@
 #define MLIR_PASS_PASSDETAIL_H_
 
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
 
 namespace mlir {
-class OpPassManager;
-
 namespace detail {
 
 //===----------------------------------------------------------------------===//
@@ -37,56 +36,61 @@
 // OpToOpPassAdaptor
 //===----------------------------------------------------------------------===//
 
+/// A base class for Op-to-Op adaptor passes.
+class OpToOpPassAdaptorBase {
+public:
+  OpToOpPassAdaptorBase(OpPassManager &&mgr);
+  OpToOpPassAdaptorBase(const OpToOpPassAdaptorBase &rhs) = default;
+
+  /// Merge the current pass adaptor into given 'rhs'.
+  void mergeInto(OpToOpPassAdaptorBase &rhs);
+
+  /// Returns the pass managers held by this adaptor.
+  MutableArrayRef<OpPassManager> getPassManagers() { return mgrs; }
+
+protected:
+  // A set of adaptors to run.
+  SmallVector<OpPassManager, 1> mgrs;
+};
+
 /// An adaptor pass used to run operation passes over nested operations
 /// synchronously on a single thread.
-class OpToOpPassAdaptor : public OperationPass<OpToOpPassAdaptor> {
+class OpToOpPassAdaptor : public OperationPass<OpToOpPassAdaptor>,
+                          public OpToOpPassAdaptorBase {
 public:
-  OpToOpPassAdaptor(std::unique_ptr<OpPassManager> mgr);
-  OpToOpPassAdaptor(const OpToOpPassAdaptor &rhs);
+  OpToOpPassAdaptor(OpPassManager &&mgr);
 
   /// Run the held pipeline over all operations.
   void runOnOperation() override;
-
-  /// Returns the nested pass manager for this adaptor.
-  OpPassManager &getPassManager() { return *mgr; }
-
-private:
-  std::unique_ptr<OpPassManager> mgr;
 };
 
 /// An adaptor pass used to run operation passes over nested operations
 /// asynchronously across multiple threads.
 class OpToOpPassAdaptorParallel
-    : public OperationPass<OpToOpPassAdaptorParallel> {
+    : public OperationPass<OpToOpPassAdaptorParallel>,
+      public OpToOpPassAdaptorBase {
 public:
-  OpToOpPassAdaptorParallel(std::unique_ptr<OpPassManager> mgr);
-  OpToOpPassAdaptorParallel(const OpToOpPassAdaptorParallel &rhs);
+  OpToOpPassAdaptorParallel(OpPassManager &&mgr);
 
   /// Run the held pipeline over all operations.
   void runOnOperation() override;
 
-  /// Returns the nested pass manager for this adaptor.
-  OpPassManager &getPassManager() { return *mgr; }
-
 private:
-  // The main pass executor for this adaptor.
-  std::unique_ptr<OpPassManager> mgr;
-
   // A set of executors, cloned from the main executor, that run asynchronously
   // on different threads.
-  std::vector<OpPassManager> asyncExecutors;
+  SmallVector<SmallVector<OpPassManager, 1>, 8> asyncExecutors;
 };
 
+/// Utility function to convert the given class to the base adaptor it is an
+/// adaptor pass, returns nullptr otherwise.
+OpToOpPassAdaptorBase *getAdaptorPassBase(Pass *pass);
+
 /// Utility function to return if a pass refers to an adaptor pass. Adaptor
 /// passes are those that internally execute a pipeline.
 inline bool isAdaptorPass(Pass *pass) {
   return isa<OpToOpPassAdaptorParallel>(pass) || isa<OpToOpPassAdaptor>(pass);
 }
 
-/// Utility function to return the operation name that the given adaptor pass
-/// operates on. Return None if the given pass is not an adaptor pass.
-Optional<StringRef> getAdaptorPassOpName(Pass *pass);
-
 } // end namespace detail
 } // end namespace mlir
 #endif // MLIR_PASS_PASSDETAIL_H_
diff --git a/third_party/mlir/lib/Pass/PassTiming.cpp b/third_party/mlir/lib/Pass/PassTiming.cpp
index 6e4a81b..3c5a37d 100644
--- a/third_party/mlir/lib/Pass/PassTiming.cpp
+++ b/third_party/mlir/lib/Pass/PassTiming.cpp
@@ -289,8 +289,15 @@
   auto kind = isAdaptorPass(pass) ? TimerKind::PipelineCollection
                                   : TimerKind::PassOrAnalysis;
   Timer *timer = getTimer(pass, kind, [pass]() -> std::string {
-    if (auto pipelineName = getAdaptorPassOpName(pass))
-      return ("Pipeline Collection : ['" + *pipelineName + "']").str();
+    if (auto *adaptor = getAdaptorPassBase(pass)) {
+      std::string name = "Pipeline Collection : [";
+      llvm::raw_string_ostream os(name);
+      interleaveComma(adaptor->getPassManagers(), os, [&](OpPassManager &pm) {
+        os << '\'' << pm.getOpName() << '\'';
+      });
+      os << ']';
+      return os.str();
+    }
     return pass->getName();
   });