blob: 8bd875804578b23d71cef101de19f5b7aa36755d [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Emits an HLO module in a text form suitable for diffing.
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
namespace {
Status OverrideConfig(const hlo_module_loader_details::Config& ovr_config,
HloModuleConfig* config) {
config->set_replica_count(ovr_config.num_replicas);
config->set_num_partitions(ovr_config.num_partitions);
return ::tensorflow::OkStatus();
}
} // namespace
std::string StripLogHeaders(const std::string& hlo_string) {
// I0521 12:04:45.883483 1509 service.cc:186] ...
static RE2* matcher = new RE2(
"[IWEF]\\d{4} "
"\\d{2}:\\d{2}:\\d{2}\\.\\d+\\s+\\d+\\s+[^:]+:\\d+\\]\\s?(.*)");
absl::string_view matches[4];
std::vector<std::string> lines = absl::StrSplit(hlo_string, '\n');
for (auto& line : lines) {
if (matcher->Match(line, 0, line.size(), RE2::ANCHOR_START, matches, 4)) {
line = std::string(matches[1]);
}
}
return absl::StrJoin(lines, "\n",
[](std::string* out, const std::string& line) {
absl::StrAppend(out, line);
});
}
StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
const std::string& data, const std::string& format,
hlo_module_loader_details::Config ovr_config,
const std::function<void(HloModuleConfig*)>& config_modifier_hook) {
DebugOptions debug_options = GetDebugOptionsFromFlags();
std::unique_ptr<HloModule> module;
if (format == "hlo" || format == "txt") {
std::string hlo_string = StripLogHeaders(data);
HloModuleConfig config;
config.set_debug_options(debug_options);
TF_RETURN_IF_ERROR(OverrideConfig(ovr_config, &config));
if (config_modifier_hook) {
config_modifier_hook(&config);
}
TF_ASSIGN_OR_RETURN(module,
ParseAndReturnUnverifiedModule(hlo_string, config));
} else {
HloSnapshot proto;
if (format == "pb") {
if (!proto.ParseFromString(data) &&
!proto.mutable_hlo()->ParseFromString(data) &&
!proto.mutable_hlo()->mutable_hlo_module()->ParseFromString(data)) {
return InvalidArgument("Failed to parse input as HLO protobuf binary");
}
} else if (format == "pbtxt") {
if (!tensorflow::protobuf::TextFormat::ParseFromString(data, &proto) &&
!tensorflow::protobuf::TextFormat::ParseFromString(
data, proto.mutable_hlo()) &&
!tensorflow::protobuf::TextFormat::ParseFromString(
data, proto.mutable_hlo()->mutable_hlo_module())) {
return InvalidArgument("Failed to parse input as HLO protobuf text");
}
} else {
return InvalidArgument(
"Invalid format from file extension: '%s'. Expected: hlo, txt, pb, "
"or pbtxt",
format);
}
TF_ASSIGN_OR_RETURN(HloModuleConfig config,
HloModule::CreateModuleConfigFromProto(
proto.hlo().hlo_module(), debug_options));
TF_RETURN_IF_ERROR(OverrideConfig(ovr_config, &config));
if (config_modifier_hook) {
config_modifier_hook(&config);
}
TF_ASSIGN_OR_RETURN(
module, HloModule::CreateFromProto(proto.hlo().hlo_module(), config));
}
return std::move(module);
}
StatusOr<std::unique_ptr<HloModule>> LoadModuleFromFile(
const std::string& path, hlo_module_loader_details::Config ovr_config,
std::string format,
const std::function<void(HloModuleConfig*)>& config_modifier_hook) {
std::string data;
if (format.empty()) {
format = std::string(tensorflow::io::Extension(path));
}
TF_RETURN_IF_ERROR(
tensorflow::ReadFileToString(tensorflow::Env::Default(), path, &data));
return LoadModuleFromData(data, format, ovr_config, config_modifier_hook);
}
} // namespace xla