Make type and rank explicit in mcuMemHostRegister function.
Fix registered size of indirect MemRefType kernel arguments.
PiperOrigin-RevId: 281362940
Change-Id: I99c3fbbc4cfc22129be7e24b8dcdef458c1ad996
diff --git a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
index d933242..9d8c894 100644
--- a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
+++ b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
@@ -49,7 +49,7 @@
static constexpr const char *cuLaunchKernelName = "mcuLaunchKernel";
static constexpr const char *cuGetStreamHelperName = "mcuGetStreamHelper";
static constexpr const char *cuStreamSynchronizeName = "mcuStreamSynchronize";
-static constexpr const char *kMcuMemHostRegisterPtr = "mcuMemHostRegisterPtr";
+static constexpr const char *kMcuMemHostRegister = "mcuMemHostRegister";
static constexpr const char *kCubinAnnotation = "nvvm.cubin";
static constexpr const char *kCubinStorageSuffix = "_cubin_cst";
@@ -228,13 +228,13 @@
getPointerType() /* CUstream stream */,
/*isVarArg=*/false));
}
- if (!module.lookupSymbol(kMcuMemHostRegisterPtr)) {
+ if (!module.lookupSymbol(kMcuMemHostRegister)) {
builder.create<LLVM::LLVMFuncOp>(
- loc, kMcuMemHostRegisterPtr,
+ loc, kMcuMemHostRegister,
LLVM::LLVMType::getFunctionTy(getVoidType(),
{
getPointerType(), /* void *ptr */
- getInt32Type() /* int32 flags*/
+ getInt64Type() /* int64 sizeBytes*/
},
/*isVarArg=*/false));
}
@@ -277,12 +277,14 @@
// the descriptor pointer is registered via @mcuMemHostRegisterPtr
if (llvmType.isStructTy()) {
auto registerFunc =
- getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegisterPtr);
- auto zero = builder.create<LLVM::ConstantOp>(
- loc, getInt32Type(), builder.getI32IntegerAttr(0));
+ getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegister);
+ auto nullPtr = builder.create<LLVM::NullOp>(loc, llvmType.getPointerTo());
+ auto gep = builder.create<LLVM::GEPOp>(loc, llvmType.getPointerTo(),
+ ArrayRef<Value *>{nullPtr, one});
+ auto size = builder.create<LLVM::PtrToIntOp>(loc, getInt64Type(), gep);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
builder.getSymbolRefAttr(registerFunc),
- ArrayRef<Value *>{casted, zero});
+ ArrayRef<Value *>{casted, size});
Value *memLocation = builder.create<LLVM::AllocaOp>(
loc, getPointerPointerType(), one, /*alignment=*/1);
builder.create<LLVM::StoreOp>(loc, casted, memLocation);
diff --git a/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index 31b6f6f..ac77258 100644
--- a/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -21,8 +21,8 @@
//
//===----------------------------------------------------------------------===//
-#include <assert.h>
-#include <memory.h>
+#include <cassert>
+#include <numeric>
#include "llvm/Support/raw_ostream.h"
@@ -80,6 +80,13 @@
/// Helper functions for writing mlir example code
+// Allows to register byte array with the CUDA runtime. Helpful until we have
+// transfer functions implemented.
+extern "C" void mcuMemHostRegister(void *ptr, uint64_t sizeBytes) {
+ reportErrorIfAny(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0),
+ "MemHostRegister");
+}
+
// A struct that corresponds to how MLIR represents memrefs.
template <typename T, int N> struct MemRefType {
T *basePtr;
@@ -89,23 +96,22 @@
int64_t strides[N];
};
-// Allows to register a pointer with the CUDA runtime. Helpful until
-// we have transfer functions implemented.
-extern "C" void mcuMemHostRegister(const MemRefType<float, 1> *arg,
- int32_t flags) {
- reportErrorIfAny(
- cuMemHostRegister(arg->data, arg->sizes[0] * sizeof(float), flags),
- "MemHostRegister");
- for (int pos = 0; pos < arg->sizes[0]; pos++) {
- arg->data[pos] = 1.23f;
- }
+// Allows to register a MemRef with the CUDA runtime. Initializes array with
+// value. Helpful until we have transfer functions implemented.
+template <typename T, int N>
+void mcuMemHostRegisterMemRef(const MemRefType<T, N> *arg, T value) {
+ auto count = std::accumulate(arg->sizes, arg->sizes + N, 1,
+ std::multiplies<int64_t>());
+ std::fill_n(arg->data, count, value);
+ mcuMemHostRegister(arg->data, count * sizeof(T));
}
-
-// Allows to register a pointer with the CUDA runtime. Helpful until
-// we have transfer functions implemented.
-extern "C" void mcuMemHostRegisterPtr(void *ptr, int32_t flags) {
- reportErrorIfAny(cuMemHostRegister(ptr, sizeof(void *), flags),
- "MemHostRegister");
+extern "C" void
+mcuMemHostRegisterMemRef1dFloat(const MemRefType<float, 1> *arg) {
+ mcuMemHostRegisterMemRef(arg, 1.23f);
+}
+extern "C" void
+mcuMemHostRegisterMemRef3dFloat(const MemRefType<float, 3> *arg) {
+ mcuMemHostRegisterMemRef(arg, 1.23f);
}
/// Prints the given float array to stderr.