blob: 568ea7f76fa2c570c6715cfd350ab644b68a60ff [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
#include <string>
#include <utility>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
static std::string GetMlirOpName(HloOpcode opcode) {
std::string op_name = HloOpcodeString(opcode);
absl::c_replace(op_name, '-', '_');
return mlir::mhlo::MhloDialect::getDialectNamespace().str() + "." + op_name;
}
static std::string ToString(mlir::Type ty) {
std::string str;
llvm::raw_string_ostream sstream(str);
ty.print(sstream);
sstream.flush();
return str;
}
// Returns 1D 64-bit dense elements attribute with the given values.
static mlir::DenseIntElementsAttr GetI64ElementsAttr(
absl::Span<const int64_t> values, mlir::Builder* builder) {
auto ty = mlir::RankedTensorType::get({static_cast<int64_t>(values.size())},
builder->getIntegerType(64));
return mlir::DenseIntElementsAttr::get(
ty, llvm::makeArrayRef(values.data(), values.size()));
}
static mlir::DenseIntElementsAttr ConvertPadding(
absl::Span<const std::pair<int64_t, int64_t>> padding,
mlir::Builder* builder) {
llvm::SmallVector<int64_t, 8> elements;
elements.reserve(padding.size() * 2);
for (const auto& vals : padding) {
elements.push_back(vals.first);
elements.push_back(vals.second);
}
auto ty = mlir::RankedTensorType::get(
{static_cast<int64_t>(padding.size()), 2}, builder->getIntegerType(64));
return mlir::DenseIntElementsAttr::get(ty, elements);
}
MlirHloBuilder::~MlirHloBuilder() = default;
StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) {
mlir::Type ty = val.getType();
auto shape = std::make_unique<Shape>(TypeToShape(ty));
if (shape->element_type() == PrimitiveType::PRIMITIVE_TYPE_INVALID) {
return InvalidArgument("unsupported type: %s", ToString(ty).c_str());
}
int64_t handle = reinterpret_cast<int64_t>(val.getAsOpaquePointer());
handle_to_shape_[handle] = std::move(shape);
return XlaOp(handle, this);
}
XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(mlir::DenseElementsAttr attr,
CreateDenseElementsAttrFromLiteral(literal, builder_));
auto op = builder_.create<mlir::mhlo::ConstantOp>(loc_, attr);
return MakeXlaOp(op);
});
}
StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
absl::Span<const int64_t> lhs_dilation,
absl::Span<const int64_t> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
mlir::ArrayAttr config_attr;
if (precision_config)
config_attr = ConvertPrecisionConfig(precision_config, &builder_);
auto op = builder_.create<mlir::mhlo::ConvolutionOp>(
loc_, ty, GetValue(lhs), GetValue(rhs),
GetI64ElementsAttr(window_strides, &builder_),
ConvertPadding(padding, &builder_),
GetI64ElementsAttr(lhs_dilation, &builder_),
GetI64ElementsAttr(rhs_dilation, &builder_),
/*window_reversal=*/nullptr,
ConvertConvDimensionNumbers(dimension_numbers, &builder_),
builder_.getI64IntegerAttr(feature_group_count),
builder_.getI64IntegerAttr(batch_group_count), config_attr);
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::FftInternal(
const Shape& shape, XlaOp operand, FftType fft_type,
absl::Span<const int64_t> fft_length) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto fft_type_attr = mlir::mhlo::symbolizeFftType(FftType_Name(fft_type));
auto op = builder_.create<mlir::mhlo::FftOp>(
loc_, ty, GetValue(operand),
mlir::mhlo::FftTypeAttr::get(builder_.getContext(),
fft_type_attr.getValue()),
GetI64ElementsAttr(fft_length, &builder_));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
const std::string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const std::string& opaque,
std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal, std::optional<Window> window,
std::optional<ConvolutionDimensionNumbers> dnums,
CustomCallSchedule schedule, CustomCallApiVersion api_version) {
TF_RET_CHECK(output_operand_aliasing.empty())
<< "MLIR CustomCallOp does not support output_operand_aliasing yet";
TF_RET_CHECK(literal == nullptr)
<< "MLIR CustomCallOp does not support literal yet";
TF_RET_CHECK(!window.has_value())
<< "MLIR CustomCallOp does not support ConvolutionDimensionNumbers yet";
TF_RET_CHECK(!dnums.has_value())
<< "MLIR CustomCallOp does not support ConvolutionDimensionNumbers yet";
TF_RET_CHECK(schedule == CustomCallSchedule::SCHEDULE_NONE)
<< "MLIR CustomCallOp does not support custom-call-schedule yet";
llvm::SmallVector<mlir::NamedAttribute> attributes;
if (operand_shapes_with_layout.has_value()) {
TF_ASSIGN_OR_RETURN(mlir::ArrayAttr operand_layouts,
ExtractLayoutsFromShapes(
operand_shapes_with_layout.value(), &builder_));
attributes.push_back(
builder_.getNamedAttr("operand_layouts", operand_layouts));
mlir::ArrayAttr result_layouts;
if (shape.IsTuple()) {
TF_ASSIGN_OR_RETURN(result_layouts,
ExtractLayoutsFromTuple(shape, &builder_));
} else {
TF_ASSIGN_OR_RETURN(result_layouts,
ExtractLayoutsFromShapes({shape}, &builder_));
}
attributes.push_back(
builder_.getNamedAttr("result_layouts", result_layouts));
}
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
TF_ASSIGN_OR_RETURN(auto mlir_api_version,
ConvertCustomCallApiVersion(api_version));
attributes.push_back(builder_.getNamedAttr(
"api_version", mlir::mhlo::CustomCallApiVersionAttr::get(
builder_.getContext(), mlir_api_version)));
attributes.push_back(builder_.getNamedAttr(
"call_target_name", builder_.getStringAttr(call_target_name)));
attributes.push_back(builder_.getNamedAttr(
"has_side_effect", builder_.getBoolAttr(has_side_effect)));
attributes.push_back(
builder_.getNamedAttr("backend_config", builder_.getStringAttr(opaque)));
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
loc_, ty, GetValues(operands), attributes);
return MakeXlaOp(op.getResult(0));
}
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64_t> dimensions_to_reduce) {
// Reduce takes two set of variadic operands inputs and init_values.
// all_operands contains both of these so split operands into two parts.
int64_t num_args = all_operands.size() / 2;
auto op = builder_.create<mlir::mhlo::ReduceOp>(
loc_, GetValues(all_operands.first(num_args)),
GetValues(all_operands.subspan(num_args)),
GetI64ElementsAttr(dimensions_to_reduce, &builder_));
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body(),
/*flatten_region_arg_tuple*/ true));
if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
return MakeXlaOp(tuple);
}
StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
const Shape& shape, XlaOp operand, XlaOp init_value,
const XlaComputation& computation, Window window) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
llvm::SmallVector<int64_t, 4> sizes, strides, base_dilations, win_dilations;
llvm::SmallVector<int64_t, 8> padding;
for (const auto& dim : window.dimensions()) {
sizes.push_back(dim.size());
strides.push_back(dim.stride());
base_dilations.push_back(dim.base_dilation());
win_dilations.push_back(dim.window_dilation());
padding.push_back(dim.padding_low());
padding.push_back(dim.padding_high());
}
auto padding_ty =
mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
builder_.getIntegerType(64));
auto op = builder_.create<mlir::mhlo::ReduceWindowOp>(
loc_, ty, GetValue(operand), GetValue(init_value),
GetI64ElementsAttr(sizes, &builder_),
GetI64ElementsAttr(strides, &builder_),
GetI64ElementsAttr(base_dilations, &builder_),
GetI64ElementsAttr(win_dilations, &builder_),
mlir::DenseIntElementsAttr::get(padding_ty, padding));
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body(),
/*flatten_region_arg_tuple*/ true));
return MakeXlaOp(op.getResult(0));
}
XlaOp MlirHloBuilder::Iota(const Shape& shape, int64_t iota_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
mlir::Type ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
auto op = builder_.create<mlir::mhlo::IotaOp>(
loc_, ty,
builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
return MakeXlaOp(op);
});
}
StatusOr<XlaOp> MlirHloBuilder::BitcastConvertTypeInternal(const Shape& shape,
XlaOp operand) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::BitcastConvertOp>(loc_, ty,
GetValue(operand));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64_t> permutation) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::TransposeOp>(
loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::RevInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64_t> dimensions) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::ReverseOp>(
loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::SortInternal(const Shape& shape,
absl::Span<const XlaOp> operands,
const XlaComputation& comparator,
int64_t dimension,
bool is_stable) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
llvm::SmallVector<mlir::Type, 4> sort_types = {ty};
if (auto tuple_ty = ty.dyn_cast<mlir::TupleType>()) {
sort_types = llvm::to_vector<6>(tuple_ty.getTypes());
}
auto op = builder_.create<mlir::mhlo::SortOp>(
loc_, sort_types, GetValues(operands),
builder_.getI64IntegerAttr(dimension), builder_.getBoolAttr(is_stable));
TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator()));
if (ty.isa<mlir::TupleType>()) {
auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
return MakeXlaOp(tuple);
}
return MakeXlaOp(op.getResult(0));
}
StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,
const XlaComputation& condition,
const XlaComputation& body,
XlaOp init) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
llvm::SmallVector<mlir::Value> flattened_operands;
llvm::SmallVector<mlir::Type> flattened_operand_types;
HloFunctionImporter::FlattenTupleType(ty, flattened_operand_types);
HloFunctionImporter::FlattenTupleValue(&builder_, loc_, GetValue(init),
flattened_operands);
auto op = builder_.create<mlir::mhlo::WhileOp>(loc_, flattened_operand_types,
flattened_operands);
TF_RETURN_IF_ERROR(ImportComputation(condition.proto(), &op.cond(),
/*flatten_region_arg_tuple*/ true));
TF_RETURN_IF_ERROR(ImportComputation(body.proto(), &op.body(),
/*flatten_region_arg_tuple*/ true));
if (ty.isa<mlir::TupleType>()) {
llvm::SmallVector<mlir::Value> flattened_results = op->getResults();
llvm::MutableArrayRef<mlir::Value> flattened_results_ref(flattened_results);
auto result = HloFunctionImporter::CreateTupleValue(
&builder_, loc_, flattened_results_ref, ty);
auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
return MakeXlaOp(defining_tuple_op);
}
return MakeXlaOp(op.getResult(0));
}
StatusOr<XlaOp> MlirHloBuilder::ReducePrecisionInternal(
const Shape& shape, XlaOp operand, const int exponent_bits,
const int mantissa_bits) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::ReducePrecisionOp>(
loc_, ty, GetValue(operand), builder_.getI32IntegerAttr(exponent_bits),
builder_.getI32IntegerAttr(mantissa_bits));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
const Shape& shape, XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,
absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::GatherOp>(
loc_, ty, GetValue(input), GetValue(start_indices),
ConvertGatherDimensionNumbers(dimension_numbers, &builder_),
GetI64ElementsAttr(slice_sizes, &builder_));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
const Shape& shape, absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
absl::Span<const XlaOp> updates, const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
bool unique_indices) {
// TODO(b/230137437): Allow variadic scatter after adding mhlo support.
if (inputs.size() != 1) {
return Unimplemented("Variadic scatter not implemented in mhlo yet.");
}
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::ScatterOp>(
loc_, ty, GetValue(inputs[0]), GetValue(scatter_indices),
GetValue(updates[0]),
ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
builder_.getBoolAttr(indices_are_sorted),
builder_.getBoolAttr(unique_indices));
TF_RETURN_IF_ERROR(
ImportComputation(update_computation.proto(), &op.update_computation()));
return MakeXlaOp(op.getResult(0));
}
StatusOr<XlaOp> MlirHloBuilder::SetDimensionSizeInternal(const Shape& shape,
XlaOp operand,
XlaOp val,
int64_t dimension) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::SetDimensionSizeOp>(
loc_, ty, GetValue(operand), GetValue(val),
builder_.getI64IntegerAttr(dimension));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
RandomDistribution distribution, absl::Span<const XlaOp> parameters,
const Shape& shape) {
mlir::mhlo::RngDistributionAttr attr;
if (distribution == xla::RandomDistribution::RNG_UNIFORM) {
attr = mlir::mhlo::RngDistributionAttr::get(
builder_.getContext(), mlir::mhlo::RngDistribution::UNIFORM);
} else {
TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL)
<< "Unexpected distribution: " << distribution;
attr = mlir::mhlo::RngDistributionAttr::get(
builder_.getContext(), mlir::mhlo::RngDistribution::NORMAL);
}
llvm::SmallVector<mlir::NamedAttribute, 1> attributes = {
builder_.getNamedAttr("rng_distribution", attr)};
if (shape.is_dynamic())
return Unimplemented("RngOp with dynamic dims not supported");
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::RngOp>(
loc_, ty, GetValue(parameters[0]), GetValue(parameters[1]),
GetValue(
ConstantLiteral(LiteralUtil::CreateR1<int64_t>(shape.dimensions()))),
attr);
return MakeXlaOp(op.getResult());
}
StatusOr<XlaOp> MlirHloBuilder::RngBitGeneratorInternal(
const Shape& full_result_shape, RandomAlgorithm algorithm,
XlaOp initial_state) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
full_result_shape, builder_));
llvm::SmallVector<mlir::Type> flattened_ret_types;
HloFunctionImporter::FlattenTupleType(ty, flattened_ret_types);
auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get(
builder_.getContext(), *mlir::mhlo::symbolizeRngAlgorithm(algorithm));
auto op = builder_.create<mlir::mhlo::RngBitGeneratorOp>(
loc_, flattened_ret_types, algorithm_attr, GetValue(initial_state));
if (ty.isa<mlir::TupleType>()) {
llvm::SmallVector<mlir::Value> flattened_results = op->getResults();
llvm::MutableArrayRef<mlir::Value> flattened_results_ref(flattened_results);
auto result = HloFunctionImporter::CreateTupleValue(
&builder_, loc_, flattened_results_ref, ty);
auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
return MakeXlaOp(defining_tuple_op);
}
return MakeXlaOp(op.getResult(0));
}
StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
XlaOp operand,
int64_t inferred_dimension) {
TF_RETURN_IF_ERROR(first_error());
if (inferred_dimension != -1)
return Unimplemented("inferred_dimension not yet supported for Reshape op");
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
mlir::Value value = GetValue(operand);
auto op = builder_.create<mlir::mhlo::ReshapeOp>(loc_, ty, value);
return MakeXlaOp(op.getResult());
}
StatusOr<XlaOp> MlirHloBuilder::DotGeneralInternal(
const Shape& shape, XlaOp lhs, XlaOp rhs,
const DotDimensionNumbers& dimension_number,
const PrecisionConfig* precision_config) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::DotGeneralOp>(
loc_, ty, GetValue(lhs), GetValue(rhs),
ConvertDotDimensionNumbers(dimension_number, &builder_),
ConvertPrecisionConfig(precision_config, &builder_));
return MakeXlaOp(op.getResult());
}
StatusOr<XlaOp> MlirHloBuilder::InDimBroadcast(
const Shape& shape, XlaOp operand,
absl::Span<const int64_t> broadcast_dimensions) {
TF_RETURN_IF_ERROR(first_error());
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
mlir::Value value = GetValue(operand);
auto op = builder_.create<mlir::mhlo::BroadcastInDimOp>(
loc_, ty, value, GetI64ElementsAttr(broadcast_dimensions, &builder_));
return MakeXlaOp(op.getResult());
}
StatusOr<XlaOp> MlirHloBuilder::AddInstruction(
HloInstructionProto&& instr, HloOpcode opcode,
absl::Span<const XlaOp> operands) {
return Unimplemented("MlirHloBuilder does not support op %s",
HloOpcodeString(opcode));
}
StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
XlaOp rhs,
ComparisonDirection direction,
Comparison::Type type) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::CompareOp>(
loc_, ty, GetValue(lhs), GetValue(rhs),
mlir::mhlo::ComparisonDirectionAttr::get(
builder_.getContext(), mlir::mhlo::symbolizeComparisonDirection(
ComparisonDirectionToString(direction))
.getValue()),
mlir::mhlo::ComparisonTypeAttr::get(
builder_.getContext(),
mlir::mhlo::symbolizeComparisonType(ComparisonTypeToString(type))
.getValue()));
return MakeXlaOp(op.getResult());
}
XlaOp MlirHloBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
XlaOp lhs, XlaOp rhs) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs});
});
}
StatusOr<XlaOp> MlirHloBuilder::AddOpWithShape(
HloOpcode opcode, const Shape& shape, absl::Span<const XlaOp> operands) {
return CreateOp(GetMlirOpName(opcode), shape,
llvm::makeArrayRef<XlaOp>(operands.data(), operands.size()));
}
XlaOp MlirHloBuilder::CreateToken() {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
return MakeXlaOp(builder_.create<mlir::mhlo::CreateTokenOp>(
loc_, mlir::mhlo::TokenType::get(builder_.getContext())));
});
}
StatusOr<XlaOp> MlirHloBuilder::TriangularSolveInternal(
const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
auto op = builder_.create<mlir::mhlo::TriangularSolveOp>(
loc_, result_ty, GetValue(a), GetValue(b),
builder_.getBoolAttr(options.left_side()),
builder_.getBoolAttr(options.lower()),
builder_.getBoolAttr(options.unit_diagonal()),
mlir::mhlo::TransposeAttr::get(
builder_.getContext(),
::mlir::mhlo::symbolizeTranspose(
TriangularSolveOptions::Transpose_Name(options.transpose_a()))
.getValue()));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
bool lower) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
auto op = builder_.create<mlir::mhlo::CholeskyOp>(
loc_, result_ty, GetValue(a), builder_.getBoolAttr(lower));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
const Shape& infeed_instruction_shape, XlaOp token,
const std::string& config) {
TF_ASSIGN_OR_RETURN(mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(
infeed_instruction_shape, builder_));
llvm::SmallVector<mlir::Type> flattened_ret_types;
HloFunctionImporter::FlattenTupleType(result_type, flattened_ret_types);
mlir::ArrayAttr layout;
auto op = builder_.create<mlir::mhlo::InfeedOp>(loc_, flattened_ret_types,
GetValue(token),
/*infeed_config=*/config,
/*layout=*/layout);
llvm::SmallVector<mlir::Value> flattened_results = op->getResults();
llvm::MutableArrayRef<mlir::Value> flattened_results_ref(flattened_results);
auto result = HloFunctionImporter::CreateTupleValue(
&builder_, loc_, flattened_results_ref, result_type);
auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
return MakeXlaOp(defining_tuple_op);
}
StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
const std::string& outfeed_config) {
auto token_type = mlir::mhlo::TokenType::get(builder_.getContext());
llvm::SmallVector<mlir::Value> flattened_operands;
HloFunctionImporter::FlattenTupleValue(&builder_, loc_, GetValue(operand),
flattened_operands);
return MakeXlaOp(builder_.create<mlir::mhlo::OutfeedOp>(
loc_, token_type, flattened_operands, GetValue(token), outfeed_config));
}
StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
const Shape& shape, absl::Span<const XlaOp> operands, int64_t dimension) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
auto mlir_operands = GetValues(operands);
return MakeXlaOp(builder_.create<mlir::mhlo::ConcatenateOp>(
loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
}
StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
XlaOp tuple_data,
int64_t index) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
return MakeXlaOp(builder_.create<mlir::mhlo::GetTupleElementOp>(
loc_, result_type, GetValue(tuple_data),
builder_.getI32IntegerAttr(index)));
}
StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64_t> start_indices,
absl::Span<const int64_t> limit_indices,
absl::Span<const int64_t> strides) {
return MakeXlaOp(builder_.create<mlir::mhlo::SliceOp>(
loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
GetI64ElementsAttr(limit_indices, &builder_),
GetI64ElementsAttr(strides, &builder_)));
}
StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64_t> slice_sizes) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
return MakeXlaOp(builder_.create<mlir::mhlo::DynamicSliceOp>(
loc_, result_ty, GetValue(operand), GetValues(start_indices),
GetI64ElementsAttr(slice_sizes, &builder_)));
}
StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
const Shape& shape, XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
return MakeXlaOp(builder_.create<mlir::mhlo::DynamicUpdateSliceOp>(
loc_, result_ty, GetValue(operand), GetValue(update),
GetValues(start_indices)));
}
StatusOr<XlaOp> MlirHloBuilder::PadInternal(
const Shape& shape, XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
llvm::SmallVector<int64_t> low, high, internal;
for (auto& dimension : padding_config.dimensions()) {
low.push_back(dimension.edge_padding_low());
high.push_back(dimension.edge_padding_high());
internal.push_back(dimension.interior_padding());
}
return MakeXlaOp(builder_.create<mlir::mhlo::PadOp>(
loc_, result_type, GetValue(operand), GetValue(padding_value),
GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
GetI64ElementsAttr(internal, &builder_)));
}
StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
const Shape& shape, absl::Span<const XlaOp> elements) {
mlir::SmallVector<mlir::Value, 4> operands;
for (auto& element : elements) {
operands.push_back(GetValue(element));
}
return MakeXlaOp(builder_.create<mlir::mhlo::TupleOp>(loc_, operands));
}
StatusOr<XlaOp> MlirHloBuilder::CreateOp(
const std::string& op_name, const Shape& shape,
llvm::ArrayRef<XlaOp> operands,
llvm::ArrayRef<mlir::NamedAttribute> attributes) {
llvm::SmallVector<mlir::Value, 4> operand_values;
operand_values.reserve(operands.size());
for (XlaOp xla_op : operands) {
operand_values.push_back(GetValue(xla_op));
}
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
mlir::OperationState state(loc_, op_name, operand_values, {ty}, attributes);
mlir::Operation* op = builder_.create(state);
return MakeXlaOp(op->getResult(0));
}
Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
mlir::Region* region,
bool flatten_region_arg_tuple) {
TF_ASSIGN_OR_RETURN(auto module_config,
xla::HloModule::CreateModuleConfigFromProto(
computation, xla::DebugOptions()));
TF_ASSIGN_OR_RETURN(auto hlo_module, xla::HloModule::CreateFromProto(
computation, module_config));
return HloFunctionImporter::ImportAsRegion(*hlo_module->entry_computation(),
region, &builder_,
flatten_region_arg_tuple);
}
StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
TF_RETURN_IF_ERROR(first_error());
TF_RETURN_IF_ERROR(CheckOpBuilder(op));
auto it = handle_to_shape_.find(op.handle());
if (it == handle_to_shape_.end()) {
return InvalidArgument("No XlaOp with handle %d", op.handle());
}
return it->second.get();
}
} // namespace xla