Use HloModule::unique_id() to identify distinct HLO modules

PiperOrigin-RevId: 449316603
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index fa24410..4eed032 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -6264,6 +6264,7 @@
         ":xla_debug_info_manager",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/container:flat_hash_set",
     ],
 )
 
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 35cb1fc..2f2b1e9 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -51,15 +51,6 @@
 namespace xla {
 namespace cpu {
 
-static std::string ModuleUniqueName(absl::string_view module_name,
-                                    const HloModule* module) {
-  std::string unique_id;
-  if (module != nullptr) {
-    unique_id = absl::StrCat("module.", module->unique_id(), ".");
-  }
-  return absl::StrCat(unique_id, module_name);
-}
-
 CpuExecutable::CpuExecutable(
     std::unique_ptr<SimpleOrcJIT> jit,
     std::unique_ptr<const BufferAssignment> assignment,
@@ -75,9 +66,10 @@
   if (assignment_) {
     buffer_assignment_.reset(new BufferAssignmentProto(assignment_->ToProto()));
   }
-  XlaDebugInfoManager::Get()->RegisterModule(
-      ModuleUniqueName(module_name_, shared_module().get()), shared_module(),
-      buffer_assignment_);
+  if (has_module()) {
+    XlaDebugInfoManager::Get()->RegisterModule(
+        module().unique_id(), shared_module(), buffer_assignment_);
+  }
 
   // Resolve symbols in the constructor rather than at execution time to avoid
   // races because FindSymbol is not thread safe.
@@ -95,9 +87,9 @@
 }
 
 CpuExecutable::~CpuExecutable() {
-  XlaDebugInfoManager::Get()->UnregisterModule(
-      ModuleUniqueName(module_name_, shared_module().get()), shared_module(),
-      buffer_assignment_);
+  if (has_module()) {
+    XlaDebugInfoManager::Get()->UnregisterModule(module().unique_id());
+  }
 }
 
 static StatusOr<MaybeOwningDeviceMemory> MemoryForAllocation(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index a027227..98d5501 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -121,15 +121,6 @@
   }
 }
 
-static std::string ModuleUniqueName(absl::string_view module_name,
-                                    const HloModule* module) {
-  std::string unique_id;
-  if (module != nullptr) {
-    unique_id = absl::StrCat("module.", module->unique_id(), ".");
-  }
-  return absl::StrCat(unique_id, module_name);
-}
-
 }  // namespace
 
 void GpuExecutable::BefBufferDeleter::operator()(uint8_t* ptr) const {
@@ -336,9 +327,10 @@
           params.verbose_buffer_assignment_string_dumper),
       constants_(std::move(params.constants)),
       output_info_(std::move(params.output_info)) {
-  XlaDebugInfoManager::Get()->RegisterModule(
-      ModuleUniqueName(module_name_, shared_module().get()), shared_module(),
-      debug_buffer_assignment_);
+  if (has_module()) {
+    XlaDebugInfoManager::Get()->RegisterModule(
+        module().unique_id(), shared_module(), debug_buffer_assignment_);
+  }
 }
 
 GpuExecutable::GpuExecutable(
@@ -356,15 +348,16 @@
       allocations_(std::move(allocations)),
       output_info_(std::move(output_info)),
       bef_executable_(bef_executable) {
-  XlaDebugInfoManager::Get()->RegisterModule(
-      ModuleUniqueName(module_name_, shared_module().get()), shared_module(),
-      debug_buffer_assignment_);
+  if (has_module()) {
+    XlaDebugInfoManager::Get()->RegisterModule(
+        module().unique_id(), shared_module(), debug_buffer_assignment_);
+  }
 }
 
 GpuExecutable::~GpuExecutable() {
-  XlaDebugInfoManager::Get()->UnregisterModule(
-      ModuleUniqueName(module_name_, shared_module().get()), shared_module(),
-      debug_buffer_assignment_);
+  if (has_module()) {
+    XlaDebugInfoManager::Get()->UnregisterModule(module().unique_id());
+  }
 
   {
     // We could have issued host->device mem copies in ResolveConstantGlobals.
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 537ba89..c390fb2 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -292,8 +292,12 @@
   HloModuleProto proto;
   proto.set_id(unique_id_);
   proto.set_name(name_);
-  proto.set_entry_computation_name(entry_computation_->name());
-  proto.set_entry_computation_id(entry_computation_->unique_id());
+  if (entry_computation_) {
+    proto.set_entry_computation_name(entry_computation_->name());
+    proto.set_entry_computation_id(entry_computation_->unique_id());
+    *proto.mutable_host_program_shape() =
+        entry_computation_layout().ComputeProgramShape().ToProto();
+  }
   for (const HloComputation* computation : MakeComputationPostOrder()) {
     HloComputationProto computation_proto = computation->ToProto();
     proto.add_computations()->Swap(&computation_proto);
@@ -301,8 +305,6 @@
   if (has_schedule()) {
     *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
   }
-  *proto.mutable_host_program_shape() =
-      entry_computation_layout().ComputeProgramShape().ToProto();
   *proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
   *proto.mutable_dynamic_parameter_binding() =
       dynamic_parameter_binding().ToProto();
diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager.cc b/tensorflow/compiler/xla/service/xla_debug_info_manager.cc
index 5c398fd..7b11c0d 100644
--- a/tensorflow/compiler/xla/service/xla_debug_info_manager.cc
+++ b/tensorflow/compiler/xla/service/xla_debug_info_manager.cc
@@ -15,49 +15,39 @@
 
 #include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
 
+#include <memory>
+#include <string>
+#include <utility>
+
 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
 
 namespace xla {
 
 void XlaDebugInfoManager::RegisterModule(
-    const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
+    ModuleIdentifier module_id, std::shared_ptr<const HloModule> hlo_module,
     std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
+  CHECK(hlo_module != nullptr && module_id == hlo_module->unique_id());
   absl::MutexLock lock(&mutex_);
-  if (active_modules_.find(module_id) != active_modules_.end()) {
-    active_modules_[module_id].instances.emplace_back(hlo_module,
-                                                      buffer_assignment);
-  } else {
-    XlaModuleEntry m;
-    m.module_id = module_id;
-    m.instances.emplace_back(hlo_module, buffer_assignment);
-    active_modules_[module_id] = std::move(m);
-  }
+  auto result = modules_.try_emplace(module_id);
+  CHECK(result.second);
+  XlaModuleEntry& m = result.first->second;
+  m.hlo_module = std::move(hlo_module);
+  m.buffer_assignment = std::move(buffer_assignment);
+  m.active = true;
 }
 
 // Unregister an active module, when the last active module of the same
 // module id is out of scope, we remove it from our database.
 // However during tracing, we will defer the cleanup after serialization.
-void XlaDebugInfoManager::UnregisterModule(
-    const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
-    std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
+void XlaDebugInfoManager::UnregisterModule(ModuleIdentifier module_id) {
   absl::MutexLock lock(&mutex_);
-  CHECK(active_modules_.find(module_id) != active_modules_.end());
-  XlaModuleEntry& active_module = active_modules_[module_id];
-  auto instance_it =
-      absl::c_find_if(active_module.instances, [&](XlaModuleInstance& e) {
-        return e.hlo_module == hlo_module &&
-               e.buffer_assignment == buffer_assignment;
-      });
-
-  CHECK(instance_it != active_module.instances.end());
-
+  auto it = modules_.find(module_id);
+  CHECK(it != modules_.end());
   if (!tracing_active_) {
-    active_module.instances.erase(instance_it);
-    if (active_module.instances.empty()) {
-      active_modules_.erase(module_id);
-    }
+    modules_.erase(it);
   } else {
-    instance_it->active = false;
+    XlaModuleEntry& m = it->second;
+    m.active = false;
   }
 }
 
@@ -67,41 +57,23 @@
 }
 
 void XlaDebugInfoManager::StopTracing(
-    std::vector<XlaModuleDebugInfo>* module_debug_info) {
+    std::vector<std::unique_ptr<HloProto>>* module_debug_info) {
   std::vector<XlaModuleEntry> modules_to_serialize;
   {
     absl::MutexLock lock(&mutex_);
     if (!tracing_active_) return;
     tracing_active_ = false;
-    for (const auto& traced_module_id : active_modules_) {
-      const XlaModuleEntry& active_module = traced_module_id.second;
 
-      // Copy the instance so that we can serialize without holding the lock.
-      // All instances are equivalent from the perspective of symbolization.
-      // We only use the first one.
-      if (!active_module.instances.empty()) {
-        XlaModuleEntry e;
-        e.module_id = active_module.module_id;
-        e.instances.push_back(active_module.instances[0]);
-        modules_to_serialize.push_back(std::move(e));
-      }
-    }
-
-    // Remove all active modules which have an instance count equal to zero.
-    for (auto it = active_modules_.begin(); it != active_modules_.end();) {
-      auto& active_module = it->second;
-      for (auto instance = active_module.instances.begin();
-           instance != active_module.instances.end();) {
-        if (instance->active) {
-          ++instance;
-        } else {
-          instance = active_module.instances.erase(instance);
-        }
-      }
-
-      if (active_module.instances.empty()) {
-        active_modules_.erase(it++);
+    // Copy all modules so we can serialize without holding the lock, and remove
+    // all inactive modules.
+    modules_to_serialize.reserve(modules_.size());
+    for (auto it = modules_.begin(); it != modules_.end();) {
+      auto& m = it->second;
+      if (!m.active) {
+        modules_to_serialize.emplace_back(std::move(m));
+        modules_.erase(it++);
       } else {
+        modules_to_serialize.emplace_back(m);
         ++it;
       }
     }
@@ -110,18 +82,14 @@
   if (module_debug_info) {
     module_debug_info->clear();
     for (const auto& m : modules_to_serialize) {
-      XlaModuleDebugInfo info;
-      info.module_id = m.module_id;
       // In real world, hlo_module and buffer_assignment will always be
       // non-nullptr. Due to the inconvenience of creation of buffer_assignment
       // object in test, we set it to nullptr and guard this for it.
-      if (m.instances[0].hlo_module && m.instances[0].buffer_assignment) {
-        info.hlo_proto = absl::make_unique<HloProto>(
-            MakeHloProto(*m.instances[0].hlo_module));
-        *info.hlo_proto->mutable_buffer_assignment() =
-            *m.instances[0].buffer_assignment;
+      auto hlo_proto = absl::make_unique<HloProto>(MakeHloProto(*m.hlo_module));
+      if (m.buffer_assignment != nullptr) {
+        *hlo_proto->mutable_buffer_assignment() = *m.buffer_assignment;
       }
-      module_debug_info->emplace_back(std::move(info));
+      module_debug_info->emplace_back(std::move(hlo_proto));
     }
   }
 }
diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager.h b/tensorflow/compiler/xla/service/xla_debug_info_manager.h
index 6f05222..8f29f1a 100644
--- a/tensorflow/compiler/xla/service/xla_debug_info_manager.h
+++ b/tensorflow/compiler/xla/service/xla_debug_info_manager.h
@@ -26,22 +26,11 @@
 
 namespace xla {
 
-using ModuleIdentifier = std::string;
+using ModuleIdentifier = int;
 
-struct XlaModuleDebugInfo {
-  ModuleIdentifier module_id;
-  // The hlo proto associated with this xla program.
-  std::unique_ptr<HloProto> hlo_proto;
-};
-
-// Debug info manager keeps track of all the debug information (symbol table,
-// HLO proto etc) during tracing period. Because tracing period can start
-// during module execution, therefore even when tracing is off, we still need
-// minimum level of monitoring (i.e. which program is running lately).
-// We allow multiple programs with the same module_id, however from tracing
-// debug information perspective, same module id implies the same debug
-// information. We will only keep track unique debug information, identified
-// by module_id.
+// XlaDebugInfoManager tracks all XLA programs (Executables) throughout their
+// lifetime. Because the tracing period can start during an Executable's
+// execution, we need to track Executables even when tracing is off.
 // This class is thread-safe.
 class XlaDebugInfoManager {
  public:
@@ -50,71 +39,42 @@
     return singleton;
   }
 
-  // Register an active module to XlaDebugInfoManager. We will keep track all
-  // existing HloModules within the process.
-  // Modules with same module id can be registered and tracked separately.
+  // Registers an active module to XlaDebugInfoManager.
+  // The module_id is expected to be unique per process.
   void RegisterModule(
-      const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
+      ModuleIdentifier module_id, std::shared_ptr<const HloModule> hlo_module,
       std::shared_ptr<const BufferAssignmentProto> buffer_assignment);
 
-  // Unregister an active module. When the last active module of the same
-  // module id is out of scope, we remove it from our database.
-  // However during tracing, we will defer the cleanup after serialization.
-  void UnregisterModule(
-      const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
-      std::shared_ptr<const BufferAssignmentProto> buffer_assignment);
+  // Unregisters an active module.
+  void UnregisterModule(ModuleIdentifier module_id);
 
   // Start tracing, began to collecting debug information for all the running
   // modules during the tracing period.
   void StartTracing();
 
-  // Stop tracing and drop all instances that have been stoped during tracing,
-  // Then drop all modules that have no instances registered. Dump debug
-  // information for all the running modules to module_debug_info if specified.
+  // Stops tracing.
+  // If module_debug_info is not null, returns debug information for all the
+  // modules that were alive since StartTracing().
   void StopTracing(
-      std::vector<XlaModuleDebugInfo>* module_debug_info = nullptr);
+      std::vector<std::unique_ptr<HloProto>>* module_debug_info = nullptr);
 
-  friend class XlaDebugInfoManagerTest;
+  friend class XlaDebugInfoManagerTestPeer;
 
  private:
-  XlaDebugInfoManager() {}
+  XlaDebugInfoManager() = default;
 
-  std::set<ModuleIdentifier> GetActiveModules() {
-    absl::MutexLock lock(&mutex_);
-    std::set<ModuleIdentifier> active;
-    for (const auto& id : active_modules_) {
-      active.insert(id.first);
-    }
-    return active;
-  }
-
-  // We track each instance of GpuExecutable. Assuming multiple GpuExecutable
-  // can have same unique id if they are actually same program. From the
-  // perspective of symbol table, they are identical, but for the life time
-  // tracking, they need to be tracked separately.
-  struct XlaModuleInstance {
-    XlaModuleInstance(std::shared_ptr<HloModule> m,
-                      std::shared_ptr<const BufferAssignmentProto> b)
-        : hlo_module(std::move(m)), buffer_assignment(std::move(b)) {}
-    std::shared_ptr<HloModule> hlo_module;
-    std::shared_ptr<const BufferAssignmentProto> buffer_assignment;
-    bool active = true;
-  };
-
-  // Each XlaModuleEntry can have multiple XlaModuleInstance's if XlA registers
-  // them with the same ModuleIdentifier.
   struct XlaModuleEntry {
-    // The module symbol table/debug info that shared by all instances.
-    ModuleIdentifier module_id;
-    std::vector<XlaModuleInstance> instances;
+    std::shared_ptr<const HloModule> hlo_module;
+    std::shared_ptr<const BufferAssignmentProto> buffer_assignment;
+    bool active = false;
   };
 
-  absl::Mutex mutex_;
+  mutable absl::Mutex mutex_;
   bool tracing_active_ ABSL_GUARDED_BY(mutex_) = false;
   // Active modules are those still tracked by us. There could be much more
   // active modules than running modules, we will try to reduce the trace size
   // by only transfer those modules that were running during tracing period.
-  absl::flat_hash_map<ModuleIdentifier, XlaModuleEntry> active_modules_
+  absl::flat_hash_map<ModuleIdentifier, XlaModuleEntry> modules_
       ABSL_GUARDED_BY(mutex_);
 };
 
diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc b/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc
index 608e009..1aa459e 100644
--- a/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc
+++ b/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc
@@ -14,14 +14,57 @@
 ==============================================================================*/
 #include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
 
+#include <memory>
 #include <string>
 #include <utility>
 
+#include "absl/container/flat_hash_set.h"
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 
 namespace xla {
 
+class XlaDebugInfoManagerTestPeer {
+ public:
+  void RegisterModule(
+      ModuleIdentifier module_id, std::shared_ptr<const HloModule> hlo_module,
+      std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
+    return xla_debug_info_manager_.RegisterModule(module_id, hlo_module,
+                                                  buffer_assignment);
+  }
+
+  void UnregisterModule(ModuleIdentifier module_id) {
+    return xla_debug_info_manager_.UnregisterModule(module_id);
+  }
+
+  void StartTracing() { return xla_debug_info_manager_.StartTracing(); }
+
+  absl::flat_hash_set<ModuleIdentifier> StopTracing() {
+    std::vector<std::unique_ptr<HloProto>> module_debug_info;
+    xla_debug_info_manager_.StopTracing(&module_debug_info);
+    absl::flat_hash_set<ModuleIdentifier> module_ids;
+    for (const auto& hlo_proto : module_debug_info) {
+      module_ids.insert(hlo_proto->hlo_module().id());
+    }
+    return module_ids;
+  }
+
+  absl::flat_hash_set<ModuleIdentifier> GetModuleIds() {
+    absl::flat_hash_set<ModuleIdentifier> module_ids;
+    absl::MutexLock lock(&xla_debug_info_manager_.mutex_);
+    for (const auto& it : xla_debug_info_manager_.modules_) {
+      module_ids.insert(it.first);
+    }
+    return module_ids;
+  }
+
+ private:
+  XlaDebugInfoManager xla_debug_info_manager_;
+};
+
+namespace {
+
+using ::testing::IsEmpty;
 using ::testing::UnorderedElementsAre;
 
 class XlaDebugInfoManagerTest : public HloTestBase {
@@ -29,89 +72,79 @@
   struct DebugMetadata {
     // We allow same id to be registered multiple times. we need unique id to
     // know which program is referenced (such as in UnregisterProgram).
-    int unique_id;
-    std::string id;
+    ModuleIdentifier unique_id;
     std::shared_ptr<HloModule> module;
     std::shared_ptr<BufferAssignmentProto> buffer_assignment;
   };
 
   // Return unique id of this module.
-  int RegisterProgram(const std::string& module_id) {
+  ModuleIdentifier RegisterProgram(const std::string& module_name) {
     DebugMetadata debug_info;
     HloModuleConfig config;
-    debug_info.unique_id = ++serial_;
-    debug_info.id = module_id;
-    debug_info.module = std::make_shared<HloModule>(module_id, config);
+    debug_info.module = std::make_shared<HloModule>(module_name, config);
     debug_info.buffer_assignment = nullptr;
-    xla_debug_info_manager_.RegisterModule(module_id, debug_info.module,
+    ModuleIdentifier unique_id = debug_info.module->unique_id();
+    debug_info.unique_id = unique_id;
+    xla_debug_info_manager_.RegisterModule(unique_id, debug_info.module,
                                            debug_info.buffer_assignment);
     external_references_.push_back(std::move(debug_info));
-    return serial_;
+    return unique_id;
   }
 
-  void UnregisterProgram(int unique_id) {
+  void UnregisterProgram(ModuleIdentifier unique_id) {
     for (int i = 0; i < external_references_.size(); i++) {
       if (external_references_[i].unique_id == unique_id) {
-        xla_debug_info_manager_.UnregisterModule(
-            external_references_[i].id, external_references_[i].module,
-            external_references_[i].buffer_assignment);
+        xla_debug_info_manager_.UnregisterModule(unique_id);
         external_references_.erase(external_references_.begin() + i);
         break;
       }
     }
   }
 
-  std::set<ModuleIdentifier> GetActiveModule() {
-    return xla_debug_info_manager_.GetActiveModules();
+  absl::flat_hash_set<ModuleIdentifier> GetModuleIds() {
+    return xla_debug_info_manager_.GetModuleIds();
   }
 
   void StartTrace() { xla_debug_info_manager_.StartTracing(); }
 
-  std::set<ModuleIdentifier> StopTrace() {
-    std::vector<XlaModuleDebugInfo> module_debug_info;
-    xla_debug_info_manager_.StopTracing(&module_debug_info);
-    std::set<ModuleIdentifier> serialized;
-    for (const auto& module : module_debug_info) {
-      serialized.insert(module.module_id);
-    }
-    return serialized;
+  absl::flat_hash_set<ModuleIdentifier> StopTrace() {
+    return xla_debug_info_manager_.StopTracing();
   }
 
-  int serial_ = 0;
-
   // Simulation of compilation cache.
   std::vector<DebugMetadata> external_references_;
 
   // Use an instance per test instead of singleton to avoid interferences.
-  XlaDebugInfoManager xla_debug_info_manager_;
+  XlaDebugInfoManagerTestPeer xla_debug_info_manager_;
 };
 
 // Test the cases where no trace session is involved.
 TEST_F(XlaDebugInfoManagerTest, NoTraceBasic) {
   auto program0 = RegisterProgram("program0");
-  EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0"));
+  EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0));
 
   auto program1 = RegisterProgram("program1");
-  EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0", "program1"));
+  EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0, program1));
 
   UnregisterProgram(program0);
-  EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program1"));
+  EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program1));
   UnregisterProgram(program1);
-  EXPECT_TRUE(GetActiveModule().empty());
+  EXPECT_TRUE(GetModuleIds().empty());
 }
 
 TEST_F(XlaDebugInfoManagerTest, NoTraceDuplicateIds) {
   auto program0A = RegisterProgram("program0");
   auto program0B = RegisterProgram("program0");  // duplicates
   auto program1 = RegisterProgram("program1");
-  EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0", "program1"));
+  EXPECT_THAT(GetModuleIds(),
+              UnorderedElementsAre(program0A, program0B, program1));
 
   UnregisterProgram(program1);
-  EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0"));
+  EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0A, program0B));
   UnregisterProgram(program0A);
-  EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0"));
+  EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0B));
   UnregisterProgram(program0B);
-  EXPECT_TRUE(GetActiveModule().empty());
+  EXPECT_THAT(GetModuleIds(), IsEmpty());
 }
 
 // Test the cases where an active trace session is involved.
@@ -120,25 +153,24 @@
   auto program0B = RegisterProgram("program0");  // duplicates
   auto program1 = RegisterProgram("program1");
 
-  // Case 1: Trace starts when no program is running.
   StartTrace();
   auto program2 = RegisterProgram("program2");
   EXPECT_THAT(StopTrace(),
-              UnorderedElementsAre("program0", "program1", "program2"));
+              UnorderedElementsAre(program0A, program0B, program1, program2));
 
-  // Case 1: Trace starts during program is running.
   StartTrace();
   EXPECT_THAT(StopTrace(),
-              UnorderedElementsAre("program0", "program1", "program2"));
+              UnorderedElementsAre(program0A, program0B, program1, program2));
 
   UnregisterProgram(program2);
-  EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0", "program1"));
+  EXPECT_THAT(GetModuleIds(),
+              UnorderedElementsAre(program0A, program0B, program1));
   UnregisterProgram(program0A);
-  EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0", "program1"));
+  EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0B, program1));
   UnregisterProgram(program0B);
-  EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program1"));
+  EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program1));
   UnregisterProgram(program1);
