[spirv] Add spv.loop

SPIR-V can explicitly declare structured control-flow constructs using merge
instructions. These explicitly declare a header block before the control
flow diverges and a merge block where control flow subsequently converges.
These blocks delimit constructs that must nest, and can only be entered
and exited in structured ways.

Instead of having a `spv.LoopMerge` op to directly model loop merge
instruction for indicating the merge and continue target, we use regions
to delimit the boundary of the loop: the merge target is the next op
following the `spv.loop` op and the continue target is the block that
has a back-edge pointing to the entry block inside the `spv.loop`'s region.
This way it's easier to discover all blocks belonging to a construct and
it plays nicer with the MLIR system.

Updated the SPIR-V.md doc.

PiperOrigin-RevId: 267431010
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 7dea586..0accb05 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -144,6 +144,7 @@
 def SPV_OC_OpFUnordLessThanEqual    : I32EnumAttrCase<"OpFUnordLessThanEqual", 189>;
 def SPV_OC_OpFOrdGreaterThanEqual   : I32EnumAttrCase<"OpFOrdGreaterThanEqual", 190>;
 def SPV_OC_OpFUnordGreaterThanEqual : I32EnumAttrCase<"OpFUnordGreaterThanEqual", 191>;
+def SPV_OC_OpLoopMerge              : I32EnumAttrCase<"OpLoopMerge", 246>;
 def SPV_OC_OpLabel                  : I32EnumAttrCase<"OpLabel", 248>;
 def SPV_OC_OpBranch                 : I32EnumAttrCase<"OpBranch", 249>;
 def SPV_OC_OpBranchConditional      : I32EnumAttrCase<"OpBranchConditional", 250>;
@@ -173,9 +174,9 @@
       SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
       SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
       SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
-      SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpLabel,
-      SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
-      SPV_OC_OpReturnValue
+      SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
+      SPV_OC_OpLoopMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
+      SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue
       ]> {
     let returnType = "::mlir::spirv::Opcode";
     let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
@@ -924,6 +925,28 @@
   let cppNamespace = "::mlir::spirv";
 }
 
