[tfrt:tf] Use Tensorflow profiler to trace tf_cpurt kernels
TFRT_TRACE traces are not connected to the request
PiperOrigin-RevId: 389907697
Change-Id: Idda571c0e0077c09da519a6aece464e2eb9474e1
diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD
index fd64b15..e86bacf 100644
--- a/tensorflow/compiler/mlir/tfrt/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/BUILD
@@ -204,6 +204,7 @@
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:framework",
"//tensorflow/core/platform:dynamic_annotations",
+ "//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
"//tensorflow/core/tfrt/utils:fallback_tensor",
"@llvm-project//mlir:Async",
diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_kernels.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_kernels.cc
index 235affd..3593047 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_kernels.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_kernels.cc
@@ -27,6 +27,7 @@
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
#include "tfrt/cpu/jit/async_runtime.h" // from @tf_runtime
@@ -46,7 +47,6 @@
#include "tfrt/support/string_util.h" // from @tf_runtime
#include "tfrt/tensor/tensor_metadata.h" // from @tf_runtime
#include "tfrt/tensor/tensor_shape.h" // from @tf_runtime
-#include "tfrt/tracing/tracing.h" // from @tf_runtime
namespace tensorflow {
namespace tfrt {
@@ -266,7 +266,12 @@
RepeatedArguments<FallbackTensor> operands,
RemainingResults results,
const ExecutionContext& exec_ctx) {
- TFRT_TRACE_SCOPE(Default, StrCat("tf_cpurt.Execute: @", executable.name()));
+ // Bind execution trace to the request context.
+ profiler::TraceMe trace_me([&] {
+ return profiler::TraceMeEncode("tf_cpurt.Execute",
+ {{"id", exec_ctx.request_ctx()->id()},
+ {"executable", executable.name()}});
+ });
// Keep track of memory address to tensor mapping for result conversion.
auto ctx = std::make_unique<TensorflowConversionContext>(operands.size());