[KernelGen][JIT] Return a dummy memref in case JIT compilation fails
Return a dummy memref to prevent the process from crashing. This way, TF can
propagate the error to the user.
PiperOrigin-RevId: 412011705
Change-Id: I7549d39e8b6bb951f00cbea57d64f9c887f833dc
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc
index eed417b..4f4f665 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc
@@ -15,6 +15,7 @@
#include "tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h"
+#include <cstddef>
#include <string>
#include <utility>
@@ -286,7 +287,18 @@
void* result, int64_t num_args,
void* args_ptr) {
// JIT compilation must have failed earlier if there is no callable ptr.
- if (callable == nullptr) return;
+ // Return some empty memory descriptor to prevent a crash.
+ if (callable == nullptr) {
+ auto* desc = static_cast<::UnrankedMemRefType<void>*>(result);
+ desc->rank = 0;
+ auto* inner_desc = static_cast<StridedMemRefType<int8_t, 0>*>(
+ malloc(sizeof(StridedMemRefType<int8_t, 0>)));
+ inner_desc->basePtr = nullptr;
+ inner_desc->data = nullptr;
+ inner_desc->offset = 0;
+ desc->descriptor = inner_desc;
+ return;
+ }
// Build the argument array according to `ExecutionEngine`'s calling
// convention.
diff --git a/tensorflow/core/kernels/mlir_generated/base_op.h b/tensorflow/core/kernels/mlir_generated/base_op.h
index 8539a71..c7e9254 100644
--- a/tensorflow/core/kernels/mlir_generated/base_op.h
+++ b/tensorflow/core/kernels/mlir_generated/base_op.h
@@ -92,7 +92,7 @@
args.push_back(ConvertTensorToDescriptor(ctx->input(i), buffers[i]));
}
- auto result_desc = Invoke(ctx, args);
+ UnrankedMemRef result_desc = Invoke(ctx, args);
if (!ctx->status().ok()) {
free(result_desc.descriptor);
return;