HLO proto to tools data conversion.
PiperOrigin-RevId: 467818247
diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD
index d3503ce..a12bb08 100644
--- a/tensorflow/core/profiler/convert/BUILD
+++ b/tensorflow/core/profiler/convert/BUILD
@@ -1,6 +1,6 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "if_oss", "tf_cc_test")
-load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts", "tf_profiler_xla_proto_header")
+load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts", "tf_profiler_pybind_cc_library_wrapper", "tf_profiler_xla_proto_header")
package(
default_visibility = ["//tensorflow/core/profiler:internal"],
@@ -712,6 +712,7 @@
hdrs = ["xplane_to_tools_data.h"],
copts = tf_profiler_copts(),
deps = [
+ ":hlo_to_tools_data_headers_only",
":op_stats_to_input_pipeline_analysis",
":op_stats_to_op_profile",
":op_stats_to_overview_page",
@@ -808,6 +809,32 @@
)
cc_library(
+ name = "hlo_to_tools_data_impl",
+ srcs = ["hlo_to_tools_data.cc"],
+ hdrs = ["hlo_to_tools_data.h"],
+ copts = tf_profiler_copts(),
+ visibility = [
+ "//tensorflow/python:__pkg__",
+ ],
+ deps = [
+ ":hlo_proto_to_memory_visualization_utils",
+ ":xplane_to_hlo",
+ "//tensorflow/compiler/xla/service:hlo_proto_cc",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ ],
+ alwayslink = True,
+)
+
+tf_profiler_pybind_cc_library_wrapper(
+ name = "hlo_to_tools_data_headers_only",
+ actual = ":hlo_to_tools_data_impl",
+)
+
+cc_library(
name = "hlo_proto_to_memory_visualization_utils",
srcs = ["hlo_proto_to_memory_visualization_utils.cc"],
hdrs = ["hlo_proto_to_memory_visualization_utils.h"],
diff --git a/tensorflow/core/profiler/convert/hlo_to_tools_data.cc b/tensorflow/core/profiler/convert/hlo_to_tools_data.cc
new file mode 100644
index 0000000..0a33547
--- /dev/null
+++ b/tensorflow/core/profiler/convert/hlo_to_tools_data.cc
@@ -0,0 +1,114 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/profiler/convert/hlo_to_tools_data.h"
+
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/path.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h"
+#include "tensorflow/core/profiler/convert/xplane_to_hlo.h"
+
+namespace tensorflow {
+namespace profiler {
+
+namespace {
+
+std::pair<std::string, bool> ConvertHloProtoToMemoryViewer(
+ const xla::HloProto& hlo_proto) {
+ static constexpr int kSmallBufferSize = 16 * 1024; // 16KB
+ static constexpr int kMemorySpaceColor = 0; // HBM
+
+ auto result_or = ConvertHloProtoToPreprocessResult(
+ hlo_proto, kSmallBufferSize,
+ GetHeapSimulatorTraceId(hlo_proto, kMemorySpaceColor), kMemorySpaceColor);
+ if (!result_or.ok()) {
+ LOG(ERROR) << "Failed to convert HLO proto to memory viewer result: "
+ << result_or.status().message();
+ return std::make_pair("", false);
+ }
+
+ std::string json_output;
+ tensorflow::protobuf::util::JsonPrintOptions options;
+ options.always_print_primitive_fields = true;
+ auto encoded_status = tensorflow::protobuf::util::MessageToJsonString(
+ result_or.value(), &json_output, options);
+ if (!encoded_status.ok()) {
+ LOG(ERROR) << "Failed to convert memory viewer result to JSON format: "
+ << encoded_status.message();
+ return std::make_pair("", false);
+ }
+
+ return std::make_pair(json_output, true);
+}
+
+} // namespace
+
+std::pair<std::string, bool> ConvertHloProtoToToolData(
+ const std::vector<std::string>& xspace_paths,
+ const absl::string_view tool_name,
+ const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
+ options) {
+ if (xspace_paths.empty()) {
+ return std::make_pair("", false);
+ }
+
+ // <options> must provide a hlo_module_name field to identify the HLO module.
+ auto* result = gtl::FindOrNull(options, "hlo_module_name");
+ if (!result) {
+ LOG(ERROR) << "Can not find HLO module name from options.";
+ return std::make_pair("", false);
+ }
+ const std::string* hlo_module_name = std::get_if<std::string>(result);
+ if (!hlo_module_name || hlo_module_name->empty()) {
+ LOG(ERROR) << "Can not find HLO module name from options.";
+ return std::make_pair("", false);
+ }
+
+ // Load HLO module from file.
+ absl::string_view base_dir = tensorflow::io::Dirname(xspace_paths[0]);
+ std::string hlo_proto_file_name =
+ GetHloProtoFileName(base_dir, *hlo_module_name);
+ xla::HloProto hlo_proto;
+ tensorflow::Status status = tensorflow::ReadBinaryProto(
+ tensorflow::Env::Default(), hlo_proto_file_name, &hlo_proto);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to read HLO proto: " << status.error_message();
+ return std::make_pair("", false);
+ }
+
+ // Convert from HLO proto to tools data.
+ if (tool_name == "memory_viewer") {
+ return ConvertHloProtoToMemoryViewer(hlo_proto);
+ } else {
+ LOG(ERROR) << "Can not find tool: " << tool_name
+ << ". Please update to the latest version of Tensorflow.";
+ return std::make_pair("", false);
+ }
+}
+
+} // namespace profiler
+} // namespace tensorflow
diff --git a/tensorflow/core/profiler/convert/hlo_to_tools_data.h b/tensorflow/core/profiler/convert/hlo_to_tools_data.h
new file mode 100644
index 0000000..d65d59a
--- /dev/null
+++ b/tensorflow/core/profiler/convert/hlo_to_tools_data.h
@@ -0,0 +1,46 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_HLO_TO_TOOLS_DATA_H_
+#define TENSORFLOW_CORE_PROFILER_CONVERT_HLO_TO_TOOLS_DATA_H_
+
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/string_view.h"
+
+namespace tensorflow {
+namespace profiler {
+
+// Convert HLO proto to tool specific data.
+// <options> must provide a "hlo_module_name" field to identify which HLO proto
+// is used for the conversion.
+// The file path of the HLO proto is automatically inferred from <xspace_paths>
+// and <options>.
+// Return the serialized string of tool specific data and whether the conversion
+// is successful.
+std::pair<std::string, bool> ConvertHloProtoToToolData(
+ const std::vector<std::string>& xspace_paths,
+ const absl::string_view tool_name,
+ const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
+ options);
+
+} // namespace profiler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PROFILER_CONVERT_HLO_TO_TOOLS_DATA_H_
diff --git a/tensorflow/core/profiler/convert/xplane_to_hlo.cc b/tensorflow/core/profiler/convert/xplane_to_hlo.cc
index 5f360a6..10a62ef 100644
--- a/tensorflow/core/profiler/convert/xplane_to_hlo.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_hlo.cc
@@ -41,6 +41,11 @@
} // namespace
+std::string GetHloProtoFileName(const absl::string_view base_dir,
+ const absl::string_view module_name) {
+ return ProfilerJoinPath(base_dir, absl::StrCat(module_name, kHloProtoSuffix));
+}
+
Status GetHloProtoFromMultiXSpaceAndSaveToFile(
const std::vector<XSpace>& xspaces,
const std::vector<std::string>& xspace_file_names) {
diff --git a/tensorflow/core/profiler/convert/xplane_to_hlo.h b/tensorflow/core/profiler/convert/xplane_to_hlo.h
index 098f318..2616394 100644
--- a/tensorflow/core/profiler/convert/xplane_to_hlo.h
+++ b/tensorflow/core/profiler/convert/xplane_to_hlo.h
@@ -19,12 +19,17 @@
#include <string>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
namespace tensorflow {
namespace profiler {
+// Helper function to get the filename of a HLO proto given HLO module name.
+std::string GetHloProtoFileName(const absl::string_view base_dir,
+ const absl::string_view module_name);
+
// Extracts and deduplicates the HLO protos from all the XSpace <xspaces>.
// Stores the HLO protos as file in the same directory as the xspace files.
Status GetHloProtoFromMultiXSpaceAndSaveToFile(
diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc
index 5018564..6a3c045 100644
--- a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc
@@ -23,6 +23,7 @@
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/profiler/convert/hlo_to_tools_data.h"
#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h"
#include "tensorflow/core/profiler/convert/op_stats_to_op_profile.h"
#include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h"
@@ -268,6 +269,8 @@
return std::make_pair("", true);
} else if (tool_name == "op_profile") {
return ConvertMultiXSpacesToOpProfileViewer(xspaces);
+ } else if (tool_name == "memory_viewer") {
+ return ConvertHloProtoToToolData(filenames, tool_name, options);
} else {
LOG(WARNING) << "Can not find tool: " << tool_name << ". Please update to "
<< "the latest version of Tensorflow.";
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index b34724f..c1a1a4d 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3319,6 +3319,7 @@
"//tensorflow/core/debug",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/platform:stacktrace_handler",
+ "//tensorflow/core/profiler/convert:hlo_to_tools_data_impl",
"//tensorflow/core/profiler/rpc:profiler_server_impl",
"//tensorflow/core/profiler/rpc/client:profiler_client_impl",
"//tensorflow/core/profiler/internal:print_model_analysis",
@@ -3419,6 +3420,7 @@
"//tensorflow/core/platform:tensor_float_32_utils", # tensor_float_32
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
"//tensorflow/core/profiler/backends/cpu:traceme_recorder_impl", # profiler
+ "//tensorflow/core/profiler/convert:hlo_to_tools_data_impl", # profiler
"//tensorflow/core/profiler/lib:profiler_session_impl", # profiler
"//tensorflow/core/profiler/rpc:profiler_server_impl", # profiler
"//tensorflow/core/profiler/rpc/client:profiler_client_impl", # profiler
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index 330f9b0..640da9b 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -372,6 +372,9 @@
tensorflow::profiler::ProfilerServer::StartProfilerServer
tensorflow::profiler::ProfilerServer::~ProfilerServer
+[//tensorflow/core/profiler/convert:hlo_to_tools_data_impl] # profiler
+tensorflow::profiler::ConvertHloProtoToToolData
+
[//tensorflow/core/profiler/rpc/client:profiler_client_impl] # profiler
tensorflow::profiler::ProfileGrpc
tensorflow::profiler::NewSessionGrpc