Move BufferAllocOp and BufferDeallocOp to ODS

This CL also fixes a parsing issue in the BufferType, adds LLVM lowering support for handling the static constant buffer size and a roundtrip test.

PiperOrigin-RevId: 255834356
diff --git a/include/mlir/Linalg/IR/LinalgOps.h b/include/mlir/Linalg/IR/LinalgOps.h
index ff04c36..72eabc2 100644
--- a/include/mlir/Linalg/IR/LinalgOps.h
+++ b/include/mlir/Linalg/IR/LinalgOps.h
@@ -29,56 +29,6 @@
 
 namespace linalg {
 
-/// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
-/// upon which a base view can be laid out to give it indexing semantics.
-/// "buffer_alloc" takes a single argument, the size of the buffer to allocate
-/// (in number of elements).
-///
-/// ```{.mlir}
-///     %0 = linalg.buffer_alloc %arg0 : !linalg.buffer<f32>
-/// ```
-class BufferAllocOp
-    : public Op<BufferAllocOp, OpTrait::OneOperand, OpTrait::OneResult> {
-public:
-  using Op::Op;
-
-  // Hooks to customize the behavior of this op.
-  static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; }
-  static void build(Builder *b, OperationState *result, Type type, Value *size);
-  LogicalResult verify();
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-
-  // Op-specific functionality.
-  Value *size() { return getOperand(); }
-  BufferType getBufferType() { return getType().cast<BufferType>(); }
-  Type getElementType() { return getBufferType().getElementType(); }
-};
-
-/// The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
-///
-/// ```{.mlir}
-///     linalg.buffer_dealloc %0 : !linalg.buffer<f32>
-/// ```
-class BufferDeallocOp
-    : public Op<BufferDeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
-public:
-  using Op::Op;
-
-  // Hooks to customize the behavior of this op.
-  static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; }
-  static void build(Builder *b, OperationState *result, Value *buffer);
-  LogicalResult verify();
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-
-  // Op-specific functionality.
-  Value *getBuffer() { return getOperand(); }
-  BufferType getBufferType() {
-    return getOperand()->getType().cast<BufferType>();
-  }
-};
-
 /// The "linalg.for" operation represents a loop nest taking 3 SSA value as
 /// operands that represent the lower bound, upper bound and step respectively.
 /// The operation defines an SSA value for its induction variable. It has one
diff --git a/include/mlir/Linalg/IR/LinalgOps.td b/include/mlir/Linalg/IR/LinalgOps.td
index 051d16e..49b75be 100644
--- a/include/mlir/Linalg/IR/LinalgOps.td
+++ b/include/mlir/Linalg/IR/LinalgOps.td
@@ -39,6 +39,65 @@
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
+def BufferAllocOp :
+    Linalg_Op<"buffer_alloc">,
+    Arguments<(ins Variadic<Index>:$size)>,
+    Results<(outs Buffer)> {
+  let summary = "buffer allocation operation";
+  let description = [{
+    The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
+    upon which a base view can be laid out to give it indexing semantics.
+    "buffer_alloc" takes a single argument, the size of the buffer to allocate
+    (in number of elements).
+
+    ```{.mlir}
+        %0 = linalg.buffer_alloc(%arg0) : !linalg.buffer<?xf32>
+    ```
+
+    The size argument may be omitted if it is statically known, in which case it
+    must be reflected in the type.
+
+    ```{.mlir}
+        %0 = linalg.buffer_alloc() : !linalg.buffer<4xf32>
+    ```
+  }];
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, BufferType bufferType", [{
+       result->types.push_back(bufferType);
+     }]
+  >];
+  let extraClassDeclaration = [{
+    BufferType getBufferType() { return getType().cast<BufferType>(); }
+    Type getElementType() { return getBufferType().getElementType(); }
+  }];
+}
+
+def BufferDeallocOp :
+    Linalg_Op<"buffer_dealloc">,
+    Arguments<(ins Buffer:$buffer)>,
+    Results<(outs)> {
+  let summary = "buffer allocation operation";
+  let description = [{
+    The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
+
+    ```{.mlir}
+        linalg.buffer_dealloc %0 : !linalg.buffer<f32>
+    ```
+  }];
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, BufferType bufferType", [{
+       result->types.push_back(bufferType);
+     }]
+  >];
+  let extraClassDeclaration = [{
+    BufferType getBufferType() {
+      return getOperand()->getType().cast<BufferType>();
+    }
+  }];
+  // Fully specified by traits.
+  let verifier = ?;
+}
+
 def BufferSizeOp :
     Linalg_Op<"buffer_size", [NoSideEffect]>,
     Arguments<(ins Buffer)>,
