blob: 5fd6dfd18b560ebc49b6ee25c94cf28ca59ce1ab [file] [log] [blame]
//===- Pass.h - Base classes for compiler passes ----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_PASS_PASS_H
#define MLIR_PASS_PASS_H
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/PointerIntPair.h"
namespace mlir {
/// The abstract base pass class. This class contains information describing the
/// derived pass object, e.g its kind and abstract PassInfo.
class Pass {
public:
enum class Kind { FunctionPass, ModulePass };
virtual ~Pass() = default;
/// Returns the unique identifier that corresponds to this pass.
const PassID *getPassID() const { return passIDAndKind.getPointer(); }
/// Returns the pass info for the specified pass class or null if unknown.
static const PassInfo *lookupPassInfo(const PassID *passID);
template <typename PassT> static const PassInfo *lookupPassInfo() {
return lookupPassInfo(PassID::getID<PassT>());
}
/// Returns the pass info for this pass.
const PassInfo *lookupPassInfo() const { return lookupPassInfo(getPassID()); }
/// Return the kind of this pass.
Kind getKind() const { return passIDAndKind.getInt(); }
/// Returns the derived pass name.
virtual StringRef getName() = 0;
protected:
Pass(const PassID *passID, Kind kind) : passIDAndKind(passID, kind) {}
private:
/// Out of line virtual method to ensure vtables and metadata are emitted to a
/// single .o file.
virtual void anchor();
/// Represents a unique identifier for the pass and its kind.
llvm::PointerIntPair<const PassID *, 1, Kind> passIDAndKind;
};
namespace detail {
class FunctionPassExecutor;
class ModulePassExecutor;
/// The state for a single execution of a pass. This provides a unified
/// interface for accessing and initializing necessary state for pass execution.
template <typename IRUnitT, typename AnalysisManagerT>
struct PassExecutionState {
PassExecutionState(IRUnitT *ir, AnalysisManagerT &analysisManager)
: irAndPassFailed(ir, false), analysisManager(analysisManager) {}
/// The current IR unit being transformed and a bool for if the pass signaled
/// a failure.
llvm::PointerIntPair<IRUnitT *, 1, bool> irAndPassFailed;
/// The analysis manager for the IR unit.
AnalysisManagerT &analysisManager;
/// The set of preserved analyses for the current execution.
detail::PreservedAnalyses preservedAnalyses;
};
} // namespace detail
/// Pass to transform a specific function within a module. Derived passes should
/// not inherit from this class directly, and instead should use the CRTP
/// FunctionPass class.
class FunctionPassBase : public Pass {
using PassStateT =
detail::PassExecutionState<Function, FunctionAnalysisManager>;
public:
static bool classof(const Pass *pass) {
return pass->getKind() == Kind::FunctionPass;
}
protected:
explicit FunctionPassBase(const PassID *id) : Pass(id, Kind::FunctionPass) {}
/// The polymorphic API that runs the pass over the currently held function.
virtual void runOnFunction() = 0;
/// A clone method to create a copy of this pass.
virtual FunctionPassBase *clone() const = 0;
/// Return the current function being transformed.
Function &getFunction() {
return *getPassState().irAndPassFailed.getPointer();
}
/// Return the MLIR context for the current function being transformed.
MLIRContext &getContext() { return *getFunction().getContext(); }
/// Returns the current pass state.
PassStateT &getPassState() {
assert(passState && "pass state was never initialized");
return *passState;
}
/// Returns the current analysis manager.
FunctionAnalysisManager &getAnalysisManager() {
return getPassState().analysisManager;
}
private:
/// Forwarding function to execute this pass.
LLVM_NODISCARD
LogicalResult run(Function *fn, FunctionAnalysisManager &fam);
/// The current execution state for the pass.
llvm::Optional<PassStateT> passState;
/// Allow access to 'run'.
friend detail::FunctionPassExecutor;
};
/// Pass to transform a module. Derived passes should not inherit from this
/// class directly, and instead should use the CRTP ModulePass class.
class ModulePassBase : public Pass {
using PassStateT = detail::PassExecutionState<Module, ModuleAnalysisManager>;
public:
static bool classof(const Pass *pass) {
return pass->getKind() == Kind::ModulePass;
}
protected:
explicit ModulePassBase(const PassID *id) : Pass(id, Kind::ModulePass) {}
/// The polymorphic API that runs the pass over the currently held module.
virtual void runOnModule() = 0;
/// Return the current module being transformed.
Module &getModule() { return *getPassState().irAndPassFailed.getPointer(); }
/// Return the MLIR context for the current module being transformed.
MLIRContext &getContext() { return *getModule().getContext(); }
/// Returns the current pass state.
PassStateT &getPassState() {
assert(passState && "pass state was never initialized");
return *passState;
}
/// Returns the current analysis manager.
ModuleAnalysisManager &getAnalysisManager() {
return getPassState().analysisManager;
}
private:
/// Forwarding function to execute this pass.
LLVM_NODISCARD
LogicalResult run(Module *module, ModuleAnalysisManager &mam);
/// The current execution state for the pass.
llvm::Optional<PassStateT> passState;
/// Allow access to 'run'.
friend detail::ModulePassExecutor;
};
//===----------------------------------------------------------------------===//
// Pass Model Definitions
//===----------------------------------------------------------------------===//
namespace detail {
/// The opaque CRTP model of a pass. This class provides utilities for derived
/// pass execution and handles all of the necessary polymorphic API.
template <typename IRUnitT, typename PassT, typename BasePassT>
class PassModel : public BasePassT {
public:
/// Support isa/dyn_cast functionality for the derived pass class.
static bool classof(const Pass *pass) {
return pass->getPassID() == PassID::getID<PassT>();
}
protected:
PassModel() : BasePassT(PassID::getID<PassT>()) {}
/// Signal that some invariant was broken when running. The IR is allowed to
/// be in an invalid state.
void signalPassFailure() {
this->getPassState().irAndPassFailed.setInt(true);
}
/// Query an analysis for the current ir unit.
template <typename AnalysisT> AnalysisT &getAnalysis() {
return this->getAnalysisManager().template getAnalysis<AnalysisT>();
}
/// Query a cached instance of an analysis for the current ir unit if one
/// exists.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() {
return this->getAnalysisManager().template getCachedAnalysis<AnalysisT>();
}
/// Mark all analyses as preserved.
void markAllAnalysesPreserved() {
this->getPassState().preservedAnalyses.preserveAll();
}
/// Mark the provided analyses as preserved.
template <typename... AnalysesT> void markAnalysesPreserved() {
this->getPassState().preservedAnalyses.template preserve<AnalysesT...>();
}
void markAnalysesPreserved(const AnalysisID *id) {
this->getPassState().preservedAnalyses.preserve(id);
}
/// Returns the derived pass name.
StringRef getName() override {
StringRef name = llvm::getTypeName<PassT>();
if (!name.consume_front("mlir::"))
name.consume_front("(anonymous namespace)::");
return name;
}
};
} // end namespace detail
/// A model for providing function pass specific utilities.
///
/// Function passes must not:
/// - read or modify any other functions within the parent module, as
/// other threads may be manipulating them concurrently.
/// - modify any state within the parent module, this includes adding
/// additional functions.
///
/// Derived function passes are expected to provide the following:
/// - A 'void runOnFunction()' method.
template <typename T>
struct FunctionPass : public detail::PassModel<Function, T, FunctionPassBase> {
/// Returns the analysis for the parent module if it exists.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedModuleAnalysis() {
return this->getAnalysisManager()
.template getCachedModuleAnalysis<AnalysisT>();
}
/// A clone method to create a copy of this pass.
FunctionPassBase *clone() const override {
return new T(*static_cast<const T *>(this));
}
};
/// A model for providing module pass specific utilities.
///
/// Derived module passes are expected to provide the following:
/// - A 'void runOnModule()' method.
template <typename T>
struct ModulePass : public detail::PassModel<Module, T, ModulePassBase> {
/// Returns the analysis for a child function.
template <typename AnalysisT> AnalysisT &getFunctionAnalysis(Function *f) {
return this->getAnalysisManager().template getFunctionAnalysis<AnalysisT>(
f);
}
/// Returns an existing analysis for a child function if it exists.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>>
getCachedFunctionAnalysis(Function *f) {
return this->getAnalysisManager()
.template getCachedFunctionAnalysis<AnalysisT>(f);
}
};
} // end namespace mlir
#endif // MLIR_PASS_PASS_H