blob: f3ed316425734255afbe8116a3de7c8a1c3a5890 [file] [log] [blame]
//===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_PATTERNMATCHER_H
#define MLIR_PATTERNMATCHER_H
#include "mlir/IR/Builders.h"
namespace mlir {
class PatternRewriter;
//===----------------------------------------------------------------------===//
// PatternBenefit class
//===----------------------------------------------------------------------===//
/// This class represents the benefit of a pattern match in a unitless scheme
/// that ranges from 0 (very little benefit) to 65K. The most common unit to
/// use here is the "number of operations matched" by the pattern.
///
/// This also has a sentinel representation that can be used for patterns that
/// fail to match.
///
class PatternBenefit {
enum { ImpossibleToMatchSentinel = 65535 };
public:
/*implicit*/ PatternBenefit(unsigned benefit);
PatternBenefit(const PatternBenefit &) = default;
PatternBenefit &operator=(const PatternBenefit &) = default;
static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
/// If the corresponding pattern can match, return its benefit. If the
// corresponding pattern isImpossibleToMatch() then this aborts.
unsigned short getBenefit() const;
bool operator==(const PatternBenefit &rhs) const {
return representation == rhs.representation;
}
bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
bool operator<(const PatternBenefit &rhs) const {
return representation < rhs.representation;
}
private:
PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
unsigned short representation;
};
/// Pattern state is used by patterns that want to maintain state between their
/// match and rewrite phases. Patterns can define a pattern-specific subclass
/// of this.
class PatternState {
public:
virtual ~PatternState() {}
protected:
// Must be subclassed.
PatternState() {}
};
/// This is the type returned by a pattern match. A match failure returns a
/// None value. A match success returns a Some value with any state the pattern
/// may need to maintain (but may also be null).
using PatternMatchResult = Optional<std::unique_ptr<PatternState>>;
//===----------------------------------------------------------------------===//
// Pattern class
//===----------------------------------------------------------------------===//
/// Instances of Pattern can be matched against SSA IR. These matches get used
/// in ways dependent on their subclasses and the driver doing the matching.
/// For example, RewritePatterns implement a rewrite from one matched pattern
/// to a replacement DAG tile.
class Pattern {
public:
/// Return the benefit (the inverse of "cost") of matching this pattern. The
/// benefit of a Pattern is always static - rewrites that may have dynamic
/// benefit can be instantiated multiple times (different Pattern instances)
/// for each benefit that they may return, and be guarded by different match
/// condition predicates.
PatternBenefit getBenefit() const { return benefit; }
/// Return the root node that this pattern matches. Patterns that can
/// match multiple root types are instantiated once per root.
OperationName getRootKind() const { return rootKind; }
//===--------------------------------------------------------------------===//
// Implementation hooks for patterns to implement.
//===--------------------------------------------------------------------===//
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). On failure, this
/// returns a None value. On success it returns a (possibly null)
/// pattern-specific state wrapped in an Optional.
virtual PatternMatchResult match(Operation *op) const = 0;
virtual ~Pattern() {}
//===--------------------------------------------------------------------===//
// Helper methods to simplify pattern implementations
//===--------------------------------------------------------------------===//
/// This method indicates that no match was found.
static PatternMatchResult matchFailure() { return None; }
/// This method indicates that a match was found and has the specified cost.
PatternMatchResult
matchSuccess(std::unique_ptr<PatternState> state = {}) const {
return PatternMatchResult(std::move(state));
}
protected:
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
private:
const OperationName rootKind;
const PatternBenefit benefit;
virtual void anchor();
};
/// RewritePattern is the common base class for all DAG to DAG replacements.
/// There are two possible usages of this class:
/// * Multi-step RewritePattern with "match" and "rewrite"
/// - By overloading the "match" and "rewrite" functions, the user can
/// separate the concerns of matching and rewriting.
/// * Single-step RewritePattern with "matchAndRewrite"
/// - By overloading the "matchAndRewrite" function, the user can perform
/// the rewrite in the same call as the match. This removes the need for
/// any PatternState.
///
class RewritePattern : public Pattern {
public:
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// rewriter. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const;
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). On failure, this
/// returns a None value. On success, it returns a (possibly null)
/// pattern-specific state wrapped in an Optional. This state is passed back
/// into the rewrite function if this match is selected.
PatternMatchResult match(Operation *op) const override;
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). If successful, this
/// function will automatically perform the rewrite.
virtual PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
if (auto matchResult = match(op)) {
rewrite(op, std::move(*matchResult), rewriter);
return matchSuccess();
}
return matchFailure();
}
/// Return a list of operations that may be generated when rewriting an
/// operation instance with this pattern.
ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
protected:
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
RewritePattern(StringRef rootName, PatternBenefit benefit,
MLIRContext *context)
: Pattern(rootName, benefit, context) {}
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching. They can also specify
/// the names of operations that may be generated during a successful rewrite.
RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context);
/// A list of the potential operations that may be generated when rewriting
/// an op with this pattern.
llvm::SmallVector<OperationName, 2> generatedOps;
};
/// OpRewritePattern is a wrapper around RewritePattern that allows for
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(SourceOp::getOperationName(), benefit, context) {}
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const final {
rewrite(llvm::cast<SourceOp>(op), std::move(state), rewriter);
}
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
rewrite(llvm::cast<SourceOp>(op), rewriter);
}
PatternMatchResult match(Operation *op) const final {
return match(llvm::cast<SourceOp>(op));
}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
return matchAndRewrite(llvm::cast<SourceOp>(op), rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual void rewrite(SourceOp op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const {
rewrite(op, rewriter);
}
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual PatternMatchResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
virtual PatternMatchResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const {
if (auto matchResult = match(op)) {
rewrite(op, std::move(*matchResult), rewriter);
return matchSuccess();
}
return matchFailure();
}
};
//===----------------------------------------------------------------------===//
// PatternRewriter class
//===----------------------------------------------------------------------===//
/// This class coordinates the application of a pattern to the current function,
/// providing a way to create operations and keep track of what gets deleted.
///
/// These class serves two purposes:
/// 1) it is the interface that patterns interact with to make mutations to the
/// IR they are being applied to.
/// 2) It is a base class that clients of the PatternMatcher use when they want
/// to apply patterns and observe their effects (e.g. to keep worklists or
/// other data structures up to date).
///
class PatternRewriter : public OpBuilder {
public:
/// Create operation of specific op type at the current insertion point
/// without verifying to see if it is valid.
template <typename OpTy, typename... Args>
OpTy create(Location location, Args... args) {
OperationState state(location, OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *op = createOperation(state);
auto result = dyn_cast<OpTy>(op);
assert(result && "Builder didn't return the right type");
return result;
}
/// Creates an operation of specific op type at the current insertion point.
/// If the result is an invalid op (the verifier hook fails), emit an error
/// and return null.
template <typename OpTy, typename... Args>
OpTy createChecked(Location location, Args... args) {
OperationState state(location, OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *op = createOperation(state);
// If the Operation we produce is valid, return it.
if (!OpTy::verifyInvariants(op)) {
auto result = dyn_cast<OpTy>(op);
assert(result && "Builder didn't return the right type");
return result;
}
// Otherwise, the error message got emitted. Just remove the operation
// we made.
op->erase();
return OpTy();
}
/// This is implemented to create the specified operations and serves as a
/// notification hook for rewriters that want to know about new operations.
virtual Operation *createOperation(const OperationState &state) = 0;
/// Move the blocks that belong to "region" before the given position in
/// another region "parent". The two regions must be different. The caller
/// is responsible for creating or updating the operation transferring flow
// of control to the region and pass it the correct block arguments.
virtual void inlineRegionBefore(Region &region, Region &parent,
Region::iterator before);
void inlineRegionBefore(Region &region, Block *before);
/// This method performs the final replacement for a pattern, where the
/// results of the operation are updated to use the specified list of SSA
/// values. In addition to replacing and removing the specified operation,
/// clients can specify a list of other nodes that this replacement may make
/// (perhaps transitively) dead. If any of those values are dead, this will
/// remove them as well.
virtual void replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead);
void replaceOp(Operation *op, ArrayRef<Value *> newValues) {
replaceOp(op, newValues, llvm::None);
}
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types.
template <typename OpTy, typename... Args>
void replaceOpWithNewOp(Operation *op, Args &&... args) {
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {});
}
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types. This allows
/// specifying a list of ops that may be removed if dead.
template <typename OpTy, typename... Args>
void replaceOpWithNewOp(ArrayRef<Value *> valuesToRemoveIfDead, Operation *op,
Args &&... args) {
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(),
valuesToRemoveIfDead);
}
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before) {
return block->splitBlock(before);
}
/// This method is used as the final notification hook for patterns that end
/// up modifying the pattern root in place, by changing its operands. This is
/// a minor efficiency win (it avoids creating a new operation and removing
/// the old one) but also often allows simpler code in the client.
///
/// The valuesToRemoveIfDead list is an optional list of values that the
/// rewriter should remove if they are dead at this point.
///
void updatedRootInPlace(Operation *op,
ArrayRef<Value *> valuesToRemoveIfDead = {});
protected:
explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
virtual ~PatternRewriter();
// These are the callback methods that subclasses can choose to implement if
// they would like to be notified about certain types of mutations.
/// Notify the pattern rewriter that the specified operation has been mutated
/// in place. This is called after the mutation is done.
virtual void notifyRootUpdated(Operation *op) {}
/// Notify the pattern rewriter that the specified operation is about to be
/// replaced with another set of operations. This is called before the uses
/// of the operation have been changed.
virtual void notifyRootReplaced(Operation *op) {}
/// This is called on an operation that a pattern match is removing, right
/// before the operation is deleted. At this point, the operation has zero
/// uses.
virtual void notifyOperationRemoved(Operation *op) {}
private:
/// op and newOp are known to have the same number of results, replace the
/// uses of op with uses of newOp
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
ArrayRef<Value *> valuesToRemoveIfDead);
};
//===----------------------------------------------------------------------===//
// Pattern-driven rewriters
//===----------------------------------------------------------------------===//
class OwningRewritePatternList {
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
PatternListT::iterator begin() { return patterns.begin(); }
PatternListT::iterator end() { return patterns.end(); }
PatternListT::const_iterator begin() const { return patterns.begin(); }
PatternListT::const_iterator end() const { return patterns.end(); }
void clear() { patterns.clear(); }
//===--------------------------------------------------------------------===//
// Pattern Insertion
//===--------------------------------------------------------------------===//
/// Add an instance of each of the pattern types 'Ts' to the pattern list with
/// the given arguments.
/// Note: ConstructorArg is necessary here to separate the two variadic lists.
template <typename... Ts, typename ConstructorArg,
typename... ConstructorArgs,
typename = std::enable_if_t<sizeof...(Ts) != 0>>
void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
// The following expands a call to emplace_back for each of the pattern
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
using dummy = int[];
(void)dummy{
0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
}
private:
PatternListT patterns;
};
/// This class manages optimization and execution of a group of rewrite
/// patterns, providing an API for finding and applying, the best match against
/// a given node.
///
class RewritePatternMatcher {
public:
/// Create a RewritePatternMatcher with the specified set of patterns.
explicit RewritePatternMatcher(const OwningRewritePatternList &patterns);
/// Try to match the given operation to a pattern and rewrite it. Return
/// true if any pattern matches.
bool matchAndRewrite(Operation *op, PatternRewriter &rewriter);
private:
RewritePatternMatcher(const RewritePatternMatcher &) = delete;
void operator=(const RewritePatternMatcher &) = delete;
/// The group of patterns that are matched for optimization through this
/// matcher.
std::vector<RewritePattern *> patterns;
};
/// Rewrite the regions of the specified operation, which must be isolated from
/// above, by repeatedly applying the highest benefit patterns in a greedy
/// work-list driven manner. Return true if no more patterns can be matched in
/// the result operation regions.
/// Note: This does not apply patterns to the top-level operation itself.
/// Note: This method also performs folding and simply dead-code elimination
/// before attempting to match any of the provided patterns.
///
bool applyPatternsGreedily(Operation *op,
const OwningRewritePatternList &patterns);
} // end namespace mlir
#endif // MLIR_PATTERN_MATCH_H