| //===- StandardOps.cpp - Standard MLIR Operations -------------------------===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| |
| #include "mlir/IR/StandardOps.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/OperationSet.h" |
| #include "mlir/IR/SSAValue.h" |
| #include "mlir/IR/Types.h" |
| #include "mlir/Support/STLExtras.h" |
| #include "llvm/Support/raw_ostream.h" |
| |
| using namespace mlir; |
| |
| static void printDimAndSymbolList(Operation::const_operand_iterator begin, |
| Operation::const_operand_iterator end, |
| unsigned numDims, OpAsmPrinter *p) { |
| *p << '('; |
| p->printOperands(begin, begin + numDims); |
| *p << ')'; |
| |
| if (begin + numDims != end) { |
| *p << '['; |
| p->printOperands(begin + numDims, end); |
| *p << ']'; |
| } |
| } |
| |
| // Parses dimension and symbol list, and sets 'numDims' to the number of |
| // dimension operands parsed. |
| // Returns 'false' on success and 'true' on error. |
| static bool |
| parseDimAndSymbolList(OpAsmParser *parser, |
| SmallVector<SSAValue *, 4> &operands, unsigned &numDims) { |
| SmallVector<OpAsmParser::OperandType, 8> opInfos; |
| if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) |
| return true; |
| // Store number of dimensions for validation by caller. |
| numDims = opInfos.size(); |
| |
| // Parse the optional symbol operands. |
| auto *affineIntTy = parser->getBuilder().getAffineIntType(); |
| if (parser->parseOperandList(opInfos, -1, |
| OpAsmParser::Delimiter::OptionalSquare) || |
| parser->resolveOperands(opInfos, affineIntTy, operands)) |
| return true; |
| return false; |
| } |
| |
| /// If this is a vector type, or a tensor type, return the scalar element type |
| /// that it is built around, otherwise return the type unmodified. |
| static Type *getTensorOrVectorElementType(Type *type) { |
| if (auto *vec = dyn_cast<VectorType>(type)) |
| return vec->getElementType(); |
| |
| // Look through tensor<vector<...>> to find the underlying element type. |
| if (auto *tensor = dyn_cast<TensorType>(type)) |
| return getTensorOrVectorElementType(tensor->getElementType()); |
| return type; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AddFOp |
| //===----------------------------------------------------------------------===// |
| |
| void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs, |
| SSAValue *rhs) { |
| assert(lhs->getType() == rhs->getType()); |
| result->addOperands({lhs, rhs}); |
| result->types.push_back(lhs->getType()); |
| } |
| |
| bool AddFOp::parse(OpAsmParser *parser, OperationState *result) { |
| SmallVector<OpAsmParser::OperandType, 2> ops; |
| Type *type; |
| return parser->parseOperandList(ops, 2) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(type) || |
| parser->resolveOperands(ops, type, result->operands) || |
| parser->addTypeToList(type, result->types); |
| } |
| |
| void AddFOp::print(OpAsmPrinter *p) const { |
| *p << "addf " << *getOperand(0) << ", " << *getOperand(1); |
| p->printOptionalAttrDict(getAttrs()); |
| *p << " : " << *getType(); |
| } |
| |
| bool AddFOp::verify() const { |
| if (!isa<FloatType>(getTensorOrVectorElementType(getType()))) |
| return emitOpError("requires a floating point type"); |
| |
| return false; |
| } |
| |
| Attribute *AddFOp::constantFold(ArrayRef<Attribute *> operands, |
| MLIRContext *context) const { |
| assert(operands.size() == 2 && "addf takes two operands"); |
| |
| if (auto *lhs = dyn_cast<FloatAttr>(operands[0])) { |
| if (auto *rhs = dyn_cast<FloatAttr>(operands[1])) |
| return FloatAttr::get(lhs->getValue() + rhs->getValue(), context); |
| } |
| |
| return nullptr; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineApplyOp |
| //===----------------------------------------------------------------------===// |
| |
| bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { |
| auto &builder = parser->getBuilder(); |
| auto *affineIntTy = builder.getAffineIntType(); |
| |
| AffineMapAttr *mapAttr; |
| unsigned numDims; |
| if (parser->parseAttribute(mapAttr, "map", result->attributes) || |
| parseDimAndSymbolList(parser, result->operands, numDims) || |
| parser->parseOptionalAttributeDict(result->attributes)) |
| return true; |
| auto *map = mapAttr->getValue(); |
| |
| if (map->getNumDims() != numDims || |
| numDims + map->getNumSymbols() != result->operands.size()) { |
| return parser->emitError(parser->getNameLoc(), |
| "dimension or symbol index mismatch"); |
| } |
| |
| result->types.append(map->getNumResults(), affineIntTy); |
| return false; |
| } |
| |
| void AffineApplyOp::print(OpAsmPrinter *p) const { |
| auto *map = getAffineMap(); |
| *p << "affine_apply " << *map; |
| printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p); |
| p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map"); |
| } |
| |
| bool AffineApplyOp::verify() const { |
| // Check that affine map attribute was specified. |
| auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map"); |
| if (!affineMapAttr) |
| return emitOpError("requires an affine map"); |
| |
| // Check input and output dimensions match. |
| auto *map = affineMapAttr->getValue(); |
| |
| // Verify that operand count matches affine map dimension and symbol count. |
| if (getNumOperands() != map->getNumDims() + map->getNumSymbols()) |
| return emitOpError( |
| "operand count and affine map dimension and symbol count must match"); |
| |
| // Verify that result count matches affine map result count. |
| if (getNumResults() != map->getNumResults()) |
| return emitOpError("result count and affine map result count must match"); |
| |
| return false; |
| } |
| |
| // The result of the affine apply operation can be used as a dimension id if it |
| // is a CFG value or if it is an MLValue, and all the operands are valid |
| // dimension ids. |
| bool AffineApplyOp::isValidDim() const { |
| for (auto *op : getOperands()) { |
| if (auto *v = dyn_cast<MLValue>(op)) |
| if (!v->isValidDim()) |
| return false; |
| } |
| return true; |
| } |
| |
| // The result of the affine apply operation can be used as a symbol if it is |
| // a CFG value or if it is an MLValue, and all the operands are symbols. |
| bool AffineApplyOp::isValidSymbol() const { |
| for (auto *op : getOperands()) { |
| if (auto *v = dyn_cast<MLValue>(op)) |
| if (!v->isValidSymbol()) |
| return false; |
| } |
| return true; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AllocOp |
| //===----------------------------------------------------------------------===// |
| |
| void AllocOp::build(Builder *builder, OperationState *result, |
| MemRefType *memrefType, ArrayRef<SSAValue *> operands) { |
| result->addOperands(operands); |
| result->types.push_back(memrefType); |
| } |
| |
| void AllocOp::print(OpAsmPrinter *p) const { |
| MemRefType *type = cast<MemRefType>(getMemRef()->getType()); |
| *p << "alloc"; |
| // Print dynamic dimension operands. |
| printDimAndSymbolList(operand_begin(), operand_end(), |
| type->getNumDynamicDims(), p); |
| p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map"); |
| *p << " : " << *type; |
| } |
| |
| bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { |
| MemRefType *type; |
| |
| // Parse the dimension operands and optional symbol operands, followed by a |
| // memref type. |
| unsigned numDimOperands; |
| if (parseDimAndSymbolList(parser, result->operands, numDimOperands) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(type)) |
| return true; |
| |
| // Check numDynamicDims against number of question marks in memref type. |
| // Note: this check remains here (instead of in verify()), because the |
| // partition between dim operands and symbol operands is lost after parsing. |
| // Verification still checks that the total number of operands matches |
| // the number of symbols in the affine map, plus the number of dynamic |
| // dimensions in the memref. |
| if (numDimOperands != type->getNumDynamicDims()) { |
| return parser->emitError(parser->getNameLoc(), |
| "dimension operand count does not equal memref " |
| "dynamic dimension count"); |
| } |
| result->types.push_back(type); |
| return false; |
| } |
| |
| bool AllocOp::verify() const { |
| auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); |
| if (!memRefType) |
| return emitOpError("result must be a memref"); |
| |
| unsigned numSymbols = 0; |
| if (!memRefType->getAffineMaps().empty()) { |
| AffineMap *affineMap = memRefType->getAffineMaps()[0]; |
| // Store number of symbols used in affine map (used in subsequent check). |
| numSymbols = affineMap->getNumSymbols(); |
| // Verify that the layout affine map matches the rank of the memref. |
| if (affineMap->getNumDims() != memRefType->getRank()) |
| return emitOpError("affine map dimension count must equal memref rank"); |
| } |
| unsigned numDynamicDims = memRefType->getNumDynamicDims(); |
| // Check that the total number of operands matches the number of symbols in |
| // the affine map, plus the number of dynamic dimensions specified in the |
| // memref type. |
| if (getOperation()->getNumOperands() != numDynamicDims + numSymbols) { |
| return emitOpError( |
| "operand count does not equal dimension plus symbol operand count"); |
| } |
| // Verify that all operands are of type AffineInt. |
| for (auto *operand : getOperands()) { |
| if (!operand->getType()->isAffineInt()) |
| return emitOpError("requires operands to be of type AffineInt"); |
| } |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CallOp |
| //===----------------------------------------------------------------------===// |
| |
| void CallOp::build(Builder *builder, OperationState *result, Function *callee, |
| ArrayRef<SSAValue *> operands) { |
| result->addOperands(operands); |
| result->addAttribute("callee", builder->getFunctionAttr(callee)); |
| result->addTypes(callee->getType()->getResults()); |
| } |
| |
| bool CallOp::parse(OpAsmParser *parser, OperationState *result) { |
| StringRef calleeName; |
| llvm::SMLoc calleeLoc; |
| FunctionType *calleeType = nullptr; |
| SmallVector<OpAsmParser::OperandType, 4> operands; |
| Function *callee = nullptr; |
| if (parser->parseFunctionName(calleeName, calleeLoc) || |
| parser->parseOperandList(operands, /*requiredOperandCount=*/-1, |
| OpAsmParser::Delimiter::Paren) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(calleeType) || |
| parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) || |
| parser->addTypesToList(calleeType->getResults(), result->types) || |
| parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc, |
| result->operands)) |
| return true; |
| |
| result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee)); |
| return false; |
| } |
| |
| void CallOp::print(OpAsmPrinter *p) const { |
| *p << "call "; |
| p->printFunctionReference(getCallee()); |
| *p << '('; |
| p->printOperands(getOperands()); |
| *p << ')'; |
| p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee"); |
| *p << " : " << *getCallee()->getType(); |
| } |
| |
| bool CallOp::verify() const { |
| // Check that the callee attribute was specified. |
| auto *fnAttr = getAttrOfType<FunctionAttr>("callee"); |
| if (!fnAttr) |
| return emitOpError("requires a 'callee' function attribute"); |
| |
| // Verify that the operand and result types match the callee. |
| auto *fnType = fnAttr->getValue()->getType(); |
| if (fnType->getNumInputs() != getNumOperands()) |
| return emitOpError("incorrect number of operands for callee"); |
| |
| for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) { |
| if (getOperand(i)->getType() != fnType->getInput(i)) |
| return emitOpError("operand type mismatch"); |
| } |
| |
| if (fnType->getNumResults() != getNumResults()) |
| return emitOpError("incorrect number of results for callee"); |
| |
| for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) { |
| if (getResult(i)->getType() != fnType->getResult(i)) |
| return emitOpError("result type mismatch"); |
| } |
| |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CallIndirectOp |
| //===----------------------------------------------------------------------===// |
| |
| void CallIndirectOp::build(Builder *builder, OperationState *result, |
| SSAValue *callee, ArrayRef<SSAValue *> operands) { |
| auto *fnType = cast<FunctionType>(callee->getType()); |
| result->operands.push_back(callee); |
| result->addOperands(operands); |
| result->addTypes(fnType->getResults()); |
| } |
| |
| bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { |
| FunctionType *calleeType = nullptr; |
| OpAsmParser::OperandType callee; |
| llvm::SMLoc operandsLoc; |
| SmallVector<OpAsmParser::OperandType, 4> operands; |
| return parser->parseOperand(callee) || |
| parser->getCurrentLocation(&operandsLoc) || |
| parser->parseOperandList(operands, /*requiredOperandCount=*/-1, |
| OpAsmParser::Delimiter::Paren) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(calleeType) || |
| parser->resolveOperand(callee, calleeType, result->operands) || |
| parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc, |
| result->operands) || |
| parser->addTypesToList(calleeType->getResults(), result->types); |
| } |
| |
| void CallIndirectOp::print(OpAsmPrinter *p) const { |
| *p << "call_indirect "; |
| p->printOperand(getCallee()); |
| *p << '('; |
| auto operandRange = getOperands(); |
| p->printOperands(++operandRange.begin(), operandRange.end()); |
| *p << ')'; |
| p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee"); |
| *p << " : " << *getCallee()->getType(); |
| } |
| |
| bool CallIndirectOp::verify() const { |
| // The callee must be a function. |
| auto *fnType = dyn_cast<FunctionType>(getCallee()->getType()); |
| if (!fnType) |
| return emitOpError("callee must have function type"); |
| |
| // Verify that the operand and result types match the callee. |
| if (fnType->getNumInputs() != getNumOperands() - 1) |
| return emitOpError("incorrect number of operands for callee"); |
| |
| for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) { |
| if (getOperand(i + 1)->getType() != fnType->getInput(i)) |
| return emitOpError("operand type mismatch"); |
| } |
| |
| if (fnType->getNumResults() != getNumResults()) |
| return emitOpError("incorrect number of results for callee"); |
| |
| for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) { |
| if (getResult(i)->getType() != fnType->getResult(i)) |
| return emitOpError("result type mismatch"); |
| } |
| |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Constant*Op |
| //===----------------------------------------------------------------------===// |
| |
| /// Builds a constant op with the specified attribute value and result type. |
| void ConstantOp::build(Builder *builder, OperationState *result, |
| Attribute *value, Type *type) { |
| result->addAttribute("value", value); |
| result->types.push_back(type); |
| } |
| |
| void ConstantOp::print(OpAsmPrinter *p) const { |
| *p << "constant " << *getValue(); |
| p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value"); |
| |
| if (!isa<FunctionAttr>(getValue())) |
| *p << " : " << *getType(); |
| } |
| |
| bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) { |
| Attribute *valueAttr; |
| Type *type; |
| |
| if (parser->parseAttribute(valueAttr, "value", result->attributes) || |
| parser->parseOptionalAttributeDict(result->attributes)) |
| return true; |
| |
| // 'constant' taking a function reference doesn't get a redundant type |
| // specifier. The attribute itself carries it. |
| if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr)) |
| return parser->addTypeToList(fnAttr->getValue()->getType(), result->types); |
| |
| return parser->parseColonType(type) || |
| parser->addTypeToList(type, result->types); |
| } |
| |
| /// The constant op requires an attribute, and furthermore requires that it |
| /// matches the return type. |
| bool ConstantOp::verify() const { |
| auto *value = getValue(); |
| if (!value) |
| return emitOpError("requires a 'value' attribute"); |
| |
| auto *type = this->getType(); |
| if (isa<IntegerType>(type) || type->isAffineInt()) { |
| if (!isa<IntegerAttr>(value)) |
| return emitOpError( |
| "requires 'value' to be an integer for an integer result type"); |
| return false; |
| } |
| |
| if (isa<FloatType>(type)) { |
| if (!isa<FloatAttr>(value)) |
| return emitOpError("requires 'value' to be a floating point constant"); |
| return false; |
| } |
| |
| if (type->isTFString()) { |
| if (!isa<StringAttr>(value)) |
| return emitOpError("requires 'value' to be a string constant"); |
| return false; |
| } |
| |
| if (isa<FunctionType>(type)) { |
| if (!isa<FunctionAttr>(value)) |
| return emitOpError("requires 'value' to be a function reference"); |
| return false; |
| } |
| |
| return emitOpError( |
| "requires a result type that aligns with the 'value' attribute"); |
| } |
| |
| Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands, |
| MLIRContext *context) const { |
| assert(operands.empty() && "constant has no operands"); |
| return getValue(); |
| } |
| |
| void ConstantFloatOp::build(Builder *builder, OperationState *result, |
| double value, FloatType *type) { |
| ConstantOp::build(builder, result, builder->getFloatAttr(value), type); |
| } |
| |
| bool ConstantFloatOp::isClassFor(const Operation *op) { |
| return ConstantOp::isClassFor(op) && |
| isa<FloatType>(op->getResult(0)->getType()); |
| } |
| |
| /// ConstantIntOp only matches values whose result type is an IntegerType. |
| bool ConstantIntOp::isClassFor(const Operation *op) { |
| return ConstantOp::isClassFor(op) && |
| isa<IntegerType>(op->getResult(0)->getType()); |
| } |
| |
| void ConstantIntOp::build(Builder *builder, OperationState *result, |
| int64_t value, unsigned width) { |
| ConstantOp::build(builder, result, builder->getIntegerAttr(value), |
| builder->getIntegerType(width)); |
| } |
| |
| /// ConstantAffineIntOp only matches values whose result type is AffineInt. |
| bool ConstantAffineIntOp::isClassFor(const Operation *op) { |
| return ConstantOp::isClassFor(op) && |
| op->getResult(0)->getType()->isAffineInt(); |
| } |
| |
| void ConstantAffineIntOp::build(Builder *builder, OperationState *result, |
| int64_t value) { |
| ConstantOp::build(builder, result, builder->getIntegerAttr(value), |
| builder->getAffineIntType()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineApplyOp |
| //===----------------------------------------------------------------------===// |
| |
| void AffineApplyOp::build(Builder *builder, OperationState *result, |
| AffineMap *map, ArrayRef<SSAValue *> operands) { |
| result->addOperands(operands); |
| result->types.append(map->getNumResults(), builder->getAffineIntType()); |
| result->addAttribute("map", builder->getAffineMapAttr(map)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DeallocOp |
| //===----------------------------------------------------------------------===// |
| |
| void DeallocOp::build(Builder *builder, OperationState *result, |
| SSAValue *memref) { |
| result->addOperands(memref); |
| } |
| |
| void DeallocOp::print(OpAsmPrinter *p) const { |
| *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType(); |
| } |
| |
| bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) { |
| OpAsmParser::OperandType memrefInfo; |
| MemRefType *type; |
| |
| return parser->parseOperand(memrefInfo) || parser->parseColonType(type) || |
| parser->resolveOperand(memrefInfo, type, result->operands); |
| } |
| |
| bool DeallocOp::verify() const { |
| if (!isa<MemRefType>(getMemRef()->getType())) |
| return emitOpError("operand must be a memref"); |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DimOp |
| //===----------------------------------------------------------------------===// |
| |
| void DimOp::print(OpAsmPrinter *p) const { |
| *p << "dim " << *getOperand() << ", " << getIndex(); |
| p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index"); |
| *p << " : " << *getOperand()->getType(); |
| } |
| |
| bool DimOp::parse(OpAsmParser *parser, OperationState *result) { |
| OpAsmParser::OperandType operandInfo; |
| IntegerAttr *indexAttr; |
| Type *type; |
| |
| return parser->parseOperand(operandInfo) || parser->parseComma() || |
| parser->parseAttribute(indexAttr, "index", result->attributes) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(type) || |
| parser->resolveOperand(operandInfo, type, result->operands) || |
| parser->addTypeToList(parser->getBuilder().getAffineIntType(), |
| result->types); |
| } |
| |
| bool DimOp::verify() const { |
| // Check that we have an integer index operand. |
| auto indexAttr = getAttrOfType<IntegerAttr>("index"); |
| if (!indexAttr) |
| return emitOpError("requires an integer attribute named 'index'"); |
| uint64_t index = (uint64_t)indexAttr->getValue(); |
| |
| auto *type = getOperand()->getType(); |
| if (auto *tensorType = dyn_cast<RankedTensorType>(type)) { |
| if (index >= tensorType->getRank()) |
| return emitOpError("index is out of range"); |
| } else if (auto *memrefType = dyn_cast<MemRefType>(type)) { |
| if (index >= memrefType->getRank()) |
| return emitOpError("index is out of range"); |
| |
| } else if (isa<UnrankedTensorType>(type)) { |
| // ok, assumed to be in-range. |
| } else { |
| return emitOpError("requires an operand with tensor or memref type"); |
| } |
| |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ExtractElementOp |
| //===----------------------------------------------------------------------===// |
| |
| void ExtractElementOp::build(Builder *builder, OperationState *result, |
| SSAValue *aggregate, |
| ArrayRef<SSAValue *> indices) { |
| auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType()); |
| result->addOperands(aggregate); |
| result->addOperands(indices); |
| result->types.push_back(aggregateType->getElementType()); |
| } |
| |
| void ExtractElementOp::print(OpAsmPrinter *p) const { |
| *p << "extract_element " << *getAggregate() << '['; |
| p->printOperands(getIndices()); |
| *p << ']'; |
| p->printOptionalAttrDict(getAttrs()); |
| *p << " : " << *getAggregate()->getType(); |
| } |
| |
| bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { |
| OpAsmParser::OperandType aggregateInfo; |
| SmallVector<OpAsmParser::OperandType, 4> indexInfo; |
| VectorOrTensorType *type; |
| |
| auto affineIntTy = parser->getBuilder().getAffineIntType(); |
| return parser->parseOperand(aggregateInfo) || |
| parser->parseOperandList(indexInfo, -1, |
| OpAsmParser::Delimiter::Square) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(type) || |
| parser->resolveOperand(aggregateInfo, type, result->operands) || |
| parser->resolveOperands(indexInfo, affineIntTy, result->operands) || |
| parser->addTypeToList(type->getElementType(), result->types); |
| } |
| |
| bool ExtractElementOp::verify() const { |
| if (getNumOperands() == 0) |
| return emitOpError("expected an aggregate to index into"); |
| |
| auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType()); |
| if (!aggregateType) |
| return emitOpError("first operand must be a vector or tensor"); |
| |
| if (getResult()->getType() != aggregateType->getElementType()) |
| return emitOpError("result type must match element type of aggregate"); |
| |
| for (auto *idx : getIndices()) |
| if (!idx->getType()->isAffineInt()) |
| return emitOpError("index to extract_element must have 'affineint' type"); |
| |
| // Verify the # indices match if we have a ranked type. |
| auto aggregateRank = aggregateType->getRankIfPresent(); |
| if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1) |
| return emitOpError("incorrect number of indices for extract_element"); |
| |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LoadOp |
| //===----------------------------------------------------------------------===// |
| |
| void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref, |
| ArrayRef<SSAValue *> indices) { |
| auto *memrefType = cast<MemRefType>(memref->getType()); |
| result->addOperands(memref); |
| result->addOperands(indices); |
| result->types.push_back(memrefType->getElementType()); |
| } |
| |
| void LoadOp::print(OpAsmPrinter *p) const { |
| *p << "load " << *getMemRef() << '['; |
| p->printOperands(getIndices()); |
| *p << ']'; |
| p->printOptionalAttrDict(getAttrs()); |
| *p << " : " << *getMemRef()->getType(); |
| } |
| |
| bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { |
| OpAsmParser::OperandType memrefInfo; |
| SmallVector<OpAsmParser::OperandType, 4> indexInfo; |
| MemRefType *type; |
| |
| auto affineIntTy = parser->getBuilder().getAffineIntType(); |
| return parser->parseOperand(memrefInfo) || |
| parser->parseOperandList(indexInfo, -1, |
| OpAsmParser::Delimiter::Square) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(type) || |
| parser->resolveOperand(memrefInfo, type, result->operands) || |
| parser->resolveOperands(indexInfo, affineIntTy, result->operands) || |
| parser->addTypeToList(type->getElementType(), result->types); |
| } |
| |
| bool LoadOp::verify() const { |
| if (getNumOperands() == 0) |
| return emitOpError("expected a memref to load from"); |
| |
| auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); |
| if (!memRefType) |
| return emitOpError("first operand must be a memref"); |
| |
| if (getResult()->getType() != memRefType->getElementType()) |
| return emitOpError("result type must match element type of memref"); |
| |
| if (memRefType->getRank() != getNumOperands() - 1) |
| return emitOpError("incorrect number of indices for load"); |
| |
| for (auto *idx : getIndices()) |
| if (!idx->getType()->isAffineInt()) |
| return emitOpError("index to load must have 'affineint' type"); |
| |
| // TODO: Verify we have the right number of indices. |
| |
| // TODO: in MLFunction verify that the indices are parameters, IV's, or the |
| // result of an affine_apply. |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ReturnOp |
| //===----------------------------------------------------------------------===// |
| |
| void ReturnOp::build(Builder *builder, OperationState *result, |
| ArrayRef<SSAValue *> results) { |
| result->addOperands(results); |
| } |
| |
| bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { |
| SmallVector<OpAsmParser::OperandType, 2> opInfo; |
| SmallVector<Type *, 2> types; |
| llvm::SMLoc loc; |
| return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || |
| (!opInfo.empty() && parser->parseColonTypeList(types)) || |
| parser->resolveOperands(opInfo, types, loc, result->operands); |
| } |
| |
| void ReturnOp::print(OpAsmPrinter *p) const { |
| *p << "return"; |
| if (getNumOperands() > 0) { |
| *p << " "; |
| p->printOperands(operand_begin(), operand_end()); |
| *p << " : "; |
| interleave(operand_begin(), operand_end(), |
| [&](const SSAValue *e) { p->printType(e->getType()); }, |
| [&]() { *p << ", "; }); |
| } |
| } |
| |
| bool ReturnOp::verify() const { |
| // ReturnOp must be part of an ML function. |
| if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) { |
| StmtBlock *block = stmt->getBlock(); |
| if (!block || !isa<MLFunction>(block) || &block->back() != stmt) |
| return emitOpError("must be the last statement in the ML function"); |
| |
| // Return success. Checking that operand types match those in the function |
| // signature is performed in the ML function verifier. |
| return false; |
| } |
| return emitOpError("cannot occur in a CFG function"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ShapeCastOp |
| //===----------------------------------------------------------------------===// |
| |
| void ShapeCastOp::build(Builder *builder, OperationState *result, |
| SSAValue *input, Type *resultType) { |
| result->addOperands(input); |
| result->addTypes(resultType); |
| } |
| |
| bool ShapeCastOp::verify() const { |
| auto *opType = dyn_cast<TensorType>(getOperand()->getType()); |
| auto *resType = dyn_cast<TensorType>(getResult()->getType()); |
| if (!opType || !resType) |
| return emitOpError("requires input and result types to be tensors"); |
| |
| if (opType == resType) |
| return emitOpError("requires the input and result type to be different"); |
| |
| if (opType->getElementType() != resType->getElementType()) |
| return emitOpError( |
| "requires input and result element types to be the same"); |
| |
| // If the source or destination are unranked, then the cast is valid. |
| auto *opRType = dyn_cast<RankedTensorType>(opType); |
| auto *resRType = dyn_cast<RankedTensorType>(resType); |
| if (!opRType || !resRType) |
| return false; |
| |
| // If they are both ranked, they have to have the same rank, and any specified |
| // dimensions must match. |
| if (opRType->getRank() != resRType->getRank()) |
| return emitOpError("requires input and result ranks to match"); |
| |
| for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) { |
| int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i); |
| if (opDim != -1 && resultDim != -1 && opDim != resultDim) |
| return emitOpError("requires static dimensions to match"); |
| } |
| |
| return false; |
| } |
| |
| void ShapeCastOp::print(OpAsmPrinter *p) const { |
| *p << "shape_cast " << *getOperand() << " : " << *getOperand()->getType() |
| << " to " << *getType(); |
| } |
| |
| bool ShapeCastOp::parse(OpAsmParser *parser, OperationState *result) { |
| OpAsmParser::OperandType srcInfo; |
| Type *srcType, *dstType; |
| return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) || |
| parser->resolveOperand(srcInfo, srcType, result->operands) || |
| parser->parseKeywordType("to", dstType) || |
| parser->addTypeToList(dstType, result->types); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StoreOp |
| //===----------------------------------------------------------------------===// |
| |
| void StoreOp::build(Builder *builder, OperationState *result, |
| SSAValue *valueToStore, SSAValue *memref, |
| ArrayRef<SSAValue *> indices) { |
| result->addOperands(valueToStore); |
| result->addOperands(memref); |
| result->addOperands(indices); |
| } |
| |
| void StoreOp::print(OpAsmPrinter *p) const { |
| *p << "store " << *getValueToStore(); |
| *p << ", " << *getMemRef() << '['; |
| p->printOperands(getIndices()); |
| *p << ']'; |
| p->printOptionalAttrDict(getAttrs()); |
| *p << " : " << *getMemRef()->getType(); |
| } |
| |
| bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { |
| OpAsmParser::OperandType storeValueInfo; |
| OpAsmParser::OperandType memrefInfo; |
| SmallVector<OpAsmParser::OperandType, 4> indexInfo; |
| MemRefType *memrefType; |
| |
| auto affineIntTy = parser->getBuilder().getAffineIntType(); |
| return parser->parseOperand(storeValueInfo) || parser->parseComma() || |
| parser->parseOperand(memrefInfo) || |
| parser->parseOperandList(indexInfo, -1, |
| OpAsmParser::Delimiter::Square) || |
| parser->parseOptionalAttributeDict(result->attributes) || |
| parser->parseColonType(memrefType) || |
| parser->resolveOperand(storeValueInfo, memrefType->getElementType(), |
| result->operands) || |
| parser->resolveOperand(memrefInfo, memrefType, result->operands) || |
| parser->resolveOperands(indexInfo, affineIntTy, result->operands); |
| } |
| |
| bool StoreOp::verify() const { |
| if (getNumOperands() < 2) |
| return emitOpError("expected a value to store and a memref"); |
| |
| // Second operand is a memref type. |
| auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); |
| if (!memRefType) |
| return emitOpError("second operand must be a memref"); |
| |
| // First operand must have same type as memref element type. |
| if (getValueToStore()->getType() != memRefType->getElementType()) |
| return emitOpError("first operand must have same type memref element type"); |
| |
| if (getNumOperands() != 2 + memRefType->getRank()) |
| return emitOpError("store index operand count not equal to memref rank"); |
| |
| for (auto *idx : getIndices()) |
| if (!idx->getType()->isAffineInt()) |
| return emitOpError("index to load must have 'affineint' type"); |
| |
| // TODO: Verify we have the right number of indices. |
| |
| // TODO: in MLFunction verify that the indices are parameters, IV's, or the |
| // result of an affine_apply. |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Register operations. |
| //===----------------------------------------------------------------------===// |
| |
| /// Install the standard operations in the specified operation set. |
| void mlir::registerStandardOperations(OperationSet &opSet) { |
| opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp, |
| ConstantOp, DeallocOp, DimOp, ExtractElementOp, LoadOp, |
| ReturnOp, ShapeCastOp, StoreOp>( |
| /*prefix=*/""); |
| } |