blob: b44e3b461e291549cf5133f99734e516235cefb7 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include <unordered_map>
#include "absl/algorithm/container.h"
#include "absl/types/optional.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/statusor.h"
using llvm::APInt;
using llvm::makeArrayRef;
using mlir::DenseIntElementsAttr;
using mlir::NamedAttribute;
using mlir::Operation;
using mlir::RankedTensorType;
using mlir::Type;
using mlir::Value;
using mlir::func::FuncOp;
namespace xla {
namespace {
constexpr char kShardingAttr[] = "mhlo.sharding";
// Note: This sanitization function causes an irreversible many-to-one mapping
// and any solution to mitigate this would cause issues with the reverse
// direction. Longterm solution is to add a function attribute to maintain the
// original HLO naming.
std::string SanitizeFunctionName(llvm::StringRef name) {
std::string output(name);
llvm::for_each(output, [](char& x) { x = x == '-' ? '_' : x; });
return output;
}
// Returns whether the instruction is a default dot operation.
bool DotIsDefault(const HloInstruction* instruction) {
const auto& operands = instruction->operands();
// eg. vector[3] dot matrix[3, 2] => [2] not default dot
if (operands[0]->shape().rank() < operands[1]->shape().rank()) {
return false;
}
auto dnums = instruction->dot_dimension_numbers();
DotDimensionNumbers default_dimension_numbers;
default_dimension_numbers.add_lhs_contracting_dimensions(
instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1);
default_dimension_numbers.add_rhs_contracting_dimensions(0);
return xla::protobuf_util::ProtobufEquals(dnums, default_dimension_numbers);
}
// Returns an MLIR Location generated from HLO Instruction. Uses instruction
// metadata if present or instruction name.
mlir::Location GenerateInstructionLocation(const HloInstruction* instruction,
mlir::OpBuilder* func_builder) {
const std::string& op_name = instruction->metadata().op_name();
if (op_name.empty()) {
return mlir::NameLoc::get(func_builder->getStringAttr(instruction->name()));
}
mlir::Location op_name_loc =
mlir::NameLoc::get(func_builder->getStringAttr(op_name));
const std::string& source_file = instruction->metadata().source_file();
if (source_file.empty()) {
return op_name_loc;
}
return func_builder->getFusedLoc(
{op_name_loc,
mlir::FileLineColLoc::get(func_builder->getContext(), source_file,
instruction->metadata().source_line(), 0)});
}
// Clean up the GetTupleElementOp, created during the flattening of
// tuple arguments and return values, if eligible for folding. Removal of
// get-tuple-element can transitively make the defining TupleOp dead to be
// removed subsequently.
void CleanUpTupleOps(mlir::Block* block, mlir::OpBuilder* builder) {
bool changed = true;
llvm::SmallVector<Value> folded_results;
while (changed) {
changed = false;
for (Operation& op : llvm::make_early_inc_range(block->getOperations())) {
if (llvm::isa<mlir::mhlo::GetTupleElementOp>(op)) {
folded_results.clear();
if (failed(builder->tryFold(&op, folded_results))) continue;
op.replaceAllUsesWith(folded_results);
op.erase();
changed = true;
} else if (llvm::isa<mlir::mhlo::TupleOp>(op) &&
mlir::isOpTriviallyDead(&op)) {
op.erase();
changed = true;
}
}
}
}
} // namespace
void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands(
mlir::Operation* op, llvm::ArrayRef<mlir::Value> implicit_operands) {
assert((mlir::dyn_cast<mlir::mhlo::IfOp>(*op) ||
mlir::dyn_cast<mlir::mhlo::CaseOp>(*op)) &&
"Unexpected mlir op in "
"HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands!");
int implicit_operand_index = 0;
for (auto& region : op->getRegions()) {
for (auto arg : region.getArguments()) {
assert(implicit_operand_index < implicit_operands.size());
arg.replaceAllUsesWith(implicit_operands[implicit_operand_index++]);
}
region.front().eraseArguments(
llvm::to_vector(llvm::seq<unsigned>(0, region.getNumArguments())));
}
}
mlir::Operation* HloFunctionImporter::CreateTupleFromOpResults(
mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Operation* op,
mlir::Type type) {
if (!type.isa<mlir::TupleType>()) return op;
llvm::SmallVector<Value> flattened_results = op->getResults();
llvm::MutableArrayRef<mlir::Value> flattened_results_ref(flattened_results);
auto result =
CreateTupleValue(func_builder, loc, flattened_results_ref, type);
auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
assert(defining_tuple_op && "builder didn't return the right type");
auto tupleOp = defining_tuple_op.getOperation();
return tupleOp;
}
static bool IsNestedTupleInData(Type type) {
auto tuple_type = type.dyn_cast<mlir::TupleType>();
if (!tuple_type) return false;
assert(tuple_type.getType(1).isa<mlir::mhlo::TokenType>() &&
"Infeed: Non token type");
auto data_type = tuple_type.getType(0);
auto data_tuple_type = data_type.dyn_cast<mlir::TupleType>();
if (!data_tuple_type) return false;
for (auto child_type : data_tuple_type.getTypes()) {
if (child_type.isa<mlir::TupleType>()) return true;
}
return false;
}
void HloFunctionImporter::FlattenTupleType(
Type type, llvm::SmallVectorImpl<Type>& flattened_types) {
auto tuple_type = type.dyn_cast<mlir::TupleType>();
if (!tuple_type) {
flattened_types.push_back(type);
return;
}
for (auto child_type : tuple_type.getTypes()) {
FlattenTupleType(child_type, flattened_types);
}
}
void HloFunctionImporter::FlattenTupleValue(
mlir::OpBuilder* func_builder, mlir::Location loc, Value value,
llvm::SmallVectorImpl<Value>& flattened_values) {
auto tuple_type = value.getType().dyn_cast<mlir::TupleType>();
if (!tuple_type) {
flattened_values.push_back(value);
return;
}
int flattenIdx = 0;
for (auto child_type : tuple_type.getTypes()) {
auto sub_value = func_builder->create<mlir::mhlo::GetTupleElementOp>(
loc, child_type, value, func_builder->getI32IntegerAttr(flattenIdx++));
FlattenTupleValue(func_builder, loc, sub_value, flattened_values);
}
}
Value HloFunctionImporter::CreateTupleValue(
mlir::OpBuilder* func_builder, mlir::Location loc,
llvm::MutableArrayRef<Value>& flatten_values, Type type) {
auto tuple_type = type.dyn_cast<mlir::TupleType>();
if (!tuple_type) {
assert(!flatten_values.empty());
auto retval = flatten_values.front();
flatten_values = flatten_values.drop_front();
return retval;
}
llvm::SmallVector<mlir::Value> flatten_sub_values;
for (auto child_type : tuple_type.getTypes())
flatten_sub_values.push_back(
CreateTupleValue(func_builder, loc, flatten_values, child_type));
return func_builder->create<mlir::mhlo::TupleOp>(loc, flatten_sub_values)
.getResult();
}
Status HloFunctionImporter::ImportAsFunc(
const HloComputation& computation, mlir::ModuleOp module,
std::unordered_map<const HloComputation*, FuncOp>* function_map,
mlir::Builder* builder) {
HloFunctionImporter importer(module, function_map, builder);
return importer.ImportAsFunc(computation).status();
}
Status HloFunctionImporter::ImportAsRegion(
const xla::HloComputation& computation, mlir::Region* region,
mlir::Builder* builder, bool flatten_region_arg_tuple) {
HloFunctionImporter importer(region->getParentOfType<mlir::ModuleOp>(), {},
builder);
return importer.ImportAsRegion(computation, region, flatten_region_arg_tuple);
}
StatusOr<FuncOp> HloFunctionImporter::ImportAsFunc(
const HloComputation& computation) {
auto& imported = (*function_map_)[&computation];
if (imported) return imported;
llvm::SmallVector<Type, 4> args, rets;
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
auto func_type = mlir::FunctionType::get(context_, args, rets);
std::string computation_name =
computation.parent()->entry_computation() == &computation
? "main"
: SanitizeFunctionName(computation.name());
// Construct the MLIR function and map arguments.
llvm::ArrayRef<mlir::NamedAttribute> attrs;
auto function = FuncOp::create(mlir::UnknownLoc::get(context_),
computation_name, func_type, attrs);
auto visibility = computation_name == "main" ? FuncOp::Visibility::Public
: FuncOp::Visibility::Private;
function.setVisibility(visibility);
module_.push_back(function);
// Add to the map right away for function calls.
imported = function;
mlir::Block* block = function.addEntryBlock();
TF_RETURN_IF_ERROR(ImportInstructions(computation, block,
/*flatten_region_arg_tuple=*/false));
return function;
}
tensorflow::Status HloFunctionImporter::ImportAsRegion(
const HloComputation& computation, mlir::Region* region,
bool flatten_region_arg_tuple) {
auto loc = region->getLoc();
// TODO(hinsu): Store computation name as an attribute for round-trip.
auto* block = new mlir::Block;
region->push_back(block);
llvm::SmallVector<Type, 4> args;
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
// Flatten the tuple-typed arguments.
if (flatten_region_arg_tuple) {
for (auto arg : args) {
llvm::SmallVector<Type> flattened_arg_types;
FlattenTupleType(arg, flattened_arg_types);
block->addArguments(
flattened_arg_types,
mlir::SmallVector<mlir::Location>(flattened_arg_types.size(), loc));
}
} else {
block->addArguments(args,
mlir::SmallVector<mlir::Location>(args.size(), loc));
}
return ImportInstructions(computation, block, flatten_region_arg_tuple);
}
StatusOr<Value> HloFunctionImporter::ImportInstructionsImpl(
const xla::HloComputation& computation,
const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
// Setup the input parameters.
const int num_parameters = computation.num_parameters();
for (int i = 0; i < num_parameters; i++) {
auto hlo_parameter = computation.parameter_instruction(i);
instruction_value_map_[hlo_parameter] = arguments[i];
}
for (auto instruction : computation.MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
TF_ASSIGN_OR_RETURN(
auto new_operation,
ImportInstructionWithLayout(instruction, operands, builder));
if (new_operation) {
instruction_value_map_[instruction] = new_operation->getResult(0);
}
}
// Setup the return type (HLO only supports a single return value).
return GetMlirValue(computation.root_instruction());
}
Status HloFunctionImporter::ImportInstructions(
const HloComputation& computation, mlir::Block* block,
bool flatten_region_arg_tuple) {
llvm::SmallVector<Value, 4> arguments(block->args_begin(), block->args_end());
mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
// TODO(suderman): Add location tracking details.
mlir::Location loc = builder.getUnknownLoc();
Value result;
if (!llvm::isa<FuncOp>(block->getParentOp()) && flatten_region_arg_tuple) {
// 'effective_arguments' stores the mhlo value corresponding to each
// computation parameter. The value could be a BlockArgument, if the
// corresponding computation parameter is non-tuple typed, or a TupleOp,
// otherwise.
llvm::SmallVector<Value> effective_arguments;
llvm::SmallVector<Type> computation_arg_types;
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(),
&computation_arg_types));
int flatten_idx = 0;
for (Type computation_arg_type : computation_arg_types) {
auto orig_tuple_arg_type =
computation_arg_type.dyn_cast<mlir::TupleType>();
// If the computation-parameter type is non-tuple, no action is needed.
if (!orig_tuple_arg_type) {
effective_arguments.push_back(arguments[flatten_idx]);
flatten_idx++;
continue;
}
// For each tuple-typed computation parameter, create a mhlo::TupleOp
// value in the region body, using the already flattened values in
// 'arguments'. For example: With computation parameters: [tuple<T1>,
// tuple<T2, T4>] We have, 'arguments' = [T1 arg1, T2 arg2, T3 arg3] and
// we need to create two tuples tuples, one using arg1, and the other
// using arg2 and arg3.
llvm::SmallVector<Type> flattened_arg_type;
FlattenTupleType(orig_tuple_arg_type, flattened_arg_type);
llvm::MutableArrayRef<Value> sub_args(
arguments.begin() + flatten_idx,
arguments.begin() + flatten_idx + flattened_arg_type.size());
auto tupleVal =
CreateTupleValue(&builder, loc, sub_args, orig_tuple_arg_type);
effective_arguments.push_back(tupleVal);
flatten_idx += flattened_arg_type.size();
}
TF_ASSIGN_OR_RETURN(
result,
ImportInstructionsImpl(computation, effective_arguments, &builder));
} else {
TF_ASSIGN_OR_RETURN(
result, ImportInstructionsImpl(computation, arguments, &builder));
}
// Create terminator op depending on the parent op of this region.
if (llvm::isa<FuncOp>(block->getParentOp())) {
builder.create<mlir::func::ReturnOp>(loc, result);
} else {
if (flatten_region_arg_tuple) {
// Flatten tuples in results of this region.
llvm::SmallVector<Value> flattened_return_operands;
FlattenTupleValue(&builder, loc, result, flattened_return_operands);
builder.create<mlir::mhlo::ReturnOp>(loc, flattened_return_operands);
} else {
builder.create<mlir::mhlo::ReturnOp>(loc, result);
}
}
CleanUpTupleOps(block, &builder);
return tensorflow::Status::OK();
}
StatusOr<Value> HloFunctionImporter::ImportInstructions(
const xla::HloComputation& computation,
const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
mlir::Block* block = builder->getBlock();
if (block == nullptr)
return InvalidArgument(
"ImportInstructions requires a valid block in the builder");
HloFunctionImporter importer(
block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
return importer.ImportInstructionsImpl(computation, arguments, builder);
}
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
const xla::HloInstruction* instr,
const llvm::SmallVectorImpl<mlir::Value>& operands,
mlir::OpBuilder* builder, DynamicShapeHandlingMode mode) {
mlir::Block* block = builder->getBlock();
if (block == nullptr)
return InvalidArgument(
"ImportInstructions requires a valid block in the builder");
HloFunctionImporter importer(
block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
return importer.ImportInstructionWithLayout(instr, operands, builder, mode);
}
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
const HloInstruction* instruction,
const llvm::SmallVectorImpl<mlir::Value>& operands,
mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) {
const Shape& instruction_shape = instruction->shape();
const Shape& shape = mode == DynamicShapeHandlingMode::kConvertToStatic
? xla::ShapeUtil::MakeStaticShape(instruction_shape)
: instruction_shape;
TF_ASSIGN_OR_RETURN(auto result_type,
ConvertShapeToType<RankedTensorType>(shape, *builder_));
mlir::Location loc = GenerateInstructionLocation(instruction, func_builder);
llvm::SmallVector<NamedAttribute, 10> attributes;
if (instruction->has_sharding()) {
attributes.push_back(builder_->getNamedAttr(
kShardingAttr,
builder_->getStringAttr(
instruction->sharding().ToProto().SerializeAsString())));
}
switch (instruction->opcode()) {
case HloOpcode::kParameter: {
return nullptr;
}
case HloOpcode::kConstant: {
const Literal& literal = instruction->literal();
auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
if (!attr.ok()) return attr.status();
mlir::Operation* new_operation =
func_builder->create<mlir::mhlo::ConstOp>(loc, attr.ValueOrDie());
for (auto attr : attributes) {
new_operation->setAttr(attr.getName(), attr.getValue());
}
return new_operation;
}
case HloOpcode::kIota: {
return func_builder
->create<mlir::mhlo::IotaOp>(
loc, result_type,
func_builder->getI64IntegerAttr(
Cast<HloIotaInstruction>(instruction)->iota_dimension()))
.getOperation();
}
case HloOpcode::kBroadcast: {
// Note that the HLO broadcast is more powerful than the XLA broadcast
// op. BroadcastInDim offers a superset of the HLO op's functionality.
attributes.push_back(
builder_->getNamedAttr("broadcast_dimensions",
ConvertDimensions(instruction->dimensions())));
return func_builder
->create<mlir::mhlo::BroadcastInDimOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
attributes.push_back(builder_->getNamedAttr(
"epsilon", builder_->getF32FloatAttr(instruction->epsilon())));
attributes.push_back(builder_->getNamedAttr(
"feature_index",
builder_->getI64IntegerAttr(instruction->feature_index())));
if (instruction->opcode() == HloOpcode::kBatchNormGrad) {
// Flatten the return type if they are tuple-typed.
llvm::SmallVector<Type> flattened_ret_types;
FlattenTupleType(result_type, flattened_ret_types);
auto op = func_builder
->create<mlir::mhlo::BatchNormGradOp>(
loc, flattened_ret_types, operands, attributes)
.getOperation();
return CreateTupleFromOpResults(func_builder, loc, op, result_type);
} else if (instruction->opcode() == HloOpcode::kBatchNormInference) {
return func_builder
->create<mlir::mhlo::BatchNormInferenceOp>(loc, result_type,
operands, attributes)
.getOperation();
} else {
assert(instruction->opcode() == HloOpcode::kBatchNormTraining);
// Flatten the return type if they are tuple-typed.
llvm::SmallVector<Type> flattened_ret_types;
FlattenTupleType(result_type, flattened_ret_types);
auto op = func_builder
->create<mlir::mhlo::BatchNormTrainingOp>(
loc, flattened_ret_types, operands, attributes)
.getOperation();
return CreateTupleFromOpResults(func_builder, loc, op, result_type);
}
case HloOpcode::kDot: {
attributes.push_back(builder_->getNamedAttr(
"precision_config",
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
// Consider consolidating DotOps together.
if (DotIsDefault(instruction)) {
return func_builder
->create<mlir::mhlo::DotOp>(loc, result_type, operands, attributes)
.getOperation();
}
attributes.push_back(builder_->getNamedAttr(
"dot_dimension_numbers",
ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
builder_)));
return func_builder
->create<mlir::mhlo::DotGeneralOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kCall: {
TF_ASSIGN_OR_RETURN(FuncOp function,
ImportAsFunc(*instruction->to_apply()));
mlir::Operation* new_operation =
func_builder->create<mlir::func::CallOp>(loc, function, operands);
return new_operation;
}
case HloOpcode::kCollectivePermute: {
attributes.push_back(ConvertSourceTargetPairs(
instruction->source_target_pairs(), builder_));
return func_builder
->create<mlir::mhlo::CollectivePermuteOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kCustomCall: {
auto custom_call = Cast<HloCustomCallInstruction>(instruction);
const auto& called_computations = custom_call->called_computations();
if (!called_computations.empty()) {
llvm::SmallVector<mlir::Attribute> callees;
callees.reserve(called_computations.size());
for (HloComputation* callee : called_computations) {
TF_ASSIGN_OR_RETURN(FuncOp function, ImportAsFunc(*callee));
callees.push_back(mlir::FlatSymbolRefAttr::get(builder_->getContext(),
function.getName()));
}
attributes.push_back(builder_->getNamedAttr(
"called_computations",
mlir::ArrayAttr::get(builder_->getContext(), callees)));
}
if (custom_call->layout_constrained()) {
TF_ASSIGN_OR_RETURN(
mlir::ArrayAttr operand_layouts,
ExtractLayoutsFromShapes(custom_call->operand_shapes_with_layout(),
builder_));
attributes.push_back(
builder_->getNamedAttr("operand_layouts", operand_layouts));
mlir::ArrayAttr result_layouts;
if (custom_call->shape().IsTuple()) {
TF_ASSIGN_OR_RETURN(
result_layouts,
ExtractLayoutsFromTuple(custom_call->shape(), builder_));
} else {
TF_ASSIGN_OR_RETURN(
result_layouts,
ExtractLayoutsFromShapes({custom_call->shape()}, builder_));
}
attributes.push_back(
builder_->getNamedAttr("result_layouts", result_layouts));
}
TF_ASSIGN_OR_RETURN(
auto mlir_api_version,
ConvertCustomCallApiVersion(custom_call->api_version()));
attributes.push_back(builder_->getNamedAttr(
"call_target_name",
builder_->getStringAttr(custom_call->custom_call_target())));
attributes.push_back(builder_->getNamedAttr(
"has_side_effect",
builder_->getBoolAttr(custom_call->custom_call_has_side_effect())));
attributes.push_back(builder_->getNamedAttr(
"backend_config",
builder_->getStringAttr(custom_call->raw_backend_config_string())));
attributes.push_back(builder_->getNamedAttr(
"api_version", mlir::mhlo::CustomCallApiVersionAttr::get(
builder_->getContext(), mlir_api_version)));
return func_builder
->create<mlir::mhlo::CustomCallOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kCompare: {
auto compare = Cast<HloCompareInstruction>(instruction);
attributes.push_back(ConvertComparisonDirection(compare->direction()));
auto default_type = Comparison::DefaultComparisonType(
compare->operand(0)->shape().element_type());
if (compare->type() != default_type)
attributes.push_back(ConvertComparisonType(compare->type()));
return func_builder
->create<mlir::mhlo::CompareOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kCholesky: {
attributes.push_back(builder_->getNamedAttr(
"lower",
builder_->getBoolAttr(instruction->cholesky_options().lower())));
return func_builder
->create<mlir::mhlo::CholeskyOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kGather: {
auto gather_instruction = Cast<HloGatherInstruction>(instruction);
attributes.push_back(builder_->getNamedAttr(
"dimension_numbers",
ConvertGatherDimensionNumbers(
gather_instruction->gather_dimension_numbers(), builder_)));
std::vector<int64_t> slice_sizes(
gather_instruction->gather_slice_sizes().begin(),
gather_instruction->gather_slice_sizes().end());
attributes.push_back(
builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
attributes.push_back(builder_->getNamedAttr(
"indices_are_sorted",
builder_->getBoolAttr(gather_instruction->indices_are_sorted())));
return func_builder
->create<mlir::mhlo::GatherOp>(loc, result_type, operands, attributes)
.getOperation();
}
case HloOpcode::kDynamicSlice: {
std::vector<int64_t> slice_sizes(
instruction->dynamic_slice_sizes().begin(),
instruction->dynamic_slice_sizes().end());
return func_builder
->create<mlir::mhlo::DynamicSliceOp>(
loc, result_type, operands[0],
makeArrayRef(operands).drop_front(), Convert(slice_sizes))
.getOperation();
}
case HloOpcode::kDynamicUpdateSlice: {
return func_builder
->create<mlir::mhlo::DynamicUpdateSliceOp>(
loc, result_type, operands[0], operands[1],
llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
.getOperation();
}
case HloOpcode::kInfeed: {
if (IsNestedTupleInData(result_type)) {
llvm_unreachable(
"Importing xla::kInfeed with nested tuple shape not supported");
}
attributes.push_back(builder_->getNamedAttr(
"infeed_config",
mlir::StringAttr::get(builder_->getContext(),
instruction->infeed_config())));
llvm::SmallVector<mlir::Attribute> flattened_attr;
TF_RETURN_IF_ERROR(
ConvertShapeToMlirLayout(instruction->shape(), flattened_attr));
attributes.push_back(builder_->getNamedAttr(
"layout", builder_->getArrayAttr(makeArrayRef(flattened_attr))));
// Flatten the return-type if they are tuple-typed.
llvm::SmallVector<Type> flattened_ret_types;
FlattenTupleType(result_type, flattened_ret_types);
auto op = func_builder->create<mlir::mhlo::InfeedOp>(
loc, flattened_ret_types, operands, attributes);
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
result_type);
}
case HloOpcode::kOutfeed: {
attributes.push_back(builder_->getNamedAttr(
"outfeed_config",
mlir::StringAttr::get(builder_->getContext(),
instruction->outfeed_config())));
assert(operands.size() == 2 && "Expected 2 operands for HLO Infeed");
// In case operands[0] is a tuple, flatten it.
llvm::SmallVector<Value> flattened_operands;
FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
flattened_operands.push_back(operands[1]);
auto op = func_builder->create<mlir::mhlo::OutfeedOp>(
loc, result_type, flattened_operands, attributes);
return op.getOperation();
}
case HloOpcode::kPad: {
const auto& padding_config = instruction->padding_config();
llvm::SmallVector<int64_t, 4> edge_padding_low;
llvm::SmallVector<int64_t, 4> edge_padding_high;
llvm::SmallVector<int64_t, 4> interior_padding;
edge_padding_low.reserve(padding_config.dimensions_size());
edge_padding_high.reserve(padding_config.dimensions_size());
interior_padding.reserve(padding_config.dimensions_size());
for (const auto& dimension : padding_config.dimensions()) {
edge_padding_low.push_back(dimension.edge_padding_low());
edge_padding_high.push_back(dimension.edge_padding_high());
interior_padding.push_back(dimension.interior_padding());
}
return func_builder
->create<mlir::mhlo::PadOp>(loc, result_type, operands[0],
operands[1], Convert(edge_padding_low),
Convert(edge_padding_high),
Convert(interior_padding))
.getOperation();
}
case HloOpcode::kScatter: {
auto scatter = Cast<HloScatterInstruction>(instruction);
attributes.push_back(builder_->getNamedAttr(
"scatter_dimension_numbers",
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
builder_)));
attributes.push_back(builder_->getNamedAttr(
"indices_are_sorted",
builder_->getBoolAttr(scatter->indices_are_sorted())));
attributes.push_back(builder_->getNamedAttr(
"unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
auto scatter_op = func_builder->create<mlir::mhlo::ScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
&scatter_op.update_computation(),
/*flatten_region_arg_tuple=*/true));
return scatter_op.getOperation();
}
case HloOpcode::kSelectAndScatter: {
auto select_scatter = Cast<HloSelectAndScatterInstruction>(instruction);
llvm::SmallVector<int64_t, 4> window_strides, window_dimensions;
llvm::SmallVector<int64_t, 8> padding;
for (const auto& dim : select_scatter->window().dimensions()) {
window_strides.push_back(dim.stride());
window_dimensions.push_back(dim.size());
padding.push_back(dim.padding_low());
padding.push_back(dim.padding_high());
}
attributes.push_back(
builder_->getNamedAttr("window_strides", Convert(window_strides)));
attributes.push_back(builder_->getNamedAttr("window_dimensions",
Convert(window_dimensions)));
attributes.push_back(ConvertPadding(padding));
auto select_scatter_op =
func_builder->create<mlir::mhlo::SelectAndScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
&select_scatter_op.select(),
/*flatten_region_arg_tuple=*/true));
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
&select_scatter_op.scatter(),
/*flatten_region_arg_tuple=*/true));
return select_scatter_op.getOperation();
}
case HloOpcode::kSetDimensionSize: {
attributes.push_back(builder_->getNamedAttr(
"dimension", builder_->getI64IntegerAttr(instruction->dimension())));
return func_builder
->create<mlir::mhlo::SetDimensionSizeOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kSlice: {
return func_builder
->create<mlir::mhlo::SliceOp>(
loc, result_type, operands[0],
ConvertDimensions(instruction->slice_starts()),
ConvertDimensions(instruction->slice_limits()),
ConvertDimensions(instruction->slice_strides()))
.getOperation();
}
case HloOpcode::kSort: {
auto sort_instruction = Cast<HloSortInstruction>(instruction);
llvm::SmallVector<Type, 4> return_types = {result_type};
if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
return_types = llvm::to_vector<6>(tuple_ty.getTypes());
}
auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
loc, return_types, operands,
builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
builder_->getBoolAttr(sort_instruction->is_stable()));
TF_RETURN_IF_ERROR(ImportAsRegion(*sort_instruction->to_apply(),
&sort_op.comparator(),
/*flatten_region_arg_tuple=*/true));
// Check if the output needs to be tupled.
if (return_types.size() == 1 && return_types.front() == result_type) {
return sort_op.getOperation();
}
return func_builder
->create<mlir::mhlo::TupleOp>(loc, result_type, sort_op.getResults())
.getOperation();
}
case HloOpcode::kConditional: {
llvm::SmallVector<Type, 4> rets;
// Flatten the tuple-typed operands.
llvm::SmallVector<Value> flattened_operands;
for (auto& operand : operands)
FlattenTupleValue(func_builder, loc, operand, flattened_operands);
// If/Case Op has a single operand; we collect the other operands to
// replace the corresponding block arguments.
llvm::ArrayRef<Value> implicit_operands(flattened_operands.begin() + 1,
flattened_operands.end());
mlir::Type pred_or_index_type =
operands[0].getType().cast<mlir::TensorType>().getElementType();
// It is a predicated conditional if first argument is a boolean and
// should be mapped to If op.
if (pred_or_index_type.isInteger(1)) {
TF_RETURN_IF_ERROR(GetMlirTypes(
{instruction->true_computation()->root_instruction()}, &rets));
// Flatten the return-type.
llvm::SmallVector<Type> flattened_ret_types;
assert(rets.size() == 1);
FlattenTupleType(rets[0], flattened_ret_types);
auto op = func_builder->create<mlir::mhlo::IfOp>(
loc, flattened_ret_types, flattened_operands[0], attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
&op.true_branch(),
/*flatten_region_arg_tuple=*/true));
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
&op.false_branch(),
/*flatten_region_arg_tuple=*/true));
// Replace the uses of block-arguments of the IfOp with the
// implicit_operands.
ReplaceBlockArgumentsWithImplicitOperands(op.getOperation(),
implicit_operands);
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
rets[0]);
}
// Otherwise, it is a indexed conditional and should be mapped to Case
// op.
TF_RETURN_IF_ERROR(GetMlirTypes(
{instruction->branch_computation(0)->root_instruction()}, &rets));
// Flatten the return-type.
llvm::SmallVector<Type> flattened_ret_types;
assert(rets.size() == 1);
FlattenTupleType(rets[0], flattened_ret_types);
int num_branches = instruction->branch_count();
auto op = func_builder->create<mlir::mhlo::CaseOp>(
loc, flattened_ret_types, flattened_operands[0], attributes,
num_branches);
for (const auto& index_and_computation :
llvm::enumerate(instruction->branch_computations())) {
auto index = index_and_computation.index();
HloComputation* computation = index_and_computation.value();
TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index],
/*flatten_region_arg_tuple=*/true));
}
// Replace the uses of block-arguments of the CaseOp with the
// implicit_operands.
ReplaceBlockArgumentsWithImplicitOperands(op.getOperation(),
implicit_operands);
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
rets[0]);
}
case HloOpcode::kConcatenate: {
// TODO(b/132057942): Support taking an uint64_t instead of an
// IntegerAttr for concatenate dimension.
return func_builder
->create<mlir::mhlo::ConcatenateOp>(
loc, result_type, operands,
builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
.getOperation();
}
case HloOpcode::kAllGather: {
auto all_gather = Cast<HloAllGatherInstruction>(instruction);
attributes.push_back(builder_->getNamedAttr(
"all_gather_dim",
builder_->getI64IntegerAttr(all_gather->all_gather_dimension())));
attributes.push_back(
ConvertReplicaGroups(all_gather->replica_groups(), builder_));
if (all_gather->channel_id().has_value())
attributes.push_back(
ConvertChannelHandle(all_gather->channel_id().value()));
return func_builder
->create<mlir::mhlo::AllGatherOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kAllReduce: {
auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
attributes.push_back(
ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
if (all_reduce->channel_id().has_value())
attributes.push_back(
ConvertChannelHandle(all_reduce->channel_id().value()));
auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
&all_reduce_op.computation()));
return all_reduce_op.getOperation();
}
case HloOpcode::kAllToAll: {
// TODO(b/207152612): all-to-all HLO can either have pre-split operands
// (and returns a tuple) or a single operand that is split across
// `split_dimension` into the number of replicas in a group. Only the
// latter case (array all-to-all) is supported in importer right now and
// the former (tuple all-to-all) is not supported yet.
auto all_to_all = Cast<HloAllToAllInstruction>(instruction);
if (all_to_all->shape().IsTuple())
return tensorflow::errors::Unimplemented(
"Importing tuple all-to-all HLO is not supported yet");
// Check invariants of array all-to-all. This is a sanity check and is
// verified by the HLO verifier.
if (!all_to_all->split_dimension().has_value() || operands.size() != 1 ||
all_to_all->replica_groups().empty())
return tensorflow::errors::InvalidArgument(
"Array all-to-all should have a split dimension, one operand and "
"non-empty replica groups");
auto replica_groups_attr =
ConvertReplicaGroups(all_to_all->replica_groups(), builder_)
.getValue()
.cast<DenseIntElementsAttr>();
uint64_t split_dim = all_to_all->split_dimension().value();
uint64_t concat_dim = split_dim;
uint64_t split_count = all_to_all->replica_groups()[0].replica_ids_size();
return func_builder
->create<mlir::mhlo::AllToAllOp>(loc, result_type, operands[0],
split_dim, concat_dim, split_count,
replica_groups_attr)
.getOperation();
}
case HloOpcode::kReduce: {
// Operands in the first half are reduction inputs and the remaining
// operands are corresponding initial values.
size_t num_inputs = operands.size() / 2;
llvm::SmallVector<Type, 4> return_types = {result_type};
if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
return_types = llvm::to_vector<6>(tuple_ty.getTypes());
}
auto reduce = func_builder->create<mlir::mhlo::ReduceOp>(
loc, return_types,
llvm::makeArrayRef(operands).take_front(num_inputs),
llvm::makeArrayRef(operands).drop_front(num_inputs),
ConvertDimensions(instruction->dimensions()));
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
&reduce.body(),
/*flatten_region_arg_tuple=*/true));
// Check if the output needs to be tupled.
if (return_types.size() == 1 && return_types.front() == result_type) {
return reduce.getOperation();
}
return func_builder
->create<mlir::mhlo::TupleOp>(loc, result_type, reduce.getResults())
.getOperation();
}
case HloOpcode::kReverse: {
return func_builder
->create<mlir::mhlo::ReverseOp>(
loc, result_type, operands[0],
ConvertDimensions(instruction->dimensions()))
.getOperation();
}
case HloOpcode::kRng: {
auto shape = func_builder->create<mlir::mhlo::ConstOp>(
loc, Convert(result_type.cast<RankedTensorType>().getShape()));
switch (instruction->random_distribution()) {
case xla::RNG_UNIFORM:
return func_builder
->create<mlir::mhlo::RngUniformOp>(loc, result_type, operands[0],
operands[1], shape)
.getOperation();
case xla::RNG_NORMAL:
return func_builder
->create<mlir::mhlo::RngNormalOp>(loc, result_type, operands[0],
operands[1], shape)
.getOperation();
default:
return tensorflow::errors::InvalidArgument(absl::StrCat(
"Unsupported distribution: ",
RandomDistributionToString(instruction->random_distribution())));
}
}
case HloOpcode::kRngBitGenerator: {
auto rng_op = Cast<HloRngBitGeneratorInstruction>(instruction);
// Flatten the return type if they are tuple-typed.
llvm::SmallVector<Type> flattened_ret_types;
FlattenTupleType(result_type, flattened_ret_types);
auto op = func_builder->create<mlir::mhlo::RngBitGeneratorOp>(
loc, flattened_ret_types,
func_builder->getI32IntegerAttr(rng_op->algorithm()), operands[0]);
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
result_type);
}
case HloOpcode::kRngGetAndUpdateState: {
return func_builder
->create<mlir::mhlo::XlaRngGetAndUpdateStateOp>(
loc, result_type,
func_builder->getI64IntegerAttr(
Cast<HloRngGetAndUpdateStateInstruction>(instruction)
->delta()))
.getOperation();
}
case HloOpcode::kWhile: {
llvm::SmallVector<Value> flattened_operands;
llvm::SmallVector<Type> flattened_operand_types;
FlattenTupleType(operands[0].getType(), flattened_operand_types);
FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
auto op = func_builder->create<mlir::mhlo::WhileOp>(
loc, flattened_operand_types, flattened_operands);
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->while_condition(),
&op.cond(),
/*flatten_region_arg_tuple=*/true));
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->while_body(), &op.body(),
/*flatten_region_arg_tuple=*/true));
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
operands[0].getType());
}
case HloOpcode::kGetTupleElement: {
attributes.push_back(builder_->getNamedAttr(
"index", builder_->getIntegerAttr(builder_->getIntegerType(32),
instruction->tuple_index())));
return func_builder
->create<mlir::mhlo::GetTupleElementOp>(loc, result_type, operands,
attributes)
.getOperation();
};
case HloOpcode::kGetDimensionSize: {
attributes.push_back(builder_->getNamedAttr(
"dimension", builder_->getI64IntegerAttr(instruction->dimension())));
return func_builder
->create<mlir::mhlo::GetDimensionSizeOp>(loc, result_type, operands,
attributes)
.getOperation();
};
case HloOpcode::kTranspose: {
attributes.push_back(builder_->getNamedAttr(
"permutation", ConvertDimensions(instruction->dimensions())));
return func_builder
->create<mlir::mhlo::TransposeOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kTriangularSolve: {
attributes.push_back(builder_->getNamedAttr(
"left_side",
builder_->getBoolAttr(
instruction->triangular_solve_options().left_side())));
attributes.push_back(builder_->getNamedAttr(
"lower", builder_->getBoolAttr(
instruction->triangular_solve_options().lower())));
attributes.push_back(builder_->getNamedAttr(
"unit_diagonal",
builder_->getBoolAttr(
instruction->triangular_solve_options().unit_diagonal())));
auto transpose_a = mlir::mhlo::TransposeAttr::get(
builder_->getContext(),
mlir::mhlo::symbolizeTranspose(
TriangularSolveOptions::Transpose_Name(
instruction->triangular_solve_options().transpose_a()))
.getValue());
attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
return func_builder
->create<mlir::mhlo::TriangularSolveOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kReduceScatter: {
auto reduce_scatter = Cast<HloReduceScatterInstruction>(instruction);
attributes.push_back(builder_->getNamedAttr(
"scatter_dimension",
builder_->getI64IntegerAttr(reduce_scatter->scatter_dimension())));
attributes.push_back(
ConvertReplicaGroups(reduce_scatter->replica_groups(), builder_));
if (reduce_scatter->channel_id().has_value())
attributes.push_back(
ConvertChannelHandle(reduce_scatter->channel_id().value()));
auto reduce_scatter_op =
func_builder->create<mlir::mhlo::ReduceScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*reduce_scatter->to_apply(),
&reduce_scatter_op.computation()));
return reduce_scatter_op.getOperation();
}
case HloOpcode::kReduceWindow: {
llvm::SmallVector<Type, 4> return_types = {result_type};
if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
return_types = llvm::to_vector<6>(tuple_ty.getTypes());
}
llvm::SmallVector<int64_t, 4> sizes, strides, base_dilations,
win_dilations;
llvm::SmallVector<int64_t, 8> padding;
for (const auto& dim : instruction->window().dimensions()) {
sizes.push_back(dim.size());
strides.push_back(dim.stride());
base_dilations.push_back(dim.base_dilation());
win_dilations.push_back(dim.window_dilation());
padding.push_back(dim.padding_low());
padding.push_back(dim.padding_high());
}
attributes.push_back(builder_->getNamedAttr("window_dimensions",
ConvertDimensions(sizes)));
attributes.push_back(
builder_->getNamedAttr("window_strides", ConvertDimensions(strides)));
attributes.push_back(builder_->getNamedAttr(
"base_dilations", ConvertDimensions(base_dilations)));
attributes.push_back(builder_->getNamedAttr(
"window_dilations", ConvertDimensions(win_dilations)));
attributes.push_back(ConvertPadding(padding));
auto reduce = func_builder->create<mlir::mhlo::ReduceWindowOp>(
loc, return_types, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
&reduce.body(),
/*flatten_region_arg_tuple=*/true));
// Check if the output needs to be tupled.
if (return_types.size() == 1 && return_types.front() == result_type) {
return reduce.getOperation();
}
return func_builder
->create<mlir::mhlo::TupleOp>(loc, result_type, reduce.getResults())
.getOperation();
}
case HloOpcode::kMap: {
auto op = func_builder->create<mlir::mhlo::MapOp>(
loc, result_type, operands,
ConvertDimensions(instruction->dimensions()));
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
&op.computation(),
/*flatten_region_arg_tuple=*/true));
return op.getOperation();
}
case HloOpcode::kConvolution: {
llvm::SmallVector<int64_t, 4> strides, lhs_dilations, rhs_dilations;
llvm::SmallVector<int64_t, 8> paddings;
for (const auto& dim : instruction->window().dimensions()) {
strides.push_back(dim.stride());
lhs_dilations.push_back(dim.base_dilation());
rhs_dilations.push_back(dim.window_dilation());
paddings.push_back(dim.padding_low());
paddings.push_back(dim.padding_high());
}
attributes.push_back(
builder_->getNamedAttr("window_strides", Convert(strides)));
attributes.push_back(ConvertPadding(paddings));
attributes.push_back(
builder_->getNamedAttr("lhs_dilation", Convert(lhs_dilations)));
attributes.push_back(
builder_->getNamedAttr("rhs_dilation", Convert(rhs_dilations)));
attributes.push_back(builder_->getNamedAttr(
"dimension_numbers",
ConvertConvDimensionNumbers(
instruction->convolution_dimension_numbers(), builder_)));
attributes.push_back(builder_->getNamedAttr(
"feature_group_count",
builder_->getI64IntegerAttr(instruction->feature_group_count())));
attributes.push_back(builder_->getNamedAttr(
"batch_group_count",
builder_->getI64IntegerAttr(instruction->batch_group_count())));
attributes.push_back(builder_->getNamedAttr(
"precision_config",
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
return func_builder
->create<mlir::mhlo::ConvOp>(loc, result_type, operands, attributes)
.getOperation();
}
case HloOpcode::kFft: {
auto fft_type = mlir::mhlo::FftTypeAttr::get(
builder_->getContext(),
mlir::mhlo::symbolizeFftType(FftType_Name(instruction->fft_type()))
.getValue());
std::vector<int64_t> fft_length(instruction->fft_length().begin(),
instruction->fft_length().end());
attributes.push_back(builder_->getNamedAttr("fft_type", fft_type));
attributes.push_back(
builder_->getNamedAttr("fft_length", Convert(fft_length)));
return func_builder
->create<mlir::mhlo::FftOp>(loc, result_type, operands, attributes)
.getOperation();
}
case HloOpcode::kAdd: {
// HLO add ops on PRED elements are actually boolean or, but MHLO dialect
// AddOps on i1 are just addition with overflow; so, we have to implement
// the special behavior of HLO add ops on PRED here by creating an
// arith::OrIOp instead.
if (instruction->shape().element_type() == PRED) {
return func_builder
->create<mlir::mhlo::OrOp>(loc, result_type, operands, attributes)
.getOperation();
} else {
return func_builder
->create<mlir::mhlo::AddOp>(loc, result_type, operands, attributes)
.getOperation();
}
}
case HloOpcode::kAfterAll: {
// HLO AfterAll ops without any token input are used to just create a
// token. MHLO has a special op CreateToken for this case.
if (instruction->operands().empty()) {
return func_builder
->create<mlir::mhlo::CreateTokenOp>(loc, result_type, operands,
attributes)
.getOperation();
} else {
return func_builder
->create<mlir::mhlo::AfterAllOp>(loc, result_type, operands,
attributes)
.getOperation();
}
}
case HloOpcode::kConvert: {
// Convert to boolean is special, it requires a comparison to 0 instead of
// a truncation to i1, otherwise it is a 1-1 translation.
auto ranked_type = result_type.dyn_cast<mlir::RankedTensorType>();
mlir::IntegerType integer_type =
(ranked_type)
? ranked_type.getElementType().dyn_cast<mlir::IntegerType>()
: nullptr;
if (!integer_type || integer_type.getWidth() != 1) {
// Simple case: 1-1 mapping.
return {func_builder->create<mlir::mhlo::ConvertOp>(
loc, result_type, operands, attributes)};
}
// Return type is boolean, let's use `operand != 0` instead of Convert.
xla::Shape input_shape = instruction->operand(0)->shape();
TF_ASSIGN_OR_RETURN(mlir::Type type,
ConvertTensorShapeToType<mlir::RankedTensorType>(
input_shape, *func_builder));
auto zero = func_builder->create<mlir::mhlo::ConstOp>(
loc, func_builder->getZeroAttr(type));
return {func_builder->create<mlir::mhlo::CompareOp>(
loc, operands[0], zero, mlir::mhlo::ComparisonDirection::NE)};
}
case HloOpcode::kOptimizationBarrier: {
llvm::SmallVector<Value> flattened_operands;
llvm::SmallVector<Type> flattened_operand_types;
FlattenTupleType(operands[0].getType(), flattened_operand_types);
FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
auto op = func_builder->create<mlir::mhlo::OptimizationBarrierOp>(
loc, flattened_operand_types, flattened_operands);
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
operands[0].getType());
}
#define NO_ATTRIBUTE_CASE(hlo_op_code, mlir_op) \
case HloOpcode::hlo_op_code: { \
return func_builder \
->create<mlir::mhlo::mlir_op>(loc, result_type, operands, attributes) \
.getOperation(); \
}
// broadcast dimensions are never added here because they don't exist as
// part of the HLO instruction. They are only a convenience in the XLA
// builder API.
NO_ATTRIBUTE_CASE(kAbs, AbsOp);
NO_ATTRIBUTE_CASE(kAnd, AndOp);
NO_ATTRIBUTE_CASE(kAtan2, Atan2Op);
NO_ATTRIBUTE_CASE(kBitcastConvert, BitcastConvertOp);
NO_ATTRIBUTE_CASE(kCbrt, CbrtOp);
NO_ATTRIBUTE_CASE(kClz, ClzOp);
NO_ATTRIBUTE_CASE(kCeil, CeilOp);
NO_ATTRIBUTE_CASE(kClamp, ClampOp);
NO_ATTRIBUTE_CASE(kComplex, ComplexOp);
NO_ATTRIBUTE_CASE(kCos, CosOp);
NO_ATTRIBUTE_CASE(kDivide, DivOp);
NO_ATTRIBUTE_CASE(kExp, ExpOp);
NO_ATTRIBUTE_CASE(kExpm1, Expm1Op);
NO_ATTRIBUTE_CASE(kFloor, FloorOp);
NO_ATTRIBUTE_CASE(kIsFinite, IsFiniteOp);
NO_ATTRIBUTE_CASE(kImag, ImagOp);
NO_ATTRIBUTE_CASE(kLog, LogOp);
NO_ATTRIBUTE_CASE(kLog1p, Log1pOp);
NO_ATTRIBUTE_CASE(kMaximum, MaxOp);
NO_ATTRIBUTE_CASE(kMinimum, MinOp);
NO_ATTRIBUTE_CASE(kMultiply, MulOp);
NO_ATTRIBUTE_CASE(kNegate, NegOp);
NO_ATTRIBUTE_CASE(kNot, NotOp);
NO_ATTRIBUTE_CASE(kOr, OrOp);
NO_ATTRIBUTE_CASE(kPopulationCount, PopulationCountOp);
NO_ATTRIBUTE_CASE(kPower, PowOp);
NO_ATTRIBUTE_CASE(kReal, RealOp);
NO_ATTRIBUTE_CASE(kRemainder, RemOp);
NO_ATTRIBUTE_CASE(kReplicaId, ReplicaIdOp);
NO_ATTRIBUTE_CASE(kLogistic, LogisticOp);
// The dimensions attribute is not present on the HLO Reshape
// instruction. If dimensions are non-default, the XLA builder
// implements it as a separate transpose.
NO_ATTRIBUTE_CASE(kReshape, ReshapeOp);
NO_ATTRIBUTE_CASE(kRoundNearestAfz, RoundOp);
NO_ATTRIBUTE_CASE(kRsqrt, RsqrtOp);
NO_ATTRIBUTE_CASE(kSelect, SelectOp);
NO_ATTRIBUTE_CASE(kShiftLeft, ShiftLeftOp);
NO_ATTRIBUTE_CASE(kShiftRightArithmetic, ShiftRightArithmeticOp);
NO_ATTRIBUTE_CASE(kShiftRightLogical, ShiftRightLogicalOp);
NO_ATTRIBUTE_CASE(kSign, SignOp);
NO_ATTRIBUTE_CASE(kSin, SinOp);
NO_ATTRIBUTE_CASE(kSqrt, SqrtOp);
NO_ATTRIBUTE_CASE(kSubtract, SubOp);
NO_ATTRIBUTE_CASE(kTanh, TanhOp);
NO_ATTRIBUTE_CASE(kTuple, TupleOp);
NO_ATTRIBUTE_CASE(kXor, XorOp);
// TODO(b/129422361) Copy needs special handling because it is not
// defined in tensorflow/compiler/xla/client/xla_builder.h. See
// operation semantics in
// g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
NO_ATTRIBUTE_CASE(kCopy, CopyOp);
#undef NO_ATTRIBUTE_CASE
case HloOpcode::kFusion: {
// Flatten the tuple-typed operands.
llvm::SmallVector<Value> flattened_operands;
for (auto& operand : operands)
FlattenTupleValue(func_builder, loc, operand, flattened_operands);
// Flatten the return type if they are tuple-typed.
llvm::SmallVector<Type> flattened_ret_types;
FlattenTupleType(result_type, flattened_ret_types);
auto fusion_kind = mlir::mhlo::symbolizeFusionKind(
xla::ToString(instruction->fusion_kind()));
auto fusion = func_builder->create<mlir::mhlo::FusionOp>(
loc, flattened_ret_types, flattened_operands,
mlir::mhlo::FusionKindAttr::get(func_builder->getContext(),
fusion_kind.getValue()));
TF_RETURN_IF_ERROR(ImportAsRegion(
*instruction->fused_instructions_computation(),
&fusion.fused_computation(), /*flatten_region_arg_tuple=*/true));
return CreateTupleFromOpResults(func_builder, loc, fusion.getOperation(),
result_type);
}
case HloOpcode::kBitcast: {
auto bitcast = func_builder->create<mlir::mhlo::BitcastOp>(
loc, result_type, operands, attributes);
// Store the source and result layout as attributes. Although the MHLO
// Bitcast operates on tensors, these layouts are relevant as they define
// the mapping between the elements of the source and result.
SetLayoutForMlir(bitcast, instruction->shape(), "result_layout");
SetLayoutForMlir(bitcast, instruction->operand(0)->shape(),
"source_layout");
return bitcast.getOperation();
}
case HloOpcode::kReducePrecision: {
auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
loc, result_type, operands[0], attributes);
op.exponent_bitsAttr(func_builder->getIntegerAttr(
func_builder->getI32Type(), instruction->exponent_bits()));
op.mantissa_bitsAttr(func_builder->getIntegerAttr(
func_builder->getI32Type(), instruction->mantissa_bits()));
return op.getOperation();
}
case HloOpcode::kAddDependency:
// Arbitrary op code that I suspect we will not implement for quite a
// while and allows testing handling of unknown ops. Selected because it
// is not mentioned in xla client anywhere or in the hlo of our sample
// models.
default: {
mlir::OperationState result(loc, "mhlo.unknown");
result.addOperands(operands);
result.addTypes(result_type);
for (auto attr : attributes) {
result.attributes.push_back(attr);
}
return func_builder->create(result);
}
}
}
void SetXlaShape(mlir::Operation* op, const Shape& shape) {
op->setAttr("xla_shape",
mlir::Builder(op->getContext())
.getStringAttr(shape.ToString(/*print_layout=*/true)));
}
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionWithLayout(
const HloInstruction* instruction,
const llvm::SmallVectorImpl<mlir::Value>& operands,
mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) {
TF_ASSIGN_OR_RETURN(
mlir::Operation * op,
ImportInstructionImpl(instruction, operands, func_builder, mode));
if (op == nullptr) return op;
// See MlirToHloConversionOptions for more about layouts.
//
// Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
// in physical minor-to-major order.
if (instruction->shape().IsArray()) {
if (!instruction->shape().layout().minor_to_major().empty() &&
instruction->shape().layout() !=
LayoutUtil::MakeDescendingLayout(
instruction->shape().dimensions().size())) {
SetXlaShape(op, instruction->shape());
}
} else {
SetXlaShape(op, instruction->shape());
}
return op;
}
StatusOr<llvm::SmallVector<mlir::Value, 4>> HloFunctionImporter::GetOperands(
const HloInstruction* instruction) {
llvm::SmallVector<mlir::Value, 4> operands;
for (const auto& operand : instruction->operands()) {
auto input_it = instruction_value_map_.find(operand);
if (input_it == instruction_value_map_.end()) {
return tensorflow::errors::Internal(
absl::StrCat("Could not find input value: ", operand->name(),
" for instruction ", instruction->name()));
}
operands.push_back(input_it->second);
}
return operands;
}
tensorflow::Status HloFunctionImporter::GetMlirTypes(
const std::vector<HloInstruction*>& instructions,
llvm::SmallVectorImpl<mlir::Type>* types) {
for (auto instruction : instructions) {
TF_ASSIGN_OR_RETURN(auto ret_type, ConvertShapeToType<RankedTensorType>(
instruction->shape(), *builder_));
types->push_back(ret_type);
}
return tensorflow::Status::OK();
}
StatusOr<Value> HloFunctionImporter::GetMlirValue(
const HloInstruction* instruction) {
auto lookup = instruction_value_map_.find(instruction);
if (lookup != instruction_value_map_.end()) {
return lookup->second;
}
return tensorflow::errors::Internal(absl::StrCat(
"Unable to find value for input: ", instruction->ToString()));
}
mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
ComparisonDirection direction) {
return builder_->getNamedAttr(
"comparison_direction",
mlir::mhlo::ComparisonDirectionAttr::get(
builder_->getContext(), mlir::mhlo::symbolizeComparisonDirection(
ComparisonDirectionToString(direction))
.getValue()));
}
mlir::NamedAttribute HloFunctionImporter::ConvertComparisonType(
Comparison::Type type) {
return builder_->getNamedAttr(
"compare_type",
mlir::mhlo::ComparisonTypeAttr::get(
builder_->getContext(),
mlir::mhlo::symbolizeComparisonType(ComparisonTypeToString(type))
.getValue()));
}
mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions(
absl::Span<const int64_t> op_dimensions) {
llvm::SmallVector<APInt, 8> dimensions;
dimensions.reserve(op_dimensions.size());
for (auto value : op_dimensions) dimensions.emplace_back(APInt(64, value));
return DenseIntElementsAttr::get(
RankedTensorType::get(dimensions.size(), builder_->getIntegerType(64)),
dimensions);
}
mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
llvm::ArrayRef<int64_t> elements) {
return DenseIntElementsAttr::get(
RankedTensorType::get(elements.size(), builder_->getIntegerType(64)),
elements);
}
mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
llvm::ArrayRef<int64_t> padding) {
auto ty =
mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
builder_->getIntegerType(64));
auto attr = DenseIntElementsAttr::get(ty, padding);
return builder_->getNamedAttr("padding", attr);
}
mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
mlir::Builder* builder) {
std::vector<int64_t> attr(source_target_pairs.size() * 2);
for (const auto& p : llvm::enumerate(source_target_pairs)) {
attr[2 * p.index()] = p.value().first;
attr[2 * p.index() + 1] = p.value().second;
}
auto type = mlir::RankedTensorType::get(
{static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
return builder->getNamedAttr("source_target_pairs",
DenseIntElementsAttr::get(type, attr));
}
mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder) {
const int64_t num_groups = replica_groups.size();
// Replica groups in HLO can be non-uniform in size, for example:
// replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
// tensor, pad the smaller sized replica groups with -1.
const int64_t group_size = absl::c_accumulate(
replica_groups, int64_t(0), [](int64_t current, const ReplicaGroup& g) {
return std::max<int64_t>(current, g.replica_ids_size());
});
// Initialize all elements to -1 to support non-uniform replica groups.
std::vector<int64_t> attr(num_groups * group_size, -1);
for (int i = 0; i < num_groups; ++i) {
int index = i * group_size;
for (const int64_t& id : replica_groups[i].replica_ids())
attr[index++] = id;
}
auto type = mlir::RankedTensorType::get({num_groups, group_size},
builder->getIntegerType(64));
return builder->getNamedAttr("replica_groups",
DenseIntElementsAttr::get(type, attr));
}
mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
absl::optional<int64_t> channel_id) {
xla::ChannelHandle channel_handle;
if (channel_id) channel_handle.set_handle(*channel_id);
return ConvertChannelHandle(channel_handle);
}
mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
const xla::ChannelHandle& channel) {
return builder_->getNamedAttr(
"channel_handle",
mlir::mhlo::ChannelHandle::get(
builder_->getI64IntegerAttr(channel.handle()),
builder_->getI64IntegerAttr(channel.type()), context_));
}
void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op,
const Shape& shape,
llvm::StringRef attr_name) {
llvm::SmallVector<int64_t, 4> minor_to_major(
shape.layout().minor_to_major().begin(),
shape.layout().minor_to_major().end());
op->setAttr(
attr_name,
mlir::Builder(op->getContext()).getIndexTensorAttr(minor_to_major));
}
Status HloFunctionImporter::ConvertShapeToMlirLayout(
const xla::Shape& shape,
llvm::SmallVectorImpl<mlir::Attribute>& flattened_attr) {
if (shape.IsToken()) {
return tensorflow::Status::OK();
}
if (shape.IsTuple()) {
std::vector<mlir::Attribute> tuple_layouts;
for (int i = 0; i < shape.tuple_shapes_size(); i++) {
TF_RETURN_IF_ERROR(
ConvertShapeToMlirLayout(shape.tuple_shapes(i), flattened_attr));
}
return tensorflow::Status::OK();
}
if (shape.IsArray()) {
const xla::Layout l = shape.layout();
std::vector<mlir::Attribute> minor_to_major;
for (int64_t i : l.minor_to_major()) {
minor_to_major.push_back(builder_->getI64IntegerAttr(i));
}
llvm::ArrayRef<mlir::Attribute> array_ref(minor_to_major);
flattened_attr.push_back(builder_->getArrayAttr(array_ref));
return tensorflow::Status::OK();
}
return tensorflow::errors::Internal("Couldn't convert layout.");
}
} // namespace xla