Make XPlaneBuilder to use reserved metadata ids for known stats in XPlaneSchema. Also, remove more dependencies to MetadataMatcher.

PiperOrigin-RevId: 290198703
Change-Id: I46cdbfe42d0a4306ff6e14544a5aa239989ccaf0
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
index a531341..a28f1df 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
@@ -35,14 +35,16 @@
       thread1.AddEvent(*host_plane.GetOrCreateEventMetadata("event1"));
   event1.SetTimestampNs(150000);
   event1.SetDurationNs(10000);
-  event1.ParseAndAddStatValue(StatType::kTfOp, "Relu");
+  event1.ParseAndAddStatValue(*host_plane.GetOrCreateStatMetadata("tf_op"),
+                              "Relu");
   XLineBuilder thread2 = host_plane.GetOrCreateLine(20);
   thread2.SetName("thread2");
   XEventBuilder event2 =
       thread2.AddEvent(*host_plane.GetOrCreateEventMetadata("event2"));
   event2.SetTimestampNs(160000);
   event2.SetDurationNs(10000);
-  event2.ParseAndAddStatValue(StatType::kTfOp, "Conv2D");
+  event2.ParseAndAddStatValue(*host_plane.GetOrCreateStatMetadata("tf_op"),
+                              "Conv2D");
 
   device_plane.SetName("gpu:0");
   device_plane.SetId(1);
@@ -52,7 +54,8 @@
       stream1.AddEvent(*device_plane.GetOrCreateEventMetadata("kernel1"));
   event3.SetTimestampNs(180000);
   event3.SetDurationNs(10000);
-  event3.ParseAndAddStatValue(StatType::kCorrelationId, "55");
+  event3.ParseAndAddStatValue(
+      *device_plane.GetOrCreateStatMetadata("correlation id"), "55");
 }
 
 TEST(ConvertXPlaneToTraceEvents, Convert) {
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
index 2ecafff..f98912a 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
@@ -130,7 +130,7 @@
   ASSERT_EQ(plane.name(), kHostThreads);
   ASSERT_EQ(plane.lines_size(), 1);
   ASSERT_EQ(plane.event_metadata_size(), 6);
-  ASSERT_EQ(plane.stat_metadata_size(), GetNumStatTypes() + 2);
+  ASSERT_EQ(plane.stat_metadata_size(), 2);
   const auto& event_metadata = plane.event_metadata();
   const auto& stat_metadata = plane.stat_metadata();
   const auto& line = plane.lines(0);
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc
index 5dbc47a..9255583 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc
@@ -96,7 +96,7 @@
               xplane.GetOrCreateStatMetadata(xstat_metadata_by_name.size());
           xstat_metadata->set_name(string(metadata.key));
         }
-        xevent.ParseAndAddStatValue(xstat_metadata->id(), metadata.value);
+        xevent.ParseAndAddStatValue(*xstat_metadata, metadata.value);
       }
     }
   }
diff --git a/tensorflow/core/profiler/internal/gpu/device_tracer.cc b/tensorflow/core/profiler/internal/gpu/device_tracer.cc
index 523f32d..71dae46 100644
--- a/tensorflow/core/profiler/internal/gpu/device_tracer.cc
+++ b/tensorflow/core/profiler/internal/gpu/device_tracer.cc
@@ -61,11 +61,13 @@
   xevent.SetTimestampNs(event.start_time_ns + offset_ns);
   xevent.SetEndTimestampNs(event.end_time_ns + offset_ns);
   if (event.correlation_id != CuptiTracerEvent::kInvalidCorrelationId) {
-    xevent.AddStatValue(StatType::kCorrelationId, event.correlation_id);
+    xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
+                            GetStatTypeStr(StatType::kCorrelationId)),
+                        event.correlation_id);
   }
   if (event.context_id != CuptiTracerEvent::kInvalidContextId) {
     xevent.AddStatValue(
-        StatType::kContextId,
+        *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kContextId)),
         absl::StrCat("$$", static_cast<uint64>(event.context_id)));
   }
   if (event.type == CuptiTracerEventType::Kernel) {
@@ -76,7 +78,9 @@
                         event.kernel_info.grid_x, event.kernel_info.grid_y,
                         event.kernel_info.grid_z, event.kernel_info.block_x,
                         event.kernel_info.block_y, event.kernel_info.block_z);
