blob: 341132d96f5905f58fa2587f4b085109043809de [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include <unordered_map>
#include "absl/types/optional.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
using llvm::APInt;
using llvm::makeArrayRef;
using mlir::DenseIntElementsAttr;
using mlir::FuncOp;
using mlir::NamedAttribute;
using mlir::Operation;
using mlir::RankedTensorType;
using mlir::Type;
using mlir::Value;
namespace xla {
namespace {
// Note: This sanitization function causes an irreversible many-to-one mapping
// and any solution to mitigate this would cause issues with the reverse
// direction. Longterm solution is to add a function attribute to maintain the
// original HLO naming.
string SanitizeFunctionName(llvm::StringRef name) {
string output(name);
llvm::for_each(output, [](char& x) { x = x == '-' ? '_' : x; });
return output;
}
// Returns whether the instruction is a default dot operation.
bool DotIsDefault(const HloInstruction* instruction) {
auto dnums = instruction->dot_dimension_numbers();
DotDimensionNumbers default_dimension_numbers;
default_dimension_numbers.add_lhs_contracting_dimensions(
instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1);
default_dimension_numbers.add_rhs_contracting_dimensions(0);
return xla::protobuf_util::ProtobufEquals(dnums, default_dimension_numbers);
}
// Returns an MLIR Location generated from HLO Instruction. Uses instruction
// metadata if present or instruction name.
mlir::Location GenerateInstructionLocation(HloInstruction* instruction,
mlir::OpBuilder* func_builder) {
const std::string& op_name = instruction->metadata().op_name();
if (op_name.empty()) {
return mlir::NameLoc::get(func_builder->getIdentifier(instruction->name()),
func_builder->getContext());
}
mlir::Location op_name_loc = mlir::NameLoc::get(
func_builder->getIdentifier(op_name), func_builder->getContext());
const std::string& source_file = instruction->metadata().source_file();
if (source_file.empty()) {
return op_name_loc;
}
return mlir::FusedLoc::get(
{op_name_loc, mlir::FileLineColLoc::get(
source_file, instruction->metadata().source_line(), 0,
func_builder->getContext())},
func_builder->getContext());
}
} // namespace
Status HloFunctionImporter::ImportAsFunc(
const HloComputation& computation, mlir::ModuleOp module,
std::unordered_map<const HloComputation*, FuncOp>* function_map,
mlir::Builder* builder) {
HloFunctionImporter importer(module, function_map, builder);
return importer.ImportAsFunc(computation).status();
}
Status HloFunctionImporter::ImportAsRegion(
const xla::HloComputation& computation, mlir::Region* region,
mlir::Builder* builder) {
HloFunctionImporter importer(region->getParentOfType<mlir::ModuleOp>(), {},
builder);
return importer.ImportAsRegion(computation, region);
}
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
const HloComputation& computation) {
auto& imported = (*function_map_)[&computation];
if (imported) return imported;
llvm::SmallVector<Type, 4> args, rets;
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
auto func_type = mlir::FunctionType::get(args, rets, context_);
string computation_name =
computation.parent()->entry_computation() == &computation
? "main"
: SanitizeFunctionName(computation.name());
// Construct the MLIR function and map arguments.
llvm::ArrayRef<mlir::NamedAttribute> attrs;
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
computation_name, func_type, attrs);
auto visibility = computation_name == "main" ? FuncOp::Visibility::Public
: FuncOp::Visibility::Private;
function.setVisibility(visibility);
module_.push_back(function);
// Add to the map right away for function calls.
imported = function;
mlir::Block* block = function.addEntryBlock();
TF_RETURN_IF_ERROR(ImportInstructions(computation, block));
return function;
}
tensorflow::Status HloFunctionImporter::ImportAsRegion(
const HloComputation& computation, mlir::Region* region) {
// TODO(hinsu): Store computation name as an attribute for round-trip.
auto* block = new mlir::Block;
region->push_back(block);
llvm::SmallVector<Type, 4> args;
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
block->addArguments(args);
return ImportInstructions(computation, block);
}
StatusOr<Value> HloFunctionImporter::ImportInstructionsImpl(
const xla::HloComputation& computation,
const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
// Setup the input parameters.
const int num_parameters = computation.num_parameters();
if (arguments.size() != num_parameters)
return InvalidArgument("Caller vs callee argument sizes do not match");
for (int i = 0; i < num_parameters; i++) {
auto hlo_parameter = computation.parameter_instruction(i);
instruction_value_map_[hlo_parameter] = arguments[i];
}
for (auto instruction : computation.MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(auto new_operation,
ImportInstruction(instruction, builder));
if (new_operation) {
instruction_value_map_[instruction] = new_operation->getResult(0);
}
}
// Setup the return type (HLO only supports a single return value).
return GetMlirValue(computation.root_instruction());
}
Status HloFunctionImporter::ImportInstructions(
const HloComputation& computation, mlir::Block* block) {
llvm::SmallVector<Value, 4> arguments(block->args_begin(), block->args_end());
mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
TF_ASSIGN_OR_RETURN(Value result,
ImportInstructionsImpl(computation, arguments, &builder));
// TODO(suderman): Add location tracking details.
mlir::Location loc = builder.getUnknownLoc();
// Create terminator op depending on the parent op of this region.
if (llvm::isa<FuncOp>(block->getParentOp())) {
builder.create<mlir::ReturnOp>(loc, result);
} else {
builder.create<mlir::mhlo::ReturnOp>(loc, result);
}
return tensorflow::Status::OK();
}
StatusOr<Value> HloFunctionImporter::ImportInstructions(
const xla::HloComputation& computation,
const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
mlir::Block* block = builder->getBlock();
if (block == nullptr)
return InvalidArgument(
"ImportInstructions requires a valid block in the builder");
HloFunctionImporter importer(
block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
return importer.ImportInstructionsImpl(computation, arguments, builder);
}
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
HloInstruction* instruction, mlir::OpBuilder* func_builder) {
TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType<RankedTensorType>(
instruction->shape(), *builder_));
mlir::Location loc = GenerateInstructionLocation(instruction, func_builder);
llvm::SmallVector<NamedAttribute, 10> attributes;
switch (instruction->opcode()) {
case HloOpcode::kParameter: {
return nullptr;
}
case HloOpcode::kConstant: {
const Literal& literal = instruction->literal();
auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
if (!attr.ok()) return attr.status();
mlir::Operation* new_operation =
func_builder->create<mlir::mhlo::ConstOp>(loc, attr.ValueOrDie());
for (auto attr : attributes) {
new_operation->setAttr(attr.first, attr.second);
}
return new_operation;
}
case HloOpcode::kIota: {
return func_builder
->create<mlir::mhlo::IotaOp>(
loc, result_type,
func_builder->getI64IntegerAttr(
Cast<HloIotaInstruction>(instruction)->iota_dimension()))
.getOperation();
}
#define MakeAndReturn(mlir_op) \
{ \
mlir::Operation* new_operation = \
func_builder->create<mlir::mhlo::mlir_op>(loc, result_type, operands, \
attributes); \
return new_operation; \
}
case HloOpcode::kBroadcast: {
// Note that the HLO broadcast is more powerful than the XLA broadcast
// op. BroadcastInDim offers a superset of the HLO op's functionality.
attributes.push_back(
builder_->getNamedAttr("broadcast_dimensions",
ConvertDimensions(instruction->dimensions())));
MakeAndReturn(BroadcastInDimOp);
}
#define MakeAndReturnBatchNormOp(batch_norm_op) \
{ \
attributes.push_back(builder_->getNamedAttr( \
"epsilon", builder_->getF32FloatAttr(instruction->epsilon()))); \
attributes.push_back(builder_->getNamedAttr( \
"feature_index", \
builder_->getI64IntegerAttr(instruction->feature_index()))); \
MakeAndReturn(batch_norm_op); \
}
case HloOpcode::kBatchNormGrad:
MakeAndReturnBatchNormOp(BatchNormGradOp);
case HloOpcode::kBatchNormInference:
MakeAndReturnBatchNormOp(BatchNormInferenceOp);
case HloOpcode::kBatchNormTraining:
MakeAndReturnBatchNormOp(BatchNormTrainingOp);
#undef MakeAndReturnBatchNormOp
case HloOpcode::kDot: {
attributes.push_back(builder_->getNamedAttr(
"precision_config",
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
// Consider consolidating DotOps together.
if (DotIsDefault(instruction)) {
MakeAndReturn(DotOp);
}
attributes.push_back(builder_->getNamedAttr(
"dot_dimension_numbers",
ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
builder_)));
MakeAndReturn(DotGeneralOp);
}
case HloOpcode::kCall: {
TF_ASSIGN_OR_RETURN(FuncOp function,
ImportAsFunc(*instruction->to_apply()));
mlir::Operation* new_operation =
func_builder->create<mlir::CallOp>(loc, function, operands);
return new_operation;
}
case HloOpcode::kCollectivePermute: {
attributes.push_back(
ConvertSourceTargetPairs(instruction->source_target_pairs()));
MakeAndReturn(CollectivePermuteOp);
}
case HloOpcode::kCustomCall: {
auto custom_call = Cast<HloCustomCallInstruction>(instruction);
attributes.push_back(builder_->getNamedAttr(
"call_target_name",
builder_->getStringAttr(custom_call->custom_call_target())));
attributes.push_back(builder_->getNamedAttr(
"has_side_effect",
builder_->getBoolAttr(custom_call->custom_call_has_side_effect())));
attributes.push_back(builder_->getNamedAttr(
"backend_config",
builder_->getStringAttr(custom_call->raw_backend_config_string())));
MakeAndReturn(CustomCallOp);
}
case HloOpcode::kCompare: {
auto compare = Cast<HloCompareInstruction>(instruction);
attributes.push_back(ConvertComparisonDirection(compare->direction()));
auto default_type = Comparison::DefaultComparisonType(
compare->operand(0)->shape().element_type());
if (compare->type() != default_type)
attributes.push_back(ConvertComparisonType(compare->type()));
MakeAndReturn(CompareOp);
}
case HloOpcode::kCholesky: {
attributes.push_back(builder_->getNamedAttr(
"lower",
builder_->getBoolAttr(instruction->cholesky_options().lower())));
MakeAndReturn(CholeskyOp);
}
case HloOpcode::kGather: {
auto gather_instruction = Cast<HloGatherInstruction>(instruction);
attributes.push_back(builder_->getNamedAttr(
"dimension_numbers",
ConvertGatherDimensionNumbers(
gather_instruction->gather_dimension_numbers(), builder_)));
std::vector<int64_t> slice_sizes(
gather_instruction->gather_slice_sizes().begin(),
gather_instruction->gather_slice_sizes().end());
attributes.push_back(
builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
attributes.push_back(builder_->getNamedAttr(
"indices_are_sorted",
builder_->getBoolAttr(gather_instruction->indices_are_sorted())));
MakeAndReturn(GatherOp);
}
case HloOpcode::kDynamicSlice: {
std::vector<int64_t> slice_sizes(
instruction->dynamic_slice_sizes().begin(),
instruction->dynamic_slice_sizes().end());
return func_builder
->create<mlir::mhlo::DynamicSliceOp>(
loc, result_type, operands[0],
makeArrayRef(operands).drop_front(), Convert(slice_sizes))
.getOperation();
}
case HloOpcode::kDynamicUpdateSlice: {
return func_builder
->create<mlir::mhlo::DynamicUpdateSliceOp>(
loc, result_type, operands[0], operands[1],
llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
.getOperation();
}
case HloOpcode::kInfeed: {
attributes.push_back(builder_->getNamedAttr(
"infeed_config", mlir::StringAttr::get(instruction->infeed_config(),
builder_->getContext())));
MakeAndReturn(InfeedOp);
}
case HloOpcode::kOutfeed: {
attributes.push_back(builder_->getNamedAttr(
"outfeed_config", mlir::StringAttr::get(instruction->outfeed_config(),
builder_->getContext())));
MakeAndReturn(OutfeedOp);
}
case HloOpcode::kPad: {
const auto& padding_config = instruction->padding_config();
llvm::SmallVector<int64_t, 4> edge_padding_low;
llvm::SmallVector<int64_t, 4> edge_padding_high;
llvm::SmallVector<int64_t, 4> interior_padding;
edge_padding_low.reserve(padding_config.dimensions_size());
edge_padding_high.reserve(padding_config.dimensions_size());
interior_padding.reserve(padding_config.dimensions_size());
for (const auto& dimension : padding_config.dimensions()) {
edge_padding_low.push_back(dimension.edge_padding_low());
edge_padding_high.push_back(dimension.edge_padding_high());
interior_padding.push_back(dimension.interior_padding());
}
return func_builder
->create<mlir::mhlo::PadOp>(loc, result_type, operands[0],
operands[1], Convert(edge_padding_low),
Convert(edge_padding_high),
Convert(interior_padding))
.getOperation();
}
case HloOpcode::kScatter: {
auto scatter = Cast<HloScatterInstruction>(instruction);
attributes.push_back(builder_->getNamedAttr(
"scatter_dimension_numbers",
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
builder_)));
attributes.push_back(builder_->getNamedAttr(
"indices_are_sorted",
builder_->getBoolAttr(scatter->indices_are_sorted())));
attributes.push_back(builder_->getNamedAttr(
"unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
auto scatter_op = func_builder->create<mlir::mhlo::ScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
&scatter_op.update_computation()));
return scatter_op.getOperation();
}
case HloOpcode::kSelectAndScatter: {
auto select_scatter = Cast<HloSelectAndScatterInstruction>(instruction);
llvm::SmallVector<int64_t, 4> window_strides, window_dimensions;
llvm::SmallVector<int64_t, 8> padding;
for (const auto& dim : select_scatter->window().dimensions()) {
window_strides.push_back(dim.stride());
window_dimensions.push_back(dim.size());
padding.push_back(dim.padding_low());
padding.push_back(dim.padding_high());
}
attributes.push_back(
builder_->getNamedAttr("window_strides", Convert(window_strides)));
attributes.push_back(builder_->getNamedAttr("window_dimensions",
Convert(window_dimensions)));
attributes.push_back(ConvertPadding(padding));
auto select_scatter_op =
func_builder->create<mlir::mhlo::SelectAndScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
&select_scatter_op.select()));
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
&select_scatter_op.scatter()));
return select_scatter_op.getOperation();
}
case HloOpcode::kSetDimensionSize: {
attributes.push_back(builder_->getNamedAttr(
"dimension", builder_->getI64IntegerAttr(instruction->dimension())));
MakeAndReturn(SetDimensionSizeOp);
}
case HloOpcode::kSlice: {
return func_builder
->create<mlir::mhlo::SliceOp>(
loc, result_type, operands[0],
ConvertDimensions(instruction->slice_starts()),
ConvertDimensions(instruction->slice_limits()),
ConvertDimensions(instruction->slice_strides()))
.getOperation();
}
case HloOpcode::kSort: {
auto sort_instruction = Cast<HloSortInstruction>(instruction);
llvm::SmallVector<Type, 4> return_types = {result_type};
if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
return_types = llvm::to_vector<6>(tuple_ty.getTypes());
}
auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
loc, return_types, operands,
builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
builder_->getBoolAttr(sort_instruction->is_stable()));
TF_RETURN_IF_ERROR(
ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator()));
// Check if the output needs to be tupled.
if (return_types.size() == 1 && return_types.front() == result_type) {
return sort_op.getOperation();
}
return func_builder
->create<mlir::mhlo::TupleOp>(loc, result_type, sort_op.getResults())
.getOperation();
}
case HloOpcode::kConditional: {
llvm::SmallVector<Type, 4> rets;
mlir::Type pred_or_index_type =
operands[0].getType().cast<mlir::TensorType>().getElementType();
// It is a predicated conditional if first argument is a boolean and
// should be mapped to If op.
if (pred_or_index_type.isInteger(1)) {
TF_RETURN_IF_ERROR(GetMlirTypes(
{instruction->true_computation()->root_instruction()}, &rets));
auto op = func_builder->create<mlir::mhlo::IfOp>(loc, rets, operands,
attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
&op.true_branch()));
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
&op.false_branch()));
return op.getOperation();
}
// Otherwise, it is a indexed conditional and should be mapped to Case
// op.
TF_RETURN_IF_ERROR(GetMlirTypes(
{instruction->branch_computation(0)->root_instruction()}, &rets));
int num_branches = instruction->branch_count();
auto op = func_builder->create<mlir::mhlo::CaseOp>(
loc, rets, operands, attributes, num_branches);
for (auto index_and_computation :
llvm::enumerate(instruction->branch_computations())) {
auto index = index_and_computation.index();
HloComputation* computation = index_and_computation.value();
TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index]));
}
return op.getOperation();
}
case HloOpcode::kConcatenate: {
// TODO(b/132057942): Support taking an uint64_t instead of an
// IntegerAttr for concatenate dimension.
return func_builder
->create<mlir::mhlo::ConcatenateOp>(
loc, result_type, operands,
builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
.getOperation();
}
case HloOpcode::kAllReduce: {
auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
attributes.push_back(
ConvertReplicaGroups(all_reduce->replica_groups(), *builder_));
attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
&all_reduce_op.computation()));
return all_reduce_op.getOperation();
}
case HloOpcode::kReduce: {
// Operands in the first half are reduction inputs and the remaining
// operands are corresponding initial values.
size_t num_inputs = operands.size() / 2;
auto reduce = func_builder->create<mlir::mhlo::ReduceOp>(
loc, result_type, llvm::makeArrayRef(operands).take_front(num_inputs),
llvm::makeArrayRef(operands).drop_front(num_inputs),
ConvertDimensions(instruction->dimensions()));
TF_RETURN_IF_ERROR(
ImportAsRegion(*instruction->to_apply(), &reduce.body()));
return reduce.getOperation();
}
case HloOpcode::kReverse: {
return func_builder
->create<mlir::mhlo::ReverseOp>(
loc, result_type, operands[0],
ConvertDimensions(instruction->dimensions()))
.getOperation();
}
case HloOpcode::kRng: {
auto shape = func_builder->create<mlir::ConstantOp>(
loc, Convert(result_type.cast<RankedTensorType>().getShape()));
switch (instruction->random_distribution()) {
case xla::RNG_UNIFORM:
return func_builder
->create<mlir::mhlo::RngUniformOp>(loc, result_type, operands[0],
operands[1], shape)
.getOperation();
case xla::RNG_NORMAL:
return func_builder
->create<mlir::mhlo::RngNormalOp>(loc, result_type, operands[0],
operands[1], shape)
.getOperation();
default:
return tensorflow::errors::InvalidArgument(absl::StrCat(
"Unsupported distribution: ",
RandomDistributionToString(instruction->random_distribution())));
}
}
case HloOpcode::kRngBitGenerator: {
auto rng_op = Cast<HloRngBitGeneratorInstruction>(instruction);
auto op = func_builder->create<mlir::mhlo::RngBitGeneratorOp>(
loc, result_type,
func_builder->getI32IntegerAttr(rng_op->algorithm()), operands[0]);
return op.getOperation();
}
case HloOpcode::kWhile: {
auto op = func_builder->create<mlir::mhlo::WhileOp>(
loc, operands[0].getType(), operands[0]);
TF_RETURN_IF_ERROR(
ImportAsRegion(*instruction->while_condition(), &op.cond()));
TF_RETURN_IF_ERROR(
ImportAsRegion(*instruction->while_body(), &op.body()));
return op.getOperation();
}
case HloOpcode::kGetTupleElement: {
attributes.push_back(builder_->getNamedAttr(
"index", builder_->getIntegerAttr(builder_->getIntegerType(32),
instruction->tuple_index())));
MakeAndReturn(GetTupleElementOp);
};
case HloOpcode::kGetDimensionSize: {
attributes.push_back(builder_->getNamedAttr(
"dimension", builder_->getI64IntegerAttr(instruction->dimension())));
MakeAndReturn(GetDimensionSizeOp);
};
case HloOpcode::kTranspose: {
attributes.push_back(builder_->getNamedAttr(
"permutation", ConvertDimensions(instruction->dimensions())));
MakeAndReturn(TransposeOp);
}
case HloOpcode::kTriangularSolve: {
attributes.push_back(builder_->getNamedAttr(
"left_side",
builder_->getBoolAttr(
instruction->triangular_solve_options().left_side())));
attributes.push_back(builder_->getNamedAttr(
"lower", builder_->getBoolAttr(
instruction->triangular_solve_options().lower())));
attributes.push_back(builder_->getNamedAttr(
"unit_diagonal",
builder_->getBoolAttr(
instruction->triangular_solve_options().unit_diagonal())));
auto transpose_a =
builder_->getStringAttr(TriangularSolveOptions::Transpose_Name(
instruction->triangular_solve_options().transpose_a()));
attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
MakeAndReturn(TriangularSolveOp);
}
case HloOpcode::kReduceWindow: {
llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
llvm::SmallVector<int64_t, 8> padding;
for (const auto& dim : instruction->window().dimensions()) {
sizes.push_back(dim.size());
strides.push_back(dim.stride());
base_dilations.push_back(dim.base_dilation());
win_dilations.push_back(dim.window_dilation());
padding.push_back(dim.padding_low());
padding.push_back(dim.padding_high());
}
attributes.push_back(builder_->getNamedAttr("window_dimensions",
ConvertDimensions(sizes)));
attributes.push_back(
builder_->getNamedAttr("window_strides", ConvertDimensions(strides)));
attributes.push_back(builder_->getNamedAttr(
"base_dilations", ConvertDimensions(base_dilations)));
attributes.push_back(builder_->getNamedAttr(
"window_dilations", ConvertDimensions(win_dilations)));
attributes.push_back(ConvertPadding(padding));
auto reduce = func_builder->create<mlir::mhlo::ReduceWindowOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(
ImportAsRegion(*instruction->to_apply(), &reduce.body()));
return reduce.getOperation();
}
case HloOpcode::kMap: {
auto op = func_builder->create<mlir::mhlo::MapOp>(
loc, result_type, operands,
ConvertDimensions(instruction->dimensions()));
TF_RETURN_IF_ERROR(
ImportAsRegion(*instruction->to_apply(), &op.computation()));
return op.getOperation();
}
case HloOpcode::kConvolution: {
llvm::SmallVector<int64_t, 4> strides, lhs_dilations, rhs_dilations;
llvm::SmallVector<int64_t, 8> paddings;
for (const auto& dim : instruction->window().dimensions()) {
strides.push_back(dim.stride());
lhs_dilations.push_back(dim.base_dilation());
rhs_dilations.push_back(dim.window_dilation());
paddings.push_back(dim.padding_low());
paddings.push_back(dim.padding_high());
}
attributes.push_back(
builder_->getNamedAttr("window_strides", Convert(strides)));
attributes.push_back(ConvertPadding(paddings));
attributes.push_back(
builder_->getNamedAttr("lhs_dilation", Convert(lhs_dilations)));
attributes.push_back(
builder_->getNamedAttr("rhs_dilation", Convert(rhs_dilations)));
attributes.push_back(builder_->getNamedAttr(
"dimension_numbers",
ConvertConvDimensionNumbers(
instruction->convolution_dimension_numbers(), builder_)));
attributes.push_back(builder_->getNamedAttr(
"feature_group_count",
builder_->getI64IntegerAttr(instruction->feature_group_count())));
attributes.push_back(builder_->getNamedAttr(
"batch_group_count",
builder_->getI64IntegerAttr(instruction->batch_group_count())));
attributes.push_back(builder_->getNamedAttr(
"precision_config",
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
MakeAndReturn(ConvOp);
}
case HloOpcode::kFft: {
auto fft_type =
builder_->getStringAttr(FftType_Name(instruction->fft_type()));
std::vector<int64_t> fft_length(instruction->fft_length().begin(),
instruction->fft_length().end());
attributes.push_back(builder_->getNamedAttr("fft_type", fft_type));
attributes.push_back(
builder_->getNamedAttr("fft_length", Convert(fft_length)));
MakeAndReturn(FftOp);
}
#define NoAttributeCase(hlo_op_code, mlir_op) \
case HloOpcode::hlo_op_code: { \
MakeAndReturn(mlir_op); \
}
// broadcast dimensions are never added here because they don't exist as
// part of the HLO instruction. They are only a convenience in the XLA
// builder API.
NoAttributeCase(kAbs, AbsOp);
NoAttributeCase(kAdd, AddOp);
NoAttributeCase(kAfterAll, AfterAllOp);
NoAttributeCase(kAnd, AndOp);
NoAttributeCase(kAtan2, Atan2Op);
NoAttributeCase(kBitcastConvert, BitcastConvertOp);
NoAttributeCase(kCbrt, CbrtOp);
NoAttributeCase(kConvert, ConvertOp);
NoAttributeCase(kCeil, CeilOp);
NoAttributeCase(kClamp, ClampOp);
NoAttributeCase(kComplex, ComplexOp);
NoAttributeCase(kCos, CosOp);
NoAttributeCase(kDivide, DivOp);
NoAttributeCase(kExp, ExpOp);
NoAttributeCase(kExpm1, Expm1Op);
NoAttributeCase(kFloor, FloorOp);
NoAttributeCase(kIsFinite, IsFiniteOp);
NoAttributeCase(kImag, ImagOp);
NoAttributeCase(kLog, LogOp);
NoAttributeCase(kLog1p, Log1pOp);
NoAttributeCase(kMaximum, MaxOp);
NoAttributeCase(kMinimum, MinOp);
NoAttributeCase(kMultiply, MulOp);
NoAttributeCase(kNegate, NegOp);
NoAttributeCase(kNot, NotOp);
NoAttributeCase(kOr, OrOp);
NoAttributeCase(kPopulationCount, PopulationCountOp);
NoAttributeCase(kPower, PowOp);
NoAttributeCase(kReal, RealOp);
NoAttributeCase(kRemainder, RemOp);
NoAttributeCase(kReplicaId, ReplicaIdOp);
// The dimensions attribute is not present on the HLO Reshape
// instruction. If dimensions are non-default, the XLA builder
// implements it as a separate transpose.
NoAttributeCase(kReshape, ReshapeOp);
NoAttributeCase(kRoundNearestAfz, RoundOp);
NoAttributeCase(kRsqrt, RsqrtOp);
NoAttributeCase(kSelect, SelectOp);
NoAttributeCase(kShiftLeft, ShiftLeftOp);
NoAttributeCase(kShiftRightArithmetic, ShiftRightArithmeticOp);
NoAttributeCase(kShiftRightLogical, ShiftRightLogicalOp);
NoAttributeCase(kSign, SignOp);
NoAttributeCase(kSin, SinOp);
NoAttributeCase(kSqrt, SqrtOp);
NoAttributeCase(kSubtract, SubOp);
NoAttributeCase(kTanh, TanhOp);
NoAttributeCase(kTuple, TupleOp);
NoAttributeCase(kXor, XorOp);
// TODO(b/129422361) Copy needs special handling because it is not
// defined in tensorflow/compiler/xla/client/xla_builder.h. See
// operation semantics in
// g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
NoAttributeCase(kCopy, CopyOp);
#undef NoAttributeCase
#undef MakeAndReturn
case HloOpcode::kFusion: {
auto fusion = func_builder->create<mlir::mhlo::FusionOp>(
loc, result_type, operands,
builder_->getStringAttr(xla::ToString(instruction->fusion_kind())));
TF_RETURN_IF_ERROR(
ImportAsRegion(*instruction->fused_instructions_computation(),
&fusion.fused_computation()));
return fusion.getOperation();
}
case HloOpcode::kBitcast:
return func_builder
->create<mlir::mhlo::BitcastOp>(loc, result_type, operands,
attributes)
.getOperation();
case HloOpcode::kReducePrecision: {
auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
loc, result_type, operands[0], attributes);
op.exponent_bitsAttr(func_builder->getIntegerAttr(
func_builder->getI32Type(), instruction->exponent_bits()));
op.mantissa_bitsAttr(func_builder->getIntegerAttr(
func_builder->getI32Type(), instruction->mantissa_bits()));
return op.getOperation();
}
case HloOpcode::kAddDependency:
// Arbitrary op code that I suspect we will not implement for quite a
// while and allows testing handling of unknown ops. Selected because it
// is not mentioned in xla client anywhere or in the hlo of our sample
// models.
default: {
mlir::OperationState result(loc, "mhlo.unknown");
result.addOperands(operands);
result.addTypes(result_type);
for (auto attr : attributes) {
result.attributes.push_back(attr);
}
return func_builder->createOperation(result);
}
}
}
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
HloInstruction* instruction, mlir::OpBuilder* func_builder) {
TF_ASSIGN_OR_RETURN(mlir::Operation * op,
ImportInstructionImpl(instruction, func_builder));
if (op == nullptr) return op;
// See MlirToHloConversionOptions for more about layouts.
//
// Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
// in physical minor-to-major order.
if (instruction->shape().IsArray() &&
!instruction->shape().layout().minor_to_major().empty() &&
instruction->shape().layout() !=
LayoutUtil::MakeDescendingLayout(
instruction->shape().dimensions().size())) {
SetLayoutForMlir(op, instruction->shape());
}
return op;
}
StatusOr<llvm::SmallVector<mlir::Value, 4>> HloFunctionImporter::GetOperands(
HloInstruction* instruction) {
llvm::SmallVector<mlir::Value, 4> operands;
for (const auto& operand : instruction->operands()) {
auto input_it = instruction_value_map_.find(operand);
if (input_it == instruction_value_map_.end()) {
return tensorflow::errors::Internal(
absl::StrCat("Could not find input value: ", operand->name(),
" for instruction ", instruction->name()));
}
operands.push_back(input_it->second);
}
return operands;
}
tensorflow::Status HloFunctionImporter::GetMlirTypes(
const std::vector<HloInstruction*>& instructions,
llvm::SmallVectorImpl<mlir::Type>* types) {
for (auto instruction : instructions) {
TF_ASSIGN_OR_RETURN(auto ret_type, ConvertShapeToType<RankedTensorType>(
instruction->shape(), *builder_));
types->push_back(ret_type);
}
return tensorflow::Status::OK();
}
StatusOr<Value> HloFunctionImporter::GetMlirValue(HloInstruction* instruction) {
auto lookup = instruction_value_map_.find(instruction);
if (lookup != instruction_value_map_.end()) {
return lookup->second;
}
return tensorflow::errors::Internal(absl::StrCat(
"Unable to find value for input: ", instruction->ToString()));
}
mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
ComparisonDirection direction) {
return builder_->getNamedAttr(
"comparison_direction",
builder_->getStringAttr(ComparisonDirectionToString(direction)));
}
mlir::NamedAttribute HloFunctionImporter::ConvertComparisonType(
Comparison::Type type) {
return builder_->getNamedAttr(
"compare_type", builder_->getStringAttr(ComparisonTypeToString(type)));
}
mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions(
llvm::ArrayRef<int64> op_dimensions) {
llvm::SmallVector<APInt, 8> dimensions;
dimensions.reserve(op_dimensions.size());
for (auto value : op_dimensions) dimensions.emplace_back(APInt(64, value));
return DenseIntElementsAttr::get(
RankedTensorType::get(dimensions.size(), builder_->getIntegerType(64)),
dimensions);
}
mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
llvm::ArrayRef<int64_t> elements) {
return DenseIntElementsAttr::get(
RankedTensorType::get(elements.size(), builder_->getIntegerType(64)),
elements);
}
mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
llvm::ArrayRef<int64_t> padding) {
auto ty =
mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
builder_->getIntegerType(64));
auto attr = DenseIntElementsAttr::get(ty, padding);
return builder_->getNamedAttr("padding", attr);
}
mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
source_target_pairs) {
std::vector<int64_t> attr(source_target_pairs.size() * 2);
for (auto p : llvm::enumerate(source_target_pairs)) {
attr[2 * p.index()] = p.value().first;
attr[2 * p.index() + 1] = p.value().second;
}
auto type = mlir::RankedTensorType::get(
{static_cast<int64_t>(attr.size() / 2), 2}, builder_->getIntegerType(64));
return builder_->getNamedAttr("source_target_pairs",
DenseIntElementsAttr::get(type, attr));
}
mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder) {
int64_t num_groups = replica_groups.size();
int64_t group_size =
num_groups == 0 ? 0 : replica_groups[0].replica_ids_size();
std::vector<int64_t> attr(num_groups * group_size);
int flat_index = 0;
for (const auto& group : replica_groups) {
assert(group_size == group.replica_ids_size());
for (int i = 0; i < group_size; ++i)
attr[flat_index++] = group.replica_ids(i);
}
auto type = mlir::RankedTensorType::get({num_groups, group_size},
builder.getIntegerType(64));
return builder.getNamedAttr("replica_groups",
DenseIntElementsAttr::get(type, attr));
}
mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
absl::optional<tensorflow::int64> channel_id) {
xla::ChannelHandle channel_handle;
if (channel_id.has_value()) channel_handle.set_handle(channel_id.value());
return ConvertChannelHandle(channel_handle);
}
mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
const xla::ChannelHandle& channel) {
return builder_->getNamedAttr(
"channel_handle",
mlir::mhlo::ChannelHandle::get(
builder_->getI64IntegerAttr(channel.handle()),
builder_->getI64IntegerAttr(channel.type()), context_));
}
void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op,
const Shape& shape) {
llvm::SmallVector<int64_t, 4> minor_to_major(
shape.layout().minor_to_major().begin(),
shape.layout().minor_to_major().end());
op->setAttr(
"minor_to_major",
mlir::Builder(op->getContext()).getIndexTensorAttr(minor_to_major));
}
} // namespace xla