blob: aa5a14e75d6f75a8f9ae3fb1a30faf2594bb1044 [file] [log] [blame]
//===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_PATTERNMATCHER_H
#define MLIR_PATTERNMATCHER_H
#include "mlir/IR/Builders.h"
namespace mlir {
class PatternRewriter;
//===----------------------------------------------------------------------===//
// PatternBenefit class
//===----------------------------------------------------------------------===//
/// This class represents the benefit of a pattern match in a unitless scheme
/// that ranges from 0 (very little benefit) to 65K. The most common unit to
/// use here is the "number of operations matched" by the pattern.
///
/// This also has a sentinel representation that can be used for patterns that
/// fail to match.
///
class PatternBenefit {
enum { ImpossibleToMatchSentinel = 65535 };
public:
/*implicit*/ PatternBenefit(unsigned benefit);
PatternBenefit(const PatternBenefit &) = default;
PatternBenefit &operator=(const PatternBenefit &) = default;
static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
bool isImpossibleToMatch() const {
return representation == ImpossibleToMatchSentinel;
}
/// If the corresponding pattern can match, return its benefit. If the
// corresponding pattern isImpossibleToMatch() then this aborts.
unsigned short getBenefit() const;
inline bool operator==(const PatternBenefit& other);
inline bool operator!=(const PatternBenefit& other);
private:
PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
unsigned short representation;
};
/// Pattern state is used by patterns that want to maintain state between their
/// match and rewrite phases. Patterns can define a pattern-specific subclass
/// of this.
class PatternState {
public:
virtual ~PatternState() {}
protected:
// Must be subclassed.
PatternState() {}
};
/// This is the type returned by a pattern match. A match failure returns a
/// None value. A match success returns a Some value with any state the pattern
/// may need to maintain (but may also be null).
using PatternMatchResult = Optional<std::unique_ptr<PatternState>>;
//===----------------------------------------------------------------------===//
// Pattern class
//===----------------------------------------------------------------------===//
/// Instances of Pattern can be matched against SSA IR. These matches get used
/// in ways dependent on their subclasses and the driver doing the matching.
/// For example, RewritePatterns implement a rewrite from one matched pattern
/// to a replacement DAG tile.
class Pattern {
public:
/// Return the benefit (the inverse of "cost") of matching this pattern. The
/// benefit of a Pattern is always static - rewrites that may have dynamic
/// benefit can be instantiated multiple times (different Pattern instances)
/// for each benefit that they may return, and be guarded by different match
/// condition predicates.
PatternBenefit getBenefit() const { return benefit; }
/// Return the root node that this pattern matches. Patterns that can
/// match multiple root types are instantiated once per root.
OperationName getRootKind() const { return rootKind; }
//===--------------------------------------------------------------------===//
// Implementation hooks for patterns to implement.
//===--------------------------------------------------------------------===//
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). On failure, this
/// returns a None value. On success it a (possibly null) pattern-specific
/// state wrapped in a Some. This state is passed back into its rewrite
/// function if this match is selected.
virtual PatternMatchResult match(OperationInst *op) const = 0;
virtual ~Pattern() {}
//===--------------------------------------------------------------------===//
// Helper methods to simplify pattern implementations
//===--------------------------------------------------------------------===//
/// This method indicates that no match was found.
static PatternMatchResult matchFailure() { return None; }
/// This method indicates that a match was found and has the specified cost.
PatternMatchResult
matchSuccess(std::unique_ptr<PatternState> state = {}) const {
return PatternMatchResult(std::move(state));
}
protected:
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
private:
const OperationName rootKind;
const PatternBenefit benefit;
virtual void anchor();
};
/// RewritePattern is the common base class for all DAG to DAG replacements.
/// After a RewritePattern is matched, its replacement is performed by invoking
/// the "rewrite" method that the instance implements.
///
class RewritePattern : public Pattern {
public:
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// rewriter. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
virtual void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const;
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
virtual void rewrite(OperationInst *op, PatternRewriter &rewriter) const;
protected:
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
RewritePattern(StringRef rootName, PatternBenefit benefit,
MLIRContext *context)
: Pattern(rootName, benefit, context) {}
};
//===----------------------------------------------------------------------===//
// PatternRewriter class
//===----------------------------------------------------------------------===//
/// This class coordinates the application of a pattern to the current function,
/// providing a way to create operations and keep track of what gets deleted.
///
/// These class serves two purposes:
/// 1) it is the interface that patterns interact with to make mutations to the
/// IR they are being applied to.
/// 2) It is a base class that clients of the PatternMatcher use when they want
/// to apply patterns and observe their effects (e.g. to keep worklists or
/// other data structures up to date).
///
class PatternRewriter : public Builder {
public:
/// Create operation of specific op type at the current insertion point
/// without verifying to see if it is valid.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Location location, Args... args) {
OperationState state(getContext(), location, OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *op = createOperation(state);
auto result = op->dyn_cast<OpTy>();
assert(result && "Builder didn't return the right type");
return result;
}
/// Creates an operation of specific op type at the current insertion point.
/// If the result is an invalid op (the verifier hook fails), emit an error
/// and return null.
template <typename OpTy, typename... Args>
OpPointer<OpTy> createChecked(Location location, Args... args) {
OperationState state(getContext(), location, OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *op = createOperation(state);
// If the Instruction we produce is valid, return it.
if (!OpTy::verifyInvariants(op)) {
auto result = op->dyn_cast<OpTy>();
assert(result && "Builder didn't return the right type");
return result;
}
// Otherwise, the error message got emitted. Just remove the instruction
// we made.
op->erase();
return OpPointer<OpTy>();
}
/// This method performs the final replacement for a pattern, where the
/// results of the operation are updated to use the specified list of SSA
/// values. In addition to replacing and removing the specified operation,
/// clients can specify a list of other nodes that this replacement may make
/// (perhaps transitively) dead. If any of those values are dead, this will
/// remove them as well.
void replaceOp(OperationInst *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead = {});
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types.
template <typename OpTy, typename... Args>
void replaceOpWithNewOp(OperationInst *op, Args... args) {
auto newOp = create<OpTy>(op->getLoc(), args...);
replaceOpWithResultsOfAnotherOp(op, newOp->getInstruction(), {});
}
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types. This allows
/// specifying a list of ops that may be removed if dead.
template <typename OpTy, typename... Args>
void replaceOpWithNewOp(OperationInst *op,
ArrayRef<Value *> valuesToRemoveIfDead,
Args... args) {
auto newOp = create<OpTy>(op->getLoc(), args...);
replaceOpWithResultsOfAnotherOp(op, newOp->getInstruction(),
valuesToRemoveIfDead);
}
/// This method is used as the final notification hook for patterns that end
/// up modifying the pattern root in place, by changing its operands. This is
/// a minor efficiency win (it avoids creating a new instruction and removing
/// the old one) but also often allows simpler code in the client.
///
/// The valuesToRemoveIfDead list is an optional list of values that the
/// rewriter should remove if they are dead at this point.
///
void updatedRootInPlace(OperationInst *op,
ArrayRef<Value *> valuesToRemoveIfDead = {});
protected:
PatternRewriter(MLIRContext *context) : Builder(context) {}
virtual ~PatternRewriter();
// These are the callback methods that subclasses can choose to implement if
// they would like to be notified about certain types of mutations.
/// This is implemented to create the specified operations and serves as a
/// notification hook for rewriters that want to know about new operations.
virtual OperationInst *createOperation(const OperationState &state) = 0;
/// Notify the pattern rewriter that the specified operation has been mutated
/// in place. This is called after the mutation is done.
virtual void notifyRootUpdated(OperationInst *op) {}
/// Notify the pattern rewriter that the specified operation is about to be
/// replaced with another set of operations. This is called before the uses
/// of the operation have been changed.
virtual void notifyRootReplaced(OperationInst *op) {}
/// This is called on an operation that a pattern match is removing, right
/// before the operation is deleted. At this point, the operation has zero
/// uses.
virtual void notifyOperationRemoved(OperationInst *op) {}
private:
/// op and newOp are known to have the same number of results, replace the
/// uses of op with uses of newOp
void replaceOpWithResultsOfAnotherOp(OperationInst *op, OperationInst *newOp,
ArrayRef<Value *> valuesToRemoveIfDead);
};
//===----------------------------------------------------------------------===//
// PatternMatcher class
//===----------------------------------------------------------------------===//
/// This is a vector that owns the patterns inside of it.
using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
/// This class manages optimization and execution of a group of patterns,
/// providing an API for finding the best match against a given node.
///
class PatternMatcher {
public:
/// Create a PatternMatch with the specified set of patterns.
explicit PatternMatcher(OwningPatternList &&patterns)
: patterns(std::move(patterns)) {}
// Support matching from subclasses of Pattern.
template <typename T>
explicit PatternMatcher(std::vector<std::unique_ptr<T>> &&patternSubclasses) {
patterns.reserve(patternSubclasses.size());
for (auto &&elt : patternSubclasses)
patterns.emplace_back(std::move(elt));
}
using MatchResult = std::pair<Pattern *, std::unique_ptr<PatternState>>;
/// Find the highest benefit pattern available in the pattern set for the DAG
/// rooted at the specified node. This returns the pattern (and any state it
/// needs) if found, or null if there are no matches.
MatchResult findMatch(OperationInst *op);
private:
PatternMatcher(const PatternMatcher &) = delete;
void operator=(const PatternMatcher &) = delete;
/// The group of patterns that are matched for optimization through this
/// matcher.
OwningPatternList patterns;
};
//===----------------------------------------------------------------------===//
// Pattern-driven rewriters
//===----------------------------------------------------------------------===//
/// This is a vector that owns the patterns inside of it.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
/// Rewrite the specified function by repeatedly applying the highest benefit
/// patterns in a greedy work-list driven manner.
///
void applyPatternsGreedily(Function *fn, OwningRewritePatternList &&patterns);
} // end namespace mlir
#endif // MLIR_PATTERN_MATCH_H