-    xevent.AddStatValue(StatType::kKernelDetails, kernel_details);
+    xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
+                            GetStatTypeStr(StatType::kKernelDetails)),
+                        kernel_details);
   }
   if (event.type == CuptiTracerEventType::MemcpyH2D ||
       event.type == CuptiTracerEventType::MemcpyD2H ||
@@ -87,19 +91,23 @@
     std::string memcpy_details =
         absl::StrFormat("size:%u dest:%u async:%u", memcpy_info.num_bytes,
                         memcpy_info.destination, memcpy_info.async);
-    xevent.AddStatValue(StatType::kMemcpyDetails, memcpy_details);
+    xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
+                            GetStatTypeStr(StatType::kMemcpyDetails)),
+                        memcpy_details);
   }
   if (event.type == CuptiTracerEventType::MemoryAlloc) {
     std::string memalloc_details =
         absl::StrFormat("num_bytes:%u", event.memalloc_info.num_bytes);
-    xevent.AddStatValue(StatType::kMemallocDetails, memalloc_details);
+    xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
+                            GetStatTypeStr(StatType::kMemallocDetails)),
+                        memalloc_details);
   }
 
   std::vector<Annotation> annotation_stack =
       ParseAnnotationStack(event.annotation);
   for (int i = 0; i < annotation_stack.size(); ++i) {
     xevent.AddStatValue(
-        plane->GetOrCreateStatMetadata(absl::StrCat("level ", i))->id(),
+        *plane->GetOrCreateStatMetadata(absl::StrCat("level ", i)),
         annotation_stack[i].name);
   }
   // If multiple metadata have the same key name, show the values from the top
@@ -113,7 +121,7 @@
         continue;  // ignored, obtained from HLO proto via DebugInfoMap
       } else if (key_set.insert(metadata.key).second) {
         xevent.ParseAndAddStatValue(
-            plane->GetOrCreateStatMetadata(metadata.key)->id(), metadata.value);
+            *plane->GetOrCreateStatMetadata(metadata.key), metadata.value);
       }
     }
   }
@@ -328,14 +336,19 @@
       auto clock_rate_in_khz =
           GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_CLOCK_RATE);
       if (clock_rate_in_khz) {
-        device_plane->AddStatValue(StatType::kDevCapClockRateKHz,
-                                   *clock_rate_in_khz);
+        device_plane->AddStatValue(
+            *device_plane->GetOrCreateStatMetadata(
+                GetStatTypeStr(StatType::kDevCapClockRateKHz)),
+            *clock_rate_in_khz);
       }
 
       auto core_count =
           GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT);
       if (core_count) {
-        device_plane->AddStatValue(StatType::kDevCapCoreCount, *core_count);
+        device_plane->AddStatValue(
+            *device_plane->GetOrCreateStatMetadata(
+                GetStatTypeStr(StatType::kDevCapCoreCount)),
+            *core_count);
       }
 
       auto mem_clock_khz =
@@ -347,27 +360,35 @@
         // data lane.
         auto memory_bandwidth =
             2ULL * (*mem_clock_khz) * 1000 * (*mem_bus_width_bits) / 8;