+def SPV_LC_None               : I32EnumAttrCase<"None", 0x0000>;
+def SPV_LC_Unroll             : I32EnumAttrCase<"Unroll", 0x0001>;
+def SPV_LC_DontUnroll         : I32EnumAttrCase<"DontUnroll", 0x0002>;
+def SPV_LC_DependencyInfinite : I32EnumAttrCase<"DependencyInfinite", 0x0004>;
+def SPV_LC_DependencyLength   : I32EnumAttrCase<"DependencyLength", 0x0008>;
+def SPV_LC_MinIterations      : I32EnumAttrCase<"MinIterations", 0x0010>;
+def SPV_LC_MaxIterations      : I32EnumAttrCase<"MaxIterations", 0x0020>;
+def SPV_LC_IterationMultiple  : I32EnumAttrCase<"IterationMultiple", 0x0040>;
+def SPV_LC_PeelCount          : I32EnumAttrCase<"PeelCount", 0x0080>;
+def SPV_LC_PartialCount       : I32EnumAttrCase<"PartialCount", 0x0100>;
+
+def SPV_LoopControlAttr :
+    I32EnumAttr<"LoopControl", "valid SPIR-V LoopControl", [
+      SPV_LC_None, SPV_LC_Unroll, SPV_LC_DontUnroll, SPV_LC_DependencyInfinite,
+      SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations,
+      SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount
+    ]> {
+  let returnType = "::mlir::spirv::LoopControl";
+  let convertFromStorage = "static_cast<::mlir::spirv::LoopControl>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
 def SPV_MA_None                    : I32EnumAttrCase<"None", 0x0000>;
 def SPV_MA_Volatile                : I32EnumAttrCase<"Volatile", 0x0001>;
 def SPV_MA_Aligned                 : I32EnumAttrCase<"Aligned", 0x0002>;
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
index 0927684..ffefa14 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -137,6 +137,71 @@
 
 // -----
 
+def SPV_LoopOp : SPV_Op<"loop"> {
+  let summary = "Define a structured loop.";
+
+  let description = [{
+    SPIR-V can explicitly declare structured control-flow constructs using merge
+    instructions. These explicitly declare a header block before the control
+    flow diverges and a merge block where control flow subsequently converges.
+    These blocks delimit constructs that must nest, and can only be entered
+    and exited in structured ways. See "2.11. Structured Control Flow" of the
+    SPIR-V spec for more details.
+
+    Instead of having a `spv.LoopMerge` op to directly model loop merge
+    instruction for indicating the merge and continue target, we use regions
+    to delimit the boundary of the loop: the merge target is the next op
+    following the `spv.loop` op and the continue target is the block that
+    has a back-edge pointing to the entry block inside the `spv.loop`'s region.
+    This way it's easier to discover all blocks belonging to a construct and
+    it plays nicer with the MLIR system.
+
+    The `spv.loop` region should contain at least four blocks: one entry block,
+    one loop header block, one loop continue block, one loop merge block.
+    The entry block should be the first block and it should jump to the loop
+    header block, which is the second block. The loop merge block should be the
+    last block. The merge block should only contain a `spv._merge` op.
+    The continue block should be the second to last block and it should have a
+    branch to the loop header block. The loop continue block should be the only
+    block, except the entry block, branching to the header block.
+  }];
+
+  let arguments = (ins
+    SPV_LoopControlAttr:$loop_control
+  );
+
+  let results = (outs);
+
+  let regions = (region AnyRegion:$body);
+
+  let hasOpcode = 0;
+}
+
+// -----
+
+def SPV_MergeOp : SPV_Op<"_merge", [HasParent<"LoopOp">, Terminator]> {
+  let summary = "A special terminator for merging a structured selection/loop.";
+
+  let description = [{
+    We use `spv.selection`/`spv.loop` for modelling structured selection/loop.
+    This op is a terminator used inside their regions to mean jumping to the
+    merge point, which is the next op following the `spv.selection` or
+    `spv.loop` op. This op does not have a corresponding instruction in the
+    SPIR-V binary format; it's solely for structural purpose.
+  }];
+
+  let arguments = (ins);
+
+  let results = (outs);
+
+  let parser = [{ return parseNoIOOp(parser, result); }];
+  let printer = [{ printNoIOOp(getOperation(), p); }];
+
+  let hasOpcode = 0;
+}
+
+// -----
+
 def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> {
   let summary = "Return with no value from a function with void return type.";
 
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 3338c10..0873eb0 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1018,6 +1018,143 @@
 }
 
 //===----------------------------------------------------------------------===//
+// spv.loop
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLoopOp(OpAsmParser *parser, OperationState *state) {
+  // TODO(antiagainst): support loop control properly
+  Builder builder = parser->getBuilder();
+  state->addAttribute("loop_control",
+                      builder.getI32IntegerAttr(
+                          static_cast<uint32_t>(spirv::LoopControl::None)));
+
+  return parser->parseRegion(*state->addRegion(), /*arguments=*/{},
+                             /*argTypes=*/{});
+}
+
+static void print(spirv::LoopOp loopOp, OpAsmPrinter *printer) {
+  auto *op = loopOp.getOperation();
+
+  *printer << spirv::LoopOp::getOperationName();
+  printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
+                       /*printBlockTerminators=*/true);
+}
+
+/// Returns true if the given `block` only contains one `spv._merge` op.
+static inline bool isMergeBlock(Block &block) {
+  return std::next(block.begin()) == block.end() &&
+         isa<spirv::MergeOp>(block.front());
+}
+
+/// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
+/// given `dstBlock`.
+static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
+  // Check that there is only one op in the `srcBlock`.
+  if (std::next(srcBlock.begin()) != srcBlock.end())
+    return false;
+
+  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
+  return branchOp && branchOp.getSuccessor(0) == &dstBlock;
+}
+
+static LogicalResult verify(spirv::LoopOp loopOp) {
+  auto *op = loopOp.getOperation();
+
+  // We need to verify that the blocks follow the following layout:
+  //
+  //                     +-------------+
+  //                     | entry block |
+  //                     +-------------+
+  //                            |
+  //                            v
+  //                     +-------------+
+  //                     | loop header | <-----+
+  //                     +-------------+       |
+  //                                           |
+  //                           ...             |
+  //                          \ | /            |
+  //                            v              |
+  //                    +---------------+      |
+  //                    | loop continue | -----+
+  //                    +---------------+
+  //
+  //                           ...
+  //                          \ | /
+  //                            v
+  //                     +-------------+
+  //                     | merge block |
+  //                     +-------------+
+
+  auto &region = op->getRegion(0);
+  // Allow empty region as a degenerated case, which can come from
+  // optimizations.
+  if (region.empty())
+    return success();
+
+  // The last block is the merge block.
+  Block &merge = region.back();
+  if (!isMergeBlock(merge))
+    return loopOp.emitOpError(
+        "last block must be the merge block with only one 'spv._merge' op");
+
+  if (std::next(region.begin()) == region.end())
+    return loopOp.emitOpError(
+        "must have an entry block branching to the loop header block");
+  // The first block is the entry block.
+  Block &entry = region.front();
+
+  if (std::next(region.begin(), 2) == region.end())
+    return loopOp.emitOpError(
+        "must have a loop header block branched from the entry block");
+  // The second block is the loop header block.
+  Block &header = *std::next(region.begin(), 1);
+
+  if (!hasOneBranchOpTo(entry, header))
+    return loopOp.emitOpError(
+        "entry block must only have one 'spv.Branch' op to the second block");
+
+  if (std::next(region.begin(), 3) == region.end())
+    return loopOp.emitOpError(
+        "requires a loop continue block branching to the loop header block");
+  // The second to last block is the loop continue block.
+  Block &cont = *std::prev(region.end(), 2);
+
+  // Make sure that we have a branch from the loop continue block to the loop
+  // header block.
+  if (llvm::none_of(
+          llvm::seq<unsigned>(0, cont.getNumSuccessors()),
+          [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
+    return loopOp.emitOpError("second to last block must be the loop continue "
+                              "block that branches to the loop header block");
+
+  // Make sure that no other blocks (except the entry and loop continue block)
+  // branches to the loop header block.
+  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
+                                      std::prev(region.end(), 2))) {
+    for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
+      if (block.getSuccessor(i) == &header) {
+        return loopOp.emitOpError("can only have the entry and loop continue "
+                                  "block branching to the loop header block");
+      }
+    }
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv._merge
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(spirv::MergeOp mergeOp) {
+  Block &parentLastBlock = mergeOp.getParentRegion()->back();
+  if (mergeOp.getOperation() != parentLastBlock.getTerminator())
+    return mergeOp.emitOpError(
+        "can only be used in the last block of 'spv.loop'");
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // spv.module
 //===----------------------------------------------------------------------===//