remove callsite where str_value is directly used in case that it use ref_value.
also don't compare XStat::metadata_id() which is not an invariant.
PiperOrigin-RevId: 305153479
Change-Id: I62e35cf3df603a7fa8870a68ed1914fda7f79614
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
index 9fd12a2..3f5e702 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
@@ -208,19 +208,21 @@
first_op_offset_ps = std::min(first_op_offset_ps, event.OffsetPs());
last_op_offset_ps = std::max(last_op_offset_ps, event.EndOffsetPs());
- const XStat* stat = event.GetStats(StatType::kLevel0);
- if (!stat) return;
- absl::string_view tf_op_fullname = stat->str_value();
- if (tf_op_fullname.empty()) return;
- TfOp tf_op = ParseTfOpFullname(tf_op_fullname);
- TfOpRoofLineCostEstimator::OpRoofLineStats costs;
- if (tf_op.category != Category::kUnknown) {
- costs = op_level_cost_estimator.Predict(event);
- }
- device_op_metrics_db_builder.EnterOp(
- /*program_id=*/0, tf_op.name, tf_op.type, tf_op_fullname,
- /*occurrences=*/1, event.DurationPs(),
- /*children_time_ps=*/0, costs.flops, costs.bytes_accessed);
+ event.ForEachStat([&](const XStatVisitor& stat) {
+ if (stat.Type() == StatType::kLevel0) {
+ auto tf_op_fullname = stat.ToString();
+ if (tf_op_fullname.empty()) return;
+ TfOp tf_op = ParseTfOpFullname(tf_op_fullname);
+ TfOpRoofLineCostEstimator::OpRoofLineStats costs;
+ if (tf_op.category != Category::kUnknown) {
+ costs = op_level_cost_estimator.Predict(event);
+ }
+ device_op_metrics_db_builder.EnterOp(
+ /*program_id=*/0, tf_op.name, tf_op.type, tf_op_fullname,
+ /*occurrences=*/1, event.DurationPs(),
+ /*children_time_ps=*/0, costs.flops, costs.bytes_accessed);
+ }
+ });
});
});
result.set_total_time_ps(last_op_offset_ps - first_op_offset_ps);
diff --git a/tensorflow/core/profiler/internal/cpu/BUILD b/tensorflow/core/profiler/internal/cpu/BUILD
index 1001fec..a7005a6 100644
--- a/tensorflow/core/profiler/internal/cpu/BUILD
+++ b/tensorflow/core/profiler/internal/cpu/BUILD
@@ -53,6 +53,7 @@
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:xplane_schema",
+ "//tensorflow/core/profiler/utils:xplane_visitor",
"@com_google_absl//absl/types:optional",
"@com_google_googletest//:gtest_main",
],
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
index 9944a10..6c1ef02 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
@@ -24,6 +24,7 @@
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
@@ -139,45 +140,79 @@
ASSERT_EQ(space.planes_size(), 1);
const auto& plane = space.planes(0);
+ XPlaneVisitor xplane(&plane);
ASSERT_EQ(plane.name(), kHostThreads);
ASSERT_EQ(plane.lines_size(), 1);
ASSERT_EQ(plane.event_metadata_size(), 6);
ASSERT_EQ(plane.stat_metadata_size(), 2);
- const auto& event_metadata = plane.event_metadata();
- const auto& stat_metadata = plane.stat_metadata();
const auto& line = plane.lines(0);
EXPECT_EQ(line.id(), thread_id);
EXPECT_EQ(line.name(), thread_name);
ASSERT_EQ(line.events_size(), 6);
const auto& events = line.events();
- EXPECT_EQ(events[0].metadata_id(), 1);
- EXPECT_EQ(event_metadata.at(1).name(), "hello");
+
+ XEventVisitor e0(&xplane, &line, &events[0]);
+ EXPECT_EQ(e0.Name(), "hello");
ASSERT_EQ(events[0].stats_size(), 0);
- EXPECT_EQ(events[1].metadata_id(), 2);
- EXPECT_EQ(event_metadata.at(2).name(), "world");
+
+ XEventVisitor e1(&xplane, &line, &events[1]);
+ EXPECT_EQ(e1.Name(), "world");
ASSERT_EQ(events[1].stats_size(), 0);
- EXPECT_EQ(events[2].metadata_id(), 3);
- EXPECT_EQ(event_metadata.at(3).name(), "contains#inside");
+
+ XEventVisitor e2(&xplane, &line, &events[2]);
+ EXPECT_EQ(e2.Name(), "contains#inside");
ASSERT_EQ(events[2].stats_size(), 0);
- EXPECT_EQ(events[3].metadata_id(), 4);
- EXPECT_EQ(event_metadata.at(4).name(), "good");
+
+ XEventVisitor e3(&xplane, &line, &events[3]);
+ EXPECT_EQ(e3.Name(), "good");
ASSERT_EQ(events[3].stats_size(), 1);
- EXPECT_EQ(events[3].stats(0).metadata_id(), 1);
- EXPECT_EQ(stat_metadata.at(1).name(), "key1");
- EXPECT_EQ(events[3].stats(0).str_value(), "value1");
- EXPECT_EQ(events[4].metadata_id(), 5);
- EXPECT_EQ(event_metadata.at(5).name(), "morning");
+ {
+ absl::optional<std::string> value;
+ e3.ForEachStat([&](const XStatVisitor& stat) {
+ if (stat.Name() == "key1") value = stat.ToString();
+ });
+ ASSERT_TRUE(value); // The stat key is present.
+ EXPECT_EQ(*value, "value1"); // The stat value is expected.
+ }
+
+ XEventVisitor e4(&xplane, &line, &events[4]);
+ EXPECT_EQ(e4.Name(), "morning");
ASSERT_EQ(events[4].stats_size(), 2);
- EXPECT_EQ(events[4].stats(0).metadata_id(), 1);
- EXPECT_EQ(events[4].stats(0).str_value(), "value1");
- EXPECT_EQ(events[4].stats(1).metadata_id(), 2);
- EXPECT_EQ(stat_metadata.at(2).name(), "key2");
- EXPECT_EQ(events[4].stats(1).str_value(), "value2");
+ {
+ absl::optional<std::string> value1, value2;
+ e4.ForEachStat([&](const XStatVisitor& stat) {
+ if (stat.Name() == "key1") {
+ value1 = stat.ToString();
+ } else if (stat.Name() == "key2") {
+ value2 = stat.ToString();
+ }
+ });
+ ASSERT_TRUE(value1 && value2); // The stat keys are presents.
+ EXPECT_EQ(*value1, "value1"); // The stat value1 is expected.
+ EXPECT_EQ(*value2, "value2"); // The stat value2 is expected.
+ }
+
+ XEventVisitor e5(&xplane, &line, &events[5]);
+ EXPECT_EQ(e5.Name(), "incomplete");
+ ASSERT_EQ(events[5].stats_size(), 1);
+ {
+ absl::optional<std::string> value1, value2;
+ e5.ForEachStat([&](const XStatVisitor& stat) {
+ if (stat.Name() == "key1") {
+ value1 = stat.ToString();
+ } else if (stat.Name() == "key2") {
+ value2 = stat.ToString();
+ }
+ });
+ ASSERT_TRUE(value1 && !value2); // One of the stat key is present.
+ EXPECT_EQ(*value1, "value1"); // The stat value is expected.
+ }
+#if 0
EXPECT_EQ(events[5].metadata_id(), 6);
EXPECT_EQ(event_metadata.at(6).name(), "incomplete");
ASSERT_EQ(events[5].stats_size(), 1);
- EXPECT_EQ(events[5].stats(0).metadata_id(), 1);
- EXPECT_EQ(events[5].stats(0).str_value(), "value1");
+ EXPECT_EQ(GetXStatString(events[5].stats(0), plane), "value1");
+#endif
}
} // namespace
diff --git a/tensorflow/core/profiler/utils/group_events.cc b/tensorflow/core/profiler/utils/group_events.cc
index c1adade..5a3bb65 100644
--- a/tensorflow/core/profiler/utils/group_events.cc
+++ b/tensorflow/core/profiler/utils/group_events.cc
@@ -23,6 +23,7 @@
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
@@ -144,7 +145,8 @@
std::string EventNode::GetGroupName() const {
std::vector<std::string> name_parts;
if (const XStat* graph_type_stat = GetContextStat(StatType::kGraphType)) {
- name_parts.push_back(graph_type_stat->str_value());
+ XStatVisitor stat(visitor_, graph_type_stat);
+ name_parts.push_back(stat.ToString());
}
int64 step_num = group_id_.value_or(0);
if (const XStat* step_num_stat = GetContextStat(StatType::kStepNum)) {