-        device_plane->AddStatValue(StatType::kDevCapMemoryBandwidth,
-                                   memory_bandwidth);
+        device_plane->AddStatValue(
+            *device_plane->GetOrCreateStatMetadata(
+                GetStatTypeStr(StatType::kDevCapMemoryBandwidth)),
+            memory_bandwidth);
       }
 
       size_t total_memory = 0;
       if (cuDeviceTotalMem(&total_memory, device) == CUDA_SUCCESS) {
-        device_plane->AddStatValue(StatType::kDevCapMemorySize,
-                                   static_cast<uint64>(total_memory));
+        device_plane->AddStatValue(
+            *device_plane->GetOrCreateStatMetadata(
+                GetStatTypeStr(StatType::kDevCapMemorySize)),
+            static_cast<uint64>(total_memory));
       }
 
       auto compute_capability_major = GetDeviceAttribute(
           device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR);
       if (compute_capability_major) {
-        device_plane->AddStatValue(StatType::kDevCapComputeCapMajor,
-                                   *compute_capability_major);
+        device_plane->AddStatValue(
+            *device_plane->GetOrCreateStatMetadata(
+                GetStatTypeStr(StatType::kDevCapComputeCapMajor)),
+            *compute_capability_major);
       }
       auto compute_capability_minor = GetDeviceAttribute(
           device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR);
       if (compute_capability_minor) {
-        device_plane->AddStatValue(StatType::kDevCapComputeCapMinor,
-                                   *compute_capability_minor);
+        device_plane->AddStatValue(
+            *device_plane->GetOrCreateStatMetadata(
+                GetStatTypeStr(StatType::kDevCapComputeCapMinor)),
+            *compute_capability_minor);
       }
     }
 
diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD
index fc3eb63..41e1fa2 100644
--- a/tensorflow/core/profiler/utils/BUILD
+++ b/tensorflow/core/profiler/utils/BUILD
@@ -117,7 +117,6 @@
     deps = [
         ":tf_op_utils",
         ":time_utils",
-        ":xplane_schema",
         "//tensorflow/core:lib",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "@com_google_absl//absl/container:flat_hash_map",
diff --git a/tensorflow/core/profiler/utils/metadata_matcher.cc b/tensorflow/core/profiler/utils/metadata_matcher.cc
index 9d95161..7abdd77 100644
--- a/tensorflow/core/profiler/utils/metadata_matcher.cc
+++ b/tensorflow/core/profiler/utils/metadata_matcher.cc
@@ -21,7 +21,9 @@
 namespace profiler {
 namespace {
 
+using ::tensorflow::profiler::XEvent;
 using ::tensorflow::profiler::XPlane;
+using ::tensorflow::profiler::XStat;
 
 absl::flat_hash_map<int64, int> CreateEventMetadataMap(
     const XPlane& xplane,
@@ -49,17 +51,95 @@
   return id_to_event_type_map;
 }
 
+absl::flat_hash_map<int64, int> CreateStatMetadataMap(
+    const XPlane& xplane,
+    const absl::Span<const absl::string_view> stat_type_str_map) {
+  absl::flat_hash_map<int64, int> id_to_stat_type_map;
+  for (const auto& id_and_stat_metadata : xplane.stat_metadata()) {
+    int64 id = id_and_stat_metadata.first;
+    absl::string_view stat_name = id_and_stat_metadata.second.name();
+    for (int stat_type = 0; stat_type < stat_type_str_map.size(); ++stat_type) {
+      if (stat_type_str_map[stat_type] == stat_name) {
+        id_to_stat_type_map[id] = stat_type;
+        break;
+      }
+    }
+  }
+  return id_to_stat_type_map;
+}
+
 }  // namespace
 
 MetadataMatcher::MetadataMatcher(
     const XPlane& xplane,
     const std::vector<std::pair<const absl::Span<const absl::string_view>,
                                 /*first_event_type*/ int>>&
-        event_type_metadata_maps)
+        event_type_metadata_maps,
+    const absl::Span<const absl::string_view> stat_type_str_map)
     : id_to_event_type_map_(
           CreateEventMetadataMap(xplane, event_type_metadata_maps)),
+      id_to_stat_type_map_(CreateStatMetadataMap(xplane, stat_type_str_map)),
       event_type_to_id_map_(gtl::ReverseMap<decltype(event_type_to_id_map_)>(
-          id_to_event_type_map_)) {}
+          id_to_event_type_map_)),
+      stat_type_to_id_map_(gtl::ReverseMap<decltype(stat_type_to_id_map_)>(
+          id_to_stat_type_map_)) {}
+
+const XStat* MetadataMatcher::GetStat(const XEvent& event,
+                                      int stat_type) const {
+  for (const auto& stat : event.stats()) {
+    if (GetStatType(stat) == stat_type) {
+      return &stat;
+    }
+  }
+  return nullptr;
+}
+
+absl::optional<std::tuple<const XStat*, const XStat*>>
+MetadataMatcher::GetStats(const XEvent& event, int first_stat_type,
+                          int second_stat_type) const {
+  const XStat* first_stat = nullptr;
+  const XStat* second_stat = nullptr;
+  for (const auto& stat : event.stats()) {
+    if (GetStatType(stat) == first_stat_type) {
+      first_stat = &stat;
+    } else if (GetStatType(stat) == second_stat_type) {
+      second_stat = &stat;
+    }
+  }
+  if (first_stat && second_stat) {
+    return std::make_tuple(first_stat, second_stat);
+  }
+  return absl::nullopt;
+}
+
+absl::optional<std::tuple<const XStat*, const XStat*, const XStat*>>
+MetadataMatcher::GetStats(const XEvent& event, int first_stat_type,
+                          int second_stat_type, int third_stat_type) const {
+  const XStat* first_stat = nullptr;
+  const XStat* second_stat = nullptr;
+  const XStat* third_stat = nullptr;
+  for (const auto& stat : event.stats()) {
+    if (GetStatType(stat) == first_stat_type) {
+      first_stat = &stat;
+    } else if (GetStatType(stat) == second_stat_type) {
+      second_stat = &stat;
+    } else if (GetStatType(stat) == third_stat_type) {
+      third_stat = &stat;
+    }
+  }
+  if (first_stat && second_stat && third_stat) {
+    return std::make_tuple(first_stat, second_stat, third_stat);
+  }
+  return absl::nullopt;
+}
+
+absl::optional<int64> MetadataMatcher::GetIntStatValue(const XEvent& event,
+                                                       int stat_type) const {
+  if (const XStat* stat = GetStat(event, stat_type)) {
+    return stat->int64_value();
+  }
+  return absl::nullopt;
+}
 
 }  // namespace profiler
 }  // namespace tensorflow
