blob: 7b2a07237152a6942d3788cb3c1f02dd6b2555b8 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
#include <climits>
#include <memory>
#include <tuple>
#include "absl/algorithm/container.h"
#include "absl/cleanup/cleanup.h"
#include "absl/types/optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/Verifier.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassOptions.h" // from @llvm-project
#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
using xla::BufferAllocation;
using xla::BufferAssignment;
using xla::HloComputation;
using xla::HloCustomCallInstruction;
using xla::HloInfeedInstruction;
using xla::HloInstruction;
using xla::HloModule;
using xla::HloModuleProto;
using xla::HloOutfeedInstruction;
using xla::HloProto;
using xla::Shape;
using xla::StatusOr;
namespace mlir {
namespace {
absl::string_view StringRefToView(llvm::StringRef ref) {
return {ref.data(), ref.size()};
}
StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
const HloProto& hlo_proto) {
const HloModuleProto& module_proto = hlo_proto.hlo_module();
TF_ASSIGN_OR_RETURN(const xla::HloModuleConfig module_config,
HloModule::CreateModuleConfigFromProto(
module_proto, xla::GetDebugOptionsFromFlags()));
return HloModule::CreateFromProto(module_proto, module_config);
}
} // namespace
// Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the
// given platform.
Status OptimizeAndConvertHloToLmhlo(std::unique_ptr<HloModule> hlo_module,
ModuleOp module, StringRef platform_name,
bool optimize_xla_hlo) {
auto platform = xla::se::MultiPlatformManager::PlatformWithName(
StringRefToView(platform_name));
if (!platform.ok()) {
std::string error_msg;
llvm::raw_string_ostream os(error_msg);
os << "failed to get platform: " << platform.status().ToString()
<< " (available Platform: ";
std::vector<std::string> available_platforms;
(void)xla::se::MultiPlatformManager::PlatformsWithFilter(
[&](const stream_executor::Platform* p) {
available_platforms.push_back(p->Name());
return false;
});
llvm::interleaveComma(available_platforms, os);
os << ")";
return xla::InvalidArgument("%s", os.str().c_str());
}
xla::BackendOptions backend_options;
backend_options.set_platform(platform.ValueOrDie());
auto backend_or_err = xla::Backend::CreateBackend(backend_options);
TF_RETURN_WITH_CONTEXT_IF_ERROR(backend_or_err.status(),
"failed to create XLA Backend ");
auto backend = std::move(backend_or_err.ValueOrDie());
StatusOr<std::unique_ptr<HloModule>> optimized_hlo_module;
if (optimize_xla_hlo) {
// Run all HLO passes to produce an optimized module.
optimized_hlo_module = backend->compiler()->RunHloPasses(
std::move(hlo_module), backend->default_stream_executor(),
backend->memory_allocator());
TF_RETURN_WITH_CONTEXT_IF_ERROR(optimized_hlo_module.status(),
"running XLA pass pipeline");
} else {
optimized_hlo_module = std::move(hlo_module);
}
StatusOr<std::unique_ptr<BufferAssignment>> assignment =
backend->compiler()->AssignBuffers(optimized_hlo_module->get());
TF_RETURN_WITH_CONTEXT_IF_ERROR(assignment.status(),
"running XLA buffer assigment");
// Clear the module before populating it back with the result of the
// conversion.
module.getBody()->clear();
OpBuilder builder(module);
TF_RETURN_WITH_CONTEXT_IF_ERROR(
HloToLhloModule(**assignment, **optimized_hlo_module, module),
"converting HLO to LHLO");
return ::tensorflow::OkStatus();
}
namespace {
// This pass takes an MLIR HLO module, converts it to XLA to perform the HLO
// optimization pipeline for the required platform, and then converts it back to
// MLIR LHLO.
class XlaHloToLhloPass
: public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry& registry) const override {
registry
.insert<arith::ArithmeticDialect, bufferization::BufferizationDialect,
func::FuncDialect, memref::MemRefDialect, mhlo::MhloDialect,
lmhlo::LmhloDialect, lmhlo_gpu::LmhloGpuDialect>();
}
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XlaHloToLhloPass)
XlaHloToLhloPass() = default;
XlaHloToLhloPass(const XlaHloToLhloPass&) {}
StringRef getArgument() const final { return "xla-hlo-to-lhlo-with-xla"; }
StringRef getDescription() const final {
return "Emit LHLO from HLO using the existing XLA implementation";
}
private:
void runOnOperation() final {
ModuleOp module = getOperation();
auto status = [&module, this]() -> Status {
SymbolTable symbol_table(module);
if (!symbol_table.lookup("main")) {
return xla::InvalidArgument(
"conversion to HLO module failed: missing main()");
}
HloProto hlo_proto;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirHloToHlo(module, &hlo_proto,
/*use_tuple_args=*/false,
/*return_tuple=*/false,
/*shape_determination_fns=*/{}),
"conversion to XLA HLO proto failed");
auto statusOrHloModule = HloModuleFromProto(hlo_proto);
TF_RETURN_WITH_CONTEXT_IF_ERROR(statusOrHloModule.status(),
"parsing HLO proto to HLO module failed");
std::unique_ptr<HloModule> hlo_module =
std::move(statusOrHloModule.ValueOrDie());
return OptimizeAndConvertHloToLmhlo(std::move(hlo_module), module,
platform_, optimize_xla_hlo_);
}();
if (!status.ok()) {
module.emitError() << status.ToString();
return signalPassFailure();
}
}
Option<std::string> platform_{
*this, "platform",
llvm::cl::desc("The platform to use for the XLA optimization pipeline."),
llvm::cl::init("Host")};
Option<bool> optimize_xla_hlo_{
*this, "optimize-xla-hlo",
llvm::cl::desc("Whether to apply HLO optimizations."),
llvm::cl::init(true)};
};
} // namespace
// Creates MLIR operands corresponding to operands and results of the XLA HLO
// instruction. If `num_operands` is valid, then only the first `num_operands`
// operands of the HLO instruction will be considered.
Status LhloDialectEmitter::CreateOperands(
const HloInstruction* instr, std::optional<int64_t> num_operands,
TokenLoweringMode token_mode, llvm::SmallVectorImpl<Value>& operands,
size_t& num_arguments, size_t& num_results) {
if (num_operands.value_or(0) > instr->operand_count())
return xla::InvalidArgument("num_operands must be <= operand count");
for (int64_t i = 0; i < num_operands.value_or(instr->operand_count()); ++i) {
TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands,
/*result_subset=*/{}, token_mode));
}
num_arguments = operands.size();
TF_RETURN_IF_ERROR(
GetOrCreateView(instr, &operands, /*result_subset=*/{}, token_mode));
num_results = operands.size() - num_arguments;
return ::tensorflow::OkStatus();
}
template <typename OpType>
OpType LhloDialectEmitter::CreateOpWithoutAttrs(const HloInstruction* instr,
ValueRange operands) {
Location loc = getLocation(instr);
return builder_.create<OpType>(loc, llvm::None, operands,
llvm::ArrayRef<NamedAttribute>{});
}
template <typename OpType>
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
const HloInstruction* instr, size_t& num_arguments, size_t& num_results,
std::optional<int64_t> num_operands) {
llvm::SmallVector<Value, 4> operands;
TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands,
TokenLoweringMode::kFailToLower, operands,
num_arguments, num_results));
return CreateOpWithoutAttrs<OpType>(instr, operands);
}
StatusOr<mlir::Operation*> LhloDialectEmitter::CreateOpInFusion(
const HloInstruction* instr, ValueRange buffer_operands,
size_t num_arguments, size_t num_results) {
Location loc = getLocation(instr);
std::vector<Value> buffers(buffer_operands.begin(), buffer_operands.end());
absl::Span<Value> arguments =
absl::MakeSpan(buffers).subspan(0, num_arguments);
absl::Span<Value> results =
absl::MakeSpan(buffers).subspan(num_arguments, num_results);
mlir::lmhlo::FusionOp fusion = builder_.create<mlir::lmhlo::FusionOp>(loc);
mlir::OpBuilder b(&fusion.getRegion());
llvm::SmallVector<mlir::Value, 4> loads;
for (Value arg : arguments) {
auto load = b.create<mlir::bufferization::ToTensorOp>(loc, arg);
Shape shape = xla::TypeToShape(arg.getType());
TF_RET_CHECK(shape.IsArray());
if (shape.layout() !=
xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
load->setAttr("xla_shape",
b.getStringAttr(shape.ToString(/*print_layout=*/true)));
}
loads.push_back(load);
}
mlir::Operation* op = nullptr;
if (instr->opcode() == xla::HloOpcode::kReduce) {
TF_RET_CHECK(loads.size() % 2 == 0);
std::vector<int64_t> dimensions(instr->dimensions().begin(),
instr->dimensions().end());
auto reduce_op = b.create<mhlo::ReduceOp>(
loc, llvm::makeArrayRef(loads).take_front(loads.size() / 2),
llvm::makeArrayRef(loads).drop_front(loads.size() / 2),
GetI64DenseElementsAttr(dimensions));
TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
*instr->called_computations()[0], &reduce_op.body(), &builder_,
/*flatten_region_arg_tuple=*/true));
op = reduce_op;
} else {
TF_ASSIGN_OR_RETURN(
op,
xla::HloFunctionImporter::ImportInstruction(
instr, loads, &b, xla::DynamicShapeHandlingMode::kConvertToStatic));
}
TF_RET_CHECK(op->getNumResults() == num_results);
for (int i = 0; i < results.size(); i++) {
b.create<mlir::memref::TensorStoreOp>(loc, op->getResult(i), results[i]);
}
return op;
}
StatusOr<mlir::Operation*> LhloDialectEmitter::CreateOpInFusion(
const HloInstruction* instr) {
llvm::SmallVector<Value, 4> operands;
size_t num_arguments, num_results;
TF_RETURN_IF_ERROR(CreateOperands(instr, std::nullopt,
TokenLoweringMode::kFailToLower, operands,
num_arguments, num_results));
TF_ASSIGN_OR_RETURN(
auto op, CreateOpInFusion(instr, operands, num_arguments, num_results));
return op->getParentOp();
}
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
const HloInstruction* instr) {
using xla::HloOpcode;
switch (instr->opcode()) {
case HloOpcode::kAddDependency:
return nullptr;
case HloOpcode::kAfterAll:
// LMHLO is already ordered. This assumption may be broken after
// introducing async regions and partial orders.
return nullptr;
case HloOpcode::kAllToAll:
return EmitAllToAllOp(instr);
case HloOpcode::kAllGather:
return EmitAllGatherOp(instr);
case HloOpcode::kAllReduce:
return EmitAllReduceOp(instr);
case HloOpcode::kAllReduceStart:
return EmitAllReduceStartOp(instr);
case HloOpcode::kAllReduceDone:
return EmitAllReduceDoneOp(instr);
case HloOpcode::kReduceScatter:
return EmitReduceScatterOp(instr);
case HloOpcode::kBitcast:
return EmitBitcast(instr);
case HloOpcode::kCollectivePermute:
return EmitCollectivePermuteOp(instr);
case HloOpcode::kConditional:
return EmitCaseOp(instr);
case HloOpcode::kFft:
return EmitFftOp(instr);
case HloOpcode::kGetTupleElement:
return nullptr;
case HloOpcode::kInfeed:
return EmitInfeedOp(instr);
case HloOpcode::kOutfeed:
return EmitOutfeedOp(instr);
case HloOpcode::kPartitionId:
return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
case HloOpcode::kReplicaId:
return CreateOpWithoutAttrs<lmhlo::ReplicaIdOp>(instr);
case HloOpcode::kTriangularSolve:
return EmitTriangularSolveOp(instr);
case HloOpcode::kTuple:
return nullptr;
case HloOpcode::kSort:
return EmitSortOp(instr);
case HloOpcode::kFusion:
return EmitFusionOp(instr);
case HloOpcode::kScatter:
return EmitScatterOp(instr);
case HloOpcode::kSelectAndScatter:
return EmitSelectAndScatterOp(instr);
case HloOpcode::kCustomCall:
return EmitCustomCallOp(instr);
case HloOpcode::kConstant:
return EmitConstant(instr);
case HloOpcode::kRngGetAndUpdateState:
return EmitRngGetAndUpdateStateOp(instr);
case HloOpcode::kWhile:
return EmitWhileOp(instr);
case HloOpcode::kAbs:
case HloOpcode::kAdd:
case HloOpcode::kAnd:
case HloOpcode::kAtan2:
case HloOpcode::kBitcastConvert:
case HloOpcode::kBroadcast:
case HloOpcode::kCeil:
case HloOpcode::kCbrt:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kConcatenate:
case HloOpcode::kConvert:
case HloOpcode::kCos:
case HloOpcode::kDivide:
case HloOpcode::kDot:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kGather:
case HloOpcode::kImag:
case HloOpcode::kIota:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kMap:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kPad:
case HloOpcode::kPopulationCount:
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kReshape:
case HloOpcode::kReducePrecision:
case HloOpcode::kReduceWindow:
case HloOpcode::kRemainder:
case HloOpcode::kReverse:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRoundNearestEven:
case HloOpcode::kRsqrt:
case HloOpcode::kSelect:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSlice:
case HloOpcode::kSqrt:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kTranspose:
case HloOpcode::kXor:
case HloOpcode::kCopy:
case HloOpcode::kReduce:
return CreateOpInFusion(instr);
default:
llvm::errs() << instr->ToString();
return tensorflow::errors::Internal(
absl::StrCat("LHLO opcode ", xla::HloOpcodeString(instr->opcode()),
" is not supported."));
}
}
Status LhloDialectEmitter::DefaultAction(const HloInstruction* instr) {
return EmitOp(instr).status();
}
StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(
const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
auto* sort_instr = xla::Cast<xla::HloSortInstruction>(instr);
sort.setDimensionAttr(
builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
sort.setIsStableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
*sort_instr->called_computations()[0], &sort.getComparator(), &builder_));
return sort;
}
// Walks MHLO::TupleOp recursively.
Status WalkTuplePostOrder(Value v,
const std::function<Status(Value)>& visitor) {
if (auto* op = v.getDefiningOp()) {
if (auto tuple = dyn_cast<mhlo::TupleOp>(op)) {
for (Value sub_v : tuple.val()) {
TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor));
}
return ::tensorflow::OkStatus();
}
}
return visitor(v);
}
StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
const HloInstruction* root, const Shape& shape,
xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
if (shape.IsTuple()) {
llvm::SmallVector<Value, 4> values;
for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
shape_index->push_back(i);
TF_ASSIGN_OR_RETURN(
auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
b, loc));
values.push_back(v);
shape_index->pop_back();
}
return Value(b->create<mhlo::TupleOp>(loc, values));
}
TF_ASSIGN_OR_RETURN(Value memref,
GetOrCreateArrayView(root, shape, *shape_index));
auto load = b->create<bufferization::ToTensorOp>(loc, memref);
if (shape.layout() !=
xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
llvm::SmallVector<int64_t, 4> minor_to_major(
shape.layout().minor_to_major().begin(),
shape.layout().minor_to_major().end());
load->setAttr("xla_shape",
b->getStringAttr(shape.ToString(/*print_layout=*/true)));
}
return load.getResult();
}
// Emit a lmhlo.fusion based on XLA HLO fusion. Structurally they are not neatly
// equivalent. Specifically, XLA HLO fusion:
// fused_computation {
// %p0 = parameter(0)
// %p1 = parameter(1)
// ...
// ROOT %ret = ...
// }
// will be converted to
// lmhlo.fusion() { // no explicit operands
// // capturing outside buffers
// %p0 = bufferization.to_tensor(%arg0) : memref<...> -> tensor<...>
// %p1 = bufferization.to_tensor(%arg1) : memref<...> -> tensor<...>
// ...
// tensor_store ..., %ret // store a tensor to a memref
// }
StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
const HloInstruction* instr) {
Location loc = getLocation(instr);
auto* fusion_instr = xla::Cast<xla::HloFusionInstruction>(instr);
auto fusion = builder_.create<lmhlo::FusionOp>(getLocation(instr));
auto after_fusion = builder_.saveInsertionPoint();
auto reverter = absl::MakeCleanup(
[this, after_fusion] { builder_.restoreInsertionPoint(after_fusion); });
builder_ = mlir::OpBuilder(fusion);
auto region_builder = OpBuilder::atBlockBegin(&fusion.getRegion().front());
llvm::SmallVector<Value, 8> arguments;
for (int i = 0; i < instr->operands().size(); ++i) {
const HloInstruction* operand = instr->operand(i);
xla::ShapeIndex shape_index;
TF_ASSIGN_OR_RETURN(
auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index,
&region_builder, loc));
arguments.push_back(arg);
}
TF_ASSIGN_OR_RETURN(Value result,
xla::HloFunctionImporter::ImportInstructions(
*fusion_instr->fused_instructions_computation(),
arguments, &region_builder));
{
int i = 0;
llvm::SmallVector<Value, 4> output;
TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output));
TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable {
region_builder.create<memref::TensorStoreOp>(loc, v, output[i++]);
return ::tensorflow::OkStatus();
}));
if (i != output.size()) {
return xla::InternalError("output sizes don't match");
}
}
// Fold GTE/Tuple pairs.
//
// Since the fused region refers to values in its parent region, we can't
// call applyPatternAndFoldGreedily. We optimize it manually.
//
// Only walk once, because post-ordering is exactly what we need for GTE
// optimizations.
fusion.getRegion().walk([](mhlo::GetTupleElementOp gte) {
SmallVector<Value, 4> folded_values;
if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) {
gte.replaceAllUsesWith(folded_values[0]);
}
});
// Effectively a DCE on the region.
{
llvm::SmallVector<mlir::Operation*, 4> ops;
fusion.getRegion().walk([&](mlir::Operation* op) { ops.push_back(op); });
// Visit the user first.
std::reverse(ops.begin(), ops.end());
for (auto op : ops) {
if (isOpTriviallyDead(op)) op->erase();
}
}
return fusion;
}
StatusOr<mhlo::ScatterDimensionNumbersAttr>
LhloDialectEmitter::GetScatterDimensionNumbers(const HloInstruction* instr,
mlir::MLIRContext* context) {
auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
const xla::ScatterDimensionNumbers& xla_scatter_dim =
scatter_instr->scatter_dimension_numbers();
auto get_i64_array = [](absl::Span<const int64_t> container) {
return ArrayRef<int64_t>{container.data(),
static_cast<size_t>(container.size())};
};
auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbersAttr::get(
context, get_i64_array(xla_scatter_dim.update_window_dims()),
get_i64_array(xla_scatter_dim.inserted_window_dims()),
get_i64_array(xla_scatter_dim.scatter_dims_to_operand_dims()),
xla_scatter_dim.index_vector_dim());
return scatter_dimension_numbers;
}
StatusOr<lmhlo::ScatterOp> LhloDialectEmitter::EmitScatterOp(
const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto scatter,
CreateOpWithoutAttrs<lmhlo::ScatterOp>(instr));
// copy attributes
auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers,
GetScatterDimensionNumbers(instr, builder_.getContext()));
scatter.setScatterDimensionNumbersAttr(scatter_dimension_numbers);
scatter.setIndicesAreSortedAttr(
builder_.getBoolAttr(scatter_instr->indices_are_sorted()));
scatter.setUniqueIndicesAttr(
builder_.getBoolAttr(scatter_instr->unique_indices()));
// import update computation as region
TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
*scatter_instr->called_computations()[0], &scatter.getUpdateComputation(),
&builder_));
return scatter;
}
StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto select_and_scatter,
CreateOpWithoutAttrs<lmhlo::SelectAndScatterOp>(instr));
// copy attributes
auto* select_and_scatter_instr =
xla::Cast<xla::HloSelectAndScatterInstruction>(instr);
const xla::Window& window = select_and_scatter_instr->window();
if (xla::window_util::HasDilation(window)) {
return xla::Unimplemented("Dilation for SelectAndScatter is not supported");
}
select_and_scatter.setWindowDimensionsAttr(
GetWindowElements(window, [](const xla::WindowDimension& dim) {
return static_cast<int64_t>(dim.size());
}));
select_and_scatter.setWindowStridesAttr(
GetWindowElements(window, [](const xla::WindowDimension& dim) {
return static_cast<int64_t>(dim.stride());
}));
select_and_scatter.setPaddingAttr(
GetWindowElements(window, [](const xla::WindowDimension& dim) {
return static_cast<int64_t>(dim.padding_low());
}));
// import select and scatter computation as region
TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
*select_and_scatter_instr->select(), &select_and_scatter.getSelect(),
&builder_));
TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
*select_and_scatter_instr->scatter(), &select_and_scatter.getScatter(),
&builder_));
return select_and_scatter;
}
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
const HloInstruction* instr) {
auto* custom_call_instr = xla::Cast<xla::HloCustomCallInstruction>(instr);
if (xla::gpu::IsCustomCallToCusolver(*instr)) {
return EmitCholesky(custom_call_instr);
}
if (xla::gpu::IsCublasGemm(*instr)) {
return EmitGemm(custom_call_instr);
}
if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) {
return EmitDnnConvolution(custom_call_instr);
}
// For custom call, if there are any token operands or results, they will not
// be represented in LHLO so we need to remember the mapping. First create
// operands where each token is replaced with a null Value.
llvm::SmallVector<Value, 4> operands;
size_t num_arguments, num_results;
TF_RETURN_IF_ERROR(CreateOperands(instr, /*num_operands=*/std::nullopt,
TokenLoweringMode::kUseNull, operands,
num_arguments, num_results));
// Now check if any of the operands is Null, which would indicate the presence
// of a token in the input or output.
bool has_token = llvm::any_of(operands, [](Value v) { return !v; });
lmhlo::CustomCallTargetArgMappingAttr target_mapping;
if (has_token) {
// If there was a token, squeeze all the non-token arguments and results
// (in-place) and remember the mapping.
int next_index = 0;
llvm::SmallVector<int64_t> arg_to_target_arg_mapping;
for (int i = 0; i < num_arguments; ++i) {
if (operands[i]) {
arg_to_target_arg_mapping.push_back(i);
operands[next_index++] = operands[i];
}
}
// Size of arg_to_target_arg_mapping is the number of arguments in LHLO.
llvm::SmallVector<int64_t> result_to_target_result_mapping;
for (int i = num_arguments; i < operands.size(); ++i) {
if (operands[i]) {
result_to_target_result_mapping.push_back(i - num_arguments);
operands[next_index++] = operands[i];
}
}
// Build the mapping attribute.
target_mapping = lmhlo::CustomCallTargetArgMappingAttr::get(
builder_.getContext(), num_arguments, num_results,
arg_to_target_arg_mapping, result_to_target_result_mapping);
// Drop the remaining operands and adjust num_arguments and num_results
// for LMHLO creation.
operands.resize(next_index);
num_arguments = arg_to_target_arg_mapping.size();
num_results = result_to_target_result_mapping.size();
}
auto custom_call = CreateOpWithoutAttrs<lmhlo::CustomCallOp>(instr, operands);
TF_ASSIGN_OR_RETURN(
auto mlir_api_version,
ConvertCustomCallApiVersion(custom_call_instr->api_version()));
custom_call.setCallTargetNameAttr(
builder_.getStringAttr(custom_call_instr->custom_call_target()));
custom_call.setBackendConfigAttr(
builder_.getStringAttr(custom_call_instr->opaque()));
custom_call.setApiVersionAttr(mhlo::CustomCallApiVersionAttr::get(
builder_.getContext(), mlir_api_version));
const int32_t segments[2] = {static_cast<int32_t>(num_arguments),
static_cast<int32_t>(num_results)};
custom_call->setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
builder_.getI32VectorAttr(segments));
if (target_mapping) custom_call.setTargetArgMappingAttr(target_mapping);
return custom_call.getOperation();
}
StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
const HloCustomCallInstruction* custom_call) {
TF_ASSIGN_OR_RETURN(auto cholesky_op,
CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call));
TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options,
custom_call->backend_config<xla::CholeskyOptions>());
cholesky_op.setIsLowerAttr(builder_.getBoolAttr(options.lower()));
return cholesky_op;
}
StatusOr<Operation*> LhloDialectEmitter::EmitGemm(
const HloCustomCallInstruction* custom_call) {
TF_ASSIGN_OR_RETURN(
auto const config,
custom_call->backend_config<xla::gpu::GemmBackendConfig>());
auto set_common_attributes = [&](auto op) -> Operation* {
auto arrayref = [](absl::Span<const int64_t> array) {
return llvm::ArrayRef<int64_t>{array.data(), array.size()};
};
auto hlo_dims = config.dot_dimension_numbers();
auto mlir_dims = mhlo::DotDimensionNumbersAttr::get(
builder_.getContext(), arrayref(hlo_dims.lhs_batch_dimensions()),
arrayref(hlo_dims.rhs_batch_dimensions()),
arrayref(hlo_dims.lhs_contracting_dimensions()),
arrayref(hlo_dims.rhs_contracting_dimensions()));
op.setDotDimensionNumbersAttr(mlir_dims);
op.setAlphaRealAttr(builder_.getF64FloatAttr(config.alpha_real()));
op.setAlphaImagAttr(builder_.getF64FloatAttr(config.alpha_imag()));
if (config.algorithm_case() ==
xla::gpu::GemmBackendConfig::kSelectedAlgorithm) {
op.setAlgorithmAttr(
builder_.getI64IntegerAttr(config.selected_algorithm()));
}
return op.getOperation();
};
if (custom_call->operand_count() == 2) {
TF_ASSIGN_OR_RETURN(auto gemm,
CreateOpWithoutAttrs<lmhlo_gpu::GEMMOp>(custom_call));
return set_common_attributes(gemm);
}
if (custom_call->operand_count() == 3) {
TF_ASSIGN_OR_RETURN(
auto gemm_bias,
CreateOpWithoutAttrs<lmhlo_gpu::GEMM_BiasOp>(custom_call));
gemm_bias.setBetaAttr(builder_.getF64FloatAttr(config.beta()));
return set_common_attributes(gemm_bias);
}
return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands");
}
static StatusOr<mlir::lmhlo_gpu::Activation> GetLHLOActivation(
stream_executor::dnn::ActivationMode activation) {
switch (activation) {
case stream_executor::dnn::kNone:
return mlir::lmhlo_gpu::Activation::None;
case stream_executor::dnn::kSigmoid:
return mlir::lmhlo_gpu::Activation::Sigmoid;
case stream_executor::dnn::kRelu:
return mlir::lmhlo_gpu::Activation::Relu;
case stream_executor::dnn::kRelu6:
return mlir::lmhlo_gpu::Activation::Relu6;
case stream_executor::dnn::kReluX:
return mlir::lmhlo_gpu::Activation::ReluX;
case stream_executor::dnn::kTanh:
return mlir::lmhlo_gpu::Activation::Tanh;
case stream_executor::dnn::kBandPass:
return mlir::lmhlo_gpu::Activation::BandPass;
default:
return xla::InternalError("Unknown activation");
}
}
StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution(
const HloCustomCallInstruction* custom_call) {
TF_ASSIGN_OR_RETURN(
auto const backend_config,
custom_call->backend_config<xla::gpu::CudnnConvBackendConfig>());
TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnConvKind kind,
xla::gpu::GetCudnnConvKind(custom_call));
auto get_layout_attribute = [&](const xla::Layout& layout) {
std::vector<int64_t> minor_to_major(layout.minor_to_major_size());
absl::c_transform(layout.minor_to_major(), minor_to_major.begin(),
[](int64_t x) { return static_cast<int64_t>(x); });
return minor_to_major;
};
auto set_common_conv_attributes = [&, this](auto op) -> Operation* {
const xla::Window& window = custom_call->window();
// Window size for Cudnn Conv is same as the kernel size.
NamedAttrList attrs(op->getAttrDictionary());
DenseIntElementsAttr window_strides;
attrs.set(op.getWindowStridesAttrName(),
window_strides = GetWindowElements(
window, [](const xla::WindowDimension& dim) {
return static_cast<int64_t>(dim.stride());
}));
// Cudnn Conv requires low and high padding to be equal.
attrs.set(op.getPaddingAttrName(),
GetWindowElements(window, [](const xla::WindowDimension& dim) {
return static_cast<int64_t>(dim.padding_low());
}));
// LHS dilation is encoded in base_dilation of the backend config.
// RHS dilation is encoded in window_dilation of the backend config.
attrs.set(op.getLhsDilationAttrName(),
GetWindowElements(window, [](const xla::WindowDimension& dim) {
return static_cast<int64_t>(dim.base_dilation());
}));
attrs.set(op.getRhsDilationAttrName(),
GetWindowElements(window, [](const xla::WindowDimension& dim) {
return static_cast<int64_t>(dim.window_dilation());
}));
// Setup window reversal.
auto window_reversal = llvm::to_vector<4>(llvm::map_range(
window.dimensions(),
[](const xla::WindowDimension& dim) { return dim.window_reversal(); }));
auto type = RankedTensorType::get(window_strides.getType().getShape(),
builder_.getIntegerType(/*width=*/1));
attrs.set(op.getWindowReversalAttrName(),
DenseElementsAttr::get(type, window_reversal));
attrs.set(op.getDimensionNumbersAttrName(),
xla::ConvertConvDimensionNumbers(
custom_call->convolution_dimension_numbers(), &builder_));
attrs.set(op.getFeatureGroupCountAttrName(),
builder_.getI64IntegerAttr(custom_call->feature_group_count()));
attrs.set(op.getBatchGroupCountAttrName(),
builder_.getI64IntegerAttr(custom_call->batch_group_count()));
attrs.set(op.getPrecisionConfigAttrName(),
xla::ConvertPrecisionConfig(&custom_call->precision_config(),
&builder_));
attrs.set(op.getResultScaleAttrName(),
builder_.getF64FloatAttr(backend_config.conv_result_scale()));
const auto& algorithm = backend_config.algorithm();
std::vector<int64_t> knob_ids;
std::vector<int64_t> knob_values;
for (const auto& entry : algorithm.tuning_knobs()) {
knob_ids.push_back(entry.first);
knob_values.push_back(entry.second);
}
auto config = mlir::lmhlo_gpu::ConvolutionBackendConfigAttr::get(
builder_.getContext(), algorithm.algo_id(),
algorithm.math_type() ==
stream_executor::dnn::AlgorithmProto::TENSOR_OP_MATH,
knob_ids, knob_values, algorithm.is_cudnn_frontend(),
algorithm.has_workspace_size() ? algorithm.workspace_size().value()
: -1,
get_layout_attribute(custom_call->operand(0)->shape().layout()),
get_layout_attribute(custom_call->operand(1)->shape().layout()),
get_layout_attribute(custom_call->shape().tuple_shapes(0).layout()));
attrs.set(op.getBackendConfigAttrName(), config);
op->setAttrs(attrs.getDictionary(op->getContext()));
return op.getOperation();
};
auto set_activation = [&, this](auto op) -> Status {
auto se_activation = static_cast<stream_executor::dnn::ActivationMode>(
backend_config.activation_mode());
TF_ASSIGN_OR_RETURN(mlir::lmhlo_gpu::Activation activation,
GetLHLOActivation(se_activation));
auto activation_attr = ::mlir::lmhlo_gpu::ActivationAttr::get(
getLocation(custom_call).getContext(), activation);
op.setActivationModeAttr(activation_attr);
return ::tensorflow::OkStatus();
};
switch (kind) {
case xla::gpu::CudnnConvKind::kForward: {
TF_ASSIGN_OR_RETURN(
auto cnn_forward,
CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardOp>(custom_call));
return set_common_conv_attributes(cnn_forward);
}
case xla::gpu::CudnnConvKind::kBackwardInput: {
TF_ASSIGN_OR_RETURN(
auto cnn_backward,
CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardInputOp>(custom_call));
return set_common_conv_attributes(cnn_backward);
}
case xla::gpu::CudnnConvKind::kBackwardFilter: {
TF_ASSIGN_OR_RETURN(
auto cnn_backward,
CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardFilterOp>(custom_call));
return set_common_conv_attributes(cnn_backward);
}
case xla::gpu::CudnnConvKind::kForwardActivation: {
// Fused conv can be either with side input or without.
if (custom_call->operand_count() == 3) {
TF_ASSIGN_OR_RETURN(
auto cnn_fused,
CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedOp>(custom_call));
TF_RETURN_IF_ERROR(set_activation(cnn_fused));
return set_common_conv_attributes(cnn_fused);
}
TF_RET_CHECK(custom_call->operand_count() == 4);
TF_ASSIGN_OR_RETURN(
auto cnn_fused_side_input,
CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedSideInputOp>(
custom_call));
cnn_fused_side_input.setSideInputScaleAttr(
builder_.getF64FloatAttr(backend_config.side_input_scale()));
TF_RETURN_IF_ERROR(set_activation(cnn_fused_side_input));
return set_common_conv_attributes(cnn_fused_side_input);
}
}
}
// Convert an XLA HLO constant to a global_memref + get_global_memref pair.
StatusOr<mlir::memref::GetGlobalOp> LhloDialectEmitter::EmitConstant(
const HloInstruction* instr) {
auto& cached_value = slices_[std::make_pair(instr, xla::ShapeIndex())];
if (cached_value) {
return dyn_cast<mlir::memref::GetGlobalOp>(cached_value.getDefiningOp());
}
// Insert a global_memref in the module.
Location loc = getLocation(instr);
auto const_instr = xla::Cast<xla::HloConstantInstruction>(instr);
TF_RET_CHECK(const_instr->shape().IsArray() &&
const_instr->shape().is_static());
TF_ASSIGN_OR_RETURN(Type type, xla::ConvertShapeToType<MemRefType>(
const_instr->shape(), builder_));
auto memref_type = type.dyn_cast<MemRefType>();
TF_RET_CHECK(memref_type != nullptr);
TF_ASSIGN_OR_RETURN(
DenseElementsAttr initial_value,
CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_));
std::string constant_name = xla::llvm_ir::ConstantNameToGlobalName(
xla::llvm_ir::SanitizeConstantName(instr->name()));
// Insert the global memref at the top level.
{
OpBuilder::InsertionGuard guard(builder_);
builder_.clearInsertionPoint();
auto global_var = builder_.create<memref::GlobalOp>(
loc, constant_name, builder_.getStringAttr("private"), memref_type,
initial_value, true, /*alignment=*/IntegerAttr());
SymbolTable(module_).insert(global_var);
global_var.getOperation()->moveBefore(&module_.front());
// For operations that do not fold this constant value in their codegen, we
// still need to materialize it into a buffer. Since buffer allocation is
// already done, annotate the global_memref with the information to get to
// the allocated buffer slice for this constant if need be.
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
assignment_.GetUniqueTopLevelSlice(instr));
global_var->setAttr(
"lmhlo.alloc",
builder_.getIndexAttr(allocations_.find(slice.allocation())
->second.cast<BlockArgument>()
.getArgNumber()));
TF_RET_CHECK(slice.offset() == 0)
<< "Each constant should have its own allocation from BufferAssignment";
TF_RET_CHECK(slice.allocation()->size() == slice.size())
<< "Each constant should have its own allocation from BufferAssignment";
}
auto get_global_memref =
builder_.create<memref::GetGlobalOp>(loc, memref_type, constant_name);
// Update the cache to remember this value.
cached_value = get_global_memref;
return get_global_memref;
}
namespace {
template <typename OpT>
void SetupChannelIdAttribute(OpT op, const xla::HloChannelInstruction* instr,
mlir::Builder builder) {
if (instr->channel_id().has_value()) {
op.setChannelIdAttr(mlir::mhlo::ChannelHandleAttr::get(
builder.getContext(), *instr->channel_id(), 0));
}
}
template <typename OpT>
Status SetupCommonCollectiveOpAttributes(OpT op, const HloInstruction* instr,
mlir::OpBuilder& builder) {
auto* collective = xla::Cast<xla::HloCollectiveInstruction>(instr);
auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
collective->replica_groups(), &builder);
op->setAttr(replica_groups_attr.getName(), replica_groups_attr.getValue());
op.setConstrainLayoutAttr(
builder.getBoolAttr(collective->constrain_layout()));
SetupChannelIdAttribute(op, collective, builder);
return ::tensorflow::OkStatus();
}
} // namespace
StatusOr<lmhlo::AllToAllOp> LhloDialectEmitter::EmitAllToAllOp(
const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto all_to_all_op,
CreateOpWithoutAttrs<lmhlo::AllToAllOp>(instr));
auto* all_to_all = xla::Cast<xla::HloAllToAllInstruction>(instr);
TF_RETURN_IF_ERROR(
SetupCommonCollectiveOpAttributes(all_to_all_op, instr, builder_));
if (all_to_all->split_dimension().has_value()) {
all_to_all_op.setSplitDimensionAttr(
builder_.getI64IntegerAttr(*all_to_all->split_dimension()));
}
return all_to_all_op;
}
StatusOr<lmhlo::AllGatherOp> LhloDialectEmitter::EmitAllGatherOp(
const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto all_gather_op,
CreateOpWithoutAttrs<lmhlo::AllGatherOp>(instr));
auto* all_gather = xla::Cast<xla::HloAllGatherInstruction>(instr);
TF_RETURN_IF_ERROR(
SetupCommonCollectiveOpAttributes(all_gather_op, instr, builder_));
all_gather_op.setUseGlobalDeviceIdsAttr(
builder_.getBoolAttr(all_gather->use_global_device_ids()));
all_gather_op.setAllGatherDimensionAttr(
builder_.getI64IntegerAttr(all_gather->all_gather_dimension()));
return all_gather_op;
}
StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto all_reduce_op,
CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
TF_RETURN_IF_ERROR(
SetupCommonCollectiveOpAttributes(all_reduce_op, instr, builder_));
all_reduce_op.setUseGlobalDeviceIdsAttr(
builder_.getBoolAttr(all_reduce->use_global_device_ids()));
TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
*instr->called_computations()[0], &all_reduce_op.getComputation(),
&builder_));
return all_reduce_op;
}
StatusOr<lmhlo_gpu::AllReduceStartOp> LhloDialectEmitter::EmitAllReduceStartOp(
const HloInstruction* instr) {
llvm::SmallVector<Value, 4> operands;
for (const HloInstruction* operand : instr->operands()) {
TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
}
TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{}));
Location loc = getLocation(instr);
mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext());
std::array<mlir::Type, 1> result_types = {token_type};
lmhlo_gpu::AllReduceStartOp all_reduce_start_op =
builder_.create<lmhlo_gpu::AllReduceStartOp>(loc, result_types, operands);
auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
TF_RETURN_IF_ERROR(
SetupCommonCollectiveOpAttributes(all_reduce_start_op, instr, builder_));
all_reduce_start_op.setUseGlobalDeviceIdsAttr(
builder_.getBoolAttr(all_reduce->use_global_device_ids()));
TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
*instr->called_computations()[0], &all_reduce_start_op.getComputation(),
&builder_));
TF_RET_CHECK(all_reduce_start_ops_.emplace(instr, all_reduce_start_op).second)
<< "all-reduce-start already lowered";
return all_reduce_start_op;
}
StatusOr<lmhlo_gpu::AllReduceDoneOp> LhloDialectEmitter::EmitAllReduceDoneOp(
const HloInstruction* instr) {
auto it = all_reduce_start_ops_.find(instr->operand(0));
TF_RET_CHECK(it != all_reduce_start_ops_.end())
<< "didn't find all-reduce-start op";
llvm::SmallVector<Value, 4> operands;
operands.push_back(it->second.getToken());
all_reduce_start_ops_.erase(it);
for (const HloInstruction* operand : instr->operands()) {
TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
}
// We don't need to add buffers for the outputs, as these always alias inputs.
return builder_.create<lmhlo_gpu::AllReduceDoneOp>(
getLocation(instr), /*resultTypes=*/llvm::None, operands);
}
StatusOr<lmhlo::ReduceScatterOp> LhloDialectEmitter::EmitReduceScatterOp(
const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto reduce_scatter_op,
CreateOpWithoutAttrs<lmhlo::ReduceScatterOp>(instr));
auto* ars = xla::Cast<xla::HloReduceScatterInstruction>(instr);
TF_RETURN_IF_ERROR(
SetupCommonCollectiveOpAttributes(reduce_scatter_op, instr, builder_));
reduce_scatter_op.setUseGlobalDeviceIdsAttr(
builder_.getBoolAttr(ars->use_global_device_ids()));
TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
*instr->called_computations()[0], &reduce_scatter_op.getComputation(),
&builder_));
reduce_scatter_op.setScatterDimensionAttr(
builder_.getI64IntegerAttr(ars->scatter_dimension()));
return reduce_scatter_op;
}
StatusOr<lmhlo::CollectivePermuteOp>
LhloDialectEmitter::EmitCollectivePermuteOp(const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto permute_op,
CreateOpWithoutAttrs<lmhlo::CollectivePermuteOp>(instr));
auto* permute = xla::Cast<xla::HloCollectivePermuteInstruction>(instr);
SetupChannelIdAttribute(permute_op, permute, builder_);
mlir::NamedAttribute source_target_pairs_attr =
xla::HloFunctionImporter::ConvertSourceTargetPairs(
permute->source_target_pairs(), &builder_);
permute_op->setAttr(source_target_pairs_attr.getName(),
source_target_pairs_attr.getValue());
return permute_op;
}
StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
const HloInstruction* instr) {
const HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);
// HLO Infeed instruction has a single operand of token type and a tuple
// with buffers and a token as its output. LMHLO Infeed operation does not
// need the token operand or result, so drop it.
SmallVector<Value, 2> operands;
TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0}));
auto infeed_op = CreateOpWithoutAttrs<lmhlo::InfeedOp>(instr, operands);
infeed_op.setConfigAttr(builder_.getStringAttr(infeed->infeed_config()));
return infeed_op;
}
StatusOr<lmhlo::OutfeedOp> LhloDialectEmitter::EmitOutfeedOp(
const HloInstruction* instr) {
const HloOutfeedInstruction* outfeed =
xla::Cast<HloOutfeedInstruction>(instr);
// HLO outfeed instruction has 2 operands, the source and a token, and a
// single token output. LMHLO Outfeed does not need the token operand and
// result, do drop it.
SmallVector<Value, 2> operands;
TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands));
auto outfeed_op = CreateOpWithoutAttrs<lmhlo::OutfeedOp>(instr, operands);
outfeed_op.setConfigAttr(builder_.getStringAttr(outfeed->outfeed_config()));
return outfeed_op;
}
xla::StatusOr<lmhlo::RngGetAndUpdateStateOp>
LhloDialectEmitter::EmitRngGetAndUpdateStateOp(
const xla::HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(
auto rng, CreateOpWithoutAttrs<lmhlo::RngGetAndUpdateStateOp>(instr));
auto hlo_rng = xla::Cast<xla::HloRngGetAndUpdateStateInstruction>(instr);
rng.setDeltaAttr(builder_.getI64IntegerAttr(hlo_rng->delta()));
return rng;
}
xla::StatusOr<lmhlo::FftOp> LhloDialectEmitter::EmitFftOp(
const HloInstruction* instr) {
auto hlo_fft = xla::Cast<xla::HloFftInstruction>(instr);
TF_ASSIGN_OR_RETURN(auto fft, CreateOpWithoutAttrs<lmhlo::FftOp>(instr));
TF_ASSIGN_OR_RETURN(mlir::mhlo::FftType fft_type,
xla::ConvertFftType(hlo_fft->fft_type()));
fft.setFftTypeAttr(
mlir::mhlo::FftTypeAttr::get(builder_.getContext(), fft_type));
fft.setFftLengthAttr(GetI64DenseElementsAttr(instr->fft_length()));
return fft;
}
xla::StatusOr<lmhlo::TriangularSolveOp>
LhloDialectEmitter::EmitTriangularSolveOp(const xla::HloInstruction* instr) {
auto hlo_triangular_solve =
xla::Cast<xla::HloTriangularSolveInstruction>(instr);
TF_ASSIGN_OR_RETURN(auto triangular_solve,
CreateOpWithoutAttrs<lmhlo::TriangularSolveOp>(instr));
const xla::TriangularSolveOptions& options =
hlo_triangular_solve->triangular_solve_options();
triangular_solve.setLeftSideAttr(builder_.getBoolAttr(options.left_side()));
triangular_solve.setLowerAttr(builder_.getBoolAttr(options.lower()));
triangular_solve.setUnitDiagonalAttr(
builder_.getBoolAttr(options.unit_diagonal()));
TF_ASSIGN_OR_RETURN(mlir::mhlo::Transpose transpose,
xla::ConvertTranspose(options.transpose_a()));
triangular_solve.setTransposeAAttr(
mlir::mhlo::TransposeAttr::get(builder_.getContext(), transpose));
triangular_solve.setLayoutAAttr(
GetLayoutAttribute(instr->operand(0)->shape().layout(), &builder_));
triangular_solve.setLayoutBAttr(
GetLayoutAttribute(instr->operand(1)->shape().layout(), &builder_));
triangular_solve.setLayoutOutputAttr(
GetLayoutAttribute(instr->shape().layout(), &builder_));
return triangular_solve;
}
xla::StatusOr<Operation*> LhloDialectEmitter::EmitBitcast(
const xla::HloInstruction* instr) {
// XLA buffer assignment should assign the same slice to a bitcast input and
// output.
const xla::ShapeIndex top_index;
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
assignment_.GetUniqueSlice(instr, top_index));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
assignment_.GetUniqueSlice(instr->operand(0), top_index));
if (input_slice != result_slice) {
return xla::InvalidArgument(
"Bitcast input and result slice should be same");
}
return nullptr;
}
mlir::DenseIntElementsAttr LhloDialectEmitter::GetLayoutAttribute(
const xla::Layout& layout, Builder* builder) {
llvm::SmallVector<int64_t, 4> minor_to_major(layout.minor_to_major().begin(),
layout.minor_to_major().end());
return builder->getIndexTensorAttr(minor_to_major);
}
Status LhloDialectEmitter::ImportAsLmhloRegion(xla::HloComputation* computation,
mlir::Region* region) {
auto after = builder_.saveInsertionPoint();
auto reverter = absl::MakeCleanup(
[this, after] { builder_.restoreInsertionPoint(after); });
builder_ = OpBuilder(region);
const xla::HloInstructionSequence* schedule =
assignment_.hlo_ordering().SequentialOrder(*computation);
if (!schedule)
return xla::Unimplemented("Missing sequential order for the computation");
TF_RETURN_IF_ERROR(
computation->AcceptOrdered(this, schedule->instructions()));
builder_.create<lmhlo::TerminatorOp>(builder_.getUnknownLoc());
return ::tensorflow::OkStatus();
}
StatusOr<lmhlo::CaseOp> LhloDialectEmitter::EmitCaseOp(
const HloInstruction* instr) {
Location loc = getLocation(instr);
llvm::SmallVector<Value, 4> operands;
size_t num_arguments, num_results;
TF_RETURN_IF_ERROR(CreateOperands(instr, 1, TokenLoweringMode::kUseNull,
operands, num_arguments, num_results));
auto case_op =
builder_.create<lmhlo::CaseOp>(loc, operands[0], instr->branch_count());
for (int i = 0; i < instr->branch_count(); i++) {
case_op.getBranches()[i].push_back(new mlir::Block());
TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[i],
&case_op.getBranches()[i]));
}
return case_op;
}
xla::StatusOr<lmhlo::WhileOp> LhloDialectEmitter::EmitWhileOp(
const xla::HloInstruction* instr) {
Location loc = getLocation(instr);
SmallVector<Value, 1> operands;
TF_RETURN_IF_ERROR(GetOrCreateView(
instr->called_computations()[1]->root_instruction(), &operands));
TF_RET_CHECK(operands.size() == 1);
TF_ASSIGN_OR_RETURN(auto config,
instr->backend_config<xla::WhileLoopBackendConfig>());
mlir::IntegerAttr trip_count;
if (config.has_known_trip_count()) {
trip_count = builder_.getI64IntegerAttr(config.known_trip_count().n());
}
lmhlo::WhileOp while_op =
builder_.create<lmhlo::WhileOp>(loc, operands[0], trip_count);
while_op.getCond().push_back(new mlir::Block());
while_op.getBody().push_back(new mlir::Block());
TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[1],
&while_op.getCond()));
TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[0],
&while_op.getBody()));
return while_op;
}
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
const xla::HloInstruction* instr, const xla::Shape& current_shape,
const xla::ShapeIndex& shape_index) {
// For constants, the cache is managed inside EmitConstant since it can
// be called either from here or when we see a top-level HloConstant instr.
if (instr->IsConstant() && shape_index.empty()) {
TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr));
return constant_memref;
}
// Cache generated ViewOp and StaticMemRefCastOp by (instruction,
// shape_index).
auto& cached_value = slices_[std::make_pair(instr, shape_index)];
if (cached_value) {
return cached_value;
}
// If the shape happens to have dynamic dimensions, create the memref using
// the underlying static shape.
// TODO(jurahul): Revisit this when we can model memrefs with dynamic shape
// but static bounds in MLIR.
const Shape static_shape = xla::ShapeUtil::MakeStaticShape(current_shape);
TF_ASSIGN_OR_RETURN(Type out_type, xla::ConvertShapeToType<MemRefType>(
static_shape, builder_));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
assignment_.GetUniqueSlice(instr, shape_index));
Value alloc = allocations_[slice.allocation()];
// TODO(timshen): revisit location handling.
Location loc = builder_.getUnknownLoc();
Value byte_shift =
builder_.create<arith::ConstantIndexOp>(alloc.getLoc(), slice.offset());
xla::Shape physical_shape =
xla::ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
static_shape);
TF_ASSIGN_OR_RETURN(
Type physical_out_type,
xla::ConvertShapeToType<MemRefType>(physical_shape, builder_));
// ViewOp only takes memrefs without affine maps (layouts). Let ViewOp
// produce the physical shape (where dimensions are ordered in major to
// minor) first, then follow up with a MemRefReinterpretCast to cast the
// resulting memref to the original layout.
Value result =
builder_.create<memref::ViewOp>(loc, physical_out_type, alloc, byte_shift,
/*sizes=*/ValueRange{});
if (result.getType() != out_type) {
int64_t out_offset;
SmallVector<int64_t, 4> out_strides;
auto out_memref_type = out_type.dyn_cast<MemRefType>();
if (!out_memref_type)
return tensorflow::errors::Internal(
"Expected memref type when creating a view for leaf type of a "
"tuple.");
if (failed(getStridesAndOffset(out_memref_type, out_strides, out_offset)))
return tensorflow::errors::Internal(
"Failed to get strides and offset from the output type.");
result = builder_.create<memref::ReinterpretCastOp>(
loc, out_memref_type, result, out_offset, out_memref_type.getShape(),
out_strides);
}
return cached_value = result;
}
Status LhloDialectEmitter::GetOrCreateViewImpl(
const HloInstruction* instr, const Shape& current_shape,
xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values,
TokenLoweringMode token_mode) {
if (current_shape.IsTuple()) {
for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) {
current_shape_index->push_back(i);
TF_RETURN_IF_ERROR(
GetOrCreateViewImpl(instr, current_shape.tuple_shapes(i),
current_shape_index, values, token_mode));
current_shape_index->pop_back();
}
return ::tensorflow::OkStatus();
}
if (current_shape.IsArray()) {
TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape,
*current_shape_index));
values->push_back(v);
return ::tensorflow::OkStatus();
}
if (current_shape.IsToken()) {
switch (token_mode) {
case TokenLoweringMode::kFailToLower:
return xla::InternalError(
"Unexpected token kind for %s and shape index %s",
instr->ToString(), current_shape_index->ToString());
case TokenLoweringMode::kUseNull:
values->push_back(Value{});
return ::tensorflow::OkStatus();
}
}
return xla::InternalError("Unexpected shape kind for %s and shape index %s",
instr->ToString(), current_shape_index->ToString());
}
// Returns a view for the result of an instruction.
// We first get a view for the slice in the allocation, and then may need to
// create another view to adjust the slice for the shape of the instruction.
Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr,
SmallVectorImpl<Value>* values,
const xla::ShapeIndex& result_subset,
TokenLoweringMode token_mode) {
xla::ShapeIndex shape_index = result_subset;
const Shape& sub_shape =
xla::ShapeUtil::GetSubshape(instr->shape(), shape_index);
return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values,
token_mode);
}
Status LhloDialectEmitter::Initialize() {
TF_RET_CHECK(computation_.IsEntryComputation());
mlir::IntegerAttr unique_id =
builder_.getI32IntegerAttr(computation_.parent()->unique_id());
module_->setAttr("hlo.unique_id", unique_id);
std::string function_name =
computation_.name().empty() ? "__compute" : computation_.name();
// Create the function as () -> (), we'll compute the arguments from the
// buffer allocation and update the type then.
auto func_op = func::FuncOp::create(builder_.getUnknownLoc(), function_name,
builder_.getFunctionType({}, {}));
{
// This is an optional attribute used by the XLA backend. If the resulting
// LMHLO doesn't go through XLA, this is not needed.
const Shape& shape = computation_.root_instruction()->shape();
func_op->setAttr(
"result_xla_shape",
builder_.getStringAttr(shape.ToString(/*print_layout=*/true)));
}
Block* block = func_op.addEntryBlock();
llvm::SmallVector<const BufferAllocation*, 8> ordered_allocations;
for (const BufferAllocation& alloc : assignment_.Allocations())
ordered_allocations.push_back(&alloc);
if (computation_.IsEntryComputation()) {
// Sort the rather arbitrarily ordered allocations to match the input/output
// parameters. Specifically we want to sort buffer allocations in the
// following order:
// * Parameters always order before non-parameters.
// * Different parameters order by parameter number.
// * Different allocations for the same parameter order by the shape index.
//
// TODO(timshen): there should be only one non-parameter buffer, the temp
// buffer. Check on that.
const auto allocation_comparator = [](const BufferAllocation* lhs,
const BufferAllocation* rhs) {
if (lhs->is_entry_computation_parameter() !=
rhs->is_entry_computation_parameter()) {
return lhs->is_entry_computation_parameter() >
rhs->is_entry_computation_parameter();
}
if (lhs->is_entry_computation_parameter()) {
return std::tuple<int, const xla::ShapeIndex&>(
lhs->parameter_number(), lhs->param_shape_index()) <
std::tuple<int, const xla::ShapeIndex&>(
rhs->parameter_number(), rhs->param_shape_index());
}
return false;
};
std::stable_sort(ordered_allocations.begin(), ordered_allocations.end(),
allocation_comparator);
}
absl::flat_hash_map<const BufferAllocation*,
std::pair<const Shape*, xla::ShapeIndex>>
allocation_to_output_info;
TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
computation_.root_instruction()->shape(),
[&](const Shape& sub_shape, xla::ShapeIndex index) -> Status {
TF_ASSIGN_OR_RETURN(
auto slice,
assignment_.GetUniqueSlice(computation_.root_instruction(), index));
const BufferAllocation* alloc = slice.allocation();
TF_RET_CHECK(slice.offset() == 0);
TF_RET_CHECK(slice.size() == alloc->size());
allocation_to_output_info[alloc] = std::make_pair(&sub_shape, index);
return ::tensorflow::OkStatus();
}));
// The function signature will be composed of:
// - one memref for each of the parameters.
// - one memref for each other buffer allocation.
llvm::SmallVector<DictionaryAttr, 8> args_attrs;
for (const BufferAllocation* alloc : ordered_allocations) {
if (alloc->is_thread_local()) {
continue;
}
// There are optional attributes to help the program run through XLA. XLA
// defines ExecutionInput and ExecutionOutput structures to carry
// input-output type and buffer information, therefore any information they
// need (mainly the type structure, potentially containing tuples) to be
// preserved. They are not needed if the generated LMHLO is not sent to XLA.
NamedAttrList arg_attr_list;
mlir::Type arg_type = MemRefType::get({alloc->size()}, i8_type_);
// Propagate source location information for every HLOInstruction that
// uses this allocation.
std::vector<mlir::Location> buf_locs;
buf_locs.reserve(alloc->assigned_buffers().size());
for (const auto& entry : alloc->assigned_buffers()) {
const xla::HloValue* hlo_value = entry.first;
buf_locs.push_back(getLocation(hlo_value->instruction()));
}
mlir::Location loc = builder_.getFusedLoc(buf_locs);
if (alloc->is_entry_computation_parameter()) {
arg_attr_list.set("lmhlo.params",
builder_.getIndexAttr(alloc->parameter_number()));
if (!alloc->param_shape_index().empty()) {
arg_attr_list.set("lmhlo.param_shape_index",
builder_.getI64TensorAttr(llvm::makeArrayRef(
alloc->param_shape_index().begin(),
alloc->param_shape_index().end())));
}
}
// Optional: an attribute for optimization. If a kernel uses this
// allocation, but the allocation has lmhlo.constant_name, then the kernel
// will instead use the global value indicated by the name for potentially
// more optimizations (e.g. constant propagation).
if (alloc->is_constant()) {
arg_attr_list.set(
"lmhlo.constant_name",
builder_.getStringAttr(
xla::llvm_ir::ConstantBufferAllocationToGlobalName(*alloc)));
}
auto iter = allocation_to_output_info.find(alloc);
if (iter != allocation_to_output_info.end()) {
const Shape* sub_shape = iter->second.first;
const xla::ShapeIndex& shape_index = iter->second.second;
if (!sub_shape->IsArray()) {
continue;
}
arg_attr_list.set("lmhlo.output_index",
builder_.getI64TensorAttr(llvm::makeArrayRef(
shape_index.begin(), shape_index.end())));
if (auto alias = computation_.parent()
->input_output_alias_config()
.GetAliasedParameter(shape_index)) {
if (alias->must_alias()) {
arg_attr_list.set("lmhlo.must_alias", builder_.getUnitAttr());
}
}
}
block->addArgument(arg_type, loc);
allocations_[alloc] = block->getArguments().back();
args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
}
FunctionType function_type =
builder_.getFunctionType(block->getArgumentTypes(), {});
func_op.setType(function_type);
func_op.setAllArgAttrs(args_attrs);
SymbolTable symbol_table(module_);
symbol_table.insert(func_op);
builder_.setInsertionPointToEnd(block);
auto return_op =
builder_.create<lmhlo::TerminatorOp>(builder_.getUnknownLoc());
builder_ = OpBuilder(return_op);
return ::tensorflow::OkStatus();
}
std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() {
return std::make_unique<XlaHloToLhloPass>();
}
Status HloToLhloModule(const BufferAssignment& assignment,
const HloModule& hlo_module, ModuleOp module) {
module.getContext()
->loadDialect<arith::ArithmeticDialect,
bufferization::BufferizationDialect, func::FuncDialect,
memref::MemRefDialect, mhlo::MhloDialect,
lmhlo::LmhloDialect, lmhlo_gpu::LmhloGpuDialect>();
module->setLoc(mlir::NameLoc::get(
mlir::StringAttr::get(module.getContext(), hlo_module.name())));
// Store the HloModule's unique_id in the MLIR module.
Builder builder(module.getContext());
module->setAttr("mhlo.unique_id",
builder.getI64IntegerAttr(hlo_module.unique_id()));
const HloComputation* computation = hlo_module.entry_computation();
LhloDialectEmitter emitter(assignment, *computation, module);
TF_RETURN_IF_ERROR(emitter.Initialize());
const xla::HloInstructionSequence* schedule =
assignment.hlo_ordering().SequentialOrder(*computation);
if (!schedule)
return xla::Unimplemented("Missing sequential order for the computation");
StatusScopedDiagnosticHandler status_handler(module.getContext());
const std::vector<HloInstruction*>& ordering = schedule->instructions();
TF_RETURN_IF_ERROR(computation->AcceptOrdered(&emitter, ordering));
TF_RETURN_IF_ERROR(status_handler.ConsumeStatus());
(void)mlir::verify(module);
return status_handler.ConsumeStatus();
}
OwningOpRef<mlir::ModuleOp> HloTextToLhloTranslateFunction(
llvm::StringRef input, MLIRContext* context, bool optimize_xla_hlo) {
StatusOr<std::unique_ptr<HloModule>> maybe_module =
xla::ParseAndReturnUnverifiedModule(
absl::string_view(input.data(), input.size()));
TF_CHECK_OK(maybe_module.status());
OwningOpRef<mlir::ModuleOp> module =
ModuleOp::create(UnknownLoc::get(context));
TF_CHECK_OK(OptimizeAndConvertHloToLmhlo(maybe_module.ConsumeValueOrDie(),
module.get(), "Host",
optimize_xla_hlo));
return module;
}
void RegisterMhloToLhloWithXlaPass() {
static PassRegistration<XlaHloToLhloPass> registration;
}
} // namespace mlir