| //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/IntegerSet.h" |
| #include "mlir/IR/Location.h" |
| #include "mlir/IR/Module.h" |
| #include "mlir/IR/StandardTypes.h" |
| using namespace mlir; |
| |
| Builder::Builder(Module *module) : context(module->getContext()) {} |
| |
| Identifier Builder::getIdentifier(StringRef str) { |
| return Identifier::get(str, context); |
| } |
| |
| Module *Builder::createModule() { return new Module(context); } |
| |
| //===----------------------------------------------------------------------===// |
| // Locations. |
| //===----------------------------------------------------------------------===// |
| |
| UnknownLoc Builder::getUnknownLoc() { return UnknownLoc::get(context); } |
| |
| UniquedFilename Builder::getUniquedFilename(StringRef filename) { |
| return UniquedFilename::get(filename, context); |
| } |
| |
| FileLineColLoc Builder::getFileLineColLoc(UniquedFilename filename, |
| unsigned line, unsigned column) { |
| return FileLineColLoc::get(filename, line, column, context); |
| } |
| |
| Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) { |
| return FusedLoc::get(locs, metadata, context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Types. |
| //===----------------------------------------------------------------------===// |
| |
| FloatType Builder::getBF16Type() { return FloatType::getBF16(context); } |
| |
| FloatType Builder::getF16Type() { return FloatType::getF16(context); } |
| |
| FloatType Builder::getF32Type() { return FloatType::getF32(context); } |
| |
| FloatType Builder::getF64Type() { return FloatType::getF64(context); } |
| |
| IndexType Builder::getIndexType() { return IndexType::get(context); } |
| |
| IntegerType Builder::getI1Type() { return IntegerType::get(1, context); } |
| |
| IntegerType Builder::getIntegerType(unsigned width) { |
| return IntegerType::get(width, context); |
| } |
| |
| FunctionType Builder::getFunctionType(ArrayRef<Type> inputs, |
| ArrayRef<Type> results) { |
| return FunctionType::get(inputs, results, context); |
| } |
| |
| MemRefType Builder::getMemRefType(ArrayRef<int64_t> shape, Type elementType, |
| ArrayRef<AffineMap> affineMapComposition, |
| unsigned memorySpace) { |
| return MemRefType::get(shape, elementType, affineMapComposition, memorySpace); |
| } |
| |
| VectorType Builder::getVectorType(ArrayRef<int64_t> shape, Type elementType) { |
| return VectorType::get(shape, elementType); |
| } |
| |
| RankedTensorType Builder::getTensorType(ArrayRef<int64_t> shape, |
| Type elementType) { |
| return RankedTensorType::get(shape, elementType); |
| } |
| |
| UnrankedTensorType Builder::getTensorType(Type elementType) { |
| return UnrankedTensorType::get(elementType); |
| } |
| |
| TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) { |
| return TupleType::get(elementTypes, context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Attributes. |
| //===----------------------------------------------------------------------===// |
| |
| NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) { |
| return NamedAttribute(getIdentifier(name), val); |
| } |
| |
| BoolAttr Builder::getBoolAttr(bool value) { |
| return BoolAttr::get(value, context); |
| } |
| |
| IntegerAttr Builder::getI64IntegerAttr(int64_t value) { |
| return IntegerAttr::get(getIntegerType(64), APInt(64, value)); |
| } |
| |
| IntegerAttr Builder::getI32IntegerAttr(int32_t value) { |
| return IntegerAttr::get(getIntegerType(32), APInt(32, value)); |
| } |
| |
| IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) { |
| if (type.isIndex()) |
| return IntegerAttr::get(type, APInt(64, value)); |
| return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value)); |
| } |
| |
| IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) { |
| return IntegerAttr::get(type, value); |
| } |
| |
| FloatAttr Builder::getF64FloatAttr(double value) { |
| return FloatAttr::get(getF64Type(), APFloat(value)); |
| } |
| |
| FloatAttr Builder::getF32FloatAttr(float value) { |
| return FloatAttr::get(getF32Type(), APFloat(value)); |
| } |
| |
| FloatAttr Builder::getFloatAttr(Type type, double value) { |
| return FloatAttr::get(type, value); |
| } |
| |
| FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) { |
| return FloatAttr::get(type, value); |
| } |
| |
| StringAttr Builder::getStringAttr(StringRef bytes) { |
| return StringAttr::get(bytes, context); |
| } |
| |
| ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) { |
| return ArrayAttr::get(value, context); |
| } |
| |
| AffineMapAttr Builder::getAffineMapAttr(AffineMap map) { |
| return AffineMapAttr::get(map); |
| } |
| |
| IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { |
| return IntegerSetAttr::get(set); |
| } |
| |
| TypeAttr Builder::getTypeAttr(Type type) { |
| return TypeAttr::get(type, context); |
| } |
| |
| FunctionAttr Builder::getFunctionAttr(const Function *value) { |
| return FunctionAttr::get(value, context); |
| } |
| |
| ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type, |
| Attribute elt) { |
| return SplatElementsAttr::get(type, elt); |
| } |
| |
| ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type, |
| ArrayRef<char> data) { |
| return DenseElementsAttr::get(type, data); |
| } |
| |
| ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type, |
| ArrayRef<Attribute> values) { |
| return DenseElementsAttr::get(type, values); |
| } |
| |
| ElementsAttr Builder::getDenseIntElementsAttr(VectorOrTensorType type, |
| ArrayRef<int64_t> values) { |
| return DenseIntElementsAttr::get(type, values); |
| } |
| |
| ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type, |
| DenseIntElementsAttr indices, |
| DenseElementsAttr values) { |
| return SparseElementsAttr::get(type, indices, values); |
| } |
| |
| ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect, |
| VectorOrTensorType type, |
| StringRef bytes) { |
| return OpaqueElementsAttr::get(dialect, type, bytes); |
| } |
| |
| Attribute Builder::getZeroAttr(Type type) { |
| switch (type.getKind()) { |
| case StandardTypes::F32: |
| return getF32FloatAttr(0); |
| case StandardTypes::F64: |
| return getF64FloatAttr(0); |
| case StandardTypes::Integer: { |
| auto width = type.cast<IntegerType>().getWidth(); |
| if (width == 1) |
| return getBoolAttr(false); |
| return getIntegerAttr(type, APInt(width, 0)); |
| } |
| case StandardTypes::Vector: |
| case StandardTypes::RankedTensor: { |
| auto vtType = type.cast<VectorOrTensorType>(); |
| auto element = getZeroAttr(vtType.getElementType()); |
| if (!element) |
| return {}; |
| return getSplatElementsAttr(vtType, element); |
| } |
| default: |
| break; |
| } |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Affine Expressions, Affine Maps, and Integet Sets. |
| //===----------------------------------------------------------------------===// |
| |
| AffineMap Builder::getAffineMap(unsigned dimCount, unsigned symbolCount, |
| ArrayRef<AffineExpr> results, |
| ArrayRef<AffineExpr> rangeSizes) { |
| return AffineMap::get(dimCount, symbolCount, results, rangeSizes); |
| } |
| |
| AffineExpr Builder::getAffineDimExpr(unsigned position) { |
| return mlir::getAffineDimExpr(position, context); |
| } |
| |
| AffineExpr Builder::getAffineSymbolExpr(unsigned position) { |
| return mlir::getAffineSymbolExpr(position, context); |
| } |
| |
| AffineExpr Builder::getAffineConstantExpr(int64_t constant) { |
| return mlir::getAffineConstantExpr(constant, context); |
| } |
| |
| IntegerSet Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount, |
| ArrayRef<AffineExpr> constraints, |
| ArrayRef<bool> isEq) { |
| return IntegerSet::get(dimCount, symbolCount, constraints, isEq); |
| } |
| |
| AffineMap Builder::getConstantAffineMap(int64_t val) { |
| return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, |
| {getAffineConstantExpr(val)}, {}); |
| } |
| |
| AffineMap Builder::getDimIdentityMap() { |
| return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, |
| {getAffineDimExpr(0)}, {}); |
| } |
| |
| AffineMap Builder::getMultiDimIdentityMap(unsigned rank) { |
| SmallVector<AffineExpr, 4> dimExprs; |
| dimExprs.reserve(rank); |
| for (unsigned i = 0; i < rank; ++i) |
| dimExprs.push_back(getAffineDimExpr(i)); |
| return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs, {}); |
| } |
| |
| AffineMap Builder::getSymbolIdentityMap() { |
| return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, |
| {getAffineSymbolExpr(0)}, {}); |
| } |
| |
| AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) { |
| // expr = d0 + shift. |
| auto expr = getAffineDimExpr(0) + shift; |
| return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr}, {}); |
| } |
| |
| AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { |
| SmallVector<AffineExpr, 4> shiftedResults; |
| shiftedResults.reserve(map.getNumResults()); |
| for (auto resultExpr : map.getResults()) { |
| shiftedResults.push_back(resultExpr + shift); |
| } |
| return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults, |
| map.getRangeSizes()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Instructions. |
| //===----------------------------------------------------------------------===// |
| |
| /// Add new block and set the insertion point to the end of it. If an |
| /// 'insertBefore' block is passed, the block will be placed before the |
| /// specified block. If not, the block will be appended to the end of the |
| /// current function. |
| Block *FuncBuilder::createBlock(Block *insertBefore) { |
| Block *b = new Block(); |
| |
| // If we are supposed to insert before a specific block, do so, otherwise add |
| // the block to the end of the function. |
| if (insertBefore) |
| function->getBlocks().insert(Function::iterator(insertBefore), b); |
| else |
| function->push_back(b); |
| |
| setInsertionPointToEnd(b); |
| return b; |
| } |
| |
| /// Create an operation given the fields represented as an OperationState. |
| Instruction *FuncBuilder::createOperation(const OperationState &state) { |
| assert(block && "createOperation() called without setting builder's block"); |
| auto *op = Instruction::create( |
| state.location, state.name, state.operands, state.types, state.attributes, |
| state.successors, state.numRegions, state.resizableOperandList, context); |
| block->getInstructions().insert(insertPoint, op); |
| return op; |
| } |