| //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| |
| #ifndef MLIR_IR_SYMBOLTABLE_H |
| #define MLIR_IR_SYMBOLTABLE_H |
| |
| #include "mlir/IR/OpDefinition.h" |
| #include "llvm/ADT/StringMap.h" |
| |
| namespace mlir { |
| class Identifier; |
| class MLIRContext; |
| class Operation; |
| |
| /// This class allows for representing and managing the symbol table used by |
| /// operations with the 'SymbolTable' trait. |
| class SymbolTable { |
| public: |
| /// Build a symbol table with the symbols within the given operation. |
| SymbolTable(Operation *op); |
| |
| /// Look up a symbol with the specified name, returning null if no such |
| /// name exists. Names never include the @ on them. |
| Operation *lookup(StringRef name) const; |
| template <typename T> T lookup(StringRef name) const { |
| return dyn_cast_or_null<T>(lookup(name)); |
| } |
| |
| /// Erase the given symbol from the table. |
| void erase(Operation *symbol); |
| |
| /// Insert a new symbol into the table, and rename it as necessary to avoid |
| /// collisions. |
| void insert(Operation *symbol); |
| |
| /// Returns the context held by this symbol table. |
| MLIRContext *getContext() const { return context; } |
| |
| /// Return the name of the attribute used for symbol names. |
| static StringRef getSymbolAttrName() { return "sym_name"; } |
| |
| private: |
| MLIRContext *context; |
| |
| /// This is a mapping from a name to the symbol with that name. |
| llvm::StringMap<Operation *> symbolTable; |
| |
| /// This is used when name conflicts are detected. |
| unsigned uniquingCounter = 0; |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // SymbolTable Trait Types |
| //===----------------------------------------------------------------------===// |
| |
| namespace OpTrait { |
| namespace impl { |
| LogicalResult verifySymbolTable(Operation *op); |
| } // namespace impl |
| |
| /// A trait used to provide symbol table functionalities to a region operation. |
| /// This operation must hold exactly 1 region. Once attached, all operations |
| /// that are directly within the region, i.e not including those within child |
| /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will |
| /// be verified to ensure that the names are uniqued. |
| template <typename ConcreteType> |
| class SymbolTable : public TraitBase<ConcreteType, SymbolTable> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifySymbolTable(op); |
| } |
| |
| /// Look up a symbol with the specified name, returning null if no such |
| /// name exists. Symbol names never include the @ on them. Note: This |
| /// performs a linear scan of held symbols. |
| Operation *lookupSymbol(StringRef name) { |
| // Look for a symbol with the given name. |
| for (auto &block : this->getOperation()->getRegion(0)) { |
| for (auto &op : block) { |
| auto nameAttr = op.template getAttrOfType<StringAttr>( |
| mlir::SymbolTable::getSymbolAttrName()); |
| if (nameAttr && nameAttr.getValue() == name) |
| return &op; |
| } |
| } |
| return nullptr; |
| } |
| template <typename T> T lookupSymbol(StringRef name) { |
| return dyn_cast_or_null<T>(lookupSymbol(name)); |
| } |
| }; |
| } // end namespace OpTrait |
| } // end namespace mlir |
| |
| #endif // MLIR_IR_SYMBOLTABLE_H |