[XLA] Adding method to construct log of buffer liveness information in BufferAssignment.

PiperOrigin-RevId: 354190100
Change-Id: Ie6d6c1583b44d190d507555fce50df0d96cf7df4
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index f48ba3d..a9cb87d 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1325,6 +1325,7 @@
         "//tensorflow/compiler/xla:util",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/memory",
@@ -1361,6 +1362,7 @@
         "//tensorflow/core:test",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
     ],
 )
 
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 3821c83..0fab934 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -22,6 +22,7 @@
 #include <ostream>
 #include <utility>
 
+#include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/memory/memory.h"
@@ -794,6 +795,78 @@
   return output;
 }
 
+string BufferAssignment::BufferInfoString() const {
+  string binfo;
+  // Columns in buffer information:
+  // buffer_id: int. This value can be used to match the allocation in
+  // allocation information.
+  // buffer_name: string.
+  // offset: int. Starting position of the buffer in the memory space.
+  // size: int. Size of the buffer in bytes.
+  // definition_time: int. Position in the schedule where the buffer starts
+  // being live (inclusive).
+  // end_time: int. Position in the schedule where the buffer stops being live
+  // (exclusive).
+  // num_uses: int. Number of uses of the buffer.
+  // use_names: string. This is a semicolon-separated list of string
+  // representation of uses.
+  // Append the column names.
+  absl::StrAppend(&binfo,
+                  "buffer_id,buffer_name,offset,size,"
+                  "definition_time,end_time,num_uses,use_times,use_names\n");
+  const HloLiveRange& live_ranges = hlo_live_range();
+  const auto& instruction_schedule = live_ranges.instruction_schedule();
+  const auto& buffer_live_ranges = live_ranges.buffer_live_ranges();
+  // Sort the buffers by Id.
+  std::vector<std::pair<const HloValue*, BufferAllocation::OffsetSize>> buffers;
+  for (const BufferAllocation& allocation : allocations_) {
+    absl::c_copy(allocation.assigned_buffers(), std::back_inserter(buffers));
+  }
+  absl::c_sort(
+      buffers,
+      [](const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b1,
+         const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b2) {
+        return b1.first->id() < b2.first->id();
+      });
+  for (const auto& buffer_pair : buffers) {
+    const HloValue& buffer = *buffer_pair.first;
+    const BufferAllocation::OffsetSize& offset_size = buffer_pair.second;
+    if (!buffer_live_ranges.contains(&buffer)) {
+      continue;
+    }
+    // Ordering uses by their use position.
+    std::vector<std::pair<int64, std::string>> uses;
+    uses.reserve(buffer.uses().size());
+    for (const HloUse& use : buffer.uses()) {
+      uses.emplace_back(instruction_schedule.at(use.instruction),
+                        use.ToString());
+    }
+    absl::c_sort(uses);
+    std::vector<int64> use_positions;
+    std::vector<std::string> use_names;
+    use_positions.reserve(uses.size());
+    use_names.reserve(uses.size());
+    for (const auto& use : uses) {
+      use_positions.push_back(use.first);
+      use_names.push_back(use.second);
+    }
+    const int64 definition_time =
+        instruction_schedule.at(buffer.defining_position().instruction);
+    const int64 end_t = buffer_live_ranges.at(&buffer).end;
+    absl::StrAppend(&binfo, buffer.id(), ",");
+    absl::StrAppend(&binfo, "\"", buffer.ToShortString(), "\",");
+    absl::StrAppend(&binfo, offset_size.offset, ",");
+    absl::StrAppend(&binfo, offset_size.size, ",");
+    absl::StrAppend(&binfo, definition_time, ",");
+    absl::StrAppend(&binfo, end_t, ",");
+    absl::StrAppend(&binfo, use_positions.size(), ",");
+    absl::StrAppend(&binfo, "\"", absl::StrJoin(use_positions, ";"), "\",");
+    absl::StrAppend(&binfo, "\"", absl::StrJoin(use_names, ";"), "\"");
+    absl::StrAppend(&binfo, "\n");
+  }
+  return binfo;
+}
+
 BufferAssignmentProto BufferAssignment::ToProto() const {
   BufferAssignmentProto proto;
   // NOTE: DataflowAnalysis state is serialized here in BufferAssignment,
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index 6f4c98c..de9e665 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -466,6 +466,7 @@
   const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
 
   string ToString() const;
+  string BufferInfoString() const;
   BufferAssignmentProto ToProto() const;
 
   // Statistics for the assignment.  Values initialized to -1 are not always
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index b49ca64..aa88e5e 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -23,6 +23,7 @@
 
 #include "absl/container/flat_hash_set.h"
 #include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/service/buffer_value.h"
 #include "tensorflow/compiler/xla/service/call_graph.h"
@@ -2518,6 +2519,41 @@
             GetAllocation(*buffers, param0, {1, 1}));
 }
 
+TEST_F(BufferAssignmentTest, BufferInfoStringTest) {
+  absl::string_view module_str = R"(
+HloModule test_module
+
+ENTRY %test_module {
+  %param.0 = s32[1024]{0} parameter(0)
+  %param.1 = s32[1024]{0} parameter(1)
+  %mul = s32[1024]{0} multiply(%param.0, %param.1)
+  %add = s32[1024]{0} add(%mul, %param.0)
+  ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[1024] %add), dimensions={0}
+})";
+
+  absl::string_view reference_str =
+      R"(buffer_id,buffer_name,offset,size,definition_time,end_time,num_uses,use_times,use_names
+0,"<0 param.0 @0>",0,4096,0,5,2,"2;3","mul, operand 0;add, operand 1"
+1,"<1 param.1 @0>",0,4096,1,5,1,"2","mul, operand 1"
+2,"<2 mul @0>",0,4096,2,3,1,"3","add, operand 0"
+3,"<3 add @0>",0,4096,3,4,1,"4","bcast, operand 0"
+4,"<4 bcast @0>",0,4194304,4,5,0,"",""
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  HloInstruction* const param0 = FindInstruction(m.get(), "param.0");
+  HloInstruction* const param1 = FindInstruction(m.get(), "param.1");
+  HloInstruction* const mul = FindInstruction(m.get(), "mul");
+  HloInstruction* const add = FindInstruction(m.get(), "add");
+  HloInstruction* const bcast = FindInstruction(m.get(), "bcast");
+  // Run buffer assignment.
+  auto assignment = RunBufferAssignmentWithInstructionSequence(
+      m.get(), {param0, param1, mul, add, bcast});
+  const std::string buffer_info_str = assignment->BufferInfoString();
+
+  EXPECT_EQ(buffer_info_str, reference_str);
+}
+
 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
   auto module = CreateNewVerifiedModule();
   auto builder = HloComputation::Builder(TestName());