Add Select operation to SPIR-V dialect.

The SelectOp models the semantics of OpSelect from SPIR-V spec.

PiperOrigin-RevId: 266849559
diff --git a/include/mlir/Dialect/SPIRV/SPIRVBase.td b/include/mlir/Dialect/SPIRV/SPIRVBase.td
index e90e816..7dea586 100644
--- a/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -121,6 +121,7 @@
 def SPV_OC_OpSMod                   : I32EnumAttrCase<"OpSMod", 139>;
 def SPV_OC_OpFRem                   : I32EnumAttrCase<"OpFRem", 140>;
 def SPV_OC_OpFMod                   : I32EnumAttrCase<"OpFMod", 141>;
+def SPV_OC_OpSelect                 : I32EnumAttrCase<"OpSelect", 169>;
 def SPV_OC_OpIEqual                 : I32EnumAttrCase<"OpIEqual", 170>;
 def SPV_OC_OpINotEqual              : I32EnumAttrCase<"OpINotEqual", 171>;
 def SPV_OC_OpUGreaterThan           : I32EnumAttrCase<"OpUGreaterThan", 172>;
@@ -164,7 +165,7 @@
       SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd,
       SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul,
       SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem,
-      SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpIEqual,
+      SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual,
       SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
       SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
       SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
@@ -217,16 +218,13 @@
     SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct
   ]>;
 
-class SPV_ScalarOrVectorOf<Type type> :
-    Type<Or<[type.predicate, VectorOf<[type]>.predicate]>,
-         "scalar/vector of " # type.description>;
+class SPV_ScalarOrVectorOf<Type type> : AnyTypeOf<[type, VectorOf<[type]>]>;
 
 // TODO(antiagainst): Use a more appropriate way to model optional operands
 class SPV_Optional<Type type> : Variadic<type>;
 
-def SPV_IsEntryPointType :
-    CPred<"$_self.isa<::mlir::spirv::EntryPointType>()">;
-def SPV_EntryPoint : Type<SPV_IsEntryPointType, "SPIR-V entry point type">;
+// TODO(ravishankarm): From 1.4, this should also include Composite type.
+def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>;
 
 //===----------------------------------------------------------------------===//
 // SPIR-V extension definitions
diff --git a/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
index 51781d8..1e9a547 100644
--- a/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
+++ b/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
@@ -634,6 +634,63 @@
 
 // -----
 
+def SPV_SelectOp : SPV_Op<"Select", []> {
+  let summary = [{
+    Select between two objects. Before version 1.4, results are only
+    computed per component.
+  }];
+
+  let description = [{
+    Before version 1.4, Result Type must be a pointer, scalar, or vector.
+
+     The types of Object 1 and Object 2 must be the same as Result Type.
+
+    Condition must be a scalar or vector of Boolean type.
+
+    If Condition is a scalar and true, the result is Object 1. If Condition
+    is a scalar and false, the result is Object 2.
+
+    If Condition is a vector, Result Type must be a vector with the same
+    number of components as Condition and the result is a mix of Object 1
+    and Object 2: When a component of Condition is true, the corresponding
+    component in the result is taken from Object 1, otherwise it is taken
+    from Object 2.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    scalar-type ::= integer-type | float-type | boolean-type
+    select-object-type ::= scalar-type
+                           | `vector<` integer-literal `x` scalar-type `>`
+                           | pointer-type
+    select-condition-type ::= boolean-type
+                              | `vector<` integer-literal `x` boolean-type `>`
+    select-op ::= ssa-id `=` `spv.Select` ssa-use, ssa-use, ssa-use
+                  `:` select-condition-type `,` select-object-type
+    ```
+
+    For example:
+
+    ```
+    %3 = spv.Select %0, %1, %2 : i1, f32
+    %3 = spv.Select %0, %1, %2 : i1, vector<3xi32>
+    %3 = spv.Select %0, %1, %2 : vector<3xi1>, vector<3xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_ScalarOrVectorOf<SPV_Bool>:$condition,
+    SPV_SelectType:$true_value,
+    SPV_SelectType:$false_value
+  );
+
+  let results = (outs
+    SPV_SelectType:$result
+  );
+}
+
+// -----
+
 def SPV_UGreaterThanOp : SPV_LogicalOp<"UGreaterThan", SPV_Integer, []> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is greater than  Operand 2.
diff --git a/lib/Dialect/SPIRV/SPIRVOps.cpp b/lib/Dialect/SPIRV/SPIRVOps.cpp
index 2b1248b..66b2f5d 100644
--- a/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1260,6 +1260,64 @@
 }
 
 //===----------------------------------------------------------------------===//
