| //===- NestedMatcher.cpp - NestedMatcher Impl ------------------*- C++ -*-===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| |
| #include "mlir/Analysis/NestedMatcher.h" |
| #include "mlir/AffineOps/AffineOps.h" |
| #include "mlir/StandardOps/StandardOps.h" |
| |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/Support/Allocator.h" |
| #include "llvm/Support/raw_ostream.h" |
| |
| namespace mlir { |
| |
| /// Underlying storage for NestedMatch. |
| struct NestedMatchStorage { |
| MutableArrayRef<NestedMatch::EntryType> matches; |
| }; |
| |
| /// Underlying storage for NestedPattern. |
| struct NestedPatternStorage { |
| NestedPatternStorage(Instruction::Kind k, ArrayRef<NestedPattern> c, |
| FilterFunctionType filter, Instruction *skip) |
| : kind(k), nestedPatterns(c), filter(filter), skip(skip) {} |
| |
| Instruction::Kind kind; |
| ArrayRef<NestedPattern> nestedPatterns; |
| FilterFunctionType filter; |
| /// skip is needed so that we can implement match without switching on the |
| /// type of the Instruction. |
| /// The idea is that a NestedPattern first checks if it matches locally |
| /// and then recursively applies its nested matchers to its elem->nested. |
| /// Since we want to rely on the InstWalker impl rather than duplicate its |
| /// the logic, we allow an off-by-one traversal to account for the fact that |
| /// we write: |
| /// |
| /// void match(Instruction *elem) { |
| /// for (auto &c : getNestedPatterns()) { |
| /// NestedPattern childPattern(...); |
| /// ^~~~ Needs off-by-one skip. |
| /// |
| Instruction *skip; |
| }; |
| |
| } // end namespace mlir |
| |
| using namespace mlir; |
| |
| llvm::BumpPtrAllocator *&NestedMatch::allocator() { |
| static thread_local llvm::BumpPtrAllocator *allocator = nullptr; |
| return allocator; |
| } |
| |
| NestedMatch NestedMatch::build(ArrayRef<NestedMatch::EntryType> elements) { |
| auto *matches = |
| allocator()->Allocate<NestedMatch::EntryType>(elements.size()); |
| std::uninitialized_copy(elements.begin(), elements.end(), matches); |
| auto *storage = allocator()->Allocate<NestedMatchStorage>(); |
| new (storage) NestedMatchStorage(); |
| storage->matches = |
| MutableArrayRef<NestedMatch::EntryType>(matches, elements.size()); |
| auto *result = allocator()->Allocate<NestedMatch>(); |
| new (result) NestedMatch(storage); |
| return *result; |
| } |
| |
| NestedMatch::iterator NestedMatch::begin() { return storage->matches.begin(); } |
| NestedMatch::iterator NestedMatch::end() { return storage->matches.end(); } |
| NestedMatch::EntryType &NestedMatch::front() { |
| return *storage->matches.begin(); |
| } |
| NestedMatch::EntryType &NestedMatch::back() { |
| return *(storage->matches.begin() + size() - 1); |
| } |
| |
| /// Calls walk on `function`. |
| NestedMatch NestedPattern::match(Function *function) { |
| assert(!matches && "NestedPattern already matched!"); |
| this->walkPostOrder(function); |
| return matches; |
| } |
| |
| /// Calls walk on `instruction`. |
| NestedMatch NestedPattern::match(Instruction *instruction) { |
| assert(!matches && "NestedPattern already matched!"); |
| this->walkPostOrder(instruction); |
| return matches; |
| } |
| |
| unsigned NestedPattern::getDepth() { |
| auto nested = getNestedPatterns(); |
| if (nested.empty()) { |
| return 1; |
| } |
| unsigned depth = 0; |
| for (auto c : nested) { |
| depth = std::max(depth, c.getDepth()); |
| } |
| return depth + 1; |
| } |
| |
| /// Matches a single instruction in the following way: |
| /// 1. checks the kind of instruction against the matcher, if different then |
| /// there is no match; |
| /// 2. calls the customizable filter function to refine the single instruction |
| /// match with extra semantic constraints; |
| /// 3. if all is good, recursivey matches the nested patterns; |
| /// 4. if all nested match then the single instruction matches too and is |
| /// appended to the list of matches; |
| /// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will |
| /// want to traverse in post-order DFS to avoid invalidating iterators. |
| void NestedPattern::matchOne(Instruction *elem) { |
| if (storage->skip == elem) { |
| return; |
| } |
| // Structural filter |
| if (elem->getKind() != getKind()) { |
| return; |
| } |
| // Local custom filter function |
| if (!getFilterFunction()(*elem)) { |
| return; |
| } |
| |
| SmallVector<NestedMatch::EntryType, 8> nestedEntries; |
| for (auto c : getNestedPatterns()) { |
| /// We create a new nestedPattern here because a matcher holds its |
| /// results. So we concretely need multiple copies of a given matcher, one |
| /// for each matching result. |
| NestedPattern nestedPattern(c); |
| // Skip elem in the walk immediately following. Without this we would |
| // essentially need to reimplement walkPostOrder here. |
| nestedPattern.storage->skip = elem; |
| nestedPattern.walkPostOrder(elem); |
| if (!nestedPattern.matches) { |
| return; |
| } |
| for (auto m : nestedPattern.matches) { |
| nestedEntries.push_back(m); |
| } |
| } |
| |
| SmallVector<NestedMatch::EntryType, 8> newEntries( |
| matches.storage->matches.begin(), matches.storage->matches.end()); |
| newEntries.push_back(std::make_pair(elem, NestedMatch::build(nestedEntries))); |
| matches = NestedMatch::build(newEntries); |
| } |
| |
| llvm::BumpPtrAllocator *&NestedPattern::allocator() { |
| static thread_local llvm::BumpPtrAllocator *allocator = nullptr; |
| return allocator; |
| } |
| |
| NestedPattern::NestedPattern(Instruction::Kind k, |
| ArrayRef<NestedPattern> nested, |
| FilterFunctionType filter) |
| : storage(allocator()->Allocate<NestedPatternStorage>()), |
| matches(NestedMatch::build({})) { |
| auto *newChildren = allocator()->Allocate<NestedPattern>(nested.size()); |
| std::uninitialized_copy(nested.begin(), nested.end(), newChildren); |
| // Initialize with placement new. |
| new (storage) NestedPatternStorage( |
| k, ArrayRef<NestedPattern>(newChildren, nested.size()), filter, |
| nullptr /* skip */); |
| } |
| |
| Instruction::Kind NestedPattern::getKind() { return storage->kind; } |
| |
| ArrayRef<NestedPattern> NestedPattern::getNestedPatterns() { |
| return storage->nestedPatterns; |
| } |
| |
| FilterFunctionType NestedPattern::getFilterFunction() { |
| return storage->filter; |
| } |
| |
| static bool isAffineIfOp(const Instruction &inst) { |
| return isa<OperationInst>(inst) && |
| cast<OperationInst>(inst).isa<AffineIfOp>(); |
| } |
| |
| namespace mlir { |
| namespace matcher { |
| |
| NestedPattern Op(FilterFunctionType filter) { |
| return NestedPattern(Instruction::Kind::OperationInst, {}, filter); |
| } |
| |
| NestedPattern If(NestedPattern child) { |
| return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp); |
| } |
| NestedPattern If(FilterFunctionType filter, NestedPattern child) { |
| return NestedPattern(Instruction::Kind::OperationInst, child, |
| [filter](const Instruction &inst) { |
| return isAffineIfOp(inst) && filter(inst); |
| }); |
| } |
| NestedPattern If(ArrayRef<NestedPattern> nested) { |
| return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp); |
| } |
| NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { |
| return NestedPattern(Instruction::Kind::OperationInst, nested, |
| [filter](const Instruction &inst) { |
| return isAffineIfOp(inst) && filter(inst); |
| }); |
| } |
| |
| NestedPattern For(NestedPattern child) { |
| return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction); |
| } |
| NestedPattern For(FilterFunctionType filter, NestedPattern child) { |
| return NestedPattern(Instruction::Kind::For, child, filter); |
| } |
| NestedPattern For(ArrayRef<NestedPattern> nested) { |
| return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction); |
| } |
| NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { |
| return NestedPattern(Instruction::Kind::For, nested, filter); |
| } |
| |
| // TODO(ntv): parallel annotation on loops. |
| bool isParallelLoop(const Instruction &inst) { |
| const auto *loop = cast<ForInst>(&inst); |
| return (void *)loop || true; // loop->isParallel(); |
| }; |
| |
| // TODO(ntv): reduction annotation on loops. |
| bool isReductionLoop(const Instruction &inst) { |
| const auto *loop = cast<ForInst>(&inst); |
| return (void *)loop || true; // loop->isReduction(); |
| }; |
| |
| bool isLoadOrStore(const Instruction &inst) { |
| const auto *opInst = dyn_cast<OperationInst>(&inst); |
| return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()); |
| }; |
| |
| } // end namespace matcher |
| } // end namespace mlir |