Add ProfilerSession method to collect data in XSpace format
PiperOrigin-RevId: 284887825
Change-Id: Ibddc95b819f9ed8d559572e84924e086f7cde274
diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD
index 54b85b0..b220369 100644
--- a/tensorflow/core/profiler/lib/BUILD
+++ b/tensorflow/core/profiler/lib/BUILD
@@ -17,6 +17,7 @@
":profiler_utils",
"//tensorflow/core/profiler/internal:profiler_interface",
"//tensorflow/core/profiler/internal:profiler_factory",
+ "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"@com_google_absl//absl/strings",
] + select({
"//tensorflow:android": [],
diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc
index 3f69e5a..c822dd2 100644
--- a/tensorflow/core/profiler/lib/profiler_session.cc
+++ b/tensorflow/core/profiler/lib/profiler_session.cc
@@ -184,6 +184,28 @@
return status_;
}
+Status ProfilerSession::CollectData(profiler::XSpace* space) {
+ mutex_lock l(mutex_);
+ if (!status_.ok()) return status_;
+ for (auto& profiler : profilers_) {
+ profiler->Stop().IgnoreError();
+ }
+
+ for (auto& profiler : profilers_) {
+ profiler->CollectData(space).IgnoreError();
+ }
+
+ if (active_) {
+ // Allow another session to start.
+#if !defined(IS_MOBILE_PLATFORM)
+ profiler::ReleaseProfilerLock();
+#endif
+ active_ = false;
+ }
+
+ return Status::OK();
+}
+
Status ProfilerSession::CollectData(RunMetadata* run_metadata) {
mutex_lock l(mutex_);
if (!status_.ok()) return status_;
diff --git a/tensorflow/core/profiler/lib/profiler_session.h b/tensorflow/core/profiler/lib/profiler_session.h
index 85b9901..8e6a682 100644
--- a/tensorflow/core/profiler/lib/profiler_session.h
+++ b/tensorflow/core/profiler/lib/profiler_session.h
@@ -22,6 +22,7 @@
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/internal/profiler_interface.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
@@ -45,6 +46,9 @@
tensorflow::Status Status() LOCKS_EXCLUDED(mutex_);
+ tensorflow::Status CollectData(profiler::XSpace* space)
+ LOCKS_EXCLUDED(mutex_);
+
tensorflow::Status CollectData(RunMetadata* run_metadata)
LOCKS_EXCLUDED(mutex_);