| //===- Ops.cpp - Standard MLIR Operations ---------------------------------===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| |
| #include "mlir/StandardOps/Ops.h" |
| |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/Module.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/StandardTypes.h" |
| #include "mlir/IR/Value.h" |
| #include "mlir/Support/MathExtras.h" |
| #include "mlir/Support/STLExtras.h" |
| #include "llvm/ADT/StringSwitch.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/raw_ostream.h" |
| using namespace mlir; |
| |
| //===----------------------------------------------------------------------===// |
| // StandardOpsDialect |
| //===----------------------------------------------------------------------===// |
| |
| /// A custom binary operation printer that omits the "std." prefix from the |
| /// operation names. |
| static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) { |
| assert(op->getNumOperands() == 2 && "binary op should have two operands"); |
| assert(op->getNumResults() == 1 && "binary op should have one result"); |
| |
| // If not all the operand and result types are the same, just use the |
| // generic assembly form to avoid omitting information in printing. |
| auto resultType = op->getResult(0)->getType(); |
| if (op->getOperand(0)->getType() != resultType || |
| op->getOperand(1)->getType() != resultType) { |
| p->printGenericOp(op); |
| return; |
| } |
| |
| *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' |
| << *op->getOperand(0) << ", " << *op->getOperand(1); |
| p->printOptionalAttrDict(op->getAttrs()); |
| |
| // Now we can output only one type for all operands and the result. |
| *p << " : " << op->getResult(0)->getType(); |
| } |
| |
| /// A custom cast operation printer that omits the "std." prefix from the |
| /// operation names. |
| static void printStandardCastOp(Operation *op, OpAsmPrinter *p) { |
| *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' |
| << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " |
| << op->getResult(0)->getType(); |
| } |
| |
| /// A custom cast operation verifier. |
| template <typename T> static LogicalResult verifyCastOp(T op) { |
| auto opType = op.getOperand()->getType(); |
| auto resType = op.getType(); |
| if (!T::areCastCompatible(opType, resType)) |
| return op.emitError("operand type ") << opType << " and result type " |
| << resType << " are cast incompatible"; |
| |
| return success(); |
| } |
| |
| StandardOpsDialect::StandardOpsDialect(MLIRContext *context) |
| : Dialect(/*name=*/"std", context) { |
| addOperations<CondBranchOp, DmaStartOp, DmaWaitOp, |
| #define GET_OP_LIST |
| #include "mlir/StandardOps/Ops.cpp.inc" |
| >(); |
| } |
| |
| void mlir::printDimAndSymbolList(Operation::operand_iterator begin, |
| Operation::operand_iterator end, |
| unsigned numDims, OpAsmPrinter *p) { |
| *p << '('; |
| p->printOperands(begin, begin + numDims); |
| *p << ')'; |
| |
| if (begin + numDims != end) { |
| *p << '['; |
| p->printOperands(begin + numDims, end); |
| *p << ']'; |
| } |
| } |
| |
| // Parses dimension and symbol list, and sets 'numDims' to the number of |
| // dimension operands parsed. |
| // Returns 'false' on success and 'true' on error. |
| ParseResult mlir::parseDimAndSymbolList(OpAsmParser *parser, |
| SmallVector<Value *, 4> &operands, |
| unsigned &numDims) { |
| SmallVector<OpAsmParser::OperandType, 8> opInfos; |
| if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) |
| return failure(); |
| // Store number of dimensions for validation by caller. |
| numDims = opInfos.size(); |
| |
| // Parse the optional symbol operands. |
| auto affineIntTy = parser->getBuilder().getIndexType(); |
| if (parser->parseOperandList(opInfos, -1, |
| OpAsmParser::Delimiter::OptionalSquare) || |
| parser->resolveOperands(opInfos, affineIntTy, operands)) |
| return failure(); |
| return success(); |
| } |
| |
| /// Matches a ConstantIndexOp. |
| /// TODO: This should probably just be a general matcher that uses m_Constant |
| /// and checks the operation for an index type. |
| static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() { |
| return detail::op_matcher<ConstantIndexOp>(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Common canonicalization pattern support logic |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This is a common class used for patterns of the form |
| /// "someop(memrefcast) -> someop". It folds the source of any memref_cast |
| /// into the root operation directly. |
| struct MemRefCastFolder : public RewritePattern { |
| /// The rootOpName is the name of the root operation to match against. |
| MemRefCastFolder(StringRef rootOpName, MLIRContext *context) |
| : RewritePattern(rootOpName, 1, context) {} |
| |
| PatternMatchResult match(Operation *op) const override { |
| for (auto *operand : op->getOperands()) |
| if (matchPattern(operand, m_Op<MemRefCastOp>())) |
| return matchSuccess(); |
| |
| return matchFailure(); |
| } |
| |
| void rewrite(Operation *op, PatternRewriter &rewriter) const override { |
| for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) |
| if (auto *memref = op->getOperand(i)->getDefiningOp()) |
| if (auto cast = dyn_cast<MemRefCastOp>(memref)) |
| op->setOperand(i, cast.getOperand()); |
| rewriter.updatedRootInPlace(op); |
| } |
| }; |
| |
| /// Performs const folding `calculate` with element-wise behavior on the two |
| /// attributes in `operands` and returns the result if possible. |
| template <class AttrElementT, |
| class ElementValueT = typename AttrElementT::ValueType, |
| class CalculationT = |
| std::function<ElementValueT(ElementValueT, ElementValueT)>> |
| Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, |
| const CalculationT &calculate) { |
| assert(operands.size() == 2 && "binary op takes two operands"); |
| |
| if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) { |
| auto rhs = operands[1].dyn_cast_or_null<AttrElementT>(); |
| if (!rhs || lhs.getType() != rhs.getType()) |
| return {}; |
| |
| return AttrElementT::get(lhs.getType(), |
| calculate(lhs.getValue(), rhs.getValue())); |
| } else if (auto lhs = operands[0].dyn_cast_or_null<SplatElementsAttr>()) { |
| auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>(); |
| if (!rhs || lhs.getType() != rhs.getType()) |
| return {}; |
| |
| auto elementResult = constFoldBinaryOp<AttrElementT>( |
| {lhs.getValue(), rhs.getValue()}, calculate); |
| if (!elementResult) |
| return {}; |
| |
| return SplatElementsAttr::get(lhs.getType(), elementResult); |
| } |
| return {}; |
| } |
| } // end anonymous namespace. |
| |
| //===----------------------------------------------------------------------===// |
| // AddFOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldBinaryOp<FloatAttr>( |
| operands, [](APFloat a, APFloat b) { return a + b; }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AddIOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) { |
| /// addi(x, 0) -> x |
| if (matchPattern(rhs(), m_Zero())) |
| return lhs(); |
| |
| return constFoldBinaryOp<IntegerAttr>(operands, |
| [](APInt a, APInt b) { return a + b; }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AllocOp |
| //===----------------------------------------------------------------------===// |
| |
| static void print(OpAsmPrinter *p, AllocOp op) { |
| *p << "alloc"; |
| |
| // Print dynamic dimension operands. |
| MemRefType type = op.getType(); |
| printDimAndSymbolList(op.operand_begin(), op.operand_end(), |
| type.getNumDynamicDims(), p); |
| p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); |
| *p << " : " << type; |
| } |
| |
| static ParseResult parseAllocOp(OpAsmParser *parser, OperationState *result) { |
| MemRefType type; |
| |
| // Parse the dimension operands and optional symbol operands, followed by a |
| // memref type. |
| unsigned numDimOperands; |
| if (parseDimAndSymbolList(parser, result->operands, numDimOperands) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(type)) |
| return failure(); |
| |
| // Check numDynamicDims against number of question marks in memref type. |
| // Note: this check remains here (instead of in verify()), because the |
| // partition between dim operands and symbol operands is lost after parsing. |
| // Verification still checks that the total number of operands matches |
| // the number of symbols in the affine map, plus the number of dynamic |
| // dimensions in the memref. |
| if (numDimOperands != type.getNumDynamicDims()) |
| return parser->emitError(parser->getNameLoc()) |
| << "dimension operand count does not equal memref dynamic dimension " |
| "count"; |
| result->types.push_back(type); |
| return success(); |
| } |
| |
| static LogicalResult verify(AllocOp op) { |
| auto memRefType = op.getResult()->getType().dyn_cast<MemRefType>(); |
| if (!memRefType) |
| return op.emitOpError("result must be a memref"); |
| |
| unsigned numSymbols = 0; |
| if (!memRefType.getAffineMaps().empty()) { |
| AffineMap affineMap = memRefType.getAffineMaps()[0]; |
| // Store number of symbols used in affine map (used in subsequent check). |
| numSymbols = affineMap.getNumSymbols(); |
| // TODO(zinenko): this check does not belong to AllocOp, or any other op but |
| // to the type system itself. It has been partially hoisted to Parser but |
| // remains here in case an AllocOp gets constructed programmatically. |
| // Remove when we can emit errors directly from *Type::get(...) functions. |
| // |
| // Verify that the layout affine map matches the rank of the memref. |
| if (affineMap.getNumDims() != memRefType.getRank()) |
| return op.emitOpError( |
| "affine map dimension count must equal memref rank"); |
| } |
| unsigned numDynamicDims = memRefType.getNumDynamicDims(); |
| // Check that the total number of operands matches the number of symbols in |
| // the affine map, plus the number of dynamic dimensions specified in the |
| // memref type. |
| if (op.getOperation()->getNumOperands() != numDynamicDims + numSymbols) |
| return op.emitOpError( |
| "operand count does not equal dimension plus symbol operand count"); |
| |
| // Verify that all operands are of type Index. |
| for (auto operandType : op.getOperandTypes()) |
| if (!operandType.isIndex()) |
| return op.emitOpError("requires operands to be of type Index"); |
| return success(); |
| } |
| |
| namespace { |
| /// Fold constant dimensions into an alloc operation. |
| struct SimplifyAllocConst : public OpRewritePattern<AllocOp> { |
| using OpRewritePattern<AllocOp>::OpRewritePattern; |
| |
| PatternMatchResult matchAndRewrite(AllocOp alloc, |
| PatternRewriter &rewriter) const override { |
| // Check to see if any dimensions operands are constants. If so, we can |
| // substitute and drop them. |
| if (llvm::none_of(alloc.getOperands(), [](Value *operand) { |
| return matchPattern(operand, m_ConstantIndex()); |
| })) |
| return matchFailure(); |
| |
| auto memrefType = alloc.getType(); |
| |
| // Ok, we have one or more constant operands. Collect the non-constant ones |
| // and keep track of the resultant memref type to build. |
| SmallVector<int64_t, 4> newShapeConstants; |
| newShapeConstants.reserve(memrefType.getRank()); |
| SmallVector<Value *, 4> newOperands; |
| SmallVector<Value *, 4> droppedOperands; |
| |
| unsigned dynamicDimPos = 0; |
| for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { |
| int64_t dimSize = memrefType.getDimSize(dim); |
| // If this is already static dimension, keep it. |
| if (dimSize != -1) { |
| newShapeConstants.push_back(dimSize); |
| continue; |
| } |
| auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp(); |
| if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { |
| // Dynamic shape dimension will be folded. |
| newShapeConstants.push_back(constantIndexOp.getValue()); |
| // Record to check for zero uses later below. |
| droppedOperands.push_back(constantIndexOp); |
| } else { |
| // Dynamic shape dimension not folded; copy operand from old memref. |
| newShapeConstants.push_back(-1); |
| newOperands.push_back(alloc.getOperand(dynamicDimPos)); |
| } |
| dynamicDimPos++; |
| } |
| |
| // Create new memref type (which will have fewer dynamic dimensions). |
| auto newMemRefType = MemRefType::get( |
| newShapeConstants, memrefType.getElementType(), |
| memrefType.getAffineMaps(), memrefType.getMemorySpace()); |
| assert(newOperands.size() == newMemRefType.getNumDynamicDims()); |
| |
| // Create and insert the alloc op for the new memref. |
| auto newAlloc = |
| rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands); |
| // Insert a cast so we have the same type as the old alloc. |
| auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc, |
| alloc.getType()); |
| |
| rewriter.replaceOp(alloc, {resultCast}, droppedOperands); |
| return matchSuccess(); |
| } |
| }; |
| |
| /// Fold alloc operations with no uses. Alloc has side effects on the heap, |
| /// but can still be deleted if it has zero uses. |
| struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> { |
| using OpRewritePattern<AllocOp>::OpRewritePattern; |
| |
| PatternMatchResult matchAndRewrite(AllocOp alloc, |
| PatternRewriter &rewriter) const override { |
| // Check if the alloc'ed value has any uses. |
| if (!alloc.use_empty()) |
| return matchFailure(); |
| |
| // If it doesn't, we can eliminate it. |
| alloc.erase(); |
| return matchSuccess(); |
| } |
| }; |
| } // end anonymous namespace. |
| |
| void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| RewriteListBuilder<SimplifyAllocConst, SimplifyDeadAlloc>::build(results, |
| context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BranchOp |
| //===----------------------------------------------------------------------===// |
| |
| static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *result) { |
| Block *dest; |
| SmallVector<Value *, 4> destOperands; |
| if (parser->parseSuccessorAndUseList(dest, destOperands)) |
| return failure(); |
| result->addSuccessor(dest, destOperands); |
| return success(); |
| } |
| |
| static void print(OpAsmPrinter *p, BranchOp op) { |
| *p << "br "; |
| p->printSuccessorAndUseList(op.getOperation(), 0); |
| } |
| |
| Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } |
| |
| void BranchOp::setDest(Block *block) { |
| return getOperation()->setSuccessor(block, 0); |
| } |
| |
| void BranchOp::eraseOperand(unsigned index) { |
| getOperation()->eraseSuccessorOperand(0, index); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CallOp |
| //===----------------------------------------------------------------------===// |
| |
| static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) { |
| StringRef calleeName; |
| llvm::SMLoc calleeLoc; |
| FunctionType calleeType; |
| SmallVector<OpAsmParser::OperandType, 4> operands; |
| if (parser->parseFunctionName(calleeName, calleeLoc) || |
| parser->parseOperandList(operands, /*requiredOperandCount=*/-1, |
| OpAsmParser::Delimiter::Paren) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(calleeType) || |
| parser->addTypesToList(calleeType.getResults(), result->types) || |
| parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc, |
| result->operands)) |
| return failure(); |
| |
| result->addAttribute("callee", |
| parser->getBuilder().getFunctionAttr(calleeName)); |
| return success(); |
| } |
| |
| static void print(OpAsmPrinter *p, CallOp op) { |
| *p << "call " << op.getAttr("callee") << '('; |
| p->printOperands(op.getOperands()); |
| *p << ')'; |
| p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); |
| *p << " : "; |
| p->printType(op.getCalleeType()); |
| } |
| |
| static LogicalResult verify(CallOp op) { |
| // Check that the callee attribute was specified. |
| auto fnAttr = op.getAttrOfType<FunctionAttr>("callee"); |
| if (!fnAttr) |
| return op.emitOpError("requires a 'callee' function attribute"); |
| auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction( |
| fnAttr.getValue()); |
| if (!fn) |
| return op.emitOpError() << "'" << fnAttr.getValue() |
| << "' does not reference a valid function"; |
| |
| // Verify that the operand and result types match the callee. |
| auto fnType = fn->getType(); |
| if (fnType.getNumInputs() != op.getNumOperands()) |
| return op.emitOpError("incorrect number of operands for callee"); |
| |
| for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) |
| if (op.getOperand(i)->getType() != fnType.getInput(i)) |
| return op.emitOpError("operand type mismatch"); |
| |
| if (fnType.getNumResults() != op.getNumResults()) |
| return op.emitOpError("incorrect number of results for callee"); |
| |
| for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) |
| if (op.getResult(i)->getType() != fnType.getResult(i)) |
| return op.emitOpError("result type mismatch"); |
| |
| return success(); |
| } |
| |
| FunctionType CallOp::getCalleeType() { |
| SmallVector<Type, 4> resultTypes(getResultTypes()); |
| SmallVector<Type, 8> argTypes(getOperandTypes()); |
| return FunctionType::get(argTypes, resultTypes, getContext()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CallIndirectOp |
| //===----------------------------------------------------------------------===// |
| namespace { |
| /// Fold indirect calls that have a constant function as the callee operand. |
| struct SimplifyIndirectCallWithKnownCallee |
| : public OpRewritePattern<CallIndirectOp> { |
| using OpRewritePattern<CallIndirectOp>::OpRewritePattern; |
| |
| PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall, |
| PatternRewriter &rewriter) const override { |
| // Check that the callee is a constant callee. |
| FunctionAttr calledFn; |
| if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) |
| return matchFailure(); |
| |
| // Replace with a direct call. |
| SmallVector<Type, 8> callResults(indirectCall.getResultTypes()); |
| SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands()); |
| rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(), |
| callResults, callOperands); |
| return matchSuccess(); |
| } |
| }; |
| } // end anonymous namespace. |
| |
| static ParseResult parseCallIndirectOp(OpAsmParser *parser, |
| OperationState *result) { |
| FunctionType calleeType; |
| OpAsmParser::OperandType callee; |
| llvm::SMLoc operandsLoc; |
| SmallVector<OpAsmParser::OperandType, 4> operands; |
| return failure( |
| parser->parseOperand(callee) || |
| parser->getCurrentLocation(&operandsLoc) || |
| parser->parseOperandList(operands, /*requiredOperandCount=*/-1, |
| OpAsmParser::Delimiter::Paren) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(calleeType) || |
| parser->resolveOperand(callee, calleeType, result->operands) || |
| parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc, |
| result->operands) || |
| parser->addTypesToList(calleeType.getResults(), result->types)); |
| } |
| |
| static void print(OpAsmPrinter *p, CallIndirectOp op) { |
| *p << "call_indirect "; |
| p->printOperand(op.getCallee()); |
| *p << '('; |
| auto operandRange = op.getOperands(); |
| p->printOperands(++operandRange.begin(), operandRange.end()); |
| *p << ')'; |
| p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); |
| *p << " : " << op.getCallee()->getType(); |
| } |
| |
| static LogicalResult verify(CallIndirectOp op) { |
| // The callee must be a function. |
| auto fnType = op.getCallee()->getType().dyn_cast<FunctionType>(); |
| if (!fnType) |
| return op.emitOpError("callee must have function type"); |
| |
| // Verify that the operand and result types match the callee. |
| if (fnType.getNumInputs() != op.getNumOperands() - 1) |
| return op.emitOpError("incorrect number of operands for callee"); |
| |
| for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) |
| if (op.getOperand(i + 1)->getType() != fnType.getInput(i)) |
| return op.emitOpError("operand type mismatch"); |
| |
| if (fnType.getNumResults() != op.getNumResults()) |
| return op.emitOpError("incorrect number of results for callee"); |
| |
| for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) |
| if (op.getResult(i)->getType() != fnType.getResult(i)) |
| return op.emitOpError("result type mismatch"); |
| |
| return success(); |
| } |
| |
| void CallIndirectOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.push_back( |
| llvm::make_unique<SimplifyIndirectCallWithKnownCallee>(context)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // General helpers for comparison ops |
| //===----------------------------------------------------------------------===// |
| |
| // Return the type of the same shape (scalar, vector or tensor) containing i1. |
| static Type getCheckedI1SameShape(Builder *build, Type type) { |
| auto i1Type = build->getI1Type(); |
| if (type.isIntOrIndexOrFloat()) |
| return i1Type; |
| if (auto tensorType = type.dyn_cast<RankedTensorType>()) |
| return build->getTensorType(tensorType.getShape(), i1Type); |
| if (auto tensorType = type.dyn_cast<UnrankedTensorType>()) |
| return build->getTensorType(i1Type); |
| if (auto vectorType = type.dyn_cast<VectorType>()) |
| return build->getVectorType(vectorType.getShape(), i1Type); |
| return Type(); |
| } |
| |
| static Type getI1SameShape(Builder *build, Type type) { |
| Type res = getCheckedI1SameShape(build, type); |
| assert(res && "expected type with valid i1 shape"); |
| return res; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CmpIOp |
| //===----------------------------------------------------------------------===// |
| |
| // Returns an array of mnemonics for CmpIPredicates indexed by values thereof. |
| static inline const char *const *getCmpIPredicateNames() { |
| static const char *predicateNames[]{ |
| /*EQ*/ "eq", |
| /*NE*/ "ne", |
| /*SLT*/ "slt", |
| /*SLE*/ "sle", |
| /*SGT*/ "sgt", |
| /*SGE*/ "sge", |
| /*ULT*/ "ult", |
| /*ULE*/ "ule", |
| /*UGT*/ "ugt", |
| /*UGE*/ "uge", |
| }; |
| static_assert(std::extent<decltype(predicateNames)>::value == |
| (size_t)CmpIPredicate::NumPredicates, |
| "wrong number of predicate names"); |
| return predicateNames; |
| } |
| |
| // Returns a value of the predicate corresponding to the given mnemonic. |
| // Returns NumPredicates (one-past-end) if there is no such mnemonic. |
| CmpIPredicate CmpIOp::getPredicateByName(StringRef name) { |
| return llvm::StringSwitch<CmpIPredicate>(name) |
| .Case("eq", CmpIPredicate::EQ) |
| .Case("ne", CmpIPredicate::NE) |
| .Case("slt", CmpIPredicate::SLT) |
| .Case("sle", CmpIPredicate::SLE) |
| .Case("sgt", CmpIPredicate::SGT) |
| .Case("sge", CmpIPredicate::SGE) |
| .Case("ult", CmpIPredicate::ULT) |
| .Case("ule", CmpIPredicate::ULE) |
| .Case("ugt", CmpIPredicate::UGT) |
| .Case("uge", CmpIPredicate::UGE) |
| .Default(CmpIPredicate::NumPredicates); |
| } |
| |
| static void buildCmpIOp(Builder *build, OperationState *result, |
| CmpIPredicate predicate, Value *lhs, Value *rhs) { |
| result->addOperands({lhs, rhs}); |
| result->types.push_back(getI1SameShape(build, lhs->getType())); |
| result->addAttribute( |
| CmpIOp::getPredicateAttrName(), |
| build->getI64IntegerAttr(static_cast<int64_t>(predicate))); |
| } |
| |
| static ParseResult parseCmpIOp(OpAsmParser *parser, OperationState *result) { |
| SmallVector<OpAsmParser::OperandType, 2> ops; |
| SmallVector<NamedAttribute, 4> attrs; |
| Attribute predicateNameAttr; |
| Type type; |
| if (parser->parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(), |
| attrs) || |
| parser->parseComma() || parser->parseOperandList(ops, 2) || |
| parser->parseOptionalAttributeDict(attrs) || |
| parser->parseColonType(type) || |
| parser->resolveOperands(ops, type, result->operands)) |
| return failure(); |
| |
| if (!predicateNameAttr.isa<StringAttr>()) |
| return parser->emitError(parser->getNameLoc(), |
| "expected string comparison predicate attribute"); |
| |
| // Rewrite string attribute to an enum value. |
| StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue(); |
| auto predicate = CmpIOp::getPredicateByName(predicateName); |
| if (predicate == CmpIPredicate::NumPredicates) |
| return parser->emitError(parser->getNameLoc()) |
| << "unknown comparison predicate \"" << predicateName << "\""; |
| |
| auto builder = parser->getBuilder(); |
| Type i1Type = getCheckedI1SameShape(&builder, type); |
| if (!i1Type) |
| return parser->emitError(parser->getNameLoc(), |
| "expected type with valid i1 shape"); |
| |
| attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate)); |
| result->attributes = attrs; |
| |
| result->addTypes({i1Type}); |
| return success(); |
| } |
| |
| static void print(OpAsmPrinter *p, CmpIOp op) { |
| *p << "cmpi "; |
| |
| auto predicateValue = |
| op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt(); |
| assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) && |
| predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) && |
| "unknown predicate index"); |
| Builder b(op.getContext()); |
| auto predicateStringAttr = |
| b.getStringAttr(getCmpIPredicateNames()[predicateValue]); |
| p->printAttribute(predicateStringAttr); |
| |
| *p << ", "; |
| p->printOperand(op.lhs()); |
| *p << ", "; |
| p->printOperand(op.rhs()); |
| p->printOptionalAttrDict(op.getAttrs(), |
| /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); |
| *p << " : " << op.lhs()->getType(); |
| } |
| |
| static LogicalResult verify(CmpIOp op) { |
| auto predicateAttr = |
| op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()); |
| if (!predicateAttr) |
| return op.emitOpError("requires an integer attribute named 'predicate'"); |
| auto predicate = predicateAttr.getInt(); |
| if (predicate < (int64_t)CmpIPredicate::FirstValidValue || |
| predicate >= (int64_t)CmpIPredicate::NumPredicates) |
| return op.emitOpError("'predicate' attribute value out of range"); |
| |
| return success(); |
| } |
| |
| // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer |
| // comparison predicates. |
| static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, |
| const APInt &rhs) { |
| switch (predicate) { |
| case CmpIPredicate::EQ: |
| return lhs.eq(rhs); |
| case CmpIPredicate::NE: |
| return lhs.ne(rhs); |
| case CmpIPredicate::SLT: |
| return lhs.slt(rhs); |
| case CmpIPredicate::SLE: |
| return lhs.sle(rhs); |
| case CmpIPredicate::SGT: |
| return lhs.sgt(rhs); |
| case CmpIPredicate::SGE: |
| return lhs.sge(rhs); |
| case CmpIPredicate::ULT: |
| return lhs.ult(rhs); |
| case CmpIPredicate::ULE: |
| return lhs.ule(rhs); |
| case CmpIPredicate::UGT: |
| return lhs.ugt(rhs); |
| case CmpIPredicate::UGE: |
| return lhs.uge(rhs); |
| default: |
| llvm_unreachable("unknown comparison predicate"); |
| } |
| } |
| |
| // Constant folding hook for comparisons. |
| OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) { |
| assert(operands.size() == 2 && "cmpi takes two arguments"); |
| |
| auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); |
| auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); |
| if (!lhs || !rhs) |
| return {}; |
| |
| auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); |
| return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CmpFOp |
| //===----------------------------------------------------------------------===// |
| |
| // Returns an array of mnemonics for CmpFPredicates indexed by values thereof. |
| static inline const char *const *getCmpFPredicateNames() { |
| static const char *predicateNames[] = { |
| /*FALSE*/ "false", |
| /*OEQ*/ "oeq", |
| /*OGT*/ "ogt", |
| /*OGE*/ "oge", |
| /*OLT*/ "olt", |
| /*OLE*/ "ole", |
| /*ONE*/ "one", |
| /*ORD*/ "ord", |
| /*UEQ*/ "ueq", |
| /*UGT*/ "ugt", |
| /*UGE*/ "uge", |
| /*ULT*/ "ult", |
| /*ULE*/ "ule", |
| /*UNE*/ "une", |
| /*UNO*/ "uno", |
| /*TRUE*/ "true", |
| }; |
| static_assert(std::extent<decltype(predicateNames)>::value == |
| (size_t)CmpFPredicate::NumPredicates, |
| "wrong number of predicate names"); |
| return predicateNames; |
| } |
| |
| // Returns a value of the predicate corresponding to the given mnemonic. |
| // Returns NumPredicates (one-past-end) if there is no such mnemonic. |
| CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { |
| return llvm::StringSwitch<CmpFPredicate>(name) |
| .Case("false", CmpFPredicate::FALSE) |
| .Case("oeq", CmpFPredicate::OEQ) |
| .Case("ogt", CmpFPredicate::OGT) |
| .Case("oge", CmpFPredicate::OGE) |
| .Case("olt", CmpFPredicate::OLT) |
| .Case("ole", CmpFPredicate::OLE) |
| .Case("one", CmpFPredicate::ONE) |
| .Case("ord", CmpFPredicate::ORD) |
| .Case("ueq", CmpFPredicate::UEQ) |
| .Case("ugt", CmpFPredicate::UGT) |
| .Case("uge", CmpFPredicate::UGE) |
| .Case("ult", CmpFPredicate::ULT) |
| .Case("ule", CmpFPredicate::ULE) |
| .Case("une", CmpFPredicate::UNE) |
| .Case("uno", CmpFPredicate::UNO) |
| .Case("true", CmpFPredicate::TRUE) |
| .Default(CmpFPredicate::NumPredicates); |
| } |
| |
| static void buildCmpFOp(Builder *build, OperationState *result, |
| CmpFPredicate predicate, Value *lhs, Value *rhs) { |
| result->addOperands({lhs, rhs}); |
| result->types.push_back(getI1SameShape(build, lhs->getType())); |
| result->addAttribute( |
| CmpFOp::getPredicateAttrName(), |
| build->getI64IntegerAttr(static_cast<int64_t>(predicate))); |
| } |
| |
| static ParseResult parseCmpFOp(OpAsmParser *parser, OperationState *result) { |
| SmallVector<OpAsmParser::OperandType, 2> ops; |
| SmallVector<NamedAttribute, 4> attrs; |
| Attribute predicateNameAttr; |
| Type type; |
| if (parser->parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(), |
| attrs) || |
| parser->parseComma() || parser->parseOperandList(ops, 2) || |
| parser->parseOptionalAttributeDict(attrs) || |
| parser->parseColonType(type) || |
| parser->resolveOperands(ops, type, result->operands)) |
| return failure(); |
| |
| if (!predicateNameAttr.isa<StringAttr>()) |
| return parser->emitError(parser->getNameLoc(), |
| "expected string comparison predicate attribute"); |
| |
| // Rewrite string attribute to an enum value. |
| StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue(); |
| auto predicate = CmpFOp::getPredicateByName(predicateName); |
| if (predicate == CmpFPredicate::NumPredicates) |
| return parser->emitError(parser->getNameLoc(), |
| "unknown comparison predicate \"" + predicateName + |
| "\""); |
| |
| auto builder = parser->getBuilder(); |
| Type i1Type = getCheckedI1SameShape(&builder, type); |
| if (!i1Type) |
| return parser->emitError(parser->getNameLoc(), |
| "expected type with valid i1 shape"); |
| |
| attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate)); |
| result->attributes = attrs; |
| |
| result->addTypes({i1Type}); |
| return success(); |
| } |
| |
| static void print(OpAsmPrinter *p, CmpFOp op) { |
| *p << "cmpf "; |
| |
| auto predicateValue = |
| op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()).getInt(); |
| assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) && |
| predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) && |
| "unknown predicate index"); |
| Builder b(op.getContext()); |
| auto predicateStringAttr = |
| b.getStringAttr(getCmpFPredicateNames()[predicateValue]); |
| p->printAttribute(predicateStringAttr); |
| |
| *p << ", "; |
| p->printOperand(op.lhs()); |
| *p << ", "; |
| p->printOperand(op.rhs()); |
| p->printOptionalAttrDict(op.getAttrs(), |
| /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); |
| *p << " : " << op.lhs()->getType(); |
| } |
| |
| static LogicalResult verify(CmpFOp op) { |
| auto predicateAttr = |
| op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()); |
| if (!predicateAttr) |
| return op.emitOpError("requires an integer attribute named 'predicate'"); |
| auto predicate = predicateAttr.getInt(); |
| if (predicate < (int64_t)CmpFPredicate::FirstValidValue || |
| predicate >= (int64_t)CmpFPredicate::NumPredicates) |
| return op.emitOpError("'predicate' attribute value out of range"); |
| |
| return success(); |
| } |
| |
| // Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point |
| // comparison predicates. |
| static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, |
| const APFloat &rhs) { |
| auto cmpResult = lhs.compare(rhs); |
| switch (predicate) { |
| case CmpFPredicate::FALSE: |
| return false; |
| case CmpFPredicate::OEQ: |
| return cmpResult == APFloat::cmpEqual; |
| case CmpFPredicate::OGT: |
| return cmpResult == APFloat::cmpGreaterThan; |
| case CmpFPredicate::OGE: |
| return cmpResult == APFloat::cmpGreaterThan || |
| cmpResult == APFloat::cmpEqual; |
| case CmpFPredicate::OLT: |
| return cmpResult == APFloat::cmpLessThan; |
| case CmpFPredicate::OLE: |
| return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; |
| case CmpFPredicate::ONE: |
| return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; |
| case CmpFPredicate::ORD: |
| return cmpResult != APFloat::cmpUnordered; |
| case CmpFPredicate::UEQ: |
| return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; |
| case CmpFPredicate::UGT: |
| return cmpResult == APFloat::cmpUnordered || |
| cmpResult == APFloat::cmpGreaterThan; |
| case CmpFPredicate::UGE: |
| return cmpResult == APFloat::cmpUnordered || |
| cmpResult == APFloat::cmpGreaterThan || |
| cmpResult == APFloat::cmpEqual; |
| case CmpFPredicate::ULT: |
| return cmpResult == APFloat::cmpUnordered || |
| cmpResult == APFloat::cmpLessThan; |
| case CmpFPredicate::ULE: |
| return cmpResult == APFloat::cmpUnordered || |
| cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; |
| case CmpFPredicate::UNE: |
| return cmpResult != APFloat::cmpEqual; |
| case CmpFPredicate::UNO: |
| return cmpResult == APFloat::cmpUnordered; |
| case CmpFPredicate::TRUE: |
| return true; |
| default: |
| llvm_unreachable("unknown comparison predicate"); |
| } |
| } |
| |
| // Constant folding hook for comparisons. |
| OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) { |
| assert(operands.size() == 2 && "cmpf takes two arguments"); |
| |
| auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); |
| auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); |
| if (!lhs || !rhs || |
| // TODO(b/122019992) Implement and test constant folding for nan/inf when |
| // it is possible to have constant nan/inf |
| !lhs.getValue().isFinite() || !rhs.getValue().isFinite()) |
| return {}; |
| |
| auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); |
| return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CondBranchOp |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// cond_br true, ^bb1, ^bb2 -> br ^bb1 |
| /// cond_br false, ^bb1, ^bb2 -> br ^bb2 |
| /// |
| struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { |
| using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
| |
| PatternMatchResult matchAndRewrite(CondBranchOp condbr, |
| PatternRewriter &rewriter) const override { |
| // Check that the condition is a constant. |
| if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>())) |
| return matchFailure(); |
| |
| Block *foldedDest; |
| SmallVector<Value *, 4> branchArgs; |
| |
| // If the condition is known to evaluate to false we fold to a branch to the |
| // false destination. Otherwise, we fold to a branch to the true |
| // destination. |
| if (matchPattern(condbr.getCondition(), m_Zero())) { |
| foldedDest = condbr.getFalseDest(); |
| branchArgs.assign(condbr.false_operand_begin(), |
| condbr.false_operand_end()); |
| } else { |
| foldedDest = condbr.getTrueDest(); |
| branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end()); |
| } |
| |
| rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs); |
| return matchSuccess(); |
| } |
| }; |
| } // end anonymous namespace. |
| |
| void CondBranchOp::build(Builder *builder, OperationState *result, |
| Value *condition, Block *trueDest, |
| ArrayRef<Value *> trueOperands, Block *falseDest, |
| ArrayRef<Value *> falseOperands) { |
| result->addOperands(condition); |
| result->addSuccessor(trueDest, trueOperands); |
| result->addSuccessor(falseDest, falseOperands); |
| } |
| |
| ParseResult CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { |
| SmallVector<Value *, 4> destOperands; |
| Block *dest; |
| OpAsmParser::OperandType condInfo; |
| |
| // Parse the condition. |
| Type int1Ty = parser->getBuilder().getI1Type(); |
| if (parser->parseOperand(condInfo) || parser->parseComma() || |
| parser->resolveOperand(condInfo, int1Ty, result->operands)) { |
| return parser->emitError(parser->getNameLoc(), |
| "expected condition type was boolean (i1)"); |
| } |
| |
| // Parse the true successor. |
| if (parser->parseSuccessorAndUseList(dest, destOperands)) |
| return failure(); |
| result->addSuccessor(dest, destOperands); |
| |
| // Parse the false successor. |
| destOperands.clear(); |
| if (parser->parseComma() || |
| parser->parseSuccessorAndUseList(dest, destOperands)) |
| return failure(); |
| result->addSuccessor(dest, destOperands); |
| |
| return success(); |
| } |
| |
| void CondBranchOp::print(OpAsmPrinter *p) { |
| *p << "cond_br "; |
| p->printOperand(getCondition()); |
| *p << ", "; |
| p->printSuccessorAndUseList(getOperation(), trueIndex); |
| *p << ", "; |
| p->printSuccessorAndUseList(getOperation(), falseIndex); |
| } |
| |
| LogicalResult CondBranchOp::verify() { |
| if (!getCondition()->getType().isInteger(1)) |
| return emitOpError("expected condition type was boolean (i1)"); |
| return success(); |
| } |
| |
| void CondBranchOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.push_back(llvm::make_unique<SimplifyConstCondBranchPred>(context)); |
| } |
| |
| Block *CondBranchOp::getTrueDest() { |
| return getOperation()->getSuccessor(trueIndex); |
| } |
| |
| Block *CondBranchOp::getFalseDest() { |
| return getOperation()->getSuccessor(falseIndex); |
| } |
| |
| unsigned CondBranchOp::getNumTrueOperands() { |
| return getOperation()->getNumSuccessorOperands(trueIndex); |
| } |
| |
| void CondBranchOp::eraseTrueOperand(unsigned index) { |
| getOperation()->eraseSuccessorOperand(trueIndex, index); |
| } |
| |
| unsigned CondBranchOp::getNumFalseOperands() { |
| return getOperation()->getNumSuccessorOperands(falseIndex); |
| } |
| |
| void CondBranchOp::eraseFalseOperand(unsigned index) { |
| getOperation()->eraseSuccessorOperand(falseIndex, index); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Constant*Op |
| //===----------------------------------------------------------------------===// |
| |
| static void print(OpAsmPrinter *p, ConstantOp &op) { |
| *p << "constant "; |
| p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); |
| |
| if (op.getAttrs().size() > 1) |
| *p << ' '; |
| p->printAttributeAndType(op.getValue()); |
| |
| // If the value is a function, print a trailing type. |
| if (op.getValue().isa<FunctionAttr>()) { |
| *p << " : "; |
| p->printType(op.getType()); |
| } |
| } |
| |
| static ParseResult parseConstantOp(OpAsmParser *parser, |
| OperationState *result) { |
| Attribute valueAttr; |
| if (parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseAttribute(valueAttr, "value", result->attributes)) |
| return failure(); |
| |
| // If the attribute is a function, then we expect a trailing type. |
| Type type; |
| if (!valueAttr.isa<FunctionAttr>()) |
| type = valueAttr.getType(); |
| else if (parser->parseColonType(type)) |
| return failure(); |
| |
| // Add the attribute type to the list. |
| return parser->addTypeToList(type, result->types); |
| } |
| |
| /// The constant op requires an attribute, and furthermore requires that it |
| /// matches the return type. |
| static LogicalResult verify(ConstantOp &op) { |
| auto value = op.getValue(); |
| if (!value) |
| return op.emitOpError("requires a 'value' attribute"); |
| |
| auto type = op.getType(); |
| if (!value.getType().isa<NoneType>() && type != value.getType()) |
| return op.emitOpError() << "requires attribute's type (" << value.getType() |
| << ") to match op's return type (" << type << ")"; |
| |
| if (type.isa<IndexType>()) |
| return success(); |
| |
| if (auto intAttr = value.dyn_cast<IntegerAttr>()) { |
| // If the type has a known bitwidth we verify that the value can be |
| // represented with the given bitwidth. |
| auto bitwidth = type.cast<IntegerType>().getWidth(); |
| auto intVal = intAttr.getValue(); |
| if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) |
| return op.emitOpError("requires 'value' to be an integer within the " |
| "range of the integer result type"); |
| return success(); |
| } |
| |
| if (type.isa<FloatType>()) { |
| if (!value.isa<FloatAttr>()) |
| return op.emitOpError("requires 'value' to be a floating point constant"); |
| return success(); |
| } |
| |
| if (type.isa<ShapedType>()) { |
| if (!value.isa<ElementsAttr>()) |
| return op.emitOpError("requires 'value' to be a shaped constant"); |
| return success(); |
| } |
| |
| if (type.isa<FunctionType>()) { |
| auto fnAttr = value.dyn_cast<FunctionAttr>(); |
| if (!fnAttr) |
| return op.emitOpError("requires 'value' to be a function reference"); |
| |
| // Try to find the referenced function. |
| auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction( |
| fnAttr.getValue()); |
| if (!fn) |
| return op.emitOpError("reference to undefined function 'bar'"); |
| |
| // Check that the referenced function has the correct type. |
| if (fn->getType() != type) |
| return op.emitOpError("reference to function with mismatched type"); |
| |
| return success(); |
| } |
| |
| return op.emitOpError( |
| "requires a result type that aligns with the 'value' attribute"); |
| } |
| |
| OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { |
| assert(operands.empty() && "constant has no operands"); |
| return getValue(); |
| } |
| |
| void ConstantFloatOp::build(Builder *builder, OperationState *result, |
| const APFloat &value, FloatType type) { |
| ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value)); |
| } |
| |
| bool ConstantFloatOp::classof(Operation *op) { |
| return ConstantOp::classof(op) && |
| op->getResult(0)->getType().isa<FloatType>(); |
| } |
| |
| /// ConstantIntOp only matches values whose result type is an IntegerType. |
| bool ConstantIntOp::classof(Operation *op) { |
| return ConstantOp::classof(op) && |
| op->getResult(0)->getType().isa<IntegerType>(); |
| } |
| |
| void ConstantIntOp::build(Builder *builder, OperationState *result, |
| int64_t value, unsigned width) { |
| Type type = builder->getIntegerType(width); |
| ConstantOp::build(builder, result, type, |
| builder->getIntegerAttr(type, value)); |
| } |
| |
| /// Build a constant int op producing an integer with the specified type, |
| /// which must be an integer type. |
| void ConstantIntOp::build(Builder *builder, OperationState *result, |
| int64_t value, Type type) { |
| assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type"); |
| ConstantOp::build(builder, result, type, |
| builder->getIntegerAttr(type, value)); |
| } |
| |
| /// ConstantIndexOp only matches values whose result type is Index. |
| bool ConstantIndexOp::classof(Operation *op) { |
| return ConstantOp::classof(op) && op->getResult(0)->getType().isIndex(); |
| } |
| |
| void ConstantIndexOp::build(Builder *builder, OperationState *result, |
| int64_t value) { |
| Type type = builder->getIndexType(); |
| ConstantOp::build(builder, result, type, |
| builder->getIntegerAttr(type, value)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DeallocOp |
| //===----------------------------------------------------------------------===// |
| namespace { |
| /// Fold Dealloc operations that are deallocating an AllocOp that is only used |
| /// by other Dealloc operations. |
| struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> { |
| using OpRewritePattern<DeallocOp>::OpRewritePattern; |
| |
| PatternMatchResult matchAndRewrite(DeallocOp dealloc, |
| PatternRewriter &rewriter) const override { |
| // Check that the memref operand's defining operation is an AllocOp. |
| Value *memref = dealloc.memref(); |
| if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp())) |
| return matchFailure(); |
| |
| // Check that all of the uses of the AllocOp are other DeallocOps. |
| for (auto *user : memref->getUsers()) |
| if (!isa<DeallocOp>(user)) |
| return matchFailure(); |
| |
| // Erase the dealloc operation. |
| rewriter.replaceOp(dealloc, llvm::None); |
| return matchSuccess(); |
| } |
| }; |
| } // end anonymous namespace. |
| |
| static void print(OpAsmPrinter *p, DeallocOp op) { |
| *p << "dealloc " << *op.memref() << " : " << op.memref()->getType(); |
| } |
| |
| static ParseResult parseDeallocOp(OpAsmParser *parser, OperationState *result) { |
| OpAsmParser::OperandType memrefInfo; |
| MemRefType type; |
| |
| return failure(parser->parseOperand(memrefInfo) || |
| parser->parseColonType(type) || |
| parser->resolveOperand(memrefInfo, type, result->operands)); |
| } |
| |
| static LogicalResult verify(DeallocOp op) { |
| if (!op.memref()->getType().isa<MemRefType>()) |
| return op.emitOpError("operand must be a memref"); |
| return success(); |
| } |
| |
| void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| /// dealloc(memrefcast) -> dealloc |
| results.push_back( |
| llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); |
| results.push_back(llvm::make_unique<SimplifyDeadDealloc>(context)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DimOp |
| //===----------------------------------------------------------------------===// |
| |
| static void print(OpAsmPrinter *p, DimOp op) { |
| *p << "dim " << *op.getOperand() << ", " << op.getIndex(); |
| p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); |
| *p << " : " << op.getOperand()->getType(); |
| } |
| |
| static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) { |
| OpAsmParser::OperandType operandInfo; |
| IntegerAttr indexAttr; |
| Type type; |
| Type indexType = parser->getBuilder().getIndexType(); |
| |
| return failure(parser->parseOperand(operandInfo) || parser->parseComma() || |
| parser->parseAttribute(indexAttr, indexType, "index", |
| result->attributes) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(type) || |
| parser->resolveOperand(operandInfo, type, result->operands) || |
| parser->addTypeToList(indexType, result->types)); |
| } |
| |
| static LogicalResult verify(DimOp op) { |
| // Check that we have an integer index operand. |
| auto indexAttr = op.getAttrOfType<IntegerAttr>("index"); |
| if (!indexAttr) |
| return op.emitOpError("requires an integer attribute named 'index'"); |
| uint64_t index = indexAttr.getValue().getZExtValue(); |
| |
| auto type = op.getOperand()->getType(); |
| if (auto tensorType = type.dyn_cast<RankedTensorType>()) { |
| if (index >= static_cast<uint64_t>(tensorType.getRank())) |
| return op.emitOpError("index is out of range"); |
| } else if (auto memrefType = type.dyn_cast<MemRefType>()) { |
| if (index >= memrefType.getRank()) |
| return op.emitOpError("index is out of range"); |
| |
| } else if (type.isa<UnrankedTensorType>()) { |
| // ok, assumed to be in-range. |
| } else { |
| return op.emitOpError("requires an operand with tensor or memref type"); |
| } |
| |
| return success(); |
| } |
| |
| OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { |
| // Constant fold dim when the size along the index referred to is a constant. |
| auto opType = getOperand()->getType(); |
| int64_t indexSize = -1; |
| if (auto tensorType = opType.dyn_cast<RankedTensorType>()) |
| indexSize = tensorType.getShape()[getIndex()]; |
| else if (auto memrefType = opType.dyn_cast<MemRefType>()) |
| indexSize = memrefType.getShape()[getIndex()]; |
| |
| if (indexSize >= 0) |
| return IntegerAttr::get(IndexType::get(getContext()), indexSize); |
| |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DivISOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult DivISOp::fold(ArrayRef<Attribute> operands) { |
| assert(operands.size() == 2 && "binary operation takes two operands"); |
| |
| auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); |
| auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); |
| if (!lhs || !rhs) |
| return {}; |
| |
| // Don't fold if it requires division by zero. |
| if (rhs.getValue().isNullValue()) |
| return {}; |
| |
| // Don't fold if it would overflow. |
| bool overflow; |
| auto result = lhs.getValue().sdiv_ov(rhs.getValue(), overflow); |
| return overflow ? IntegerAttr{} : IntegerAttr::get(lhs.getType(), result); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DivIUOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult DivIUOp::fold(ArrayRef<Attribute> operands) { |
| assert(operands.size() == 2 && "binary operation takes two operands"); |
| |
| auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); |
| auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); |
| if (!lhs || !rhs) |
| return {}; |
| |
| // Don't fold if it requires division by zero. |
| if (rhs.getValue().isNullValue()) { |
| return {}; |
| } |
| |
| return IntegerAttr::get(lhs.getType(), lhs.getValue().udiv(rhs.getValue())); |
| } |
| |
| // --------------------------------------------------------------------------- |
| // DmaStartOp |
| // --------------------------------------------------------------------------- |
| |
| void DmaStartOp::build(Builder *builder, OperationState *result, |
| Value *srcMemRef, ArrayRef<Value *> srcIndices, |
| Value *destMemRef, ArrayRef<Value *> destIndices, |
| Value *numElements, Value *tagMemRef, |
| ArrayRef<Value *> tagIndices, Value *stride, |
| Value *elementsPerStride) { |
| result->addOperands(srcMemRef); |
| result->addOperands(srcIndices); |
| result->addOperands(destMemRef); |
| result->addOperands(destIndices); |
| result->addOperands(numElements); |
| result->addOperands(tagMemRef); |
| result->addOperands(tagIndices); |
| if (stride) { |
| result->addOperands(stride); |
| result->addOperands(elementsPerStride); |
| } |
| } |
| |
| void DmaStartOp::print(OpAsmPrinter *p) { |
| *p << "dma_start " << *getSrcMemRef() << '['; |
| p->printOperands(getSrcIndices()); |
| *p << "], " << *getDstMemRef() << '['; |
| p->printOperands(getDstIndices()); |
| *p << "], " << *getNumElements(); |
| *p << ", " << *getTagMemRef() << '['; |
| p->printOperands(getTagIndices()); |
| *p << ']'; |
| if (isStrided()) { |
| *p << ", " << *getStride(); |
| *p << ", " << *getNumElementsPerStride(); |
| } |
| p->printOptionalAttrDict(getAttrs()); |
| *p << " : " << getSrcMemRef()->getType(); |
| *p << ", " << getDstMemRef()->getType(); |
| *p << ", " << getTagMemRef()->getType(); |
| } |
| |
| // Parse DmaStartOp. |
| // Ex: |
| // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, |
| // %tag[%index], %stride, %num_elt_per_stride : |
| // : memref<3076 x f32, 0>, |
| // memref<1024 x f32, 2>, |
| // memref<1 x i32> |
| // |
| ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { |
| OpAsmParser::OperandType srcMemRefInfo; |
| SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; |
| OpAsmParser::OperandType dstMemRefInfo; |
| SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; |
| OpAsmParser::OperandType numElementsInfo; |
| OpAsmParser::OperandType tagMemrefInfo; |
| SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; |
| SmallVector<OpAsmParser::OperandType, 2> strideInfo; |
| |
| SmallVector<Type, 3> types; |
| auto indexType = parser->getBuilder().getIndexType(); |
| |
| // Parse and resolve the following list of operands: |
| // *) source memref followed by its indices (in square brackets). |
| // *) destination memref followed by its indices (in square brackets). |
| // *) dma size in KiB. |
| if (parser->parseOperand(srcMemRefInfo) || |
| parser->parseOperandList(srcIndexInfos, -1, |
| OpAsmParser::Delimiter::Square) || |
| parser->parseComma() || parser->parseOperand(dstMemRefInfo) || |
| parser->parseOperandList(dstIndexInfos, -1, |
| OpAsmParser::Delimiter::Square) || |
| parser->parseComma() || parser->parseOperand(numElementsInfo) || |
| parser->parseComma() || parser->parseOperand(tagMemrefInfo) || |
| parser->parseOperandList(tagIndexInfos, -1, |
| OpAsmParser::Delimiter::Square)) |
| return failure(); |
| |
| // Parse optional stride and elements per stride. |
| if (parser->parseTrailingOperandList(strideInfo)) { |
| return failure(); |
| } |
| if (!strideInfo.empty() && strideInfo.size() != 2) { |
| return parser->emitError(parser->getNameLoc(), |
| "expected two stride related operands"); |
| } |
| bool isStrided = strideInfo.size() == 2; |
| |
| if (parser->parseColonTypeList(types)) |
| return failure(); |
| |
| if (types.size() != 3) |
| return parser->emitError(parser->getNameLoc(), "fewer/more types expected"); |
| |
| if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || |
| parser->resolveOperands(srcIndexInfos, indexType, result->operands) || |
| parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || |
| parser->resolveOperands(dstIndexInfos, indexType, result->operands) || |
| // size should be an index. |
| parser->resolveOperand(numElementsInfo, indexType, result->operands) || |
| parser->resolveOperand(tagMemrefInfo, types[2], result->operands) || |
| // tag indices should be index. |
| parser->resolveOperands(tagIndexInfos, indexType, result->operands)) |
| return failure(); |
| |
| if (!types[0].isa<MemRefType>()) |
| return parser->emitError(parser->getNameLoc(), |
| "expected source to be of memref type"); |
| |
| if (!types[1].isa<MemRefType>()) |
| return parser->emitError(parser->getNameLoc(), |
| "expected destination to be of memref type"); |
| |
| if (!types[2].isa<MemRefType>()) |
| return parser->emitError(parser->getNameLoc(), |
| "expected tag to be of memref type"); |
| |
| if (isStrided) { |
| if (parser->resolveOperand(strideInfo[0], indexType, result->operands) || |
| parser->resolveOperand(strideInfo[1], indexType, result->operands)) |
| return failure(); |
| } |
| |
| // Check that source/destination index list size matches associated rank. |
| if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() || |
| dstIndexInfos.size() != types[1].cast<MemRefType>().getRank()) |
| return parser->emitError(parser->getNameLoc(), |
| "memref rank not equal to indices count"); |
| |
| if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank()) |
| return parser->emitError(parser->getNameLoc(), |
| "tag memref rank not equal to indices count"); |
| |
| return success(); |
| } |
| |
| LogicalResult DmaStartOp::verify() { |
| // DMAs from different memory spaces supported. |
| if (getSrcMemorySpace() == getDstMemorySpace()) { |
| return emitOpError("DMA should be between different memory spaces"); |
| } |
| |
| if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + |
| getDstMemRefRank() + 3 + 1 && |
| getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + |
| getDstMemRefRank() + 3 + 1 + 2) { |
| return emitOpError("incorrect number of operands"); |
| } |
| return success(); |
| } |
| |
| void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| /// dma_start(memrefcast) -> dma_start |
| results.push_back( |
| llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); |
| } |
| |
| // --------------------------------------------------------------------------- |
| // DmaWaitOp |
| // --------------------------------------------------------------------------- |
| |
| void DmaWaitOp::build(Builder *builder, OperationState *result, |
| Value *tagMemRef, ArrayRef<Value *> tagIndices, |
| Value *numElements) { |
| result->addOperands(tagMemRef); |
| result->addOperands(tagIndices); |
| result->addOperands(numElements); |
| } |
| |
| void DmaWaitOp::print(OpAsmPrinter *p) { |
| *p << "dma_wait "; |
| // Print operands. |
| p->printOperand(getTagMemRef()); |
| *p << '['; |
| p->printOperands(getTagIndices()); |
| *p << "], "; |
| p->printOperand(getNumElements()); |
| p->printOptionalAttrDict(getAttrs()); |
| *p << " : " << getTagMemRef()->getType(); |
| } |
| |
| // Parse DmaWaitOp. |
| // Eg: |
| // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> |
| // |
| ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { |
| OpAsmParser::OperandType tagMemrefInfo; |
| SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; |
| Type type; |
| auto indexType = parser->getBuilder().getIndexType(); |
| OpAsmParser::OperandType numElementsInfo; |
| |
| // Parse tag memref, its indices, and dma size. |
| if (parser->parseOperand(tagMemrefInfo) || |
| parser->parseOperandList(tagIndexInfos, -1, |
| OpAsmParser::Delimiter::Square) || |
| parser->parseComma() || parser->parseOperand(numElementsInfo) || |
| parser->parseColonType(type) || |
| parser->resolveOperand(tagMemrefInfo, type, result->operands) || |
| parser->resolveOperands(tagIndexInfos, indexType, result->operands) || |
| parser->resolveOperand(numElementsInfo, indexType, result->operands)) |
| return failure(); |
| |
| if (!type.isa<MemRefType>()) |
| return parser->emitError(parser->getNameLoc(), |
| "expected tag to be of memref type"); |
| |
| if (tagIndexInfos.size() != type.cast<MemRefType>().getRank()) |
| return parser->emitError(parser->getNameLoc(), |
| "tag memref rank not equal to indices count"); |
| |
| return success(); |
| } |
| |
| void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| /// dma_wait(memrefcast) -> dma_wait |
| results.push_back( |
| llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ExtractElementOp |
| //===----------------------------------------------------------------------===// |
| |
| static void print(OpAsmPrinter *p, ExtractElementOp op) { |
| *p << "extract_element " << *op.getAggregate() << '['; |
| p->printOperands(op.getIndices()); |
| *p << ']'; |
| p->printOptionalAttrDict(op.getAttrs()); |
| *p << " : " << op.getAggregate()->getType(); |
| } |
| |
| static ParseResult parseExtractElementOp(OpAsmParser *parser, |
| OperationState *result) { |
| OpAsmParser::OperandType aggregateInfo; |
| SmallVector<OpAsmParser::OperandType, 4> indexInfo; |
| ShapedType type; |
| |
| auto affineIntTy = parser->getBuilder().getIndexType(); |
| return failure( |
| parser->parseOperand(aggregateInfo) || |
| parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(type) || |
| parser->resolveOperand(aggregateInfo, type, result->operands) || |
| parser->resolveOperands(indexInfo, affineIntTy, result->operands) || |
| parser->addTypeToList(type.getElementType(), result->types)); |
| } |
| |
| static LogicalResult verify(ExtractElementOp op) { |
| auto aggregateType = op.getAggregate()->getType().cast<ShapedType>(); |
| |
| // This should be possible with tablegen type constraints |
| if (op.getType() != aggregateType.getElementType()) |
| return op.emitOpError("result type must match element type of aggregate"); |
| |
| // TODO(b/132908002) This should be covered by the op specification in |
| // tablegen, but for some reason it's not. |
| for (auto *idx : op.getIndices()) |
| if (!idx->getType().isIndex()) |
| return op.emitOpError("index to extract_element must have 'index' type"); |
| |
| // Verify the # indices match if we have a ranked type. |
| auto aggregateRank = aggregateType.getRank(); |
| if (aggregateRank != -1 && aggregateRank != op.getNumOperands() - 1) |
| return op.emitOpError("incorrect number of indices for extract_element"); |
| |
| return success(); |
| } |
| |
| OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) { |
| assert(!operands.empty() && "extract_element takes atleast one operand"); |
| |
| // The aggregate operand must be a known constant. |
| Attribute aggregate = operands.front(); |
| if (!aggregate) |
| return {}; |
| |
| // If this is a splat elements attribute, simply return the value. All of the |
| // elements of a splat attribute are the same. |
| if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>()) |
| return splatAggregate.getValue(); |
| |
| // Otherwise, collect the constant indices into the aggregate. |
| SmallVector<uint64_t, 8> indices; |
| for (Attribute indice : llvm::drop_begin(operands, 1)) { |
| if (!indice || !indice.isa<IntegerAttr>()) |
| return {}; |
| indices.push_back(indice.cast<IntegerAttr>().getInt()); |
| } |
| |
| // If this is an elements attribute, query the value at the given indices. |
| if (auto elementsAttr = aggregate.dyn_cast<ElementsAttr>()) |
| return elementsAttr.getValue(indices); |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LoadOp |
| //===----------------------------------------------------------------------===// |
| |
| static void print(OpAsmPrinter *p, LoadOp op) { |
| *p << "load " << *op.getMemRef() << '['; |
| p->printOperands(op.getIndices()); |
| *p << ']'; |
| p->printOptionalAttrDict(op.getAttrs()); |
| *p << " : " << op.getMemRefType(); |
| } |
| |
| static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) { |
| OpAsmParser::OperandType memrefInfo; |
| SmallVector<OpAsmParser::OperandType, 4> indexInfo; |
| MemRefType type; |
| |
| auto affineIntTy = parser->getBuilder().getIndexType(); |
| return failure( |
| parser->parseOperand(memrefInfo) || |
| parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(type) || |
| parser->resolveOperand(memrefInfo, type, result->operands) || |
| parser->resolveOperands(indexInfo, affineIntTy, result->operands) || |
| parser->addTypeToList(type.getElementType(), result->types)); |
| } |
| |
| static LogicalResult verify(LoadOp op) { |
| if (op.getType() != op.getMemRefType().getElementType()) |
| return op.emitOpError("result type must match element type of memref"); |
| |
| if (op.getMemRefType().getRank() != op.getNumOperands() - 1) |
| return op.emitOpError("incorrect number of indices for load"); |
| |
| for (auto *idx : op.getIndices()) |
| if (!idx->getType().isIndex()) |
| return op.emitOpError("index to load must have 'index' type"); |
| |
| // TODO: Verify we have the right number of indices. |
| |
| // TODO: in Function verify that the indices are parameters, IV's, or the |
| // result of an affine.apply. |
| return success(); |
| } |
| |
| void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| /// load(memrefcast) -> load |
| results.push_back( |
| llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MemRefCastOp |
| //===----------------------------------------------------------------------===// |
| |
| bool MemRefCastOp::areCastCompatible(Type a, Type b) { |
| auto aT = a.dyn_cast<MemRefType>(); |
| auto bT = b.dyn_cast<MemRefType>(); |
| |
| if (!aT || !bT) |
| return false; |
| if (aT.getElementType() != bT.getElementType()) |
| return false; |
| if (aT.getAffineMaps() != bT.getAffineMaps()) |
| return false; |
| if (aT.getMemorySpace() != bT.getMemorySpace()) |
| return false; |
| |
| // They must have the same rank, and any specified dimensions must match. |
| if (aT.getRank() != bT.getRank()) |
| return false; |
| |
| for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { |
| int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); |
| if (aDim != -1 && bDim != -1 && aDim != bDim) |
| return false; |
| } |
| |
| return true; |
| } |
| |
| OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) { |
| return impl::foldCastOp(*this); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MulFOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldBinaryOp<FloatAttr>( |
| operands, [](APFloat a, APFloat b) { return a * b; }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MulIOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) { |
| /// muli(x, 0) -> 0 |
| if (matchPattern(rhs(), m_Zero())) |
| return rhs(); |
| /// muli(x, 1) -> x |
| if (matchPattern(rhs(), m_One())) |
| return getOperand(0); |
| |
| // TODO: Handle the overflow case. |
| return constFoldBinaryOp<IntegerAttr>(operands, |
| [](APInt a, APInt b) { return a * b; }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RemISOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult RemISOp::fold(ArrayRef<Attribute> operands) { |
| assert(operands.size() == 2 && "remis takes two operands"); |
| |
| auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); |
| if (!rhs) |
| return {}; |
| |
| // x % 1 = 0 |
| if (rhs.getValue().isOneValue()) |
| return IntegerAttr::get(rhs.getType(), |
| APInt(rhs.getValue().getBitWidth(), 0)); |
| |
| // Don't fold if it requires division by zero. |
| if (rhs.getValue().isNullValue()) |
| return {}; |
| |
| auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); |
| if (!lhs) |
| return {}; |
| |
| return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhs.getValue())); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RemIUOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult RemIUOp::fold(ArrayRef<Attribute> operands) { |
| assert(operands.size() == 2 && "remiu takes two operands"); |
| |
| auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); |
| if (!rhs) |
| return {}; |
| |
| // x % 1 = 0 |
| if (rhs.getValue().isOneValue()) |
| return IntegerAttr::get(rhs.getType(), |
| APInt(rhs.getValue().getBitWidth(), 0)); |
| |
| // Don't fold if it requires division by zero. |
| if (rhs.getValue().isNullValue()) |
| return {}; |
| |
| auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); |
| if (!lhs) |
| return {}; |
| |
| return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhs.getValue())); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ReturnOp |
| //===----------------------------------------------------------------------===// |
| |
| static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) { |
| SmallVector<OpAsmParser::OperandType, 2> opInfo; |
| SmallVector<Type, 2> types; |
| llvm::SMLoc loc; |
| return failure(parser->getCurrentLocation(&loc) || |
| parser->parseOperandList(opInfo) || |
| (!opInfo.empty() && parser->parseColonTypeList(types)) || |
| parser->resolveOperands(opInfo, types, loc, result->operands)); |
| } |
| |
| static void print(OpAsmPrinter *p, ReturnOp op) { |
| *p << "return"; |
| if (op.getNumOperands() > 0) { |
| *p << ' '; |
| p->printOperands(op.operand_begin(), op.operand_end()); |
| *p << " : "; |
| interleave( |
| op.operand_begin(), op.operand_end(), |
| [&](Value *e) { p->printType(e->getType()); }, [&]() { *p << ", "; }); |
| } |
| } |
| |
| static LogicalResult verify(ReturnOp op) { |
| auto *function = op.getOperation()->getFunction(); |
| |
| // The operand number and types must match the function signature. |
| const auto &results = function->getType().getResults(); |
| if (op.getNumOperands() != results.size()) |
| return op.emitOpError("has ") |
| << op.getNumOperands() |
| << " operands, but enclosing function returns " << results.size(); |
| |
| for (unsigned i = 0, e = results.size(); i != e; ++i) |
| if (op.getOperand(i)->getType() != results[i]) |
| return op.emitError() |
| << "type of return operand " << i << " (" |
| << op.getOperand(i)->getType() |
| << ") doesn't match function result type (" << results[i] << ")"; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SelectOp |
| //===----------------------------------------------------------------------===// |
| |
| static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) { |
| SmallVector<OpAsmParser::OperandType, 3> ops; |
| SmallVector<NamedAttribute, 4> attrs; |
| Type type; |
| |
| if (parser->parseOperandList(ops, 3) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(type)) |
| return failure(); |
| |
| auto i1Type = getCheckedI1SameShape(&parser->getBuilder(), type); |
| if (!i1Type) |
| return parser->emitError(parser->getNameLoc(), |
| "expected type with valid i1 shape"); |
| |
| SmallVector<Type, 3> types = {i1Type, type, type}; |
| return failure(parser->resolveOperands(ops, types, parser->getNameLoc(), |
| result->operands) || |
| parser->addTypeToList(type, result->types)); |
| } |
| |
| static void print(OpAsmPrinter *p, SelectOp op) { |
| *p << "select "; |
| p->printOperands(op.getOperands()); |
| *p << " : " << op.getTrueValue()->getType(); |
| p->printOptionalAttrDict(op.getAttrs()); |
| } |
| |
| static LogicalResult verify(SelectOp op) { |
| auto trueType = op.getTrueValue()->getType(); |
| auto falseType = op.getFalseValue()->getType(); |
| |
| if (trueType != falseType) |
| return op.emitOpError( |
| "requires 'true' and 'false' arguments to be of the same type"); |
| |
| return success(); |
| } |
| |
| OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) { |
| auto *condition = getCondition(); |
| |
| // select true, %0, %1 => %0 |
| if (matchPattern(condition, m_One())) |
| return getTrueValue(); |
| |
| // select false, %0, %1 => %1 |
| if (matchPattern(condition, m_Zero())) |
| return getFalseValue(); |
| return nullptr; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StoreOp |
| //===----------------------------------------------------------------------===// |
| |
| static void print(OpAsmPrinter *p, StoreOp op) { |
| *p << "store " << *op.getValueToStore(); |
| *p << ", " << *op.getMemRef() << '['; |
| p->printOperands(op.getIndices()); |
| *p << ']'; |
| p->printOptionalAttrDict(op.getAttrs()); |
| *p << " : " << op.getMemRefType(); |
| } |
| |
| static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) { |
| OpAsmParser::OperandType storeValueInfo; |
| OpAsmParser::OperandType memrefInfo; |
| SmallVector<OpAsmParser::OperandType, 4> indexInfo; |
| MemRefType memrefType; |
| |
| auto affineIntTy = parser->getBuilder().getIndexType(); |
| return failure( |
| parser->parseOperand(storeValueInfo) || parser->parseComma() || |
| parser->parseOperand(memrefInfo) || |
| parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(memrefType) || |
| parser->resolveOperand(storeValueInfo, memrefType.getElementType(), |
| result->operands) || |
| parser->resolveOperand(memrefInfo, memrefType, result->operands) || |
| parser->resolveOperands(indexInfo, affineIntTy, result->operands)); |
| } |
| |
| static LogicalResult verify(StoreOp op) { |
| // First operand must have same type as memref element type. |
| if (op.getValueToStore()->getType() != op.getMemRefType().getElementType()) |
| return op.emitOpError( |
| "first operand must have same type memref element type"); |
| |
| if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) |
| return op.emitOpError("store index operand count not equal to memref rank"); |
| |
| for (auto *idx : op.getIndices()) |
| if (!idx->getType().isIndex()) |
| return op.emitOpError("index to load must have 'index' type"); |
| |
| // TODO: Verify we have the right number of indices. |
| |
| // TODO: in Function verify that the indices are parameters, IV's, or the |
| // result of an affine.apply. |
| return success(); |
| } |
| |
| void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| /// store(memrefcast) -> store |
| results.push_back( |
| llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SubFOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldBinaryOp<FloatAttr>( |
| operands, [](APFloat a, APFloat b) { return a - b; }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SubIOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) { |
| // subi(x,x) -> 0 |
| if (getOperand(0) == getOperand(1)) |
| return Builder(getContext()).getZeroAttr(getType()); |
| |
| return constFoldBinaryOp<IntegerAttr>(operands, |
| [](APInt a, APInt b) { return a - b; }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AndOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) { |
| /// and(x, 0) -> 0 |
| if (matchPattern(rhs(), m_Zero())) |
| return rhs(); |
| /// and(x,x) -> x |
| if (lhs() == rhs()) |
| return rhs(); |
| |
| return constFoldBinaryOp<IntegerAttr>(operands, |
| [](APInt a, APInt b) { return a & b; }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OrOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) { |
| /// or(x, 0) -> x |
| if (matchPattern(rhs(), m_Zero())) |
| return lhs(); |
| /// or(x,x) -> x |
| if (lhs() == rhs()) |
| return rhs(); |
| |
| return constFoldBinaryOp<IntegerAttr>(operands, |
| [](APInt a, APInt b) { return a | b; }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XOrOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) { |
| /// xor(x, 0) -> x |
| if (matchPattern(rhs(), m_Zero())) |
| return lhs(); |
| /// xor(x,x) -> 0 |
| if (lhs() == rhs()) |
| return Builder(getContext()).getZeroAttr(getType()); |
| |
| return constFoldBinaryOp<IntegerAttr>(operands, |
| [](APInt a, APInt b) { return a ^ b; }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TensorCastOp |
| //===----------------------------------------------------------------------===// |
| |
| bool TensorCastOp::areCastCompatible(Type a, Type b) { |
| auto aT = a.dyn_cast<TensorType>(); |
| auto bT = b.dyn_cast<TensorType>(); |
| if (!aT || !bT) |
| return false; |
| |
| if (aT.getElementType() != bT.getElementType()) |
| return false; |
| |
| // If the either are unranked, then the cast is valid. |
| auto aRType = aT.dyn_cast<RankedTensorType>(); |
| auto bRType = bT.dyn_cast<RankedTensorType>(); |
| if (!aRType || !bRType) |
| return true; |
| |
| // If they are both ranked, they have to have the same rank, and any specified |
| // dimensions must match. |
| if (aRType.getRank() != bRType.getRank()) |
| return false; |
| |
| for (unsigned i = 0, e = aRType.getRank(); i != e; ++i) { |
| int64_t aDim = aRType.getDimSize(i), bDim = bRType.getDimSize(i); |
| if (aDim != -1 && bDim != -1 && aDim != bDim) |
| return false; |
| } |
| |
| return true; |
| } |
| |
| OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) { |
| return impl::foldCastOp(*this); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TableGen'd op method definitions |
| //===----------------------------------------------------------------------===// |
| |
| #define GET_OP_CLASSES |
| #include "mlir/StandardOps/Ops.cpp.inc" |