[spirv] Implement inliner interface

We just need to implement a few interface hooks to DialectInlinerInterface
and CallOpInterface to gain the benefits of an inliner. :)

Right now only supports some trivial cases:
* Inlining single block with spv.Return/spv.ReturnValue
* Inlining multi block with spv.Return
* Inlining spv.selection/spv.loop without return ops

More advanced cases will require block argument and Phi support.

PiperOrigin-RevId: 275151132
diff --git a/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
index c442ee1..ae4b2be 100644
--- a/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ b/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -29,6 +29,11 @@
 include "mlir/SPIRV/SPIRVBase.td"
 #endif // SPIRV_BASE
 
+#ifdef MLIR_CALLINTERFACES
+#else
+include "mlir/Analysis/CallInterfaces.td"
+#endif // MLIR_CALLINTERFACES
+
 // -----
 
 def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
@@ -151,7 +156,8 @@
 
 // -----
 
-def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [InFunctionScope]> {
+def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [
+    InFunctionScope, DeclareOpInterfaceMethods<CallOpInterface>]> {
   let summary = "Call a function.";
 
   let description = [{
diff --git a/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index ba7c61b..d9e3787 100644
--- a/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -264,7 +264,8 @@
 }
 
 def SPV_ModuleOp : SPV_Op<"module",
-                          [SingleBlockImplicitTerminator<"ModuleEndOp">,
+                          [IsolatedFromAbove,
+                           SingleBlockImplicitTerminator<"ModuleEndOp">,
                            NativeOpTrait<"SymbolTable">]> {
   let summary = "The top-level op that defines a SPIR-V module";
 
diff --git a/lib/Dialect/SPIRV/SPIRVDialect.cpp b/lib/Dialect/SPIRV/SPIRVDialect.cpp
index af50c8e..96777b1 100644
--- a/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Parser.h"
 #include "mlir/Support/StringExtras.h"
+#include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
@@ -35,6 +36,67 @@
 using namespace mlir::spirv;
 
 //===----------------------------------------------------------------------===//
+// InlinerInterface
+//===----------------------------------------------------------------------===//
+
+/// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
+static inline bool containsReturn(Region &region) {
+  return llvm::any_of(region, [](Block &block) {
+    Operation *terminator = block.getTerminator();
+    return isa<spirv::ReturnOp>(terminator) ||
+           isa<spirv::ReturnValueOp>(terminator);
+  });
+}
+
+namespace {
+/// This class defines the interface for inlining within the SPIR-V dialect.
+struct SPIRVInlinerInterface : public DialectInlinerInterface {
+  using DialectInlinerInterface::DialectInlinerInterface;
+
+  /// Returns true if the given region 'src' can be inlined into the region
+  /// 'dest' that is attached to an operation registered to the current dialect.
+  bool isLegalToInline(Operation *op, Region *dest,
+                       BlockAndValueMapping &) const final {
+    // TODO(antiagainst): Enable inlining structured control flows with return.
+    if ((isa<spirv::SelectionOp>(op) || isa<spirv::LoopOp>(op)) &&
+        containsReturn(op->getRegion(0)))
+      return false;
+    // TODO(antiagainst): we need to filter OpKill here to avoid inlining it to
+    // a loop continue construct:
+    // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
+    // However OpKill is fragment shader specific and we don't support it yet.
+    return true;
+  }
+
+  /// Handle the given inlined terminator by replacing it with a new operation
+  /// as necessary.
+  void handleTerminator(Operation *op, Block *newDest) const final {
+    if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
+      OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
+      op->erase();
+    } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
+      llvm_unreachable("unimplemented spv.ReturnValue in inliner");
+    }
+  }
+
+  /// Handle the given inlined terminator by replacing it with a new operation
+  /// as necessary.
+  void handleTerminator(Operation *op,
+                        ArrayRef<Value *> valuesToRepl) const final {
+    // Only spv.ReturnValue needs to be handled here.
+    auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
+    if (!retValOp)
+      return;
+
+    // Replace the values directly with the return operands.
+    assert(valuesToRepl.size() == 1 &&
+           "spv.ReturnValue expected to only handle one result");
+    valuesToRepl.front()->replaceAllUsesWith(retValOp.value());
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
 // SPIR-V Dialect
 //===----------------------------------------------------------------------===//
 
@@ -48,6 +110,8 @@
 #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
       >();
 
+  addInterfaces<SPIRVInlinerInterface>();
+
   // Allow unknown operations because SPIR-V is extensible.
   allowUnknownOperations();
 }
diff --git a/lib/Dialect/SPIRV/SPIRVOps.cpp b/lib/Dialect/SPIRV/SPIRVOps.cpp
index 9662e05..7e51cc0 100644
--- a/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -21,6 +21,7 @@
 
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 
+#include "mlir/Analysis/CallInterfaces.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Builders.h"
@@ -1199,6 +1200,14 @@
   return success();
 }
 
+CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
+  return getAttrOfType<SymbolRefAttr>(kCallee);
+}
+
+Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
+  return arguments();
+}
+
 //===----------------------------------------------------------------------===//
 // spv.globalVariable
 //===----------------------------------------------------------------------===//
diff --git a/test/Dialect/SPIRV/Transforms/inlining.mlir b/test/Dialect/SPIRV/Transforms/inlining.mlir
new file mode 100644
index 0000000..9837d7b
--- /dev/null
+++ b/test/Dialect/SPIRV/Transforms/inlining.mlir
@@ -0,0 +1,182 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='spv.module(inline)' -mlir-disable-inline-simplify | FileCheck %s
+
+spv.module "Logical" "GLSL450" {
+  func @callee() {
+    spv.Return
+  }
+
+  // CHECK-LABEL: func @calling_single_block_ret_func
+  func @calling_single_block_ret_func() {
+    // CHECK-NEXT: spv.Return
+    spv.FunctionCall @callee() : () -> ()
+    spv.Return
+  }
+}
+
+// -----
+
+spv.module "Logical" "GLSL450" {
+  func @callee() -> i32 {
+    %0 = spv.constant 42 : i32
+    spv.ReturnValue %0 : i32
+  }
+
+  // CHECK-LABEL: func @calling_single_block_retval_func
+  func @calling_single_block_retval_func() -> i32 {
+    // CHECK-NEXT: %[[CST:.*]] = spv.constant 42
+    %0 = spv.FunctionCall @callee() : () -> (i32)
+    // CHECK-NEXT: spv.ReturnValue %[[CST]]
+    spv.ReturnValue %0 : i32
+  }
+}
+
+// -----
+
+spv.module "Logical" "GLSL450" {
+  spv.globalVariable @data bind(0, 0) : !spv.ptr<!spv.struct<!spv.rtarray<i32> [0]>, StorageBuffer>
+  func @callee() {
+    %0 = spv._address_of @data : !spv.ptr<!spv.struct<!spv.rtarray<i32> [0]>, StorageBuffer>
+    %1 = spv.constant 0: i32
+    %2 = spv.AccessChain %0[%1, %1] : !spv.ptr<!spv.struct<!spv.rtarray<i32> [0]>, StorageBuffer>
+    spv.Branch ^next
+
+  ^next:
+    %3 = spv.constant 42: i32
+    spv.Store "StorageBuffer" %2, %3 : i32
+    spv.Return
+  }
+
+  // CHECK-LABEL: func @calling_multi_block_ret_func
+  func @calling_multi_block_ret_func() {
+    // CHECK-NEXT:   spv._address_of
+    // CHECK-NEXT:   spv.constant 0
+    // CHECK-NEXT:   spv.AccessChain
+    // CHECK-NEXT:   spv.Branch ^bb1
+    // CHECK-NEXT: ^bb1:
+    // CHECK-NEXT:   spv.constant
+    // CHECK-NEXT:   spv.Store
+    // CHECK-NEXT:   spv.Branch ^bb2
+    spv.FunctionCall @callee() : () -> ()
+    // CHECK-NEXT: ^bb2:
+    // CHECK-NEXT:   spv.Return
+    spv.Return
+  }
+}
+
+// TODO: calling_multi_block_retval_func
+
+// -----
+
+spv.module "Logical" "GLSL450" {
+  func @callee(%cond : i1) -> () {
+    spv.selection {
+      spv.BranchConditional %cond, ^then, ^merge
+    ^then:
+      spv.Return
+    ^merge:
+      spv._merge
+    }
+    spv.Return
+  }
+
+  // CHECK-LABEL: calling_selection_ret_func
+  func @calling_selection_ret_func() {
+    %0 = spv.constant true
+    // CHECK: spv.FunctionCall
+    spv.FunctionCall @callee(%0) : (i1) -> ()
+    spv.Return
+  }
+}
+
+// -----
+
+spv.module "Logical" "GLSL450" {
+  func @callee(%cond : i1) -> () {
+    spv.selection {
+      spv.BranchConditional %cond, ^then, ^merge
+    ^then:
+      spv.Branch ^merge
+    ^merge:
+      spv._merge
+    }
+    spv.Return
+  }
+
+  // CHECK-LABEL: calling_selection_no_ret_func
+  func @calling_selection_no_ret_func() {
+    // CHECK-NEXT: %[[TRUE:.*]] = spv.constant true
+    %0 = spv.constant true
+    // CHECK-NEXT: spv.selection
+    // CHECK-NEXT:   spv.BranchConditional %[[TRUE]], ^bb1, ^bb2
+    // CHECK-NEXT: ^bb1:
+    // CHECK-NEXT:   spv.Branch ^bb2
+    // CHECK-NEXT: ^bb2:
+    // CHECK-NEXT:   spv._merge
+    spv.FunctionCall @callee(%0) : (i1) -> ()
+    spv.Return
+  }
+}
+
+// -----
+
+spv.module "Logical" "GLSL450" {
+  func @callee(%cond : i1) -> () {
+    spv.loop {
+      spv.Branch ^header
+    ^header:
+      spv.BranchConditional %cond, ^body, ^merge
+    ^body:
+      spv.Return
+    ^continue:
+      spv.Branch ^header
+    ^merge:
+      spv._merge
+    }
+    spv.Return
+  }
+
+  // CHECK-LABEL: calling_loop_ret_func
+  func @calling_loop_ret_func() {
+    %0 = spv.constant true
+    // CHECK: spv.FunctionCall
+    spv.FunctionCall @callee(%0) : (i1) -> ()
+    spv.Return
+  }
+}
+
+// -----
+
+spv.module "Logical" "GLSL450" {
+  func @callee(%cond : i1) -> () {
+    spv.loop {
+      spv.Branch ^header
+    ^header:
+      spv.BranchConditional %cond, ^body, ^merge
+    ^body:
+      spv.Branch ^continue
+    ^continue:
+      spv.Branch ^header
+    ^merge:
+      spv._merge
+    }
+    spv.Return
+  }
+
+  // CHECK-LABEL: calling_loop_no_ret_func
+  func @calling_loop_no_ret_func() {
+    // CHECK-NEXT: %[[TRUE:.*]] = spv.constant true
+    %0 = spv.constant true
+    // CHECK-NEXT: spv.loop
+    // CHECK-NEXT:   spv.Branch ^bb1
+    // CHECK-NEXT: ^bb1:
+    // CHECK-NEXT:   spv.BranchConditional %[[TRUE]], ^bb2, ^bb4
+    // CHECK-NEXT: ^bb2:
+    // CHECK-NEXT:   spv.Branch ^bb3
+    // CHECK-NEXT: ^bb3:
+    // CHECK-NEXT:   spv.Branch ^bb1
+    // CHECK-NEXT: ^bb4:
+    // CHECK-NEXT:   spv._merge
+    spv.FunctionCall @callee(%0) : (i1) -> ()
+    spv.Return
+  }
+}