blob: 88268094204182393d2cdd196b66ac8d40ab0df4 [file] [log] [blame]
//===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_IR_SYMBOLTABLE_H
#define MLIR_IR_SYMBOLTABLE_H
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/StringMap.h"
namespace mlir {
class Identifier;
class MLIRContext;
class Operation;
/// This class allows for representing and managing the symbol table used by
/// operations with the 'SymbolTable' trait.
class SymbolTable {
public:
/// Build a symbol table with the symbols within the given operation.
SymbolTable(Operation *op);
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names never include the @ on them.
Operation *lookup(StringRef name) const;
template <typename T> T lookup(StringRef name) const {
return dyn_cast_or_null<T>(lookup(name));
}
/// Erase the given symbol from the table.
void erase(Operation *symbol);
/// Insert a new symbol into the table, and rename it as necessary to avoid
/// collisions.
void insert(Operation *symbol);
/// Returns the context held by this symbol table.
MLIRContext *getContext() const { return context; }
/// Return the name of the attribute used for symbol names.
static StringRef getSymbolAttrName() { return "sym_name"; }
private:
MLIRContext *context;
/// This is a mapping from a name to the symbol with that name.
llvm::StringMap<Operation *> symbolTable;
/// This is used when name conflicts are detected.
unsigned uniquingCounter = 0;
};
//===----------------------------------------------------------------------===//
// SymbolTable Trait Types
//===----------------------------------------------------------------------===//
namespace OpTrait {
namespace impl {
LogicalResult verifySymbolTable(Operation *op);
} // namespace impl
/// A trait used to provide symbol table functionalities to a region operation.
/// This operation must hold exactly 1 region. Once attached, all operations
/// that are directly within the region, i.e not including those within child
/// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will
/// be verified to ensure that the names are uniqued.
template <typename ConcreteType>
class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySymbolTable(op);
}
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Symbol names never include the @ on them. Note: This
/// performs a linear scan of held symbols.
Operation *lookupSymbol(StringRef name) {
// Look for a symbol with the given name.
for (auto &block : this->getOperation()->getRegion(0)) {
for (auto &op : block) {
auto nameAttr = op.template getAttrOfType<StringAttr>(
mlir::SymbolTable::getSymbolAttrName());
if (nameAttr && nameAttr.getValue() == name)
return &op;
}
}
return nullptr;
}
template <typename T> T lookupSymbol(StringRef name) {
return dyn_cast_or_null<T>(lookupSymbol(name));
}
};
} // end namespace OpTrait
} // end namespace mlir
#endif // MLIR_IR_SYMBOLTABLE_H