Open source hlo_proto_to_memory_visualization_utils

PiperOrigin-RevId: 415595986
Change-Id: I989b82fe56005fbe515b49280c5243b6ee278671
diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD
index aff2523..da71689 100644
--- a/tensorflow/core/profiler/convert/BUILD
+++ b/tensorflow/core/profiler/convert/BUILD
@@ -772,3 +772,28 @@
         "@com_google_absl//absl/strings",
     ],
 )
+
+cc_library(
+    name = "hlo_proto_to_memory_visualization_utils",
+    srcs = ["hlo_proto_to_memory_visualization_utils.cc"],
+    hdrs = ["hlo_proto_to_memory_visualization_utils.h"],
+    copts = tf_profiler_copts(),
+    visibility = ["//tensorflow/core/profiler/protobuf:memory_viewer_friends"],
+    deps = [
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/service:hlo_proto_cc",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/profiler/protobuf:memory_viewer_preprocess_proto_cc",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:node_hash_map",
+        "@com_google_absl//absl/container:node_hash_set",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/types:optional",
+        "@com_google_absl//absl/types:span",
+    ],
+)
diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc
new file mode 100644
index 0000000..f022f03
--- /dev/null
+++ b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc
@@ -0,0 +1,592 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/node_hash_map.h"
+#include "absl/container/node_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/math/math_util.h"
+#include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace {
+
+using absl::StrFormat;
+using ::xla::BufferAllocationProto;
+using ::xla::HloProto;
+using ::xla::LayoutUtil;
+using ::xla::LogicalBufferProto;
+using ::xla::HloInstructionProto;
+using ::xla::Shape;
+using ::xla::ShapeUtil;
+
+double BytesToMiB(int64_t bytes) {
+  return static_cast<double>(bytes) / tensorflow::MathUtil::IPow(2, 20);
+}
+
+HeapObject MakeHeapObjectCommon(std::string label, int logical_buffer_id,
+                                int64_t logical_buffer_size_bytes,
+                                int64_t unpadded_shape_bytes) {
+  HeapObject result;
+  result.set_label(std::move(label));
+  result.set_logical_buffer_id(logical_buffer_id);
+  result.set_logical_buffer_size_mib(BytesToMiB(logical_buffer_size_bytes));
+  result.set_unpadded_shape_mib(BytesToMiB(unpadded_shape_bytes));
+  return result;
+}
+
+HeapObject MakeHeapObject(int color, std::string label, int logical_buffer_id,
+                          int64_t logical_buffer_size_bytes,
+                          int64_t unpadded_shape_bytes) {
+  HeapObject result =
+      MakeHeapObjectCommon(std::move(label), logical_buffer_id,
+                           logical_buffer_size_bytes, unpadded_shape_bytes);
+  result.set_numbered(color);
+  return result;
+}
+
+HeapObject MakeHeapObject(std::string color, std::string label,
+                          int logical_buffer_id,
+                          int64_t logical_buffer_size_bytes,
+                          int64_t unpadded_shape_bytes) {
+  HeapObject result =
+      MakeHeapObjectCommon(std::move(label), logical_buffer_id,
+                           logical_buffer_size_bytes, unpadded_shape_bytes);
+  result.set_named(std::move(color));
+  return result;
+}
+
+BufferSpan MakeBufferSpan(int32 start, int32 limit) {
+  BufferSpan result;
+  result.set_start(start);
+  result.set_limit(limit);
+  return result;
+}
+
+const Shape* ResolveShapeIndex(const Shape* shape,
+                               absl::Span<const int64_t> shape_index) {
+  for (int64_t value : shape_index) {
+    shape = &shape->tuple_shapes(value);
+  }
+  return shape;
+}
+
+// A wrapper around ShapeUtil::ByteSizeOf that clears out the layout/padding,
+// since that is considered in the ByteSizeOf calculation.
+int64_t UnpaddedSize(Shape shape) {
+  // Ensure the layout has no padding by making it the default layout.
+  LayoutUtil::SetToDefaultLayout(&shape);
+  // Note: we make a simplifying assumption here that a "minimal" size for a
+  // tuple member would be the size of a `void*` -- there may be even fancier
+  // ways of doing things, but this should give a good enough approximation of
+  // what a minimal tuple size is.
+  return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*));
+}
+
+void Convert(const xla::BufferAllocationProto_Assigned& assigned,
+             const absl::flat_hash_map<int64_t, const LogicalBufferProto*>&
+                 id_to_logical_buffer,
+             const absl::node_hash_map<std::string, const HloInstructionProto*>&
+                 name_to_hlo,
+             LogicalBuffer* result) {
+  result->set_id(assigned.logical_buffer_id()),
+      result->set_size_mib(BytesToMiB(assigned.size()));
+  const LogicalBufferProto* proto =
+      id_to_logical_buffer.at(assigned.logical_buffer_id());
+  const std::string& instruction_name = proto->defined_at().instruction_name();
+  result->set_hlo_name(instruction_name);
+  result->mutable_shape_index()->CopyFrom(proto->defined_at().shape_index());
+  const Shape top_level_shape(name_to_hlo.at(instruction_name)->shape());
+  const Shape* shape =
+      ResolveShapeIndex(&top_level_shape, proto->defined_at().shape_index());
+  result->set_shape(ShapeUtil::HumanStringWithLayout(*shape));
+}
+
+bool IsReusable(const BufferAllocationProto& buffer_allocation) {
+  return !buffer_allocation.is_thread_local() && !buffer_allocation.is_tuple();
+}
+
+void Convert(const BufferAllocationProto& proto,
+             const absl::flat_hash_map<int64_t, const LogicalBufferProto*>&
+                 id_to_logical_buffer,
+             const absl::node_hash_map<std::string, const HloInstructionProto*>&
+                 name_to_hlo,
+             BufferAllocation* result) {
+  result->set_id(proto.index());
+  result->set_size_mib(BytesToMiB(proto.size()));
+  if (proto.is_entry_computation_parameter()) {
+    result->add_attributes("entry computation parameter");
+  }
+  if (proto.maybe_live_out()) {
+    result->add_attributes("may-be live out");
+  }
+  if (IsReusable(proto)) {
+    result->add_attributes("reusable");
+  }
+  for (const auto& assigned : proto.assigned()) {
+    Convert(assigned, id_to_logical_buffer, name_to_hlo,
+            result->add_logical_buffers());
+  }
+  // Check whether all logical buffers for this buffer allocation have a common
+  // shape.
+  if (!result->logical_buffers().empty()) {
+    std::string common_shape = result->logical_buffers(0).shape();
+    for (int64_t i = 1; i < result->logical_buffers_size(); ++i) {
+      if (result->logical_buffers(i).shape() != common_shape) {
+        common_shape = "";
+        break;
+      }
+    }
+    if (!common_shape.empty()) {
+      result->set_common_shape(common_shape);
+    }
+  }
+}
+
+void NoteSpecialAllocations(
+    const absl::flat_hash_set<const BufferAllocationProto*>&
+        all_buffer_allocations,
+    const absl::flat_hash_map<int64_t, const LogicalBufferProto*>&
+        id_to_logical_buffer,
+
+    const absl::node_hash_map<std::string, const HloInstructionProto*>&
+        name_to_hlo,
+    int64_t small_buffer_size, PreprocessResult* result) {
+  int64_t entry_parameters_bytes = 0;
+  int64_t non_reusable_bytes = 0;
+  int64_t maybe_live_out_bytes = 0;
+  for (const BufferAllocationProto* buffer_allocation :
+       all_buffer_allocations) {
+    if (buffer_allocation->is_entry_computation_parameter()) {
+      entry_parameters_bytes += buffer_allocation->size();
+    }
+    if (!IsReusable(*buffer_allocation)) {
+      non_reusable_bytes += buffer_allocation->size();
+    }
+    if (buffer_allocation->maybe_live_out()) {
+      if (buffer_allocation->size() > small_buffer_size) {
+        VLOG(1) << "Maybe live out buffer allocation: "
+                << buffer_allocation->size()
+                << " bytes :: " << buffer_allocation->ShortDebugString();
+      }
+      maybe_live_out_bytes += buffer_allocation->size();
+    }
+    Convert(*buffer_allocation, id_to_logical_buffer, name_to_hlo,
+            result->add_indefinite_lifetimes());
+  }
+
+  result->set_entry_computation_parameters_mib(
+      BytesToMiB(entry_parameters_bytes));
+  result->set_non_reusable_mib(BytesToMiB(non_reusable_bytes));
+  result->set_maybe_live_out_mib(BytesToMiB(maybe_live_out_bytes));
+}
+
+}  // namespace
+
+absl::StatusOr<PreprocessResult> ConvertHloProtoToPreprocessResult(
+    const HloProto& hlo_proto, int64_t small_buffer_size,
+    int64_t heap_simulator_trace_id) {
+  // Construct a mapping from name to HLO proto.
+  absl::node_hash_map<std::string, const HloInstructionProto*> name_to_hlo;
+  for (const auto& computation : hlo_proto.hlo_module().computations()) {
+    for (const auto& instruction : computation.instructions()) {
+      name_to_hlo[instruction.name()] = &instruction;
+      VLOG(1) << "HLO: " << instruction.ShortDebugString();
+    }
+  }
+
+  // Mapping from logical buffer ID to logical buffer, and set of all logical
+  // buffer protos.
+  absl::flat_hash_map<int64_t, const LogicalBufferProto*> id_to_logical_buffer;
+  absl::flat_hash_set<const LogicalBufferProto*> all_logical_buffers;
+  for (const auto& logical_buffer :
+       hlo_proto.buffer_assignment().logical_buffers()) {
+    VLOG(1) << "Logical buffer: " << logical_buffer.ShortDebugString();
+    id_to_logical_buffer[logical_buffer.id()] = &logical_buffer;
+    all_logical_buffers.insert(&logical_buffer);
+  }
+
+  // Mapping from logocal buffer proto to the buffer allocation that it exists
+  // inside (there must be only one).
+  //
+  // Also a reverse mapping from buffer allocation proto to the set of logical
+  // buffer protos that exist inside of it.
+  absl::flat_hash_map<const LogicalBufferProto*, const BufferAllocationProto*>
+      logical_buffer_to_buffer_allocation;
+  absl::node_hash_map<const BufferAllocationProto*,
+                      absl::flat_hash_set<const LogicalBufferProto*>>
+      buffer_allocation_to_logical_buffers;
+  absl::flat_hash_set<const BufferAllocationProto*> all_buffer_allocations;
+  for (const BufferAllocationProto& buffer_allocation :
+       hlo_proto.buffer_assignment().buffer_allocations()) {
+    all_buffer_allocations.insert(&buffer_allocation);
+    for (const xla::BufferAllocationProto_Assigned& assigned :
+         buffer_allocation.assigned()) {
+      const LogicalBufferProto* logical_buffer =
+          id_to_logical_buffer.at(assigned.logical_buffer_id());
+      buffer_allocation_to_logical_buffers[&buffer_allocation].insert(
+          logical_buffer);
+      auto insert_result = logical_buffer_to_buffer_allocation.insert(
+          {logical_buffer, &buffer_allocation});
+      if (!insert_result.second) {
+        return absl::InvalidArgumentError(
+            "A logical buffer appears to be associated with multiple buffer "
+            "allocations.");
+      }
+    }
+  }
+
+  std::vector<int64_t> logical_buffers;
+  std::vector<int64_t> peak_logical_buffers;
+
+  int64_t heap_size_bytes = 0;
+  int64_t unpadded_heap_size_bytes = 0;
+
+  int64_t peak_heap_size_bytes = 0;
+  int64_t unpadded_peak_heap_size_bytes = 0;  // Unpadded size at peak.
+  int64_t peak_heap_size_position = 0;
+  std::vector<double> heap_sizes;
+  std::vector<double> unpadded_heap_sizes;
+
+  absl::node_hash_map<int64_t, std::pair<int64_t, absl::optional<int64_t>>>
+      logical_buffer_spans;
+  absl::flat_hash_set<const LogicalBufferProto*> seen;
+  absl::flat_hash_set<const BufferAllocationProto*> seen_buffer_allocations;
+
+  // Run through all the simulator events in the given trace, and simulate the
+  // heap in order to find the point of peak memory usage and record its
+  // associated metadata.
+  if (heap_simulator_trace_id >= 0 &&
+      heap_simulator_trace_id <
+          hlo_proto.buffer_assignment().heap_simulator_traces_size()) {
+    const auto& simulator_events =
+        hlo_proto.buffer_assignment()
+            .heap_simulator_traces(heap_simulator_trace_id)
+            .events();
+    for (const auto& event : simulator_events) {
+      heap_sizes.push_back(BytesToMiB(heap_size_bytes));
+      unpadded_heap_sizes.push_back(BytesToMiB(unpadded_heap_size_bytes));
+      const auto* logical_buffer = id_to_logical_buffer.at(event.buffer_id());
+      seen.insert(logical_buffer);
+      seen_buffer_allocations.insert(
+          logical_buffer_to_buffer_allocation.at(logical_buffer));
+      const auto& instruction_name =
+          logical_buffer->defined_at().instruction_name();
+      const Shape top_level_shape(name_to_hlo.at(instruction_name)->shape());
+      const Shape* shape = ResolveShapeIndex(
+          &top_level_shape, logical_buffer->defined_at().shape_index());
+      if (event.kind() == xla::HeapSimulatorTrace_Event::ALLOC ||
+          event.kind() == xla::HeapSimulatorTrace_Event::SHARE_WITH) {
+        logical_buffers.push_back(event.buffer_id());
+        heap_size_bytes += logical_buffer->size();
+        unpadded_heap_size_bytes += UnpaddedSize(*shape);
+        // Initialize the buffer span from the current event to the last event.
+        logical_buffer_spans[event.buffer_id()] = {heap_sizes.size() - 1,
+                                                   simulator_events.size() - 1};
+        int64_t prior_peak_heap_size_bytes = peak_heap_size_bytes;
+        peak_heap_size_bytes = std::max(peak_heap_size_bytes, heap_size_bytes);
+        if (prior_peak_heap_size_bytes != peak_heap_size_bytes) {
+          peak_heap_size_position = heap_sizes.size() - 1;
+          unpadded_peak_heap_size_bytes = unpadded_heap_size_bytes;
+          VLOG(1) << StrFormat("New peak heap size on %d: %s :: %d bytes",
+                               peak_heap_size_position, instruction_name,
+                               peak_heap_size_bytes);
+          peak_logical_buffers = logical_buffers;
+        }
+      } else if (event.kind() == xla::HeapSimulatorTrace_Event::FREE) {
+        logical_buffers.erase(
+            std::remove(logical_buffers.begin(), logical_buffers.end(),
+                        event.buffer_id()),
+            logical_buffers.end());
+        heap_size_bytes -= logical_buffer->size();
+        unpadded_heap_size_bytes -= UnpaddedSize(*shape);
+        logical_buffer_spans[event.buffer_id()].second = heap_sizes.size() - 1;
+        if (heap_size_bytes < 0) {
+          return absl::InvalidArgumentError(absl::StrCat(
+              "heap_size_bytes should be non-negative: ", heap_size_bytes));
+        }
+      } else {
+        return absl::InvalidArgumentError(
+            absl::StrCat("Unhandled event kind: ", event.kind()));
+      }
+    }
+
+    if (seen_buffer_allocations.size() != 1) {
+      return absl::InvalidArgumentError(
+          absl::StrCat("All heap simulation should work out of a single buffer "
+                       "allocation, actual seen_buffer_allocations.size():",
+                       seen_buffer_allocations.size()));
+    }
+  }
+
+  std::cout << "Found " << peak_logical_buffers.size()
+            << " logical buffers alive at point of peak heap usage."
+            << std::endl;
+
+  VLOG(1) << "Peak logical buffers: ["
+          << absl::StrJoin(peak_logical_buffers, ",") << "]";
+
+  int64_t indefinite_memory_usage_bytes = 0;
+  std::vector<HeapObject> max_heap;
+  int colorno = 0;
+  int64_t rest = 0;
+
+  // Helper lambda that adds the logical buffer as an element in the "max heap"
+  // view with constitutent logical buffers.
+  auto add_heap_object = [&](const LogicalBufferProto* logical_buffer) {
+    if (logical_buffer->size() <= small_buffer_size) {
+      rest += logical_buffer->size();
+      return;
+    }
+    const std::string& instruction_name =
+        logical_buffer->defined_at().instruction_name();
+    const Shape top_level_shape(name_to_hlo.at(instruction_name)->shape());
+    const Shape* shape = ResolveShapeIndex(
+        &top_level_shape, logical_buffer->defined_at().shape_index());
+    std::string shape_string = ShapeUtil::HumanStringWithLayout(*shape);
+    int64 unpadded_shape_bytes = UnpaddedSize(*shape);
+    const std::string& metadata =
+        name_to_hlo.at(instruction_name)->metadata().op_name();
+    std::string label =
+        StrFormat("%s: %s # %s", instruction_name, shape_string, metadata);
+    max_heap.push_back(
+        MakeHeapObject(colorno++, std::move(label), logical_buffer->id(),
+                       logical_buffer->size(), unpadded_shape_bytes));
+  };
+
+  // Now look for all logical buffers which have not been seen, and assume they
+  // have indefinite lifetime if they are not in thread-local buffer
+  // allocations.
+  absl::flat_hash_set<const LogicalBufferProto*> unseen;
+  for (const LogicalBufferProto* logical_buffer : all_logical_buffers) {
+    if (!seen.contains(logical_buffer)) {
+      unseen.insert(logical_buffer);
+    }
+  }
+  for (const LogicalBufferProto* logical_buffer : unseen) {
+    const BufferAllocationProto* buffer_allocation =
+        logical_buffer_to_buffer_allocation.at(logical_buffer);
+    if (buffer_allocation->is_thread_local()) {
+      continue;
+    }
+    // Clear out the assigned logical buffers when stringifying the buffer
+    // allocation, as it can be a long list.
+    auto to_string = [](const BufferAllocationProto* p) {
+      BufferAllocationProto copy = *p;
+      copy.mutable_assigned()->Clear();
+      return copy.ShortDebugString();
+    };
+    if (seen_buffer_allocations.insert(buffer_allocation).second) {
+      indefinite_memory_usage_bytes += buffer_allocation->size();
+      const auto& logical_buffers =
+          buffer_allocation_to_logical_buffers.at(buffer_allocation);
+      if (logical_buffers.size() == 1) {
+        add_heap_object(*logical_buffers.begin());
+      } else {
+        VLOG(1) << "Indefinite lifetime, no heap object shown due to "
+                   "multiple logical buffers in buffer allocation: "
+                << logical_buffer->ShortDebugString()
+                << " :: " << to_string(buffer_allocation) << std::endl;
+      }
+      if (buffer_allocation->size() > small_buffer_size) {
+        VLOG(1) << "Indefinite memory usage now: "
+                << indefinite_memory_usage_bytes << " bytes (+"
+                << buffer_allocation->size() << " bytes)";
+      }
+    }
+  }
+
+  // For the buffers that have indefinite lifetime (that is, lifetime not
+  // reflected by the heap simulation) add it to the peak values and the vectors
+  // of heap sizes.
+  peak_heap_size_bytes += indefinite_memory_usage_bytes;
+  unpadded_peak_heap_size_bytes += indefinite_memory_usage_bytes;
+  double addend = BytesToMiB(indefinite_memory_usage_bytes);
+  for (int i = 0; i < heap_sizes.size(); ++i) {
+    heap_sizes[i] += addend;
+    unpadded_heap_sizes[i] += addend;
+  }
+
+  // Accumulate data for use in a stacked bar plot.
+  //
+  // We accumulate it in "program order" -- the order in which it was placed
+  // into the logical_buffers sequence above was program order, and we iterate
+  // that order to create data points.
+  for (int logical_buffer_id : peak_logical_buffers) {
+    const auto* logical_buffer = id_to_logical_buffer.at(logical_buffer_id);
+    add_heap_object(logical_buffer);
+  }
+  if (rest != 0) {
+    max_heap.push_back(MakeHeapObject(
+        "gray", StrFormat("small (<%d bytes)", small_buffer_size), -1, rest,
+        0));
+  }
+
+  std::vector<const HeapObject*> max_heap_by_size;
+  max_heap_by_size.reserve(max_heap.size());
+  for (const auto& object : max_heap) {
+    max_heap_by_size.push_back(&object);
+  }
+  std::sort(max_heap_by_size.begin(), max_heap_by_size.end(),
+            [](const HeapObject* a, const HeapObject* b) {
+              return a->logical_buffer_size_mib() >
+                     b->logical_buffer_size_mib();
+            });
+
+  std::vector<int> max_heap_to_by_size;
+  max_heap_to_by_size.reserve(max_heap.size());
+  for (const auto& object : max_heap) {
+    auto it =
+        std::find(max_heap_by_size.begin(), max_heap_by_size.end(), &object);
+    int index = std::distance(max_heap_by_size.begin(), it);
+    max_heap_to_by_size.push_back(index);
+  }
+
+  std::vector<int> by_size_to_max_heap;
+  for (const auto* object : max_heap_by_size) {
+    int index = object - &max_heap[0];
+    by_size_to_max_heap.push_back(index);
+  }
+
+  PreprocessResult result;
+  result.set_module_name(hlo_proto.hlo_module().name());
+  result.set_entry_computation_name(
+      hlo_proto.hlo_module().entry_computation_name());
+  *result.mutable_heap_sizes() = {heap_sizes.begin(), heap_sizes.end()};
+  *result.mutable_unpadded_heap_sizes() = {unpadded_heap_sizes.begin(),
+                                           unpadded_heap_sizes.end()};
+  *result.mutable_max_heap() = {max_heap.begin(), max_heap.end()};
+  for (const HeapObject* o : max_heap_by_size) {
+    *result.add_max_heap_by_size() = *o;
+  }
+  *result.mutable_max_heap_to_by_size() = {max_heap_to_by_size.begin(),
+                                           max_heap_to_by_size.end()};
+  *result.mutable_by_size_to_max_heap() = {by_size_to_max_heap.begin(),
+                                           by_size_to_max_heap.end()};
+  result.set_peak_heap_mib(BytesToMiB(peak_heap_size_bytes));
+  result.set_peak_unpadded_heap_mib(BytesToMiB(unpadded_peak_heap_size_bytes));
+  result.set_peak_heap_size_position(peak_heap_size_position);
+
+  for (const auto& item : logical_buffer_spans) {
+    (*result.mutable_logical_buffer_spans())[item.first] =
+        MakeBufferSpan(item.second.first, item.second.second.value());
+  }
+
+  NoteSpecialAllocations(all_buffer_allocations, id_to_logical_buffer,
+                         name_to_hlo, small_buffer_size, &result);
+  return result;
+}
+
+// From a list of heap simulator traces, identify the one that has the largest
+// number of HBM (color = 0) memory events.
+// If unable to find the heap simulator trace, return -1, and
+// ConvertHloProtoToPreprocessResult will not consider heap_simulator_traces
+// during preprocess.
+int64_t GetHeapSimulatorTraceIdFromEvents(const HloProto& proto) {
+  absl::flat_hash_map<int64_t, const xla::LogicalBufferProto*>
+      id_to_logical_buffer;
+  for (const auto& logical_buffer :
+       proto.buffer_assignment().logical_buffers()) {
+    id_to_logical_buffer[logical_buffer.id()] = &logical_buffer;
+  }
+  int64_t best_index = -1;
+  int64_t best_event_count = 0;
+  for (int64_t i = 0;
+       i < proto.buffer_assignment().heap_simulator_traces_size(); i++) {
+    const auto& heap_simulator_trace =
+        proto.buffer_assignment().heap_simulator_traces(i);
+    int64_t event_count = 0;
+    for (const auto& event : heap_simulator_trace.events()) {
+      const auto iter = id_to_logical_buffer.find(event.buffer_id());
+      if (iter == id_to_logical_buffer.end()) {
+        continue;
+      }
+      // TODO(tianrun): Add a "memory space color" query parameter.
+      if (iter->second->color() == 0) {
+        event_count++;
+      }
+    }
+    if (event_count > best_event_count) {
+      best_index = i;
+      best_event_count = event_count;
+    }
+  }
+
+  return best_index;
+}
+
+// Tries to get the correct heap simulator trace based on
+// buffer_allocation_index.
+int64_t GetHeapSimulatorTraceIdFromBufferAllocationIndex(
+    const HloProto& proto) {
+  absl::flat_hash_map<int64_t, const xla::BufferAllocationProto*>
+      id_to_buffer_allocation;
+  for (const auto& buffer_allocation :
+       proto.buffer_assignment().buffer_allocations()) {
+    id_to_buffer_allocation[buffer_allocation.index()] = &buffer_allocation;
+  }
+  for (int64_t i = 0;
+       i < proto.buffer_assignment().heap_simulator_traces_size(); ++i) {
+    int64_t buffer_allocation_index = proto.buffer_assignment()
+                                          .heap_simulator_traces(i)
+                                          .buffer_allocation_index();
+    const auto iter = id_to_buffer_allocation.find(buffer_allocation_index);
+    if (buffer_allocation_index && iter != id_to_buffer_allocation.end()) {
+      // TODO(tianrun): Add a "memory space color" query parameter.
+      // Find the heap simulator trace that corresponds to the HLO temporaries
+      // buffer allocation, where is_thread_local,
+      // is_entry_computation_parameter, is_constant, and maybe_live_out will
+      // all be false.
+      const auto* buffer_allocation = iter->second;
+      if (buffer_allocation->color() == 0 &&
+          !buffer_allocation->is_thread_local() &&
+          !buffer_allocation->is_entry_computation_parameter() &&
+          !buffer_allocation->is_constant() &&
+          !buffer_allocation->maybe_live_out()) {
+        return i;
+      }
+    }
+  }
+  return -1;
+}
+
+int64_t GetHeapSimulatorTraceId(const HloProto& proto) {
+  int64_t id = GetHeapSimulatorTraceIdFromBufferAllocationIndex(proto);
+  if (id != -1) {
+    return id;
+  }
+  return GetHeapSimulatorTraceIdFromEvents(proto);
+}
+
+}  // namespace profiler
+}  // namespace tensorflow
diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h
new file mode 100644
index 0000000..c97d2e9
--- /dev/null
+++ b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h
@@ -0,0 +1,43 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_MEMORY_VISUALIZATION_UTILS_H_
+#define TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_MEMORY_VISUALIZATION_UTILS_H_
+
+#include <cstdint>
+
+#include "absl/status/statusor.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h"
+
+namespace tensorflow {
+namespace profiler {
+
+// Convert HloProto to PreprocessResult proto for memory visualization.
+// small_buffer_size sets the byte size within which we collapse buffer entries
+// for the max-heap display.
+// heap_simulator_trace_id sets the index of heap simulator trace to be
+// displayed. If it is set to -1, then HLOProto.heap_simulator_traces will not
+// be considered during the preprocess.
+absl::StatusOr<PreprocessResult> ConvertHloProtoToPreprocessResult(
+    const xla::HloProto& hlo_proto, int64_t small_buffer_size,
+    int64_t heap_simulator_trace_id);
+
+int64_t GetHeapSimulatorTraceId(const xla::HloProto& proto);
+
+}  // namespace profiler
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_MEMORY_VISUALIZATION_UTILS_H_