[ODS] Define ConstantOp.
Add ConstantOp's Op Definition Spec. Currently we don't use convertFromStorage in the generated patterns and so needed to add a few casts to support patterns restricted to ElementsAttrs and to compensate for old rules where we defined ConstantOp to have a ElementsAttr to simplify writing the pattern.
--
PiperOrigin-RevId: 246361555
diff --git a/include/mlir/IR/OpBase.td b/include/mlir/IR/OpBase.td
index 275acff..f45f748 100644
--- a/include/mlir/IR/OpBase.td
+++ b/include/mlir/IR/OpBase.td
@@ -576,7 +576,7 @@
Attr<condition, description> {
let storageType = [{ ElementsAttr }];
let returnType = [{ ElementsAttr }];
- let convertFromStorage = "$_self";
+ let convertFromStorage = "cast<ElementsAttr>($_self)";
}
def ElementsAttr: ElementsAttrBase<CPred<"$_self.isa<ElementsAttr>()">,
diff --git a/include/mlir/StandardOps/Ops.h b/include/mlir/StandardOps/Ops.h
index cdf0ee6..fbb5cde 100644
--- a/include/mlir/StandardOps/Ops.h
+++ b/include/mlir/StandardOps/Ops.h
@@ -356,36 +356,6 @@
}
};
-/// The "constant" operation requires a single attribute named "value".
-/// It returns its value as an SSA value. For example:
-///
-/// %1 = "std.constant"(){value: 42} : i32
-/// %2 = "std.constant"(){value: @foo} : (f32)->f32
-///
-class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
- OpTrait::OneResult, OpTrait::HasNoSideEffect> {
-public:
- using Op::Op;
-
- /// Builds a constant op with the specified attribute value and result type.
- static void build(Builder *builder, OperationState *result, Type type,
- Attribute value);
-
- /// Builds a constant op with the specified attribute value and the
- /// attribute's type.
- static void build(Builder *builder, OperationState *result, Attribute value);
-
- Attribute getValue() { return getAttr("value"); }
-
- static StringRef getOperationName() { return "std.constant"; }
-
- // Hooks to customize behavior of this op.
- static bool parse(OpAsmParser *parser, OperationState *result);
- void print(OpAsmPrinter *p);
- LogicalResult verify();
- Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
-};
-
/// This is a refinement of the "constant" op for the case where it is
/// returning a float value of FloatType.
///
diff --git a/include/mlir/StandardOps/Ops.td b/include/mlir/StandardOps/Ops.td
index c83790c..12424f5 100644
--- a/include/mlir/StandardOps/Ops.td
+++ b/include/mlir/StandardOps/Ops.td
@@ -90,6 +90,27 @@
let hasFolder = 1;
}
+def ConstantOp : Op<Standard_Dialect, "constant", [NoSideEffect]> {
+ let summary = "constant";
+
+ let arguments = (ins AnyAttr:$value);
+ let results = (outs AnyType);
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState *result, Attribute value",
+ [{ build(builder, result, value.getType(), value); }]>];
+
+ let parser = [{ return parseConstantOp(parser, result); }];
+ let printer = [{ return printConstantOp(p, *this); }];
+ let verifier = [{ return ::verify(*this); }];
+
+ let extraClassDeclaration = [{
+ Attribute getValue() { return getAttr("value"); }
+ }];
+
+ let hasConstantFolder = 1;
+}
+
def DivFOp : FloatArithmeticOp<"divf"> {
let summary = "floating point division operation";
}
diff --git a/lib/StandardOps/Ops.cpp b/lib/StandardOps/Ops.cpp
index a05d83d..4cc192b 100644
--- a/lib/StandardOps/Ops.cpp
+++ b/lib/StandardOps/Ops.cpp
@@ -62,9 +62,8 @@
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
: Dialect(/*name=*/"std", context) {
addOperations<AllocOp, BranchOp, CallOp, CallIndirectOp, CmpIOp, CondBranchOp,
- ConstantOp, DeallocOp, DimOp, DmaStartOp, DmaWaitOp,
- ExtractElementOp, LoadOp, MemRefCastOp, ReturnOp, SelectOp,
- StoreOp, TensorCastOp,
+ DeallocOp, DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp,
+ LoadOp, MemRefCastOp, ReturnOp, SelectOp, StoreOp, TensorCastOp,
#define GET_OP_LIST
#include "mlir/StandardOps/Ops.cpp.inc"
>();
@@ -913,32 +912,18 @@
// Constant*Op
//===----------------------------------------------------------------------===//
-/// Builds a constant op with the specified attribute value and result type.
-void ConstantOp::build(Builder *builder, OperationState *result, Type type,
- Attribute value) {
- result->addAttribute("value", value);
- result->types.push_back(type);
-}
-
-/// Builds a constant with the specified attribute value and type extracted
-/// from the attribute. The attribute must have a type.
-void ConstantOp::build(Builder *builder, OperationState *result,
- Attribute value) {
- return build(builder, result, value.getType(), value);
-}
-
-void ConstantOp::print(OpAsmPrinter *p) {
+static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) {
*p << "constant ";
- p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"value"});
+ p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
- if (getAttrs().size() > 1)
+ if (op.getAttrs().size() > 1)
*p << ' ';
- *p << getValue();
- if (!getValue().isa<FunctionAttr>())
- *p << " : " << getType();
+ *p << op.getValue();
+ if (!op.getValue().isa<FunctionAttr>())
+ *p << " : " << op.getType();
}
-bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
+static bool parseConstantOp(OpAsmParser *parser, OperationState *result) {
Attribute valueAttr;
Type type;
@@ -963,16 +948,16 @@
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
-LogicalResult ConstantOp::verify() {
- auto value = getValue();
+static LogicalResult verify(ConstantOp &op) {
+ auto value = op.getValue();
if (!value)
- return emitOpError("requires a 'value' attribute");
+ return op.emitOpError("requires a 'value' attribute");
- auto type = this->getType();
+ auto type = op.getType();
if (type.isa<IntegerType>() || type.isIndex()) {
auto intAttr = value.dyn_cast<IntegerAttr>();
if (!intAttr)
- return emitOpError(
+ return op.emitOpError(
"requires 'value' to be an integer for an integer result type");
// If the type has a known bitwidth we verify that the value can be
@@ -981,36 +966,36 @@
auto bitwidth = type.cast<IntegerType>().getWidth();
auto intVal = intAttr.getValue();
if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
- return emitOpError("requires 'value' to be an integer within the range "
- "of the integer result type");
+ return op.emitOpError(
+ "requires 'value' to be an integer within the range "
+ "of the integer result type");
}
return success();
}
if (type.isa<FloatType>()) {
if (!value.isa<FloatAttr>())
- return emitOpError("requires 'value' to be a floating point constant");
+ return op.emitOpError("requires 'value' to be a floating point constant");
return success();
}
if (type.isa<VectorOrTensorType>()) {
if (!value.isa<ElementsAttr>())
- return emitOpError("requires 'value' to be a vector/tensor constant");
+ return op.emitOpError("requires 'value' to be a vector/tensor constant");
return success();
}
if (type.isa<FunctionType>()) {
if (!value.isa<FunctionAttr>())
- return emitOpError("requires 'value' to be a function reference");
+ return op.emitOpError("requires 'value' to be a function reference");
return success();
}
- auto attrType = value.getType();
- if (attrType != type)
- return emitOpError("requires the type of the 'value' attribute to match "
- "that of the operation result");
+ if (value.getType() != type)
+ return op.emitOpError("requires the type of the 'value' attribute to match "
+ "that of the operation result");
- return emitOpError(
+ return op.emitOpError(
"requires a result type that aligns with the 'value' attribute");
}