[spirv] NFC: Add getZero() and getOne() static method to ConstantOp

Getting constant zero or one is very common so it merits a special handy
method on spirv::ConstantOp itself.

PiperOrigin-RevId: 282832572
Change-Id: Ifb6fe54acef73f7ce2af6b995bb06b94a35fd294
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index a5b3fc2..8faa90c 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -56,8 +56,6 @@
 protected:
   /// Type lowering class.
   SPIRVTypeConverter &typeConverter;
-
-private:
 };
 
 #include "mlir/Dialect/SPIRV/SPIRVLowering.h.inc"
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
index 104a479..353004b 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
@@ -26,6 +26,8 @@
 #include "mlir/IR/Function.h"
 
 namespace mlir {
+class OpBuilder;
+
 namespace spirv {
 
 #define GET_OP_CLASSES
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index 1ec825a..34b386e 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -118,6 +118,13 @@
   let extraClassDeclaration = [{
     // Returns true if a constant can be built for the given `type`.
     static bool isBuildableWith(Type type);
+
+    // Creates a constant zero/one of the given `type` at the current insertion
+    // point of `builder` and returns it.
+    static spirv::ConstantOp getZero(Type type, Location loc,
+                                     OpBuilder *builder);
+    static spirv::ConstantOp getOne(Type type, Location loc,
+                                    OpBuilder *builder);
   }];
 
   let hasOpcode = 0;
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 62cabf6..4a3d25f 100644
--- a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -145,8 +145,7 @@
 
   // Need to add a '0' at the beginning of the index list for accessing into the
   // struct that wraps the nested array types.
-  Value *zero = builder.create<spirv::ConstantOp>(
-      loc, indexType, builder.getIntegerAttr(indexType, 0));
+  Value *zero = spirv::ConstantOp::getZero(indexType, loc, &builder);
   SmallVector<Value *, 4> accessIndices;
   accessIndices.reserve(1 + indices.size());
   accessIndices.push_back(zero);
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index ae7643f..e824200 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1169,6 +1169,35 @@
   return true;
 }
 
+spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
+                                             OpBuilder *builder) {
+  if (auto intType = type.dyn_cast<IntegerType>()) {
+    unsigned width = intType.getWidth();
+    Attribute val;
+    if (width == 1)
+      return builder->create<spirv::ConstantOp>(loc, type,
+                                                builder->getBoolAttr(false));
+    return builder->create<spirv::ConstantOp>(
+        loc, type, builder->getIntegerAttr(type, APInt(width, 0)));
+  }
+
+  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
+}
+
+spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
+                                            OpBuilder *builder) {
+  if (auto intType = type.dyn_cast<IntegerType>()) {
+    unsigned width = intType.getWidth();
+    if (width == 1)
+      return builder->create<spirv::ConstantOp>(loc, type,
+                                                builder->getBoolAttr(true));
+    return builder->create<spirv::ConstantOp>(
+        loc, type, builder->getIntegerAttr(type, APInt(width, 1)));
+  }
+
+  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
+}
+
 //===----------------------------------------------------------------------===//
 // spv.ControlBarrier
 //===----------------------------------------------------------------------===//
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index e9d36f6..d48b31f 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -194,8 +194,8 @@
     if (isScalarOrVectorType(argType.value())) {
       auto indexType =
           typeConverter.convertType(IndexType::get(funcOp.getContext()));
-      auto zero = rewriter.create<spirv::ConstantOp>(
-          funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
+      auto zero =
+          spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter);
       auto loadPtr = rewriter.create<spirv::AccessChainOp>(
           funcOp.getLoc(), replacement, zero.constant());
       replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr,