Use HloModule::unique_id() to identify distinct HLO modules
PiperOrigin-RevId: 449316603
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index fa24410..4eed032 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -6264,6 +6264,7 @@
":xla_debug_info_manager",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 35cb1fc..2f2b1e9 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -51,15 +51,6 @@
namespace xla {
namespace cpu {
-static std::string ModuleUniqueName(absl::string_view module_name,
- const HloModule* module) {
- std::string unique_id;
- if (module != nullptr) {
- unique_id = absl::StrCat("module.", module->unique_id(), ".");
- }
- return absl::StrCat(unique_id, module_name);
-}
-
CpuExecutable::CpuExecutable(
std::unique_ptr<SimpleOrcJIT> jit,
std::unique_ptr<const BufferAssignment> assignment,
@@ -75,9 +66,10 @@
if (assignment_) {
buffer_assignment_.reset(new BufferAssignmentProto(assignment_->ToProto()));
}
- XlaDebugInfoManager::Get()->RegisterModule(
- ModuleUniqueName(module_name_, shared_module().get()), shared_module(),
- buffer_assignment_);
+ if (has_module()) {
+ XlaDebugInfoManager::Get()->RegisterModule(
+ module().unique_id(), shared_module(), buffer_assignment_);
+ }
// Resolve symbols in the constructor rather than at execution time to avoid
// races because FindSymbol is not thread safe.
@@ -95,9 +87,9 @@
}
CpuExecutable::~CpuExecutable() {
- XlaDebugInfoManager::Get()->UnregisterModule(
- ModuleUniqueName(module_name_, shared_module().get()), shared_module(),
- buffer_assignment_);
+ if (has_module()) {
+ XlaDebugInfoManager::Get()->UnregisterModule(module().unique_id());
+ }
}
static StatusOr<MaybeOwningDeviceMemory> MemoryForAllocation(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index a027227..98d5501 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -121,15 +121,6 @@
}
}
-static std::string ModuleUniqueName(absl::string_view module_name,
- const HloModule* module) {
- std::string unique_id;
- if (module != nullptr) {
- unique_id = absl::StrCat("module.", module->unique_id(), ".");
- }
- return absl::StrCat(unique_id, module_name);
-}
-
} // namespace
void GpuExecutable::BefBufferDeleter::operator()(uint8_t* ptr) const {
@@ -336,9 +327,10 @@
params.verbose_buffer_assignment_string_dumper),
constants_(std::move(params.constants)),
output_info_(std::move(params.output_info)) {
- XlaDebugInfoManager::Get()->RegisterModule(
- ModuleUniqueName(module_name_, shared_module().get()), shared_module(),
- debug_buffer_assignment_);
+ if (has_module()) {
+ XlaDebugInfoManager::Get()->RegisterModule(
+ module().unique_id(), shared_module(), debug_buffer_assignment_);
+ }
}
GpuExecutable::GpuExecutable(
@@ -356,15 +348,16 @@
allocations_(std::move(allocations)),
output_info_(std::move(output_info)),
bef_executable_(bef_executable) {
- XlaDebugInfoManager::Get()->RegisterModule(
- ModuleUniqueName(module_name_, shared_module().get()), shared_module(),
- debug_buffer_assignment_);
+ if (has_module()) {
+ XlaDebugInfoManager::Get()->RegisterModule(
+ module().unique_id(), shared_module(), debug_buffer_assignment_);
+ }
}
GpuExecutable::~GpuExecutable() {
- XlaDebugInfoManager::Get()->UnregisterModule(
- ModuleUniqueName(module_name_, shared_module().get()), shared_module(),
- debug_buffer_assignment_);
+ if (has_module()) {
+ XlaDebugInfoManager::Get()->UnregisterModule(module().unique_id());
+ }
{
// We could have issued host->device mem copies in ResolveConstantGlobals.
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 537ba89..c390fb2 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -292,8 +292,12 @@
HloModuleProto proto;
proto.set_id(unique_id_);
proto.set_name(name_);
- proto.set_entry_computation_name(entry_computation_->name());
- proto.set_entry_computation_id(entry_computation_->unique_id());
+ if (entry_computation_) {
+ proto.set_entry_computation_name(entry_computation_->name());
+ proto.set_entry_computation_id(entry_computation_->unique_id());
+ *proto.mutable_host_program_shape() =
+ entry_computation_layout().ComputeProgramShape().ToProto();
+ }
for (const HloComputation* computation : MakeComputationPostOrder()) {
HloComputationProto computation_proto = computation->ToProto();
proto.add_computations()->Swap(&computation_proto);
@@ -301,8 +305,6 @@
if (has_schedule()) {
*proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
}
- *proto.mutable_host_program_shape() =
- entry_computation_layout().ComputeProgramShape().ToProto();
*proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
*proto.mutable_dynamic_parameter_binding() =
dynamic_parameter_binding().ToProto();
diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager.cc b/tensorflow/compiler/xla/service/xla_debug_info_manager.cc
index 5c398fd..7b11c0d 100644
--- a/tensorflow/compiler/xla/service/xla_debug_info_manager.cc
+++ b/tensorflow/compiler/xla/service/xla_debug_info_manager.cc
@@ -15,49 +15,39 @@
#include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
+#include <memory>
+#include <string>
+#include <utility>
+
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
namespace xla {
void XlaDebugInfoManager::RegisterModule(
- const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
+ ModuleIdentifier module_id, std::shared_ptr<const HloModule> hlo_module,
std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
+ CHECK(hlo_module != nullptr && module_id == hlo_module->unique_id());
absl::MutexLock lock(&mutex_);
- if (active_modules_.find(module_id) != active_modules_.end()) {
- active_modules_[module_id].instances.emplace_back(hlo_module,
- buffer_assignment);
- } else {
- XlaModuleEntry m;
- m.module_id = module_id;
- m.instances.emplace_back(hlo_module, buffer_assignment);
- active_modules_[module_id] = std::move(m);
- }
+ auto result = modules_.try_emplace(module_id);
+ CHECK(result.second);
+ XlaModuleEntry& m = result.first->second;
+ m.hlo_module = std::move(hlo_module);
+ m.buffer_assignment = std::move(buffer_assignment);
+ m.active = true;
}
// Unregister an active module, when the last active module of the same
// module id is out of scope, we remove it from our database.
// However during tracing, we will defer the cleanup after serialization.
-void XlaDebugInfoManager::UnregisterModule(
- const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
- std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
+void XlaDebugInfoManager::UnregisterModule(ModuleIdentifier module_id) {
absl::MutexLock lock(&mutex_);
- CHECK(active_modules_.find(module_id) != active_modules_.end());
- XlaModuleEntry& active_module = active_modules_[module_id];
- auto instance_it =
- absl::c_find_if(active_module.instances, [&](XlaModuleInstance& e) {
- return e.hlo_module == hlo_module &&
- e.buffer_assignment == buffer_assignment;
- });
-
- CHECK(instance_it != active_module.instances.end());
-
+ auto it = modules_.find(module_id);
+ CHECK(it != modules_.end());
if (!tracing_active_) {
- active_module.instances.erase(instance_it);
- if (active_module.instances.empty()) {
- active_modules_.erase(module_id);
- }
+ modules_.erase(it);
} else {
- instance_it->active = false;
+ XlaModuleEntry& m = it->second;
+ m.active = false;
}
}
@@ -67,41 +57,23 @@
}
void XlaDebugInfoManager::StopTracing(
- std::vector<XlaModuleDebugInfo>* module_debug_info) {
+ std::vector<std::unique_ptr<HloProto>>* module_debug_info) {
std::vector<XlaModuleEntry> modules_to_serialize;
{
absl::MutexLock lock(&mutex_);
if (!tracing_active_) return;
tracing_active_ = false;
- for (const auto& traced_module_id : active_modules_) {
- const XlaModuleEntry& active_module = traced_module_id.second;
- // Copy the instance so that we can serialize without holding the lock.
- // All instances are equivalent from the perspective of symbolization.
- // We only use the first one.
- if (!active_module.instances.empty()) {
- XlaModuleEntry e;
- e.module_id = active_module.module_id;
- e.instances.push_back(active_module.instances[0]);
- modules_to_serialize.push_back(std::move(e));
- }
- }
-
- // Remove all active modules which have an instance count equal to zero.
- for (auto it = active_modules_.begin(); it != active_modules_.end();) {
- auto& active_module = it->second;
- for (auto instance = active_module.instances.begin();
- instance != active_module.instances.end();) {
- if (instance->active) {
- ++instance;
- } else {
- instance = active_module.instances.erase(instance);
- }
- }
-
- if (active_module.instances.empty()) {
- active_modules_.erase(it++);
+ // Copy all modules so we can serialize without holding the lock, and remove
+ // all inactive modules.
+ modules_to_serialize.reserve(modules_.size());
+ for (auto it = modules_.begin(); it != modules_.end();) {
+ auto& m = it->second;
+ if (!m.active) {
+ modules_to_serialize.emplace_back(std::move(m));
+ modules_.erase(it++);
} else {
+ modules_to_serialize.emplace_back(m);
++it;
}
}
@@ -110,18 +82,14 @@
if (module_debug_info) {
module_debug_info->clear();
for (const auto& m : modules_to_serialize) {
- XlaModuleDebugInfo info;
- info.module_id = m.module_id;
// In real world, hlo_module and buffer_assignment will always be
// non-nullptr. Due to the inconvenience of creation of buffer_assignment
// object in test, we set it to nullptr and guard this for it.
- if (m.instances[0].hlo_module && m.instances[0].buffer_assignment) {
- info.hlo_proto = absl::make_unique<HloProto>(
- MakeHloProto(*m.instances[0].hlo_module));
- *info.hlo_proto->mutable_buffer_assignment() =
- *m.instances[0].buffer_assignment;
+ auto hlo_proto = absl::make_unique<HloProto>(MakeHloProto(*m.hlo_module));
+ if (m.buffer_assignment != nullptr) {
+ *hlo_proto->mutable_buffer_assignment() = *m.buffer_assignment;
}
- module_debug_info->emplace_back(std::move(info));
+ module_debug_info->emplace_back(std::move(hlo_proto));
}
}
}
diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager.h b/tensorflow/compiler/xla/service/xla_debug_info_manager.h
index 6f05222..8f29f1a 100644
--- a/tensorflow/compiler/xla/service/xla_debug_info_manager.h
+++ b/tensorflow/compiler/xla/service/xla_debug_info_manager.h
@@ -26,22 +26,11 @@
namespace xla {
-using ModuleIdentifier = std::string;
+using ModuleIdentifier = int;
-struct XlaModuleDebugInfo {
- ModuleIdentifier module_id;
- // The hlo proto associated with this xla program.
- std::unique_ptr<HloProto> hlo_proto;
-};
-
-// Debug info manager keeps track of all the debug information (symbol table,
-// HLO proto etc) during tracing period. Because tracing period can start
-// during module execution, therefore even when tracing is off, we still need
-// minimum level of monitoring (i.e. which program is running lately).
-// We allow multiple programs with the same module_id, however from tracing
-// debug information perspective, same module id implies the same debug
-// information. We will only keep track unique debug information, identified
-// by module_id.
+// XlaDebugInfoManager tracks all XLA programs (Executables) throughout their
+// lifetime. Because the tracing period can start during an Executable's
+// execution, we need to track Executables even when tracing is off.
// This class is thread-safe.
class XlaDebugInfoManager {
public:
@@ -50,71 +39,42 @@
return singleton;
}
- // Register an active module to XlaDebugInfoManager. We will keep track all
- // existing HloModules within the process.
- // Modules with same module id can be registered and tracked separately.
+ // Registers an active module to XlaDebugInfoManager.
+ // The module_id is expected to be unique per process.
void RegisterModule(
- const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
+ ModuleIdentifier module_id, std::shared_ptr<const HloModule> hlo_module,
std::shared_ptr<const BufferAssignmentProto> buffer_assignment);
- // Unregister an active module. When the last active module of the same
- // module id is out of scope, we remove it from our database.
- // However during tracing, we will defer the cleanup after serialization.
- void UnregisterModule(
- const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
- std::shared_ptr<const BufferAssignmentProto> buffer_assignment);
+ // Unregisters an active module.
+ void UnregisterModule(ModuleIdentifier module_id);
// Start tracing, began to collecting debug information for all the running
// modules during the tracing period.
void StartTracing();
- // Stop tracing and drop all instances that have been stoped during tracing,
- // Then drop all modules that have no instances registered. Dump debug
- // information for all the running modules to module_debug_info if specified.
+ // Stops tracing.
+ // If module_debug_info is not null, returns debug information for all the
+ // modules that were alive since StartTracing().
void StopTracing(
- std::vector<XlaModuleDebugInfo>* module_debug_info = nullptr);
+ std::vector<std::unique_ptr<HloProto>>* module_debug_info = nullptr);
- friend class XlaDebugInfoManagerTest;
+ friend class XlaDebugInfoManagerTestPeer;
private:
- XlaDebugInfoManager() {}
+ XlaDebugInfoManager() = default;
- std::set<ModuleIdentifier> GetActiveModules() {
- absl::MutexLock lock(&mutex_);
- std::set<ModuleIdentifier> active;
- for (const auto& id : active_modules_) {
- active.insert(id.first);
- }
- return active;
- }
-
- // We track each instance of GpuExecutable. Assuming multiple GpuExecutable
- // can have same unique id if they are actually same program. From the
- // perspective of symbol table, they are identical, but for the life time
- // tracking, they need to be tracked separately.
- struct XlaModuleInstance {
- XlaModuleInstance(std::shared_ptr<HloModule> m,
- std::shared_ptr<const BufferAssignmentProto> b)
- : hlo_module(std::move(m)), buffer_assignment(std::move(b)) {}
- std::shared_ptr<HloModule> hlo_module;
- std::shared_ptr<const BufferAssignmentProto> buffer_assignment;
- bool active = true;
- };
-
- // Each XlaModuleEntry can have multiple XlaModuleInstance's if XlA registers
- // them with the same ModuleIdentifier.
struct XlaModuleEntry {
- // The module symbol table/debug info that shared by all instances.
- ModuleIdentifier module_id;
- std::vector<XlaModuleInstance> instances;
+ std::shared_ptr<const HloModule> hlo_module;
+ std::shared_ptr<const BufferAssignmentProto> buffer_assignment;
+ bool active = false;
};
- absl::Mutex mutex_;
+ mutable absl::Mutex mutex_;
bool tracing_active_ ABSL_GUARDED_BY(mutex_) = false;
// Active modules are those still tracked by us. There could be much more
// active modules than running modules, we will try to reduce the trace size
// by only transfer those modules that were running during tracing period.
- absl::flat_hash_map<ModuleIdentifier, XlaModuleEntry> active_modules_
+ absl::flat_hash_map<ModuleIdentifier, XlaModuleEntry> modules_
ABSL_GUARDED_BY(mutex_);
};
diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc b/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc
index 608e009..1aa459e 100644
--- a/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc
+++ b/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc
@@ -14,14 +14,57 @@
==============================================================================*/
#include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
+#include <memory>
#include <string>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
+class XlaDebugInfoManagerTestPeer {
+ public:
+ void RegisterModule(
+ ModuleIdentifier module_id, std::shared_ptr<const HloModule> hlo_module,
+ std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
+ return xla_debug_info_manager_.RegisterModule(module_id, hlo_module,
+ buffer_assignment);
+ }
+
+ void UnregisterModule(ModuleIdentifier module_id) {
+ return xla_debug_info_manager_.UnregisterModule(module_id);
+ }
+
+ void StartTracing() { return xla_debug_info_manager_.StartTracing(); }
+
+ absl::flat_hash_set<ModuleIdentifier> StopTracing() {
+ std::vector<std::unique_ptr<HloProto>> module_debug_info;
+ xla_debug_info_manager_.StopTracing(&module_debug_info);
+ absl::flat_hash_set<ModuleIdentifier> module_ids;
+ for (const auto& hlo_proto : module_debug_info) {
+ module_ids.insert(hlo_proto->hlo_module().id());
+ }
+ return module_ids;
+ }
+
+ absl::flat_hash_set<ModuleIdentifier> GetModuleIds() {
+ absl::flat_hash_set<ModuleIdentifier> module_ids;
+ absl::MutexLock lock(&xla_debug_info_manager_.mutex_);
+ for (const auto& it : xla_debug_info_manager_.modules_) {
+ module_ids.insert(it.first);
+ }
+ return module_ids;
+ }
+
+ private:
+ XlaDebugInfoManager xla_debug_info_manager_;
+};
+
+namespace {
+
+using ::testing::IsEmpty;
using ::testing::UnorderedElementsAre;
class XlaDebugInfoManagerTest : public HloTestBase {
@@ -29,89 +72,79 @@
struct DebugMetadata {
// We allow same id to be registered multiple times. we need unique id to
// know which program is referenced (such as in UnregisterProgram).
- int unique_id;
- std::string id;
+ ModuleIdentifier unique_id;
std::shared_ptr<HloModule> module;
std::shared_ptr<BufferAssignmentProto> buffer_assignment;
};
// Return unique id of this module.
- int RegisterProgram(const std::string& module_id) {
+ ModuleIdentifier RegisterProgram(const std::string& module_name) {
DebugMetadata debug_info;
HloModuleConfig config;
- debug_info.unique_id = ++serial_;
- debug_info.id = module_id;
- debug_info.module = std::make_shared<HloModule>(module_id, config);
+ debug_info.module = std::make_shared<HloModule>(module_name, config);
debug_info.buffer_assignment = nullptr;
- xla_debug_info_manager_.RegisterModule(module_id, debug_info.module,
+ ModuleIdentifier unique_id = debug_info.module->unique_id();
+ debug_info.unique_id = unique_id;
+ xla_debug_info_manager_.RegisterModule(unique_id, debug_info.module,
debug_info.buffer_assignment);
external_references_.push_back(std::move(debug_info));
- return serial_;
+ return unique_id;
}
- void UnregisterProgram(int unique_id) {
+ void UnregisterProgram(ModuleIdentifier unique_id) {
for (int i = 0; i < external_references_.size(); i++) {
if (external_references_[i].unique_id == unique_id) {
- xla_debug_info_manager_.UnregisterModule(
- external_references_[i].id, external_references_[i].module,
- external_references_[i].buffer_assignment);
+ xla_debug_info_manager_.UnregisterModule(unique_id);
external_references_.erase(external_references_.begin() + i);
break;
}
}
}
- std::set<ModuleIdentifier> GetActiveModule() {
- return xla_debug_info_manager_.GetActiveModules();
+ absl::flat_hash_set<ModuleIdentifier> GetModuleIds() {
+ return xla_debug_info_manager_.GetModuleIds();
}
void StartTrace() { xla_debug_info_manager_.StartTracing(); }
- std::set<ModuleIdentifier> StopTrace() {
- std::vector<XlaModuleDebugInfo> module_debug_info;
- xla_debug_info_manager_.StopTracing(&module_debug_info);
- std::set<ModuleIdentifier> serialized;
- for (const auto& module : module_debug_info) {
- serialized.insert(module.module_id);
- }
- return serialized;
+ absl::flat_hash_set<ModuleIdentifier> StopTrace() {
+ return xla_debug_info_manager_.StopTracing();
}
- int serial_ = 0;
-
// Simulation of compilation cache.
std::vector<DebugMetadata> external_references_;
// Use an instance per test instead of singleton to avoid interferences.
- XlaDebugInfoManager xla_debug_info_manager_;
+ XlaDebugInfoManagerTestPeer xla_debug_info_manager_;
};
// Test the cases where no trace session is involved.
TEST_F(XlaDebugInfoManagerTest, NoTraceBasic) {
auto program0 = RegisterProgram("program0");
- EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0"));
+ EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0));
auto program1 = RegisterProgram("program1");
- EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0", "program1"));
+ EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0, program1));
UnregisterProgram(program0);
- EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program1"));
+ EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program1));
UnregisterProgram(program1);
- EXPECT_TRUE(GetActiveModule().empty());
+ EXPECT_TRUE(GetModuleIds().empty());
}
TEST_F(XlaDebugInfoManagerTest, NoTraceDuplicateIds) {
auto program0A = RegisterProgram("program0");
auto program0B = RegisterProgram("program0"); // duplicates
auto program1 = RegisterProgram("program1");
- EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0", "program1"));
+ EXPECT_THAT(GetModuleIds(),
+ UnorderedElementsAre(program0A, program0B, program1));
UnregisterProgram(program1);
- EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0"));
+ EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0A, program0B));
UnregisterProgram(program0A);
- EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0"));
+ EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0B));
UnregisterProgram(program0B);
- EXPECT_TRUE(GetActiveModule().empty());
+ EXPECT_THAT(GetModuleIds(), IsEmpty());
}
// Test the cases where an active trace session is involved.
@@ -120,25 +153,24 @@
auto program0B = RegisterProgram("program0"); // duplicates
auto program1 = RegisterProgram("program1");
- // Case 1: Trace starts when no program is running.
StartTrace();
auto program2 = RegisterProgram("program2");
EXPECT_THAT(StopTrace(),
- UnorderedElementsAre("program0", "program1", "program2"));
+ UnorderedElementsAre(program0A, program0B, program1, program2));
- // Case 1: Trace starts during program is running.
StartTrace();
EXPECT_THAT(StopTrace(),
- UnorderedElementsAre("program0", "program1", "program2"));
+ UnorderedElementsAre(program0A, program0B, program1, program2));
UnregisterProgram(program2);
- EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0", "program1"));
+ EXPECT_THAT(GetModuleIds(),
+ UnorderedElementsAre(program0A, program0B, program1));
UnregisterProgram(program0A);
- EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0", "program1"));
+ EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0B, program1));
UnregisterProgram(program0B);
- EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program1"));
+ EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program1));
UnregisterProgram(program1);
- EXPECT_TRUE(GetActiveModule().empty());
+ EXPECT_THAT(GetModuleIds(), IsEmpty());
}
TEST_F(XlaDebugInfoManagerTest, UnregisterDuringTrace) {
@@ -149,10 +181,12 @@
StartTrace();
UnregisterProgram(program1);
UnregisterProgram(program0B);
- EXPECT_THAT(StopTrace(), UnorderedElementsAre("program0", "program1"));
- EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0"));
+ EXPECT_THAT(StopTrace(),
+ UnorderedElementsAre(program0A, program0B, program1));
+ EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0A));
UnregisterProgram(program0A);
}
+} // namespace
} // namespace xla
diff --git a/tensorflow/core/profiler/backends/cpu/metadata_collector.cc b/tensorflow/core/profiler/backends/cpu/metadata_collector.cc
index f59241a..5955a62 100644
--- a/tensorflow/core/profiler/backends/cpu/metadata_collector.cc
+++ b/tensorflow/core/profiler/backends/cpu/metadata_collector.cc
@@ -76,9 +76,13 @@
XPlaneBuilder xplane(plane);
const XStatMetadata& hlo_proto_stat =
*xplane.GetOrCreateStatMetadata(kHloProto);
- for (auto& p : debug_info_) {
- xplane.AddStatValue(hlo_proto_stat, *p.hlo_proto);
- p.hlo_proto.reset();
+ for (auto& hlo_proto : debug_info_) {
+ XEventMetadata* metadata =
+ xplane.GetOrCreateEventMetadata(hlo_proto->hlo_module().id());
+ metadata->set_name(hlo_proto->hlo_module().name());
+ XStatsBuilder<XEventMetadata> stats(metadata, &xplane);
+ stats.AddStatValue(hlo_proto_stat, *hlo_proto);
+ hlo_proto.reset();
}
debug_info_.clear();
}
@@ -86,7 +90,7 @@
}
private:
- std::vector<xla::XlaModuleDebugInfo> debug_info_;
+ std::vector<std::unique_ptr<xla::HloProto>> debug_info_;
bool trace_active_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(MetadataCollector);