make sure we capture up to 1 million trace events to avoid crashing the trace viewer. add an unit-test
PiperOrigin-RevId: 298403826
Change-Id: Idfcbf9501e9d2a626828c83501b2c548b8df48dd
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events.cc b/tensorflow/core/profiler/convert/xplane_to_trace_events.cc
index 565ca2c..cb43a9a 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events.cc
@@ -36,6 +36,27 @@
}
} // namespace
+void MaybeDropEventsForTraceViewer(Trace* trace, uint32 limit) {
+ auto* trace_events = trace->mutable_trace_events();
+ size_t trace_event_size = trace_events->size();
+ if (trace_event_size <= limit) return; // Nothing to do.
+ // Sort the events according to start time.
+ std::vector<uint64> timestamps;
+ timestamps.reserve(trace_event_size);
+ for (const auto& event : *trace_events) {
+ timestamps.push_back(event.timestamp_ps());
+ }
+ std::partial_sort(timestamps.begin(), timestamps.begin() + limit,
+ timestamps.end(), std::less<uint64>());
+ uint64 cutoff_timestamp = timestamps[limit - 1];
+ trace_events->erase(std::remove_if(trace_events->begin(), trace_events->end(),
+ [&](const TraceEvent& event) {
+ return event.timestamp_ps() >
+ cutoff_timestamp;
+ }),
+ trace_events->end());
+}
+
void ConvertXSpaceToTraceEvents(const XSpace& xspace, Trace* trace) {
auto* trace_devices = trace->mutable_devices();
@@ -69,6 +90,11 @@
});
});
}
+
+ // Trace viewer (non-streaming) has scalability issues, we need to drop
+ // events to avoid loading failure for trace viewer.
+ constexpr uint64 kMaxEvents = 1000000;
+ MaybeDropEventsForTraceViewer(trace, kMaxEvents);
}
} // namespace profiler
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events.h b/tensorflow/core/profiler/convert/xplane_to_trace_events.h
index 40e03c5..b8e5f00 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events.h
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events.h
@@ -26,6 +26,9 @@
void ConvertXSpaceToTraceEvents(const XSpace& xspace, Trace* trace);
+// Not Public API, Testing only.
+void MaybeDropEventsForTraceViewer(Trace* trace, uint32 limit);
+
} // namespace profiler
} // namespace tensorflow
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
index 42c50b7..afff5e6 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
@@ -71,6 +71,22 @@
EXPECT_EQ(trace.trace_events_size(), 3);
}
+TEST(ConvertXPlaneToTraceEvents, Drop) {
+ Trace trace;
+ for (int i = 0; i < 100; i++) {
+ trace.add_trace_events()->set_timestamp_ps((100 - i) % 50);
+ }
+
+ MaybeDropEventsForTraceViewer(&trace, 150);
+ EXPECT_EQ(trace.trace_events_size(), 100); // No dropping.
+
+ MaybeDropEventsForTraceViewer(&trace, 50);
+ EXPECT_EQ(trace.trace_events_size(), 50);
+ for (const auto& event : trace.trace_events()) {
+ EXPECT_LT(event.timestamp_ps(), 25);
+ }
+}
+
} // namespace
} // namespace profiler
} // namespace tensorflow