Clone called functions into nested GPU module.
PiperOrigin-RevId: 270891190
diff --git a/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 9bf4cf6..f38a2e8 100644
--- a/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -144,13 +144,10 @@
public:
void runOnModule() override {
ModuleManager moduleManager(getModule());
- auto context = getModule().getContext();
- Builder builder(context);
for (auto func : getModule().getOps<FuncOp>()) {
// Insert just after the function.
Block::iterator insertPt(func.getOperation()->getNextNode());
func.walk([&](gpu::LaunchOp op) {
- // TODO(b/141098412): Handle called functions and globals.
FuncOp outlinedFunc = outlineKernelFunc(op);
// Potentially renames outlinedFunc to make symbol unique.
@@ -164,14 +161,41 @@
kernelFunc.getBody().takeBody(outlinedFunc.getBody());
// Create nested module and insert kernelFunc.
- auto kernelModule = ModuleOp::create(UnknownLoc::get(context));
- kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(),
- builder.getUnitAttr());
- kernelModule.push_back(kernelFunc);
+ auto kernelModule = createKernelModule(kernelFunc, moduleManager);
getModule().insert(insertPt, kernelModule);
});
}
}
+
+private:
+ // Returns a module containing kernelFunc and all callees (recursive).
+ ModuleOp createKernelModule(FuncOp kernelFunc,
+ const ModuleManager &parentModuleManager) {
+ auto context = getModule().getContext();
+ auto kernelModule = ModuleOp::create(UnknownLoc::get(context));
+ kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(),
+ UnitAttr::get(context));
+ ModuleManager moduleManager(kernelModule);
+
+ llvm::SmallVector<FuncOp, 8> funcsToInsert = {kernelFunc};
+ while (!funcsToInsert.empty()) {
+ FuncOp func = funcsToInsert.pop_back_val();
+ moduleManager.insert(func);
+
+ // TODO(b/141098412): Support any op with a callable interface.
+ func.walk([&](CallOp call) {
+ auto callee = call.callee();
+ if (moduleManager.lookupSymbol<FuncOp>(callee))
+ return;
+
+ auto calleeFromParent =
+ parentModuleManager.lookupSymbol<FuncOp>(callee);
+ funcsToInsert.push_back(calleeFromParent.clone());
+ });
+ }
+
+ return kernelModule;
+ }
};
} // namespace
diff --git a/test/Dialect/GPU/outlining.mlir b/test/Dialect/GPU/outlining.mlir
index fdfe8d0..5f31486 100644
--- a/test/Dialect/GPU/outlining.mlir
+++ b/test/Dialect/GPU/outlining.mlir
@@ -113,8 +113,7 @@
%grid_z = %cst)
threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst,
%block_z = %cst) {
- // TODO(b/141098412): Support function calls.
- // expected-error @+1 {{'device_function' does not reference a valid function}}
+ call @device_function() : () -> ()
call @device_function() : () -> ()
gpu.return
}
@@ -122,5 +121,16 @@
}
func @device_function() {
+ call @recursive_device_function() : () -> ()
gpu.return
}
+
+func @recursive_device_function() {
+ call @recursive_device_function() : () -> ()
+ gpu.return
+}
+
+// CHECK: @device_function
+// CHECK: @recursive_device_function
+// CHECK: @device_function
+// CHECK: @recursive_device_function