[XLA] Adding method to construct log of buffer liveness information in BufferAssignment.
PiperOrigin-RevId: 354190100
Change-Id: Ie6d6c1583b44d190d507555fce50df0d96cf7df4
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index f48ba3d..a9cb87d 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1325,6 +1325,7 @@
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
@@ -1361,6 +1362,7 @@
"//tensorflow/core:test",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 3821c83..0fab934 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -22,6 +22,7 @@
#include <ostream>
#include <utility>
+#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
@@ -794,6 +795,78 @@
return output;
}
+string BufferAssignment::BufferInfoString() const {
+ string binfo;
+ // Columns in buffer information:
+ // buffer_id: int. This value can be used to match the allocation in
+ // allocation information.
+ // buffer_name: string.
+ // offset: int. Starting position of the buffer in the memory space.
+ // size: int. Size of the buffer in bytes.
+ // definition_time: int. Position in the schedule where the buffer starts
+ // being live (inclusive).
+ // end_time: int. Position in the schedule where the buffer stops being live
+ // (exclusive).
+ // num_uses: int. Number of uses of the buffer.
+ // use_names: string. This is a semicolon-separated list of string
+ // representation of uses.
+ // Append the column names.
+ absl::StrAppend(&binfo,
+ "buffer_id,buffer_name,offset,size,"
+ "definition_time,end_time,num_uses,use_times,use_names\n");
+ const HloLiveRange& live_ranges = hlo_live_range();
+ const auto& instruction_schedule = live_ranges.instruction_schedule();
+ const auto& buffer_live_ranges = live_ranges.buffer_live_ranges();
+ // Sort the buffers by Id.
+ std::vector<std::pair<const HloValue*, BufferAllocation::OffsetSize>> buffers;
+ for (const BufferAllocation& allocation : allocations_) {
+ absl::c_copy(allocation.assigned_buffers(), std::back_inserter(buffers));
+ }
+ absl::c_sort(
+ buffers,
+ [](const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b1,
+ const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b2) {
+ return b1.first->id() < b2.first->id();
+ });
+ for (const auto& buffer_pair : buffers) {
+ const HloValue& buffer = *buffer_pair.first;
+ const BufferAllocation::OffsetSize& offset_size = buffer_pair.second;
+ if (!buffer_live_ranges.contains(&buffer)) {
+ continue;
+ }
+ // Ordering uses by their use position.
+ std::vector<std::pair<int64, std::string>> uses;
+ uses.reserve(buffer.uses().size());
+ for (const HloUse& use : buffer.uses()) {
+ uses.emplace_back(instruction_schedule.at(use.instruction),
+ use.ToString());
+ }
+ absl::c_sort(uses);
+ std::vector<int64> use_positions;
+ std::vector<std::string> use_names;
+ use_positions.reserve(uses.size());
+ use_names.reserve(uses.size());
+ for (const auto& use : uses) {
+ use_positions.push_back(use.first);
+ use_names.push_back(use.second);
+ }
+ const int64 definition_time =
+ instruction_schedule.at(buffer.defining_position().instruction);
+ const int64 end_t = buffer_live_ranges.at(&buffer).end;
+ absl::StrAppend(&binfo, buffer.id(), ",");
+ absl::StrAppend(&binfo, "\"", buffer.ToShortString(), "\",");
+ absl::StrAppend(&binfo, offset_size.offset, ",");
+ absl::StrAppend(&binfo, offset_size.size, ",");
+ absl::StrAppend(&binfo, definition_time, ",");
+ absl::StrAppend(&binfo, end_t, ",");
+ absl::StrAppend(&binfo, use_positions.size(), ",");
+ absl::StrAppend(&binfo, "\"", absl::StrJoin(use_positions, ";"), "\",");
+ absl::StrAppend(&binfo, "\"", absl::StrJoin(use_names, ";"), "\"");
+ absl::StrAppend(&binfo, "\n");
+ }
+ return binfo;
+}
+
BufferAssignmentProto BufferAssignment::ToProto() const {
BufferAssignmentProto proto;
// NOTE: DataflowAnalysis state is serialized here in BufferAssignment,
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index 6f4c98c..de9e665 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -466,6 +466,7 @@
const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
string ToString() const;
+ string BufferInfoString() const;
BufferAssignmentProto ToProto() const;
// Statistics for the assignment. Values initialized to -1 are not always
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index b49ca64..aa88e5e 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -23,6 +23,7 @@
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
@@ -2518,6 +2519,41 @@
GetAllocation(*buffers, param0, {1, 1}));
}
+TEST_F(BufferAssignmentTest, BufferInfoStringTest) {
+ absl::string_view module_str = R"(
+HloModule test_module
+
+ENTRY %test_module {
+ %param.0 = s32[1024]{0} parameter(0)
+ %param.1 = s32[1024]{0} parameter(1)
+ %mul = s32[1024]{0} multiply(%param.0, %param.1)
+ %add = s32[1024]{0} add(%mul, %param.0)
+ ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[1024] %add), dimensions={0}
+})";
+
+ absl::string_view reference_str =
+ R"(buffer_id,buffer_name,offset,size,definition_time,end_time,num_uses,use_times,use_names
+0,"<0 param.0 @0>",0,4096,0,5,2,"2;3","mul, operand 0;add, operand 1"
+1,"<1 param.1 @0>",0,4096,1,5,1,"2","mul, operand 1"
+2,"<2 mul @0>",0,4096,2,3,1,"3","add, operand 0"
+3,"<3 add @0>",0,4096,3,4,1,"4","bcast, operand 0"
+4,"<4 bcast @0>",0,4194304,4,5,0,"",""
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+ HloInstruction* const param0 = FindInstruction(m.get(), "param.0");
+ HloInstruction* const param1 = FindInstruction(m.get(), "param.1");
+ HloInstruction* const mul = FindInstruction(m.get(), "mul");
+ HloInstruction* const add = FindInstruction(m.get(), "add");
+ HloInstruction* const bcast = FindInstruction(m.get(), "bcast");
+ // Run buffer assignment.
+ auto assignment = RunBufferAssignmentWithInstructionSequence(
+ m.get(), {param0, param1, mul, add, bcast});
+ const std::string buffer_info_str = assignment->BufferInfoString();
+
+ EXPECT_EQ(buffer_info_str, reference_str);
+}
+
TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());