+// spv.Select
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *state) {
+  OpAsmParser::OperandType condition;
+  SmallVector<OpAsmParser::OperandType, 2> operands;
+  SmallVector<Type, 2> types;
+  auto loc = parser->getCurrentLocation();
+  if (parser->parseOperand(condition) || parser->parseComma() ||
+      parser->parseOperandList(operands, 2) ||
+      parser->parseColonTypeList(types)) {
+    return failure();
+  }
+  if (types.size() != 2) {
+    return parser->emitError(
+        loc, "need exactly two trailing types for select condition and object");
+  }
+  if (parser->resolveOperand(condition, types[0], state->operands) ||
+      parser->resolveOperands(operands, types[1], state->operands)) {
+    return failure();
+  }
+  return parser->addTypesToList(types[1], state->types);
+}
+
+static void print(spirv::SelectOp op, OpAsmPrinter *printer) {
+  *printer << spirv::SelectOp::getOperationName() << " ";
+
+  // Print the operands.
+  printer->printOperands(op.getOperands());
+
+  // Print colon and types.
+  *printer << " : " << op.condition()->getType() << ", "
+           << op.result()->getType();
+}
+
+static LogicalResult verify(spirv::SelectOp op) {
+  auto resultTy = op.result()->getType();
+  if (op.true_value()->getType() != resultTy) {
+    return op.emitOpError("result type and true value type must be the same");
+  }
+  if (op.false_value()->getType() != resultTy) {
+    return op.emitOpError("result type and false value type must be the same");
+  }
+  if (auto conditionTy = op.condition()->getType().dyn_cast<VectorType>()) {
+    auto resultVectorTy = resultTy.dyn_cast<VectorType>();
+    if (!resultVectorTy) {
+      return op.emitOpError("result expected to be of vector type when "
+                            "condition is of vector type");
+    }
+    if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
+      return op.emitOpError("result should have the same number of elements as "
+                            "the condition when condition is of vector type");
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // spv.specConstant
 //===----------------------------------------------------------------------===//
 
diff --git a/test/Dialect/SPIRV/Serialization/select.mlir b/test/Dialect/SPIRV/Serialization/select.mlir
new file mode 100644
index 0000000..aec39e8
--- /dev/null
+++ b/test/Dialect/SPIRV/Serialization/select.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
+
+spv.module "Logical" "VulkanKHR" {
+  spv.specConstant @condition_scalar = true
+  func @select() -> () {
+    %0 = spv.constant 4.0 : f32
+    %1 = spv.constant 5.0 : f32
+    %2 = spv._reference_of @condition_scalar : i1
+    // CHECK: spv.Select {{.*}}, {{.*}}, {{.*}} : i1, f32
+    %3 = spv.Select %2, %0, %1 : i1, f32
+    %4 = spv.constant dense<[2.0, 3.0, 4.0, 5.0]> : vector<4xf32>
+    %5 = spv.constant dense<[6.0, 7.0, 8.0, 9.0]> : vector<4xf32>
+    // CHECK: spv.Select {{.*}}, {{.*}}, {{.*}} : i1, vector<4xf32>
+    %6 = spv.Select %2, %4, %5 : i1, vector<4xf32>
+    %7 = spv.constant dense<[true, true, true, true]> : vector<4xi1>
+    // CHECK: spv.Select {{.*}}, {{.*}}, {{.*}} : vector<4xi1>, vector<4xf32>
+    %8 = spv.Select %7, %4, %5 : vector<4xi1>, vector<4xf32>
+    spv.Return
+  }
+}
\ No newline at end of file
diff --git a/test/Dialect/SPIRV/arithmetic-ops.mlir b/test/Dialect/SPIRV/arithmetic-ops.mlir
index ea12268..9369962 100644
--- a/test/Dialect/SPIRV/arithmetic-ops.mlir
+++ b/test/Dialect/SPIRV/arithmetic-ops.mlir
@@ -55,7 +55,7 @@
 // -----
 
 func @fmul_i32(%arg: i32) -> i32 {
-  // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
   %0 = spv.FMul %arg, %arg : i32
   return %0 : i32
 }
@@ -63,7 +63,7 @@
 // -----
 
 func @fmul_bf16(%arg: bf16) -> bf16 {
-  // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
   %0 = spv.FMul %arg, %arg : bf16
   return %0 : bf16
 }
@@ -71,7 +71,7 @@
 // -----
 
 func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
-  // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
   %0 = spv.FMul %arg, %arg : tensor<4xf32>
   return %0 : tensor<4xf32>
 }
diff --git a/test/Dialect/SPIRV/ops.mlir b/test/Dialect/SPIRV/ops.mlir
index 3e32b90..524f1d2 100644
--- a/test/Dialect/SPIRV/ops.mlir
+++ b/test/Dialect/SPIRV/ops.mlir
@@ -431,6 +431,110 @@
 // -----
 
 //===----------------------------------------------------------------------===//
+// spv.SelectOp
+//===----------------------------------------------------------------------===//
+
+func @select_op_bool(%arg0: i1) -> () {
+  %0 = spv.constant true
+  %1 = spv.constant false
+  // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, i1
+  %2 = spv.Select %arg0, %0, %1 : i1, i1
+  return
+}
+
+func @select_op_int(%arg0: i1) -> () {
+  %0 = spv.constant 2 : i32
+  %1 = spv.constant 3 : i32
+  // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, i32
+  %2 = spv.Select %arg0, %0, %1 : i1, i32
+  return
+}
+
+func @select_op_float(%arg0: i1) -> () {
+  %0 = spv.constant 2.0 : f32
+  %1 = spv.constant 3.0 : f32
+  // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, f32
+  %2 = spv.Select %arg0, %0, %1 : i1, f32
+  return
+}
+
+func @select_op_ptr(%arg0: i1) -> () {
+  %0 = spv.Variable : !spv.ptr<f32, Function>
+  %1 = spv.Variable : !spv.ptr<f32, Function>
+  // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, !spv.ptr<f32, Function>
+  %2 = spv.Select %arg0, %0, %1 : i1, !spv.ptr<f32, Function>
+  return
+}
+
+func @select_op_vec(%arg0: i1) -> () {
+  %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
+  %1 = spv.constant dense<[5.0, 6.0, 7.0]> : vector<3xf32>
+  // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, vector<3xf32>
+  %2 = spv.Select %arg0, %0, %1 : i1, vector<3xf32>
+  return
+}
+
+func @select_op_vec_condn_vec(%arg0: vector<3xi1>) -> () {
+  %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
+  %1 = spv.constant dense<[5.0, 6.0, 7.0]> : vector<3xf32>
+  // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : vector<3xi1>, vector<3xf32>
+  %2 = spv.Select %arg0, %0, %1 : vector<3xi1>, vector<3xf32>
+  return
+}
+
+// -----
+
+func @select_op(%arg0: i1) -> () {
+  %0 = spv.constant 2 : i32
+  %1 = spv.constant 3 : i32
+  // expected-error @+1 {{need exactly two trailing types for select condition and object}}
+  %2 = spv.Select %arg0, %0, %1 : i1
+  return
+}
+
+// -----
+
+func @select_op(%arg1: vector<3xi1>) -> () {
+  %0 = spv.constant 2 : i32
+  %1 = spv.constant 3 : i32
+  // expected-error @+1 {{result expected to be of vector type when condition is of vector type}}
+  %2 = spv.Select %arg1, %0, %1 : vector<3xi1>, i32
+  return
+}
+
+// -----
+
+func @select_op(%arg1: vector<4xi1>) -> () {
+  %0 = spv.constant dense<[2, 3, 4]> : vector<3xi32>
+  %1 = spv.constant dense<[5, 6, 7]> : vector<3xi32>
+  // expected-error @+1 {{result should have the same number of elements as the condition when condition is of vector type}}
+  %2 = spv.Select %arg1, %0, %1 : vector<4xi1>, vector<3xi32>
+  return
+}
+
+// -----
+
+func @select_op(%arg1: vector<4xi1>) -> () {
+  %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
+  %1 = spv.constant dense<[5, 6, 7]> : vector<3xi32>
+  // expected-error @+1 {{op result type and true value type must be the same}}
+  %2 = "spv.Select"(%arg1, %0, %1) : (vector<4xi1>, vector<3xf32>, vector<3xi32>) -> vector<3xi32>
+  return
+}
+
+// -----
+
+func @select_op(%arg1: vector<4xi1>) -> () {
+  %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
+  %1 = spv.constant dense<[5, 6, 7]> : vector<3xi32>
+  // expected-error @+1 {{op result type and false value type must be the same}}
+  %2 = "spv.Select"(%arg1, %1, %0) : (vector<4xi1>, vector<3xi32>, vector<3xf32>) -> vector<3xi32>
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // spv.StoreOp
 //===----------------------------------------------------------------------===//