Use Timespan in derived XEvent generation
PiperOrigin-RevId: 450104961
diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD
index 26a2188..4114dbe 100644
--- a/tensorflow/core/profiler/utils/BUILD
+++ b/tensorflow/core/profiler/utils/BUILD
@@ -430,6 +430,7 @@
":xplane_visitor",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core/profiler/convert:xla_op_utils",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/util:stats_calculator_portable",
"@com_google_absl//absl/algorithm:container",
diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc
index fd5ddd5..584e4fd 100644
--- a/tensorflow/core/profiler/utils/derived_timeline.cc
+++ b/tensorflow/core/profiler/utils/derived_timeline.cc
@@ -27,6 +27,7 @@
#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/convert/xla_op_utils.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/gpu_event_stats.h"
#include "tensorflow/core/profiler/utils/group_events.h"
@@ -44,14 +45,13 @@
namespace profiler {
namespace {
-XEvent CreateXEvent(const XEventMetadata& metadata, int64_t offset_ps,
- int64_t duration_ps, int64_t group_id_stat_metadata_id,
+XEvent CreateXEvent(const XEventMetadata& metadata, Timespan timespan,
+ int64_t group_id_stat_metadata_id,
absl::optional<int64_t> group_id) {
XEvent event;
event.set_metadata_id(metadata.id());
- // TODO(b/150498419): Normalize with the line start time.
- event.set_offset_ps(offset_ps);
- event.set_duration_ps(duration_ps);
+ event.set_offset_ps(timespan.begin_ps());
+ event.set_duration_ps(timespan.duration_ps());
if (group_id) {
XStat* stat = event.add_stats();
stat->set_metadata_id(group_id_stat_metadata_id);
@@ -63,8 +63,8 @@
} // namespace
void ProcessTfOpEvent(absl::string_view tf_op_full_name,
- absl::string_view low_level_event_name, int64_t offset_ps,
- int64_t duration_ps, absl::optional<int64_t> group_id,
+ absl::string_view low_level_event_name, Timespan timespan,
+ absl::optional<int64_t> group_id,
XPlaneBuilder* plane_builder,
DerivedXLineBuilder* tf_name_scope_line_builder,
DerivedXLineBuilder* tf_op_line_builder) {
@@ -76,9 +76,9 @@
if (category == Category::kTensorFlow || category == Category::kJax) {
std::vector<XEvent> name_scope_event_per_level;
for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) {
- name_scope_event_per_level.push_back(CreateXEvent(
- *plane_builder->GetOrCreateEventMetadata(tf_name_scope), offset_ps,
- duration_ps, group_id_stat_metadata_id, group_id));
+ name_scope_event_per_level.push_back(
+ CreateXEvent(*plane_builder->GetOrCreateEventMetadata(tf_name_scope),
+ timespan, group_id_stat_metadata_id, group_id));
}
tf_name_scope_line_builder->ExpandOrAddEvents(
name_scope_event_per_level, group_id, low_level_event_name);
@@ -89,8 +89,8 @@
// the same color in the trace viewer.
tf_op_event_metadata->set_display_name(TfOpEventName(tf_op));
tf_op_line_builder->ExpandOrAddEvent(
- CreateXEvent(*tf_op_event_metadata, offset_ps, duration_ps,
- group_id_stat_metadata_id, group_id),
+ CreateXEvent(*tf_op_event_metadata, timespan, group_id_stat_metadata_id,
+ group_id),
group_id, low_level_event_name);
}
@@ -112,9 +112,10 @@
void DerivedXEventBuilder::Expand(const XEvent& event,
absl::string_view low_level_event_name) {
- DCHECK_LE(event_.OffsetPs(), event.offset_ps());
- event_.SetDurationPs((event.offset_ps() + event.duration_ps()) -
- event_.OffsetPs());
+ Timespan timespan = event_.GetTimespan();
+ DCHECK_LE(timespan.begin_ps(), event.offset_ps());
+ timespan.ExpandToInclude(Timespan(event.offset_ps(), event.duration_ps()));
+ event_.SetTimespan(timespan);
if (!low_level_event_name.empty()) {
low_level_event_names_.insert(std::string(low_level_event_name));
}
@@ -197,13 +198,12 @@
// Process events in order by start time.
for (const XEventVisitor& event : events) {
- int64_t offset_ps = event.OffsetPs();
- int64_t duration_ps = event.DurationPs();
+ Timespan timespan = event.GetTimespan();
GpuEventStats stats(&event);
if (stats.group_id) {
XEvent step_event = CreateXEvent(
*plane.GetOrCreateEventMetadata(absl::StrCat(*stats.group_id)),
- offset_ps, duration_ps, group_id_stat_metadata_id, stats.group_id);
+ timespan, group_id_stat_metadata_id, stats.group_id);
if (auto group_metadata =
gtl::FindOrNull(group_metadata_map, *stats.group_id)) {
XStat* stat = step_event.add_stats();
@@ -220,13 +220,13 @@
if (!stats.IsKernel()) continue;
if (!stats.hlo_module_name.empty()) {
- std::string name(stats.hlo_module_name);
- if (stats.program_id.has_value()) {
- absl::StrAppend(&name, "(", stats.program_id.value(), ")");
- }
+ std::string name = stats.program_id
+ ? HloModuleNameWithProgramId(stats.hlo_module_name,
+ *stats.program_id)
+ : std::string(stats.hlo_module_name);
hlo_modules.ExpandOrAddEvent(
- CreateXEvent(*plane.GetOrCreateEventMetadata(name), offset_ps,
- duration_ps, group_id_stat_metadata_id, stats.group_id));
+ CreateXEvent(*plane.GetOrCreateEventMetadata(std::move(name)),
+ timespan, group_id_stat_metadata_id, stats.group_id));
}
if (stats.IsXlaOp()) {
@@ -234,29 +234,27 @@
std::vector<XEvent> hlo_op_event_per_level;
for (absl::string_view hlo_op_name : stats.hlo_op_names) {
DCHECK(!hlo_op_name.empty());
- hlo_op_event_per_level.push_back(CreateXEvent(
- *plane.GetOrCreateEventMetadata(hlo_op_name), offset_ps,
- duration_ps, group_id_stat_metadata_id, stats.group_id));
+ hlo_op_event_per_level.push_back(
+ CreateXEvent(*plane.GetOrCreateEventMetadata(hlo_op_name), timespan,
+ group_id_stat_metadata_id, stats.group_id));
}
hlo_ops.ExpandOrAddEvents(hlo_op_event_per_level, stats.group_id);
auto symbol = symbol_resolver(stats.program_id, stats.hlo_module_name,
stats.hlo_op_names.back());
if (!symbol.tf_op_name.empty()) {
ProcessTfOpEvent(symbol.tf_op_name,
- /*low_level_event_name=*/event.Name(), offset_ps,
- duration_ps, stats.group_id, &plane, &tf_name_scope,
- &tf_ops);
+ /*low_level_event_name=*/event.Name(), timespan,
+ stats.group_id, &plane, &tf_name_scope, &tf_ops);
}
if (!symbol.source_info.empty()) {
- source.ExpandOrAddEvent(CreateXEvent(
- *plane.GetOrCreateEventMetadata(symbol.source_info), offset_ps,
- duration_ps, group_id_stat_metadata_id, stats.group_id));
+ source.ExpandOrAddEvent(
+ CreateXEvent(*plane.GetOrCreateEventMetadata(symbol.source_info),
+ timespan, group_id_stat_metadata_id, stats.group_id));
}
} else if (stats.IsTfOp()) {
ProcessTfOpEvent(stats.tf_op_fullname,
- /*low_level_event_name=*/event.Name(), offset_ps,
- duration_ps, stats.group_id, &plane, &tf_name_scope,
- &tf_ops);
+ /*low_level_event_name=*/event.Name(), timespan,
+ stats.group_id, &plane, &tf_name_scope, &tf_ops);
}
}
RemoveEmptyLines(device_trace);
@@ -349,7 +347,7 @@
// to look up tensorflow op name from hlo_module/hlo_op.
auto dummy_symbol_resolver =
[](absl::optional<uint64_t> program_id, absl::string_view hlo_module,
- absl::string_view hlo_op) { return tensorflow::profiler::Symbol(); };
+ absl::string_view hlo_op) { return Symbol(); };
std::vector<XPlane*> device_traces =
FindMutablePlanesWithPrefix(space, kGpuPlanePrefix);
for (XPlane* plane : device_traces) {
diff --git a/tensorflow/core/profiler/utils/derived_timeline.h b/tensorflow/core/profiler/utils/derived_timeline.h
index c24c574..2870c94 100644
--- a/tensorflow/core/profiler/utils/derived_timeline.h
+++ b/tensorflow/core/profiler/utils/derived_timeline.h
@@ -25,6 +25,7 @@
#include "absl/types/optional.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/group_events.h"
+#include "tensorflow/core/profiler/utils/timespan.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
namespace tensorflow {
@@ -89,7 +90,7 @@
private:
// If the last event of the given level has the same metadata, expands it to
- // include the time until the given event's (offset_ps + duration_ps).
+ // include the time until the given event's end time.
// Otherwise, adds a new event and clears last_event_by_level_ for the levels
// below the given level and all levels of the dependent lines. Clearing
// last_event_by_level_ prevents a nested event from growing larger than the
@@ -117,8 +118,8 @@
// Derives TF name scope and op events from the TF op's fully qualified name
// with the name of the originating low-level event.
void ProcessTfOpEvent(absl::string_view tf_op_full_name,
- absl::string_view low_level_event_name, int64_t offset_ps,
- int64_t duration_ps, absl::optional<int64_t> group_id,
+ absl::string_view low_level_event_name, Timespan timespan,
+ absl::optional<int64_t> group_id,
XPlaneBuilder* plane_builder,
DerivedXLineBuilder* tf_name_scope_line_builder,
DerivedXLineBuilder* tf_op_line_builder);