[spirv] Add spv.loop
SPIR-V can explicitly declare structured control-flow constructs using merge
instructions. These explicitly declare a header block before the control
flow diverges and a merge block where control flow subsequently converges.
These blocks delimit constructs that must nest, and can only be entered
and exited in structured ways.
Instead of having a `spv.LoopMerge` op to directly model loop merge
instruction for indicating the merge and continue target, we use regions
to delimit the boundary of the loop: the merge target is the next op
following the `spv.loop` op and the continue target is the block that
has a back-edge pointing to the entry block inside the `spv.loop`'s region.
This way it's easier to discover all blocks belonging to a construct and
it plays nicer with the MLIR system.
Updated the SPIR-V.md doc.
PiperOrigin-RevId: 267431010
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 7dea586..0accb05 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -144,6 +144,7 @@
def SPV_OC_OpFUnordLessThanEqual : I32EnumAttrCase<"OpFUnordLessThanEqual", 189>;
def SPV_OC_OpFOrdGreaterThanEqual : I32EnumAttrCase<"OpFOrdGreaterThanEqual", 190>;
def SPV_OC_OpFUnordGreaterThanEqual : I32EnumAttrCase<"OpFUnordGreaterThanEqual", 191>;
+def SPV_OC_OpLoopMerge : I32EnumAttrCase<"OpLoopMerge", 246>;
def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
@@ -173,9 +174,9 @@
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
- SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpLabel,
- SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
- SPV_OC_OpReturnValue
+ SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
+ SPV_OC_OpLoopMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
+ SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
@@ -924,6 +925,28 @@
let cppNamespace = "::mlir::spirv";
}
+def SPV_LC_None : I32EnumAttrCase<"None", 0x0000>;
+def SPV_LC_Unroll : I32EnumAttrCase<"Unroll", 0x0001>;
+def SPV_LC_DontUnroll : I32EnumAttrCase<"DontUnroll", 0x0002>;
+def SPV_LC_DependencyInfinite : I32EnumAttrCase<"DependencyInfinite", 0x0004>;
+def SPV_LC_DependencyLength : I32EnumAttrCase<"DependencyLength", 0x0008>;
+def SPV_LC_MinIterations : I32EnumAttrCase<"MinIterations", 0x0010>;
+def SPV_LC_MaxIterations : I32EnumAttrCase<"MaxIterations", 0x0020>;
+def SPV_LC_IterationMultiple : I32EnumAttrCase<"IterationMultiple", 0x0040>;
+def SPV_LC_PeelCount : I32EnumAttrCase<"PeelCount", 0x0080>;
+def SPV_LC_PartialCount : I32EnumAttrCase<"PartialCount", 0x0100>;
+
+def SPV_LoopControlAttr :
+ I32EnumAttr<"LoopControl", "valid SPIR-V LoopControl", [
+ SPV_LC_None, SPV_LC_Unroll, SPV_LC_DontUnroll, SPV_LC_DependencyInfinite,
+ SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations,
+ SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount
+ ]> {
+ let returnType = "::mlir::spirv::LoopControl";
+ let convertFromStorage = "static_cast<::mlir::spirv::LoopControl>($_self.getInt())";
+ let cppNamespace = "::mlir::spirv";
+}
+
def SPV_MA_None : I32EnumAttrCase<"None", 0x0000>;
def SPV_MA_Volatile : I32EnumAttrCase<"Volatile", 0x0001>;
def SPV_MA_Aligned : I32EnumAttrCase<"Aligned", 0x0002>;
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
index 0927684..ffefa14 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -137,6 +137,71 @@
// -----
+def SPV_LoopOp : SPV_Op<"loop"> {
+ let summary = "Define a structured loop.";
+
+ let description = [{
+ SPIR-V can explicitly declare structured control-flow constructs using merge
+ instructions. These explicitly declare a header block before the control
+ flow diverges and a merge block where control flow subsequently converges.
+ These blocks delimit constructs that must nest, and can only be entered
+ and exited in structured ways. See "2.11. Structured Control Flow" of the
+ SPIR-V spec for more details.
+
+ Instead of having a `spv.LoopMerge` op to directly model loop merge
+ instruction for indicating the merge and continue target, we use regions
+ to delimit the boundary of the loop: the merge target is the next op
+ following the `spv.loop` op and the continue target is the block that
+ has a back-edge pointing to the entry block inside the `spv.loop`'s region.
+ This way it's easier to discover all blocks belonging to a construct and
+ it plays nicer with the MLIR system.
+
+ The `spv.loop` region should contain at least four blocks: one entry block,
+ one loop header block, one loop continue block, one loop merge block.
+ The entry block should be the first block and it should jump to the loop
+ header block, which is the second block. The loop merge block should be the
+ last block. The merge block should only contain a `spv._merge` op.
+ The continue block should be the second to last block and it should have a
+ branch to the loop header block. The loop continue block should be the only
+ block, except the entry block, branching to the header block.
+ }];
+
+ let arguments = (ins
+ SPV_LoopControlAttr:$loop_control
+ );
+
+ let results = (outs);
+
+ let regions = (region AnyRegion:$body);
+
+ let hasOpcode = 0;
+}
+
+// -----
+
+def SPV_MergeOp : SPV_Op<"_merge", [HasParent<"LoopOp">, Terminator]> {
+ let summary = "A special terminator for merging a structured selection/loop.";
+
+ let description = [{
+ We use `spv.selection`/`spv.loop` for modelling structured selection/loop.
+ This op is a terminator used inside their regions to mean jumping to the
+ merge point, which is the next op following the `spv.selection` or
+ `spv.loop` op. This op does not have a corresponding instruction in the
+ SPIR-V binary format; it's solely for structural purpose.
+ }];
+
+ let arguments = (ins);
+
+ let results = (outs);
+
+ let parser = [{ return parseNoIOOp(parser, result); }];
+ let printer = [{ printNoIOOp(getOperation(), p); }];
+
+ let hasOpcode = 0;
+}
+
+// -----
+
def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> {
let summary = "Return with no value from a function with void return type.";
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 3338c10..0873eb0 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1018,6 +1018,143 @@
}
//===----------------------------------------------------------------------===//
+// spv.loop
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLoopOp(OpAsmParser *parser, OperationState *state) {
+ // TODO(antiagainst): support loop control properly
+ Builder builder = parser->getBuilder();
+ state->addAttribute("loop_control",
+ builder.getI32IntegerAttr(
+ static_cast<uint32_t>(spirv::LoopControl::None)));
+
+ return parser->parseRegion(*state->addRegion(), /*arguments=*/{},
+ /*argTypes=*/{});
+}
+
+static void print(spirv::LoopOp loopOp, OpAsmPrinter *printer) {
+ auto *op = loopOp.getOperation();
+
+ *printer << spirv::LoopOp::getOperationName();
+ printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+}
+
+/// Returns true if the given `block` only contains one `spv._merge` op.
+static inline bool isMergeBlock(Block &block) {
+ return std::next(block.begin()) == block.end() &&
+ isa<spirv::MergeOp>(block.front());
+}
+
+/// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
+/// given `dstBlock`.
+static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
+ // Check that there is only one op in the `srcBlock`.
+ if (std::next(srcBlock.begin()) != srcBlock.end())
+ return false;
+
+ auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
+ return branchOp && branchOp.getSuccessor(0) == &dstBlock;
+}
+
+static LogicalResult verify(spirv::LoopOp loopOp) {
+ auto *op = loopOp.getOperation();
+
+ // We need to verify that the blocks follow the following layout:
+ //
+ // +-------------+
+ // | entry block |
+ // +-------------+
+ // |
+ // v
+ // +-------------+
+ // | loop header | <-----+
+ // +-------------+ |
+ // |
+ // ... |
+ // \ | / |
+ // v |
+ // +---------------+ |
+ // | loop continue | -----+
+ // +---------------+
+ //
+ // ...
+ // \ | /
+ // v
+ // +-------------+
+ // | merge block |
+ // +-------------+
+
+ auto ®ion = op->getRegion(0);
+ // Allow empty region as a degenerated case, which can come from
+ // optimizations.
+ if (region.empty())
+ return success();
+
+ // The last block is the merge block.
+ Block &merge = region.back();
+ if (!isMergeBlock(merge))
+ return loopOp.emitOpError(
+ "last block must be the merge block with only one 'spv._merge' op");
+
+ if (std::next(region.begin()) == region.end())
+ return loopOp.emitOpError(
+ "must have an entry block branching to the loop header block");
+ // The first block is the entry block.
+ Block &entry = region.front();
+
+ if (std::next(region.begin(), 2) == region.end())
+ return loopOp.emitOpError(
+ "must have a loop header block branched from the entry block");
+ // The second block is the loop header block.
+ Block &header = *std::next(region.begin(), 1);
+
+ if (!hasOneBranchOpTo(entry, header))
+ return loopOp.emitOpError(
+ "entry block must only have one 'spv.Branch' op to the second block");
+
+ if (std::next(region.begin(), 3) == region.end())
+ return loopOp.emitOpError(
+ "requires a loop continue block branching to the loop header block");
+ // The second to last block is the loop continue block.
+ Block &cont = *std::prev(region.end(), 2);
+
+ // Make sure that we have a branch from the loop continue block to the loop
+ // header block.
+ if (llvm::none_of(
+ llvm::seq<unsigned>(0, cont.getNumSuccessors()),
+ [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
+ return loopOp.emitOpError("second to last block must be the loop continue "
+ "block that branches to the loop header block");
+
+ // Make sure that no other blocks (except the entry and loop continue block)
+ // branches to the loop header block.
+ for (auto &block : llvm::make_range(std::next(region.begin(), 2),
+ std::prev(region.end(), 2))) {
+ for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
+ if (block.getSuccessor(i) == &header) {
+ return loopOp.emitOpError("can only have the entry and loop continue "
+ "block branching to the loop header block");
+ }
+ }
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv._merge
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(spirv::MergeOp mergeOp) {
+ Block &parentLastBlock = mergeOp.getParentRegion()->back();
+ if (mergeOp.getOperation() != parentLastBlock.getTerminator())
+ return mergeOp.emitOpError(
+ "can only be used in the last block of 'spv.loop'");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// spv.module
//===----------------------------------------------------------------------===//