Make XPlaneBuilder to use reserved metadata ids for known stats in XPlaneSchema. Also, remove more dependencies to MetadataMatcher.
PiperOrigin-RevId: 290198703
Change-Id: I46cdbfe42d0a4306ff6e14544a5aa239989ccaf0
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
index a531341..a28f1df 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
@@ -35,14 +35,16 @@
thread1.AddEvent(*host_plane.GetOrCreateEventMetadata("event1"));
event1.SetTimestampNs(150000);
event1.SetDurationNs(10000);
- event1.ParseAndAddStatValue(StatType::kTfOp, "Relu");
+ event1.ParseAndAddStatValue(*host_plane.GetOrCreateStatMetadata("tf_op"),
+ "Relu");
XLineBuilder thread2 = host_plane.GetOrCreateLine(20);
thread2.SetName("thread2");
XEventBuilder event2 =
thread2.AddEvent(*host_plane.GetOrCreateEventMetadata("event2"));
event2.SetTimestampNs(160000);
event2.SetDurationNs(10000);
- event2.ParseAndAddStatValue(StatType::kTfOp, "Conv2D");
+ event2.ParseAndAddStatValue(*host_plane.GetOrCreateStatMetadata("tf_op"),
+ "Conv2D");
device_plane.SetName("gpu:0");
device_plane.SetId(1);
@@ -52,7 +54,8 @@
stream1.AddEvent(*device_plane.GetOrCreateEventMetadata("kernel1"));
event3.SetTimestampNs(180000);
event3.SetDurationNs(10000);
- event3.ParseAndAddStatValue(StatType::kCorrelationId, "55");
+ event3.ParseAndAddStatValue(
+ *device_plane.GetOrCreateStatMetadata("correlation id"), "55");
}
TEST(ConvertXPlaneToTraceEvents, Convert) {
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
index 2ecafff..f98912a 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
@@ -130,7 +130,7 @@
ASSERT_EQ(plane.name(), kHostThreads);
ASSERT_EQ(plane.lines_size(), 1);
ASSERT_EQ(plane.event_metadata_size(), 6);
- ASSERT_EQ(plane.stat_metadata_size(), GetNumStatTypes() + 2);
+ ASSERT_EQ(plane.stat_metadata_size(), 2);
const auto& event_metadata = plane.event_metadata();
const auto& stat_metadata = plane.stat_metadata();
const auto& line = plane.lines(0);
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc
index 5dbc47a..9255583 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc
@@ -96,7 +96,7 @@
xplane.GetOrCreateStatMetadata(xstat_metadata_by_name.size());
xstat_metadata->set_name(string(metadata.key));
}
- xevent.ParseAndAddStatValue(xstat_metadata->id(), metadata.value);
+ xevent.ParseAndAddStatValue(*xstat_metadata, metadata.value);
}
}
}
diff --git a/tensorflow/core/profiler/internal/gpu/device_tracer.cc b/tensorflow/core/profiler/internal/gpu/device_tracer.cc
index 523f32d..71dae46 100644
--- a/tensorflow/core/profiler/internal/gpu/device_tracer.cc
+++ b/tensorflow/core/profiler/internal/gpu/device_tracer.cc
@@ -61,11 +61,13 @@
xevent.SetTimestampNs(event.start_time_ns + offset_ns);
xevent.SetEndTimestampNs(event.end_time_ns + offset_ns);
if (event.correlation_id != CuptiTracerEvent::kInvalidCorrelationId) {
- xevent.AddStatValue(StatType::kCorrelationId, event.correlation_id);
+ xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
+ GetStatTypeStr(StatType::kCorrelationId)),
+ event.correlation_id);
}
if (event.context_id != CuptiTracerEvent::kInvalidContextId) {
xevent.AddStatValue(
- StatType::kContextId,
+ *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kContextId)),
absl::StrCat("$$", static_cast<uint64>(event.context_id)));
}
if (event.type == CuptiTracerEventType::Kernel) {
@@ -76,7 +78,9 @@
event.kernel_info.grid_x, event.kernel_info.grid_y,
event.kernel_info.grid_z, event.kernel_info.block_x,
event.kernel_info.block_y, event.kernel_info.block_z);
- xevent.AddStatValue(StatType::kKernelDetails, kernel_details);
+ xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
+ GetStatTypeStr(StatType::kKernelDetails)),
+ kernel_details);
}
if (event.type == CuptiTracerEventType::MemcpyH2D ||
event.type == CuptiTracerEventType::MemcpyD2H ||
@@ -87,19 +91,23 @@
std::string memcpy_details =
absl::StrFormat("size:%u dest:%u async:%u", memcpy_info.num_bytes,
memcpy_info.destination, memcpy_info.async);
- xevent.AddStatValue(StatType::kMemcpyDetails, memcpy_details);
+ xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
+ GetStatTypeStr(StatType::kMemcpyDetails)),
+ memcpy_details);
}
if (event.type == CuptiTracerEventType::MemoryAlloc) {
std::string memalloc_details =
absl::StrFormat("num_bytes:%u", event.memalloc_info.num_bytes);
- xevent.AddStatValue(StatType::kMemallocDetails, memalloc_details);
+ xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
+ GetStatTypeStr(StatType::kMemallocDetails)),
+ memalloc_details);
}
std::vector<Annotation> annotation_stack =
ParseAnnotationStack(event.annotation);
for (int i = 0; i < annotation_stack.size(); ++i) {
xevent.AddStatValue(
- plane->GetOrCreateStatMetadata(absl::StrCat("level ", i))->id(),
+ *plane->GetOrCreateStatMetadata(absl::StrCat("level ", i)),
annotation_stack[i].name);
}
// If multiple metadata have the same key name, show the values from the top
@@ -113,7 +121,7 @@
continue; // ignored, obtained from HLO proto via DebugInfoMap
} else if (key_set.insert(metadata.key).second) {
xevent.ParseAndAddStatValue(
- plane->GetOrCreateStatMetadata(metadata.key)->id(), metadata.value);
+ *plane->GetOrCreateStatMetadata(metadata.key), metadata.value);
}
}
}
@@ -328,14 +336,19 @@
auto clock_rate_in_khz =
GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_CLOCK_RATE);
if (clock_rate_in_khz) {
- device_plane->AddStatValue(StatType::kDevCapClockRateKHz,
- *clock_rate_in_khz);
+ device_plane->AddStatValue(
+ *device_plane->GetOrCreateStatMetadata(
+ GetStatTypeStr(StatType::kDevCapClockRateKHz)),
+ *clock_rate_in_khz);
}
auto core_count =
GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT);
if (core_count) {
- device_plane->AddStatValue(StatType::kDevCapCoreCount, *core_count);
+ device_plane->AddStatValue(
+ *device_plane->GetOrCreateStatMetadata(
+ GetStatTypeStr(StatType::kDevCapCoreCount)),
+ *core_count);
}
auto mem_clock_khz =
@@ -347,27 +360,35 @@
// data lane.
auto memory_bandwidth =
2ULL * (*mem_clock_khz) * 1000 * (*mem_bus_width_bits) / 8;
- device_plane->AddStatValue(StatType::kDevCapMemoryBandwidth,
- memory_bandwidth);
+ device_plane->AddStatValue(
+ *device_plane->GetOrCreateStatMetadata(
+ GetStatTypeStr(StatType::kDevCapMemoryBandwidth)),
+ memory_bandwidth);
}
size_t total_memory = 0;
if (cuDeviceTotalMem(&total_memory, device) == CUDA_SUCCESS) {
- device_plane->AddStatValue(StatType::kDevCapMemorySize,
- static_cast<uint64>(total_memory));
+ device_plane->AddStatValue(
+ *device_plane->GetOrCreateStatMetadata(
+ GetStatTypeStr(StatType::kDevCapMemorySize)),
+ static_cast<uint64>(total_memory));
}
auto compute_capability_major = GetDeviceAttribute(
device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR);
if (compute_capability_major) {
- device_plane->AddStatValue(StatType::kDevCapComputeCapMajor,
- *compute_capability_major);
+ device_plane->AddStatValue(
+ *device_plane->GetOrCreateStatMetadata(
+ GetStatTypeStr(StatType::kDevCapComputeCapMajor)),
+ *compute_capability_major);
}
auto compute_capability_minor = GetDeviceAttribute(
device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR);
if (compute_capability_minor) {
- device_plane->AddStatValue(StatType::kDevCapComputeCapMinor,
- *compute_capability_minor);
+ device_plane->AddStatValue(
+ *device_plane->GetOrCreateStatMetadata(
+ GetStatTypeStr(StatType::kDevCapComputeCapMinor)),
+ *compute_capability_minor);
}
}
diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD
index fc3eb63..41e1fa2 100644
--- a/tensorflow/core/profiler/utils/BUILD
+++ b/tensorflow/core/profiler/utils/BUILD
@@ -117,7 +117,6 @@
deps = [
":tf_op_utils",
":time_utils",
- ":xplane_schema",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
diff --git a/tensorflow/core/profiler/utils/metadata_matcher.cc b/tensorflow/core/profiler/utils/metadata_matcher.cc
index 9d95161..7abdd77 100644
--- a/tensorflow/core/profiler/utils/metadata_matcher.cc
+++ b/tensorflow/core/profiler/utils/metadata_matcher.cc
@@ -21,7 +21,9 @@
namespace profiler {
namespace {
+using ::tensorflow::profiler::XEvent;
using ::tensorflow::profiler::XPlane;
+using ::tensorflow::profiler::XStat;
absl::flat_hash_map<int64, int> CreateEventMetadataMap(
const XPlane& xplane,
@@ -49,17 +51,95 @@
return id_to_event_type_map;
}
+absl::flat_hash_map<int64, int> CreateStatMetadataMap(
+ const XPlane& xplane,
+ const absl::Span<const absl::string_view> stat_type_str_map) {
+ absl::flat_hash_map<int64, int> id_to_stat_type_map;
+ for (const auto& id_and_stat_metadata : xplane.stat_metadata()) {
+ int64 id = id_and_stat_metadata.first;
+ absl::string_view stat_name = id_and_stat_metadata.second.name();
+ for (int stat_type = 0; stat_type < stat_type_str_map.size(); ++stat_type) {
+ if (stat_type_str_map[stat_type] == stat_name) {
+ id_to_stat_type_map[id] = stat_type;
+ break;
+ }
+ }
+ }
+ return id_to_stat_type_map;
+}
+
} // namespace
MetadataMatcher::MetadataMatcher(
const XPlane& xplane,
const std::vector<std::pair<const absl::Span<const absl::string_view>,
/*first_event_type*/ int>>&
- event_type_metadata_maps)
+ event_type_metadata_maps,
+ const absl::Span<const absl::string_view> stat_type_str_map)
: id_to_event_type_map_(
CreateEventMetadataMap(xplane, event_type_metadata_maps)),
+ id_to_stat_type_map_(CreateStatMetadataMap(xplane, stat_type_str_map)),
event_type_to_id_map_(gtl::ReverseMap<decltype(event_type_to_id_map_)>(
- id_to_event_type_map_)) {}
+ id_to_event_type_map_)),
+ stat_type_to_id_map_(gtl::ReverseMap<decltype(stat_type_to_id_map_)>(
+ id_to_stat_type_map_)) {}
+
+const XStat* MetadataMatcher::GetStat(const XEvent& event,
+ int stat_type) const {
+ for (const auto& stat : event.stats()) {
+ if (GetStatType(stat) == stat_type) {
+ return &stat;
+ }
+ }
+ return nullptr;
+}
+
+absl::optional<std::tuple<const XStat*, const XStat*>>
+MetadataMatcher::GetStats(const XEvent& event, int first_stat_type,
+ int second_stat_type) const {
+ const XStat* first_stat = nullptr;
+ const XStat* second_stat = nullptr;
+ for (const auto& stat : event.stats()) {
+ if (GetStatType(stat) == first_stat_type) {
+ first_stat = &stat;
+ } else if (GetStatType(stat) == second_stat_type) {
+ second_stat = &stat;
+ }
+ }
+ if (first_stat && second_stat) {
+ return std::make_tuple(first_stat, second_stat);
+ }
+ return absl::nullopt;
+}
+
+absl::optional<std::tuple<const XStat*, const XStat*, const XStat*>>
+MetadataMatcher::GetStats(const XEvent& event, int first_stat_type,
+ int second_stat_type, int third_stat_type) const {
+ const XStat* first_stat = nullptr;
+ const XStat* second_stat = nullptr;
+ const XStat* third_stat = nullptr;
+ for (const auto& stat : event.stats()) {
+ if (GetStatType(stat) == first_stat_type) {
+ first_stat = &stat;
+ } else if (GetStatType(stat) == second_stat_type) {
+ second_stat = &stat;
+ } else if (GetStatType(stat) == third_stat_type) {
+ third_stat = &stat;
+ }
+ }
+ if (first_stat && second_stat && third_stat) {
+ return std::make_tuple(first_stat, second_stat, third_stat);
+ }
+ return absl::nullopt;
+}
+
+absl::optional<int64> MetadataMatcher::GetIntStatValue(const XEvent& event,
+ int stat_type) const {
+ if (const XStat* stat = GetStat(event, stat_type)) {
+ return stat->int64_value();
+ }
+ return absl::nullopt;
+}
} // namespace profiler
} // namespace tensorflow
diff --git a/tensorflow/core/profiler/utils/metadata_matcher.h b/tensorflow/core/profiler/utils/metadata_matcher.h
index 40f0e5f..beaba5e 100644
--- a/tensorflow/core/profiler/utils/metadata_matcher.h
+++ b/tensorflow/core/profiler/utils/metadata_matcher.h
@@ -27,18 +27,19 @@
namespace tensorflow {
namespace profiler {
-// Builds mapping between metadata ids and interesting event types. Event types
-// are represented in integer ids. Multiple spans of event types can be passed
-// with offset values (i.e., first_event_type) to be used to calculate integer
-// ids for event types. Spans and offset values are expected to result in a
-// unique integer id for each event type.
+// Builds mapping between metadata ids and interesting event and stat types.
+// Event and stat types are represented in integer ids. Multiple spans of event
+// types can be passed with offset values (i.e., first_event_type) to be
+// used to calculate integer ids for event types. Spans and offset values are
+// expected to result in a unique integer id for each event type.
class MetadataMatcher {
public:
explicit MetadataMatcher(
const XPlane& xplane,
const std::vector<std::pair<const absl::Span<const absl::string_view>,
/*first_event_type*/ int>>&
- event_type_metadata_maps);
+ event_type_metadata_maps,
+ const absl::Span<const absl::string_view> stat_type_str_map);
// Returns EventType if input is one of interesting event types.
// Otherwise, it returns kUnknownEventType.
@@ -63,12 +64,42 @@
return absl::nullopt;
}
+ // Returns StatType if input is one of interesting stat types.
+ // Otherwise, it returns kUnknownStatType.
+ int GetStatType(const XStat& xstat) const {
+ return gtl::FindWithDefault(id_to_stat_type_map_, xstat.metadata_id(),
+ /*kUnknownStatType*/ 0);
+ }
+
+ // Returns metadata id if xplane has the input stat type.
+ absl::optional<int64> GetStatMetadataId(int stat_type) const {
+ if (const int64* id = gtl::FindOrNull(stat_type_to_id_map_, stat_type)) {
+ return *id;
+ }
+ return absl::nullopt;
+ }
+
+ const XStat* GetStat(const XEvent& event, int stat_type) const;
+
+ absl::optional<std::tuple<const XStat*, const XStat*>> GetStats(
+ const XEvent& event, int first_stat_type, int second_stat_type) const;
+
+ absl::optional<std::tuple<const XStat*, const XStat*, const XStat*>> GetStats(
+ const XEvent& event, int first_stat_type, int second_stat_type,
+ int third_stat_type) const;
+
+ absl::optional<int64> GetIntStatValue(const XEvent& event,
+ int stat_type) const;
+
private:
- // Maps from metada ids to interesting event types. Uninteresting event types
- // are not cached in these maps and considered to be kUnknownEvent.
+ // Maps from metada ids to interesting event and stat types.
+ // Uninteresting event and stat types are not cached in these maps and
+ // considered to be kUnknown*.
const absl::flat_hash_map<int64, int> id_to_event_type_map_;
+ const absl::flat_hash_map<int64, int> id_to_stat_type_map_;
// Reverse of the above.
const absl::flat_hash_map<int, int64> event_type_to_id_map_;
+ const absl::flat_hash_map<int, int64> stat_type_to_id_map_;
};
} // namespace profiler
diff --git a/tensorflow/core/profiler/utils/metadata_matcher_test.cc b/tensorflow/core/profiler/utils/metadata_matcher_test.cc
index bfbfc9a..d430b44 100644
--- a/tensorflow/core/profiler/utils/metadata_matcher_test.cc
+++ b/tensorflow/core/profiler/utils/metadata_matcher_test.cc
@@ -26,6 +26,7 @@
using ::tensorflow::profiler::XEventMetadata;
using ::tensorflow::profiler::XPlane;
+using ::tensorflow::profiler::XStatMetadata;
TEST(MetadataMatcherTest, GetHostEventTypeTest) {
for (int event_type = HostEventType::kFirstHostEventType;
@@ -37,13 +38,32 @@
GetHostEventTypeStr(static_cast<HostEventType>(event_type))));
MetadataMatcher metadata_matcher(
xplane,
- {{GetHostEventTypeStrMap(), HostEventType::kFirstHostEventType}});
+ {{GetHostEventTypeStrMap(), HostEventType::kFirstHostEventType}},
+ GetStatTypeStrMap());
XEvent event;
event.set_metadata_id(0);
EXPECT_EQ(metadata_matcher.GetEventType(event), event_type);
}
}
+TEST(MetadataMatcherTest, GetStatTypeTest) {
+ for (int stat_type = StatType::kFirstStatType;
+ stat_type <= StatType::kLastStatType; ++stat_type) {
+ XPlane xplane;
+ XStatMetadata& metadata = (*xplane.mutable_stat_metadata())[0];
+ metadata.set_id(0);
+ metadata.set_name(
+ std::string(GetStatTypeStr(static_cast<StatType>(stat_type))));
+ MetadataMatcher metadata_matcher(
+ xplane,
+ {{GetHostEventTypeStrMap(), HostEventType::kFirstHostEventType}},
+ GetStatTypeStrMap());
+ XStat stat;
+ stat.set_metadata_id(0);
+ EXPECT_EQ(metadata_matcher.GetStatType(stat), stat_type);
+ }
+}
+
} // namespace
} // namespace profiler
} // namespace tensorflow
diff --git a/tensorflow/core/profiler/utils/xplane_builder.cc b/tensorflow/core/profiler/utils/xplane_builder.cc
index b6230be..e2aec65 100644
--- a/tensorflow/core/profiler/utils/xplane_builder.cc
+++ b/tensorflow/core/profiler/utils/xplane_builder.cc
@@ -14,9 +14,7 @@
==============================================================================*/
#include "tensorflow/core/profiler/utils/xplane_builder.h"
-#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
-#include "tensorflow/core/profiler/utils/xplane_schema.h"
namespace tensorflow {
namespace profiler {
@@ -28,23 +26,10 @@
std::max<int64>(last_event_metadata_id_, iter.second.id());
event_metadata_by_name_.try_emplace(iter.second.name(), &iter.second);
}
- if (plane->stat_metadata_size() == 0) {
- // Add reserved stat metadata.
- for (const auto& stat_name_and_type : GetStatTypeMap()) {
- XStatMetadata* metadata =
- GetOrCreateStatMetadata(stat_name_and_type.second);
- metadata->set_name(std::string(stat_name_and_type.first));
- stat_metadata_by_name_.try_emplace(stat_name_and_type.first, metadata);
- }
- last_stat_metadata_id_ = kLastStatType;
- } else {
- // If plane is not empty, reserved stat metadata should have been added
- // the first time XPlaneBuilder was called.
- for (auto& iter : *plane->mutable_stat_metadata()) {
- last_stat_metadata_id_ =
- std::max<int64>(last_stat_metadata_id_, iter.second.id());
- stat_metadata_by_name_.try_emplace(iter.second.name(), &iter.second);
- }
+ for (auto& iter : *plane->mutable_stat_metadata()) {
+ last_stat_metadata_id_ =
+ std::max<int64>(last_stat_metadata_id_, iter.second.id());
+ stat_metadata_by_name_.try_emplace(iter.second.name(), &iter.second);
}
for (XLine& line : *plane->mutable_lines()) {
lines_by_id_.try_emplace(line.id(), &line);
diff --git a/tensorflow/core/profiler/utils/xplane_builder.h b/tensorflow/core/profiler/utils/xplane_builder.h
index 2a5e4c8..99a554d 100644
--- a/tensorflow/core/profiler/utils/xplane_builder.h
+++ b/tensorflow/core/profiler/utils/xplane_builder.h
@@ -31,26 +31,26 @@
public:
explicit XStatsBuilder(T* stats_owner) : stats_owner_(stats_owner) {}
- void AddStatValue(int64 metadata_id, uint32 value) {
- AddStat(metadata_id)->set_uint64_value(value);
+ void AddStatValue(const XStatMetadata& metadata, uint32 value) {
+ AddStat(metadata)->set_uint64_value(value);
}
- void AddStatValue(int64 metadata_id, uint64 value) {
- AddStat(metadata_id)->set_uint64_value(value);
+ void AddStatValue(const XStatMetadata& metadata, uint64 value) {
+ AddStat(metadata)->set_uint64_value(value);
}
- void AddStatValue(int64 metadata_id, int32 value) {
- AddStat(metadata_id)->set_int64_value(value);
+ void AddStatValue(const XStatMetadata& metadata, int32 value) {
+ AddStat(metadata)->set_int64_value(value);
}
- void AddStatValue(int64 metadata_id, int64 value) {
- AddStat(metadata_id)->set_int64_value(value);
+ void AddStatValue(const XStatMetadata& metadata, int64 value) {
+ AddStat(metadata)->set_int64_value(value);
}
- void AddStatValue(int64 metadata_id, double value) {
- AddStat(metadata_id)->set_double_value(value);
+ void AddStatValue(const XStatMetadata& metadata, double value) {
+ AddStat(metadata)->set_double_value(value);
}
- void AddStatValue(int64 metadata_id, absl::string_view value) {
- AddStat(metadata_id)->set_str_value(string(value));
+ void AddStatValue(const XStatMetadata& metadata, absl::string_view value) {
+ AddStat(metadata)->set_str_value(string(value));
}
- void AddStatValue(int64 metadata_id, string&& value) {
- AddStat(metadata_id)->set_str_value(std::move(value));
+ void AddStatValue(const XStatMetadata& metadata, string&& value) {
+ AddStat(metadata)->set_str_value(std::move(value));
}
void AddStat(const XStatMetadata& metadata, const XStat& stat) {
@@ -58,18 +58,19 @@
*stats_owner_->add_stats() = stat;
}
- void ParseAndAddStatValue(int64 metadata_id, absl::string_view value) {
+ void ParseAndAddStatValue(const XStatMetadata& metadata,
+ absl::string_view value) {
int64 int_value;
uint64 uint_value;
double double_value;
if (absl::SimpleAtoi(value, &int_value)) {
- AddStatValue(metadata_id, int_value);
+ AddStatValue(metadata, int_value);
} else if (absl::SimpleAtoi(value, &uint_value)) {
- AddStatValue(metadata_id, uint_value);
+ AddStatValue(metadata, uint_value);
} else if (absl::SimpleAtod(value, &double_value)) {
- AddStatValue(metadata_id, double_value);
+ AddStatValue(metadata, double_value);
} else {
- AddStatValue(metadata_id, value);
+ AddStatValue(metadata, value);
}
}
void ReserveStats(size_t num_stats) {
@@ -77,9 +78,9 @@
}
private:
- XStat* AddStat(int64 metadata_id) {
+ XStat* AddStat(const XStatMetadata& metadata) {
XStat* stat = stats_owner_->add_stats();
- stat->set_metadata_id(metadata_id);
+ stat->set_metadata_id(metadata.id());
return stat;
}
diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc
index 767c01d..39e14ef 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.cc
+++ b/tensorflow/core/profiler/utils/xplane_schema.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/profiler/utils/xplane_schema.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -94,7 +95,6 @@
"memcpy_details",
"memalloc_details",
"kernel_details",
- "stream",
"group_id",
"step_name",
"level 0",
@@ -121,8 +121,6 @@
return absl::MakeConstSpan(kStatTypeStrMap, kNumStatTypes);
}
-int GetNumStatTypes() { return kNumStatTypes; }
-
const absl::flat_hash_map<absl::string_view, StatType>& GetStatTypeMap() {
static absl::flat_hash_map<absl::string_view, StatType>* stats_type_map =
new absl::flat_hash_map<absl::string_view, StatType>({
@@ -155,7 +153,6 @@
{"memcpy_details", kMemcpyDetails},
{"memalloc_details", kMemallocDetails},
{"kernel_details", kKernelDetails},
- {"stream", kStream},
// Stats added when processing traces.
{"group_id", kGroupId},
{"step_name", kStepName},
diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h
index fcd1d8d..743fedf 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.h
+++ b/tensorflow/core/profiler/utils/xplane_schema.h
@@ -16,7 +16,6 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_
-#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
@@ -65,9 +64,8 @@
kLastHostEventType = kPartitionedCallOp,
};
-// TODO(jihochoi): Rename it to ReservedStatMetadataId.
enum StatType {
- kFirstStatType = 1 << 10,
+ kFirstStatType = 0,
kUnknownStatType = kFirstStatType,
// TraceMe arguments.
kStepId,
@@ -97,7 +95,6 @@
kMemcpyDetails,
kMemallocDetails,
kKernelDetails,
- kStream,
// Stats added when processing traces.
kGroupId,
kStepName,
@@ -129,19 +126,15 @@
absl::Span<const absl::string_view> GetStatTypeStrMap();
inline absl::string_view GetStatTypeStr(StatType stat_type) {
- return GetStatTypeStrMap()[stat_type - StatType::kFirstStatType];
+ return GetStatTypeStrMap()[stat_type];
}
inline bool IsStatType(StatType stat_type, absl::string_view stat_name) {
return GetStatTypeStr(stat_type) == stat_name;
}
-const absl::flat_hash_map<absl::string_view, StatType>& GetStatTypeMap();
-
StatType GetStatType(absl::string_view stat_name);
-int GetNumStatTypes();
-
} // namespace profiler
} // namespace tensorflow