[spirv] NFC: Add getZero() and getOne() static method to ConstantOp
Getting constant zero or one is very common so it merits a special handy
method on spirv::ConstantOp itself.
PiperOrigin-RevId: 282832572
Change-Id: Ifb6fe54acef73f7ce2af6b995bb06b94a35fd294
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index a5b3fc2..8faa90c 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -56,8 +56,6 @@
protected:
/// Type lowering class.
SPIRVTypeConverter &typeConverter;
-
-private:
};
#include "mlir/Dialect/SPIRV/SPIRVLowering.h.inc"
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
index 104a479..353004b 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
@@ -26,6 +26,8 @@
#include "mlir/IR/Function.h"
namespace mlir {
+class OpBuilder;
+
namespace spirv {
#define GET_OP_CLASSES
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index 1ec825a..34b386e 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -118,6 +118,13 @@
let extraClassDeclaration = [{
// Returns true if a constant can be built for the given `type`.
static bool isBuildableWith(Type type);
+
+ // Creates a constant zero/one of the given `type` at the current insertion
+ // point of `builder` and returns it.
+ static spirv::ConstantOp getZero(Type type, Location loc,
+ OpBuilder *builder);
+ static spirv::ConstantOp getOne(Type type, Location loc,
+ OpBuilder *builder);
}];
let hasOpcode = 0;
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 62cabf6..4a3d25f 100644
--- a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -145,8 +145,7 @@
// Need to add a '0' at the beginning of the index list for accessing into the
// struct that wraps the nested array types.
- Value *zero = builder.create<spirv::ConstantOp>(
- loc, indexType, builder.getIntegerAttr(indexType, 0));
+ Value *zero = spirv::ConstantOp::getZero(indexType, loc, &builder);
SmallVector<Value *, 4> accessIndices;
accessIndices.reserve(1 + indices.size());
accessIndices.push_back(zero);
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index ae7643f..e824200 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1169,6 +1169,35 @@
return true;
}
+spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
+ OpBuilder *builder) {
+ if (auto intType = type.dyn_cast<IntegerType>()) {
+ unsigned width = intType.getWidth();
+ Attribute val;
+ if (width == 1)
+ return builder->create<spirv::ConstantOp>(loc, type,
+ builder->getBoolAttr(false));
+ return builder->create<spirv::ConstantOp>(
+ loc, type, builder->getIntegerAttr(type, APInt(width, 0)));
+ }
+
+ llvm_unreachable("unimplemented types for ConstantOp::getZero()");
+}
+
+spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
+ OpBuilder *builder) {
+ if (auto intType = type.dyn_cast<IntegerType>()) {
+ unsigned width = intType.getWidth();
+ if (width == 1)
+ return builder->create<spirv::ConstantOp>(loc, type,
+ builder->getBoolAttr(true));
+ return builder->create<spirv::ConstantOp>(
+ loc, type, builder->getIntegerAttr(type, APInt(width, 1)));
+ }
+
+ llvm_unreachable("unimplemented types for ConstantOp::getOne()");
+}
+
//===----------------------------------------------------------------------===//
// spv.ControlBarrier
//===----------------------------------------------------------------------===//
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index e9d36f6..d48b31f 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -194,8 +194,8 @@
if (isScalarOrVectorType(argType.value())) {
auto indexType =
typeConverter.convertType(IndexType::get(funcOp.getContext()));
- auto zero = rewriter.create<spirv::ConstantOp>(
- funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
+ auto zero =
+ spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter);
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
funcOp.getLoc(), replacement, zero.constant());
replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr,