blob: 868dc03c62865f05664895399f23fc6ce07095f1 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
#include <memory>
#include <string>
#include "absl/container/flat_hash_set.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_os_ostream.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
auto* mlir_function_pass_failed_fallback = monitoring::Counter<0>::New(
"/tensorflow/core/mlir_pass_failed_fallback",
"Failure count of MLIR pass runs when fallback used");
auto* mlir_function_pass_succeeded_fallback = monitoring::Counter<0>::New(
"/tensorflow/core/mlir_pass_succeeded_fallback",
"Success count of MLIR pass runs when fallback enabled");
static inline absl::string_view StringRefToView(llvm::StringRef ref) {
return {ref.data(), ref.size()};
}
// Dumps the MLIR module to disk.
// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can
// be created).
static void DumpModule(mlir::ModuleOp module, std::string file_prefix) {
std::string prefix = GetDumpDirFromEnvVar();
if (prefix.empty()) return;
auto* env = tensorflow::Env::Default();
auto status = env->RecursivelyCreateDir(prefix);
if (!status.ok()) {
LOG(WARNING) << "cannot create directory '" + prefix +
"': " + status.error_message();
return;
}
prefix += "/" + file_prefix;
if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) {
LOG(WARNING) << "cannot create unique filename, won't dump MLIR module.";
return;
}
std::unique_ptr<WritableFile> file_writer;
status = env->NewWritableFile(prefix, &file_writer);
if (!status.ok()) {
LOG(WARNING) << "cannot open file '" + prefix +
"': " + status.error_message();
return;
}
// Print the module to a string before writing to the file.
std::string txt_module;
{
llvm::raw_string_ostream os(txt_module);
module.print(os);
}
status = file_writer->Append(txt_module);
if (!status.ok()) {
LOG(WARNING) << "error writing to file '" + prefix +
"': " + status.error_message();
return;
}
(void)file_writer->Close();
VLOG(1) << "Dumped MLIR module to " << prefix;
}
MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() {
static auto* global = new MlirOptimizationPassRegistry();
return *global;
}
static void RegisterDialects(mlir::DialectRegistry& registry) {
// clang-format off
registry.insert<mlir::StandardOpsDialect,
mlir::TF::TensorFlowDialect,
mlir::shape::ShapeDialect,
mlir::tf_device::TensorFlowDeviceDialect,
mlir::tf_executor::TensorFlowExecutorDialect>();
// clang-format on
}
Status MlirFunctionOptimizationPass::Run(
const DeviceSet& device_set, const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) {
// overall_state equals to:
// Enabled if at least one pass is Enabled.
// Disabled if all passes are Disabled.
// FallbackEnabled if there are no Enabled passes and there is at least one
// FallbackEnabled pass.
MlirOptimizationPassState overall_state = MlirOptimizationPassState::Disabled;
// Cache per pass state and reuse it during pass execution.
std::vector<MlirOptimizationPassState> per_pass_state;
per_pass_state.reserve(registry_->passes().size());
int num_passes_enabled = 0, num_passes_disabled = 0,
num_passes_fallback_enabled = 0;
for (const auto& pass_registration : registry_->passes()) {
MlirOptimizationPassState pass_state = pass_registration.pass->GetPassState(
&device_set, config_proto, **graph, *flib_def);
per_pass_state.push_back(pass_state);
switch (pass_state) {
case MlirOptimizationPassState::FallbackEnabled: {
if (overall_state != MlirOptimizationPassState::Enabled)
overall_state = MlirOptimizationPassState::FallbackEnabled;
++num_passes_fallback_enabled;
break;
}
case MlirOptimizationPassState::Enabled: {
overall_state = MlirOptimizationPassState::Enabled;
++num_passes_enabled;
break;
}
case MlirOptimizationPassState::Disabled: {
++num_passes_disabled;
break;
}
}
}
// Capture stats on graph properties analyzed before running the MLIR bridge.
// We set `uses_uninitialized_resource_args` to false here because function
// optimization is not affected by uninitialized resource args.
GetMlirBridgeRolloutPolicy(**graph, flib_def, config_proto,
/*uses_uninitialized_resource_args=*/false,
/*record_stats=*/true);
if (overall_state == MlirOptimizationPassState::Disabled) {
LOG_FIRST_N(INFO, 1) << "None of the MLIR Optimization Passes are enabled "
<< "(registered " << registry_->passes().size() << ")";
return Status::OK();
}
LOG_FIRST_N(INFO, 1) << "MLIR Graph Optimization Passes."
<< " Enabled: " << num_passes_enabled
<< ", Disabled: " << num_passes_disabled
<< ", FallbackEnabled: " << num_passes_fallback_enabled
<< ", Total: " << registry_->passes().size();
GraphDebugInfo debug_info;
mlir::DialectRegistry registry;
RegisterDialects(registry);
mlir::MLIRContext context(registry);
GraphImportConfig import_config;
import_config.graph_as_function = true;
import_config.control_outputs = *control_ret_node_names;
import_config.upgrade_legacy = true;
// Disable shape inference during import as some TensorFlow op fails during
// shape inference with dynamic shaped operands. This in turn causes the
// import to fail. Shape inference during import is going to be removed and
// the shape inference pass is run early in the pass pipeline, shape inference
// during import is not necessary.
import_config.enable_shape_inference = false;
auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def,
import_config, &context);
if (!module_ref_status.ok()) {
// If at least one pass is enabled, return failure to the caller
// immediately.
if (overall_state == MlirOptimizationPassState::Enabled) {
return module_ref_status.status();
}
// Do not fail, just keep the original TF graph unchanged in fallback mode.
return Status::OK();
}
mlir::OwningModuleRef module_ref = std::move(module_ref_status.ValueOrDie());
AddDevicesToOp(*module_ref, &device_set);
int per_pass_state_index = 0;
for (auto& pass_registration : registry_->passes()) {
llvm::StringRef name = pass_registration.pass->name();
VLOG(2) << "Run MLIR graph optimization pass: " << StringRefToView(name);
if (VLOG_IS_ON(1)) {
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
}
Status pass_status = Status::OK();
auto pass_state = per_pass_state[per_pass_state_index++];
if (pass_state == MlirOptimizationPassState::Enabled) {
pass_status = pass_registration.pass->Run(config_proto, *module_ref,
**graph, *flib_def);
} else if (pass_state == MlirOptimizationPassState::FallbackEnabled) {
// Make sure when the pass is FallbackEnabled, it only modifies the MLIR
// module in case of no failures.
auto module_ref_clone = module_ref->clone();
pass_status = pass_registration.pass->Run(config_proto, module_ref_clone,
**graph, *flib_def);
if (pass_status.ok())
module_ref = module_ref_clone;
else
module_ref_clone->destroy();
}
if (!pass_status.ok()) {
// If pass failed and it is:
// FallbackEnabled - only collect metrics, do not propagate
// error to the caller.
// Enabled - return error back to the caller.
if (pass_state == MlirOptimizationPassState::FallbackEnabled) {
LOG(WARNING) << StringRefToView(name)
<< " pass failed, continuing without the pass because the "
"pass has fallback enabled";
mlir_function_pass_failed_fallback->GetCell()->IncrementBy(1);
} else if (pass_state == MlirOptimizationPassState::Enabled) {
return pass_status;
}
} else {
if (pass_state == MlirOptimizationPassState::FallbackEnabled) {
mlir_function_pass_succeeded_fallback->GetCell()->IncrementBy(1);
}
}
if (VLOG_IS_ON(1)) {
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
}
}
GraphExportConfig export_config;
absl::flat_hash_set<Node*> control_ret_nodes;
// Some or all passes are enabled. Convert MLIR module and return back
// resulted graph.
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
&control_ret_nodes),
"Error converting MLIR module back to graph");
control_ret_node_names->clear();
control_ret_node_names->reserve(control_ret_nodes.size());
for (const auto* node : control_ret_nodes)
control_ret_node_names->push_back(node->name());
*control_rets_updated = true;
return Status::OK();
}
MlirV1CompatOptimizationPassRegistry&
MlirV1CompatOptimizationPassRegistry::Global() {
static auto* global = new MlirV1CompatOptimizationPassRegistry();
return *global;
}
Status MlirV1CompatGraphOptimizationPass::Run(
const GraphOptimizationPassOptions& options) {
// Skip function graphs as MlirOptimizationPassRegistry_ will be used instead.
// Skip if no underlying pass was registered.
if (options.is_function_graph || !registry_->pass()) return Status::OK();
auto pass = registry_->pass();
auto pass_state =
pass->GetPassState(options.device_set, options.session_options->config,
**options.graph, *options.flib_def);
if (pass_state == MlirOptimizationPassState::Disabled) {
LOG_FIRST_N(INFO, 1) << "MLIR V1 optimization pass is not enabled";
return Status::OK();
}
LOG_FIRST_N(INFO, 1) << "Running MLIR Graph Optimization V1 Compat Pass";
GraphDebugInfo debug_info;
mlir::DialectRegistry registry;
RegisterDialects(registry);
mlir::MLIRContext context(registry);
GraphImportConfig import_config;
import_config.upgrade_legacy = true;
// Restrict functionalization to TPU nodes to avoid problems in v1 session
// runtime.
import_config.restrict_functionalization_to_tpu_nodes = true;
auto module_ref_status = ConvertGraphToMlir(
**options.graph, debug_info, *options.flib_def, import_config, &context);
if (!module_ref_status.ok()) {
return (pass_state == MlirOptimizationPassState::Enabled)
? module_ref_status.status()
: Status::OK();
}
mlir::OwningModuleRef module_ref = std::move(module_ref_status.ValueOrDie());
AddDevicesToOp(*module_ref, options.device_set);
llvm::StringRef name = pass->name();
VLOG(2) << "Run MLIR V1 graph optimization pass: " << StringRefToView(name);
if (VLOG_IS_ON(1)) {
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
}
Status pass_status = pass->Run(options, *module_ref);
if (!pass_status.ok()) {
if (pass_state == MlirOptimizationPassState::Enabled) return pass_status;
if (pass_state == MlirOptimizationPassState::FallbackEnabled) {
LOG(WARNING) << StringRefToView(name)
<< " pass failed, continuing without the pass because the "
"pass has fallback enabled";
mlir_function_pass_failed_fallback->GetCell()->IncrementBy(1);
return Status::OK();
}
} else {
if (pass_state == MlirOptimizationPassState::FallbackEnabled) {
mlir_function_pass_succeeded_fallback->GetCell()->IncrementBy(1);
}
}
if (VLOG_IS_ON(1)) {
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
}
GraphExportConfig export_config;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirToGraph(*module_ref, export_config, options.graph,
options.flib_def),
"Error converting MLIR module back to graph");
return Status::OK();
}
} // namespace tensorflow