blob: 6998da595eecdd6bd5425a377ca583e0b80d2a49 [file] [log] [blame]
//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a the Linalg operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Linalg/IR/LinalgOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/STLExtras.h"
using namespace mlir;
using namespace mlir::linalg;
//////////////////////////////////////////////////////////////////////////////
// BufferAllocOp
//////////////////////////////////////////////////////////////////////////////
void mlir::linalg::BufferAllocOp::build(Builder *b, OperationState *result,
Type type, Value *size) {
result->addOperands({size});
result->addTypes(type);
}
LogicalResult mlir::linalg::BufferAllocOp::verify() {
if (!size() || !size()->getType().isa<IndexType>())
return emitOpError("first operand should be of type index");
if (!VectorType::isValidElementType(getElementType()) &&
!getElementType().isa<VectorType>())
return emitOpError("unsupported buffer element type");
return success();
}
// A BufferAllocOp prints as:
//
// ```{.mlir}
// linalg.alloc %0 : !linalg.buffer<f32>
// ```
void mlir::linalg::BufferAllocOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *size() << " : " << getType();
}
ParseResult mlir::linalg::BufferAllocOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType sizeInfo;
BufferType bufferType;
auto indexTy = parser->getBuilder().getIndexType();
if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
return failure();
if (bufferType.getElementType() != parser->getBuilder().getF32Type())
return parser->emitError(parser->getNameLoc(),
"Only buffer<f32> supported until "
"mlir::linalg::Parser pieces are exposed");
return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
parser->addTypeToList(bufferType, result->types));
}
//////////////////////////////////////////////////////////////////////////////
// BufferDeallocOp
//////////////////////////////////////////////////////////////////////////////
void mlir::linalg::BufferDeallocOp::build(Builder *b, OperationState *result,
Value *buffer) {
result->addOperands({buffer});
}
LogicalResult mlir::linalg::BufferDeallocOp::verify() {
if (!getBuffer()->getType())
return emitOpError("first operand should be of type buffer");
return success();
}
// A BufferDeallocOp prints as:
//
// ```{.mlir}
// linalg.dealloc %0 : !linalg.buffer<f32>
// ```
void mlir::linalg::BufferDeallocOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getBuffer() << " : " << getBufferType();
}
ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType sizeInfo;
BufferType bufferType;
return failure(
parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
parser->resolveOperands(sizeInfo, bufferType, result->operands));
}
////////////////////////////////////////////////////////////////////////////////
// LoadOp.
////////////////////////////////////////////////////////////////////////////////
void mlir::linalg::LoadOp::build(Builder *b, OperationState *result,
Value *view, ArrayRef<Value *> indices) {
auto viewType = view->getType().cast<ViewType>();
result->addOperands(view);
result->addOperands(indices);
result->addTypes(viewType.getElementType());
}
// A LoadOp prints as:
//
// ```{.mlir}
// %0 = linalg.load %V[%c0] : !linalg.view<?xf32>
// ```
void mlir::linalg::LoadOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getView() << '[';
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << getViewType();
}
ParseResult mlir::linalg::LoadOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType viewInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
ViewType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return failure(
parser->parseOperand(viewInfo) ||
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(viewInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type.getElementType(), result->types));
}
LogicalResult mlir::linalg::LoadOp::verify() {
if (getNumOperands() == 0)
return emitOpError("expected a view to load from");
auto viewType = getView()->getType().dyn_cast<ViewType>();
if (!viewType)
return emitOpError("first operand must be a view");
if (getType() != viewType.getElementType())
return emitOpError("result type must match element type of the view");
if (getRank() != getNumOperands() - 1)
return emitOpError("incorrect number of indices for load");
for (auto *idx : getIndices())
if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type");
return success();
}
//////////////////////////////////////////////////////////////////////////////
// RangeOp
//////////////////////////////////////////////////////////////////////////////
void mlir::linalg::RangeOp::build(Builder *b, OperationState *result,
Value *min, Value *max, Value *step) {
result->addOperands({min, max, step});
result->addTypes({RangeType::get(b->getContext())});
}
// Verification is simply that a RangeOp takes 3 index ssa-value.
LogicalResult mlir::linalg::RangeOp::verify() {
if (!min() || !min()->getType().isa<IndexType>())
return emitOpError("first operand should be of type index");
if (!max() || !max()->getType().isa<IndexType>())
return emitOpError("second operand should be of type index");
if (!step() || !step()->getType().isa<IndexType>())
return emitOpError("third operand should be of type index");
return success();
}
// A RangeOp prints as:
//
// ```{.mlir}
// linalg.range %0:%1:%2 : !linalg.range
// ```
void mlir::linalg::RangeOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
<< " : " << getType();
}
ParseResult mlir::linalg::RangeOp::parse(OpAsmParser *parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return failure(
parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type, result->types));
}
//////////////////////////////////////////////////////////////////////////////
// SliceOp
//////////////////////////////////////////////////////////////////////////////
void mlir::linalg::SliceOp::build(Builder *b, OperationState *result,
Value *base, ArrayRef<Value *> indexings) {
result->addOperands({base});
result->addOperands(indexings);
ViewType viewType = base->getType().cast<ViewType>();
unsigned rank = viewType.getRank();
for (auto *i : indexings)
if (!i->getType().isa<RangeType>())
rank--;
Type elementType = viewType.getElementType();
result->addTypes({ViewType::get(b->getContext(), elementType, rank)});
}
LogicalResult mlir::linalg::SliceOp::verify() {
if (llvm::empty(getOperands()))
return emitOpError(
"requires at least a view operand followed by 'rank' indices");
if (!dyn_cast_or_null<ViewOp>(getOperand(0)->getDefiningOp()))
return emitOpError("first operand must come from a ViewOp");
unsigned rank = getBaseViewRank();
if (llvm::size(getIndexings()) != rank) {
return emitOpError("requires at least a view operand followed by ")
<< rank << " indexings";
}
unsigned index = 0;
for (auto indexing : getIndexings()) {
if (!indexing->getType().isa<RangeType>() &&
!indexing->getType().isa<IndexType>()) {
return emitOpError() << index
<< "^th index must be of range or index type";
}
if (indexing->getType().isa<IndexType>())
--rank;
++index;
}
if (getRank() != rank) {
return emitOpError()
<< "the rank of the view must be the number of its range indices ("
<< rank << ") but got: " << getRank();
}
return success();
}
ParseResult mlir::linalg::SliceOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType baseInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
SmallVector<Type, 8> types;
if (parser->parseOperand(baseInfo) ||
parser->parseOperandList(indexingsInfo, -1,
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types))
return failure();
if (types.size() != 2 + indexingsInfo.size())
return parser->emitError(parser->getNameLoc(),
"unexpected number of types ");
ViewType baseViewType = types[0].dyn_cast<ViewType>();
if (!baseViewType)
return parser->emitError(parser->getNameLoc(),
"view type expected for first type");
if (indexingsInfo.size() != baseViewType.getRank())
return parser->emitError(parser->getNameLoc(), "expected ")
<< baseViewType.getRank() << " indexings";
ViewType viewType = types.back().dyn_cast<ViewType>();
if (!viewType)
return parser->emitError(parser->getNameLoc(), "view type expected");
ArrayRef<Type> indexingTypes =
ArrayRef<Type>(types).drop_front(1).drop_back(1);
if (indexingTypes.size() != baseViewType.getRank())
return parser->emitError(parser->getNameLoc(), "expected ")
<< baseViewType.getRank() << " indexing types";
return failure(
parser->resolveOperand(baseInfo, baseViewType, result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, indexingTypes,
indexingsInfo.front().location,
result->operands)) ||
parser->addTypeToList(viewType, result->types));
}
// A SliceOp prints as:
//
// ```{.mlir}
// linalg.slice %0[%1, %2] :
// !linalg.view<?x?xf32>, [indexing-types], !linalg.view<?x?xf32>
// ```
//
// Where %0 is an ssa-value holding a view created from a buffer, %1 and %2 are
// ssa-value each holding a range.
void mlir::linalg::SliceOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getBaseView() << "[";
interleave(
getIndexings().begin(), getIndexings().end(), [p](Value *v) { *p << *v; },
[p]() { *p << ", "; });
*p << "] : " << getBaseViewType();
for (auto indexing : getIndexings()) {
*p << ", " << indexing->getType();
}
*p << ", " << getType();
}
ViewOp mlir::linalg::SliceOp::getBaseViewOp() {
return getOperand(0)->getDefiningOp()->cast<ViewOp>();
}
ViewType mlir::linalg::SliceOp::getBaseViewType() {
return getBaseViewOp().getType().cast<ViewType>();
}
SmallVector<Value *, 8> mlir::linalg::SliceOp::getRanges() {
llvm::SmallVector<Value *, 8> res;
for (auto *operand : getIndexings()) {
if (!operand->getType().isa<IndexType>()) {
res.push_back(operand);
}
}
return res;
}
////////////////////////////////////////////////////////////////////////////////
// StoreOp.
////////////////////////////////////////////////////////////////////////////////
void mlir::linalg::StoreOp::build(Builder *b, OperationState *result,
Value *valueToStore, Value *view,
ArrayRef<Value *> indices) {
result->addOperands(valueToStore);
result->addOperands(view);
result->addOperands(indices);
}
// A StoreOp prints as:
//
// ```{.mlir}
// linalg.store %f, %V[%c0] : !linalg.view<?xf32>
// ```
void mlir::linalg::StoreOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getValueToStore();
*p << ", " << *getView() << '[';
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << getViewType();
}
ParseResult mlir::linalg::StoreOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType viewInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
ViewType viewType;
auto affineIntTy = parser->getBuilder().getIndexType();
return failure(
parser->parseOperand(storeValueInfo) || parser->parseComma() ||
parser->parseOperand(viewInfo) ||
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(viewType) ||
parser->resolveOperand(storeValueInfo, viewType.getElementType(),
result->operands) ||
parser->resolveOperand(viewInfo, viewType, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands));
}
LogicalResult mlir::linalg::StoreOp::verify() {
if (getNumOperands() < 2)
return emitOpError("expected a value to store and a view");
// Second operand is a memref type.
auto viewType = getView()->getType().dyn_cast<ViewType>();
if (!viewType)
return emitOpError("second operand must be a view");
// First operand must have same type as memref element type.
if (getValueToStore()->getType() != viewType.getElementType())
return emitOpError("first operand must have same element type as the view");
if (getNumOperands() != 2 + viewType.getRank())
return emitOpError("store index operand count not equal to view rank");
for (auto *idx : getIndices())
if (!idx->getType().isIndex())
return emitOpError("index to store must have 'index' type");
return success();
}
//////////////////////////////////////////////////////////////////////////////
// ViewOp
//////////////////////////////////////////////////////////////////////////////
void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
Value *buffer, ArrayRef<Value *> indexings) {
BufferType bufferType = buffer->getType().cast<BufferType>();
result->addOperands({buffer});
result->addOperands(indexings);
assert(
std::none_of(indexings.begin(), indexings.end(),
[](Value *v) { return !v->getType().isa<RangeType>(); }) &&
"linalg.view takes only arguments of type linalg.range");
Type elementType = bufferType.getElementType();
result->addTypes(
{ViewType::get(b->getContext(), elementType, indexings.size())});
}
LogicalResult mlir::linalg::ViewOp::verify() {
if (llvm::empty(getOperands()))
return emitOpError(
"requires at least a buffer operand followed by indexings");
auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
if (!bufferType)
return emitOpError("first operand must be of BufferType");
unsigned index = 0;
for (auto indexing : getIndexings()) {
if (!indexing->getType().isa<RangeType>()) {
return emitOpError() << index << "^th index must be of range type";
}
++index;
}
if (getViewType().getRank() != index)
return emitOpError()
<< "the rank of the view must be the number of its indexings";
return success();
}
ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
Type type;
if (parser->parseOperand(bufferInfo) ||
parser->parseOperandList(indexingsInfo, -1,
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return failure();
ViewType viewType = type.dyn_cast<ViewType>();
if (!viewType)
return parser->emitError(parser->getNameLoc(), "view type expected");
if (viewType.getRank() != indexingsInfo.size())
return parser->emitError(parser->getNameLoc(), "expected")
<< viewType.getRank() << " range indexings";
return failure(
parser->resolveOperand(
bufferInfo,
BufferType::get(type.getContext(), viewType.getElementType()),
result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, RangeType::get(type.getContext()),
result->operands)) ||
parser->addTypeToList(viewType, result->types));
}
// A ViewOp prints as:
//
// ```{.mlir}
// linalg.view %0[%1, %2] : !linalg.view<?x?xf32>
// ```
//
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
// holding a range.
void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getSupportingBuffer() << "[";
interleave(
getIndexings().begin(), getIndexings().end(), [&](Value *v) { *p << *v; },
[&]() { *p << ", "; });
*p << "] : " << getType();
}
namespace mlir {
namespace linalg {
namespace impl {
void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op);
ParseResult parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result);
void printBufferSizeOp(OpAsmPrinter *p, Operation *op);
ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result);
} // namespace impl
} // namespace linalg
/// Buffer size prints as:
///
/// ``` {.mlir}
/// %0 = linalg.buffer_size %arg0 : !linalg.buffer<f32>
/// ```
void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
*p << op->cast<BufferSizeOp>().getOperationName() << " "
<< *op->getOperand(0);
p->printOptionalAttrDict(op->getAttrs());
*p << " : " << op->getOperand(0)->getType();
}
ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType op;
Type type;
return failure(parser->parseOperand(op) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(op, type, result->operands) ||
parser->addTypeToList(parser->getBuilder().getIndexType(),
result->types));
}
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
} // namespace mlir
// A LinalgLibraryOp prints as:
//
// ```{.mlir}
// concrete_op_name (ssa-inputs, ssa-outputs) : view-types
// ```
//
// for example:
//
// ```
// linalg.matmul(%0, %1, %2) :
// !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
// ```
//
// Where %0, %1 and %2 are ssa-values of type ViewType.
void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
*p << op->getName().getStringRef() << "(";
interleave(
op->getOperands().begin(), op->getOperands().end(),
[&](Value *v) { *p << *v; }, [&]() { *p << ", "; });
*p << ") : ";
interleave(
op->getOperands().begin(), op->getOperands().end(),
[&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
}
ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<Type, 3> types;
return failure(
parser->parseOperandList(ops, -1, OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types) ||
parser->resolveOperands(ops, types, parser->getNameLoc(),
result->operands));
}
// Ideally this should all be Tablegen'd but there is no good story for
// AffineMap for now.
SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
MLIRContext *context = op->getContext();
auto i = getAffineDimExpr(0, context);
auto j = getAffineDimExpr(1, context);
auto k = getAffineDimExpr(2, context);
if (op->isa<DotOp>())
// A(r_i) * B(r_i) -> C()
return SmallVector<AffineMap, 4>{AffineMap::get(1, 0, {i}, {}),
AffineMap::get(1, 0, {i}, {}),
AffineMap()};
if (op->isa<MatvecOp>())
// A(i, r_j) * B(r_j) -> C(i)
return SmallVector<AffineMap, 4>{AffineMap::get(2, 0, {i, j}, {}),
AffineMap::get(2, 0, {j}, {}),
AffineMap::get(2, 0, {i}, {})};
if (op->isa<MatmulOp>())
// A(i, r_j) * B(r_j) -> C(i)
return SmallVector<AffineMap, 4>{AffineMap::get(3, 0, {i, k}, {}),
AffineMap::get(3, 0, {k, j}, {}),
AffineMap::get(3, 0, {i, j}, {})};
llvm_unreachable("Missing loopToOperandRangesMaps for op");
}