Use Timespan in derived XEvent generation

PiperOrigin-RevId: 450104961
diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD
index 26a2188..4114dbe 100644
--- a/tensorflow/core/profiler/utils/BUILD
+++ b/tensorflow/core/profiler/utils/BUILD
@@ -430,6 +430,7 @@
         ":xplane_visitor",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/profiler/convert:xla_op_utils",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/util:stats_calculator_portable",
         "@com_google_absl//absl/algorithm:container",
diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc
index fd5ddd5..584e4fd 100644
--- a/tensorflow/core/profiler/utils/derived_timeline.cc
+++ b/tensorflow/core/profiler/utils/derived_timeline.cc
@@ -27,6 +27,7 @@
 #include "absl/types/optional.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/convert/xla_op_utils.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/gpu_event_stats.h"
 #include "tensorflow/core/profiler/utils/group_events.h"
@@ -44,14 +45,13 @@
 namespace profiler {
 namespace {
 
-XEvent CreateXEvent(const XEventMetadata& metadata, int64_t offset_ps,
-                    int64_t duration_ps, int64_t group_id_stat_metadata_id,
+XEvent CreateXEvent(const XEventMetadata& metadata, Timespan timespan,
+                    int64_t group_id_stat_metadata_id,
                     absl::optional<int64_t> group_id) {
   XEvent event;
   event.set_metadata_id(metadata.id());
-  // TODO(b/150498419): Normalize with the line start time.
-  event.set_offset_ps(offset_ps);
-  event.set_duration_ps(duration_ps);
+  event.set_offset_ps(timespan.begin_ps());
+  event.set_duration_ps(timespan.duration_ps());
   if (group_id) {
     XStat* stat = event.add_stats();
     stat->set_metadata_id(group_id_stat_metadata_id);
@@ -63,8 +63,8 @@
 }  // namespace
 
 void ProcessTfOpEvent(absl::string_view tf_op_full_name,
-                      absl::string_view low_level_event_name, int64_t offset_ps,
-                      int64_t duration_ps, absl::optional<int64_t> group_id,
+                      absl::string_view low_level_event_name, Timespan timespan,
+                      absl::optional<int64_t> group_id,
                       XPlaneBuilder* plane_builder,
                       DerivedXLineBuilder* tf_name_scope_line_builder,
                       DerivedXLineBuilder* tf_op_line_builder) {
@@ -76,9 +76,9 @@
   if (category == Category::kTensorFlow || category == Category::kJax) {
     std::vector<XEvent> name_scope_event_per_level;
     for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) {
-      name_scope_event_per_level.push_back(CreateXEvent(
-          *plane_builder->GetOrCreateEventMetadata(tf_name_scope), offset_ps,
-          duration_ps, group_id_stat_metadata_id, group_id));
+      name_scope_event_per_level.push_back(
+          CreateXEvent(*plane_builder->GetOrCreateEventMetadata(tf_name_scope),
+                       timespan, group_id_stat_metadata_id, group_id));
     }
     tf_name_scope_line_builder->ExpandOrAddEvents(
         name_scope_event_per_level, group_id, low_level_event_name);
@@ -89,8 +89,8 @@
   // the same color in the trace viewer.
   tf_op_event_metadata->set_display_name(TfOpEventName(tf_op));
   tf_op_line_builder->ExpandOrAddEvent(
-      CreateXEvent(*tf_op_event_metadata, offset_ps, duration_ps,
-                   group_id_stat_metadata_id, group_id),
+      CreateXEvent(*tf_op_event_metadata, timespan, group_id_stat_metadata_id,
+                   group_id),
       group_id, low_level_event_name);
 }
 
@@ -112,9 +112,10 @@
 
 void DerivedXEventBuilder::Expand(const XEvent& event,
                                   absl::string_view low_level_event_name) {
-  DCHECK_LE(event_.OffsetPs(), event.offset_ps());
-  event_.SetDurationPs((event.offset_ps() + event.duration_ps()) -
-                       event_.OffsetPs());
+  Timespan timespan = event_.GetTimespan();
+  DCHECK_LE(timespan.begin_ps(), event.offset_ps());
+  timespan.ExpandToInclude(Timespan(event.offset_ps(), event.duration_ps()));
+  event_.SetTimespan(timespan);
   if (!low_level_event_name.empty()) {
     low_level_event_names_.insert(std::string(low_level_event_name));
   }
@@ -197,13 +198,12 @@
 
   // Process events in order by start time.
   for (const XEventVisitor& event : events) {
-    int64_t offset_ps = event.OffsetPs();
-    int64_t duration_ps = event.DurationPs();
+    Timespan timespan = event.GetTimespan();
     GpuEventStats stats(&event);
     if (stats.group_id) {
       XEvent step_event = CreateXEvent(
           *plane.GetOrCreateEventMetadata(absl::StrCat(*stats.group_id)),
-          offset_ps, duration_ps, group_id_stat_metadata_id, stats.group_id);
+          timespan, group_id_stat_metadata_id, stats.group_id);
       if (auto group_metadata =
               gtl::FindOrNull(group_metadata_map, *stats.group_id)) {
         XStat* stat = step_event.add_stats();
@@ -220,13 +220,13 @@
     if (!stats.IsKernel()) continue;
 
     if (!stats.hlo_module_name.empty()) {
-      std::string name(stats.hlo_module_name);
-      if (stats.program_id.has_value()) {
-        absl::StrAppend(&name, "(", stats.program_id.value(), ")");
-      }
+      std::string name = stats.program_id
+                             ? HloModuleNameWithProgramId(stats.hlo_module_name,
+                                                          *stats.program_id)
+                             : std::string(stats.hlo_module_name);
       hlo_modules.ExpandOrAddEvent(
-          CreateXEvent(*plane.GetOrCreateEventMetadata(name), offset_ps,
-                       duration_ps, group_id_stat_metadata_id, stats.group_id));
+          CreateXEvent(*plane.GetOrCreateEventMetadata(std::move(name)),
+                       timespan, group_id_stat_metadata_id, stats.group_id));
     }
 
     if (stats.IsXlaOp()) {
@@ -234,29 +234,27 @@
       std::vector<XEvent> hlo_op_event_per_level;
       for (absl::string_view hlo_op_name : stats.hlo_op_names) {
         DCHECK(!hlo_op_name.empty());
-        hlo_op_event_per_level.push_back(CreateXEvent(
-            *plane.GetOrCreateEventMetadata(hlo_op_name), offset_ps,
-            duration_ps, group_id_stat_metadata_id, stats.group_id));
+        hlo_op_event_per_level.push_back(
+            CreateXEvent(*plane.GetOrCreateEventMetadata(hlo_op_name), timespan,
+                         group_id_stat_metadata_id, stats.group_id));
       }
       hlo_ops.ExpandOrAddEvents(hlo_op_event_per_level, stats.group_id);
       auto symbol = symbol_resolver(stats.program_id, stats.hlo_module_name,
                                     stats.hlo_op_names.back());
       if (!symbol.tf_op_name.empty()) {
         ProcessTfOpEvent(symbol.tf_op_name,
-                         /*low_level_event_name=*/event.Name(), offset_ps,
-                         duration_ps, stats.group_id, &plane, &tf_name_scope,
-                         &tf_ops);
+                         /*low_level_event_name=*/event.Name(), timespan,
+                         stats.group_id, &plane, &tf_name_scope, &tf_ops);
       }
       if (!symbol.source_info.empty()) {
-        source.ExpandOrAddEvent(CreateXEvent(
-            *plane.GetOrCreateEventMetadata(symbol.source_info), offset_ps,
-            duration_ps, group_id_stat_metadata_id, stats.group_id));
+        source.ExpandOrAddEvent(
+            CreateXEvent(*plane.GetOrCreateEventMetadata(symbol.source_info),
+                         timespan, group_id_stat_metadata_id, stats.group_id));
       }
     } else if (stats.IsTfOp()) {
       ProcessTfOpEvent(stats.tf_op_fullname,
-                       /*low_level_event_name=*/event.Name(), offset_ps,
-                       duration_ps, stats.group_id, &plane, &tf_name_scope,
-                       &tf_ops);
+                       /*low_level_event_name=*/event.Name(), timespan,
+                       stats.group_id, &plane, &tf_name_scope, &tf_ops);
     }
   }
   RemoveEmptyLines(device_trace);
@@ -349,7 +347,7 @@
   // to look up tensorflow op name from hlo_module/hlo_op.
   auto dummy_symbol_resolver =
       [](absl::optional<uint64_t> program_id, absl::string_view hlo_module,
-         absl::string_view hlo_op) { return tensorflow::profiler::Symbol(); };
+         absl::string_view hlo_op) { return Symbol(); };
   std::vector<XPlane*> device_traces =
       FindMutablePlanesWithPrefix(space, kGpuPlanePrefix);
   for (XPlane* plane : device_traces) {
diff --git a/tensorflow/core/profiler/utils/derived_timeline.h b/tensorflow/core/profiler/utils/derived_timeline.h
index c24c574..2870c94 100644
--- a/tensorflow/core/profiler/utils/derived_timeline.h
+++ b/tensorflow/core/profiler/utils/derived_timeline.h
@@ -25,6 +25,7 @@
 #include "absl/types/optional.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/group_events.h"
+#include "tensorflow/core/profiler/utils/timespan.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 
 namespace tensorflow {
@@ -89,7 +90,7 @@
 
  private:
   // If the last event of the given level has the same metadata, expands it to
-  // include the time until the given event's (offset_ps + duration_ps).
+  // include the time until the given event's end time.
   // Otherwise, adds a new event and clears last_event_by_level_ for the levels
   // below the given level and all levels of the dependent lines. Clearing
   // last_event_by_level_ prevents a nested event from growing larger than the
@@ -117,8 +118,8 @@
 // Derives TF name scope and op events from the TF op's fully qualified name
 // with the name of the originating low-level event.
 void ProcessTfOpEvent(absl::string_view tf_op_full_name,
-                      absl::string_view low_level_event_name, int64_t offset_ps,
-                      int64_t duration_ps, absl::optional<int64_t> group_id,
+                      absl::string_view low_level_event_name, Timespan timespan,
+                      absl::optional<int64_t> group_id,
                       XPlaneBuilder* plane_builder,
                       DerivedXLineBuilder* tf_name_scope_line_builder,
                       DerivedXLineBuilder* tf_op_line_builder);