diff --git a/tensorflow/core/profiler/utils/metadata_matcher.h b/tensorflow/core/profiler/utils/metadata_matcher.h
index 40f0e5f..beaba5e 100644
--- a/tensorflow/core/profiler/utils/metadata_matcher.h
+++ b/tensorflow/core/profiler/utils/metadata_matcher.h
@@ -27,18 +27,19 @@
 namespace tensorflow {
 namespace profiler {
 
-// Builds mapping between metadata ids and interesting event types. Event types
-// are represented in integer ids. Multiple spans of event types can be passed
-// with offset values (i.e., first_event_type) to be used to calculate integer
-// ids for event types. Spans and offset values are expected to result in a
-// unique integer id for each event type.
+// Builds mapping between metadata ids and interesting event and stat types.
+// Event and stat types are represented in integer ids. Multiple spans of event
+// types can be passed with offset values (i.e., first_event_type) to be
+// used to calculate integer ids for event types. Spans and offset values are
+// expected to result in a unique integer id for each event type.
 class MetadataMatcher {
  public:
   explicit MetadataMatcher(
       const XPlane& xplane,
       const std::vector<std::pair<const absl::Span<const absl::string_view>,
                                   /*first_event_type*/ int>>&
-          event_type_metadata_maps);
+          event_type_metadata_maps,
+      const absl::Span<const absl::string_view> stat_type_str_map);
 
   // Returns EventType if input is one of interesting event types.
   // Otherwise, it returns kUnknownEventType.
@@ -63,12 +64,42 @@
     return absl::nullopt;
   }
 
+  // Returns StatType if input is one of interesting stat types.
+  // Otherwise, it returns kUnknownStatType.
+  int GetStatType(const XStat& xstat) const {
+    return gtl::FindWithDefault(id_to_stat_type_map_, xstat.metadata_id(),
+                                /*kUnknownStatType*/ 0);
+  }
+
+  // Returns metadata id if xplane has the input stat type.
+  absl::optional<int64> GetStatMetadataId(int stat_type) const {
+    if (const int64* id = gtl::FindOrNull(stat_type_to_id_map_, stat_type)) {
+      return *id;
+    }
+    return absl::nullopt;
+  }
+
+  const XStat* GetStat(const XEvent& event, int stat_type) const;
+
+  absl::optional<std::tuple<const XStat*, const XStat*>> GetStats(
+      const XEvent& event, int first_stat_type, int second_stat_type) const;
+
+  absl::optional<std::tuple<const XStat*, const XStat*, const XStat*>> GetStats(
+      const XEvent& event, int first_stat_type, int second_stat_type,
+      int third_stat_type) const;
+
+  absl::optional<int64> GetIntStatValue(const XEvent& event,
+                                        int stat_type) const;
+
  private:
-  // Maps from metada ids to interesting event types. Uninteresting event types
-  // are not cached in these maps and considered to be kUnknownEvent.
+  // Maps from metada ids to interesting event and stat types.
+  // Uninteresting event and stat types are not cached in these maps and
+  // considered to be kUnknown*.
   const absl::flat_hash_map<int64, int> id_to_event_type_map_;
+  const absl::flat_hash_map<int64, int> id_to_stat_type_map_;
   // Reverse of the above.
   const absl::flat_hash_map<int, int64> event_type_to_id_map_;
+  const absl::flat_hash_map<int, int64> stat_type_to_id_map_;
 };
 
 }  // namespace profiler