diff --git a/lib/Linalg/IR/LinalgOps.cpp b/lib/Linalg/IR/LinalgOps.cpp
index 232379f..c23c601 100644
--- a/lib/Linalg/IR/LinalgOps.cpp
+++ b/lib/Linalg/IR/LinalgOps.cpp
@@ -37,76 +37,6 @@
 using namespace mlir::edsc::intrinsics;
 using namespace mlir::linalg;
 
-//////////////////////////////////////////////////////////////////////////////
-// BufferAllocOp
-//////////////////////////////////////////////////////////////////////////////
-void mlir::linalg::BufferAllocOp::build(Builder *b, OperationState *result,
-                                        Type type, Value *size) {
-  result->addOperands({size});
-  result->addTypes(type);
-}
-
-LogicalResult mlir::linalg::BufferAllocOp::verify() {
-  if (!size() || !size()->getType().isa<IndexType>())
-    return emitOpError("first operand should be of type index");
-  if (!VectorType::isValidElementType(getElementType()) &&
-      !getElementType().isa<VectorType>())
-    return emitOpError("unsupported buffer element type");
-  return success();
-}
-
-// A BufferAllocOp prints as:
-//
-// ```{.mlir}
-//   linalg.alloc %0 : !linalg.buffer<f32>
-// ```
-void mlir::linalg::BufferAllocOp::print(OpAsmPrinter *p) {
-  *p << getOperationName() << " " << *size() << " : " << getType();
-}
-
-ParseResult mlir::linalg::BufferAllocOp::parse(OpAsmParser *parser,
-                                               OperationState *result) {
-  OpAsmParser::OperandType sizeInfo;
-  BufferType bufferType;
-  auto indexTy = parser->getBuilder().getIndexType();
-  if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
-    return failure();
-  return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
-                 parser->addTypeToList(bufferType, result->types));
-}
-
-//////////////////////////////////////////////////////////////////////////////
-// BufferDeallocOp
-//////////////////////////////////////////////////////////////////////////////
-void mlir::linalg::BufferDeallocOp::build(Builder *b, OperationState *result,
-                                          Value *buffer) {
-  result->addOperands({buffer});
-}
-
-LogicalResult mlir::linalg::BufferDeallocOp::verify() {
-  if (!getBuffer()->getType())
-    return emitOpError("first operand should be of type buffer");
-  return success();
-}
-
-// A BufferDeallocOp prints as:
-//
-// ```{.mlir}
-//   linalg.dealloc %0 : !linalg.buffer<f32>
-// ```
-void mlir::linalg::BufferDeallocOp::print(OpAsmPrinter *p) {
-  *p << getOperationName() << " " << *getBuffer() << " : " << getBufferType();
-}
-
-ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser,
-                                                 OperationState *result) {
-  OpAsmParser::OperandType sizeInfo;
-  BufferType bufferType;
-  return failure(
-      parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
-      parser->resolveOperands(sizeInfo, bufferType, result->operands));
-}
-
 ////////////////////////////////////////////////////////////////////////////////
 // ForOp.
 ////////////////////////////////////////////////////////////////////////////////
@@ -605,6 +535,60 @@
 // LinalgOps.td), we define an overloaded `print` function and a
 // parse`className` function.
 
