[XLA:Python] Add ProfilerSession bindings.
This allows for programmatically collecting a profiler trace that can be viewed in TensorBoard.
TESTING: I manually collected a JAX trace in a Cloud TPU VM. I'll add unit tests in a subsequent JAX change.
PiperOrigin-RevId: 362628356
Change-Id: I7e00bf0504572a364b8f91c01c3dda249ce21233
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 09ffb3f..3d45e3a 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -450,9 +450,14 @@
],
features = ["-use_header_modules"],
deps = [
+ ":types",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/core/platform:errors",
+ "//tensorflow/core/platform:platform_port",
"//tensorflow/core/profiler/lib:profiler_backends",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core/profiler/rpc:profiler_server_impl",
+ "//tensorflow/core/profiler/rpc/client:capture_profile",
"//tensorflow/python/profiler/internal:traceme_wrapper",
"@pybind11",
],
diff --git a/tensorflow/compiler/xla/python/profiler.cc b/tensorflow/compiler/xla/python/profiler.cc
index 1a4c588..f57f86d 100644
--- a/tensorflow/compiler/xla/python/profiler.cc
+++ b/tensorflow/compiler/xla/python/profiler.cc
@@ -16,6 +16,12 @@
#include "tensorflow/compiler/xla/python/profiler.h"
#include "pybind11/pybind11.h"
+#include "tensorflow/compiler/xla/python/types.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/host_info.h"
+#include "tensorflow/core/profiler/lib/profiler_session.h"
+#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h"
#include "tensorflow/python/profiler/internal/traceme_wrapper.h"
@@ -47,6 +53,24 @@
},
py::arg("port"));
+ py::class_<tensorflow::ProfilerSession> profiler_session_class(
+ profiler, "ProfilerSession");
+ profiler_session_class
+ .def(py::init([]() {
+ return tensorflow::ProfilerSession::Create(
+ tensorflow::ProfilerSession::DefaultOptions());
+ }))
+ .def("stop_and_export",
+ [](tensorflow::ProfilerSession* sess,
+ const std::string& tensorboard_dir) -> xla::Status {
+ tensorflow::profiler::XSpace xspace;
+ // Disables the ProfilerSession
+ TF_RETURN_IF_ERROR(sess->CollectData(&xspace));
+ xspace.add_hostnames(tensorflow::port::Hostname());
+ return tensorflow::profiler::ExportToTensorBoard(xspace,
+ tensorboard_dir);
+ });
+
py::class_<TraceMeWrapper> traceme_class(profiler, "TraceMe",
py::module_local());
traceme_class.def(py::init<py::str, py::kwargs>())
diff --git a/tensorflow/core/profiler/rpc/client/BUILD b/tensorflow/core/profiler/rpc/client/BUILD
index 544ceac..c3e9cfa 100644
--- a/tensorflow/core/profiler/rpc/client/BUILD
+++ b/tensorflow/core/profiler/rpc/client/BUILD
@@ -23,6 +23,7 @@
hdrs = ["capture_profile.h"],
copts = tf_profiler_copts(),
visibility = [
+ "//tensorflow/compiler/xla/python:__pkg__",
"//tensorflow/python/profiler/internal:__pkg__",
],
deps = [