diff --git a/tensorflow/core/profiler/utils/metadata_matcher_test.cc b/tensorflow/core/profiler/utils/metadata_matcher_test.cc
index bfbfc9a..d430b44 100644
--- a/tensorflow/core/profiler/utils/metadata_matcher_test.cc
+++ b/tensorflow/core/profiler/utils/metadata_matcher_test.cc
@@ -26,6 +26,7 @@
 
 using ::tensorflow::profiler::XEventMetadata;
 using ::tensorflow::profiler::XPlane;
+using ::tensorflow::profiler::XStatMetadata;
 
 TEST(MetadataMatcherTest, GetHostEventTypeTest) {
   for (int event_type = HostEventType::kFirstHostEventType;
@@ -37,13 +38,32 @@
         GetHostEventTypeStr(static_cast<HostEventType>(event_type))));
     MetadataMatcher metadata_matcher(
         xplane,
-        {{GetHostEventTypeStrMap(), HostEventType::kFirstHostEventType}});
+        {{GetHostEventTypeStrMap(), HostEventType::kFirstHostEventType}},
+        GetStatTypeStrMap());
     XEvent event;
     event.set_metadata_id(0);
     EXPECT_EQ(metadata_matcher.GetEventType(event), event_type);
   }
 }
 
+TEST(MetadataMatcherTest, GetStatTypeTest) {
+  for (int stat_type = StatType::kFirstStatType;
+       stat_type <= StatType::kLastStatType; ++stat_type) {
+    XPlane xplane;
+    XStatMetadata& metadata = (*xplane.mutable_stat_metadata())[0];
+    metadata.set_id(0);
+    metadata.set_name(
+        std::string(GetStatTypeStr(static_cast<StatType>(stat_type))));
+    MetadataMatcher metadata_matcher(
+        xplane,
+        {{GetHostEventTypeStrMap(), HostEventType::kFirstHostEventType}},
+        GetStatTypeStrMap());
+    XStat stat;
+    stat.set_metadata_id(0);
+    EXPECT_EQ(metadata_matcher.GetStatType(stat), stat_type);
+  }
+}
+
 }  // namespace
 }  // namespace profiler
 }  // namespace tensorflow