-  EXPECT_TRUE(GetActiveModule().empty());
+  EXPECT_THAT(GetModuleIds(), IsEmpty());
 }
 
 TEST_F(XlaDebugInfoManagerTest, UnregisterDuringTrace) {
@@ -149,10 +181,12 @@
   StartTrace();
   UnregisterProgram(program1);
   UnregisterProgram(program0B);
-  EXPECT_THAT(StopTrace(), UnorderedElementsAre("program0", "program1"));
-  EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0"));
+  EXPECT_THAT(StopTrace(),
+              UnorderedElementsAre(program0A, program0B, program1));
+  EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0A));
 
   UnregisterProgram(program0A);
 }
 
+}  // namespace
 }  // namespace xla
diff --git a/tensorflow/core/profiler/backends/cpu/metadata_collector.cc b/tensorflow/core/profiler/backends/cpu/metadata_collector.cc
index f59241a..5955a62 100644
--- a/tensorflow/core/profiler/backends/cpu/metadata_collector.cc
+++ b/tensorflow/core/profiler/backends/cpu/metadata_collector.cc
@@ -76,9 +76,13 @@
       XPlaneBuilder xplane(plane);
       const XStatMetadata& hlo_proto_stat =
           *xplane.GetOrCreateStatMetadata(kHloProto);
-      for (auto& p : debug_info_) {
-        xplane.AddStatValue(hlo_proto_stat, *p.hlo_proto);
-        p.hlo_proto.reset();
+      for (auto& hlo_proto : debug_info_) {
+        XEventMetadata* metadata =
+            xplane.GetOrCreateEventMetadata(hlo_proto->hlo_module().id());
+        metadata->set_name(hlo_proto->hlo_module().name());
+        XStatsBuilder<XEventMetadata> stats(metadata, &xplane);
+        stats.AddStatValue(hlo_proto_stat, *hlo_proto);
+        hlo_proto.reset();
       }
       debug_info_.clear();
     }
@@ -86,7 +90,7 @@
   }
 
  private:
-  std::vector<xla::XlaModuleDebugInfo> debug_info_;
+  std::vector<std::unique_ptr<xla::HloProto>> debug_info_;
   bool trace_active_ = false;
 
   TF_DISALLOW_COPY_AND_ASSIGN(MetadataCollector);