Make type and rank explicit in mcuMemHostRegister function.

Fix registered size of indirect MemRefType kernel arguments.

PiperOrigin-RevId: 281362940
Change-Id: I99c3fbbc4cfc22129be7e24b8dcdef458c1ad996
diff --git a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
index d933242..9d8c894 100644
--- a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
+++ b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
@@ -49,7 +49,7 @@
 static constexpr const char *cuLaunchKernelName = "mcuLaunchKernel";
 static constexpr const char *cuGetStreamHelperName = "mcuGetStreamHelper";
 static constexpr const char *cuStreamSynchronizeName = "mcuStreamSynchronize";
-static constexpr const char *kMcuMemHostRegisterPtr = "mcuMemHostRegisterPtr";
+static constexpr const char *kMcuMemHostRegister = "mcuMemHostRegister";
 
 static constexpr const char *kCubinAnnotation = "nvvm.cubin";
 static constexpr const char *kCubinStorageSuffix = "_cubin_cst";
@@ -228,13 +228,13 @@
                                       getPointerType() /* CUstream stream */,
                                       /*isVarArg=*/false));
   }
-  if (!module.lookupSymbol(kMcuMemHostRegisterPtr)) {
+  if (!module.lookupSymbol(kMcuMemHostRegister)) {
     builder.create<LLVM::LLVMFuncOp>(
-        loc, kMcuMemHostRegisterPtr,
+        loc, kMcuMemHostRegister,
         LLVM::LLVMType::getFunctionTy(getVoidType(),
                                       {
                                           getPointerType(), /* void *ptr */
-                                          getInt32Type()    /* int32 flags*/
+                                          getInt64Type()    /* int64 sizeBytes*/
                                       },
                                       /*isVarArg=*/false));
   }
@@ -277,12 +277,14 @@
     //   the descriptor pointer is registered via @mcuMemHostRegisterPtr
     if (llvmType.isStructTy()) {
       auto registerFunc =
-          getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegisterPtr);
-      auto zero = builder.create<LLVM::ConstantOp>(
-          loc, getInt32Type(), builder.getI32IntegerAttr(0));
+          getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegister);
+      auto nullPtr = builder.create<LLVM::NullOp>(loc, llvmType.getPointerTo());
+      auto gep = builder.create<LLVM::GEPOp>(loc, llvmType.getPointerTo(),
+                                             ArrayRef<Value *>{nullPtr, one});
+      auto size = builder.create<LLVM::PtrToIntOp>(loc, getInt64Type(), gep);
       builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
                                    builder.getSymbolRefAttr(registerFunc),
-                                   ArrayRef<Value *>{casted, zero});
+                                   ArrayRef<Value *>{casted, size});
       Value *memLocation = builder.create<LLVM::AllocaOp>(
           loc, getPointerPointerType(), one, /*alignment=*/1);
       builder.create<LLVM::StoreOp>(loc, casted, memLocation);
diff --git a/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index 31b6f6f..ac77258 100644
--- a/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -21,8 +21,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <assert.h>
-#include <memory.h>
+#include <cassert>
+#include <numeric>
 
 #include "llvm/Support/raw_ostream.h"
 
@@ -80,6 +80,13 @@
 
 /// Helper functions for writing mlir example code
 
+// Allows to register byte array with the CUDA runtime. Helpful until we have
+// transfer functions implemented.
+extern "C" void mcuMemHostRegister(void *ptr, uint64_t sizeBytes) {
+  reportErrorIfAny(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0),
+                   "MemHostRegister");
+}
+
 // A struct that corresponds to how MLIR represents memrefs.
 template <typename T, int N> struct MemRefType {
   T *basePtr;
@@ -89,23 +96,22 @@
   int64_t strides[N];
 };
 
-// Allows to register a pointer with the CUDA runtime. Helpful until
-// we have transfer functions implemented.
-extern "C" void mcuMemHostRegister(const MemRefType<float, 1> *arg,
-                                   int32_t flags) {
-  reportErrorIfAny(
-      cuMemHostRegister(arg->data, arg->sizes[0] * sizeof(float), flags),
-      "MemHostRegister");
-  for (int pos = 0; pos < arg->sizes[0]; pos++) {
-    arg->data[pos] = 1.23f;
-  }
+// Allows to register a MemRef with the CUDA runtime. Initializes array with
+// value. Helpful until we have transfer functions implemented.
+template <typename T, int N>
+void mcuMemHostRegisterMemRef(const MemRefType<T, N> *arg, T value) {
+  auto count = std::accumulate(arg->sizes, arg->sizes + N, 1,
+                               std::multiplies<int64_t>());
+  std::fill_n(arg->data, count, value);
+  mcuMemHostRegister(arg->data, count * sizeof(T));
 }
-
-// Allows to register a pointer with the CUDA runtime. Helpful until
-// we have transfer functions implemented.
-extern "C" void mcuMemHostRegisterPtr(void *ptr, int32_t flags) {
-  reportErrorIfAny(cuMemHostRegister(ptr, sizeof(void *), flags),
-                   "MemHostRegister");
+extern "C" void
+mcuMemHostRegisterMemRef1dFloat(const MemRefType<float, 1> *arg) {
+  mcuMemHostRegisterMemRef(arg, 1.23f);
+}
+extern "C" void
+mcuMemHostRegisterMemRef3dFloat(const MemRefType<float, 3> *arg) {
+  mcuMemHostRegisterMemRef(arg, 1.23f);
 }
 
 /// Prints the given float array to stderr.