[XLA:Python] Add ProfilerSession bindings.

This allows for programmatically collecting a profiler trace that can be viewed in TensorBoard.

TESTING: I manually collected a JAX trace in a Cloud TPU VM. I'll add unit tests in a subsequent JAX change.
PiperOrigin-RevId: 362628356
Change-Id: I7e00bf0504572a364b8f91c01c3dda249ce21233
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 09ffb3f..3d45e3a 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -450,9 +450,14 @@
     ],
     features = ["-use_header_modules"],
     deps = [
+        ":types",
+        "//tensorflow/compiler/xla:status",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/core/platform:platform_port",
         "//tensorflow/core/profiler/lib:profiler_backends",
         "//tensorflow/core/profiler/lib:profiler_session",
         "//tensorflow/core/profiler/rpc:profiler_server_impl",
+        "//tensorflow/core/profiler/rpc/client:capture_profile",
         "//tensorflow/python/profiler/internal:traceme_wrapper",
         "@pybind11",
     ],
diff --git a/tensorflow/compiler/xla/python/profiler.cc b/tensorflow/compiler/xla/python/profiler.cc
index 1a4c588..f57f86d 100644
--- a/tensorflow/compiler/xla/python/profiler.cc
+++ b/tensorflow/compiler/xla/python/profiler.cc
@@ -16,6 +16,12 @@
 #include "tensorflow/compiler/xla/python/profiler.h"
 
 #include "pybind11/pybind11.h"
+#include "tensorflow/compiler/xla/python/types.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/host_info.h"
+#include "tensorflow/core/profiler/lib/profiler_session.h"
+#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
 #include "tensorflow/core/profiler/rpc/profiler_server.h"
 #include "tensorflow/python/profiler/internal/traceme_wrapper.h"
 
@@ -47,6 +53,24 @@
       },
       py::arg("port"));
 
+  py::class_<tensorflow::ProfilerSession> profiler_session_class(
+      profiler, "ProfilerSession");
+  profiler_session_class
+      .def(py::init([]() {
+        return tensorflow::ProfilerSession::Create(
+            tensorflow::ProfilerSession::DefaultOptions());
+      }))
+      .def("stop_and_export",
+           [](tensorflow::ProfilerSession* sess,
+              const std::string& tensorboard_dir) -> xla::Status {
+             tensorflow::profiler::XSpace xspace;
+             // Disables the ProfilerSession
+             TF_RETURN_IF_ERROR(sess->CollectData(&xspace));
+             xspace.add_hostnames(tensorflow::port::Hostname());
+             return tensorflow::profiler::ExportToTensorBoard(xspace,
+                                                              tensorboard_dir);
+           });
+
   py::class_<TraceMeWrapper> traceme_class(profiler, "TraceMe",
                                            py::module_local());
   traceme_class.def(py::init<py::str, py::kwargs>())
diff --git a/tensorflow/core/profiler/rpc/client/BUILD b/tensorflow/core/profiler/rpc/client/BUILD
index 544ceac..c3e9cfa 100644
--- a/tensorflow/core/profiler/rpc/client/BUILD
+++ b/tensorflow/core/profiler/rpc/client/BUILD
@@ -23,6 +23,7 @@
     hdrs = ["capture_profile.h"],
     copts = tf_profiler_copts(),
     visibility = [
+        "//tensorflow/compiler/xla/python:__pkg__",
         "//tensorflow/python/profiler/internal:__pkg__",
     ],
     deps = [