| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" |
| |
| #include <climits> |
| #include <memory> |
| #include <tuple> |
| |
| #include "absl/algorithm/container.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project |
| #include "mlir/IR/AffineExpr.h" // from @llvm-project |
| #include "mlir/IR/AffineMap.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Dialect.h" // from @llvm-project |
| #include "mlir/IR/Location.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/IR/OpDefinition.h" // from @llvm-project |
| #include "mlir/IR/Operation.h" // from @llvm-project |
| #include "mlir/IR/PatternMatch.h" // from @llvm-project |
| #include "mlir/IR/SymbolTable.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Pass/PassOptions.h" // from @llvm-project |
| #include "mlir/Translation.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h" |
| #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" |
| #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" |
| #include "tensorflow/compiler/mlir/xla/attribute_importer.h" |
| #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" |
| #include "tensorflow/compiler/mlir/xla/hlo_utils.h" |
| #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" |
| #include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h" |
| #include "tensorflow/compiler/xla/debug_options_flags.h" |
| #include "tensorflow/compiler/xla/service/backend.h" |
| #include "tensorflow/compiler/xla/service/buffer_assignment.h" |
| #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" |
| #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" |
| #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" |
| #include "tensorflow/compiler/xla/service/hlo_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_instructions.h" |
| #include "tensorflow/compiler/xla/service/hlo_module.h" |
| #include "tensorflow/compiler/xla/service/hlo_parser.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/compiler/xla/util.h" |
| |
| using xla::BufferAllocation; |
| using xla::BufferAssignment; |
| using xla::HloComputation; |
| using xla::HloCustomCallInstruction; |
| using xla::HloInfeedInstruction; |
| using xla::HloInstruction; |
| using xla::HloModule; |
| using xla::HloModuleProto; |
| using xla::HloOutfeedInstruction; |
| using xla::HloProto; |
| using xla::Shape; |
| using xla::StatusOr; |
| |
| namespace mlir { |
| namespace { |
| |
| absl::string_view StringRefToView(llvm::StringRef ref) { |
| return {ref.data(), ref.size()}; |
| } |
| |
| StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto( |
| const HloProto& hlo_proto) { |
| const HloModuleProto& module_proto = hlo_proto.hlo_module(); |
| TF_ASSIGN_OR_RETURN(const xla::HloModuleConfig module_config, |
| HloModule::CreateModuleConfigFromProto( |
| module_proto, xla::GetDebugOptionsFromFlags())); |
| return HloModule::CreateFromProto(module_proto, module_config); |
| } |
| |
| // Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the |
| // given platform. |
| Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module, |
| StringRef platform_name) { |
| auto platform = xla::se::MultiPlatformManager::PlatformWithName( |
| StringRefToView(platform_name)); |
| if (!platform.ok()) { |
| std::string error_msg; |
| llvm::raw_string_ostream os(error_msg); |
| os << "failed to get platform: " << platform.status().ToString() |
| << " (available Platform: "; |
| std::vector<std::string> available_platforms; |
| (void)xla::se::MultiPlatformManager::PlatformsWithFilter( |
| [&](const stream_executor::Platform* p) { |
| available_platforms.push_back(p->Name()); |
| return false; |
| }); |
| llvm::interleaveComma(available_platforms, os); |
| os << ")"; |
| return xla::InvalidArgument("%s", os.str().c_str()); |
| } |
| |
| xla::BackendOptions backend_options; |
| backend_options.set_platform(platform.ValueOrDie()); |
| auto backend_or_err = xla::Backend::CreateBackend(backend_options); |
| TF_RETURN_WITH_CONTEXT_IF_ERROR(backend_or_err.status(), |
| "failed to create XLA Backend "); |
| auto backend = std::move(backend_or_err.ValueOrDie()); |
| |
| // Run all HLO passes to produce an optimized module. |
| auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement( |
| std::move(hlo_module), backend->default_stream_executor(), |
| optimize_xla_hlo, {backend->memory_allocator()}); |
| TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(), |
| "running XLA pass pipeline"); |
| std::unique_ptr<HloModule> optimized_hlo_module = |
| std::move(std::get<0>(result_or.ValueOrDie())); |
| std::unique_ptr<BufferAssignment> assignment = |
| std::move(std::get<1>(result_or.ValueOrDie())); |
| |
| // Clear the module before populating it back with the result of the |
| // conversion. |
| module.getBody()->clear(); |
| OpBuilder builder(module); |
| module.ensureTerminator(module.getBodyRegion(), builder, module.getLoc()); |
| |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| HloToLhloModule(*assignment, *optimized_hlo_module, module), |
| "converting HLO to LHLO"); |
| |
| return Status::OK(); |
| } |
| |
| // This pass takes an MLIR HLO module, converts it to XLA to perform the HLO |
| // optimization pipeline for the required platform, and then converts it back to |
| // MLIR LHLO. |
| class XlaHloToLhloPass |
| : public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> { |
| void getDependentDialects(DialectRegistry& registry) const override { |
| registry |
| .insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect, |
| mlir::lmhlo::LmhloDialect, mlir::lmhlo_gpu::LmhloGpuDialect>(); |
| } |
| |
| public: |
| XlaHloToLhloPass() = default; |
| XlaHloToLhloPass(const XlaHloToLhloPass&) {} |
| |
| private: |
| void runOnOperation() final { |
| ModuleOp module = getOperation(); |
| |
| auto status = [&module, this]() -> Status { |
| SymbolTable symbol_table(module); |
| if (!symbol_table.lookup("main")) { |
| return xla::InvalidArgument( |
| "conversion to HLO module failed: missing main()"); |
| } |
| HloProto hlo_proto; |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| ConvertMlirHloToHlo(module, &hlo_proto, |
| /*use_tuple_args=*/false, |
| /*return_tuple=*/false, |
| /*shape_representation_fn=*/nullptr), |
| "conversion to XLA HLO proto failed"); |
| |
| auto statusOrHloModule = HloModuleFromProto(hlo_proto); |
| TF_RETURN_WITH_CONTEXT_IF_ERROR(statusOrHloModule.status(), |
| "parsing HLO proto to HLO module failed"); |
| std::unique_ptr<HloModule> hlo_module = |
| std::move(statusOrHloModule.ValueOrDie()); |
| |
| return ConvertModule(std::move(hlo_module), module, platform_); |
| }(); |
| if (!status.ok()) { |
| module.emitError() << status.ToString(); |
| return signalPassFailure(); |
| } |
| } |
| |
| Option<std::string> platform_{ |
| *this, "platform", |
| llvm::cl::desc("The platform to use for the XLA optimization pipeline."), |
| llvm::cl::init("Host")}; |
| }; |
| |
| } // namespace |
| |
| // Creates MLIR operands corresponding to operands and results of the XLA HLO |
| // instruction. If `num_operands` is valid, then only the first `num_operands` |
| // operands of the HLO instruction will be considered. |
| Status LhloDialectEmitter::CreateOperands( |
| const HloInstruction* instr, absl::optional<xla::int64> num_operands, |
| llvm::SmallVectorImpl<Value>& operands, size_t& num_arguments, |
| size_t& num_results) { |
| if (num_operands.value_or(0) > instr->operand_count()) |
| return xla::InvalidArgument("num_operands must be <= operand count"); |
| for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count()); |
| ++i) { |
| TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands)); |
| } |
| num_arguments = operands.size(); |
| TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands)); |
| num_results = operands.size() - num_arguments; |
| return Status::OK(); |
| } |
| |
| template <typename OpType> |
| OpType LhloDialectEmitter::CreateOpWithoutAttrs(const HloInstruction* instr, |
| ValueRange operands) { |
| Location loc = getLocation(instr); |
| NamedAttribute attrs[] = {{Identifier::get("name", builder_.getContext()), |
| builder_.getStringAttr(instr->name())}}; |
| return builder_.create<OpType>(loc, llvm::None, operands, attrs); |
| } |
| |
| template <typename OpType> |
| StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs( |
| const HloInstruction* instr, size_t& num_arguments, size_t& num_results, |
| absl::optional<xla::int64> num_operands) { |
| llvm::SmallVector<Value, 4> operands; |
| TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands, operands, |
| num_arguments, num_results)); |
| return CreateOpWithoutAttrs<OpType>(instr, operands); |
| } |
| |
| StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp( |
| const HloInstruction* instr) { |
| using xla::HloOpcode; |
| switch (instr->opcode()) { |
| case HloOpcode::kAbs: |
| return CreateOpWithoutAttrs<lmhlo::AbsOp>(instr); |
| case HloOpcode::kAdd: |
| return CreateOpWithoutAttrs<lmhlo::AddOp>(instr); |
| case HloOpcode::kAllReduce: |
| return EmitAllReduceOp(instr); |
| case HloOpcode::kAnd: |
| return CreateOpWithoutAttrs<lmhlo::AndOp>(instr); |
| case HloOpcode::kAtan2: |
| return CreateOpWithoutAttrs<lmhlo::Atan2Op>(instr); |
| case HloOpcode::kBitcastConvert: |
| return CreateOpWithoutAttrs<lmhlo::BitcastConvertOp>(instr); |
| case HloOpcode::kCeil: |
| return CreateOpWithoutAttrs<lmhlo::CeilOp>(instr); |
| case HloOpcode::kCbrt: |
| return CreateOpWithoutAttrs<lmhlo::CbrtOp>(instr); |
| case HloOpcode::kClamp: |
| return CreateOpWithoutAttrs<lmhlo::ClampOp>(instr); |
| case HloOpcode::kClz: |
| return CreateOpWithoutAttrs<lmhlo::ClzOp>(instr); |
| case HloOpcode::kCompare: |
| return EmitCompareOp(instr); |
| case HloOpcode::kComplex: |
| return CreateOpWithoutAttrs<lmhlo::ComplexOp>(instr); |
| case HloOpcode::kConvert: |
| return CreateOpWithoutAttrs<lmhlo::ConvertOp>(instr); |
| case HloOpcode::kCopy: |
| return CreateOpWithoutAttrs<lmhlo::CopyOp>(instr); |
| case HloOpcode::kCos: |
| return CreateOpWithoutAttrs<lmhlo::CosOp>(instr); |
| case HloOpcode::kDivide: |
| return CreateOpWithoutAttrs<lmhlo::DivOp>(instr); |
| case HloOpcode::kExp: |
| return CreateOpWithoutAttrs<lmhlo::ExpOp>(instr); |
| case HloOpcode::kExpm1: |
| return CreateOpWithoutAttrs<lmhlo::Expm1Op>(instr); |
| case HloOpcode::kFloor: |
| return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr); |
| case HloOpcode::kImag: |
| return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr); |
| case HloOpcode::kInfeed: |
| return EmitInfeedOp(instr); |
| case HloOpcode::kIsFinite: |
| return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr); |
| case HloOpcode::kLog: |
| return CreateOpWithoutAttrs<lmhlo::LogOp>(instr); |
| case HloOpcode::kLog1p: |
| return CreateOpWithoutAttrs<lmhlo::Log1pOp>(instr); |
| case HloOpcode::kMap: |
| return EmitMapOp(instr); |
| case HloOpcode::kMaximum: |
| return CreateOpWithoutAttrs<lmhlo::MaxOp>(instr); |
| case HloOpcode::kMinimum: |
| return CreateOpWithoutAttrs<lmhlo::MinOp>(instr); |
| case HloOpcode::kMultiply: |
| return CreateOpWithoutAttrs<lmhlo::MulOp>(instr); |
| case HloOpcode::kNegate: |
| return CreateOpWithoutAttrs<lmhlo::NegOp>(instr); |
| case HloOpcode::kNot: |
| return CreateOpWithoutAttrs<lmhlo::NotOp>(instr); |
| case HloOpcode::kOr: |
| return CreateOpWithoutAttrs<lmhlo::OrOp>(instr); |
| case HloOpcode::kOutfeed: |
| return EmitOutfeedOp(instr); |
| case HloOpcode::kPopulationCount: |
| return CreateOpWithoutAttrs<lmhlo::PopulationCountOp>(instr); |
| case HloOpcode::kPower: |
| return CreateOpWithoutAttrs<lmhlo::PowOp>(instr); |
| case HloOpcode::kReal: |
| return CreateOpWithoutAttrs<lmhlo::RealOp>(instr); |
| case HloOpcode::kReducePrecision: |
| return EmitReducePrecisionOp(instr); |
| case HloOpcode::kRemainder: |
| return CreateOpWithoutAttrs<lmhlo::RemOp>(instr); |
| case HloOpcode::kRoundNearestAfz: |
| return CreateOpWithoutAttrs<lmhlo::RoundOp>(instr); |
| case HloOpcode::kRsqrt: |
| return CreateOpWithoutAttrs<lmhlo::RsqrtOp>(instr); |
| case HloOpcode::kSelect: |
| return CreateOpWithoutAttrs<lmhlo::SelectOp>(instr); |
| case HloOpcode::kShiftLeft: |
| return CreateOpWithoutAttrs<lmhlo::ShiftLeftOp>(instr); |
| case HloOpcode::kShiftRightLogical: |
| return CreateOpWithoutAttrs<lmhlo::ShiftRightLogicalOp>(instr); |
| case HloOpcode::kShiftRightArithmetic: |
| return CreateOpWithoutAttrs<lmhlo::ShiftRightArithmeticOp>(instr); |
| case HloOpcode::kSign: |
| return CreateOpWithoutAttrs<lmhlo::SignOp>(instr); |
| case HloOpcode::kSin: |
| return CreateOpWithoutAttrs<lmhlo::SinOp>(instr); |
| case HloOpcode::kSqrt: |
| return CreateOpWithoutAttrs<lmhlo::SqrtOp>(instr); |
| case HloOpcode::kSubtract: |
| return CreateOpWithoutAttrs<lmhlo::SubOp>(instr); |
| case HloOpcode::kTanh: |
| return CreateOpWithoutAttrs<lmhlo::TanhOp>(instr); |
| case HloOpcode::kXor: |
| return CreateOpWithoutAttrs<lmhlo::XorOp>(instr); |
| case HloOpcode::kSort: |
| return EmitSortOp(instr); |
| case HloOpcode::kFusion: |
| return EmitFusionOp(instr); |
| case HloOpcode::kScatter: |
| return EmitScatterOp(instr); |
| case HloOpcode::kSelectAndScatter: |
| return EmitSelectAndScatterOp(instr); |
| case HloOpcode::kCustomCall: |
| return EmitCustomCallOp(instr); |
| case HloOpcode::kConstant: |
| return EmitConstant(instr); |
| case HloOpcode::kReduce: |
| return EmitReduceOp(instr); |
| default: |
| llvm::errs() << instr->ToString(); |
| return tensorflow::errors::Internal( |
| absl::StrCat("LHLO opcode ", xla::HloOpcodeString(instr->opcode()), |
| " is not supported.")); |
| } |
| } |
| |
| Status LhloDialectEmitter::DefaultAction(const HloInstruction* instr) { |
| return EmitOp(instr).status(); |
| } |
| |
| StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp( |
| const HloInstruction* instr) { |
| TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr)); |
| auto* sort_instr = xla::Cast<xla::HloSortInstruction>(instr); |
| sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension())); |
| sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable())); |
| TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( |
| *sort_instr->called_computations()[0], &sort.comparator(), &builder_)); |
| return sort; |
| } |
| |
| // Walks MHLO::TupleOp recursively. |
| Status WalkTuplePostOrder(Value v, |
| const std::function<Status(Value)>& visitor) { |
| if (auto* op = v.getDefiningOp()) { |
| if (auto tuple = dyn_cast<mhlo::TupleOp>(op)) { |
| for (Value sub_v : tuple.val()) { |
| TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor)); |
| } |
| return Status::OK(); |
| } |
| } |
| return visitor(v); |
| } |
| |
| // This function removes all uses of a fused region argument, and rewire those |
| // uses to a `tensor_load %memref`, where %memref is caller argument. |
| // |
| // It also flattens all input/output tuples into more region arguments / |
| // results. |
| StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand( |
| const HloInstruction* root, const Shape& shape, |
| xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) { |
| if (shape.IsTuple()) { |
| llvm::SmallVector<Value, 4> values; |
| for (int i = 0; i < shape.tuple_shapes_size(); ++i) { |
| shape_index->push_back(i); |
| TF_ASSIGN_OR_RETURN( |
| auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index, |
| b, loc)); |
| values.push_back(v); |
| shape_index->pop_back(); |
| } |
| return Value(b->create<mhlo::TupleOp>(loc, values)); |
| } |
| TF_ASSIGN_OR_RETURN(Value memref, |
| GetOrCreateArrayView(root, shape, *shape_index)); |
| auto load = b->create<TensorLoadOp>(loc, memref); |
| if (shape.layout() != |
| xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) { |
| llvm::SmallVector<int64_t, 4> minor_to_major( |
| shape.layout().minor_to_major().begin(), |
| shape.layout().minor_to_major().end()); |
| load->setAttr("minor_to_major", b->getIndexTensorAttr(minor_to_major)); |
| } |
| return load.getResult(); |
| } |
| |
| StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp( |
| const HloInstruction* instr) { |
| Location loc = getLocation(instr); |
| |
| auto* fusion_instr = xla::Cast<xla::HloFusionInstruction>(instr); |
| |
| auto fusion = builder_.create<lmhlo::FusionOp>(getLocation(instr)); |
| auto after_fusion = builder_.saveInsertionPoint(); |
| builder_ = mlir::OpBuilder(fusion); |
| |
| auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front()); |
| |
| llvm::SmallVector<Value, 8> arguments; |
| for (int i = 0; i < instr->operands().size(); ++i) { |
| const HloInstruction* operand = instr->operand(i); |
| xla::ShapeIndex shape_index; |
| TF_ASSIGN_OR_RETURN( |
| auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index, |
| ®ion_builder, loc)); |
| arguments.push_back(arg); |
| } |
| |
| TF_ASSIGN_OR_RETURN(Value result, |
| xla::HloFunctionImporter::ImportInstructions( |
| *fusion_instr->fused_instructions_computation(), |
| arguments, ®ion_builder)); |
| |
| { |
| int i = 0; |
| llvm::SmallVector<Value, 4> output; |
| TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output)); |
| TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable { |
| region_builder.create<TensorStoreOp>(loc, v, output[i++]); |
| return Status::OK(); |
| })); |
| if (i != output.size()) { |
| return xla::InternalError("output sizes don't match"); |
| } |
| } |
| |
| // Fold GTE/Tuple pairs. |
| // |
| // Since the fused region refers to values in its parent region, we can't |
| // call applyPatternAndFoldGreedily. We optimize it manually. |
| // |
| // Only walk once, because post-ordering is exactly what we need for GTE |
| // optimizations. |
| fusion.region().walk([](mhlo::GetTupleElementOp gte) { |
| SmallVector<Value, 4> folded_values; |
| if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) { |
| gte.replaceAllUsesWith(folded_values[0]); |
| } |
| }); |
| |
| // Effectively a DCE on the region. |
| { |
| llvm::SmallVector<mlir::Operation*, 4> ops; |
| fusion.region().walk([&](mlir::Operation* op) { ops.push_back(op); }); |
| // Visit the user first. |
| std::reverse(ops.begin(), ops.end()); |
| for (auto op : ops) { |
| if (isOpTriviallyDead(op)) op->erase(); |
| } |
| } |
| |
| builder_.restoreInsertionPoint(after_fusion); |
| return fusion; |
| } |
| |
| StatusOr<mhlo::ScatterDimensionNumbers> |
| LhloDialectEmitter::GetScatterDimensionNumbers(const HloInstruction* instr) { |
| auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr); |
| |
| const xla::ScatterDimensionNumbers& xla_scatter_dim = |
| scatter_instr->scatter_dimension_numbers(); |
| auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbers::get( |
| GetI64DenseElementsAttr(xla_scatter_dim.update_window_dims()), |
| GetI64DenseElementsAttr(xla_scatter_dim.inserted_window_dims()), |
| GetI64DenseElementsAttr(xla_scatter_dim.scatter_dims_to_operand_dims()), |
| builder_.getI64IntegerAttr(xla_scatter_dim.index_vector_dim()), |
| module_.getContext()); |
| return scatter_dimension_numbers; |
| } |
| |
| StatusOr<lmhlo::ScatterOp> LhloDialectEmitter::EmitScatterOp( |
| const HloInstruction* instr) { |
| TF_ASSIGN_OR_RETURN(auto scatter, |
| CreateOpWithoutAttrs<lmhlo::ScatterOp>(instr)); |
| |
| // copy attributes |
| auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr); |
| |
| TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers, |
| GetScatterDimensionNumbers(instr)); |
| scatter.scatter_dimension_numbersAttr(scatter_dimension_numbers); |
| scatter.indices_are_sortedAttr( |
| builder_.getBoolAttr(scatter_instr->indices_are_sorted())); |
| scatter.unique_indicesAttr( |
| builder_.getBoolAttr(scatter_instr->unique_indices())); |
| |
| // import update computation as region |
| TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( |
| *scatter_instr->called_computations()[0], &scatter.update_computation(), |
| &builder_)); |
| |
| return scatter; |
| } |
| |
| StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp( |
| const HloInstruction* instr) { |
| TF_ASSIGN_OR_RETURN(auto select_and_scatter, |
| CreateOpWithoutAttrs<lmhlo::SelectAndScatterOp>(instr)); |
| |
| // copy attributes |
| auto* select_and_scatter_instr = |
| xla::Cast<xla::HloSelectAndScatterInstruction>(instr); |
| const xla::Window& window = select_and_scatter_instr->window(); |
| |
| select_and_scatter.window_dimensionsAttr( |
| GetWindowElements(window, [](const xla::WindowDimension& dim) { |
| return static_cast<int64_t>(dim.size()); |
| })); |
| select_and_scatter.window_stridesAttr( |
| GetWindowElements(window, [](const xla::WindowDimension& dim) { |
| return static_cast<int64_t>(dim.stride()); |
| })); |
| select_and_scatter.paddingAttr( |
| GetWindowElements(window, [](const xla::WindowDimension& dim) { |
| return static_cast<int64_t>(dim.padding_low()); |
| })); |
| |
| // import select and scatter computation as region |
| TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( |
| *select_and_scatter_instr->select(), &select_and_scatter.select(), |
| &builder_)); |
| TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( |
| *select_and_scatter_instr->scatter(), &select_and_scatter.scatter(), |
| &builder_)); |
| return select_and_scatter; |
| } |
| |
| StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp( |
| const HloInstruction* instr) { |
| auto* custom_call_instr = xla::Cast<xla::HloCustomCallInstruction>(instr); |
| |
| if (xla::gpu::IsCustomCallToCusolver(*instr)) { |
| return EmitCholesky(custom_call_instr); |
| } |
| |
| if (xla::gpu::IsCublasGemm(*instr)) { |
| return EmitGemm(custom_call_instr); |
| } |
| |
| if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) { |
| return EmitDnnConvolution(custom_call_instr); |
| } |
| |
| if (xla::gpu::IsCustomCallToDnnBatchNorm(*instr)) { |
| return EmitDnnBatchNorm(custom_call_instr); |
| } |
| |
| size_t num_arguments, num_results; |
| TF_ASSIGN_OR_RETURN(auto custom_call, |
| CreateOpWithoutAttrs<lmhlo::CustomCallOp>( |
| instr, num_arguments, num_results)); |
| custom_call.call_target_nameAttr( |
| builder_.getStringAttr(custom_call_instr->custom_call_target())); |
| custom_call.backend_configAttr( |
| builder_.getStringAttr(custom_call_instr->opaque())); |
| const int32_t segments[2] = {static_cast<int32_t>(num_arguments), |
| static_cast<int32_t>(num_results)}; |
| custom_call->setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(), |
| builder_.getI32VectorAttr(segments)); |
| return custom_call.getOperation(); |
| } |
| |
| StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky( |
| const HloCustomCallInstruction* custom_call) { |
| TF_ASSIGN_OR_RETURN(auto cholesky_op, |
| CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call)); |
| TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options, |
| custom_call->backend_config<xla::CholeskyOptions>()); |
| cholesky_op.is_lowerAttr(builder_.getBoolAttr(options.lower())); |
| return cholesky_op; |
| } |
| |
| StatusOr<Operation*> LhloDialectEmitter::EmitGemm( |
| const HloCustomCallInstruction* custom_call) { |
| TF_ASSIGN_OR_RETURN( |
| auto const config, |
| custom_call->backend_config<xla::gpu::GemmBackendConfig>()); |
| |
| auto set_common_attributes = [&](auto op) -> Operation* { |
| auto hlo_dims = config.dot_dimension_numbers(); |
| auto mlir_dims = mhlo::DotDimensionNumbers::get( |
| GetI64DenseElementsAttr(hlo_dims.lhs_batch_dimensions()), |
| GetI64DenseElementsAttr(hlo_dims.rhs_batch_dimensions()), |
| GetI64DenseElementsAttr(hlo_dims.lhs_contracting_dimensions()), |
| GetI64DenseElementsAttr(hlo_dims.rhs_contracting_dimensions()), |
| builder_.getContext()); |
| op.dot_dimension_numbersAttr(mlir_dims); |
| op.alpha_realAttr(builder_.getF64FloatAttr(config.alpha_real())); |
| op.alpha_imagAttr(builder_.getF64FloatAttr(config.alpha_imag())); |
| op.batch_sizeAttr(builder_.getI64IntegerAttr(config.batch_size())); |
| if (config.algorithm_case() == |
| xla::gpu::GemmBackendConfig::kSelectedAlgorithm) { |
| op.algorithmAttr(builder_.getI64IntegerAttr(config.selected_algorithm())); |
| } |
| return op.getOperation(); |
| }; |
| |
| if (custom_call->operand_count() == 2) { |
| TF_ASSIGN_OR_RETURN(auto gemm, |
| CreateOpWithoutAttrs<lmhlo_gpu::GEMMOp>(custom_call)); |
| return set_common_attributes(gemm); |
| } |
| |
| if (custom_call->operand_count() == 3) { |
| TF_ASSIGN_OR_RETURN( |
| auto gemm_bias, |
| CreateOpWithoutAttrs<lmhlo_gpu::GEMM_BiasOp>(custom_call)); |
| gemm_bias.betaAttr(builder_.getF64FloatAttr(config.beta())); |
| return set_common_attributes(gemm_bias); |
| } |
| |
| return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands"); |
| } |
| |
| static StatusOr<mlir::lmhlo_gpu::Activation> GetLHLOActivation( |
| stream_executor::dnn::ActivationMode activation) { |
| switch (activation) { |
| case stream_executor::dnn::kNone: |
| return mlir::lmhlo_gpu::Activation::None; |
| case stream_executor::dnn::kSigmoid: |
| return mlir::lmhlo_gpu::Activation::Sigmoid; |
| case stream_executor::dnn::kRelu: |
| return mlir::lmhlo_gpu::Activation::Relu; |
| case stream_executor::dnn::kRelu6: |
| return mlir::lmhlo_gpu::Activation::Relu6; |
| case stream_executor::dnn::kReluX: |
| return mlir::lmhlo_gpu::Activation::ReluX; |
| case stream_executor::dnn::kTanh: |
| return mlir::lmhlo_gpu::Activation::Tanh; |
| case stream_executor::dnn::kBandPass: |
| return mlir::lmhlo_gpu::Activation::BandPass; |
| default: |
| return xla::InternalError("Unknown activation"); |
| } |
| } |
| |
| StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution( |
| const HloCustomCallInstruction* custom_call) { |
| TF_ASSIGN_OR_RETURN( |
| auto const backend_config, |
| custom_call->backend_config<xla::gpu::CudnnConvBackendConfig>()); |
| |
| TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnConvKind kind, |
| xla::gpu::GetCudnnConvKind(custom_call)); |
| |
| auto get_layout_attribute = [&](const xla::Layout& layout) { |
| std::vector<int64_t> minor_to_major(layout.minor_to_major_size()); |
| absl::c_transform(layout.minor_to_major(), minor_to_major.begin(), |
| [](xla::int64 x) { return static_cast<int64_t>(x); }); |
| return builder_.getI64ArrayAttr(minor_to_major); |
| }; |
| |
| auto set_common_conv_attributes = [&, this](auto op) -> Operation* { |
| const xla::Window& window = custom_call->window(); |
| // Window size for Cudnn Conv is same as the kernel size. |
| op.window_stridesAttr( |
| GetWindowElements(window, [](const xla::WindowDimension& dim) { |
| return static_cast<int64_t>(dim.stride()); |
| })); |
| // Cudnn Conv requires low and high padding to be equal. |
| op.paddingAttr( |
| GetWindowElements(window, [](const xla::WindowDimension& dim) { |
| return static_cast<int64_t>(dim.padding_low()); |
| })); |
| // LHS dilation is encoded in base_dilation of the backend config. |
| // RHS dilation is encoded in window_dilation of the backend config. |
| op.lhs_dilationAttr( |
| GetWindowElements(window, [](const xla::WindowDimension& dim) { |
| return static_cast<int64_t>(dim.base_dilation()); |
| })); |
| op.rhs_dilationAttr( |
| GetWindowElements(window, [](const xla::WindowDimension& dim) { |
| return static_cast<int64_t>(dim.window_dilation()); |
| })); |
| // Setup window reversal. |
| auto window_reversal = llvm::to_vector<4>(llvm::map_range( |
| window.dimensions(), |
| [](const xla::WindowDimension& dim) { return dim.window_reversal(); })); |
| auto type = RankedTensorType::get(op.window_strides()->getType().getShape(), |
| builder_.getIntegerType(/*width=*/1)); |
| op.window_reversalAttr(DenseElementsAttr::get(type, window_reversal)); |
| |
| op.dimension_numbersAttr(xla::ConvertConvDimensionNumbers( |
| custom_call->convolution_dimension_numbers(), &builder_)); |
| op.feature_group_countAttr( |
| builder_.getI64IntegerAttr(custom_call->feature_group_count())); |
| op.batch_group_countAttr( |
| builder_.getI64IntegerAttr(custom_call->batch_group_count())); |
| op.precision_configAttr(xla::ConvertPrecisionConfig( |
| &custom_call->precision_config(), &builder_)); |
| op.result_scaleAttr( |
| builder_.getF64FloatAttr(backend_config.conv_result_scale())); |
| auto config = mlir::lmhlo_gpu::ConvolutionBackendConfig::get( |
| builder_.getI64IntegerAttr(backend_config.algorithm()), |
| builder_.getBoolAttr(backend_config.tensor_ops_enabled()), |
| get_layout_attribute(custom_call->operand(0)->shape().layout()), |
| get_layout_attribute(custom_call->operand(1)->shape().layout()), |
| get_layout_attribute(custom_call->shape().tuple_shapes(0).layout()), |
| builder_.getContext()); |
| op.backend_configAttr(config); |
| |
| return op.getOperation(); |
| }; |
| |
| auto set_activation = [&, this](auto op) -> Status { |
| auto se_activation = static_cast<stream_executor::dnn::ActivationMode>( |
| backend_config.activation_mode()); |
| TF_ASSIGN_OR_RETURN(mlir::lmhlo_gpu::Activation activation, |
| GetLHLOActivation(se_activation)); |
| StringAttr activation_attr = builder_.getStringAttr( |
| mlir::lmhlo_gpu::stringifyActivation(activation)); |
| op.activation_modeAttr(activation_attr); |
| return Status::OK(); |
| }; |
| |
| switch (kind) { |
| case xla::gpu::CudnnConvKind::kForward: { |
| TF_ASSIGN_OR_RETURN( |
| auto cnn_forward, |
| CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardOp>(custom_call)); |
| return set_common_conv_attributes(cnn_forward); |
| } |
| case xla::gpu::CudnnConvKind::kBackwardInput: { |
| TF_ASSIGN_OR_RETURN( |
| auto cnn_backward, |
| CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardInputOp>(custom_call)); |
| return set_common_conv_attributes(cnn_backward); |
| } |
| case xla::gpu::CudnnConvKind::kBackwardFilter: { |
| TF_ASSIGN_OR_RETURN( |
| auto cnn_backward, |
| CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardFilterOp>(custom_call)); |
| return set_common_conv_attributes(cnn_backward); |
| } |
| case xla::gpu::CudnnConvKind::kForwardActivation: { |
| // Fused conv can be either with side input or without. |
| if (custom_call->operand_count() == 3) { |
| TF_ASSIGN_OR_RETURN( |
| auto cnn_fused, |
| CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedOp>(custom_call)); |
| TF_RETURN_IF_ERROR(set_activation(cnn_fused)); |
| return set_common_conv_attributes(cnn_fused); |
| } |
| |
| TF_RET_CHECK(custom_call->operand_count() == 4); |
| TF_ASSIGN_OR_RETURN( |
| auto cnn_fused_side_input, |
| CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedSideInputOp>( |
| custom_call)); |
| cnn_fused_side_input.side_input_scaleAttr( |
| builder_.getF64FloatAttr(backend_config.side_input_scale())); |
| TF_RETURN_IF_ERROR(set_activation(cnn_fused_side_input)); |
| return set_common_conv_attributes(cnn_fused_side_input); |
| } |
| } |
| } |
| |
| StatusOr<Operation*> LhloDialectEmitter::EmitDnnBatchNorm( |
| const HloCustomCallInstruction* custom_call) { |
| const xla::int64 num_operands = custom_call->operand_count(); |
| auto set_batchnorm_attributes = [&](auto op) -> StatusOr<Operation*> { |
| // The last 2 operands of a custom call for batch norm are the epsilon and |
| // feature_index. |
| const HloInstruction* epsilon = custom_call->operand(num_operands - 2); |
| TF_RET_CHECK(epsilon->IsConstant()); |
| float epsilon_value = epsilon->literal().Get<float>({}); |
| |
| const HloInstruction* feature_index = |
| custom_call->operand(num_operands - 1); |
| TF_RET_CHECK(feature_index->IsConstant()); |
| xla::int64 feature_index_value = |
| feature_index->literal().Get<xla::int64>({}); |
| |
| op.epsilonAttr(builder_.getF32FloatAttr(epsilon_value)); |
| op.feature_indexAttr(builder_.getI64IntegerAttr(feature_index_value)); |
| return op.getOperation(); |
| }; |
| |
| const std::string& target = custom_call->custom_call_target(); |
| if (target == xla::gpu::kCudnnBatchNormForwardTrainingCallTarget) { |
| TF_ASSIGN_OR_RETURN(auto fwd_training, |
| CreateOpWithoutAttrs<lmhlo_gpu::BatchNormTrainingOp>( |
| custom_call, num_operands - 2)); |
| return set_batchnorm_attributes(fwd_training); |
| } |
| |
| if (target == xla::gpu::kCudnnBatchNormBackwardCallTarget) { |
| TF_ASSIGN_OR_RETURN(auto backward, |
| CreateOpWithoutAttrs<lmhlo_gpu::BatchNormGradOp>( |
| custom_call, num_operands - 2)); |
| return set_batchnorm_attributes(backward); |
| } |
| |
| if (target == xla::gpu::kCudnnBatchNormForwardInferenceCallTarget) { |
| TF_ASSIGN_OR_RETURN(auto fwd_inference, |
| CreateOpWithoutAttrs<lmhlo_gpu::BatchNormInferenceOp>( |
| custom_call, num_operands - 2)); |
| return set_batchnorm_attributes(fwd_inference); |
| } |
| |
| return xla::Unimplemented("Unsupported batch norm operation"); |
| } |
| |
| // Convert an XLA HLO constant to a global_memref + get_global_memref pair. |
| StatusOr<mlir::GetGlobalMemrefOp> LhloDialectEmitter::EmitConstant( |
| const HloInstruction* instr) { |
| // Insert a global_memref in the module. |
| Location loc = getLocation(instr); |
| |
| auto const_instr = xla::Cast<xla::HloConstantInstruction>(instr); |
| TF_RET_CHECK(const_instr->shape().IsArray() && |
| const_instr->shape().is_static()); |
| TF_ASSIGN_OR_RETURN(Type type, xla::ConvertShapeToType<MemRefType>( |
| const_instr->shape(), builder_)); |
| auto memref_type = type.dyn_cast<MemRefType>(); |
| TF_RET_CHECK(memref_type != nullptr); |
| |
| TF_ASSIGN_OR_RETURN( |
| DenseElementsAttr initial_value, |
| CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_)); |
| |
| std::string constant_name = xla::llvm_ir::ConstantHloToGlobalName(*instr); |
| |
| // Insert the global memref at the top level. |
| { |
| OpBuilder::InsertionGuard guard(builder_); |
| builder_.clearInsertionPoint(); |
| auto global_var = builder_.create<GlobalMemrefOp>( |
| loc, constant_name, builder_.getStringAttr("private"), |
| TypeAttr::get(memref_type), initial_value, true); |
| SymbolTable(module_).insert(global_var); |
| global_var.getOperation()->moveBefore(&module_.front()); |
| |
| // For operations that do not fold this constant value in their codegen, we |
| // still need to materialize it into a buffer. Since buffer allocation is |
| // already done, annotate the global_memref with the information to get to |
| // the allocated buffer slice for this constant if need be. |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, |
| assignment_.GetUniqueTopLevelSlice(instr)); |
| global_var->setAttr("lmhlo.alloc", builder_.getIndexAttr(slice.index())); |
| TF_RET_CHECK(slice.offset() == 0) |
| << "Each constant should have its own allocation from BufferAssignment"; |
| TF_RET_CHECK(slice.allocation()->size() == slice.size()) |
| << "Each constant should have its own allocation from BufferAssignment"; |
| } |
| |
| auto get_global_memref = |
| builder_.create<GetGlobalMemrefOp>(loc, memref_type, constant_name); |
| |
| // Update the cache to remember this value. |
| auto& cached_value = slices_[std::make_pair(instr, xla::ShapeIndex())]; |
| TF_RET_CHECK(cached_value == nullptr); |
| cached_value = get_global_memref; |
| return get_global_memref; |
| } |
| |
| StatusOr<lmhlo::ReduceOp> LhloDialectEmitter::EmitReduceOp( |
| const HloInstruction* instr) { |
| TF_ASSIGN_OR_RETURN(auto reduce_op, |
| CreateOpWithoutAttrs<lmhlo::ReduceOp>(instr)); |
| auto* reduce = xla::Cast<xla::HloReduceInstruction>(instr); |
| std::vector<int64_t> dimensions(reduce->dimensions().begin(), |
| reduce->dimensions().end()); |
| reduce_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions)); |
| TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( |
| *instr->called_computations()[0], &reduce_op.body(), &builder_)); |
| return reduce_op; |
| } |
| |
| StatusOr<lmhlo::MapOp> LhloDialectEmitter::EmitMapOp( |
| const HloInstruction* instr) { |
| TF_ASSIGN_OR_RETURN(auto map_op, CreateOpWithoutAttrs<lmhlo::MapOp>(instr)); |
| auto* map = xla::Cast<xla::HloMapInstruction>(instr); |
| std::vector<int64_t> dimensions(map->dimensions().begin(), |
| map->dimensions().end()); |
| map_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions)); |
| TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( |
| *instr->called_computations()[0], &map_op.computation(), &builder_)); |
| return map_op; |
| } |
| |
| StatusOr<lmhlo::CompareOp> LhloDialectEmitter::EmitCompareOp( |
| const HloInstruction* instr) { |
| TF_ASSIGN_OR_RETURN(auto compare_op, |
| CreateOpWithoutAttrs<lmhlo::CompareOp>(instr)); |
| |
| auto* compare = xla::Cast<xla::HloCompareInstruction>(instr); |
| auto direction = [&]() { |
| switch (compare->direction()) { |
| case xla::ComparisonDirection::kEq: |
| return mhlo::ComparisonDirection::EQ; |
| case xla::ComparisonDirection::kNe: |
| return mhlo::ComparisonDirection::NE; |
| case xla::ComparisonDirection::kGe: |
| return mhlo::ComparisonDirection::GE; |
| case xla::ComparisonDirection::kGt: |
| return mhlo::ComparisonDirection::GT; |
| case xla::ComparisonDirection::kLe: |
| return mhlo::ComparisonDirection::LE; |
| case xla::ComparisonDirection::kLt: |
| return mhlo::ComparisonDirection::LT; |
| } |
| }(); |
| compare_op.comparison_directionAttr( |
| builder_.getStringAttr(stringifyComparisonDirection(direction))); |
| auto compare_type = [&]() { |
| switch (compare->type()) { |
| case xla::Comparison::Type::kFloat: |
| return mhlo::ComparisonType::FLOAT; |
| case xla::Comparison::Type::kFloatTotalOrder: |
| return mhlo::ComparisonType::TOTALORDER; |
| case xla::Comparison::Type::kSigned: |
| return mhlo::ComparisonType::SIGNED; |
| case xla::Comparison::Type::kUnsigned: |
| return mhlo::ComparisonType::UNSIGNED; |
| } |
| }(); |
| compare_op.compare_typeAttr( |
| builder_.getStringAttr(stringifyComparisonType(compare_type))); |
| return compare_op; |
| } |
| |
| StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp( |
| const HloInstruction* instr) { |
| TF_ASSIGN_OR_RETURN(auto reduce_precision_op, |
| CreateOpWithoutAttrs<lmhlo::ReducePrecisionOp>(instr)); |
| auto* reduce_precision = xla::Cast<xla::HloReducePrecisionInstruction>(instr); |
| reduce_precision_op.exponent_bitsAttr( |
| builder_.getI32IntegerAttr(reduce_precision->exponent_bits())); |
| reduce_precision_op.mantissa_bitsAttr( |
| builder_.getI32IntegerAttr(reduce_precision->mantissa_bits())); |
| return reduce_precision_op; |
| } |
| |
| StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp( |
| const HloInstruction* instr) { |
| TF_ASSIGN_OR_RETURN(auto all_reduce_op, |
| CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr)); |
| auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr); |
| auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups( |
| all_reduce->replica_groups(), builder_); |
| all_reduce_op->setAttr(replica_groups_attr.first, replica_groups_attr.second); |
| all_reduce_op.constrain_layoutAttr( |
| builder_.getBoolAttr(all_reduce->constrain_layout())); |
| all_reduce_op.channel_idAttr(mlir::mhlo::ChannelHandle::get( |
| builder_.getI64IntegerAttr(all_reduce->channel_id().value_or(0)), |
| builder_.getI64IntegerAttr(0), builder_.getContext())); |
| all_reduce_op.use_global_device_idsAttr( |
| builder_.getBoolAttr(all_reduce->use_global_device_ids())); |
| TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( |
| *instr->called_computations()[0], &all_reduce_op.computation(), |
| &builder_)); |
| return all_reduce_op; |
| } |
| |
| StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp( |
| const HloInstruction* instr) { |
| const HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr); |
| // HLO Infeed instruction has a single operand of token type and a tuple |
| // with buffers and a token as its output. LMHLO Infeed operation does not |
| // need the token operand or result, so drop it. |
| SmallVector<Value, 2> operands; |
| TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0})); |
| auto infeed_op = CreateOpWithoutAttrs<lmhlo::InfeedOp>(instr, operands); |
| infeed_op.configAttr(builder_.getStringAttr(infeed->infeed_config())); |
| return infeed_op; |
| } |
| |
| StatusOr<lmhlo::OutfeedOp> LhloDialectEmitter::EmitOutfeedOp( |
| const HloInstruction* instr) { |
| const HloOutfeedInstruction* outfeed = |
| xla::Cast<HloOutfeedInstruction>(instr); |
| // HLO outfeed instruction has 2 operands, the source and a token, and a |
| // single token output. LMHLO Outfeed does not need the token operand and |
| // result, do drop it. |
| SmallVector<Value, 2> operands; |
| TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands)); |
| auto outfeed_op = CreateOpWithoutAttrs<lmhlo::OutfeedOp>(instr, operands); |
| outfeed_op.configAttr(builder_.getStringAttr(outfeed->outfeed_config())); |
| return outfeed_op; |
| } |
| |
| StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView( |
| const xla::HloInstruction* instr, const xla::Shape& current_shape, |
| const xla::ShapeIndex& shape_index) { |
| // Cache generated ViewOp and StaticMemRefCastOp by (instruction, |
| // shape_index). |
| auto& cached_value = slices_[std::make_pair(instr, shape_index)]; |
| if (cached_value) { |
| return cached_value; |
| } |
| |
| if (instr->IsConstant() && shape_index.empty()) { |
| TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr)); |
| return cached_value = constant_memref; |
| } |
| |
| // If the shape happens to have dynamic dimensions, create the memref using |
| // the underlying static shape. |
| // TODO(jurahul): Revisit this when we can model memrefs with dynamic shape |
| // but static bounds in MLIR. |
| const Shape static_shape = xla::ShapeUtil::MakeStaticShape(current_shape); |
| |
| TF_ASSIGN_OR_RETURN(Type out_type, xla::ConvertShapeToType<MemRefType>( |
| static_shape, builder_)); |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, |
| assignment_.GetUniqueSlice(instr, shape_index)); |
| Value alloc = allocations_[slice.allocation()]; |
| if (alloc.getType() == out_type && slice.offset() == 0) { |
| return cached_value = alloc; |
| } |
| |
| auto out_memref_type = out_type.dyn_cast<MemRefType>(); |
| if (!out_memref_type) |
| return tensorflow::errors::Internal( |
| "Expected memref type when creating a view for leaf type of a " |
| "tuple."); |
| |
| Value byte_shift = |
| builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset()); |
| |
| xla::Shape physical_shape = |
| xla::ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( |
| static_shape); |
| TF_ASSIGN_OR_RETURN( |
| Type physical_out_type, |
| xla::ConvertShapeToType<MemRefType>(physical_shape, builder_)); |
| |
| // TODO(timshen): revisit location handling. |
| Location loc = builder_.getUnknownLoc(); |
| |
| // ViewOp only takes memrefs without affine maps (layouts). Let ViewOp produce |
| // the physical shape (where dimensions are ordered in major to minor) first, |
| // then follow up with a MemRefReinterpretCast to cast the resulting memref to |
| // the original layout. |
| Value result = |
| builder_.create<ViewOp>(loc, physical_out_type, alloc, byte_shift, |
| /*sizes=*/ValueRange{}); |
| if (physical_out_type != out_type) { |
| int64_t out_offset; |
| SmallVector<int64_t, 4> out_strides; |
| if (failed(getStridesAndOffset(out_memref_type, out_strides, out_offset))) |
| return tensorflow::errors::Internal( |
| "Failed to get strides and offset from the output type."); |
| result = builder_.create<MemRefReinterpretCastOp>( |
| loc, out_memref_type, result, out_offset, out_memref_type.getShape(), |
| out_strides, llvm::None, llvm::None, llvm::None); |
| } |
| return cached_value = result; |
| } |
| |
| Status LhloDialectEmitter::GetOrCreateViewImpl( |
| const HloInstruction* instr, const Shape& current_shape, |
| xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values) { |
| if (current_shape.IsTuple()) { |
| for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) { |
| current_shape_index->push_back(i); |
| TF_RETURN_IF_ERROR(GetOrCreateViewImpl( |
| instr, current_shape.tuple_shapes(i), current_shape_index, values)); |
| current_shape_index->pop_back(); |
| } |
| return Status::OK(); |
| } |
| if (current_shape.IsArray()) { |
| TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape, |
| *current_shape_index)); |
| values->push_back(v); |
| return Status::OK(); |
| } |
| return xla::InternalError("Unexpected shape kind for %s and shape index %s", |
| instr->ToString(), current_shape_index->ToString()); |
| } |
| |
| // Returns a view for the result of an instruction. |
| // We first get a view for the slice in the allocation, and then may need to |
| // create another view to adjust the slice for the shape of the instruction. |
| Status LhloDialectEmitter::GetOrCreateView( |
| const HloInstruction* instr, SmallVectorImpl<Value>* values, |
| const xla::ShapeIndex& result_subset) { |
| xla::ShapeIndex shape_index = result_subset; |
| const Shape& sub_shape = |
| xla::ShapeUtil::GetSubshape(instr->shape(), shape_index); |
| return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values); |
| } |
| |
| Status LhloDialectEmitter::Initialize() { |
| std::string function_name = |
| computation_.name().empty() ? "__compute" : computation_.name(); |
| |
| // Create the function as () -> (), we'll compute the arguments from the |
| // buffer allocation and update the type then. |
| auto func_op = FuncOp::create(builder_.getUnknownLoc(), function_name, |
| builder_.getFunctionType({}, {})); |
| Block* block = func_op.addEntryBlock(); |
| |
| llvm::SmallVector<const BufferAllocation*, 8> ordered_allocations; |
| for (const BufferAllocation& alloc : assignment_.Allocations()) |
| ordered_allocations.push_back(&alloc); |
| |
| if (computation_.IsEntryComputation()) { |
| // Sort the rather arbitrarily ordered allocations to match the input/output |
| // parameters. Specifically we want to sort buffer allocations in the |
| // following order: |
| // * Parameters always order before non-parameters. |
| // * Different parameters order by parameter number. |
| // * Different allocations for the same parameter order by the shape index. |
| // |
| // TODO(timshen): there should be only one non-parameter buffer, the temp |
| // buffer. Check on that. |
| const auto allocation_comparator = [](const BufferAllocation* lhs, |
| const BufferAllocation* rhs) { |
| if (lhs->is_entry_computation_parameter() != |
| rhs->is_entry_computation_parameter()) { |
| return lhs->is_entry_computation_parameter() > |
| rhs->is_entry_computation_parameter(); |
| } |
| if (lhs->is_entry_computation_parameter()) { |
| return std::tuple<int, const xla::ShapeIndex&>( |
| lhs->parameter_number(), lhs->param_shape_index()) < |
| std::tuple<int, const xla::ShapeIndex&>( |
| rhs->parameter_number(), rhs->param_shape_index()); |
| } |
| return false; |
| }; |
| |
| std::stable_sort(ordered_allocations.begin(), ordered_allocations.end(), |
| allocation_comparator); |
| } |
| |
| // The function signature will be composed of: |
| // - one memref for each of the parameters. |
| // - one memref for each other buffer allocation. |
| llvm::SmallVector<DictionaryAttr, 8> args_attrs; |
| for (const BufferAllocation* alloc : ordered_allocations) { |
| if (computation_.IsEntryComputation() && |
| alloc->is_entry_computation_parameter()) { |
| const xla::Shape& buffer_shape = xla::ShapeUtil::GetSubshape( |
| computation_.parameter_instruction(alloc->parameter_number()) |
| ->shape(), |
| alloc->param_shape_index()); |
| |
| TF_ASSIGN_OR_RETURN(auto arg_type, xla::ConvertShapeToType<MemRefType>( |
| buffer_shape, builder_)); |
| |
| // First map parameters to memrefs on the operation. |
| block->addArgument(arg_type); |
| allocations_[alloc] = block->getArguments().back(); |
| NamedAttrList arg_attr_list; |
| arg_attr_list.set("lmhlo.alloc", builder_.getIndexAttr(alloc->index())); |
| arg_attr_list.set("lmhlo.params", |
| builder_.getIndexAttr(alloc->parameter_number())); |
| args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext())); |
| } else { |
| block->addArgument(MemRefType::get({alloc->size()}, i8_type_)); |
| allocations_[alloc] = block->getArguments().back(); |
| |
| NamedAttrList arg_attr_list; |
| arg_attr_list.set("lmhlo.alloc", builder_.getIndexAttr(alloc->index())); |
| arg_attr_list.set("lmhlo.liveout", builder_.getBoolAttr(true)); |
| args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext())); |
| } |
| } |
| |
| FunctionType function_type = |
| builder_.getFunctionType(block->getArgumentTypes(), {}); |
| func_op.setType(function_type); |
| func_op.setAllArgAttrs(args_attrs); |
| |
| SymbolTable symbol_table(module_); |
| symbol_table.insert(func_op); |
| builder_.setInsertionPointToEnd(block); |
| |
| auto return_op = builder_.create<ReturnOp>(builder_.getUnknownLoc()); |
| builder_ = OpBuilder(return_op); |
| |
| return Status::OK(); |
| } |
| |
| std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() { |
| return std::make_unique<XlaHloToLhloPass>(); |
| } |
| |
| Status HloToLhloModule(const BufferAssignment& assignment, |
| const HloModule& hlo_module, ModuleOp module) { |
| module.getContext() |
| ->loadDialect<StandardOpsDialect, mhlo::MhloDialect, lmhlo::LmhloDialect, |
| lmhlo_gpu::LmhloGpuDialect>(); |
| const HloComputation* computation = hlo_module.entry_computation(); |
| |
| LhloDialectEmitter emitter(assignment, *computation, module); |
| TF_RETURN_IF_ERROR(emitter.Initialize()); |
| |
| const xla::HloInstructionSequence* schedule = |
| assignment.hlo_ordering().SequentialOrder(*computation); |
| if (!schedule) |
| return xla::Unimplemented("Missing sequential order for the computation"); |
| const std::vector<HloInstruction*>& ordering = schedule->instructions(); |
| return computation->AcceptOrdered(&emitter, ordering); |
| } |
| |
| OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input, |
| MLIRContext* context) { |
| StatusOr<std::unique_ptr<HloModule>> maybe_module = |
| xla::ParseAndReturnUnverifiedModule( |
| absl::string_view(input.data(), input.size())); |
| TF_CHECK_OK(maybe_module.status()); |
| |
| OwningModuleRef module = ModuleOp::create(UnknownLoc::get(context)); |
| |
| TF_CHECK_OK( |
| ConvertModule(maybe_module.ConsumeValueOrDie(), module.get(), "Host")); |
| |
| return module; |
| } |
| |
| static PassRegistration<XlaHloToLhloPass> registration( |
| "xla-hlo-to-lhlo-with-xla", |
| "Emit LHLO from HLO using the existing XLA implementation"); |
| |
| } // namespace mlir |