diff --git a/tensorflow/core/profiler/utils/xplane_builder.cc b/tensorflow/core/profiler/utils/xplane_builder.cc
index b6230be..e2aec65 100644
--- a/tensorflow/core/profiler/utils/xplane_builder.cc
+++ b/tensorflow/core/profiler/utils/xplane_builder.cc
@@ -14,9 +14,7 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 
-#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
-#include "tensorflow/core/profiler/utils/xplane_schema.h"
 
 namespace tensorflow {
 namespace profiler {
@@ -28,23 +26,10 @@
         std::max<int64>(last_event_metadata_id_, iter.second.id());
     event_metadata_by_name_.try_emplace(iter.second.name(), &iter.second);
   }
-  if (plane->stat_metadata_size() == 0) {
-    // Add reserved stat metadata.
-    for (const auto& stat_name_and_type : GetStatTypeMap()) {
-      XStatMetadata* metadata =
-          GetOrCreateStatMetadata(stat_name_and_type.second);
-      metadata->set_name(std::string(stat_name_and_type.first));
-      stat_metadata_by_name_.try_emplace(stat_name_and_type.first, metadata);
-    }
-    last_stat_metadata_id_ = kLastStatType;
-  } else {
-    // If plane is not empty, reserved stat metadata should have been added
-    // the first time XPlaneBuilder was called.
-    for (auto& iter : *plane->mutable_stat_metadata()) {
-      last_stat_metadata_id_ =
-          std::max<int64>(last_stat_metadata_id_, iter.second.id());
-      stat_metadata_by_name_.try_emplace(iter.second.name(), &iter.second);
-    }
+  for (auto& iter : *plane->mutable_stat_metadata()) {
+    last_stat_metadata_id_ =
+        std::max<int64>(last_stat_metadata_id_, iter.second.id());
+    stat_metadata_by_name_.try_emplace(iter.second.name(), &iter.second);
   }
   for (XLine& line : *plane->mutable_lines()) {
     lines_by_id_.try_emplace(line.id(), &line);
diff --git a/tensorflow/core/profiler/utils/xplane_builder.h b/tensorflow/core/profiler/utils/xplane_builder.h
index 2a5e4c8..99a554d 100644
--- a/tensorflow/core/profiler/utils/xplane_builder.h
+++ b/tensorflow/core/profiler/utils/xplane_builder.h
@@ -31,26 +31,26 @@
  public:
   explicit XStatsBuilder(T* stats_owner) : stats_owner_(stats_owner) {}
 
-  void AddStatValue(int64 metadata_id, uint32 value) {
-    AddStat(metadata_id)->set_uint64_value(value);
+  void AddStatValue(const XStatMetadata& metadata, uint32 value) {
+    AddStat(metadata)->set_uint64_value(value);
   }
-  void AddStatValue(int64 metadata_id, uint64 value) {
-    AddStat(metadata_id)->set_uint64_value(value);
+  void AddStatValue(const XStatMetadata& metadata, uint64 value) {
+    AddStat(metadata)->set_uint64_value(value);
   }
-  void AddStatValue(int64 metadata_id, int32 value) {
-    AddStat(metadata_id)->set_int64_value(value);
+  void AddStatValue(const XStatMetadata& metadata, int32 value) {
+    AddStat(metadata)->set_int64_value(value);
   }
-  void AddStatValue(int64 metadata_id, int64 value) {
-    AddStat(metadata_id)->set_int64_value(value);
+  void AddStatValue(const XStatMetadata& metadata, int64 value) {
+    AddStat(metadata)->set_int64_value(value);
   }
-  void AddStatValue(int64 metadata_id, double value) {
-    AddStat(metadata_id)->set_double_value(value);
+  void AddStatValue(const XStatMetadata& metadata, double value) {
+    AddStat(metadata)->set_double_value(value);
   }
-  void AddStatValue(int64 metadata_id, absl::string_view value) {
-    AddStat(metadata_id)->set_str_value(string(value));
+  void AddStatValue(const XStatMetadata& metadata, absl::string_view value) {
+    AddStat(metadata)->set_str_value(string(value));
   }
-  void AddStatValue(int64 metadata_id, string&& value) {
-    AddStat(metadata_id)->set_str_value(std::move(value));
+  void AddStatValue(const XStatMetadata& metadata, string&& value) {
+    AddStat(metadata)->set_str_value(std::move(value));
   }
 
   void AddStat(const XStatMetadata& metadata, const XStat& stat) {
@@ -58,18 +58,19 @@
     *stats_owner_->add_stats() = stat;
   }
 
-  void ParseAndAddStatValue(int64 metadata_id, absl::string_view value) {
+  void ParseAndAddStatValue(const XStatMetadata& metadata,
+                            absl::string_view value) {
     int64 int_value;
     uint64 uint_value;
     double double_value;
     if (absl::SimpleAtoi(value, &int_value)) {
-      AddStatValue(metadata_id, int_value);
+      AddStatValue(metadata, int_value);
     } else if (absl::SimpleAtoi(value, &uint_value)) {
-      AddStatValue(metadata_id, uint_value);
+      AddStatValue(metadata, uint_value);
     } else if (absl::SimpleAtod(value, &double_value)) {
-      AddStatValue(metadata_id, double_value);
+      AddStatValue(metadata, double_value);
     } else {
-      AddStatValue(metadata_id, value);
+      AddStatValue(metadata, value);
     }
   }
   void ReserveStats(size_t num_stats) {
@@ -77,9 +78,9 @@
   }
 
  private:
-  XStat* AddStat(int64 metadata_id) {
+  XStat* AddStat(const XStatMetadata& metadata) {
     XStat* stat = stats_owner_->add_stats();
-    stat->set_metadata_id(metadata_id);
+    stat->set_metadata_id(metadata.id());
     return stat;
   }
 
diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc
index 767c01d..39e14ef 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.cc
+++ b/tensorflow/core/profiler/utils/xplane_schema.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
 
+#include "absl/container/flat_hash_map.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 
@@ -94,7 +95,6 @@
     "memcpy_details",
     "memalloc_details",
     "kernel_details",
-    "stream",
     "group_id",
     "step_name",
     "level 0",
@@ -121,8 +121,6 @@
   return absl::MakeConstSpan(kStatTypeStrMap, kNumStatTypes);
 }
 
-int GetNumStatTypes() { return kNumStatTypes; }
-
 const absl::flat_hash_map<absl::string_view, StatType>& GetStatTypeMap() {
   static absl::flat_hash_map<absl::string_view, StatType>* stats_type_map =
       new absl::flat_hash_map<absl::string_view, StatType>({
@@ -155,7 +153,6 @@
           {"memcpy_details", kMemcpyDetails},
           {"memalloc_details", kMemallocDetails},
           {"kernel_details", kKernelDetails},
-          {"stream", kStream},
           // Stats added when processing traces.
           {"group_id", kGroupId},
           {"step_name", kStepName},
diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h
index fcd1d8d..743fedf 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.h
+++ b/tensorflow/core/profiler/utils/xplane_schema.h
@@ -16,7 +16,6 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_
 
-#include "absl/container/flat_hash_map.h"
 #include "absl/strings/match.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
@@ -65,9 +64,8 @@
   kLastHostEventType = kPartitionedCallOp,
 };
 
-// TODO(jihochoi): Rename it to ReservedStatMetadataId.
 enum StatType {
-  kFirstStatType = 1 << 10,
+  kFirstStatType = 0,
   kUnknownStatType = kFirstStatType,
   // TraceMe arguments.
   kStepId,
@@ -97,7 +95,6 @@
   kMemcpyDetails,
   kMemallocDetails,
   kKernelDetails,
-  kStream,
   // Stats added when processing traces.
   kGroupId,
   kStepName,
@@ -129,19 +126,15 @@
 absl::Span<const absl::string_view> GetStatTypeStrMap();
 
 inline absl::string_view GetStatTypeStr(StatType stat_type) {
-  return GetStatTypeStrMap()[stat_type - StatType::kFirstStatType];
+  return GetStatTypeStrMap()[stat_type];
 }
 
 inline bool IsStatType(StatType stat_type, absl::string_view stat_name) {
   return GetStatTypeStr(stat_type) == stat_name;
 }
 
-const absl::flat_hash_map<absl::string_view, StatType>& GetStatTypeMap();
-
 StatType GetStatType(absl::string_view stat_name);
 
-int GetNumStatTypes();
-
 }  // namespace profiler
 }  // namespace tensorflow