make sure we capture up to 1 million trace events to avoid crashing the trace viewer. add an unit-test

PiperOrigin-RevId: 298403826
Change-Id: Idfcbf9501e9d2a626828c83501b2c548b8df48dd
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events.cc b/tensorflow/core/profiler/convert/xplane_to_trace_events.cc
index 565ca2c..cb43a9a 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events.cc
@@ -36,6 +36,27 @@
 }
 }  // namespace
 
+void MaybeDropEventsForTraceViewer(Trace* trace, uint32 limit) {
+  auto* trace_events = trace->mutable_trace_events();
+  size_t trace_event_size = trace_events->size();
+  if (trace_event_size <= limit) return;  // Nothing to do.
+  // Sort the events according to start time.
+  std::vector<uint64> timestamps;
+  timestamps.reserve(trace_event_size);
+  for (const auto& event : *trace_events) {
+    timestamps.push_back(event.timestamp_ps());
+  }
+  std::partial_sort(timestamps.begin(), timestamps.begin() + limit,
+                    timestamps.end(), std::less<uint64>());
+  uint64 cutoff_timestamp = timestamps[limit - 1];
+  trace_events->erase(std::remove_if(trace_events->begin(), trace_events->end(),
+                                     [&](const TraceEvent& event) {
+                                       return event.timestamp_ps() >
+                                              cutoff_timestamp;
+                                     }),
+                      trace_events->end());
+}
+
 void ConvertXSpaceToTraceEvents(const XSpace& xspace, Trace* trace) {
   auto* trace_devices = trace->mutable_devices();
 
@@ -69,6 +90,11 @@
       });
     });
   }
+
+  // Trace viewer (non-streaming) has scalability issues, we need to drop
+  // events to avoid loading failure for trace viewer.
+  constexpr uint64 kMaxEvents = 1000000;
+  MaybeDropEventsForTraceViewer(trace, kMaxEvents);
 }
 
 }  // namespace profiler
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events.h b/tensorflow/core/profiler/convert/xplane_to_trace_events.h
index 40e03c5..b8e5f00 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events.h
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events.h
@@ -26,6 +26,9 @@
 
 void ConvertXSpaceToTraceEvents(const XSpace& xspace, Trace* trace);
 
+// Not Public API, Testing only.
+void MaybeDropEventsForTraceViewer(Trace* trace, uint32 limit);
+
 }  // namespace profiler
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
index 42c50b7..afff5e6 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
@@ -71,6 +71,22 @@
   EXPECT_EQ(trace.trace_events_size(), 3);
 }
 
+TEST(ConvertXPlaneToTraceEvents, Drop) {
+  Trace trace;
+  for (int i = 0; i < 100; i++) {
+    trace.add_trace_events()->set_timestamp_ps((100 - i) % 50);
+  }
+
+  MaybeDropEventsForTraceViewer(&trace, 150);
+  EXPECT_EQ(trace.trace_events_size(), 100);  // No dropping.
+
+  MaybeDropEventsForTraceViewer(&trace, 50);
+  EXPECT_EQ(trace.trace_events_size(), 50);
+  for (const auto& event : trace.trace_events()) {
+    EXPECT_LT(event.timestamp_ps(), 25);
+  }
+}
+
 }  // namespace
 }  // namespace profiler
 }  // namespace tensorflow