Move BufferAllocOp and BufferDeallocOp to ODS
This CL also fixes a parsing issue in the BufferType, adds LLVM lowering support for handling the static constant buffer size and a roundtrip test.
PiperOrigin-RevId: 255834356
diff --git a/include/mlir/Linalg/IR/LinalgOps.h b/include/mlir/Linalg/IR/LinalgOps.h
index ff04c36..72eabc2 100644
--- a/include/mlir/Linalg/IR/LinalgOps.h
+++ b/include/mlir/Linalg/IR/LinalgOps.h
@@ -29,56 +29,6 @@
namespace linalg {
-/// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
-/// upon which a base view can be laid out to give it indexing semantics.
-/// "buffer_alloc" takes a single argument, the size of the buffer to allocate
-/// (in number of elements).
-///
-/// ```{.mlir}
-/// %0 = linalg.buffer_alloc %arg0 : !linalg.buffer<f32>
-/// ```
-class BufferAllocOp
- : public Op<BufferAllocOp, OpTrait::OneOperand, OpTrait::OneResult> {
-public:
- using Op::Op;
-
- // Hooks to customize the behavior of this op.
- static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; }
- static void build(Builder *b, OperationState *result, Type type, Value *size);
- LogicalResult verify();
- static ParseResult parse(OpAsmParser *parser, OperationState *result);
- void print(OpAsmPrinter *p);
-
- // Op-specific functionality.
- Value *size() { return getOperand(); }
- BufferType getBufferType() { return getType().cast<BufferType>(); }
- Type getElementType() { return getBufferType().getElementType(); }
-};
-
-/// The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
-///
-/// ```{.mlir}
-/// linalg.buffer_dealloc %0 : !linalg.buffer<f32>
-/// ```
-class BufferDeallocOp
- : public Op<BufferDeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
-public:
- using Op::Op;
-
- // Hooks to customize the behavior of this op.
- static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; }
- static void build(Builder *b, OperationState *result, Value *buffer);
- LogicalResult verify();
- static ParseResult parse(OpAsmParser *parser, OperationState *result);
- void print(OpAsmPrinter *p);
-
- // Op-specific functionality.
- Value *getBuffer() { return getOperand(); }
- BufferType getBufferType() {
- return getOperand()->getType().cast<BufferType>();
- }
-};
-
/// The "linalg.for" operation represents a loop nest taking 3 SSA value as
/// operands that represent the lower bound, upper bound and step respectively.
/// The operation defines an SSA value for its induction variable. It has one
diff --git a/include/mlir/Linalg/IR/LinalgOps.td b/include/mlir/Linalg/IR/LinalgOps.td
index 051d16e..49b75be 100644
--- a/include/mlir/Linalg/IR/LinalgOps.td
+++ b/include/mlir/Linalg/IR/LinalgOps.td
@@ -39,6 +39,65 @@
let parser = [{ return ::parse$cppClass(parser, result); }];
}
+def BufferAllocOp :
+ Linalg_Op<"buffer_alloc">,
+ Arguments<(ins Variadic<Index>:$size)>,
+ Results<(outs Buffer)> {
+ let summary = "buffer allocation operation";
+ let description = [{
+ The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
+ upon which a base view can be laid out to give it indexing semantics.
+ "buffer_alloc" takes a single argument, the size of the buffer to allocate
+ (in number of elements).
+
+ ```{.mlir}
+ %0 = linalg.buffer_alloc(%arg0) : !linalg.buffer<?xf32>
+ ```
+
+ The size argument may be omitted if it is statically known, in which case it
+ must be reflected in the type.
+
+ ```{.mlir}
+ %0 = linalg.buffer_alloc() : !linalg.buffer<4xf32>
+ ```
+ }];
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState *result, BufferType bufferType", [{
+ result->types.push_back(bufferType);
+ }]
+ >];
+ let extraClassDeclaration = [{
+ BufferType getBufferType() { return getType().cast<BufferType>(); }
+ Type getElementType() { return getBufferType().getElementType(); }
+ }];
+}
+
+def BufferDeallocOp :
+ Linalg_Op<"buffer_dealloc">,
+ Arguments<(ins Buffer:$buffer)>,
+ Results<(outs)> {
+ let summary = "buffer allocation operation";
+ let description = [{
+ The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
+
+ ```{.mlir}
+ linalg.buffer_dealloc %0 : !linalg.buffer<f32>
+ ```
+ }];
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState *result, BufferType bufferType", [{
+ result->types.push_back(bufferType);
+ }]
+ >];
+ let extraClassDeclaration = [{
+ BufferType getBufferType() {
+ return getOperand()->getType().cast<BufferType>();
+ }
+ }];
+ // Fully specified by traits.
+ let verifier = ?;
+}
+
def BufferSizeOp :
Linalg_Op<"buffer_size", [NoSideEffect]>,
Arguments<(ins Buffer)>,
diff --git a/lib/Linalg/IR/LinalgOps.cpp b/lib/Linalg/IR/LinalgOps.cpp
index 232379f..c23c601 100644
--- a/lib/Linalg/IR/LinalgOps.cpp
+++ b/lib/Linalg/IR/LinalgOps.cpp
@@ -37,76 +37,6 @@
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
-//////////////////////////////////////////////////////////////////////////////
-// BufferAllocOp
-//////////////////////////////////////////////////////////////////////////////
-void mlir::linalg::BufferAllocOp::build(Builder *b, OperationState *result,
- Type type, Value *size) {
- result->addOperands({size});
- result->addTypes(type);
-}
-
-LogicalResult mlir::linalg::BufferAllocOp::verify() {
- if (!size() || !size()->getType().isa<IndexType>())
- return emitOpError("first operand should be of type index");
- if (!VectorType::isValidElementType(getElementType()) &&
- !getElementType().isa<VectorType>())
- return emitOpError("unsupported buffer element type");
- return success();
-}
-
-// A BufferAllocOp prints as:
-//
-// ```{.mlir}
-// linalg.alloc %0 : !linalg.buffer<f32>
-// ```
-void mlir::linalg::BufferAllocOp::print(OpAsmPrinter *p) {
- *p << getOperationName() << " " << *size() << " : " << getType();
-}
-
-ParseResult mlir::linalg::BufferAllocOp::parse(OpAsmParser *parser,
- OperationState *result) {
- OpAsmParser::OperandType sizeInfo;
- BufferType bufferType;
- auto indexTy = parser->getBuilder().getIndexType();
- if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
- return failure();
- return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
- parser->addTypeToList(bufferType, result->types));
-}
-
-//////////////////////////////////////////////////////////////////////////////
-// BufferDeallocOp
-//////////////////////////////////////////////////////////////////////////////
-void mlir::linalg::BufferDeallocOp::build(Builder *b, OperationState *result,
- Value *buffer) {
- result->addOperands({buffer});
-}
-
-LogicalResult mlir::linalg::BufferDeallocOp::verify() {
- if (!getBuffer()->getType())
- return emitOpError("first operand should be of type buffer");
- return success();
-}
-
-// A BufferDeallocOp prints as:
-//
-// ```{.mlir}
-// linalg.dealloc %0 : !linalg.buffer<f32>
-// ```
-void mlir::linalg::BufferDeallocOp::print(OpAsmPrinter *p) {
- *p << getOperationName() << " " << *getBuffer() << " : " << getBufferType();
-}
-
-ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser,
- OperationState *result) {
- OpAsmParser::OperandType sizeInfo;
- BufferType bufferType;
- return failure(
- parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
- parser->resolveOperands(sizeInfo, bufferType, result->operands));
-}
-
////////////////////////////////////////////////////////////////////////////////
// ForOp.
////////////////////////////////////////////////////////////////////////////////
@@ -605,6 +535,60 @@
// LinalgOps.td), we define an overloaded `print` function and a
// parse`className` function.
+static void print(OpAsmPrinter *p, BufferAllocOp op) {
+ *p << op.getOperationName() << " ";
+ if (!llvm::empty(op.size()))
+ *p << *op.getOperand(0);
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.getBufferType();
+}
+
+static ParseResult parseBufferAllocOp(OpAsmParser *parser,
+ OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 1> sizeInfo;
+ BufferType bufferType;
+ auto indexTy = parser->getBuilder().getIndexType();
+ if (parser->parseOperandList(sizeInfo) || parser->parseColonType(bufferType))
+ return failure();
+ if (sizeInfo.empty())
+ return parser->addTypeToList(bufferType, result->types);
+ return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
+ parser->addTypeToList(bufferType, result->types));
+}
+
+static LogicalResult verify(BufferAllocOp op) {
+ if (!op.getBufferType().hasConstantSize()) {
+ if (llvm::size(op.size()) != 1 ||
+ !op.getOperand(0)->getType().isa<IndexType>())
+ return op.emitOpError(
+ "one operand of type index expected for dynamic buffer");
+ } else { // op.getBufferType().hasConstantSize()
+ if (!llvm::empty(op.size()))
+ return op.emitOpError("unexpected static buffer operand");
+ if (op.getBufferType().getBufferSize().getValue() <= 0)
+ return op.emitOpError("expected nonnegative static buffer size");
+ }
+ if (!VectorType::isValidElementType(op.getElementType()) &&
+ !op.getElementType().isa<VectorType>())
+ return op.emitOpError("unsupported buffer element type");
+ return success();
+}
+
+static void print(OpAsmPrinter *p, BufferDeallocOp op) {
+ *p << op.getOperationName() << " " << *op.buffer();
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.getBufferType();
+}
+
+static ParseResult parseBufferDeallocOp(OpAsmParser *parser,
+ OperationState *result) {
+ OpAsmParser::OperandType bufferInfo;
+ BufferType bufferType;
+ if (parser->parseOperand(bufferInfo) || parser->parseColonType(bufferType))
+ return failure();
+ return parser->resolveOperands(bufferInfo, bufferType, result->operands);
+}
+
static void print(OpAsmPrinter *p, BufferSizeOp op) {
*p << op.getOperationName() << " " << *op.getOperand();
p->printOptionalAttrDict(op.getAttrs());
diff --git a/lib/Linalg/IR/LinalgTypes.cpp b/lib/Linalg/IR/LinalgTypes.cpp
index 82be170..9cf9c55 100644
--- a/lib/Linalg/IR/LinalgTypes.cpp
+++ b/lib/Linalg/IR/LinalgTypes.cpp
@@ -34,8 +34,7 @@
mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addTypes<BufferType, RangeType, ViewType>();
- addOperations<BufferAllocOp, BufferDeallocOp, ForOp, LoadOp, RangeOp, StoreOp,
- SliceOp, ViewOp>();
+ addOperations<ForOp, LoadOp, RangeOp, StoreOp, SliceOp, ViewOp>();
addOperations<
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
@@ -119,8 +118,8 @@
// Check for '?'
int64_t bufferSize = -1;
if (!spec.consume_front("?")) {
- unsigned parsedBufferSize;
- if (!spec.consumeInteger(10, parsedBufferSize)) {
+ unsigned long long parsedBufferSize = 0;
+ if (spec.consumeInteger(10, parsedBufferSize)) {
emitError(loc, "expected buffer size to be an unsigned integer");
return Type();
}
diff --git a/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
index d43a2e6..a8099aa 100644
--- a/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -168,7 +168,7 @@
auto indexType = IndexType::get(op->getContext());
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
- auto int64Ty = lowering.convertType(operands[0]->getType());
+ auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Insert the `malloc` declaration if it is not already present.
auto *module = op->getFunction()->getModule();
Function *mallocFunc = module->getNamedFunction("malloc");
@@ -187,14 +187,19 @@
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
else
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
- auto elementPtrType = getPtrToElementType(
- allocOp.getResult()->getType().cast<BufferType>(), lowering);
+ auto bufferType = allocOp.getResult()->getType().cast<BufferType>();
+ auto elementPtrType = getPtrToElementType(bufferType, lowering);
auto bufferDescriptorType =
convertLinalgType(allocOp.getResult()->getType(), lowering);
// Emit IR for creating a new buffer descriptor with an underlying malloc.
edsc::ScopedContext context(rewriter, op->getLoc());
- Value *size = operands[0];
+ auto constantSize = bufferType.getBufferSize();
+ Value *size =
+ constantSize
+ ? constant(int64Ty, IntegerAttr::get(indexType, *constantSize))
+ .getValue()
+ : operands[0];
Value *allocSize =
mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
Value *allocated =
diff --git a/test/Linalg/roundtrip.mlir b/test/Linalg/roundtrip.mlir
index 2133b24..c927818 100644
--- a/test/Linalg/roundtrip.mlir
+++ b/test/Linalg/roundtrip.mlir
@@ -13,12 +13,16 @@
func @buffer(%arg0: index, %arg1: index) {
%0 = muli %arg0, %arg0 : index
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
+ %2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
+ linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>>
linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
return
}
// CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
+// CHECK-NEXT: %2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
+// CHECK-NEXT: linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
func @view_fun(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {