[KernelGen][JIT] Return a dummy memref in case JIT compilation fails

Return a dummy memref to prevent the process from crashing. This way, TF can
propagate the error to the user.

PiperOrigin-RevId: 412011705
Change-Id: I7549d39e8b6bb951f00cbea57d64f9c887f833dc
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc
index eed417b..4f4f665 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h"
 
+#include <cstddef>
 #include <string>
 #include <utility>
 
@@ -286,7 +287,18 @@
                                             void* result, int64_t num_args,
                                             void* args_ptr) {
   // JIT compilation must have failed earlier if there is no callable ptr.
-  if (callable == nullptr) return;
+  // Return some empty memory descriptor to prevent a crash.
+  if (callable == nullptr) {
+    auto* desc = static_cast<::UnrankedMemRefType<void>*>(result);
+    desc->rank = 0;
+    auto* inner_desc = static_cast<StridedMemRefType<int8_t, 0>*>(
+        malloc(sizeof(StridedMemRefType<int8_t, 0>)));
+    inner_desc->basePtr = nullptr;
+    inner_desc->data = nullptr;
+    inner_desc->offset = 0;
+    desc->descriptor = inner_desc;
+    return;
+  }
 
   // Build the argument array according to `ExecutionEngine`'s calling
   // convention.
diff --git a/tensorflow/core/kernels/mlir_generated/base_op.h b/tensorflow/core/kernels/mlir_generated/base_op.h
index 8539a71..c7e9254 100644
--- a/tensorflow/core/kernels/mlir_generated/base_op.h
+++ b/tensorflow/core/kernels/mlir_generated/base_op.h
@@ -92,7 +92,7 @@
       args.push_back(ConvertTensorToDescriptor(ctx->input(i), buffers[i]));
     }
 
-    auto result_desc = Invoke(ctx, args);
+    UnrankedMemRef result_desc = Invoke(ctx, args);
     if (!ctx->status().ok()) {
       free(result_desc.descriptor);
       return;