blob: c23c601572069ef753a91124ee399906adc27167 [file] [log] [blame]
//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a the Linalg operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Linalg/IR/LinalgOps.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
#include "mlir/Linalg/Utils/Utils.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/STLExtras.h"
#include "mlir/Transforms/FoldUtils.h"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
////////////////////////////////////////////////////////////////////////////////
// ForOp.
////////////////////////////////////////////////////////////////////////////////
// Check that if a "block" has a terminator, it is an `TerminatorOp`.
static LogicalResult checkHasTerminator(OpState &op, Block &block) {
if (block.empty() || isa<TerminatorOp>(block.back()))
return success();
return op.emitOpError("expects regions to end with '" +
TerminatorOp::getOperationName() + "'")
.attachNote()
<< "in custom textual format, the absence of terminator implies '"
<< TerminatorOp::getOperationName() << "'";
}
// Insert `linalg.terminator` at the end of the ForOp only region's only block
// if it does not have a terminator already. If a new `linalg.terminator` is
// inserted, the location is specified by `loc`. If the region is empty, insert
// a new block first.
static void ensureTerminator(Region &region, Builder &builder, Location loc) {
impl::ensureRegionTerminator<TerminatorOp>(region, builder, loc);
}
void mlir::linalg::ForOp::build(Builder *builder, OperationState *result,
Value *lb, Value *ub, Value *step) {
result->addOperands({lb, ub, step});
Region *bodyRegion = result->addRegion();
Block *body = new Block();
body->addArgument(IndexType::get(builder->getContext()));
bodyRegion->push_back(body);
ensureTerminator(*bodyRegion, *builder, result->location);
}
LogicalResult mlir::linalg::ForOp::verify() {
if (!getLowerBound()->getType().isa<IndexType>())
return emitOpError("lower bound operand must be an index");
if (!getUpperBound()->getType().isa<IndexType>())
return emitOpError("upper bound operand must be an index");
if (!getStep()->getType().dyn_cast<IndexType>())
return emitOpError("step operand must be an index");
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(getStep()->getDefiningOp()))
if (cst.getValue() <= 0)
return emitOpError("constant step operand must be positive");
if (std::next(getOperation()->getRegions().begin()) !=
getOperation()->getRegions().end())
return emitOpError("operation expected to have exactly one region");
auto &bodyRegion = getOperation()->getRegion(0);
// The body region must contain a single basic block.
if (bodyRegion.empty() || std::next(bodyRegion.begin()) != bodyRegion.end())
return emitOpError("expected body region to have a single block");
// Check that the body defines as single block argument for the induction
// variable.
auto *body = getBody();
if (body->getNumArguments() != 1 ||
!body->getArgument(0)->getType().isIndex())
return emitOpError("expected body to have a single index argument for "
"the induction variable");
if (failed(checkHasTerminator(*this, *body)))
return failure();
return success();
}
void mlir::linalg::ForOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getInductionVar() << " = "
<< *getLowerBound() << " to " << *getUpperBound() << " step "
<< *getStep();
p->printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
p->printOptionalAttrDict(getAttrs());
}
ParseResult mlir::linalg::ForOp::parse(OpAsmParser *parser,
OperationState *result) {
auto &builder = parser->getBuilder();
OpAsmParser::OperandType inductionVariable, lb, ub, step;
// Parse the induction variable followed by '='.
if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual())
return failure();
// Parse loop bounds.
Type indexType = builder.getIndexType();
if (parser->parseOperand(lb) ||
parser->resolveOperand(lb, indexType, result->operands) ||
parser->parseKeyword("to") || parser->parseOperand(ub) ||
parser->resolveOperand(ub, indexType, result->operands) ||
parser->parseKeyword("step") || parser->parseOperand(step) ||
parser->resolveOperand(step, indexType, result->operands))
return failure();
// Parse the body region.
Region *body = result->addRegion();
if (parser->parseRegion(*body, inductionVariable, indexType))
return failure();
ensureTerminator(*body, builder, result->location);
// Parse the optional attribute list.
if (parser->parseOptionalAttributeDict(result->attributes))
return failure();
return success();
}
mlir::linalg::ForOp mlir::linalg::getForInductionVarOwner(Value *val) {
auto *ivArg = dyn_cast<BlockArgument>(val);
if (!ivArg)
return ForOp();
assert(ivArg->getOwner() && "unlinked block argument");
auto *containingInst = ivArg->getOwner()->getContainingOp();
return dyn_cast_or_null<ForOp>(containingInst);
}
////////////////////////////////////////////////////////////////////////////////
// LoadOp.
////////////////////////////////////////////////////////////////////////////////
void mlir::linalg::LoadOp::build(Builder *b, OperationState *result,
Value *view, ArrayRef<Value *> indices) {
auto viewType = view->getType().cast<ViewType>();
result->addOperands(view);
result->addOperands(indices);
result->addTypes(viewType.getElementType());
}
// A LoadOp prints as:
//
// ```{.mlir}
// %0 = linalg.load %V[%c0] : !linalg.view<?xf32>
// ```
void mlir::linalg::LoadOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getView() << '[';
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << getViewType();
}
ParseResult mlir::linalg::LoadOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType viewInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
ViewType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return failure(
parser->parseOperand(viewInfo) ||
parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(viewInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type.getElementType(), result->types));
}
LogicalResult mlir::linalg::LoadOp::verify() {
if (getNumOperands() == 0)
return emitOpError("expected a view to load from");
auto viewType = getView()->getType().dyn_cast<ViewType>();
if (!viewType)
return emitOpError("first operand must be a view");
if (getType() != viewType.getElementType())
return emitOpError("result type must match element type of the view");
if (getRank() != getNumOperands() - 1)
return emitOpError("incorrect number of indices for load");
for (auto *idx : getIndices())
if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type");
return success();
}
//////////////////////////////////////////////////////////////////////////////
// RangeOp
//////////////////////////////////////////////////////////////////////////////
void mlir::linalg::RangeOp::build(Builder *b, OperationState *result,
Value *min, Value *max, Value *step) {
result->addOperands({min, max, step});
result->addTypes({RangeType::get(b->getContext())});
}
// Verification is simply that a RangeOp takes 3 index ssa-value.
LogicalResult mlir::linalg::RangeOp::verify() {
if (!min() || !min()->getType().isa<IndexType>())
return emitOpError("first operand should be of type index");
if (!max() || !max()->getType().isa<IndexType>())
return emitOpError("second operand should be of type index");
if (!step() || !step()->getType().isa<IndexType>())
return emitOpError("third operand should be of type index");
return success();
}
// A RangeOp prints as:
//
// ```{.mlir}
// linalg.range %0:%1:%2 : !linalg.range
// ```
void mlir::linalg::RangeOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
<< " : " << getType();
}
ParseResult mlir::linalg::RangeOp::parse(OpAsmParser *parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return failure(
parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type, result->types));
}
//////////////////////////////////////////////////////////////////////////////
// SliceOp
//////////////////////////////////////////////////////////////////////////////
void mlir::linalg::SliceOp::build(Builder *b, OperationState *result,
Value *base, ArrayRef<Value *> indexings) {
result->addOperands({base});
result->addOperands(indexings);
ViewType viewType = base->getType().cast<ViewType>();
unsigned rank = viewType.getRank();
for (auto *i : indexings)
if (!i->getType().isa<RangeType>())
rank--;
Type elementType = viewType.getElementType();
result->addTypes({ViewType::get(b->getContext(), elementType, rank)});
}
LogicalResult mlir::linalg::SliceOp::verify() {
if (llvm::empty(getOperands()))
return emitOpError(
"requires at least a view operand followed by 'rank' indices");
unsigned rank = getBaseViewRank();
if (llvm::size(getIndexings()) != rank) {
return emitOpError("requires at least a view operand followed by ")
<< rank << " indexings";
}
unsigned index = 0;
for (auto indexing : getIndexings()) {
if (!indexing->getType().isa<RangeType>() &&
!indexing->getType().isa<IndexType>()) {
return emitOpError() << index
<< "^th index must be of range or index type";
}
if (indexing->getType().isa<IndexType>())
--rank;
++index;
}
if (getRank() != rank) {
return emitOpError()
<< "the rank of the view must be the number of its range indices ("
<< rank << ") but got: " << getRank();
}
return success();
}
ParseResult mlir::linalg::SliceOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType baseInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
SmallVector<Type, 8> types;
if (parser->parseOperand(baseInfo) ||
parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types))
return failure();
if (types.size() != 2 + indexingsInfo.size())
return parser->emitError(parser->getNameLoc(),
"unexpected number of types ");
ViewType baseViewType = types[0].dyn_cast<ViewType>();
if (!baseViewType)
return parser->emitError(parser->getNameLoc(),
"view type expected for first type");
if (indexingsInfo.size() != baseViewType.getRank())
return parser->emitError(parser->getNameLoc(), "expected ")
<< baseViewType.getRank() << " indexings";
ViewType viewType = types.back().dyn_cast<ViewType>();
if (!viewType)
return parser->emitError(parser->getNameLoc(), "view type expected");
ArrayRef<Type> indexingTypes =
ArrayRef<Type>(types).drop_front(1).drop_back(1);
if (indexingTypes.size() != baseViewType.getRank())
return parser->emitError(parser->getNameLoc(), "expected ")
<< baseViewType.getRank() << " indexing types";
return failure(
parser->resolveOperand(baseInfo, baseViewType, result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, indexingTypes,
indexingsInfo.front().location,
result->operands)) ||
parser->addTypeToList(viewType, result->types));
}
// A SliceOp prints as:
//
// ```{.mlir}
// linalg.slice %0[%1, %2] :
// !linalg.view<?x?xf32>, [indexing-types], !linalg.view<?x?xf32>
// ```
//
// Where %0 is an ssa-value holding a view created from a buffer, %1 and %2 are
// ssa-value each holding a range.
void mlir::linalg::SliceOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getBaseView() << "[";
interleave(
getIndexings().begin(), getIndexings().end(), [p](Value *v) { *p << *v; },
[p]() { *p << ", "; });
*p << "] : " << getBaseViewType();
for (auto indexing : getIndexings()) {
*p << ", " << indexing->getType();
}
*p << ", " << getType();
}
ViewOp mlir::linalg::SliceOp::getBaseViewOp() {
return cast<ViewOp>(getOperand(0)->getDefiningOp());
}
ViewType mlir::linalg::SliceOp::getBaseViewType() {
return getOperand(0)->getType().cast<ViewType>();
}
SmallVector<Value *, 8> mlir::linalg::SliceOp::getRanges() {
llvm::SmallVector<Value *, 8> res;
for (auto *operand : getIndexings()) {
if (!operand->getType().isa<IndexType>()) {
res.push_back(operand);
}
}
return res;
}
////////////////////////////////////////////////////////////////////////////////
// StoreOp.
////////////////////////////////////////////////////////////////////////////////
void mlir::linalg::StoreOp::build(Builder *b, OperationState *result,
Value *valueToStore, Value *view,
ArrayRef<Value *> indices) {
result->addOperands(valueToStore);
result->addOperands(view);
result->addOperands(indices);
}
// A StoreOp prints as:
//
// ```{.mlir}
// linalg.store %f, %V[%c0] : !linalg.view<?xf32>
// ```
void mlir::linalg::StoreOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getValueToStore();
*p << ", " << *getView() << '[';
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << getViewType();
}
ParseResult mlir::linalg::StoreOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType viewInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
ViewType viewType;
auto affineIntTy = parser->getBuilder().getIndexType();
return failure(
parser->parseOperand(storeValueInfo) || parser->parseComma() ||
parser->parseOperand(viewInfo) ||
parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(viewType) ||
parser->resolveOperand(storeValueInfo, viewType.getElementType(),
result->operands) ||
parser->resolveOperand(viewInfo, viewType, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands));
}
LogicalResult mlir::linalg::StoreOp::verify() {
if (getNumOperands() < 2)
return emitOpError("expected a value to store and a view");
// Second operand is a memref type.
auto viewType = getView()->getType().dyn_cast<ViewType>();
if (!viewType)
return emitOpError("second operand must be a view");
// First operand must have same type as memref element type.
if (getValueToStore()->getType() != viewType.getElementType())
return emitOpError("first operand must have same element type as the view");
if (getNumOperands() != 2 + viewType.getRank())
return emitOpError("store index operand count not equal to view rank");
for (auto *idx : getIndices())
if (!idx->getType().isIndex())
return emitOpError("index to store must have 'index' type");
return success();
}
//////////////////////////////////////////////////////////////////////////////
// ViewOp
//////////////////////////////////////////////////////////////////////////////
void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
Value *buffer, ArrayRef<Value *> indexings) {
BufferType bufferType = buffer->getType().cast<BufferType>();
result->addOperands({buffer});
result->addOperands(indexings);
assert(
std::none_of(indexings.begin(), indexings.end(),
[](Value *v) { return !v->getType().isa<RangeType>(); }) &&
"linalg.view takes only arguments of type linalg.range");
Type elementType = bufferType.getElementType();
result->addTypes(
{ViewType::get(b->getContext(), elementType, indexings.size())});
}
LogicalResult mlir::linalg::ViewOp::verify() {
if (llvm::empty(getOperands()))
return emitOpError(
"requires at least a buffer operand followed by indexings");
auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
if (!bufferType)
return emitOpError("first operand must be of BufferType");
unsigned index = 0;
for (auto indexing : getIndexings()) {
if (!indexing->getType().isa<RangeType>()) {
return emitOpError() << index << "^th index must be of range type";
}
++index;
}
if (getViewType().getRank() != index)
return emitOpError()
<< "the rank of the view must be the number of its indexings";
return success();
}
ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
Type type;
if (parser->parseOperand(bufferInfo) ||
parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return failure();
ViewType viewType = type.dyn_cast<ViewType>();
if (!viewType)
return parser->emitError(parser->getNameLoc(), "view type expected");
if (viewType.getRank() != indexingsInfo.size())
return parser->emitError(parser->getNameLoc(), "expected")
<< viewType.getRank() << " range indexings";
return failure(
parser->resolveOperand(
bufferInfo,
BufferType::get(type.getContext(), viewType.getElementType()),
result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, RangeType::get(type.getContext()),
result->operands)) ||
parser->addTypeToList(viewType, result->types));
}
// A ViewOp prints as:
//
// ```{.mlir}
// linalg.view %0[%1, %2] : !linalg.view<?x?xf32>
// ```
//
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
// holding a range.
void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getSupportingBuffer() << "[";
interleave(
getIndexings().begin(), getIndexings().end(), [&](Value *v) { *p << *v; },
[&]() { *p << ", "; });
*p << "] : " << getType();
}
///////////////////// Operations defined with Tablegen /////////////////////////
// For such operations that do not correspond to library calls (i.e. defined in
// LinalgOps.td), we define an overloaded `print` function and a
// parse`className` function.
static void print(OpAsmPrinter *p, BufferAllocOp op) {
*p << op.getOperationName() << " ";
if (!llvm::empty(op.size()))
*p << *op.getOperand(0);
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.getBufferType();
}
static ParseResult parseBufferAllocOp(OpAsmParser *parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 1> sizeInfo;
BufferType bufferType;
auto indexTy = parser->getBuilder().getIndexType();
if (parser->parseOperandList(sizeInfo) || parser->parseColonType(bufferType))
return failure();
if (sizeInfo.empty())
return parser->addTypeToList(bufferType, result->types);
return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
parser->addTypeToList(bufferType, result->types));
}
static LogicalResult verify(BufferAllocOp op) {
if (!op.getBufferType().hasConstantSize()) {
if (llvm::size(op.size()) != 1 ||
!op.getOperand(0)->getType().isa<IndexType>())
return op.emitOpError(
"one operand of type index expected for dynamic buffer");
} else { // op.getBufferType().hasConstantSize()
if (!llvm::empty(op.size()))
return op.emitOpError("unexpected static buffer operand");
if (op.getBufferType().getBufferSize().getValue() <= 0)
return op.emitOpError("expected nonnegative static buffer size");
}
if (!VectorType::isValidElementType(op.getElementType()) &&
!op.getElementType().isa<VectorType>())
return op.emitOpError("unsupported buffer element type");
return success();
}
static void print(OpAsmPrinter *p, BufferDeallocOp op) {
*p << op.getOperationName() << " " << *op.buffer();
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.getBufferType();
}
static ParseResult parseBufferDeallocOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType bufferInfo;
BufferType bufferType;
if (parser->parseOperand(bufferInfo) || parser->parseColonType(bufferType))
return failure();
return parser->resolveOperands(bufferInfo, bufferType, result->operands);
}
static void print(OpAsmPrinter *p, BufferSizeOp op) {
*p << op.getOperationName() << " " << *op.getOperand();
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.getOperand()->getType();
}
static ParseResult parseBufferSizeOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType op;
Type type;
return failure(parser->parseOperand(op) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(op, type, result->operands) ||
parser->addTypeToList(parser->getBuilder().getIndexType(),
result->types));
}
static void print(OpAsmPrinter *p, linalg::DimOp op) {
*p << op.getOperationName() << " " << *op.getOperand() << ", "
<< op.getIndex();
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
*p << " : " << op.getOperand()->getType();
}
static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo;
IntegerAttr indexAttr;
Type type;
Type indexType = parser->getBuilder().getIndexType();
return failure(parser->parseOperand(operandInfo) || parser->parseComma() ||
parser->parseAttribute(indexAttr, indexType, "index",
result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(operandInfo, type, result->operands) ||
parser->addTypeToList(indexType, result->types));
}
static void print(OpAsmPrinter *p, RangeIntersectOp op) {
*p << op.getOperationName() << " " << *op.getOperand(0) << ", "
<< *op.getOperand(1);
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.getOperand(0)->getType();
}
static ParseResult parseRangeIntersectOp(OpAsmParser *parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type type;
return failure(parser->parseOperandList(ops) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperands(ops, type, result->operands) ||
parser->addTypeToList(type, result->types));
}
static void print(OpAsmPrinter *p, SubViewOp op) {
*p << op.getOperationName() << " " << *op.getOperand(0) << "[";
auto ranges = op.getRanges();
interleaveComma(ranges, *p, [&p](const SubViewOp::Range &i) {
*p << *i.min << ", " << *i.max << ", " << *i.step;
});
*p << "]";
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.getViewType();
}
static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType inputView, resultView;
Type viewType;
if (parser->parseOperand(inputView))
return failure();
SmallVector<OpAsmParser::OperandType, 12> ops;
// TODO(ntv) evolve parsing from
// linalg.subview %0[%1, %2, %3, %4, %5, %6]
// to something resembling
// linalg.subview %0[%1:%2:%3][%4:%5:%6]
if (parser->parseOperandList(ops, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(viewType))
return failure();
auto indexTy = parser->getBuilder().getIndexType();
return failure(
parser->resolveOperand(inputView, viewType, result->operands) ||
parser->resolveOperands(ops, indexTy, result->operands) ||
parser->addTypeToList(viewType, result->types));
}
/////// Operations corresponding to library calls defined with Tablegen ////////
// For such operations correspond to library calls (i.e. defined in
// LinalgLibraryOps.td), we define an overloaded `print` function and a
// parse`className` function.
// A LinalgLibraryOp prints as:
//
// ```{.mlir}
// concrete_op_name (ssa-inputs, ssa-outputs) : view-types
// ```
//
// for example:
//
// ```
// linalg.matmul(%0, %1, %2) :
// !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
// ```
//
// Where %0, %1 and %2 are ssa-values of type ViewType.
static void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
*p << op->getName().getStringRef() << "(";
interleave(
op->getOperands().begin(), op->getOperands().end(),
[&](Value *v) { *p << *v; }, [&]() { *p << ", "; });
*p << ")";
p->printOptionalAttrDict(op->getAttrs());
*p << " : ";
interleave(
op->getOperands().begin(), op->getOperands().end(),
[&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
}
static ParseResult parseLinalgLibraryOp(OpAsmParser *parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<Type, 3> types;
return failure(parser->parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types) ||
parser->resolveOperands(ops, types, parser->getNameLoc(),
result->operands));
}
static LogicalResult verify(FillOp op) {
auto viewType = op.getOutputViewType(0);
auto fillType = op.getValue()->getType();
if (viewType.getElementType() != fillType)
return op.emitOpError("expects fill type to match view elemental type");
return success();
}
static LogicalResult verify(CopyOp op) {
auto outputViewType = op.getOutputViewType(0);
auto inputViewType = op.getInputViewType(0);
if (inputViewType.getElementType() != outputViewType.getElementType())
return op.emitOpError("expects views of the same type");
if (inputViewType.getRank() != outputViewType.getRank())
return op.emitOpError("expects views of the same rank");
auto rank = op.getNumParallelLoops();
auto inputPermutationMap = op.inputPermutation();
if (inputPermutationMap) {
if (inputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional input_permutation map of rank ")
<< rank;
if (!inputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional input_permutation map to be a permutation");
}
auto outputPermutationMap = op.outputPermutation();
if (outputPermutationMap) {
if (outputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional output_permutation map of rank ")
<< rank;
if (!outputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional output_permutation map to be a permutation");
}
if (rank == 0 && inputPermutationMap)
return op.emitOpError("expected no input permutation when rank == 0");
if (rank == 0 && outputPermutationMap)
return op.emitOpError("expected no output permutation when rank == 0");
return success();
}
static LogicalResult
verifyStrideOrDilation(ConvOp op, ArrayRef<Attribute> attrs, bool isStride) {
auto strideOrDilation = isStride ? "stride" : "dilation";
if (attrs.size() != op.getNumWindowLoops())
return op.emitOpError("expects num ")
<< strideOrDilation
<< "s equal to number of window dimensions: " << attrs.size()
<< " vs " << op.getNumWindowLoops();
return success();
}
static LogicalResult verify(ConvOp op) {
auto oType = op.output()->getType().cast<ViewType>();
auto fType = op.filter()->getType().cast<ViewType>();
auto iType = op.input()->getType().cast<ViewType>();
if (oType.getElementType() != iType.getElementType() ||
oType.getElementType() != fType.getElementType())
return op.emitOpError("expects view elemental types to match");
if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank())
return op.emitOpError("expects view ranks to match");
if (auto strides = op.strides()) {
if (failed(
verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true)))
return failure();
}
if (auto dilations = op.dilations()) {
if (failed(verifyStrideOrDilation(op, dilations->getValue(),
/*isStride=*/false)))
return failure();
}
return success();
}
llvm::raw_ostream &mlir::linalg::operator<<(llvm::raw_ostream &os,
SubViewOp::Range &range) {
return os << "range " << *range.min << ":" << *range.max << ":"
<< *range.step;
}
namespace mlir {
namespace linalg {
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
} // namespace linalg
} // namespace mlir
static AffineMap extractOrIdentityMap(llvm::Optional<AffineMap> maybeMap,
unsigned rank, MLIRContext *context) {
if (maybeMap)
return maybeMap.getValue();
if (rank == 0)
return AffineMap();
return AffineMap::getMultiDimIdentityMap(rank, context);
}
// Returns `num` AffineDimExpr dimensions at positions [curIdx, curIdx + num)
// and increments `curIdx` to `curIdx + num`.
static SmallVector<AffineExpr, 4>
makeAffineDimExprs(unsigned num, unsigned &curIdx, MLIRContext *context) {
SmallVector<AffineExpr, 4> res;
res.reserve(num);
for (unsigned i = 0; i < num; ++i)
res.push_back(getAffineDimExpr(curIdx++, context));
return res;
}
static SmallVector<AffineExpr, 4>
weightedConvInputIndex(ConvOp op, ArrayRef<AffineExpr> a,
ArrayRef<AffineExpr> b) {
assert(a.size() == b.size());
SmallVector<AffineExpr, 4> res;
res.reserve(a.size());
for (unsigned i = 0, e = a.size(); i < e; ++i) {
res.push_back(op.getStride(i) * a[i] + op.getDilation(i) * b[i]);
}
return res;
}
static SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
ArrayRef<AffineExpr> b) {
SmallVector<AffineExpr, 4> res;
res.reserve(a.size() + b.size());
res.assign(a.begin(), a.end());
res.append(b.begin(), b.end());
return res;
}
static SmallVector<ValueHandle, 8>
foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map,
ArrayRef<Value *> vals, OperationFolder &folder) {
assert(map.getNumSymbols() == 0);
assert(map.getNumInputs() == vals.size());
SmallVector<ValueHandle, 8> res;
res.reserve(map.getNumResults());
auto dims = map.getNumDims();
for (auto e : map.getResults()) {
auto exprMap = AffineMap::get(dims, 0, e);
SmallVector<Value *, 4> operands(vals.begin(), vals.end());
canonicalizeMapAndOperands(&exprMap, &operands);
res.push_back(
ValueHandle(folder.create<AffineApplyOp>(b, loc, exprMap, operands)));
}
return res;
}
// Note: both functions below would completely disappear with a simple tensor
// kernel language.
//
// Ideally this should all be Tablegen'd but there is no good story for
// AffineMap for now.
SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
MLIRContext *context = op->getContext();
if (auto copyOp = dyn_cast<CopyOp>(op)) {
// I(input_perm(ivs)) -> O(output_perm(ivs))
auto maybeInputMap = copyOp.inputPermutation();
auto maybeOutputMap = copyOp.outputPermutation();
unsigned inputRank = copyOp.getInputViewType(0).getRank();
unsigned outputRank = copyOp.getOutputViewType(0).getRank();
return SmallVector<AffineMap, 4>{
extractOrIdentityMap(maybeInputMap, inputRank, context),
extractOrIdentityMap(maybeOutputMap, outputRank, context)};
}
if (auto fillOp = dyn_cast<FillOp>(op)) {
// filling_value -> O(ivs)
unsigned rank = fillOp.getNumParallelLoops();
return SmallVector<AffineMap, 4>{
extractOrIdentityMap(llvm::None, rank, context)};
}
auto i = getAffineDimExpr(0, context);
auto j = getAffineDimExpr(1, context);
auto k = getAffineDimExpr(2, context);
if (isa<DotOp>(op))
// A(r_i) * B(r_i) -> C()
return SmallVector<AffineMap, 4>{AffineMap::get(1, 0, {i}),
AffineMap::get(1, 0, {i}), AffineMap()};
if (isa<MatvecOp>(op))
// A(i, r_j) * B(r_j) -> C(i)
return SmallVector<AffineMap, 4>{AffineMap::get(2, 0, {i, j}),
AffineMap::get(2, 0, {j}),
AffineMap::get(2, 0, {i})};
if (isa<MatmulOp>(op))
// A(i, r_k) * B(r_k, j) -> C(i, j)
return SmallVector<AffineMap, 4>{AffineMap::get(3, 0, {i, k}),
AffineMap::get(3, 0, {k, j}),
AffineMap::get(3, 0, {i, j})};
if (auto convOp = dyn_cast<ConvOp>(op)) {
// F(z0, ..., zN-1, q, k) * I(b, x0 + z0, ..., xN-1 + zN-1, q) ->
// O(b, x0, ..., xN-1, k)
// for N equal to `nWindow`.
auto nWin = convOp.getNumWindowLoops();
assert(nWin > 0 && "expected at least one window dimension");
unsigned idx = 0;
// In the following, AffineDimExprs are indexed in loop order:
// [ b, xs, k, q, zs]
// parallels non-window reductions windows
//
// Parallel dims are exactly the dimensions indexing `output`:
// output[b, x[0], ..., x[N-1], k]; i.e.
// * batch dimensions (bs with #bs = 1 for now)
// * "image" dimensions (xs with #xs = #zs = output_rank - #bs - #ks)
// * output filter dimensions (ks with #ks = 1 for now)
auto bs = makeAffineDimExprs(convOp.getNumBatchDimensions(), idx, context);
auto xs = makeAffineDimExprs(nWin, idx, context);
auto ks = makeAffineDimExprs(convOp.getNumOutputFeatureDimensions(), idx,
context);
// Non-window reduction dim: sum_{z[0], ..., z[N-1], q}
auto qs =
makeAffineDimExprs(convOp.getNumInputFeatureDimensions(), idx, context);
// Window reduction dims: sum_{z[0], ..., z[N-1], q}
auto zs = makeAffineDimExprs(nWin, idx, context);
// Construct the weighedSum expression.
auto ws = weightedConvInputIndex(convOp, xs, zs);
return SmallVector<AffineMap, 4>{
// filter[z[0], ..., z[N-1], q, k]
AffineMap::get(idx, 0, concat(concat(zs, qs), ks)),
// input[b,
// x[0]*s[0] + d[0]*z[0], ..., x[N-1]*s[N-1] + d[N-1]*z[N-1],
// q]
AffineMap::get(idx, 0, concat(concat(bs, ws), qs)),
// output[b, x[0], ..., x[N-1], k]
AffineMap::get(idx, 0, concat(concat(bs, xs), ks))};
}
llvm_unreachable("Missing loopToOperandRangesMaps for op");
}
static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
Optional<AffineMap> permutation,
OperationFolder &state) {
return permutation ? applyMapToValues(ScopedContext::getBuilder(),
ScopedContext::getLocation(),
permutation.getValue(), ivs, state)
: SmallVector<Value *, 4>(ivs.begin(), ivs.end());
}
// Ideally this should all be Tablegen'd but there is no good story for op
// expansion directly in MLIR for now.
void mlir::linalg::emitScalarImplementation(
llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs,
llvm::ArrayRef<Value *> windowIvs, LinalgOp &linalgOp,
OperationFolder &folder) {
using linalg_load = ValueBuilder<linalg::LoadOp>;
using linalg_store = OperationBuilder<linalg::StoreOp>;
using IndexedValue = TemplatedIndexedValue<linalg_load, linalg_store>;
using edsc::op::operator+;
using edsc::op::operator*;
using edsc::op::operator==;
using edsc::intrinsics::select;
auto nPar = parallelIvs.size();
auto nRed = reductionIvs.size();
auto nWin = windowIvs.size();
SmallVector<Value *, 8> allIvs;
allIvs.reserve(nPar + nRed + nWin);
allIvs.assign(parallelIvs.begin(), parallelIvs.end());
allIvs.append(reductionIvs.begin(), reductionIvs.end());
allIvs.append(windowIvs.begin(), windowIvs.end());
// Default OpBuilder supports 0-D case (no loops).
OpBuilder b(linalgOp.getOperation());
auto nLoops = nPar + nRed + nWin;
if (nLoops > 0) {
auto innermostLoop = linalg::getForInductionVarOwner(allIvs.back());
// accounts for linalg.terminator in loop.
b = innermostLoop.getBodyBuilder();
}
auto loc = linalgOp.getLoc();
ScopedContext scope(b, loc);
auto *op = linalgOp.getOperation();
if (auto copyOp = dyn_cast<CopyOp>(op)) {
OperationFolder state;
auto inputIvs = permuteIvs(parallelIvs, copyOp.inputPermutation(), state);
auto outputIvs = permuteIvs(parallelIvs, copyOp.outputPermutation(), state);
SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
// clang-format off
IndexedValue O(copyOp.getOutput(0)), I(copyOp.getInput(0));
nLoops > 0 ?
O(oivs) = I(iivs) :
O() = I();
// clang-format on
return;
}
if (auto fillOp = dyn_cast<FillOp>(op)) {
SmallVector<IndexHandle, 8> ivs(parallelIvs.begin(), parallelIvs.end());
// clang-format off
IndexedValue O(fillOp.getOutput(0));
nLoops > 0 ?
O(ivs) = ValueHandle(fillOp.getValue()) :
O() = ValueHandle(fillOp.getValue());
// clang-format on
return;
}
if (auto dotOp = dyn_cast<DotOp>(op)) {
IndexHandle r_i(reductionIvs[0]);
IndexedValue A(dotOp.getInput(0)), B(dotOp.getInput(1)),
C(dotOp.getOutput(0));
C() = C() + A(r_i) * B(r_i);
return;
}
if (auto matvecOp = dyn_cast<MatvecOp>(op)) {
IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]);
IndexedValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
C(matvecOp.getOutput(0));
C(i) = C(i) + A(i, r_j) * B(r_j);
return;
}
if (auto matmulOp = dyn_cast<MatmulOp>(op)) {
IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]);
IndexedValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
C(matmulOp.getOutput(0));
C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
return;
}
if (auto convOp = dyn_cast<ConvOp>(op)) {
auto maps = loopToOperandRangesMaps(op);
SmallVector<ValueHandle, 8> fIdx(
foldedAffineApplies(b, loc, maps[0], allIvs, folder));
SmallVector<ValueHandle, 8> imIdx(
foldedAffineApplies(b, loc, maps[1], allIvs, folder));
SmallVector<ValueHandle, 8> oIdx(
foldedAffineApplies(b, loc, maps[2], allIvs, folder));
IndexedValue F(convOp.filter()), I(convOp.input()), O(convOp.output());
O(oIdx) += F(fIdx) * I(imIdx);
return;
}
llvm_unreachable("Missing emitScalarImplementation for op");
}