+static void print(OpAsmPrinter *p, BufferAllocOp op) {
+  *p << op.getOperationName() << " ";
+  if (!llvm::empty(op.size()))
+    *p << *op.getOperand(0);
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getBufferType();
+}
+
+static ParseResult parseBufferAllocOp(OpAsmParser *parser,
+                                      OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 1> sizeInfo;
+  BufferType bufferType;
+  auto indexTy = parser->getBuilder().getIndexType();
+  if (parser->parseOperandList(sizeInfo) || parser->parseColonType(bufferType))
+    return failure();
+  if (sizeInfo.empty())
+    return parser->addTypeToList(bufferType, result->types);
+  return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
+                 parser->addTypeToList(bufferType, result->types));
+}
+
+static LogicalResult verify(BufferAllocOp op) {
+  if (!op.getBufferType().hasConstantSize()) {
+    if (llvm::size(op.size()) != 1 ||
+        !op.getOperand(0)->getType().isa<IndexType>())
+      return op.emitOpError(
+          "one operand of type index expected for dynamic buffer");
+  } else { // op.getBufferType().hasConstantSize()
+    if (!llvm::empty(op.size()))
+      return op.emitOpError("unexpected static buffer operand");
+    if (op.getBufferType().getBufferSize().getValue() <= 0)
+      return op.emitOpError("expected nonnegative static buffer size");
+  }
+  if (!VectorType::isValidElementType(op.getElementType()) &&
+      !op.getElementType().isa<VectorType>())
+    return op.emitOpError("unsupported buffer element type");
+  return success();
+}
+
+static void print(OpAsmPrinter *p, BufferDeallocOp op) {
+  *p << op.getOperationName() << " " << *op.buffer();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getBufferType();
+}
+
+static ParseResult parseBufferDeallocOp(OpAsmParser *parser,
+                                        OperationState *result) {
+  OpAsmParser::OperandType bufferInfo;
+  BufferType bufferType;
+  if (parser->parseOperand(bufferInfo) || parser->parseColonType(bufferType))
+    return failure();
+  return parser->resolveOperands(bufferInfo, bufferType, result->operands);
+}
+
 static void print(OpAsmPrinter *p, BufferSizeOp op) {
   *p << op.getOperationName() << " " << *op.getOperand();
   p->printOptionalAttrDict(op.getAttrs());
diff --git a/lib/Linalg/IR/LinalgTypes.cpp b/lib/Linalg/IR/LinalgTypes.cpp
index 82be170..9cf9c55 100644
--- a/lib/Linalg/IR/LinalgTypes.cpp
+++ b/lib/Linalg/IR/LinalgTypes.cpp
@@ -34,8 +34,7 @@
 mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
   addTypes<BufferType, RangeType, ViewType>();
-  addOperations<BufferAllocOp, BufferDeallocOp, ForOp, LoadOp, RangeOp, StoreOp,
-                SliceOp, ViewOp>();
+  addOperations<ForOp, LoadOp, RangeOp, StoreOp, SliceOp, ViewOp>();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Linalg/IR/LinalgOps.cpp.inc"
@@ -119,8 +118,8 @@
       // Check for '?'
       int64_t bufferSize = -1;
       if (!spec.consume_front("?")) {
-        unsigned parsedBufferSize;
-        if (!spec.consumeInteger(10, parsedBufferSize)) {
+        unsigned long long parsedBufferSize = 0;
+        if (spec.consumeInteger(10, parsedBufferSize)) {
           emitError(loc, "expected buffer size to be an unsigned integer");
           return Type();
         }
diff --git a/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
index d43a2e6..a8099aa 100644
--- a/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -168,7 +168,7 @@
     auto indexType = IndexType::get(op->getContext());
     auto voidPtrTy =
         LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
-    auto int64Ty = lowering.convertType(operands[0]->getType());
+    auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
     // Insert the `malloc` declaration if it is not already present.
     auto *module = op->getFunction()->getModule();
     Function *mallocFunc = module->getNamedFunction("malloc");
@@ -187,14 +187,19 @@
                     llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
     else
       elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
-    auto elementPtrType = getPtrToElementType(
-        allocOp.getResult()->getType().cast<BufferType>(), lowering);
+    auto bufferType = allocOp.getResult()->getType().cast<BufferType>();
+    auto elementPtrType = getPtrToElementType(bufferType, lowering);
     auto bufferDescriptorType =
         convertLinalgType(allocOp.getResult()->getType(), lowering);
 
     // Emit IR for creating a new buffer descriptor with an underlying malloc.
     edsc::ScopedContext context(rewriter, op->getLoc());
-    Value *size = operands[0];
+    auto constantSize = bufferType.getBufferSize();
+    Value *size =
+        constantSize
+            ? constant(int64Ty, IntegerAttr::get(indexType, *constantSize))
+                  .getValue()
+            : operands[0];
     Value *allocSize =
         mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
     Value *allocated =
diff --git a/test/Linalg/roundtrip.mlir b/test/Linalg/roundtrip.mlir
index 2133b24..c927818 100644
--- a/test/Linalg/roundtrip.mlir
+++ b/test/Linalg/roundtrip.mlir
@@ -13,12 +13,16 @@
 func @buffer(%arg0: index, %arg1: index) {
   %0 = muli %arg0, %arg0 : index
   %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
+  %2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
+  linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>>
   linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
   return
 }
 // CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
 //  CHECK-NEXT:  %0 = muli %arg0, %arg0 : index
 //  CHECK-NEXT:  %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
+//  CHECK-NEXT:  %2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
+//  CHECK-NEXT:  linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>>
 //  CHECK-NEXT:  linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
 
 func @view_fun(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {