blob: a4b87c2464e6636138b87b99143b1a40e119c8f1 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/raw_os_ostream.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
namespace {
// NOLINTNEXTLINE
llvm::cl::opt<bool> emit_use_tuple_arg(
"emit-use-tuple-args",
llvm::cl::desc(
"Emit HLO modules using tuples as args for the entry computation"),
llvm::cl::init(false));
// NOLINTNEXTLINE
llvm::cl::opt<bool> emit_return_tuple(
"emit-return-tuple",
llvm::cl::desc("Emit HLO modules with entry computations returning tuple"),
llvm::cl::init(false));
// NOLINTNEXTLINE
llvm::cl::opt<bool> optimize_xla_hlo(
"optimize-xla-hlo",
llvm::cl::desc("Enable optimizations when translating XLA HLO -> LHLO"),
llvm::cl::init(true));
// NOLINTNEXTLINE
llvm::cl::opt<bool> legalize_node_names(
"legalize-node-names",
llvm::cl::desc("Legalize nodes names when translating MHLO->XLA HLO"),
llvm::cl::init(true));
// NOLINTNEXTLINE
llvm::cl::opt<bool> with_layouts(
"with-layouts",
llvm::cl::desc("Propagate layouts when translating MHLO->XLA HLO"),
llvm::cl::init(false));
// NOLINTNEXTLINE
llvm::cl::opt<bool> print_layouts(
"print-layouts", llvm::cl::desc("Print layouts in the generated HLO text"),
llvm::cl::init(false));
// NOLINTNEXTLINE
llvm::cl::opt<bool> via_builder(
"via-builder", llvm::cl::desc("Translate MHLO->XLA HLO via XLA Builder"),
llvm::cl::init(false));
// NOLINTNEXTLINE
llvm::cl::opt<bool> import_all_computations(
"hlo-import-all-computations",
llvm::cl::desc("Enable importing unreachable computations."));
} // namespace
namespace xla {
static mlir::LogicalResult MlirHloToHloTranslateFunction(
mlir::ModuleOp module, llvm::raw_ostream& output) {
if (!module) return mlir::failure();
HloProto hloProto;
Status status = mlir::ConvertMlirHloToHlo(
module, &hloProto, emit_use_tuple_arg, emit_return_tuple);
if (!status.ok()) {
LOG(ERROR) << "Module conversion failed: " << status;
return mlir::failure();
}
output << hloProto.DebugString();
return mlir::success();
}
static StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
const HloProto& hlo_proto) {
const HloModuleProto& module_proto = hlo_proto.hlo_module();
TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
HloModule::CreateModuleConfigFromProto(
module_proto, GetDebugOptionsFromFlags()));
return HloModule::CreateFromProto(module_proto, module_config);
}
// Wraps BuildHloFromMlirHlo to output an HloProto that's the same as
// ConvertMlirHloToHlo.
Status ConvertMlirHloToHloViaBuilder(mlir::ModuleOp module,
::xla::HloProto* hlo_proto,
mlir::MlirToHloConversionOptions options) {
mlir::func::FuncOp main = module.lookupSymbol<mlir::func::FuncOp>("main");
mlir::Block& block = main.getRegion().front();
xla::XlaBuilder builder("main");
// Create xla_params.
std::vector<xla::XlaOp> xla_params;
for (mlir::BlockArgument& arg : block.getArguments()) {
auto num = arg.getArgNumber();
xla::Shape shape = xla::TypeToShape(arg.getType());
XlaOp argop =
xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
xla_params.push_back(argop);
}
std::vector<xla::XlaOp> returns(1);
TF_RETURN_IF_ERROR(
mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns, options));
xla::XlaOp return_value;
if (returns.size() == 1)
return_value = returns[0];
else if (returns.size() > 1)
return_value = xla::Tuple(&builder, returns);
TF_ASSIGN_OR_RETURN(
xla::XlaComputation computation,
return_value.valid() ? builder.Build(return_value) : builder.Build());
auto hlo_module = computation.proto();
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
return Status::OK();
}
static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
mlir::ModuleOp module, llvm::raw_ostream& output) {
if (!module) return mlir::failure();
HloProto hloProto;
mlir::MlirToHloConversionOptions options;
options.propagate_layouts = with_layouts;
options.legalize_node_names = legalize_node_names;
Status status =
via_builder
? ConvertMlirHloToHloViaBuilder(module, &hloProto, options)
: mlir::ConvertMlirHloToHlo(module, &hloProto, emit_use_tuple_arg,
emit_return_tuple,
/*shape_determination_fns=*/{}, options);
if (!status.ok()) {
LOG(ERROR) << "Module conversion failed: " << status;
return mlir::failure();
}
auto statusOrHloModule = HloModuleFromProto(hloProto);
if (!statusOrHloModule.ok()) {
LOG(ERROR) << "Conversion to HLO module failed: "
<< statusOrHloModule.status();
return mlir::failure();
}
HloModule* hlo_module = statusOrHloModule.ValueOrDie().get();
output << hlo_module->ToString(
HloPrintOptions().set_include_layout_in_shapes(print_layouts));
// Output alias information as comments in the HLO text.
hlo_module->input_output_alias_config().ForEachAlias(
[&](const ShapeIndex& output_index,
const HloInputOutputAliasConfig::Alias& alias) {
output << "// OutputIndex " << output_index.ToString()
<< " aliases with input " << alias.parameter_number << " at "
<< alias.parameter_index.ToString() << "\n";
});
return mlir::success();
}
} // namespace xla
//----------------------------------------------------------------------------//
// Hooks for tf-mlir-translate
//----------------------------------------------------------------------------/
static mlir::OwningOpRef<mlir::ModuleOp> HloToMlirHloTranslate(
llvm::StringRef input, mlir::MLIRContext* context) {
return xla::HloToMlirHloTranslateFunction(input, context,
import_all_computations);
}
static mlir::OwningOpRef<mlir::ModuleOp> HloTextToMlirHloTranslate(
llvm::StringRef input, mlir::MLIRContext* context) {
return xla::HloTextToMlirHloTranslateFunction(input, context,
import_all_computations);
}
static void RegisterInputDialects(mlir::DialectRegistry& registry) {
registry.insert<mlir::arith::ArithmeticDialect, mlir::func::FuncDialect,
mlir::mhlo::MhloDialect, mlir::tensor::TensorDialect>();
}
static mlir::TranslateFromMLIRRegistration MlirHloToHloTranslate(
"mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction,
RegisterInputDialects);
static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate(
"mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction,
RegisterInputDialects);
static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
"hlo-to-mlir-hlo", HloToMlirHloTranslate);
static mlir::TranslateToMLIRRegistration HloTextToHloMlirTranslate(
"hlo-text-to-mlir-hlo", HloTextToMlirHloTranslate);
// MHLO doesn't support explicit layouts, while XLA service does.
// TODO(timshen): remove it once MHLO supports explicit layouts.
static mlir::TranslateToMLIRRegistration HloTextToLhloMlirTranslate(
"hlo-text-to-lhlo", [](llvm::StringRef input, mlir::MLIRContext* context) {
return mlir::HloTextToLhloTranslateFunction(input, context,
optimize_xla_hlo);
});