Clone called functions into nested GPU module.

PiperOrigin-RevId: 270891190
diff --git a/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 9bf4cf6..f38a2e8 100644
--- a/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -144,13 +144,10 @@
 public:
   void runOnModule() override {
     ModuleManager moduleManager(getModule());
-    auto context = getModule().getContext();
-    Builder builder(context);
     for (auto func : getModule().getOps<FuncOp>()) {
       // Insert just after the function.
       Block::iterator insertPt(func.getOperation()->getNextNode());
       func.walk([&](gpu::LaunchOp op) {
-        // TODO(b/141098412): Handle called functions and globals.
         FuncOp outlinedFunc = outlineKernelFunc(op);
 
         // Potentially renames outlinedFunc to make symbol unique.
@@ -164,14 +161,41 @@
         kernelFunc.getBody().takeBody(outlinedFunc.getBody());
 
         // Create nested module and insert kernelFunc.
-        auto kernelModule = ModuleOp::create(UnknownLoc::get(context));
-        kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(),
-                             builder.getUnitAttr());
-        kernelModule.push_back(kernelFunc);
+        auto kernelModule = createKernelModule(kernelFunc, moduleManager);
         getModule().insert(insertPt, kernelModule);
       });
     }
   }
+
+private:
+  // Returns a module containing kernelFunc and all callees (recursive).
+  ModuleOp createKernelModule(FuncOp kernelFunc,
+                              const ModuleManager &parentModuleManager) {
+    auto context = getModule().getContext();
+    auto kernelModule = ModuleOp::create(UnknownLoc::get(context));
+    kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(),
+                         UnitAttr::get(context));
+    ModuleManager moduleManager(kernelModule);
+
+    llvm::SmallVector<FuncOp, 8> funcsToInsert = {kernelFunc};
+    while (!funcsToInsert.empty()) {
+      FuncOp func = funcsToInsert.pop_back_val();
+      moduleManager.insert(func);
+
+      // TODO(b/141098412): Support any op with a callable interface.
+      func.walk([&](CallOp call) {
+        auto callee = call.callee();
+        if (moduleManager.lookupSymbol<FuncOp>(callee))
+          return;
+
+        auto calleeFromParent =
+            parentModuleManager.lookupSymbol<FuncOp>(callee);
+        funcsToInsert.push_back(calleeFromParent.clone());
+      });
+    }
+
+    return kernelModule;
+  }
 };
 
 } // namespace
diff --git a/test/Dialect/GPU/outlining.mlir b/test/Dialect/GPU/outlining.mlir
index fdfe8d0..5f31486 100644
--- a/test/Dialect/GPU/outlining.mlir
+++ b/test/Dialect/GPU/outlining.mlir
@@ -113,8 +113,7 @@
                                        %grid_z = %cst)
              threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst,
                                         %block_z = %cst) {
-    // TODO(b/141098412): Support function calls.
-    // expected-error @+1 {{'device_function' does not reference a valid function}}
+    call @device_function() : () -> ()
     call @device_function() : () -> ()
     gpu.return
   }
@@ -122,5 +121,16 @@
 }
 
 func @device_function() {
+  call @recursive_device_function() : () -> ()
   gpu.return
 }
+
+func @recursive_device_function() {
+  call @recursive_device_function() : () -> ()
+  gpu.return
+}
+
+// CHECK: @device_function
+// CHECK: @recursive_device_function
+// CHECK: @device_function
+// CHECK: @recursive_device_function