| //===- PatternMatch.cpp - Base classes for pattern match ------------------===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| |
| #include "mlir/IR/SSAValue.h" |
| #include "mlir/IR/Statements.h" |
| #include "mlir/StandardOps/StandardOps.h" |
| #include "mlir/Transforms/PatternMatch.h" |
| using namespace mlir; |
| |
| PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { |
| assert(representation == benefit && benefit != ImpossibleToMatchSentinel && |
| "This pattern match benefit is too large to represent"); |
| } |
| |
| unsigned short PatternBenefit::getBenefit() const { |
| assert(representation != ImpossibleToMatchSentinel && |
| "Pattern doesn't match"); |
| return representation; |
| } |
| |
| bool PatternBenefit::operator==(const PatternBenefit& other) { |
| if (isImpossibleToMatch()) |
| return other.isImpossibleToMatch(); |
| if (other.isImpossibleToMatch()) |
| return false; |
| return getBenefit() == other.getBenefit(); |
| } |
| |
| bool PatternBenefit::operator!=(const PatternBenefit& other) { |
| return !(*this == other); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pattern implementation |
| //===----------------------------------------------------------------------===// |
| |
| Pattern::Pattern(StringRef rootName, MLIRContext *context, |
| Optional<PatternBenefit> staticBenefit) |
| : rootKind(OperationName(rootName, context)), staticBenefit(staticBenefit) { |
| } |
| |
| Pattern::Pattern(StringRef rootName, MLIRContext *context, |
| unsigned staticBenefit) |
| : rootKind(rootName, context), staticBenefit(staticBenefit) {} |
| |
| Optional<PatternBenefit> Pattern::getStaticBenefit() const { |
| return staticBenefit; |
| } |
| |
| OperationName Pattern::getRootKind() const { return rootKind; } |
| |
| void Pattern::rewrite(Operation *op, std::unique_ptr<PatternState> state, |
| PatternRewriter &rewriter) const { |
| rewrite(op, rewriter); |
| } |
| |
| void Pattern::rewrite(Operation *op, PatternRewriter &rewriter) const { |
| llvm_unreachable("need to implement one of the rewrite functions!"); |
| } |
| |
| /// This method indicates that no match was found. |
| PatternMatchResult Pattern::matchFailure() { |
| return {PatternBenefit::impossibleToMatch(), std::unique_ptr<PatternState>()}; |
| } |
| |
| /// This method indicates that a match was found and has the specified cost. |
| PatternMatchResult |
| Pattern::matchSuccess(PatternBenefit benefit, |
| std::unique_ptr<PatternState> state) const { |
| assert((!getStaticBenefit().hasValue() || |
| getStaticBenefit().getValue() == benefit) && |
| "This version of matchSuccess must be called with a benefit that " |
| "matches the static benefit if set!"); |
| |
| return {benefit, std::move(state)}; |
| } |
| |
| /// This method indicates that a match was found for patterns that have a |
| /// known static benefit. |
| PatternMatchResult |
| Pattern::matchSuccess(std::unique_ptr<PatternState> state) const { |
| auto benefit = getStaticBenefit(); |
| assert(benefit.hasValue() && "Pattern doesn't have a static benefit"); |
| return matchSuccess(benefit.getValue(), std::move(state)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // PatternRewriter implementation |
| //===----------------------------------------------------------------------===// |
| |
| PatternRewriter::~PatternRewriter() { |
| // Out of line to provide a vtable anchor for the class. |
| } |
| |
| /// This method is used as the final replacement hook for patterns that match |
| /// a single result value. In addition to replacing and removing the |
| /// specified operation, clients can specify a list of other nodes that this |
| /// replacement may make (perhaps transitively) dead. If any of those ops are |
| /// dead, this will remove them as well. |
| void PatternRewriter::replaceSingleResultOp( |
| Operation *op, SSAValue *newValue, ArrayRef<SSAValue *> opsToRemoveIfDead) { |
| // Notify the rewriter subclass that we're about to replace this root. |
| notifyRootReplaced(op); |
| |
| assert(op->getNumResults() == 1 && "op isn't a SingleResultOp!"); |
| op->getResult(0)->replaceAllUsesWith(newValue); |
| |
| notifyOperationRemoved(op); |
| op->erase(); |
| |
| // TODO: Process the opsToRemoveIfDead list, removing things and calling the |
| // notifyOperationRemoved hook in the process. |
| } |
| |
| /// This method is used as the final notification hook for patterns that end |
| /// up modifying the pattern root in place, by changing its operands. This is |
| /// a minor efficiency win (it avoids creating a new instruction and removing |
| /// the old one) but also often allows simpler code in the client. |
| /// |
| /// The opsToRemoveIfDead list is an optional list of nodes that the rewriter |
| /// should remove if they are dead at this point. |
| /// |
| void PatternRewriter::updatedRootInPlace( |
| Operation *op, ArrayRef<SSAValue *> opsToRemoveIfDead) { |
| // Notify the rewriter subclass that we're about to replace this root. |
| notifyRootUpdated(op); |
| |
| // TODO: Process the opsToRemoveIfDead list, removing things and calling the |
| // notifyOperationRemoved hook in the process. |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // PatternMatcher implementation |
| //===----------------------------------------------------------------------===// |
| |
| /// Find the highest benefit pattern available in the pattern set for the DAG |
| /// rooted at the specified node. This returns the pattern if found, or null |
| /// if there are no matches. |
| auto PatternMatcher::findMatch(Operation *op) -> MatchResult { |
| // TODO: This is a completely trivial implementation, expand this in the |
| // future. |
| |
| // Keep track of the best match, the benefit of it, and any matcher specific |
| // state it is maintaining. |
| MatchResult bestMatch = {nullptr, nullptr}; |
| Optional<PatternBenefit> bestBenefit; |
| |
| for (auto *pattern : patterns) { |
| // Ignore patterns that are for the wrong root. |
| if (pattern->getRootKind() != op->getName()) |
| continue; |
| |
| // If we know the static cost of the pattern is worse than what we've |
| // already found then don't run it. |
| auto staticBenefit = pattern->getStaticBenefit(); |
| if (staticBenefit.hasValue() && bestBenefit.hasValue() && |
| staticBenefit.getValue().getBenefit() < |
| bestBenefit.getValue().getBenefit()) |
| continue; |
| |
| // Check to see if this pattern matches this node. |
| auto result = pattern->match(op); |
| auto benefit = result.first; |
| |
| // If this pattern failed to match, ignore it. |
| if (benefit.isImpossibleToMatch()) |
| continue; |
| |
| // If it matched but had lower benefit than our best match so far, then |
| // ignore it. |
| if (bestBenefit.hasValue() && |
| benefit.getBenefit() < bestBenefit.getValue().getBenefit()) |
| continue; |
| |
| // Okay we found a match that is better than our previous one, remember it. |
| bestBenefit = benefit; |
| bestMatch = {pattern, std::move(result.second)}; |
| } |
| |
| // If we found any match, return it. |
| return bestMatch; |
| } |