blob: 92fa8e2d2d81f2dcc373204f73997cc90d120da0 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/run_hlo_module.h"
#include <functional>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/lib/testing.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/error_spec.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_comparison.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/tools/hlo_control_flow_flattening.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/compiler/xla/tools/prepare_reference_module.h"
#include "tensorflow/compiler/xla/tools/run_hlo_module.pb.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/status.h"
namespace xla {
namespace {
// Writes the given literal to a file in the test temporary directory.
void WriteLiteralToTempFile(const LiteralSlice& literal,
const std::string& name) {
// Bazel likes for tests to write "debugging outputs" like these to
// TEST_UNDECLARED_OUTPUTS_DIR. This plays well with tools that inspect test
// results, especially when they're run on remote machines.
auto* env = tensorflow::Env::Default();
std::string binary_filename;
std::string text_filename;
std::string outdir;
if (tensorflow::io::GetTestUndeclaredOutputsDir(&outdir)) {
std::string filename = tensorflow::io::JoinPath(
outdir, absl::StrFormat("tempfile-%d-%s", env->NowMicros(), name));
binary_filename = absl::StrCat(filename, ".pb");
text_filename = absl::StrCat(filename, ".txt");
} else {
binary_filename =
tensorflow::io::GetTempFilename(absl::StrCat(name, ".pb"));
text_filename = tensorflow::io::GetTempFilename(absl::StrCat(name, ".txt"));
}
TF_CHECK_OK(
tensorflow::WriteBinaryProto(env, binary_filename, literal.ToProto()));
TF_CHECK_OK(
tensorflow::WriteStringToFile(env, text_filename, literal.ToString()));
LOG(ERROR) << "wrote Literal to " << name << " binary: " << binary_filename
<< " text: " << text_filename;
}
// Callback helper that dumps literals to temporary files in the event of a
// miscomparison.
void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual,
const LiteralSlice& mismatches,
const ShapeIndex& /*shape_index*/,
const literal_comparison::ErrorBuckets& /*error_buckets*/) {
LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " "
<< literal_comparison::ToStringTruncated(expected);
LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " "
<< literal_comparison::ToStringTruncated(actual);
LOG(INFO) << "Dumping literals to temp files...";
WriteLiteralToTempFile(expected, "expected");
WriteLiteralToTempFile(actual, "actual");
WriteLiteralToTempFile(mismatches, "mismatches");
}
Literal ExecuteWithRunner(std::unique_ptr<HloModule> module,
absl::Span<const Literal> args,
HloRunnerInterface* runner, bool run_hlo_passes) {
TF_QCHECK_OK(VerifyHloModule(module.get(), /*layout_sensitive=*/false,
/*allow_mixed_precision=*/true))
<< " (on " << runner->Name() << ")";
std::cerr << "Running HLO module with runner " << runner->Name() << "...\n";
XLA_VLOG_LINES(1, module->ToString());
const auto start = std::chrono::high_resolution_clock::now();
auto result_status = runner->Execute(std::move(module), args, run_hlo_passes);
const auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
std::cerr << "... compiled and ran in " << diff.count() << "s.\n";
TF_QCHECK_OK(result_status.status())
<< "Failed to execute on " << runner->Name() << "\n";
return result_status.ConsumeValueOrDie();
}
} // namespace
Status RunAndCompare(
std::unique_ptr<HloModule> test_module, HloRunnerInterface* test_runner,
HloRunnerInterface* reference_runner, std::minstd_rand0* engine,
const RunHloModuleOptions& options,
xla::RunHloModuleIterationLiterals* iteration_literals_proto,
std::function<Status(const HloModule&, HloRunnerInterface*, HloModule*)>
reference_module_modifier_hook,
std::function<void(HloModuleConfig*)> config_modifier_hook) {
if (!config_modifier_hook) {
config_modifier_hook = [](HloModuleConfig* config) {
config->set_seed(42);
};
}
if (options.flatten_control_flow) {
HloControlFlowFlattening control_flow_flattening(
HloControlFlowFlattening::Options{/*while_execution_count=*/1});
TF_RETURN_IF_ERROR(control_flow_flattening.Run(test_module.get()).status());
}
const HloModuleProto test_module_proto = test_module->ToProto();
std::vector<Literal> args = MakeFakeArguments(test_module.get(), engine,
options.use_large_float_range)
.ConsumeValueOrDie();
// Use provided input literals as arguments, if any.
if (iteration_literals_proto != nullptr &&
iteration_literals_proto->arguments_size() != 0) {
if (iteration_literals_proto->arguments_size() != args.size()) {
return xla::InvalidArgument(
"Failed to use input literals as arguments; mismatched "
"number of expected arguments.");
} else {
for (int i = 0; i < args.size(); ++i) {
if (!literal_comparison::EqualShapes(
xla::Shape(args[i].shape()),
xla::Shape(iteration_literals_proto->arguments(i).shape()))
.ok()) {
return xla::InvalidArgument(
"Failed to use input literals for argument %d "
"because of a shape mismatch.",
i);
}
TF_ASSIGN_OR_RETURN(args[i],
xla::Literal::CreateFromProto(
iteration_literals_proto->arguments(i)));
}
}
}
if (options.print_literals) {
for (int i = 0; i < args.size(); ++i) {
std::cout << "\n** Argument " << i << " **\n"
<< args[i].ToString() << "\n";
}
}
if (iteration_literals_proto != nullptr &&
iteration_literals_proto->arguments_size() == 0) {
for (int i = 0; i < args.size(); ++i) {
*iteration_literals_proto->add_arguments() = args[i].ToProto();
}
}
std::unique_ptr<HloModule> reference_module;
if (reference_runner != nullptr) {
// PrepareReferenceModule needs to know the *test* runner, in order to
// properly match the test runner's numerics.
reference_module =
PrepareReferenceModule(*test_module, test_runner, config_modifier_hook,
reference_module_modifier_hook)
.ConsumeValueOrDie();
}
Literal test_result = ExecuteWithRunner(
std::move(test_module), args, test_runner, options.run_test_hlo_passes);
if (options.print_literals) {
std::cout << "\n** Result with test runner " << test_runner->Name()
<< " **\n"
<< test_result.ToString() << "\n";
}
if (iteration_literals_proto != nullptr) {
LiteralProto test_result_proto = test_result.ToProto();
iteration_literals_proto->mutable_result()->Swap(&test_result_proto);
}
if (reference_module == nullptr) {
std::cerr << "Skipping reference runner\n";
return ::tensorflow::OkStatus();
}
Literal reference_result =
ExecuteWithRunner(std::move(reference_module), args, reference_runner,
options.run_reference_hlo_passes);
if (options.print_literals) {
std::cout << "\n** Result with reference runner "
<< reference_runner->Name() << " **\n"
<< reference_result.ToString() << "\n";
}
if (iteration_literals_proto != nullptr) {
LiteralProto reference_result_proto = reference_result.ToProto();
iteration_literals_proto->mutable_reference_result()->Swap(
&reference_result_proto);
}
ErrorSpec error_spec(static_cast<float>(options.abs_error_bound),
static_cast<float>(options.rel_error_bound));
return literal_comparison::Near(/*expected=*/reference_result,
/*actual=*/test_result,
/*error=*/error_spec,
/*detailed_message=*/true, &OnMiscompare);
}
Status RunAndCompare(
const std::string& hlo_filename, HloRunnerInterface* test_runner,
HloRunnerInterface* reference_runner, std::minstd_rand0* engine,
const RunHloModuleOptions& options,
xla::RunHloModuleIterationLiterals* iteration_literals_proto,
std::function<Status(const HloModule&, HloRunnerInterface*, HloModule*)>
reference_module_modifier_hook,
std::function<void(HloModuleConfig*)> config_modifier_hook) {
std::unique_ptr<HloModule> test_module =
LoadModuleFromFile(hlo_filename, hlo_module_loader_details::Config(),
options.input_format, config_modifier_hook)
.ValueOrDie();
return RunAndCompare(std::move(test_module), test_runner, reference_runner,
engine, options, iteration_literals_proto,
reference_module_modifier_hook, config_modifier_hook);
}
} // namespace xla