| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" |
| |
| #include <algorithm> |
| #include <cstring> |
| #include <iterator> |
| #include <memory> |
| #include <string> |
| #include <type_traits> |
| #include <vector> |
| |
| #include "absl/algorithm/container.h" |
| #include "absl/container/inlined_vector.h" |
| #include "absl/memory/memory.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/str_format.h" |
| #include "absl/types/optional.h" |
| #include "absl/types/span.h" |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/IR/BasicBlock.h" |
| #include "llvm/IR/Function.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/IR/LLVMContext.h" |
| #include "llvm/IR/Module.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Verifier.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" |
| #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" |
| #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" |
| #include "tensorflow/compiler/mlir/utils/name_utils.h" |
| #include "tensorflow/compiler/mlir/xla/attribute_exporter.h" |
| #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" |
| #include "tensorflow/compiler/mlir/xla/hlo_utils.h" |
| #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" |
| #include "tensorflow/compiler/mlir/xla/type_to_shape.h" |
| #include "tensorflow/compiler/xla/layout_util.h" |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/permutation_util.h" |
| #include "tensorflow/compiler/xla/service/collective_ops_utils.h" |
| #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" |
| #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" |
| #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" |
| #include "tensorflow/compiler/xla/service/gpu/bef_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" |
| #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" |
| #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" |
| #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" |
| #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" |
| #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" |
| #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" |
| #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" |
| #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" |
| #include "tensorflow/compiler/xla/service/gpu/replica_id_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/target_util.h" |
| #include "tensorflow/compiler/xla/service/gpu/thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" |
| #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" |
| #include "tensorflow/compiler/xla/service/hlo_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_instructions.h" |
| #include "tensorflow/compiler/xla/service/hlo_opcode.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" |
| #include "tensorflow/compiler/xla/service/name_uniquer.h" |
| #include "tensorflow/compiler/xla/service/pattern_matcher.h" |
| #include "tensorflow/compiler/xla/service/shape_inference.h" |
| #include "tensorflow/compiler/xla/service/while_loop_analysis.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/compiler/xla/types.h" |
| #include "tensorflow/compiler/xla/union_find.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/compiler/xla/window_util.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/lib/core/bits.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/logging.h" |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h" |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| namespace xla { |
| namespace gpu { |
| |
| namespace { |
| |
| using absl::InlinedVector; |
| using absl::nullopt; |
| using absl::optional; |
| using absl::StrCat; |
| using llvm_ir::IrArray; |
| using llvm_ir::IrName; |
| |
| const auto kDimX = TilingScheme::DimX; |
| const auto kDimY = TilingScheme::DimY; |
| const auto kDimZ = TilingScheme::DimZ; |
| const auto kDimTot = TilingScheme::DimTot; |
| |
| const auto kLinearIndexingX = TilingScheme::LinearIndexingX; |
| const auto kStridedIndexingX = TilingScheme::StridedIndexingX; |
| const auto kStridedLinearIndexingX = TilingScheme::StridedLinearIndexingX; |
| |
| // If a dimensions is smaller than this, untiled transposition may be more |
| // efficient. |
| const int64_t kMinDimensionToTransposeTiled = 16; |
| |
| void AnnotateWithInt32Value(string name, int64_t value, |
| const std::string& kernel_name, |
| llvm::Module* llvm_module) { |
| llvm::NamedMDNode* nvvm_annotations_node = |
| llvm_module->getOrInsertNamedMetadata("nvvm.annotations"); |
| llvm::Function* ir_kernel = llvm_module->getFunction(kernel_name.c_str()); |
| llvm::LLVMContext& llvm_context = llvm_module->getContext(); |
| |
| nvvm_annotations_node->addOperand(llvm::MDNode::get( |
| llvm_context, |
| {llvm::ConstantAsMetadata::get(ir_kernel), |
| llvm::MDString::get(llvm_context, name), |
| llvm::ConstantAsMetadata::get(llvm::ConstantInt::get( |
| llvm::IntegerType::get(llvm_context, /*NumBits=*/32), value))})); |
| } |
| |
| // Annotates the launch dimensions of the corresponding IR kernel in |
| // `llvm_module`. |
| void AnnotateThunkLaunchDimensions(const LaunchDimensions& launch_dims, |
| const std::string& kernel_name, |
| llvm::Module* llvm_module) { |
| // Add __launch_bounds__ to metadata. This limits registers per thread to |
| // avoid out-of-resources launching errors. |
| |
| // Our launch bounds are exact, so we can specify them as |
| // reqntid[xyz] rather than maxntid[xyz]. |
| AnnotateWithInt32Value("reqntidx", launch_dims.thread_counts_per_block().x, |
| kernel_name, llvm_module); |
| if (launch_dims.thread_counts_per_block().y > 1) { |
| AnnotateWithInt32Value("reqntidy", launch_dims.thread_counts_per_block().y, |
| kernel_name, llvm_module); |
| } |
| if (launch_dims.thread_counts_per_block().z > 1) { |
| AnnotateWithInt32Value("reqntidz", launch_dims.thread_counts_per_block().z, |
| kernel_name, llvm_module); |
| } |
| } |
| |
| bool BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements, |
| int64_t v) { |
| mlir::APInt value(sizeof(int64_t) * 8, v, /*isSigned=*/true); |
| return std::binary_search( |
| elements.begin(), elements.end(), value, |
| [](const mlir::APInt& x, const mlir::APInt& y) { return x.slt(y); }); |
| } |
| |
| bool MhloOpIsElementwise(mlir::Operation* op) { |
| CHECK(op->getDialect() == |
| op->getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>()); |
| auto opcode = *MhloToHloOpcode(op); |
| if (HloInstruction::IsOpElementwise(opcode)) { |
| return true; |
| } |
| if (opcode == HloOpcode::kMap) { |
| int iota = 0; |
| for (const llvm::APInt& i : |
| mlir::cast<mlir::mhlo::MapOp>(op).dimensions()) { |
| if (i.getZExtValue() != iota) { |
| return false; |
| } |
| iota++; |
| } |
| return true; |
| } |
| // TODO(timshen): not sure about whether porting |
| // HloFusionInstruction::IsElementwiseImpl() is necessary. HandleFusion() |
| // doesn't use such information. |
| return false; |
| } |
| |
| bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion) { |
| int instruction_count = 0; |
| for (mlir::Operation& instr : fusion.region().front()) { |
| if (mlir::isa<mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp, |
| mlir::memref::TensorLoadOp, mlir::memref::TensorStoreOp>( |
| &instr)) { |
| continue; |
| } |
| instruction_count++; |
| } |
| return instruction_count == 1; |
| } |
| |
| bool MayPreventVectorization(mlir::Operation* op) { |
| // An empirically chosen constant: unrolling concat with a large amount of |
| // arguments causes excessive register spilling. |
| static constexpr int kMaxConcatArgumentsForUnrolling = 10; |
| |
| auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op); |
| const bool is_single_instruction = IsSingleInstructionFusion(fusion); |
| |
| for (mlir::Operation& instr : fusion.region().front()) { |
| if (mlir::isa<mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp, |
| mlir::memref::TensorLoadOp, mlir::memref::TensorStoreOp>( |
| &instr)) { |
| continue; |
| } |
| if (is_single_instruction) { |
| auto instr_opcode = *MhloToHloOpcode(&instr); |
| if (MhloOpIsElementwise(&instr)) { |
| switch (instr_opcode) { |
| case HloOpcode::kSin: |
| case HloOpcode::kCos: |
| case HloOpcode::kPower: |
| case HloOpcode::kAtan2: |
| return true; |
| default: |
| return false; |
| } |
| } else if (instr_opcode == HloOpcode::kReduce && |
| instr.getNumResults() == 1) { |
| // TODO(timshen): check if the to_apply() attribute contains |
| // instructions that break LLVM vectorization. |
| return false; |
| } |
| return true; |
| } |
| |
| CHECK(instr.getDialect() == |
| instr.getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>()) |
| << MlirToString(op); |
| switch (*MhloToHloOpcode(&instr)) { |
| case HloOpcode::kReduceWindow: |
| case HloOpcode::kSort: |
| case HloOpcode::kDot: |
| case HloOpcode::kSin: |
| case HloOpcode::kCos: |
| case HloOpcode::kPower: |
| case HloOpcode::kAtan2: |
| return true; |
| case HloOpcode::kConcatenate: |
| if (instr.getOperands().size() > kMaxConcatArgumentsForUnrolling) { |
| return true; |
| } |
| break; |
| case HloOpcode::kReduce: |
| if (instr.getNumResults() > 1) { |
| return true; |
| } |
| break; |
| default: |
| break; |
| } |
| } |
| return false; |
| } |
| |
| // Computes the maximum valid unroll factor for a given instruction. |
| int ComputeMaxUnrollFactor(mlir::Type type, |
| const HloModuleConfig& hlo_module_config) { |
| int max_unroll_factor = |
| hlo_module_config.debug_options().xla_gpu_max_kernel_unroll_factor(); |
| |
| // Find the largest possible power of two to unroll by. |
| // TODO(kramerb): Make this smarter. |
| |
| auto shaped_type = type.cast<mlir::ShapedType>(); |
| int64_t num_elements = std::accumulate( |
| shaped_type.getShape().begin(), shaped_type.getShape().end(), int64_t{1}, |
| std::multiplies<int64_t>()); |
| for (int i = max_unroll_factor; i > 1; i /= 2) { |
| if (num_elements % i == 0) { |
| return i; |
| } |
| } |
| |
| // Cannot unroll. |
| return 1; |
| } |
| |
| // Computes the maximum valid unroll factor for a given instruction. |
| int ComputeMaxUnrollFactor(mlir::Operation* op, |
| const HloModuleConfig& hlo_module_config) { |
| mlir::Type element_shape = [&] { |
| if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) { |
| return fusion.getFusionRoots()[0]->getResult(0).getType(); |
| } |
| return GetHloOutputs(op)[0].getType(); |
| }(); |
| return ComputeMaxUnrollFactor(element_shape, hlo_module_config); |
| } |
| |
| // Returns the llvm type for the indices used in the kernel that contains the |
| // hlo instruction. Such indices include the index for the parallel loop and |
| // the indices for the tensors accessed by the kernel. The return type is i32 |
| // iff the following conditions are met: |
| // . The launch_size of the kernel is within the range of i32. |
| // . The sizes of all the tensors accessed within the kernel are within the |
| // range of i32. |
| // Otherwise, the return type is i64. |
| llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, |
| int64_t launch_size, llvm::IRBuilder<>* b) { |
| // Find the unnested hlo instruction for which the kernel is generated for. |
| const HloInstruction* unnested_hlo = hlo; |
| const HloComputation* computation = hlo->parent(); |
| if (computation->IsFusionComputation()) { |
| unnested_hlo = computation->FusionInstruction(); |
| } |
| |
| auto shape_in_range = [&](const Shape& s) { |
| bool in_range = true; |
| ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape, |
| const ShapeIndex& /*index*/) { |
| if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { |
| in_range = false; |
| } |
| }); |
| |
| return in_range; |
| }; |
| |
| llvm::Type* i64_ty = b->getInt64Ty(); |
| // Check launch dimension |
| if (!IsInt32(launch_size)) { |
| return i64_ty; |
| } |
| |
| // Check the size of result tensors |
| if (!shape_in_range(unnested_hlo->shape())) { |
| return i64_ty; |
| } |
| |
| auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool { |
| return shape_in_range(operand->shape()); |
| }; |
| |
| // Check the size of input tensors |
| if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { |
| return i64_ty; |
| } |
| |
| // Check the size of the internal result tensors |
| if (unnested_hlo->opcode() == HloOpcode::kFusion) { |
| if (!absl::c_all_of( |
| unnested_hlo->fused_instructions_computation()->instructions(), |
| hlo_shape_in_range)) { |
| return i64_ty; |
| } |
| } |
| |
| return b->getInt32Ty(); |
| } |
| |
| // The same as GetIndexTypeForKernel, but works with MLIR ops. |
| llvm::Type* GetIndexTypeForKernel(mlir::Operation* op, int64_t launch_size, |
| llvm::IRBuilder<>* b) { |
| auto shape_in_range = [&](const Shape& s) { |
| bool in_range = true; |
| ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape, |
| const ShapeIndex& /*index*/) { |
| if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { |
| in_range = false; |
| } |
| }); |
| |
| return in_range; |
| }; |
| |
| llvm::Type* i64_ty = b->getInt64Ty(); |
| // Check launch dimension |
| if (!IsInt32(launch_size)) { |
| return i64_ty; |
| } |
| |
| // Check the size of result tensors |
| for (auto result : GetHloOutputs(op)) { |
| if (!shape_in_range(GetShape(result))) { |
| return i64_ty; |
| } |
| } |
| |
| auto hlo_shape_in_range = [&](mlir::Value operand) -> bool { |
| return shape_in_range(GetShape(operand)); |
| }; |
| |
| // Check the size of input tensors |
| if (!absl::c_all_of(op->getOperands(), hlo_shape_in_range)) { |
| return i64_ty; |
| } |
| |
| // Check the size of the internal result tensors |
| if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) { |
| auto result = fusion.region().walk([&](mlir::Operation* op) { |
| for (mlir::Value result : op->getResults()) { |
| if (!hlo_shape_in_range(result)) { |
| return mlir::WalkResult::interrupt(); |
| } |
| } |
| return mlir::WalkResult::advance(); |
| }); |
| if (result.wasInterrupted()) { |
| return i64_ty; |
| } |
| } |
| |
| return b->getInt32Ty(); |
| } |
| |
| // Gets the input shape of the ROOT slices, which will be used as the kernel |
| // launch dims. The slice input fusion requires the input shapes of the ROOT |
| // slices to be the same although the (slice) output shapes can be different. |
| // |
| // Returns the input shape of the ROOT slices if all the input shapes of ROOT |
| // slices are the same and the slices are non-strided. Otherwise, returns |
| // FailedPrecondition. |
| StatusOr<Shape> GetConsistentInputShapeForRootSlices( |
| const HloComputation* fused_computation) { |
| const HloInstruction& root = *fused_computation->root_instruction(); |
| if (root.opcode() == HloOpcode::kSlice) { |
| return root.operands()[0]->shape(); |
| } |
| |
| CHECK_EQ(root.opcode(), HloOpcode::kTuple); |
| const Shape& first_slice_operand_shape = |
| root.operands()[0]->operands()[0]->shape(); |
| for (size_t i = 1; i < root.operands().size(); ++i) { |
| const HloInstruction* slice = root.operands()[i]; |
| const Shape& operand_shape = slice->operands()[0]->shape(); |
| if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape, |
| operand_shape)) { |
| return FailedPrecondition( |
| "Fused slices do not have the same input shape, fused computation = " |
| "%s.", |
| root.parent()->name()); |
| } |
| } |
| |
| return first_slice_operand_shape; |
| } |
| |
| } // namespace |
| |
| IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, |
| IrEmitterContext* ir_emitter_context) |
| : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false) {} |
| |
| StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create( |
| const HloModuleConfig& hlo_module_config, |
| IrEmitterContext* ir_emitter_context) { |
| return std::unique_ptr<IrEmitterUnnested>( |
| new IrEmitterUnnested(hlo_module_config, ir_emitter_context)); |
| } |
| |
| llvm::Function* IrEmitterUnnested::BuildKernelPrototype( |
| absl::string_view name, absl::Span<const BufferAllocation* const> args) { |
| // Compute the kernel name. The opcode string may contain "-" which cannot be |
| // in a PTX function name, so sanitize the name before uniquifying it. |
| string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( |
| llvm_ir::SanitizeFunctionName(std::string(name))); |
| |
| // Create the kernel and add it to the module. |
| llvm::Module* module = ir_emitter_context_->llvm_module(); |
| llvm::LLVMContext& context = module->getContext(); |
| llvm::FunctionType* kernel_type = llvm::FunctionType::get( |
| /*Result=*/llvm::Type::getVoidTy(context), |
| std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()), |
| /*isVarArg=*/false); |
| llvm::Function* kernel = |
| llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, |
| kernel_name.c_str(), module); |
| |
| // Add dereferenceable and alignment information to each of the kernel's |
| // parameters. |
| auto arg_it = kernel->arg_begin(); |
| for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) { |
| const BufferAllocation* alloc = args[arg_no]; |
| llvm::Argument* fn_arg = &*arg_it; |
| ++arg_it; |
| |
| kernel->addDereferenceableParamAttr(arg_no, alloc->size()); |
| |
| const int64_t alignment = [&] { |
| if (alloc->is_entry_computation_parameter()) { |
| return kEntryParameterAlignBytes; |
| } else if (alloc->is_constant()) { |
| return kConstantBufferAlignBytes; |
| } else { |
| return kXlaAllocatedBufferAlignBytes; |
| } |
| }(); |
| |
| kernel->addParamAttr( |
| arg_no, |
| llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment)); |
| |
| if (alloc->IsPreallocatedTempBuffer()) { |
| fn_arg->setName("temp_buf"); |
| } else { |
| fn_arg->setName(StrCat("alloc", alloc->index())); |
| } |
| } |
| |
| AnnotateFunctionAsGpuKernel(module, kernel, &b_); |
| |
| // TODO(b/65380986): Investigate if adding fast math flags for generated |
| // kernels makes sense. |
| |
| // Update the insert point to the entry basic block. |
| llvm::BasicBlock* entry_bb = |
| llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel); |
| |
| // Emit a "return void" at entry_bb's end, and set the insert point before |
| // that return instruction. |
| b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb)); |
| |
| return kernel; |
| } |
| |
| StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSlice( |
| mlir::Value v, std::string* constant_name) { |
| return xla::gpu::GetAllocationSlice(v, ir_emitter_context_->allocations(), |
| constant_name); |
| } |
| |
| Status IrEmitterUnnested::EmitConstant(mlir::Operation* op) { |
| auto get_global = mlir::cast<mlir::memref::GetGlobalOp>(op); |
| auto module = get_global->getParentOfType<mlir::ModuleOp>(); |
| auto global = mlir::cast<mlir::memref::GlobalOp>( |
| module.lookupSymbol(get_global.name())); |
| |
| auto literal = global.initial_value()->dyn_cast<mlir::DenseElementsAttr>(); |
| TF_RET_CHECK(literal); |
| |
| const bool should_emit_initializer = literal.getType().getNumElements() <= 1; |
| |
| TF_ASSIGN_OR_RETURN(int element_bytes, |
| GetElementTypeBytes(literal.getType().getElementType())); |
| llvm::ArrayType* global_type = llvm::ArrayType::get( |
| b_.getInt8Ty(), literal.getType().getNumElements() * element_bytes); |
| |
| GpuExecutable::ConstantInfo info; |
| llvm::Constant* initializer; |
| if (should_emit_initializer) { |
| std::vector<uint8> content; |
| TF_RETURN_IF_ERROR(CopyDenseElementsDataToXlaFormat(literal, &content)); |
| initializer = llvm::ConstantDataArray::get<uint8>( |
| ir_emitter_context_->llvm_module()->getContext(), content); |
| } else { |
| TF_RETURN_IF_ERROR( |
| CopyDenseElementsDataToXlaFormat(literal, &info.content)); |
| initializer = llvm::ConstantAggregateZero::get(global_type); |
| } |
| |
| // These globals will be looked up by name by GpuExecutable so we need to |
| // give them an external linkage. Not all of their uses are visible in |
| // the LLVM IR so we can't give then a linkage that merely preserves their |
| // names (like available_externally), we also need to ensure that they stick |
| // around even if they're "unused". |
| // |
| // We may have to be more clever here in the future if we notice that we're |
| // keeping around too many globals because of their linkage. |
| llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( |
| global_type, /*isConstant=*/should_emit_initializer, |
| llvm::GlobalValue::ExternalLinkage, |
| /*Initializer=*/initializer, global.sym_name(), |
| /*TLMode=*/llvm::GlobalValue::NotThreadLocal, |
| /*AddressSpace=*/0, |
| /*isExternallyInitialized=*/false); |
| global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes)); |
| ir_emitter_context_->llvm_module()->getGlobalList().push_back( |
| global_for_const); |
| |
| info.symbol_name.assign(global.sym_name().begin(), global.sym_name().end()); |
| |
| info.allocation_index = |
| global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt(); |
| ir_emitter_context_->constants().push_back(std::move(info)); |
| return Status::OK(); |
| } |
| |
| static ConditionalThunkConfig GetConditionalThunkConfig( |
| mlir::lmhlo::CaseOp op, std::vector<ThunkSequence> branch_thunk_sequences) { |
| ConditionalThunkConfig config; |
| config.branch_index_is_bool = |
| op.index().getType().cast<mlir::ShapedType>().getElementType().isInteger( |
| /*width=*/1); |
| config.branch_count = op.branches().size(); |
| // Pass nullptr as the HloInstruction* to the branch_thunks |
| // constructors because these SequentialThunks are logically "part of" |
| // this ConditionalThunk, and shouldn't be profiled separately from it. |
| config.branch_thunks.reserve(branch_thunk_sequences.size()); |
| for (auto& branch_thunk_sequence : branch_thunk_sequences) { |
| config.branch_thunks.emplace_back(new SequentialThunk( |
| Thunk::ThunkInfo(), std::move(branch_thunk_sequence))); |
| } |
| return config; |
| } |
| |
| Status IrEmitterUnnested::EmitConditional(mlir::Operation* op) { |
| auto conditional = mlir::cast<mlir::lmhlo::CaseOp>(op); |
| |
| std::vector<ThunkSequence> branch_thunks; |
| |
| int branch_count = conditional.branches().size(); |
| branch_thunks.reserve(branch_count); |
| |
| for (int j = 0; j < branch_count; ++j) { |
| mlir::Region* branch_computation = &conditional.branches()[j]; |
| TF_ASSIGN_OR_RETURN( |
| auto ir_emitter, |
| IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_)); |
| TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(branch_computation)); |
| branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence())); |
| } |
| |
| ConditionalThunkConfig config = |
| GetConditionalThunkConfig(conditional, std::move(branch_thunks)); |
| |
| TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(conditional.index())); |
| AddThunkToThunkSequence(std::unique_ptr<Thunk>( |
| new ConditionalThunk(GetThunkInfo(op), std::move(config), slice))); |
| return Status::OK(); |
| } |
| |
| // Input = {dynamic array(with dynamic dimension meta data at the end)} |
| // Output = {static array, dynamic_dim0, dynamic_dim1} |
| Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { |
| // TODO(jurahul): Create an op to represent PadToStatic. |
| auto pad_to_static = mlir::cast<mlir::lmhlo::CustomCallOp>(op); |
| int unroll_factor = 1; |
| std::string ir_name = mlir::GetNameFromLoc(pad_to_static.getLoc()); |
| |
| const Shape& input_shape = GetShape(pad_to_static.args().front()); |
| TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, |
| CalculateLaunchDimensions( |
| input_shape, ir_emitter_context_->gpu_device_info(), |
| {unroll_factor})); |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| TF_ASSIGN_OR_RETURN(auto kernel_thunk, |
| BuildKernelThunk(pad_to_static, GetThunkInfo(op), |
| &ir_arrays, launch_dimensions)); |
| |
| const llvm_ir::IrArray source_array = ir_arrays[0]; |
| const llvm_ir::IrArray output_array = ir_arrays[1]; |
| auto output_dim_arrays = |
| absl::Span<const llvm_ir::IrArray>(ir_arrays).subspan(2); |
| |
| // pseudo code for PadToStatic on a 2d array |
| // int* source_array = input[0]; |
| // int* dest_array = output[0]; |
| llvm::Value* source_buffer = source_array.GetBasePointer(); |
| llvm::Value* raw_buffer = |
| b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo()); |
| |
| // TODO(jurahul): input_shape here is the static shape of the input (which has |
| // a dynamic shape in XLA). Currently, we are mapping that to a static shaped |
| // memref. When we change that to a more appropriate representation in MLIR, |
| // fix this code to correctly deduce the static shape backing the dynamically |
| // shaped memref. |
| int64_t raw_data_size = ShapeUtil::ByteSizeOf(input_shape); |
| |
| // int* dyn_dim0_size = source_array + meta_data_offset; |
| // int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int); |
| std::vector<llvm::Value*> dynamic_dims; |
| for (int64_t i = 1; i < pad_to_static.output().size(); ++i) { |
| // Dynamic size of each dimension is attached at the end of the source |
| // array(operand(0)). We need to extract these value. |
| const Shape& dim_shape = GetShape(pad_to_static.output()[i]); |
| TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32))); |
| |
| const int64_t dim_index = i - 1; |
| llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( |
| b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32)); |
| llvm::Value* dyn_dim_size = b_.CreateLoad( |
| b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()), |
| "dyn_dim_size"); |
| dynamic_dims.push_back(dyn_dim_size); |
| } |
| |
| // only one thread need to store the dynamic index |
| // int thread_id = GetThreadId(); |
| // int block_id = GetBlockId(); |
| // if (thread_id == 0 && block_id == 0) { |
| // *output[1] = *dyn_dim0_size; |
| // *output[2] = *dyn_dim1_size; |
| // } |
| KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] { |
| for (int64_t i = 1; i < pad_to_static.output().size(); ++i) { |
| const int64_t dim_index = i - 1; |
| llvm::Value* dest_dim_size_address = |
| output_dim_arrays[dim_index].GetBasePointer(); |
| // output[i] stores dynamic_dim_(i-1) |
| b_.CreateStore(dynamic_dims[i - 1], |
| b_.CreateBitCast(dest_dim_size_address, |
| b_.getInt32Ty()->getPointerTo())); |
| } |
| }); |
| |
| // int dyn_element_total = 1; |
| // dyn_element_total *= *dyn_dim0_size; |
| // dyn_element_total *= *dyn_dim1_size; |
| llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1); |
| for (llvm::Value* dynamic_dim : dynamic_dims) { |
| dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim, |
| /*Name=*/"dyn_element_total"); |
| } |
| |
| // linear_index = block_id * threads_per_block + thread_id; |
| // if (linear_index < max_num_element) { |
| // Index static_index = |
| // delinerized(linerized_index, static_dim0_size, static_dim1_size); |
| // if (linerized_index < dyn_element_total) { |
| // Index dyn_index = |
| // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size); |
| // dest_array[dyn_index.dim0][dyn_index.dim1] = |
| // source_array[static_index.dim0][static_index.dim1]; |
| // } |
| // } |
| llvm_ir::LoopEmitter::BodyEmitter body_generator = |
| [&](const llvm_ir::IrArray::Index& array_index) -> Status { |
| llvm::Value* linearIndex = |
| array_index.Linearize(input_shape.dimensions(), &b_); |
| auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse( |
| b_.CreateICmpULT(linearIndex, dyn_element_total), |
| llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false); |
| // Set IR builder insertion point to the body of the if structure. |
| llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_); |
| llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape, |
| absl::MakeSpan(dynamic_dims), &b_); |
| output_array.EmitWriteArrayElement( |
| dyn_index, |
| source_array.EmitReadArrayElement(array_index, &b_, /*name=*/""), &b_, |
| /*use_linear_index=*/false); |
| return Status::OK(); |
| }; |
| |
| const Shape& data_shape = GetShape(pad_to_static.output().front()); |
| TF_RETURN_IF_ERROR( |
| ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_, |
| {unroll_factor}) |
| .EmitLoop(ir_name, |
| GetIndexTypeForKernel( |
| pad_to_static, launch_dimensions.launch_bound(), &b_))); |
| thunk_sequence_.emplace_back(std::move(kernel_thunk)); |
| return Status::OK(); |
| } |
| |
| // Input = {dynamic array(with dynamic dimension meta data at the end)} |
| // Output = {static array, dynamic_dim0, dynamic_dim1} |
| Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) { |
| // TODO(jurahul): Create an op to represent SliceToDynamic. |
| auto slice_to_dynamic = mlir::cast<mlir::lmhlo::CustomCallOp>(op); |
| int unroll_factor = 1; |
| std::string ir_name = mlir::GetNameFromLoc(slice_to_dynamic.getLoc()); |
| |
| const Shape& input_shape = GetShape(slice_to_dynamic.args().front()); |
| TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, |
| CalculateLaunchDimensions( |
| input_shape, ir_emitter_context_->gpu_device_info(), |
| {unroll_factor})); |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| TF_ASSIGN_OR_RETURN(auto kernel_thunk, |
| BuildKernelThunk(slice_to_dynamic, GetThunkInfo(op), |
| &ir_arrays, launch_dimensions)); |
| |
| TF_RET_CHECK(slice_to_dynamic.output().size() == 1); |
| const Shape& data_shape = GetShape(slice_to_dynamic.output().front()); |
| |
| // TODO(jurahul): data_shape here is the static shape of the output (which has |
| // a dynamic shape in XLA). Currently, we are mapping that to a static shaped |
| // memref. When we change that to a more appropriate representation in MLIR, |
| // fix this code to correctly deduce the static shape backing the dynamically |
| // shaped memref. |
| |
| // calculate the location where metadata needs to be inserted |
| // int* dyn_dim0_size = dest_array + meta_data_offset; |
| // int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int); |
| int32_t raw_data_size = ShapeUtil::ByteSizeOf(data_shape); |
| |
| // pseudo code for sliceToDynamic on a 2d array |
| // int* source_array = input[0]; |
| // int* dest_array = output[0]; |
| const llvm_ir::IrArray data_array = ir_arrays.back(); |
| llvm::Value* dest_buffer = data_array.GetBasePointer(); |
| llvm::Value* raw_buffer = |
| b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo()); |
| |
| // Load dynamic dimensions from memory. |
| std::vector<llvm::Value*> dynamic_dims; |
| for (int64_t i = 1; i < slice_to_dynamic.args().size(); ++i) { |
| // const int64_t dim_index = i - 1; |
| llvm::Value* source_buffer = ir_arrays[i].GetBasePointer(); |
| llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size"); |
| dynamic_dims.push_back(dyn_dim_size); |
| } |
| |
| // only one thread need to store the dynamic index |
| // int thread_id = GetThreadId(); |
| // int block_id = GetBlockId(); |
| // if (thread_id == 0 && block_id == 0) { |
| // *dyn_dim0_size = *output[1]; |
| // *dyn_dim1_size = *output[2]; |
| // } |
| KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] { |
| for (int64_t i = 1; i < slice_to_dynamic.args().size(); ++i) { |
| const int64_t dim_index = i - 1; |
| llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( |
| b_.getInt8Ty(), raw_buffer, |
| raw_data_size + dim_index * sizeof(int32)); |
| // output[i] stores dynamic_dim_(i-1) |
| b_.CreateStore( |
| dynamic_dims[dim_index], |
| b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo())); |
| } |
| }); |
| |
| // int dyn_element_total = 1; |
| // dyn_element_total *= dyn_dim0_size; |
| // dyn_element_total *= dyn_dim1_size; |
| llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1); |
| for (llvm::Value* dynamic_dim : dynamic_dims) { |
| dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim, |
| /*Name=*/"dyn_element_total"); |
| } |
| |
| // linear_index = block_id * threads_per_block + thread_id; |
| // if (linear_index < max_num_element) { |
| // Index static_index = |
| // delinerized(linerized_index, static_dim0_size, static_dim1_size); |
| // if (linerized_index < dyn_element_total) { |
| // Index dyn_index = |
| // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size); |
| // dest_array[static_index.dim0][static_index.di] = |
| // source_array[dyn_index.dim0][dyn_index.dim1]; |
| // } |
| // } |
| llvm_ir::LoopEmitter::BodyEmitter body_generator = |
| [&](const llvm_ir::IrArray::Index& array_index) -> Status { |
| llvm::Value* linearIndex = |
| array_index.Linearize(input_shape.dimensions(), &b_); |
| auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse( |
| b_.CreateICmpULT(linearIndex, dyn_element_total), |
| llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false); |
| // Set IR builder insertion point to the body of the if structure. |
| llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_); |
| llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape, |
| absl::MakeSpan(dynamic_dims), &b_); |
| |
| data_array.EmitWriteArrayElement( |
| array_index, |
| ir_arrays[0].EmitReadArrayElement(dyn_index, &b_, /*name=*/"", |
| /*use_linear_index=*/false), |
| &b_); |
| return Status::OK(); |
| }; |
| |
| TF_RETURN_IF_ERROR( |
| ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_, |
| {unroll_factor}) |
| .EmitLoop(ir_name, GetIndexTypeForKernel( |
| slice_to_dynamic, |
| launch_dimensions.launch_bound(), &b_))); |
| thunk_sequence_.emplace_back(std::move(kernel_thunk)); |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) { |
| using mlir::dyn_cast; |
| using mlir::lmhlo_gpu::Activation; |
| using mlir::lmhlo_gpu::ConvBackwardFilterOp; |
| using mlir::lmhlo_gpu::ConvBackwardInputOp; |
| using mlir::lmhlo_gpu::ConvForwardFusedOp; |
| using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp; |
| using mlir::lmhlo_gpu::ConvForwardOp; |
| |
| // Last 2 operands of the convolution operation are the result and scratch. |
| std::vector<BufferAllocation::Slice> operand_slices; |
| int64_t num_operands = op->getNumOperands(); |
| operand_slices.reserve(num_operands - 2); |
| for (mlir::Value operand : op->getOperands().drop_back(2)) { |
| TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand)); |
| operand_slices.push_back(slice); |
| } |
| |
| mlir::Value conv_result = op->getOperand(num_operands - 2); |
| mlir::Value scratch_result = op->getOperand(num_operands - 1); |
| TF_ASSIGN_OR_RETURN(auto conv_result_slice, GetAllocationSlice(conv_result)); |
| TF_ASSIGN_OR_RETURN(auto scratch_slice, GetAllocationSlice(scratch_result)); |
| |
| auto apply_layout = [](const Shape& shape, mlir::ArrayAttr layout_attrib) { |
| mlir::SmallVector<int64_t, 4> minor_to_major = llvm::to_vector<4>( |
| llvm::map_range(layout_attrib, [](mlir::Attribute a) -> int64_t { |
| return static_cast<int64_t>(a.cast<mlir::IntegerAttr>().getInt()); |
| })); |
| return ShapeUtil::MakeShapeWithLayout(shape.element_type(), |
| shape.dimensions(), minor_to_major); |
| }; |
| |
| GpuConvDescriptor descriptor; |
| |
| auto fill_conv_descriptor = [&](auto op) { |
| descriptor.operand0_shape = apply_layout( |
| GetShape(op->getOperand(0)), op.backend_config().operand_0_layout()); |
| descriptor.operand1_shape = apply_layout( |
| GetShape(op->getOperand(1)), op.backend_config().operand_1_layout()); |
| descriptor.result_shape = apply_layout(GetShape(conv_result), |
| op.backend_config().result_layout()); |
| descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers()); |
| descriptor.scratch_size = scratch_slice.size(); |
| mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue(); |
| mlir::DenseIntElementsAttr padding = op.padding().getValue(); |
| mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue(); |
| mlir::DenseIntElementsAttr rhs_dilation = op.rhs_dilation().getValue(); |
| mlir::DenseElementsAttr window_reversal = op.window_reversal().getValue(); |
| for (auto index : llvm::seq<int>(0, window_strides.getNumElements())) { |
| WindowDimension* dim = descriptor.window.add_dimensions(); |
| // Window size for a convolution is the same as the kernel size. |
| // Kernel size of the convolution is operand1_shape. We need to look at |
| // the convolution dimension numbers kernel spatial dimensions to get |
| // the window size. |
| int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index); |
| dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim)); |
| dim->set_stride(window_strides.getValue<int64_t>(index)); |
| dim->set_padding_low(padding.getValue<int64_t>(index)); |
| dim->set_padding_high(padding.getValue<int64_t>(index)); |
| dim->set_base_dilation(lhs_dilation.getValue<int64_t>(index)); |
| dim->set_window_dilation(rhs_dilation.getValue<int64_t>(index)); |
| dim->set_window_reversal(window_reversal.getValue<bool>(index)); |
| } |
| descriptor.feature_group_count = op.feature_group_count(); |
| descriptor.backend_config.set_algorithm( |
| op.backend_config().algorithm().getInt()); |
| descriptor.backend_config.set_tensor_ops_enabled( |
| op.backend_config().tensor_ops_enabled().getValue()); |
| descriptor.backend_config.set_conv_result_scale( |
| op.result_scale().convertToDouble()); |
| }; |
| |
| auto set_activation_mode = [&](auto op) -> Status { |
| TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode, |
| ConvertConvActivationMode(op.activation_mode())); |
| descriptor.backend_config.set_activation_mode( |
| static_cast<int64_t>(activation_mode)); |
| return Status::OK(); |
| }; |
| |
| if (auto conv = dyn_cast<ConvForwardOp>(op)) { |
| descriptor.kind = CudnnConvKind::kForward; |
| fill_conv_descriptor(conv); |
| } else if (auto conv = dyn_cast<ConvBackwardInputOp>(op)) { |
| descriptor.kind = CudnnConvKind::kBackwardInput; |
| fill_conv_descriptor(conv); |
| } else if (auto conv = dyn_cast<ConvBackwardFilterOp>(op)) { |
| descriptor.kind = CudnnConvKind::kBackwardFilter; |
| fill_conv_descriptor(conv); |
| } else if (auto conv = dyn_cast<ConvForwardFusedOp>(op)) { |
| descriptor.kind = CudnnConvKind::kForwardActivation; |
| fill_conv_descriptor(conv); |
| TF_RETURN_IF_ERROR(set_activation_mode(conv)); |
| } else if (auto conv = dyn_cast<ConvForwardFusedSideInputOp>(op)) { |
| descriptor.kind = CudnnConvKind::kForwardActivation; |
| fill_conv_descriptor(conv); |
| TF_RETURN_IF_ERROR(set_activation_mode(conv)); |
| descriptor.backend_config.set_side_input_scale( |
| conv.side_input_scale().convertToDouble()); |
| } else { |
| return InternalError("Unexpected operation"); |
| } |
| TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, "")); |
| AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>( |
| GetThunkInfo(op), std::move(config), std::move(operand_slices), |
| conv_result_slice, scratch_slice)); |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitGemmThunk(mlir::Operation* op) { |
| auto make_bef_thunk = |
| [&](auto op, absl::optional<BufferAllocation::Slice> bias = |
| absl::nullopt) -> StatusOr<std::unique_ptr<Thunk>> { |
| TF_ASSIGN_OR_RETURN(auto lhs, GetAllocationSlice(op.lhs())); |
| TF_ASSIGN_OR_RETURN(auto rhs, GetAllocationSlice(op.rhs())); |
| TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(op.output())); |
| std::vector<BufferAllocation::Slice> buffers = {lhs, rhs}; |
| if (bias.has_value()) { |
| buffers.push_back(bias.value()); |
| } |
| buffers.push_back(output); |
| return CreateBefThunk(GetThunkInfo(op), op, std::move(buffers)); |
| }; |
| |
| auto make_gemm_thunk = |
| [&](auto op, absl::optional<double> gemm_bias_beta = absl::nullopt, |
| bool implements_whole_instruction = |
| true) -> StatusOr<std::unique_ptr<Thunk>> { |
| TF_ASSIGN_OR_RETURN(auto lhs, GetAllocationSlice(op.lhs())); |
| TF_ASSIGN_OR_RETURN(auto rhs, GetAllocationSlice(op.rhs())); |
| TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(op.output())); |
| |
| GpuGemmConfig config; |
| GemmBackendConfig& backend = config.backend_config; |
| config.output_shape = GetShape(op.output()); |
| config.lhs_shape = GetShape(op.lhs()); |
| config.rhs_shape = GetShape(op.rhs()); |
| backend.Clear(); |
| if (op.algorithm()) { |
| backend.set_selected_algorithm(*op.algorithm()); |
| } |
| backend.set_alpha_real(op.alpha_real().convertToDouble()); |
| backend.set_alpha_imag(op.alpha_imag().convertToDouble()); |
| backend.set_batch_size(op.batch_size()); |
| if (gemm_bias_beta.has_value()) { |
| backend.set_beta(gemm_bias_beta.value()); |
| } |
| backend.set_lhs_stride(op.lhs_stride()); |
| backend.set_rhs_stride(op.rhs_stride()); |
| |
| auto& dims = *backend.mutable_dot_dimension_numbers(); |
| auto mlir_dims = op.dot_dimension_numbers(); |
| |
| auto fill_dims = [](llvm::ArrayRef<int64_t> mlir_dim, auto* config_attrs) { |
| for (int64_t e : mlir_dim) config_attrs->Add(e); |
| }; |
| fill_dims(mlir_dims.getLhsBatchingDimensions(), |
| dims.mutable_lhs_batch_dimensions()); |
| fill_dims(mlir_dims.getRhsBatchingDimensions(), |
| dims.mutable_rhs_batch_dimensions()); |
| fill_dims(mlir_dims.getLhsContractingDimensions(), |
| dims.mutable_lhs_contracting_dimensions()); |
| fill_dims(mlir_dims.getRhsContractingDimensions(), |
| dims.mutable_rhs_contracting_dimensions()); |
| |
| return std::unique_ptr<Thunk>( |
| new GemmThunk(GetThunkInfo(op), std::move(config), lhs, rhs, output, |
| implements_whole_instruction)); |
| }; |
| |
| TF_ASSIGN_OR_RETURN(auto thunk, [&]() -> StatusOr<std::unique_ptr<Thunk>> { |
| if (auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMMOp>(op)) { |
| // TODO(loreno): TFRT support for zero-strided gemm calls |
| if (IsBefThunkEnabled() && gemm.lhs_stride() && gemm.rhs_stride()) |
| return make_bef_thunk(gemm); |
| return make_gemm_thunk(gemm); |
| } |
| |
| if (auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMM_BiasOp>(op)) { |
| double gemm_bias_beta = gemm.beta().convertToDouble(); |
| TF_ASSIGN_OR_RETURN(auto bias, GetAllocationSlice(gemm.bias())); |
| TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(gemm.output())); |
| |
| // TODO(loreno): TFRT support for zero-strided gemm calls |
| if (IsBefThunkEnabled() && gemm.lhs_stride() && gemm.rhs_stride()) |
| return make_bef_thunk(gemm, bias); |
| |
| // The bias is passed inside the output buffer. If those buffers are |
| // shared we can just use it, otherwise copy the bias values into the |
| // output buffer first. |
| if (bias == output) { |
| return make_gemm_thunk(gemm, gemm_bias_beta); |
| } |
| |
| ThunkSequence thunks; |
| thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( |
| Thunk::ThunkInfo(), |
| /*source_buffer=*/bias, |
| /*destination_buffer=*/output, |
| /*mem_size=*/ |
| ShapeUtil::ByteSizeOf(GetShape(gemm.output())))); |
| TF_ASSIGN_OR_RETURN( |
| auto thunk, make_gemm_thunk(gemm, gemm_bias_beta, |
| /*implements_whole_instruction=*/false)); |
| thunks.push_back(std::move(thunk)); |
| return std::unique_ptr<Thunk>( |
| new SequentialThunk(GetThunkInfo(op), std::move(thunks))); |
| } |
| |
| return tensorflow::errors::Internal("Unexpected op."); |
| }()); |
| |
| AddThunkToThunkSequence(std::move(thunk)); |
| return Status::OK(); |
| } |
| |
| namespace { |
| // An MLIR value and its name as defined in the ODS spec. |
| struct NamedValue { |
| mlir::Value value; |
| absl::string_view name; |
| }; |
| |
| // Verifies that the given batch norm is well formed for thunk emission. This |
| // requires that all statistics operands (mean, stddev etc) are F32 types and |
| // all the non-statistics operands need to match in shape, element type, and |
| // layout (which maps to them having the same memref type). |
| Status VerifyBatchNormForThunkEmission( |
| mlir::ArrayRef<NamedValue> statistics_operands, |
| mlir::ArrayRef<NamedValue> other_operands) { |
| for (const NamedValue& v : statistics_operands) { |
| // Note: MLIR verification will ensure that the operands of the batchnorm |
| // LHLO are valid memref types. |
| if (!v.value.getType().cast<mlir::MemRefType>().getElementType().isF32()) { |
| return Unimplemented("Operand %s of batch norm should have F32 type", |
| v.name); |
| } |
| } |
| if (other_operands.empty()) { |
| return Status::OK(); |
| } |
| |
| mlir::Type first_type = other_operands.front().value.getType(); |
| absl::string_view first_name = other_operands.front().name; |
| |
| for (const NamedValue& v : other_operands.drop_front(1)) { |
| if (v.value.getType() != first_type) { |
| return Unimplemented("%s and %s for batch norm should have same types", |
| v.name, first_name); |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| // Determine if we enable the row optimized codegen. When we have a |
| // fusion with only point-wise operations, scalar broadcasting and row |
| // broadcasting, we can trigger a kernel that vectorize the row loads. |
| // This speed up the kernel, in particular on A100. |
| // Returns a pair<bool, int>. The bool mean should we try to enable |
| // row vectorization. The int is the number of inputs with the higher |
| // rank. |
| std::pair<bool, int> RowVectorizationEnabled(mlir::lmhlo::FusionOp fusion) { |
| const auto is_row_major = [](mlir::Value value) { |
| // Only tested when the inputs are row-major. So only |
| // enable that case. Maybe it would works if only the |
| // inner dimensions is contiguous. |
| return LayoutUtil::IsMonotonicWithDim0Major(GetShape(value).layout()); |
| }; |
| bool row_vectorized = |
| fusion.getFusionResults().size() == 1 && // Not tested with MOF. |
| absl::c_all_of(GetHloOperands(fusion), is_row_major) && |
| absl::c_all_of(GetHloOutputs(fusion), is_row_major); |
| |
| // Check that the operations in the fusion are supported. Each |
| // supported operation (or category) must be manually vetted as XLA |
| // only unrolls and relies on LLVM to vectorize. But this is brittle. |
| // Currently tested and supported operations: |
| // Elementwise, scalar and row broadcasting. |
| // |
| // We also detect at the same time if there is a row broadcasting |
| // operation. |
| bool some_row_broadcasting = false; |
| auto out_rank = |
| fusion.getFusionResults()[0].getType().cast<mlir::ShapedType>().getRank(); |
| int num_big_inputs = 0; |
| for (mlir::Operation& op : fusion.region().front()) { |
| if (auto load = mlir::dyn_cast<mlir::memref::TensorLoadOp>(op)) { |
| auto rank = load.getResult().getType().cast<mlir::ShapedType>().getRank(); |
| num_big_inputs += static_cast<int>(rank == out_rank); |
| continue; |
| } else if (mlir::isa<mlir::memref::TensorStoreOp, mlir::lmhlo::TerminatorOp, |
| mlir::mhlo::ReturnOp, mlir::mhlo::ConstOp, |
| mlir::lmhlo::ConstOp>(op)) { |
| continue; |
| } |
| HloOpcode opcode = *MhloToHloOpcode(&op); |
| if (HloInstruction::IsOpElementwise(opcode)) { |
| continue; |
| } |
| |
| if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastInDimOp>(op)) { |
| if (broadcast.broadcast_dimensions().size() == 0) { |
| continue; |
| } |
| std::vector<int64_t> broadcast_dimensions; |
| for (const llvm::APInt& int_value : broadcast.broadcast_dimensions()) { |
| broadcast_dimensions.push_back(int_value.getSExtValue()); |
| } |
| |
| auto rank = GetShape(broadcast.getResult()).rank(); |
| if (broadcast_dimensions.size() == 1 && |
| broadcast_dimensions.back() == (rank - 1)) { |
| some_row_broadcasting = true; |
| continue; |
| } |
| } |
| VLOG(2) << "Row vectorization not enabled due to this op: " |
| << MlirToString(&op); |
| return std::make_pair(false, 0); |
| } |
| // Trigger only when there is a row broadcasting. |
| return std::make_pair(row_vectorized && some_row_broadcasting, |
| num_big_inputs); |
| } |
| } // namespace |
| |
| Status IrEmitterUnnested::EmitBatchNormThunk(mlir::Operation* op) { |
| auto get_batch_norm_config = [](auto op, mlir::Value output) { |
| CudnnBatchNormConfig config; |
| config.output_shape = GetShape(output); |
| config.output_type = config.output_shape.element_type(); |
| config.epsilon = op.epsilon().convertToFloat(); |
| config.feature_index = op.feature_index(); |
| return config; |
| }; |
| |
| // The statistics operands for batch norm operations need to be FP32 type. |
| // And the rest of the operands need match in shape, layout, and element type |
| // to match. |
| if (auto bn_train = |
| mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormTrainingOp>(op)) { |
| TF_RETURN_IF_ERROR(VerifyBatchNormForThunkEmission( |
| /*statistics_operands=*/ |
| {{bn_train.scale(), "scale"}, |
| {bn_train.offset(), "offset"}, |
| {bn_train.batch_mean(), "batch_mean"}, |
| {bn_train.batch_stddev(), "batch_stddev"}}, |
| /*other_operands=*/ |
| {{bn_train.operand(), "operand"}, {bn_train.output(), "output"}})); |
| TF_ASSIGN_OR_RETURN(auto operand, GetAllocationSlice(bn_train.operand())); |
| TF_ASSIGN_OR_RETURN(auto scale, GetAllocationSlice(bn_train.scale())); |
| TF_ASSIGN_OR_RETURN(auto offset, GetAllocationSlice(bn_train.offset())); |
| |
| // BatchNormTraining returns a tuple of three elements: data, calculated |
| // mean, and calculated 1/sqrt(variance + epsilon). |
| TF_ASSIGN_OR_RETURN(auto output_data, |
| GetAllocationSlice(bn_train.output())); |
| TF_ASSIGN_OR_RETURN(auto output_mean, |
| GetAllocationSlice(bn_train.batch_mean())); |
| TF_ASSIGN_OR_RETURN(auto output_inv_stddev, |
| GetAllocationSlice(bn_train.batch_stddev())); |
| |
| AddThunkToThunkSequence( |
| absl::make_unique<CudnnBatchNormForwardTrainingThunk>( |
| GetThunkInfo(op), |
| /*config=*/get_batch_norm_config(bn_train, bn_train.output()), |
| /*operand=*/operand, |
| /*scale=*/scale, |
| /*offset=*/offset, |
| /*output_data=*/output_data, |
| /*output_mean=*/output_mean, |
| /*output_inv_stddev=*/output_inv_stddev)); |
| return Status::OK(); |
| } |
| |
| if (auto bn_grad = mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormGradOp>(op)) { |
| TF_RETURN_IF_ERROR(VerifyBatchNormForThunkEmission( |
| /*statistics_operands=*/ |
| {{bn_grad.scale(), "scale"}, |
| {bn_grad.mean(), "mean"}, |
| {bn_grad.stddev(), "stddev"}, |
| {bn_grad.grad_scale(), "grad_scale"}, |
| {bn_grad.grad_offset(), "grad_offset"}}, |
| /*other_operands=*/ |
| {{bn_grad.operand(), "operand"}, |
| {bn_grad.grad_output(), "grad_output"}, |
| {bn_grad.grad_operand(), "grad_operand"}})); |
| |
| TF_ASSIGN_OR_RETURN(auto operand, GetAllocationSlice(bn_grad.operand())); |
| TF_ASSIGN_OR_RETURN(auto scale, GetAllocationSlice(bn_grad.scale())); |
| TF_ASSIGN_OR_RETURN(auto mean, GetAllocationSlice(bn_grad.mean())); |
| TF_ASSIGN_OR_RETURN(auto inv_stddev, GetAllocationSlice(bn_grad.stddev())); |
| TF_ASSIGN_OR_RETURN(auto grad_output, |
| GetAllocationSlice(bn_grad.grad_output())); |
| |
| // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale, |
| // grad_offset. |
| TF_ASSIGN_OR_RETURN(auto output_grad_data, |
| GetAllocationSlice(bn_grad.grad_operand())); |
| TF_ASSIGN_OR_RETURN(auto output_grad_scale, |
| GetAllocationSlice(bn_grad.grad_scale())); |
| TF_ASSIGN_OR_RETURN(auto output_grad_offset, |
| GetAllocationSlice(bn_grad.grad_offset())); |
| |
| CudnnBatchNormConfig config; |
| config.output_shape = GetShape(bn_grad.grad_output()); |
| config.output_type = config.output_shape.element_type(); |
| config.epsilon = bn_grad.epsilon().convertToFloat(); |
| config.feature_index = bn_grad.feature_index(); |
| |
| AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>( |
| GetThunkInfo(op), |
| /*config=*/get_batch_norm_config(bn_grad, bn_grad.grad_output()), |
| /*operand=*/operand, |
| /*scale=*/scale, |
| /*mean=*/mean, |
| /*inv_stddev=*/inv_stddev, |
| /*grad_output=*/grad_output, |
| /*output_grad_data=*/output_grad_data, |
| /*output_grad_scale=*/output_grad_scale, |
| /*output_grad_offset=*/output_grad_offset)); |
| return Status::OK(); |
| } |
| |
| if (auto bn_inference = |
| mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormInferenceOp>(op)) { |
| TF_RETURN_IF_ERROR( |
| VerifyBatchNormForThunkEmission(/*statistics_operands=*/ |
| {{bn_inference.scale(), "scale"}, |
| {bn_inference.offset(), "offset"}, |
| {bn_inference.mean(), "mean"}, |
| {bn_inference.stddev(), "stddev"}}, |
| /*other_operands=*/ |
| {{bn_inference.operand(), "operand"}, |
| {bn_inference.output(), "output"}})); |
| |
| TF_ASSIGN_OR_RETURN(auto operand, |
| GetAllocationSlice(bn_inference.operand())); |
| TF_ASSIGN_OR_RETURN(auto scale, GetAllocationSlice(bn_inference.scale())); |
| TF_ASSIGN_OR_RETURN(auto offset, GetAllocationSlice(bn_inference.offset())); |
| TF_ASSIGN_OR_RETURN(auto mean, GetAllocationSlice(bn_inference.mean())); |
| TF_ASSIGN_OR_RETURN(auto variance, |
| GetAllocationSlice(bn_inference.stddev())); |
| TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(bn_inference.output())); |
| |
| AddThunkToThunkSequence(absl::make_unique< |
| CudnnBatchNormForwardInferenceThunk>( |
| GetThunkInfo(op), |
| /*config=*/get_batch_norm_config(bn_inference, bn_inference.output()), |
| /*operand=*/operand, |
| /*scale=*/scale, |
| /*offset=*/offset, |
| /*mean=*/mean, |
| /*variance=*/variance, |
| /*output=*/output)); |
| return Status::OK(); |
| } |
| |
| return Unimplemented("Unsupported batch norm operation"); |
| } |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| Status IrEmitterUnnested::EmitCholeskyThunk(mlir::Operation* op) { |
| auto cholesky_op = mlir::cast<mlir::lmhlo_gpu::CholeskyOp>(op); |
| |
| const Shape shape = GetShape(cholesky_op.input()); |
| int ndim = shape.dimensions_size(); |
| CHECK_GE(ndim, 2); |
| int64_t n = shape.dimensions(ndim - 1); |
| |
| const auto& dims = shape.dimensions(); |
| int64_t batch_size = |
| std::accumulate(dims.begin(), dims.end() - 2, int64_t{1}, |
| [](int64_t a, int64_t b) { return a * b; }); |
| |
| TF_ASSIGN_OR_RETURN(auto operand_buffer, |
| GetAllocationSlice(cholesky_op.input())); |
| TF_ASSIGN_OR_RETURN(auto a_buffer, GetAllocationSlice(cholesky_op.output())); |
| TF_ASSIGN_OR_RETURN(auto workspace_buffer, |
| GetAllocationSlice(cholesky_op.scratch())); |
| TF_ASSIGN_OR_RETURN(auto info_buffer, GetAllocationSlice(cholesky_op.info())); |
| |
| if (IsBefThunkEnabled()) { |
| std::vector<BufferAllocation::Slice> buffers = { |
| operand_buffer, a_buffer, workspace_buffer, info_buffer}; |
| TF_ASSIGN_OR_RETURN( |
| std::unique_ptr<Thunk> thunk, |
| CreateBefThunk(GetThunkInfo(op), op, std::move(buffers))); |
| AddThunkToThunkSequence(std::move(thunk)); |
| return Status::OK(); |
| } |
| |
| ThunkSequence thunks; |
| |
| if (operand_buffer != a_buffer) { |
| thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( |
| GetThunkInfo(op), |
| /*source_address=*/operand_buffer, |
| /*destination_buffer=*/a_buffer, |
| /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); |
| } |
| |
| CholeskyOptions options; |
| options.set_lower(cholesky_op.is_lower()); |
| thunks.push_back(absl::make_unique<CholeskyThunk>( |
| GetThunkInfo(op), options, a_buffer, workspace_buffer, info_buffer, |
| shape.element_type(), batch_size, n)); |
| |
| // Elide the sequential thunk if there's no copy. |
| if (thunks.size() == 1) { |
| AddThunkToThunkSequence(std::move(thunks[0])); |
| } else { |
| AddThunkToThunkSequence(absl::make_unique<SequentialThunk>( |
| GetThunkInfo(op), std::move(thunks))); |
| } |
| |
| return Status::OK(); |
| } |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) { |
| auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(op); |
| const std::string call_target_name = custom_call.call_target_name().str(); |
| |
| void* call_target = CustomCallTargetRegistry::Global()->Lookup( |
| call_target_name, std::string(platform_name())); |
| if (!call_target) { |
| return Unimplemented( |
| "No registered implementation for custom call to \"%s\"", |
| call_target_name); |
| } |
| |
| std::vector<CustomCallThunk::OptionalSlice> operands; |
| std::vector<CustomCallThunk::OptionalSlice> results; |
| |
| if (custom_call.target_arg_mapping()) { |
| auto values_to_slices_with_token_holes = |
| [&](mlir::ValueRange operands, mlir::ArrayAttr op_to_target_mapping, |
| mlir::IntegerAttr num_target) |
| -> StatusOr<std::vector<CustomCallThunk::OptionalSlice>> { |
| std::vector<CustomCallThunk::OptionalSlice> slices(num_target.getInt()); |
| for (auto index_and_value_it : |
| llvm::zip(op_to_target_mapping, operands)) { |
| mlir::Attribute index_attr = std::get<0>(index_and_value_it); |
| mlir::Value value = std::get<1>(index_and_value_it); |
| int64_t index = index_attr.cast<mlir::IntegerAttr>().getInt(); |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, |
| GetAllocationSlice(value)); |
| slices[index] = slice; |
| } |
| return slices; |
| }; |
| |
| mlir::lmhlo::CustomCallTargetArgMapping target_mapping = |
| *custom_call.target_arg_mapping(); |
| TF_ASSIGN_OR_RETURN( |
| operands, values_to_slices_with_token_holes( |
| custom_call.args(), target_mapping.args_to_target_args(), |
| target_mapping.num_args())); |
| TF_ASSIGN_OR_RETURN(results, values_to_slices_with_token_holes( |
| custom_call.output(), |
| target_mapping.results_to_target_results(), |
| target_mapping.num_results())); |
| } else { |
| auto values_to_slices = [&](mlir::ValueRange values) |
| -> StatusOr<std::vector<CustomCallThunk::OptionalSlice>> { |
| std::vector<CustomCallThunk::OptionalSlice> slices; |
| for (mlir::Value value : values) { |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, |
| GetAllocationSlice(value)); |
| slices.push_back(slice); |
| } |
| return slices; |
| }; |
| |
| TF_ASSIGN_OR_RETURN(operands, values_to_slices(custom_call.args())); |
| TF_ASSIGN_OR_RETURN(results, values_to_slices(custom_call.output())); |
| } |
| |
| CustomCallThunk::CustomCallTarget custom_call_target; |
| |
| // TODO(hanbinyoon): Move this to a location that will serve both |
| // ir_emitter_unnested and BEF Executable. |
| // For information about this calling convention, see |
| // xla/g3doc/custom_call.md. |
| switch (custom_call.api_version()) { |
| case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL: |
| using original_call_type = |
| void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/, |
| const char* /*opaque*/, size_t /*opaque_len*/); |
| custom_call_target = [call_target](CustomCallThunk::Stream stream, |
| void** buffers, const char* opaque, |
| size_t opaque_len, |
| XlaCustomCallStatus*) { |
| auto typed_call_target = |
| reinterpret_cast<original_call_type>(call_target); |
| typed_call_target(stream, buffers, opaque, opaque_len); |
| }; |
| break; |
| case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING: |
| using status_returning_call_type = |
| void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/, |
| const char* /*opaque*/, size_t /*opaque_len*/, |
| XlaCustomCallStatus* /*status*/); |
| custom_call_target = |
| reinterpret_cast<status_returning_call_type>(call_target); |
| break; |
| default: |
| return InternalError("Unknown custom-call API version enum value: %d", |
| custom_call.api_version()); |
| } |
| |
| std::unique_ptr<Thunk> thunk; |
| if (IsBefThunkEnabled()) { |
| auto values_to_non_optional_slices = [&](mlir::ValueRange values) |
| -> StatusOr<std::vector<BufferAllocation::Slice>> { |
| std::vector<BufferAllocation::Slice> slices; |
| for (mlir::Value value : values) { |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, |
| GetAllocationSlice(value)); |
| slices.push_back(slice); |
| } |
| return slices; |
| }; |
| |
| TF_ASSIGN_OR_RETURN(std::vector<BufferAllocation::Slice> inputs, |
| values_to_non_optional_slices(custom_call.args())); |
| TF_ASSIGN_OR_RETURN(std::vector<BufferAllocation::Slice> outputs, |
| values_to_non_optional_slices(custom_call.output())); |
| std::vector<BufferAllocation::Slice> buffers; |
| buffers.reserve(inputs.size() + outputs.size()); |
| for (const auto& buffer : inputs) { |
| buffers.push_back(buffer); |
| } |
| for (const auto& buffer : outputs) { |
| buffers.push_back(buffer); |
| } |
| TF_ASSIGN_OR_RETURN(thunk, CreateBefCustomCallThunk( |
| GetThunkInfo(op), op, std::move(buffers), |
| std::move(custom_call_target))); |
| } else { |
| thunk = absl::make_unique<CustomCallThunk>( |
| GetThunkInfo(op), std::move(custom_call_target), std::move(operands), |
| std::move(results), custom_call.backend_config().str()); |
| } |
| AddThunkToThunkSequence(std::move(thunk)); |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitFftThunk(mlir::Operation* op) { |
| auto fft_op = mlir::cast<mlir::lmhlo::FftOp>(op); |
| const Shape operand_shape = GetShape(fft_op.operand()); |
| const Shape output_shape = GetShape(fft_op.output()); |
| TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand_shape.layout())); |
| TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())); |
| |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice, |
| GetAllocationSlice(fft_op.operand())); |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest_slice, |
| GetAllocationSlice(fft_op.output())); |
| TF_ASSIGN_OR_RETURN(xla::FftType fft_type, ConvertFftType(fft_op.fft_type())); |
| auto fft_length_values = fft_op.fft_length().getValues<int64_t>(); |
| std::vector<int64_t> fft_length(fft_length_values.begin(), |
| fft_length_values.end()); |
| AddThunkToThunkSequence( |
| absl::make_unique<FftThunk>(GetThunkInfo(op), fft_type, fft_length, |
| /*input_buffer=*/arg_slice, |
| /*output_buffer=*/dest_slice, |
| /*input_shape=*/operand_shape, |
| /*output_shape=*/output_shape)); |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitTriangularSolve(mlir::Operation* op) { |
| auto triangular_solve_op = mlir::cast<mlir::lmhlo::TriangularSolveOp>(op); |
| auto has_fortran_layout = [](mlir::DenseIntElementsAttr layout_attr) { |
| int64_t n = layout_attr.getNumElements(); |
| return layout_attr.getValue<int64_t>({0}) == n - 2 && |
| layout_attr.getValue<int64_t>({1}) == n - 1; |
| }; |
| TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_a())); |
| TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_b())); |
| TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_output())); |
| |
| const Shape b_shape = GetShape(triangular_solve_op.b()); |
| |
| const Shape output_shape = GetShape(triangular_solve_op.output()); |
| |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice, |
| GetAllocationSlice(triangular_solve_op.a())); |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_slice, |
| GetAllocationSlice(triangular_solve_op.b())); |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, |
| GetAllocationSlice(triangular_solve_op.output())); |
| TF_ASSIGN_OR_RETURN(TriangularSolveOptions_Transpose transpose_a, |
| ConvertTranspose(triangular_solve_op.transpose_a())); |
| |
| if (IsBefThunkEnabled()) { |
| std::vector<BufferAllocation::Slice> buffers = {a_slice, b_slice, |
| output_slice}; |
| TF_ASSIGN_OR_RETURN( |
| std::unique_ptr<Thunk> thunk, |
| CreateBefThunk(GetThunkInfo(op), op, std::move(buffers))); |
| AddThunkToThunkSequence(std::move(thunk)); |
| return Status::OK(); |
| } |
| |
| ThunkSequence thunks; |
| |
| // Triangular solve is in-place on 'b', so copy 'b' to the output if they |
| // aren't the same buffer. |
| if (b_slice != output_slice) { |
| thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( |
| Thunk::ThunkInfo(), |
| /*source_address=*/b_slice, |
| /*destination_buffer=*/output_slice, |
| /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape))); |
| } |
| |
| int64_t m = b_shape.dimensions(b_shape.rank() - 2); |
| int64_t n = b_shape.dimensions(b_shape.rank() - 1); |
| int64_t batch_size = std::accumulate( |
| b_shape.dimensions().begin(), b_shape.dimensions().end() - 2, int64_t{1}, |
| [](int64_t a, int64_t b) { return a * b; }); |
| int64_t elem_size = |
| ShapeUtil::ByteSizeOfPrimitiveType(output_shape.element_type()); |
| int64_t a_batch_stride = |
| triangular_solve_op.left_side() ? m * m * elem_size : n * n * elem_size; |
| int64_t b_batch_stride = m * n * elem_size; |
| TriangularSolveOptions options; |
| options.set_left_side(triangular_solve_op.left_side()); |
| options.set_lower(triangular_solve_op.lower()); |
| options.set_unit_diagonal(triangular_solve_op.unit_diagonal()); |
| options.set_transpose_a(transpose_a); |
| thunks.push_back(absl::make_unique<TriangularSolveThunk>( |
| GetThunkInfo(op), options, |
| /*a_input_buffer=*/a_slice, |
| /*b_input_buffer=*/output_slice, output_shape.element_type(), batch_size, |
| m, n, a_batch_stride, b_batch_stride)); |
| |
| // Elide the sequential thunk if there's no copy. |
| if (thunks.size() == 1) { |
| AddThunkToThunkSequence(std::move(thunks[0])); |
| } else { |
| AddThunkToThunkSequence(absl::make_unique<SequentialThunk>( |
| GetThunkInfo(op), std::move(thunks))); |
| } |
| return Status::OK(); |
| } |
| |
| // Convert the following form of fusion region: |
| // fusion() { |
| // %0 = tensor_load %external_memref0 |
| // %1 = tensor_load %external_memref1 |
| // ... |
| // tensor_store %ret, %external_memref2 |
| // } |
| // to |
| // fusion(%external_memref0, %external_memref1) (^bb(%0, %1) { |
| // ... |
| // mhlo.return %ret |
| // }) |
| // |
| // So that it's suitable for MHLO -> XLA HLO conversion. |
| // This function won't be needed once ElementalIrEmitter migrates to take MHLO |
| // instead. |
| static Status ProcessFusionForConversion(mlir::Region* region, |
| std::vector<Shape>* operand_shapes, |
| std::vector<Shape>* output_shapes) { |
| std::vector<mlir::memref::TensorLoadOp> loads; |
| std::vector<mlir::memref::TensorStoreOp> stores; |
| |
| region->walk([&](mlir::memref::TensorLoadOp load) { |
| if (load.memref().getParentRegion() != region) { |
| loads.push_back(load); |
| } |
| }); |
| |
| region->walk([&](mlir::memref::TensorStoreOp store) { |
| if (store.memref().getParentRegion() != region) { |
| stores.push_back(store); |
| } |
| }); |
| |
| for (auto load : loads) { |
| auto arg = region->addArgument(load.getType()); |
| load.replaceAllUsesWith(arg); |
| Shape shape = GetShape(load.getResult()); |
| operand_shapes->push_back(std::move(shape)); |
| load.erase(); |
| } |
| |
| std::vector<mlir::Value> returned_values; |
| for (auto store : stores) { |
| Shape shape = GetShape(store.memref()); |
| output_shapes->push_back(shape); |
| |
| returned_values.push_back(store.tensor()); |
| store.erase(); |
| } |
| |
| region->back().back().erase(); |
| auto b = mlir::OpBuilder::atBlockEnd(®ion->back()); |
| auto loc = returned_values[0].getLoc(); |
| b.create<mlir::mhlo::ReturnOp>(loc, returned_values); |
| return Status::OK(); |
| } |
| |
| // TODO(timshen): update the comment once the HandleFusion code path deleted. |
| // |
| // This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the |
| // subclass. The logic is de-virtualized and less scattered. |
| Status IrEmitterUnnested::EmitLoopFusion(mlir::Operation* op) { |
| auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op); |
| MlirEmitterContext context; |
| context.SetOperation(fusion); |
| |
| TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation, |
| GetOrCreateSubComputationFromRegion(&fusion.region(), |
| /*is_fusion=*/true)); |
| |
| int unroll_factor; |
| if (!MayPreventVectorization(fusion)) { |
| unroll_factor = ComputeMaxUnrollFactor(fusion, hlo_module_config_); |
| } else { |
| unroll_factor = 1; |
| } |
| |
| bool row_vectorized; |
| int num_big_inputs; |
| std::tie(row_vectorized, num_big_inputs) = RowVectorizationEnabled(fusion); |
| bool few_waves = [fusion, row_vectorized, num_big_inputs]() mutable { |
| for (mlir::Operation& op : fusion.region().front()) { |
| if (mlir::isa<mlir::memref::TensorLoadOp, mlir::memref::TensorStoreOp, |
| mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp, |
| mlir::mhlo::ConstOp>(op)) { |
| continue; |
| } |
| HloOpcode opcode = *MhloToHloOpcode(&op); |
| if (HloInstruction::IsOpElementwise(opcode)) { |
| continue; |
| } |
| if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastInDimOp>(op)) { |
| if (broadcast.broadcast_dimensions().empty() || |
| // More then 2 bit inputs cause one speed regression. |
| (row_vectorized && num_big_inputs <= 3)) { |
| continue; |
| } |
| } |
| VLOG(2) << "few_waves not enabled due to: " << MlirToString(&op); |
| return false; |
| } |
| return true; |
| }(); |
| |
| Shape element_shape = context.output_shapes[0]; |
| LaunchDimensionsConfig launch_config{unroll_factor, few_waves, |
| row_vectorized}; |
| // Check that the shapes is supported. |
| if (launch_config.row_vectorized && |
| ThreadsPerBlockRowVectorized(element_shape, |
| ir_emitter_context_->gpu_device_info(), |
| launch_config) <= 0) { |
| VLOG(2) << "Cancelling row_vectorization as the shape isn't supported."; |
| launch_config.row_vectorized = false; |
| launch_config.few_waves = false; |
| } |
| |
| TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, |
| CalculateLaunchDimensions( |
| element_shape, ir_emitter_context_->gpu_device_info(), |
| launch_config)); |
| |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| Thunk* kernel_thunk; |
| { |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> kernel_thunk_ptr, |
| BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays, |
| launch_dimensions)); |
| kernel_thunk = kernel_thunk_ptr.get(); |
| thunk_sequence_.emplace_back(std::move(kernel_thunk_ptr)); |
| } |
| |
| auto operand_arrays = |
| absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size()); |
| auto output_element_arrays = absl::MakeSpan(ir_arrays).subspan( |
| context.operand_shapes.size(), context.output_shapes.size()); |
| |
| GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_, |
| GetNestedComputer()); |
| FusedIrEmitter fused_emitter(&elemental_emitter); |
| |
| for (int i = 0; i < context.operand_shapes.size(); i++) { |
| auto* builder = &b_; |
| auto ir_array = operand_arrays[i]; |
| fused_emitter.BindGenerator( |
| fused_computation->parameter_instruction(i), |
| [builder, ir_array](llvm_ir::IrArray::Index index) { |
| return ir_array.EmitReadArrayElement(index, builder); |
| }); |
| } |
| TF_ASSIGN_OR_RETURN( |
| auto element_generator, |
| fused_emitter.GetGenerator(fused_computation->root_instruction())); |
| |
| llvm::Type* index_type = |
| GetIndexTypeForKernel(fusion, launch_dimensions.launch_bound(), &b_); |
| |
| if (context.output_shapes.size() > 1) { |
| // For multioutput fusion, we need to emit each operand and the root. |
| TF_RETURN_IF_ERROR( |
| ParallelLoopEmitter(element_generator, output_element_arrays, |
| launch_dimensions, &b_, launch_config) |
| .EmitLoop(context.name, index_type)); |
| } else { |
| TF_RETURN_IF_ERROR( |
| ParallelLoopEmitter(element_generator, output_element_arrays[0], |
| launch_dimensions, &b_, launch_config) |
| .EmitLoop(context.name, index_type)); |
| } |
| |
| b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); |
| return Status::OK(); |
| } |
| |
| // Returns whether any of the rooots of the fusion are unnested reductions. |
| static bool HasAnyUnnestedReductionRoot(mlir::lmhlo::FusionOp fusion) { |
| return absl::c_any_of(fusion.getFusionRoots(), [&](mlir::Operation* op) { |
| return IsReductionFromOrToContiguousDimensions(op); |
| }); |
| } |
| |
| Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { |
| auto fusion_op = mlir::cast<mlir::lmhlo::FusionOp>(op); |
| const bool is_single_instruction = IsSingleInstructionFusion(fusion_op); |
| |
| if (HasAnyUnnestedReductionRoot(fusion_op)) { |
| return EmitUnnestedReduction(fusion_op); |
| } |
| |
| llvm::SmallVector<mlir::Value, 6> fusion_results = |
| fusion_op.getFusionResults(); |
| TF_RET_CHECK(!fusion_results.empty()); |
| if (fusion_results.size() > 1) { |
| // In the case of root tuple, it can be either reduce or slice input |
| // fusion. |
| if (IsInputFusibleSlices(op, /*verify_no_strides=*/true)) { |
| // The emitter doesn't support all cases. If it's not supported, fallback |
| // to ElementalIrEmitter. |
| auto status = EmitInputFusibleNonStridedSlices(op); |
| if (status.code() == tensorflow::error::FAILED_PRECONDITION) { |
| return EmitLoopFusion(op); |
| } |
| return status; |
| } |
| } |
| |
| mlir::Operation* fusion_root = fusion_results[0].getDefiningOp(); |
| if (mlir::isa<mlir::mhlo::ScatterOp>(fusion_root)) { |
| TF_ASSIGN_OR_RETURN( |
| const HloComputation* fused_computation, |
| GetOrCreateSubComputationFromRegion(&fusion_op.region(), |
| /*is_fusion=*/true)); |
| auto* root = fused_computation->root_instruction(); |
| |
| ThunkSequence thunks; |
| // The initialization from 'operand' is using different loop bounds, so |
| // emit it in a separate kernel. Treat it like a loop fusion, writing to |
| // the output buffer. |
| { |
| auto unroll_factor = |
| ComputeMaxUnrollFactor(fusion_op, hlo_module_config_); |
| const Shape& element_shape = root->shape(); |
| TF_ASSIGN_OR_RETURN( |
| LaunchDimensions launch_dimensions, |
| CalculateLaunchDimensions(element_shape, |
| ir_emitter_context_->gpu_device_info(), |
| {unroll_factor, /*few_waves=*/false})); |
| |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| TF_ASSIGN_OR_RETURN(auto operand_thunk, |
| BuildKernelThunk(op, Thunk::ThunkInfo(), &ir_arrays, |
| launch_dimensions)); |
| thunks.push_back(std::move(operand_thunk)); |
| |
| GpuElementalIrEmitter operand_elemental_emitter( |
| hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, |
| GetNestedComputer()); |
| FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter); |
| for (int i = 0; i < fused_computation->num_parameters(); i++) { |
| auto fused_operand = fused_computation->parameter_instruction(i); |
| operand_fused_emitter.BindGenerator( |
| fused_operand, [this, &ir_arrays, i, |
| fused_operand](llvm_ir::IrArray::Index index) { |
| return ir_arrays[i].EmitReadArrayElement(index, &b_, |
| fused_operand->name()); |
| }); |
| } |
| TF_ASSIGN_OR_RETURN(auto generator, |
| operand_fused_emitter.GetGenerator(root->operand(0))); |
| |
| TF_RETURN_IF_ERROR( |
| ParallelLoopEmitter(generator, ir_arrays.back(), launch_dimensions, |
| &b_, {unroll_factor}) |
| .EmitLoop(IrName(mlir::GetNameFromLoc(fusion_op.getLoc())), |
| GetIndexTypeForKernel( |
| fusion_op, launch_dimensions.launch_bound(), &b_))); |
| } |
| |
| // Now build the actual scatter, reading and writing to the freshly |
| // filled output buffer. |
| { |
| const Shape& updates_shape = root->operand(2)->shape(); |
| TF_ASSIGN_OR_RETURN( |
| LaunchDimensions launch_dimensions, |
| CalculateLaunchDimensions(updates_shape, |
| ir_emitter_context_->gpu_device_info())); |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| TF_ASSIGN_OR_RETURN(auto scatter_thunk, |
| BuildKernelThunk(op, Thunk::ThunkInfo(), &ir_arrays, |
| launch_dimensions)); |
| thunks.push_back(std::move(scatter_thunk)); |
| // Spin up a new fused emitter for the scatter kernel and emit it. |
| GpuElementalIrEmitter scatter_elemental_emitter( |
| hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, |
| GetNestedComputer()); |
| FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter); |
| for (int i = 0; i < fused_computation->num_parameters(); i++) { |
| auto fused_operand = fused_computation->parameter_instruction(i); |
| scatter_fused_emitter.BindGenerator( |
| fused_operand, [this, &ir_arrays, i, |
| fused_operand](llvm_ir::IrArray::Index index) { |
| return ir_arrays[i].EmitReadArrayElement(index, &b_, |
| fused_operand->name()); |
| }); |
| } |
| |
| TF_ASSIGN_OR_RETURN(const auto dim_numbers, |
| mlir::LhloDialectEmitter::GetScatterDimensionNumbers( |
| root, fusion_op.getContext())); |
| |
| ScatterDescriptor desc; |
| desc.name = IrName(root); |
| desc.operand_shape = root->operand(0)->shape(); |
| desc.scatter_indices_shape = root->operand(1)->shape(); |
| desc.updates_shape = updates_shape; |
| desc.dim_numbers = dim_numbers; |
| desc.unique_indices = root->unique_indices(); |
| desc.update_computation = root->called_computations()[0]; |
| desc.output = ir_arrays.back(); |
| TF_ASSIGN_OR_RETURN(desc.scatter_indices_gen, |
| scatter_fused_emitter.GetGenerator(root->operand(1))); |
| TF_ASSIGN_OR_RETURN(desc.updates_gen, |
| scatter_fused_emitter.GetGenerator(root->operand(2))); |
| desc.get_index_type = [&](int64_t launch_size) { |
| return GetIndexTypeForKernel(root, launch_size, &b_); |
| }; |
| |
| TF_RETURN_IF_ERROR( |
| EmitScatter(desc, thunks.back().get(), launch_dimensions)); |
| } |
| AddThunkToThunkSequence(absl::make_unique<SequentialThunk>( |
| GetThunkInfo(op), std::move(thunks))); |
| return Status::OK(); |
| } |
| |
| if (!is_single_instruction && |
| CanEmitFusedDynamicUpdateSliceInPlaceForGpu( |
| fusion_op, ir_emitter_context_->allocations())) { |
| // Fusion node with dynamic-update-slice as the root where the op's input |
| // (i.e. array to update) shares the same slice as its output. In this case |
| // we have a special algorithm that modifies the output in place without |
| // touching the un-updated elements. |
| CHECK_EQ(1, GetHloOutputs(op).size()); |
| |
| TF_ASSIGN_OR_RETURN( |
| const HloComputation* fused_computation, |
| GetOrCreateSubComputationFromRegion(&fusion_op.region(), |
| /*is_fusion=*/true)); |
| |
| // Shape of the dynamic-update-slice's "update" operand. |
| Shape update_shape = |
| fused_computation->root_instruction()->operand(1)->shape(); |
| |
| TF_ASSIGN_OR_RETURN( |
| LaunchDimensions launch_dimensions, |
| CalculateLaunchDimensions(update_shape, |
| ir_emitter_context_->gpu_device_info())); |
| |
| // Set up kernel thunk and fused ir emitter. |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| TF_ASSIGN_OR_RETURN(auto fusion_thunk, |
| BuildKernelThunk(fusion_op, GetThunkInfo(op), |
| &ir_arrays, launch_dimensions)); |
| AddThunkToThunkSequence(std::move(fusion_thunk)); |
| |
| GpuElementalIrEmitter elemental_emitter(hlo_module_config_, |
| ir_emitter_context_->llvm_module(), |
| &b_, GetNestedComputer()); |
| |
| FusedIrEmitter fused_emitter(&elemental_emitter); |
| |
| for (int i = 0; i < fused_computation->num_parameters(); i++) { |
| auto fused_operand = fused_computation->parameter_instruction(i); |
| fused_emitter.BindGenerator( |
| fused_operand, [this, &ir_arrays, i, |
| fused_operand](const llvm_ir::IrArray::Index& index) { |
| return ir_arrays[i].EmitReadArrayElement(index, &b_, |
| fused_operand->name()); |
| }); |
| } |
| |
| // Array to write into. Because this is an in-place operation, this is the |
| // same as operand 0's array. |
| const IrArray& output_array = ir_arrays.back(); |
| |
| return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( |
| fused_computation, output_array, &fused_emitter, launch_dimensions, |
| &b_); |
| } |
| |
| if (auto copy = mlir::dyn_cast<mlir::mhlo::CopyOp>(fusion_root)) { |
| if (IsSingleInstructionFusion(fusion_op)) { |
| auto operands = GetHloOperands(fusion_op); |
| auto outputs = GetHloOutputs(fusion_op); |
| TF_RET_CHECK(operands.size() == 1); |
| TF_RET_CHECK(outputs.size() == 1); |
| |
| auto operand_shape = GetShape(operands[0]); |
| auto output_shape = GetShape(outputs[0]); |
| |
| CHECK(ShapeUtil::Compatible(operand_shape, output_shape)); |
| auto maybe_slice = GetAllocationSlice(operands[0]); |
| if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) && |
| maybe_slice.ok()) { |
| // Copy the operand into the output if it's not the same buffer already. |
| auto operand_buffer = *maybe_slice; |
| auto destination_buffer = *GetAllocationSlice(outputs[0]); |
| if (operand_buffer != destination_buffer) { |
| AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>( |
| GetThunkInfo(op), |
| /*source_address=*/operand_buffer, |
| /*destination_buffer=*/destination_buffer, |
| /*mem_size=*/ |
| ByteSizeOf(operand_shape))); |
| } |
| return Status::OK(); |
| } |
| } |
| } |
| |
| TF_ASSIGN_OR_RETURN(const bool matched_021, CheckAndEmitHloWithTile021(op)); |
| if (matched_021) { |
| return Status::OK(); |
| } |
| |
| return EmitLoopFusion(op); |
| } |
| |
| Status IrEmitterUnnested::EmitExtraOutputsForReduce( |
| const ReductionOutputMap& result_ir_arrays, const IrArray::Index& index, |
| bool use_linear_index, |
| absl::Span<const std::pair<llvm_ir::ElementGenerator, int>> |
| extra_output_gens) { |
| // Compute all extra output values before writing them. This avoids |
| // overwriting aliased input/output buffers before all reads occured. |
| absl::InlinedVector<llvm::Value*, 8> extra_output_ir_values; |
| for (int i = 0; i < extra_output_gens.size(); ++i) { |
| TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, |
| extra_output_gens[i].first(index)); |
| extra_output_ir_values.push_back(extra_output_ir_value); |
| } |
| for (int i = 0; i < extra_output_gens.size(); ++i) { |
| int idx = extra_output_gens[i].second; |
| CHECK_EQ(result_ir_arrays.at(idx).size(), 1); |
| result_ir_arrays.at(idx)[0].EmitWriteArrayElement( |
| index, extra_output_ir_values[i], &b_, use_linear_index); |
| } |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::AssertNonDeterminismIsOkay(const string& op_name) { |
| if (hlo_module_config_.debug_options().xla_gpu_deterministic_ops()) { |
| return Unimplemented( |
| "HLO instruction %s does not have a deterministic implementation, " |
| "but run-to-run determinism is required by " |
| "--xla_gpu_deterministic_ops.", |
| op_name); |
| } |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitSelectAndScatter(mlir::Operation* op) { |
| auto select_and_scatter_op = mlir::cast<mlir::lmhlo::SelectAndScatterOp>(op); |
| |
| const Shape source_shape = GetShape(select_and_scatter_op.source()); |
| const Shape operand_shape = GetShape(select_and_scatter_op.operand()); |
| const int64_t rank = operand_shape.rank(); |
| |
| CHECK_EQ(rank, source_shape.rank()); |
| if (select_and_scatter_op.window_dimensions()) { |
| CHECK_EQ(rank, select_and_scatter_op.window_dimensions()->size()); |
| } |
| |
| TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay( |
| mlir::GetNameFromLoc(select_and_scatter_op.getLoc()))); |
| |
| std::string name = mlir::GetNameFromLoc(select_and_scatter_op.getLoc()); |
| |
| // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk |
| // consisting of two thunks, an initializer KernelThunk that initializes |
| // the output and another KernelThunk that accumulates the scattered |
| // elements. |
| ThunkSequence thunks; |
| thunks.emplace_back(); |
| TF_ASSIGN_OR_RETURN(thunks.back(), BuildInitializerThunk( |
| op, select_and_scatter_op.init_value(), |
| select_and_scatter_op.out())); |
| |
| TF_ASSIGN_OR_RETURN( |
| LaunchDimensions launch_dimensions, |
| CalculateLaunchDimensions(source_shape, |
| ir_emitter_context_->gpu_device_info())); |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| thunks.emplace_back(); |
| // Init value is not needed in IR emission. |
| TF_ASSIGN_OR_RETURN( |
| thunks.back(), |
| BuildKernelThunk( |
| select_and_scatter_op, |
| {select_and_scatter_op.operand(), select_and_scatter_op.source(), |
| select_and_scatter_op.out()}, |
| Thunk::ThunkInfo(), &ir_arrays, launch_dimensions)); |
| |
| CHECK_EQ(ir_arrays.size(), 3); |
| const IrArray& operand_array = ir_arrays[0]; |
| const IrArray& source_array = ir_arrays[1]; |
| const IrArray& out_array = ir_arrays[2]; |
| |
| auto select_and_scatter_thunk = |
| absl::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)); |
| |
| llvm::Type* index_type = GetIndexTypeForKernel( |
| select_and_scatter_op, launch_dimensions.launch_bound(), &b_); |
| auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { |
| return llvm::ConstantInt::get(index_type, c); |
| }; |
| |
| // kSelectAndScatter is implemented as two kernel launches: the first launch |
| // initializes the output array to the given initial value, |
| // and the second accumulates the "source" matrix to the |
| // selected elements in the output array. The first launch is already |
| // implemented by the initializer thunk generated earlier, so this function |
| // only needs to take care of the select-and-scatter part. |
| // |
| // Pseudo code for select-and-scatter: |
| // |
| // for (coordinates S in the source): # This loop is parallel. |
| // initialized_flag = false |
| // for (coordinates W in the window): |
| // I = S * stride + W - pad_low |
| // if I within bounds of operand: |
| // if !(initialized_flag and select(selected_value, operand(I))): |
| // selected_value = operand(I) |
| // selected_index = I |
| // initialized_flag = true |
| // output(selected_index) = scatter(output(selected_index), source(S)) |
| auto loop_body_emitter = [&](const IrArray::Index& source_index) -> Status { |
| // Allocate space to keep the currently selected value, its index, and a |
| // boolean flag if the value is initialized. The initialized_flag is set |
| // false. |
| llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( |
| llvm_ir::PrimitiveTypeToIrType(operand_shape.element_type(), |
| ir_emitter_context_->llvm_module()), |
| "selected_value_address", &b_); |
| |
| llvm::Value* selected_index_address = |
| llvm_ir::EmitAllocaAtFunctionEntryWithCount( |
| index_type, index_typed_constant(rank), "selected_index_address", |
| &b_); |
| |
| llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( |
| b_.getInt1Ty(), "initialized_flag_address", &b_); |
| Store(b_.getInt1(false), initialized_flag_address); |
| |
| // Create the inner loop to iterate over the window. |
| llvm_ir::ForLoopNest window_loops(absl::StrCat(name, "inner"), &b_, |
| index_type); |
| |
| DimensionVector window_size; |
| mlir::DenseIntElementsAttr window_dimensions = |
| select_and_scatter_op.window_dimensions().getValue(); |
| for (const auto& dim : window_dimensions) { |
| window_size.push_back(dim.getSExtValue()); |
| CHECK_GT(dim.getSExtValue(), 0); |
| } |
| |
| const IrArray::Index window_index = window_loops.AddLoopsForShape( |
| ShapeUtil::MakeShape(operand_shape.element_type(), window_size), |
| "window"); |
| llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), |
| &b_); |
| |
| // Compute the operand index to visit and evaluate the condition whether the |
| // operand index is within the bounds. The unsigned comparison includes |
| // checking whether the operand index >= 0. |
| std::vector<llvm::Value*> operand_multi_index(source_index.size()); |
| llvm::Value* in_bounds_condition = b_.getInt1(true); |
| |
| auto strides = *select_and_scatter_op.window_strides(); |
| auto paddings = *select_and_scatter_op.padding(); |
| |
| for (auto stride_and_padding : |
| llvm::enumerate(llvm::zip(strides, paddings))) { |
| const int i = stride_and_padding.index(); |
| int64_t stride = std::get<0>(stride_and_padding.value()).getSExtValue(); |
| int64_t padding = std::get<1>(stride_and_padding.value()).getSExtValue(); |
| |
| llvm::Value* strided_index = |
| NSWMul(source_index[i], index_typed_constant(stride)); |
| operand_multi_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), |
| index_typed_constant(padding)); |
| llvm::Value* index_condition = ICmpULT( |
| operand_multi_index[i], |
| index_typed_constant(ShapeUtil::GetDimension(operand_shape, i))); |
| in_bounds_condition = And(in_bounds_condition, index_condition); |
| } |
| |
| // Only need to do something if the operand index is within the bounds. |
| // First check if the initialized_flag is set. |
| llvm_ir::LlvmIfData if_in_bounds = |
| llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); |
| llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_); |
| llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( |
| Load(initialized_flag_address), "initialized", &b_); |
| |
| // If the initialized_flag is false, initialize the selected value and index |
| // with the currently visiting operand. |
| llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_); |
| const auto save_operand_index = [&](const IrArray::Index& operand_index) { |
| for (int64_t i = 0; i < rank; ++i) { |
| llvm::Value* selected_index_address_slot = |
| InBoundsGEP(selected_index_address, {b_.getInt32(i)}); |
| Store(operand_index[i], selected_index_address_slot); |
| } |
| }; |
| IrArray::Index operand_index(operand_multi_index, operand_shape, |
| index_type); |
| llvm::Value* operand_data = |
| operand_array.EmitReadArrayElement(operand_index, &b_); |
| Store(operand_data, selected_value_address); |
| save_operand_index(operand_index); |
| Store(b_.getInt1(true), initialized_flag_address); |
| |
| // If the initialized_flag is true, call the `select` function to |
| // potentially update the selected value and index with the currently |
| // visiting operand. |
| llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_); |
| llvm::Value* operand_address = |
| operand_array.EmitArrayElementAddress(operand_index, &b_); |
| llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( |
| llvm_ir::PrimitiveTypeToIrType(PRED, |
| ir_emitter_context_->llvm_module()), |
| "select_return_buffer", &b_); |
| |
| TF_ASSIGN_OR_RETURN( |
| const HloComputation* select_computation, |
| GetOrCreateSubComputationFromRegion(&select_and_scatter_op.select(), |
| /*is_fusion=*/false)); |
| |
| TF_RETURN_IF_ERROR(EmitCallToNestedComputation( |
| *select_computation, {selected_value_address, operand_address}, |
| select_return_buffer)); |
| llvm::Value* result = Load(select_return_buffer); |
| |
| // If the 'select' function returns false, update the selected value and the |
| // index to the currently visiting operand. |
| llvm::Value* cond = ICmpNE( |
| result, |
| llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType( |
| PRED, ir_emitter_context_->llvm_module()), |
| 0), |
| "boolean_predicate"); |
| llvm_ir::LlvmIfData if_select_lhs = |
| llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); |
| llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_); |
| Store(Load(operand_address), selected_value_address); |
| save_operand_index(operand_index); |
| |
| // After iterating over the window elements, scatter the source element to |
| // the selected index of the output. The value we store at the output |
| // location is computed by calling the `scatter` function with the source |
| // value and the current output value. |
| llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), |
| &b_); |
| std::vector<llvm::Value*> selected_multi_index; |
| for (int64_t i = 0; i < rank; ++i) { |
| llvm::Value* selected_index_address_slot = |
| InBoundsGEP(selected_index_address, {b_.getInt32(i)}); |
| selected_multi_index.push_back(Load(selected_index_address_slot)); |
| } |
| const Shape output_shape = GetShape(select_and_scatter_op.out()); |
| llvm::Value* source_value_address = |
| source_array.EmitArrayElementAddress(source_index, &b_); |
| IrArray::Index selected_index(selected_multi_index, output_shape, |
| operand_index.GetType()); |
| llvm::Value* output_value_address = |
| out_array.EmitArrayElementAddress(selected_index, &b_); |
| |
| TF_ASSIGN_OR_RETURN( |
| const HloComputation* scatter_computation, |
| GetOrCreateSubComputationFromRegion(&select_and_scatter_op.scatter(), |
| /*is_fusion=*/false)); |
| |
| return EmitAtomicOperationForNestedComputation( |
| *scatter_computation, output_value_address, source_value_address); |
| }; |
| |
| AddThunkToThunkSequence(std::move(select_and_scatter_thunk)); |
| return ParallelLoopEmitter(loop_body_emitter, source_shape, launch_dimensions, |
| &b_) |
| .EmitLoop(name, index_type); |
| } |
| |
| Status IrEmitterUnnested::EmitWhile(mlir::Operation* op) { |
| auto while_op = mlir::cast<mlir::lmhlo::WhileOp>(op); |
| |
| auto cond_result = GetHloOutputs(while_op); |
| TF_RET_CHECK(cond_result.size() == 1); |
| TF_RET_CHECK(cond_result[0] |
| .getType() |
| .cast<mlir::ShapedType>() |
| .getElementType() |
| .isInteger(/*width=*/1)) |
| << "While condition computation must return bool"; |
| |
| // Build ForThunk for conformant while loops, otherwise build WhileThunk. |
| if (while_op.trip_count()) { |
| TF_ASSIGN_OR_RETURN(auto thunk, BuildForThunk(while_op, GetThunkInfo(op), |
| *while_op.trip_count())); |
| AddThunkToThunkSequence(std::move(thunk)); |
| } else { |
| TF_ASSIGN_OR_RETURN(auto thunk, |
| BuildWhileThunk(while_op, GetThunkInfo(op))); |
| AddThunkToThunkSequence(std::move(thunk)); |
| } |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitRngGetAndUpdateState(mlir::Operation* op) { |
| auto rng_op = mlir::dyn_cast<mlir::lmhlo::RngGetAndUpdateStateOp>(op); |
| |
| // Emit a kernel to increment the global state for Philox RNG algorithm. |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| TF_ASSIGN_OR_RETURN(auto kernel_thunk, |
| BuildKernelThunk(rng_op, rng_op.state(), GetThunkInfo(op), |
| &ir_arrays, LaunchDimensions())); |
| AddThunkToThunkSequence(std::move(kernel_thunk)); |
| |
| llvm::Value* old_state = |
| llvm_ir::RngGetAndUpdateState(rng_op.delta(), module_, &b_); |
| |
| const Shape shape = GetShape(rng_op.state()); |
| |
| llvm::Value* output_address = ir_arrays[0].EmitArrayElementAddress( |
| llvm_ir::IrArray::Index( |
| /*linear=*/b_.getInt64(0), shape, &b_), |
| &b_, "rng_state_address"); |
| output_address = BitCast( |
| output_address, llvm::PointerType::get( |
| old_state->getType(), |
| output_address->getType()->getPointerAddressSpace())); |
| Store(old_state, output_address); |
| |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitScatter(mlir::Operation* op) { |
| ThunkSequence thunks; |
| |
| auto scatter_op = mlir::cast<mlir::lmhlo::ScatterOp>(op); |
| |
| if (!scatter_op.unique_indices()) { |
| TF_RETURN_IF_ERROR( |
| AssertNonDeterminismIsOkay(mlir::GetNameFromLoc(scatter_op.getLoc()))); |
| } |
| |
| TF_ASSIGN_OR_RETURN(auto operand_buffer, |
| GetAllocationSlice(scatter_op.operand())); |
| TF_ASSIGN_OR_RETURN(auto output_buffer, |
| GetAllocationSlice(scatter_op.output())); |
| |
| // Copy the operand into the output if it's not the same buffer already. |
| if (operand_buffer != output_buffer) { |
| thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( |
| Thunk::ThunkInfo(), |
| /*source_address=*/operand_buffer, |
| /*destination_buffer=*/output_buffer, |
| /*mem_size=*/ |
| ShapeUtil::ByteSizeOf(GetShape(scatter_op.output())))); |
| } |
| |
| const Shape& data_shape = GetShape(scatter_op.updates()); |
| TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, |
| CalculateLaunchDimensions( |
| data_shape, ir_emitter_context_->gpu_device_info())); |
| |
| // Create kernel thunk for all operands except the first one (`operand`). The |
| // code generated for scatter below assumes that the input operand is already |
| // copied into the output, so does not use it in codegen. |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| thunks.emplace_back(); |
| TF_ASSIGN_OR_RETURN( |
| thunks.back(), |
| BuildKernelThunk(scatter_op, scatter_op.getOperands().drop_front(), |
| GetThunkInfo(op), &ir_arrays, launch_dimensions)); |
| |
| CHECK_EQ(ir_arrays.size(), 3); |
| const IrArray& scatter_indices = ir_arrays[0]; |
| const IrArray& updates = ir_arrays[1]; |
| const IrArray& output = ir_arrays[2]; |
| |
| auto get_index_type = [&](int64_t launch_size) { |
| return GetIndexTypeForKernel(scatter_op, launch_size, &b_); |
| }; |
| |
| TF_RETURN_IF_ERROR(EmitScatter( |
| thunks.back().get(), scatter_op, launch_dimensions, output, |
| /*scatter_indices_gen=*/ |
| [&](const IrArray::Index& index) { |
| return scatter_indices.EmitReadArrayElement(index, &b_, |
| "scatter_index"); |
| }, |
| /*updates_gen=*/ |
| [&](const IrArray::Index& index) { |
| return updates.EmitReadArrayElement(index, &b_, "update"); |
| }, |
| /* get_index_type=*/ |
| get_index_type)); |
| |
| // Elide the sequential thunk if there's no copy. |
| if (thunks.size() == 1) { |
| AddThunkToThunkSequence(std::move(thunks[0])); |
| } else { |
| AddThunkToThunkSequence(absl::make_unique<SequentialThunk>( |
| GetThunkInfo(op), std::move(thunks))); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitScatter( |
| Thunk* thunk, mlir::lmhlo::ScatterOp scatter, |
| const LaunchDimensions& launch_dimensions, const llvm_ir::IrArray& output, |
| const llvm_ir::ElementGenerator& scatter_indices_gen, |
| const llvm_ir::ElementGenerator& updates_gen, |
| std::function<llvm::Type*(int64_t)> get_index_type) { |
| const Shape operand_shape = GetShape(scatter.operand()); |
| CHECK(ShapeUtil::Equal(GetShape(scatter.output()), operand_shape)); |
| |
| TF_ASSIGN_OR_RETURN( |
| const HloComputation* update_computation, |
| GetOrCreateSubComputationFromRegion(&scatter.update_computation(), |
| /*is_fusion=*/false)); |
| |
| ScatterDescriptor desc; |
| desc.name = mlir::GetNameFromLoc(scatter.getLoc()); |
| desc.operand_shape = operand_shape; |
| desc.scatter_indices_shape = GetShape(scatter.scatter_indices()); |
| desc.updates_shape = GetShape(scatter.updates()); |
| desc.dim_numbers = scatter.scatter_dimension_numbers(); |
| desc.unique_indices = scatter.unique_indices(); |
| desc.update_computation = update_computation; |
| desc.output = output; |
| desc.scatter_indices_gen = scatter_indices_gen; |
| desc.updates_gen = updates_gen; |
| desc.get_index_type = get_index_type; |
| return EmitScatter(desc, thunk, launch_dimensions); |
| } |
| |
| Status IrEmitterUnnested::EmitScatter( |
| const ScatterDescriptor& desc, Thunk* thunk, |
| const LaunchDimensions& launch_dimensions) { |
| if (!desc.unique_indices) { |
| TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(desc.name)); |
| } |
| auto loop_body_emitter = [&](const IrArray::Index& index) -> Status { |
| std::vector<llvm::Value*> raw_window_multidim; |
| std::vector<llvm::Value*> input_scatter_multidim; |
| std::vector<int64_t> raw_window_bounds; |
| |
| // Partition the index into window indices and scatter indices. |
| for (int64_t i = 0, e = index.size(); i != e; ++i) { |
| // For window indices also remember the window size, this comes in handy |
| // later. |
| if (llvm::is_contained(desc.dim_numbers.getUpdateWindowDims(), i)) { |
| raw_window_multidim.push_back(index[i]); |
| raw_window_bounds.push_back(desc.updates_shape.dimensions(i)); |
| } else { |
| input_scatter_multidim.push_back(index[i]); |
| } |
| } |
| DCHECK_EQ(raw_window_multidim.size(), |
| desc.dim_numbers.getUpdateWindowDims().size()); |
| |
| // Apply inserted_window_dims to the window dimensions. |
| int64_t raw_window_multidim_idx = 0; |
| std::vector<llvm::Value*> input_window_multidim; |
| std::vector<int64_t> input_window_bounds; |
| |
| for (int64_t i = 0, e = desc.operand_shape.rank(); i != e; ++i) { |
| if (llvm::is_contained(desc.dim_numbers.getInsertedWindowDims(), i)) { |
| input_window_bounds.push_back(1); // Trivial dimension. |
| input_window_multidim.push_back(index.GetConstantWithIndexType(0)); |
| } else { |
| input_window_bounds.push_back( |
| raw_window_bounds[raw_window_multidim_idx]); |
| input_window_multidim.push_back( |
| raw_window_multidim[raw_window_multidim_idx]); |
| ++raw_window_multidim_idx; |
| } |
| } |
| DCHECK_EQ(input_window_multidim.size(), desc.operand_shape.rank()); |
| |
| // Insert a 1 dimension at the end if index_vector_dim requests one. |
| Shape scatter_indices_shape_fixed = desc.scatter_indices_shape; |
| if (desc.dim_numbers.getIndexVectorDim() == |
| desc.scatter_indices_shape.rank()) { |
| scatter_indices_shape_fixed.add_dimensions(1); |
| scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major( |
| desc.dim_numbers.getIndexVectorDim()); |
| } |
| |
| // Now load the indices corresponding to the current window from |
| // scatter_indices. |
| std::vector<llvm::Value*> raw_scatter_index_multidim = |
| input_scatter_multidim; |
| raw_scatter_index_multidim.insert(raw_scatter_index_multidim.begin() + |
| desc.dim_numbers.getIndexVectorDim(), |
| nullptr); |
| llvm::Value* is_in_bounds = b_.getTrue(); |
| for (int64_t i = 0, |
| e = desc.dim_numbers.getScatterDimsToOperandDims().size(); |
| i != e; ++i) { |
| // Our index is stored along index_vector_dim, insert that into the lookup |
| // index into scatter_indices. |
| raw_scatter_index_multidim[desc.dim_numbers.getIndexVectorDim()] = |
| index.GetConstantWithIndexType(i); |
| llvm_ir::IrArray::Index raw_scatter_index_index( |
| raw_scatter_index_multidim, scatter_indices_shape_fixed, |
| index.GetType()); |
| |
| int64_t operand_dim = desc.dim_numbers.getScatterDimsToOperandDims()[i]; |
| TF_ASSIGN_OR_RETURN( |
| llvm::Value* const loaded_scatter_index, |
| desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape( |
| scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_))); |
| // And add the index to our window index. This yields the output index. |
| llvm::Value* casted_scatter_index = |
| IntCast(loaded_scatter_index, index.GetType(), |
| /*isSigned=*/true); |
| llvm::Value* dim_offset = |
| Add(input_window_multidim[operand_dim], casted_scatter_index); |
| input_window_multidim[operand_dim] = dim_offset; |
| |
| // Also do the bounds check now. |
| int64_t max_index = desc.operand_shape.dimensions(operand_dim) - |
| input_window_bounds[operand_dim] + 1; |
| // is_in_bounds = index >= 0 && index < dim_size-window_size+1 |
| // --> index u< dim_size-window_size+1 |
| is_in_bounds = |
| And(is_in_bounds, ICmpULT(casted_scatter_index, |
| index.GetConstantWithIndexType(max_index))); |
| } |
| |
| llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse( |
| is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false); |
| llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_); |
| // All done, now just read from the calculated input from the window, and do |
| // an atomic store to the calculated location in the output. |
| llvm_ir::IrArray::Index input_window_index( |
| input_window_multidim, desc.output.GetShape(), index.GetType()); |
| llvm::Value* output_address = |
| desc.output.EmitArrayElementAddress(input_window_index, &b_); |
| llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry( |
| llvm_ir::PrimitiveTypeToIrType(desc.updates_shape.element_type(), |
| module_), |
| "input_address", &b_); |
| TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, |
| desc.updates_gen(index)); |
| Store(input_ir_value, input_address); |
| |
| if (!desc.unique_indices) { |
| return EmitAtomicOperationForNestedComputation( |
| *desc.update_computation, output_address, input_address); |
| } else { |
| return EmitCallToNestedComputation(*desc.update_computation, |
| {output_address, input_address}, |
| output_address); |
| } |
| }; |
| |
| // Launch a kernel that reads every element in the updates tensor. We could |
| // also do one kernel per window instead if bounds checks turn out to be a |
| // bottleneck. |
| return ParallelLoopEmitter(loop_body_emitter, desc.updates_shape, |
| launch_dimensions, &b_) |
| .EmitLoop(desc.name, |
| desc.get_index_type(launch_dimensions.launch_bound())); |
| } |
| |
| // This transformation should be migrated off. See b/171334474. |
| StatusOr<HloComputation*> |
| IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region, |
| bool is_fusion) { |
| std::unique_ptr<HloModule>& module = scratch_nested_computations_[region]; |
| if (module == nullptr) { |
| std::vector<Shape> operand_shapes, output_shapes; |
| if (is_fusion) { |
| mlir::Operation* clone = region->getParentOp()->clone(); |
| region = &mlir::cast<mlir::lmhlo::FusionOp>(clone).region(); |
| TF_RETURN_IF_ERROR( |
| ProcessFusionForConversion(region, &operand_shapes, &output_shapes)); |
| } |
| |
| xla::XlaComputation xla_computation; |
| mlir::MlirToHloConversionOptions options; |
| options.propagate_layouts = true; |
| options.propagate_bitcast_layouts_to_backend_config = true; |
| TF_RETURN_IF_ERROR( |
| ConvertRegionToComputation(region, &xla_computation, options)); |
| |
| if (is_fusion) { |
| region->getParentOp()->erase(); |
| } |
| |
| TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape()); |
| TF_ASSIGN_OR_RETURN( |
| module, HloModule::CreateFromProto(xla_computation.proto(), |
| HloModuleConfig(program_shape))); |
| |
| if (is_fusion) { |
| HloComputation* fused_computation = module->entry_computation(); |
| |
| CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters()); |
| for (int i = 0; i < fused_computation->num_parameters(); i++) { |
| *fused_computation->parameter_instruction(i) |
| ->mutable_shape() |
| ->mutable_layout() = operand_shapes[i].layout(); |
| } |
| HloInstruction* root = fused_computation->root_instruction(); |
| // Manually fold Tuple(GTE(a, 0), GTE(a, 1), GTE(a, 2), ...) to a. |
| // FusedIrEmitter doesn't take GTE ops because we aim to elimiate tuples |
| // as much as possible. |
| if (root->opcode() == HloOpcode::kTuple) { |
| [&] { |
| HloInstruction* real_root = nullptr; |
| int expected_tuple_index = 0; |
| for (HloInstruction* operand : root->operands()) { |
| if (operand->opcode() != HloOpcode::kGetTupleElement) { |
| return; |
| } |
| if (real_root == nullptr) { |
| real_root = operand->mutable_operand(0); |
| } else if (real_root != operand->operand(0)) { |
| return; |
| } |
| if (expected_tuple_index != operand->tuple_index()) { |
| return; |
| } |
| expected_tuple_index++; |
| } |
| fused_computation->set_root_instruction(real_root); |
| std::vector<HloInstruction*> to_be_removed; |
| to_be_removed.push_back(root); |
| for (HloInstruction* operand : root->operands()) { |
| to_be_removed.push_back(operand); |
| } |
| for (auto instr : to_be_removed) { |
| TF_CHECK_OK(fused_computation->RemoveInstruction(instr)); |
| } |
| |
| root = real_root; |
| }(); |
| } |
| |
| if (output_shapes.size() > 1) { |
| CHECK(root->shape().IsTuple()); |
| CHECK_EQ(root->shape().tuple_shapes_size(), output_shapes.size()); |
| |
| for (int i = 0; i < output_shapes.size(); i++) { |
| *root->mutable_shape()->mutable_tuple_shapes(i) = output_shapes.at(i); |
| } |
| } else { |
| CHECK_EQ(1, output_shapes.size()); |
| *root->mutable_shape() = output_shapes[0]; |
| } |
| } |
| // Post-process the generated computation: |
| // * Sanitize constant names, so that they can be used as LLVM global |
| // symbols. |
| // * Propagate layouts for tuple types. |
| for (HloComputation* computation : module->computations()) { |
| for (HloInstruction* instr : computation->MakeInstructionPostOrder()) { |
| if (instr->opcode() == HloOpcode::kConstant) { |
| // Notice that IR emitters use the name of constants as LLVM symbol |
| // names, therefore it's important to not let these constants in the |
| // new module collide with constants in the original module by names. |
| // Unique them by prepending the module name. |
| // |
| // TODO(timshen): A better solution would be to plumb the exact |
| // constant names through original HLO -> LHLO -> MHLO -> HLO. This is |
| // hard because XLA builder doesn't support setting names. Revisit |
| // this once we get rid of this function, or don't rely on the op name |
| // (which shouldn't be the identity) to generate LLVM symbols. |
| instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName( |
| module->name() + "_" + instr->name())); |
| } |
| if (instr->shape().IsTuple() && |
| computation == module->entry_computation() && |
| instr != computation->root_instruction()) { |
| return InternalError("Non-root tuple types are not handled."); |
| } |
| } |
| } |
| } |
| return module->entry_computation(); |
| } |
| |
| Status IrEmitterUnnested::EmitSort(mlir::Operation* op) { |
| auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(op); |
| MlirEmitterContext context; |
| context.SetOperation(sort_op); |
| |
| ThunkSequence thunks; |
| |
| const Shape& keys_shape = context.operand_shapes[0]; |
| int64_t dimension_to_sort = sort_op.dimension(); |
| for (int64_t i = 0; i < context.operand_shapes.size(); ++i) { |
| // We assume that the layout of all involved operands and outputs is the |
| // same. |
| TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape, |
| context.operand_shapes[i])); |
| TF_RET_CHECK( |
| LayoutUtil::LayoutsInShapesEqual(keys_shape, context.output_shapes[i])); |
| |
| // If possible, we share buffers. If that is not possible, we need to copy |
| // the values, because the emitter does the sorting in-place. |
| TF_ASSIGN_OR_RETURN(auto destination_buffer, |
| GetAllocationSlice(sort_op.output()[i])); |
| TF_ASSIGN_OR_RETURN(auto source_address, |
| GetAllocationSlice(sort_op.operands()[i])); |
| if (destination_buffer != source_address) { |
| // TODO(b/26783907): Figure out why we never seem to share buffers for |
| // key/value sort. |
| VLOG(2) << context.name << " requires initial D2D copy for operand " << i; |
| thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( |
| Thunk::ThunkInfo(), |
| /*source_address=*/source_address, |
| /*destination_buffer=*/destination_buffer, |
| /*mem_size=*/ShapeUtil::ByteSizeOf(context.operand_shapes[i]))); |
| } |
| } |
| |
| uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); |
| int64_t num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); |
| VLOG(2) << context.name << " requires " << num_stages << " stages."; |
| CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); |
| CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound); |
| |
| // Naive C++ code for the outer loops: |
| // |
| // for (int64_t stage = 0; stage < Log2Ceiling(dimension_to_sort_bound); |
| // ++stage) { |
| // int64_t first_xor_mask = (1LL << (stage + 1)) - 1; |
| // SortInPlace(first_xor_mask); |
| // for (int64_t mask = stage - 1; mask >= 0; --mask) { |
| // int64_t later_xor_mask = 1LL << mask; |
| // SortInPlace(later_xor_mask); |
| // } |
| // } |
| // |
| // This follows the alternative representation of the algorithm described on |
| // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter |
| // |
| // Each mask specifies how to derive from one position in the array the |
| // position with which it should be compared (we calculate the xor of the |
| // position with the mask). |
| // As an optimization, we can move the 'mask' loop to inside the |
| // sorting/comparison loop if the comparisons happen within a small block of |
| // the array. To make this work, we collect all consecutive masks that are |
| // smaller than our chosen power of 2 tile size, and pass them to SortInPlace. |
| // Each thread then processes one tile of data. |
| |
| const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages); |
| |
| // If we cannot combine several xor masks together, we don't use tiling, so we |
| // calculate the standard launch dimensions for the shape. However we only |
| // need to iterate through ~half of the dimension to sort (rounded up to the |
| // next highest power of 2), because each iteration compares one pair of |
| // elements. |
| Shape standard_iteration_shape = keys_shape; |
| uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1); |
| standard_iteration_shape.set_dimensions(dimension_to_sort, |
| standard_num_iterations_in_sort_dim); |
| TF_ASSIGN_OR_RETURN( |
| LaunchDimensions standard_launch_dimensions, |
| CalculateLaunchDimensions(standard_iteration_shape, |
| ir_emitter_context_->gpu_device_info())); |
| |
| // Calculate the launch dimensions for the case where we use tiling. We split |
| // the dimension that should be sorted into tiles of size 'kTileSize'. This |
| // means we first need to round 'dimension_to_sort_bound' up to be a multiple |
| // of the tile size. |
| int64_t rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize); |
| Shape iteration_shape = keys_shape; |
| |
| // We iterate through the element pairs that should be compared. |
| uint64 num_iterations_in_sort_dim = rounded_bound / 2; |
| iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim); |
| uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape); |
| |
| // For correctness reasons we need exactly 'kTileSize' / 2 many threads per |
| // block. Each thread is responsible for copying exactly two adjacent elements |
| // into shared memory, and then does a comparison of two possibly different |
| // elements taken from shared memory. |
| const uint64 kThreadsPerBlock = kTileSize / 2; |
| |
| // Check whether we should use any tiling. We might not be able to use it if |
| // we have not enough threads, or not enough shared memory. Also it does not |
| // give a speedup if the tile size is < 128. |
| int64_t total_shared_memory_needed = 0; |
| for (int64_t i = 0; i < context.operand_shapes.size(); ++i) { |
| total_shared_memory_needed += |
| kTileSize * ShapeUtil::ByteSizeOfPrimitiveType( |
| context.operand_shapes[i].element_type()); |
| } |
| bool no_tiling = |
| kTileSize < 128 || |
| kThreadsPerBlock > |
| ir_emitter_context_->gpu_device_info().threads_per_block_limit || |
| total_shared_memory_needed > |
| ir_emitter_context_->gpu_device_info().shared_memory_per_block; |
| VLOG(2) << absl::StreamFormat( |
| "%s %s use tiling. No tiling if any of the following is true: " |
| "kTileSize=%d < 128, " |
| "kThreadsPerBlock=%d > threads_per_block_limit=%d, " |
| "total_shared_memory_needed=%d > shared_memory_per_block=%d", |
| context.name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock, |
| ir_emitter_context_->gpu_device_info().threads_per_block_limit, |
| total_shared_memory_needed, |
| ir_emitter_context_->gpu_device_info().shared_memory_per_block); |
| |
| uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock); |
| LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock); |
| VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block", |
| context.name, num_blocks, kThreadsPerBlock); |
| |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| auto emit_kernel = [&](absl::Span<const int64_t> xor_masks) { |
| VLOG(2) << absl::StreamFormat( |
| "%s uses kernel for xor masks [%s]", context.name, |
| absl::StrJoin(xor_masks, ", ", [](std::string* out, int64_t xor_mask) { |
| absl::StrAppendFormat(out, "0x%x", xor_mask); |
| })); |
| thunks.emplace_back(); |
| LaunchDimensions launch_dimensions = xor_masks.size() > 1 |
| ? tiled_launch_dimensions |
| : standard_launch_dimensions; |
| TF_ASSIGN_OR_RETURN( |
| thunks.back(), |
| BuildKernelThunk(sort_op, sort_op.output(), Thunk::ThunkInfo(), |
| &ir_arrays, launch_dimensions)); |
| std::vector<IrArray> values_arrays; |
| values_arrays.reserve(context.operand_shapes.size()); |
| for (int64_t i = 0; i < context.operand_shapes.size(); ++i) { |
| values_arrays.push_back(ir_arrays[i]); |
| } |
| TF_ASSIGN_OR_RETURN(const HloComputation* comparator, |
| GetOrCreateSubComputationFromRegion( |
| &sort_op.comparator(), /*is_fusion=*/false)); |
| return llvm_ir::EmitSortInPlace( |
| dimension_to_sort, values_arrays, IrName(context.name), xor_masks, &b_, |
| launch_dimensions, |
| xor_masks.size() > 1 ? num_iterations_in_sort_dim |
| : standard_num_iterations_in_sort_dim, |
| kTileSize, |
| [&](absl::Span<llvm::Value* const> operands, llvm::Value* output) { |
| return EmitCallToNestedComputation(*comparator, operands, output); |
| }); |
| }; |
| std::vector<int64_t> xor_masks; |
| for (int64_t stage = 0; stage < num_stages; ++stage) { |
| for (int64_t mask = stage; mask >= 0; --mask) { |
| int64_t xor_mask; |
| if (mask == stage) { |
| xor_mask = (1LL << (stage + 1)) - 1; |
| } else { |
| xor_mask = 1LL << mask; |
| } |
| if (xor_mask >= kTileSize || no_tiling) { |
| if (!xor_masks.empty()) { |
| TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); |
| xor_masks.clear(); |
| } |
| TF_RETURN_IF_ERROR(emit_kernel({xor_mask})); |
| } else { |
| xor_masks.push_back(xor_mask); |
| } |
| } |
| } |
| if (!xor_masks.empty()) { |
| TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); |
| } |
| VLOG(2) << absl::StreamFormat( |
| "%s requires %d thunks (including any D2D copies)", context.name, |
| thunks.size()); |
| |
| AddThunkToThunkSequence( |
| absl::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks))); |
| return Status::OK(); |
| } |
| |
| template <typename ThunkType, typename OpT> |
| Status IrEmitterUnnested::EmitReplicaOrPartitionId(mlir::Operation* op) { |
| auto casted = mlir::cast<OpT>(op); |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, |
| GetAllocationSlice(casted.getOperand())); |
| AddThunkToThunkSequence( |
| absl::make_unique<ThunkType>(GetThunkInfo(op), result_slice)); |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitCollectivePermute(mlir::Operation* op) { |
| auto collective_permute_op = mlir::cast<mlir::lmhlo::CollectivePermuteOp>(op); |
| |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice source_slice, |
| GetAllocationSlice(collective_permute_op.operand())); |
| TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, |
| GetAllocationSlice(collective_permute_op.output())); |
| |
| const Shape shape = GetShape(collective_permute_op.operand()); |
| const int64_t replica_count = hlo_module_config_.replica_count(); |
| const int64_t partition_count = hlo_module_config_.num_partitions(); |
| |
| if (NcclCollectivePermuteThunk::IsDegenerate( |
| collective_permute_op, replica_count, partition_count)) { |
| // For a degenerate collective permute, just generate a copy thunk. |
| AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>( |
| GetThunkInfo(op), |
| /*source_address=*/source_slice, |
| /*destination_buffer=*/result_slice, |
| /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); |
| } else { |
| std::unique_ptr<Thunk> thunk; |
| if (IsBefThunkEnabled()) { |
| std::vector<BufferAllocation::Slice> buffers = {source_slice, |
| result_slice}; |
| TF_ASSIGN_OR_RETURN(thunk, CreateBefCollectivePermuteThunk( |
| GetThunkInfo(op), op, std::move(buffers), |
| replica_count, partition_count)); |
| } else { |
| const NcclCollectivePermuteThunk::Buffer buffer = { |
| /*element_count=*/ShapeUtil::ElementsIn(shape), |
| /*source_buffer=*/source_slice, |
| /*destination_buffer=*/result_slice}; |
| thunk = absl::make_unique<NcclCollectivePermuteThunk>( |
| GetThunkInfo(op), collective_permute_op, replica_count, |
| partition_count, buffer); |
| } |
| AddThunkToThunkSequence(std::move(thunk)); |
| } |
| return Status::OK(); |
| } |
| |
| Status MaybeAddAllReduceStartThunkToMap( |
| absl::flat_hash_map<mlir::Operation*, NcclAllReduceStartThunk*>& |
| all_reduce_start_thunks, |
| mlir::Operation* op, Thunk* thunk) { |
| if (mlir::isa<mlir::lmhlo_gpu::AllReduceStartOp>(op)) { |
| TF_RET_CHECK(all_reduce_start_thunks |
| .emplace(op, static_cast<NcclAllReduceStartThunk*>(thunk)) |
| .second) |
| << "all-reduce-start with this unique ID already seen"; |
| } |
| return Status::OK(); |
| } |
| |
| template <typename NcclThunkType, typename OpTy> |
| Status IrEmitterUnnested::EmitNcclThunk(mlir::Operation* untyped_op) { |
| OpTy op = mlir::cast<OpTy>(untyped_op); |
| int64_t replica_count = hlo_module_config_.replica_count(); |
| int64_t partition_count = hlo_module_config_.num_partitions(); |
| VLOG(2) << NcclThunkType::GetName() << "; replica count: " << replica_count |
| << "; partition count: " << partition_count |
| << "; operand count: " << op.operands().size() |
| << "; NCCL is enabled: " << NcclThunkType::NcclIsEnabled(); |
| |
| // A given collective op can be degenerate if across all groups formed |
| // by it are singleton. In such a case, we don't need to do any communication |
| // and we can just copy the input to the output. |
| bool is_degenerate = |
| NcclThunkType::IsDegenerate(op, replica_count, partition_count); |
| bool should_use_nccl_thunk = |
| !is_degenerate && NcclThunkType::CanImplement(op); |
| |
| // Stash relevant information in NcclCollectiveThunk::Buffer even if we may |
| // not generate an NcclCollectiveThunk. |
| std::vector<NcclCollectiveThunk::Buffer> buffers; |
| buffers.reserve(op.operands().size()); |
| for (auto it : llvm::zip(op.operands(), op.results())) { |
| mlir::Value operand = std::get<0>(it); |
| mlir::Value result = std::get<1>(it); |
| const Shape shape = GetShape(operand); |
| TF_ASSIGN_OR_RETURN(auto source_slice, GetAllocationSlice(operand)); |
| TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(result)); |
| buffers.push_back(NcclCollectiveThunk::Buffer{ |
| /*element_count=*/ShapeUtil::ElementsIn(shape), |
| /*source_buffer=*/source_slice, |
| /*destination_buffer=*/dest_slice}); |
| } |
| |
| if (should_use_nccl_thunk) { |
| std::unique_ptr<Thunk> thunk; |
| if (IsBefThunkEnabled() && (mlir::isa<mlir::lmhlo::AllGatherOp>(op) || |
| mlir::isa<mlir::lmhlo::AllReduceOp>(op) || |
| mlir::isa<mlir::lmhlo::ReduceScatterOp>(op) || |
| mlir::isa<mlir::lmhlo::AllToAllOp>(op))) { |
| std::vector<BufferAllocation::Slice> arg_buffers; |
| arg_buffers.reserve(buffers.size() * 2); |
| for (const auto& buffer : buffers) { |
| arg_buffers.push_back(buffer.source_buffer); |
| } |
| for (const auto& buffer : buffers) { |
| arg_buffers.push_back(buffer.destination_buffer); |
| } |
| TF_ASSIGN_OR_RETURN( |
| thunk, CreateBefThunk(GetThunkInfo(op), op, std::move(arg_buffers))); |
| } else { |
| thunk = absl::make_unique<NcclThunkType>(GetThunkInfo(op), op, |
| /*buffers=*/std::move(buffers)); |
| } |
| // Record thunks for all-reduce-start ops as the done ops need them. |
| TF_RETURN_IF_ERROR(MaybeAddAllReduceStartThunkToMap( |
| all_reduce_start_thunks_, op, thunk.get())); |
| AddThunkToThunkSequence(std::move(thunk)); |
| return Status::OK(); |
| } |
| |
| // Signal that all-reduce-start thunk not created with nullptr. |
| TF_RETURN_IF_ERROR( |
| MaybeAddAllReduceStartThunkToMap(all_reduce_start_thunks_, op, nullptr)); |
| |
| if (!is_degenerate) { |
| CollectiveOpGroupMode group_mode = NcclThunkType::GetGroupMode(op); |
| |
| string message = absl::StrFormat( |
| "Requested %s not implemented on GPU; replica_count: %d; " |
| "partition_count: %d, group_mode: %s, operand_count: %d; NCCL support: " |
| "%d", |
| NcclThunkType::GetName(), replica_count, partition_count, |
| CollectiveOpGroupModeToString(group_mode), op.operands().size(), |
| NcclThunkType::NcclIsEnabled()); |
| if (!op.operands().empty()) { |
| const Shape shape = GetShape(op.operands().front()); |
| absl::StrAppendFormat(&message, "; first operand array element-type: %s", |
| PrimitiveType_Name(shape.element_type())); |
| } |
| return Unimplemented("%s", message); |
| } |
| |
| // All-gather with one replica is simply the identity function. Buffer |
| // assignment expects a copy, so that's what we do. |
| ThunkSequence thunks; |
| for (int64_t i = 0; i < buffers.size(); i++) { |
| const Shape shape = GetShape(op.operands()[i]); |
| thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( |
| buffers.size() == 1 ? GetThunkInfo(op) : Thunk::ThunkInfo(), |
| /*source_address=*/buffers[i].source_buffer, |
| /*destination_buffer=*/buffers[i].destination_buffer, |
| /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); |
| } |
| if (thunks.size() == 1) { |
| AddThunkToThunkSequence(std::move(thunks[0])); |
| } else { |
| AddThunkToThunkSequence(absl::make_unique<SequentialThunk>( |
| GetThunkInfo(op), std::move(thunks))); |
| } |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitAllReduceDone(mlir::Operation* op) { |
| auto done_op = mlir::cast<mlir::lmhlo_gpu::AllReduceDoneOp>(op); |
| auto start_op = |
| done_op.token().getDefiningOp<mlir::lmhlo_gpu::AllReduceStartOp>(); |
| auto it = all_reduce_start_thunks_.find(start_op); |
| TF_RET_CHECK(it != all_reduce_start_thunks_.end()) |
| << "couldn't find thunk for all-reduce-start op"; |
| |
| // Can be null if no all-reduce-start thunk was created (e.g. if the start op |
| // is degenerate), in which case there's nothing to do here. |
| if (it->second != nullptr) { |
| AddThunkToThunkSequence(absl::make_unique<NcclAllReduceDoneThunk>( |
| GetThunkInfo(op), *it->second)); |
| } |
| all_reduce_start_thunks_.erase(it); |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitInfeed(mlir::Operation* op) { |
| auto infeed_op = mlir::cast<mlir::lmhlo::InfeedOp>(op); |
| |
| std::vector<ShapedSlice> dest_slices; |
| dest_slices.reserve(infeed_op.outputs().size()); |
| |
| for (mlir::Value output : infeed_op.outputs()) { |
| TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(output)); |
| const Shape& shape = GetShape(output); |
| dest_slices.push_back(ShapedSlice{slice, shape}); |
| } |
| |
| AddThunkToThunkSequence( |
| absl::make_unique<InfeedThunk>(GetThunkInfo(op), std::move(dest_slices))); |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitOutfeed(mlir::Operation* op) { |
| auto outfeed_op = mlir::cast<mlir::lmhlo::OutfeedOp>(op); |
| |
| std::vector<ShapedSlice> source_slices; |
| source_slices.reserve(outfeed_op.operands().size()); |
| |
| for (mlir::Value operand : outfeed_op.operands()) { |
| TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand)); |
| const Shape& shape = GetShape(operand); |
| source_slices.push_back(ShapedSlice{slice, shape}); |
| } |
| |
| AddThunkToThunkSequence(absl::make_unique<OutfeedThunk>( |
| GetThunkInfo(op), std::move(source_slices))); |
| return Status::OK(); |
| } |
| |
| StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildKernelThunkImpl( |
| absl::string_view name, Thunk::ThunkInfo thunk_info, |
| absl::Span<const BufferSlice> slices, |
| std::vector<llvm_ir::IrArray>* ir_arrays, |
| const LaunchDimensions& launch_dimensions) { |
| // Figure out which buffer allocations need to be passed as arguments to our |
| // kernel. This is simply all of the allocations referenced in slices, |
| // plus the XLA temp buffer (if we have it). We always include the temp |
| // buffer because even if the kernel itself doesn't use it, a nested |
| // subcomputation within the kernel (e.g. a kMap's computation) might. |
| std::unordered_set<const BufferAllocation*> buffers_needed; |
| for (const auto& slice : slices) { |
| buffers_needed.insert(slice.buffer_slice.allocation()); |
| } |
| absl::optional<const BufferAllocation*> temp_buffer; |
| for (const BufferAllocation& alloc : ir_emitter_context_->allocations()) { |
| if (alloc.IsPreallocatedTempBuffer()) { |
| if (!temp_buffer.has_value()) { |
| // Retrieve the first seen temp buffer. |
| temp_buffer = &alloc; |
| } |
| } |
| } |
| if (temp_buffer.has_value()) { |
| buffers_needed.insert(*temp_buffer); |
| } |
| |
| // We'll pass a pointer to each of the elements of `buffers` to our kernel, in |
| // this order. |
| std::vector<const BufferAllocation*> non_constant_buffers; |
| absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), |
| [](const BufferAllocation* allocation) { |
| return !allocation->is_constant(); |
| }); |
| |
| absl::c_sort(non_constant_buffers, |
| [](const BufferAllocation* a, const BufferAllocation* b) { |
| return a->index() < b->index(); |
| }); |
| |
| llvm::Function* kernel = BuildKernelPrototype(name, non_constant_buffers); |
| |
| // Build a map from a BufferAllocation to the corresponding argument in our |
| // kernel. |
| std::unordered_map<const BufferAllocation*, llvm::Value*> kernel_args; |
| { |
| auto arg_it = kernel->arg_begin(); |
| auto buffers_it = non_constant_buffers.begin(); |
| for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) { |
| kernel_args[*buffers_it] = arg_it; |
| |
| // Annotate all allocations with LLVM's `noalias`. |
| // There are three kinds of allocations: |
| // * Read-only allocations, aka input parameters that are not aliased with |
| // outputs. |
| // * Read-write allocations, including all output buffers, some of which |
| // may alias with input HLO parameters, but aliased HLO buffers are always |
| // assigned with the same allocation. |
| // * The temp buffer. |
| // |
| // Read-only allocations may overlap with each other, but since they are |
| // not mutated, they can always be annotated with `noalias` per LLVM |
| // semantics. |
| // |
| // Read-write allocations and the temp buffer don't overlap with any |
| // allocations, therefore they can also be annotated with `noalias`. |
| kernel->addParamAttr( |
| arg_it->getArgNo(), |
| llvm::Attribute::get(arg_it->getContext(), llvm::Attribute::NoAlias)); |
| } |
| } |
| |
| absl::flat_hash_set<BufferAllocation::Slice> buffers_written; |
| for (const auto& slice : slices) { |
| if (slice.written) { |
| buffers_written.insert(slice.buffer_slice); |
| } |
| } |
| |
| ir_arrays->clear(); |
| |
| // For each buffer our kernel might want to touch, bind it to a value derived |
| // from our kernel args. |
| for (const BufferSlice& slice : slices) { |
| const BufferAllocation::Slice& buffer_slice = slice.buffer_slice; |
| |
| llvm::Value* loc; |
| if (!slice.constant_name.empty()) { |
| loc = ir_emitter_context_->llvm_module()->getGlobalVariable( |
| slice.constant_name); |
| CHECK_NE(loc, nullptr); |
| } else { |
| CHECK(!buffer_slice.allocation()->is_constant()); |
| loc = InBoundsGEP(kernel_args.at(buffer_slice.allocation()), |
| {b_.getInt64(buffer_slice.offset())}); |
| } |
| |
| llvm_ir::IrArray ir_array(CastToTypedValue(slice.shape, loc, &b_), |
| slice.shape); |
| if (!buffers_written.contains(slice.buffer_slice)) { |
| ir_array.MarkInvariantOverWholeProgram(&loc->getContext()); |
| } |
| |
| ir_arrays->push_back(ir_array); |
| } |
| |
| AnnotateThunkLaunchDimensions(launch_dimensions, |
| std::string(kernel->getName()), |
| ir_emitter_context_->llvm_module()); |
| |
| if (IsBefThunkEnabled()) { |
| return CreateBefKernelThunk(thunk_info, non_constant_buffers, |
| std::string(kernel->getName()), |
| launch_dimensions); |
| } else { |
| return {absl::make_unique<KernelThunk>(thunk_info, non_constant_buffers, |
| std::string(kernel->getName()), |
| launch_dimensions)}; |
| } |
| } |
| |
| StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildKernelThunk( |
| mlir::Operation* op, mlir::ValueRange operands, Thunk::ThunkInfo thunk_info, |
| std::vector<llvm_ir::IrArray>* ir_arrays, |
| const LaunchDimensions& launch_dimensions) { |
| TF_RET_CHECK(!mlir::isa<mlir::lmhlo::FusionOp>(op)); |
| |
| std::vector<BufferSlice> slices; |
| for (mlir::Value operand : operands) { |
| slices.emplace_back(); |
| auto& slice = slices.back(); |
| TF_ASSIGN_OR_RETURN(slice.buffer_slice, |
| GetAllocationSlice(operand, &slice.constant_name)); |
| slice.written = WritesMlirBuffer(op, operand); |
| slice.shape = GetShape(operand); |
| } |
| std::string name = mlir::GetNameFromLoc(op->getLoc()); |
| return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays, |
| launch_dimensions); |
| } |
| |
| StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildKernelThunk( |
| mlir::Operation* op, Thunk::ThunkInfo thunk_info, |
| std::vector<llvm_ir::IrArray>* ir_arrays, |
| const LaunchDimensions& launch_dimensions) { |
| if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) { |
| auto operands = GetHloOperands(op); |
| auto outputs = GetHloOutputs(op); |
| |
| std::vector<BufferSlice> slices; |
| for (mlir::Value operand : operands) { |
| slices.emplace_back(); |
| BufferSlice& slice = slices.back(); |
| TF_ASSIGN_OR_RETURN(slice.buffer_slice, |
| GetAllocationSlice(operand, &slice.constant_name)); |
| slice.written = false; |
| slice.shape = GetShape(operand); |
| } |
| for (mlir::Value output : outputs) { |
| slices.emplace_back(); |
| BufferSlice& slice = slices.back(); |
| TF_ASSIGN_OR_RETURN(slice.buffer_slice, |
| GetAllocationSlice(output, &slice.constant_name)); |
| slice.written = true; |
| slice.shape = GetShape(output); |
| } |
| std::string name = mlir::GetNameFromLoc(op->getLoc()); |
| return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays, |
| launch_dimensions); |
| } |
| return BuildKernelThunk(op, op->getOperands(), thunk_info, ir_arrays, |
| launch_dimensions); |
| } |
| |
| std::unique_ptr<Thunk> IrEmitterUnnested::BuildConstantInitializerThunk( |
| absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest, |
| const Shape& output_shape) { |
| int64_t num_bytes = init_value.size(); |
| if (absl::c_all_of(init_value, [](uint8 byte) { return byte == 0; })) { |
| return absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(), dest); |
| } |
| |
| // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by |
| // repeating the literal 4 or 2 times, so long as the destination buffer is |
| // an even multiple of 32 bits long. |
| if ((num_bytes == 1 || num_bytes == 2) && |
| ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { |
| uint16 pattern16; |
| if (num_bytes == 1) { |
| uint8 b = init_value.front(); |
| pattern16 = uint16{b} | (uint16{b} << 8); |
| } else { |
| memcpy(&pattern16, init_value.data(), sizeof(pattern16)); |
| } |
| uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); |
| return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(), |
| pattern32, dest); |
| } |
| |
| // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit |
| // memset so long as all 32-bit words of the scalar are equal to each other. |
| if (num_bytes >= 4 && num_bytes % 4 == 0 && |
| memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) == |
| 0) { |
| uint32 word; |
| memcpy(&word, init_value.data(), sizeof(word)); |
| return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(), word, |
| dest); |
| } |
| |
| return nullptr; |
| } |
| |
| StatusOr<std::unique_ptr<Thunk>> |
| IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Value init_value, |
| mlir::Value dest) { |
| mlir::DenseElementsAttr const_init; |
| if (auto get_global_memref = |
| mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>( |
| init_value.getDefiningOp())) { |
| auto global_memref = |
| mlir::SymbolTable::lookupNearestSymbolFrom<mlir::memref::GlobalOp>( |
| get_global_memref, get_global_memref.nameAttr()); |
| if (global_memref.constant() && global_memref.initial_value()) { |
| // If the initial value happens to be a constant, generate a specialized |
| // thunk. |
| const_init = global_memref.initial_value() |
| .getValue() |
| .cast<mlir::DenseElementsAttr>(); |
| } |
| } else if (auto constant = mlir::dyn_cast_or_null<mlir::mhlo::ConstOp>( |
| init_value.getDefiningOp())) { |
| const_init = constant.value().dyn_cast<mlir::DenseElementsAttr>(); |
| } |
| |
| if (const_init) { |
| std::vector<uint8> literal_bytes; |
| TF_RETURN_IF_ERROR( |
| CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes)); |
| |
| TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(dest)); |
| |
| const Shape dest_shape = GetShape(dest); |
| auto thunk = |
| BuildConstantInitializerThunk(literal_bytes, dest_slice, dest_shape); |
| if (thunk) { |
| return {std::move(thunk)}; |
| } |
| } |
| return std::unique_ptr<Thunk>(); |
| } |
| |
| StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( |
| mlir::Operation* op, mlir::Value init_value, mlir::Value dest) { |
| // initial value must be a scalar memref. |
| auto init_type = init_value.getType().dyn_cast<mlir::MemRefType>(); |
| TF_RET_CHECK(init_type.getRank() == 0); |
| |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk, |
| TryBuildConstantInitializerThunk(init_value, dest)); |
| if (constant_init_thunk) { |
| return {std::move(constant_init_thunk)}; |
| } |
| |
| // Otherwise fall back to our slow initializer code. The thunk in this case |
| // will just need the IR arrays for the initial value and the destination. |
| const Shape dest_shape = GetShape(dest); |
| TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, |
| CalculateLaunchDimensions( |
| dest_shape, ir_emitter_context_->gpu_device_info())); |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| TF_ASSIGN_OR_RETURN( |
| std::unique_ptr<Thunk> kernel_thunk, |
| BuildKernelThunk(op, {init_value, dest}, Thunk::ThunkInfo(), &ir_arrays, |
| launch_dimensions)); |
| |
| const llvm_ir::IrArray init_array = ir_arrays[0]; |
| const llvm_ir::IrArray dest_array = ir_arrays[1]; |
| |
| std::string name = mlir::GetNameFromLoc(op->getLoc()); |
| TF_RETURN_IF_ERROR(ParallelLoopEmitter( |
| [=](const IrArray::Index& index) { |
| return init_array.EmitReadArrayElement(index, &b_); |
| }, |
| dest_array, launch_dimensions, &b_) |
| .EmitLoop(mlir::GetNameFromLoc(op->getLoc()))); |
| |
| return std::move(kernel_thunk); |
| } |
| |
| StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildFusedInitializerThunk( |
| mlir::lmhlo::FusionOp fusion, int output_index) { |
| auto reduce = mlir::dyn_cast_or_null<mlir::mhlo::ReduceOp>( |
| fusion.getFusionRoots()[output_index]); |
| |
| TF_RET_CHECK(reduce); |
| TF_RET_CHECK(reduce.getNumResults() == 1); |
| |
| mlir::Value init_value = reduce.init_values()[0]; |
| mlir::Value dest = fusion.getOutputBuffers()[output_index]; |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk, |
| TryBuildConstantInitializerThunk(init_value, dest)); |
| if (constant_init_thunk) { |
| return {std::move(constant_init_thunk)}; |
| } |
| |
| auto input_buffers = fusion.getInputBuffers(); |
| |
| const Shape dest_shape = GetShape(dest); |
| TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, |
| CalculateLaunchDimensions( |
| dest_shape, ir_emitter_context_->gpu_device_info())); |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> kernel_thunk, |
| BuildKernelThunk(fusion, Thunk::ThunkInfo(), &ir_arrays, |
| launch_dimensions)); |
| |
| const llvm_ir::IrArray dest_array = |
| ir_arrays[input_buffers.size() + output_index]; |
| |
| const HloComputation* fused_computation = |
| *GetOrCreateSubComputationFromRegion(&fusion.region(), |
| /*is_fusion=*/true); |
| |
| // If init_value was fused into this reduce we have to generate it first. |
| GpuElementalIrEmitter elemental_emitter(hlo_module_config_, |
| ir_emitter_context_->llvm_module(), |
| &b_, GetNestedComputer()); |
| |
| FusedIrEmitter fused_emitter(&elemental_emitter); |
| for (int i = 0; i < fused_computation->num_parameters(); i++) { |
| fused_emitter.BindGenerator( |
| fused_computation->parameter_instruction(i), |
| [this, &ir_arrays, i](llvm_ir::IrArray::Index index) { |
| return ir_arrays[i].EmitReadArrayElement(index, &b_); |
| }); |
| } |
| HloInstruction* instr = fused_computation->root_instruction(); |
| if (instr->opcode() != HloOpcode::kTuple) { |
| CHECK_EQ(0, output_index); |
| } else { |
| instr = instr->mutable_operand(output_index); |
| } |
| TF_RET_CHECK(instr->shape().IsArray()); |
| TF_ASSIGN_OR_RETURN(auto generator, |
| fused_emitter.GetGenerator(instr->operand(1))); |
| TF_RETURN_IF_ERROR( |
| ParallelLoopEmitter(generator, dest_array, launch_dimensions, &b_) |
| .EmitLoop(mlir::GetNameFromLoc(fusion.getLoc()))); |
| return {std::move(kernel_thunk)}; |
| } |
| |
| StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk( |
| mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info) { |
| // Generate thunk sequence for while 'condition'. |
| mlir::Region* condition = &while_op.cond(); |
| TF_ASSIGN_OR_RETURN( |
| auto ir_emitter_condition, |
| IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_)); |
| |
| TF_RETURN_IF_ERROR(ir_emitter_condition->EmitLmhloRegion(condition)); |
| |
| // Generate thunk sequence for while 'body'. |
| mlir::Region* body = &while_op.body(); |
| TF_ASSIGN_OR_RETURN( |
| auto ir_emitter_body, |
| IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_)); |
| |
| TF_RETURN_IF_ERROR(ir_emitter_body->EmitLmhloRegion(body)); |
| |
| // Extract the condition value from the last op (exlucidng the terminator op) |
| // in the condition region. |
| auto cond_result = GetHloOutputs(while_op); |
| TF_RET_CHECK(cond_result.size() == 1); |
| TF_ASSIGN_OR_RETURN(auto cond_result_slice, |
| GetAllocationSlice(cond_result[0])); |
| |
| return std::unique_ptr<Thunk>( |
| new WhileThunk(thunk_info, cond_result_slice, |
| ir_emitter_condition->ConsumeThunkSequence(), |
| ir_emitter_body->ConsumeThunkSequence())); |
| } |
| |
| StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk( |
| mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info, |
| const int64_t loop_limit) { |
| // Generate thunk sequence for while 'body' (will be used a For loop body). |
| TF_ASSIGN_OR_RETURN( |
| auto ir_emitter_body, |
| IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_)); |
| TF_RETURN_IF_ERROR(ir_emitter_body->EmitLmhloRegion(&while_op.body())); |
| |
| return std::unique_ptr<Thunk>(new ForThunk( |
| thunk_info, loop_limit, ir_emitter_body->ConsumeThunkSequence())); |
| } |
| |
| Status IrEmitterUnnested::EmitTargetElementLoop( |
| const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) { |
| return InternalError("This should be unreachable"); |
| } |
| |
| // Gets the output offset as calculated from thread_id.x (to be applied to the |
| // offset calculated from block_id and thread_id.y). |
| static llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme, |
| llvm::Value* thread_id_x, |
| llvm::Type* index_ty, |
| llvm::IRBuilder<>* b) { |
| auto constant = [&](int64_t val) { |
| return llvm::ConstantInt::get(index_ty, val); |
| }; |
| if (tiling_scheme.GetIndexingOrder() == kStridedIndexingX) { |
| return thread_id_x; |
| } else if (tiling_scheme.GetIndexingOrder() == kStridedLinearIndexingX) { |
| return b->CreateMul(thread_id_x, constant(tiling_scheme.GetVectorSize())); |
| } |
| CHECK_EQ(tiling_scheme.GetIndexingOrder(), kLinearIndexingX); |
| return b->CreateMul(thread_id_x, |
| constant(tiling_scheme.GetTileSizeFor(kDimX))); |
| } |
| |
| // Calls `emit_elem_function()` `x_num_steps` times. If |
| // `vector_size`==1, then each element index passed to |
| // `emit_elem_function()` will be separated by `step_x`. If `vector_size`>1, |
| // then it must be a multiple of `x_num_steps`. In that case, it |
| // triggers a different indexing order that is vectorizable by |
| // LLVM. It generates many groups of calls to `emit_elem_function`. Each |
| // group is separated by `step_x` elements. Inside a group, elements |
| // are consecutive. If `check_x_tile_bounds` is true, then it will check |
| // if the element index is in bound compared to `tile_width` before |
| // calling `emit_elem_function`. |
| static void UnrollInnerTileLoop( |
| bool check_x_tile_bounds, int64_t x_num_steps, int64_t step_x, |
| int64_t vector_size, const string& loop_name, KernelSupportLibrary* ksl, |
| llvm::Value* start_offset_x, llvm::Value* y_loc, llvm::Value* tile_width, |
| const IrArray::Index& source_idx, llvm::IRBuilder<>* b, |
| const IrEmitterUnnested::EmitElementFunction* emit_elem_function) { |
| llvm::Type* index_ty = tile_width->getType(); |
| auto constant = [&](int64_t val) { |
| return llvm::ConstantInt::get(index_ty, val); |
| }; |
| IrArray::Index source_idx_x_base = source_idx.AddOffsetToDim(y_loc, kDimY, b); |
| for (int64_t j = 0; j < x_num_steps / vector_size; j++) { |
| for (int64_t i = 0; i < vector_size; i++) { |
| int64_t linear_index = j * vector_size + i; |
| llvm::Value* x_loc = b->CreateAdd(constant(j * step_x * vector_size + i), |
| start_offset_x, "x_loc"); |
| IrArray::Index source_idx_x = source_idx_x_base.AddOffsetToDim( |
| constant(j * step_x * vector_size + i), kDimX, b); |
| auto emit_element = [&] { |
| return (*emit_elem_function)(source_idx_x, y_loc, x_loc, linear_index); |
| }; |
| if (check_x_tile_bounds) { |
| ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width), |
| emit_element); |
| } else { |
| emit_element(); |
| } |
| } |
| } |
| } |
| |
| void IrEmitterUnnested::EmitTile( |
| const TilingScheme& tiling_scheme, const IrArray::Index& tile_origin_index, |
| const string& loop_name, KernelSupportLibrary* ksl, |
| const ThreadIdInfo& thread_id_info, llvm::Value* tile_height, |
| llvm::Value* tile_width, |
| const IrEmitterUnnested::EmitElementFunction& emit_elem_function) { |
| llvm::Type* index_ty = tile_width->getType(); |
| auto constant = [&](int64_t val) { |
| return llvm::ConstantInt::get(index_ty, val); |
| }; |
| int64_t num_threads_x = tiling_scheme.GetNumThreadsFor(kDimX); |
| llvm::Value* num_threads_y = constant(tiling_scheme.GetNumThreadsFor(kDimY)); |
| llvm::Value* start_offset_x = |
| GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_x, index_ty, &b_); |
| |
| // Using dilated mapping scheme, each thread steps with a stride of number |
| // of threads. |
| // Otherwise, the stride is one, but we multiply each offset by the limit of |
| // number of steps which can be made. |
| int64_t step_x = |
| tiling_scheme.GetIndexingOrder() == kLinearIndexingX ? 1 : num_threads_x; |
| int64_t vector_size = tiling_scheme.GetVectorSize(); |
| |
| IrArray::Index source_idx = |
| tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_); |
| |
| ksl->For( |
| loop_name + "_y_in_tile", |
| /*start=*/thread_id_info.thread_id_y, |
| /*end=*/ |
| tile_height, |
| /*step=*/num_threads_y, [&](llvm::Value* y_loc) { |
| auto unroll_inner_tile_loop = [&](bool check_x_tile_bounds) { |
| return UnrollInnerTileLoop( |
| check_x_tile_bounds, tiling_scheme.GetTileSizeFor(kDimX), step_x, |
| vector_size, loop_name, ksl, start_offset_x, y_loc, tile_width, |
| source_idx, &b_, &emit_elem_function); |
| }; |
| |
| // Only take this path when we unroll in a way vectorizable by |
| // LLVM. Special case when the tile doesn't fit completely for even |
| // row size. For odd row size every other row isn't aligned to the |
| // vectorized size, so it can't be vectorized by LLVM. |
| if (tiling_scheme.GetIndexingOrder() == kStridedLinearIndexingX) { |
| ksl->If( |
| loop_name + "_is_full_tile", |
| // For the last block, tile_width will be the number of |
| // elements left. |
| b_.CreateICmpEQ( |
| constant(tiling_scheme.GetBlockTileSizeFor(kDimX)), |
| tile_width), |
| [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/false); }, |
| [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/true); }); |
| } else { |
| unroll_inner_tile_loop(/*check_x_tile_bounds=*/true); |
| } |
| }); |
| } |
| |
| static IrArray::Index GetUnnormalizedIndex( |
| const IrArray::Index& normalized_shape_index, |
| const Shape& unnormalized_shape, llvm::IRBuilder<>* b_, |
| const TilingScheme& tiling_scheme) { |
| DCHECK_EQ(normalized_shape_index.size(), 3); |
| // If the normalization only add a new dimensions of size 1, |
| // generate simpler indexing. LLVM doesn't always simplify the more |
| // complicated indexing and this prevents it from vectorizing some |
| // cases. We do this only for major_to_minor memory layout. |
| if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && |
| unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] && |
| unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] && |
| unnormalized_shape.layout().minor_to_major(1) == 0) { |
| CHECK_EQ(normalized_shape_index.dims()[0], 1); |
| auto multidim = normalized_shape_index.multidim(); |
| return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape, |
| normalized_shape_index.GetType()); |
| } |
| if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && |
| unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] && |
| unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] && |
| unnormalized_shape.layout().minor_to_major(1) == 1) { |
| CHECK_EQ(normalized_shape_index.dims()[0], 1); |
| auto multidim = normalized_shape_index.multidim(); |
| return IrArray::Index({multidim[2], multidim[1]}, unnormalized_shape, |
| normalized_shape_index.GetType()); |
| } |
| llvm::Value* linear = |
| normalized_shape_index.Linearize(tiling_scheme.GetDimsInElems(), b_); |
| return IrArray::Index(linear, unnormalized_shape, b_); |
| } |
| |
| // Emits code to process a tensor element in a tile for the given kLoop fusion |
| // HLO containing parameters that are 0-2-1 transpose of its outputs. |
| // |
| // index: The index for the first output element in the normalized tensor, that |
| // is the resulting tensor after collapsing contiguous dimensions that play |
| // the same role in the transpose. |
| // kernel_info: Other information to support the kernel code generation. |
| void IrEmitterUnnested::EmitTileElementForFusion( |
| mlir::lmhlo::FusionOp fusion, |
| absl::Span<const llvm_ir::IrArray> operand_arrays, |
| absl::Span<const llvm_ir::IrArray> output_arrays, |
| const llvm_ir::IrArray::Index& index, const TilingScheme& tiling_scheme, |
| llvm::Value* y_loc, llvm::Value* x_loc, |
| absl::Span<llvm::Value* const> param_shmem_buffers) { |
| const HloComputation* fused_computation = |
| *GetOrCreateSubComputationFromRegion(&fusion.region(), |
| /*is_fusion=*/true); |
| GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, |
| GetNestedComputer()); |
| FusedIrEmitter fused_emitter(&elem_emitter); |
| for (int i = 0; i < operand_arrays.size(); i++) { |
| llvm_ir::ElementGenerator gen; |
| if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) { |
| gen = [this, param_tile_buffer, x_loc, |
| y_loc](llvm_ir::IrArray::Index index) { |
| // TODO(jlebar): Add AA metadata to this load. Tile buffers are |
| // global variables, so LLVM's points-to analysis doesn't help us |
| // much. And we want the AA info to be present before address |
| // spaces are inferred (which is pretty late in the pipeline), so |
| // even if we had address-space-based AA in LLVM, it wouldn't help |
| // us much here. |
| return b_.CreateLoad( |
| b_.CreateGEP(param_tile_buffer, |
| {index.GetConstantWithIndexType(0), x_loc, y_loc}), |
| "tiled_buffer"); |
| }; |
| } else { |
| auto array = operand_arrays[i]; |
| auto name = fused_computation->parameter_instruction(i)->name(); |
| gen = [this, array, name](const llvm_ir::IrArray::Index& index) { |
| return array.EmitReadArrayElement(index, &b_, name); |
| }; |
| } |
| fused_emitter.BindGenerator(fused_computation->parameter_instruction(i), |
| std::move(gen)); |
| } |
| IrArray::Index untiled_index = GetUnnormalizedIndex( |
| index, output_arrays[0].GetShape(), &b_, tiling_scheme); |
| llvm_ir::ElementGenerator output_generator = |
| *fused_emitter.GetGenerator(fused_computation->root_instruction()); |
| llvm::Value* output_value = output_generator(untiled_index).ValueOrDie(); |
| if (output_arrays.size() > 1) { |
| DCHECK(output_value->getType()->isStructTy()); |
| DCHECK_EQ(output_value->getType()->getStructNumElements(), |
| output_arrays.size()); |
| for (int64_t i = 0; i < output_arrays.size(); ++i) { |
| output_arrays[i].EmitWriteArrayElement( |
| untiled_index, ExtractValue(output_value, i), &b_); |
| } |
| } else { |
| output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_); |
| } |
| } |
| |
| static HloInstruction* GetFusionOutput(HloComputation* fusion, int index) { |
| HloInstruction* root = fusion->root_instruction(); |
| if (root->opcode() == HloOpcode::kTuple) { |
| root = root->mutable_operand(index); |
| } else { |
| CHECK_EQ(0, index); |
| } |
| return root; |
| } |
| |
| ReductionCodegenState IrEmitterUnnested::GenerateReductionCodegenState( |
| mlir::lmhlo::FusionOp fusion, const ReductionCodegenInfo& reduction_info, |
| absl::Span<const int> reduce_instr_index_group, |
| HloComputation* fused_computation, FusedIrEmitter* fused_emitter) { |
| ReductionCodegenState reduction_codegen_state(reduction_info); |
| VLOG(10) << "Emit prologue for reduction: " << MlirToString(fusion); |
| |
| for (int index : reduce_instr_index_group) { |
| auto reduce_inst = |
| mlir::cast<mlir::mhlo::ReduceOp>(fusion.getFusionRoots()[index]); |
| |
| VLOG(10) << "Emit prologue for reduction: " << MlirToString(reduce_inst); |
| int num_partial_results = reduction_codegen_state.GetNumPartialResults(); |
| const auto* reduce_hlo = |
| Cast<HloReduceInstruction>(GetFusionOutput(fused_computation, index)); |
| |
| for (int op_result_idx = 0; op_result_idx < reduce_inst.getNumResults(); |
| op_result_idx++) { |
| const mlir::Value& result = reduce_inst.getResult(op_result_idx); |
| Shape result_shape = GetShape(result); |
| |
| llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( |
| result_shape.element_type(), ir_emitter_context_->llvm_module()); |
| llvm::AllocaInst* reduction_input_address = |
| llvm_ir::EmitAllocaAtFunctionEntry(element_type, |
| "reduction_input_address", &b_); |
| |
| llvm::AllocaInst* partial_result_address = |
| llvm_ir::EmitAllocaAtFunctionEntryWithCount( |
| element_type, /*element_count=*/b_.getInt32(num_partial_results), |
| ("partial_reduction_result." + llvm::Twine(index)).str(), &b_); |
| |
| const HloInstruction* init_value = |
| reduce_hlo->init_values()[op_result_idx]; |
| |
| // Initialize the partial result with the initial value of the reduction. |
| llvm::Value* init_ir_value = (*fused_emitter->GetGenerator(init_value))( |
| IrArray::Index(b_.getInt32Ty())) |
| .ValueOrDie(); |
| |
| for (int i = 0; i < num_partial_results; ++i) { |
| b_.CreateStore( |
| init_ir_value, |
| b_.CreateInBoundsGEP(partial_result_address, {b_.getInt32(i)})); |
| } |
| |
| const TilingScheme& tiling_scheme = |
| reduction_codegen_state.GetTilingScheme(); |
| int64_t num_threads_x = tiling_scheme.GetNumThreadsFor(kDimX); |
| llvm::Type* buffer_type = [&] { |
| if (reduction_codegen_state.IsRowReduction()) { |
| // Allocate __shared__ cache[num_partial_results][kWarpSize]. |
| // TODO(cheshire): Do we need the same trick as below to avoid bank |
| // conflicts? |
| return llvm::ArrayType::get( |
| llvm::ArrayType::get(element_type, kWarpSize), |
| num_partial_results); |
| } else { |
| // Allocate __shared__ |
| // cache[num_partial_results][num_threads][num_threads + 1], where |
| // num_threads == num_threads_x == num_threads_y. The "+1" is used to |
| // avoid bank conflicts. |
| CHECK_EQ(num_threads_x, tiling_scheme.GetNumThreadsFor(kDimY)); |
| return llvm::ArrayType::get( |
| llvm::ArrayType::get( |
| llvm::ArrayType::get(element_type, num_threads_x + 1), |
| num_threads_x), |
| num_partial_results); |
| } |
| }(); |
| llvm::GlobalVariable* shared_cache_per_reduce = |
| llvm_ir::AllocateSharedMemoryTile( |
| b_.GetInsertBlock()->getModule(), buffer_type, |
| absl::StrCat("shared_cache_", index)); |
| |
| llvm_ir::ElementGenerator input_gen = |
| *fused_emitter->GetGenerator(reduce_hlo->inputs()[op_result_idx]); |
| |
| auto calculation_state = ReductionCodegenState::ReductionCalculationState{ |
| /*shared_cache=*/shared_cache_per_reduce, |
| /*initial_value=*/init_ir_value, |
| /*partial_result_address=*/partial_result_address, |
| /*input_address=*/reduction_input_address, |
| /*input_gen=*/input_gen}; |
| reduction_codegen_state.SetCalculationStateFor(calculation_state, index, |
| op_result_idx); |
| } |
| } |
| |
| return reduction_codegen_state; |
| } |
| |
| void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce( |
| const HloComputation* reducer, |
| absl::Span<llvm::Value* const> partial_result_addresses, |
| int threads_per_block) { |
| // This only works when the block size is a multiple of 32 threads. |
| |
| // We check this here as a mistake in the number of threads per |
| // block is very hard to detect. |
| CHECK_EQ(threads_per_block % 32, 0); |
| |
| for (int distance = 16; distance >= 1; distance /= 2) { |
| absl::InlinedVector<llvm::Value*, 2> reduction_params; |
| |
| for (llvm::Value* acc : partial_result_addresses) { |
| reduction_params.push_back(acc); |
| } |
| |
| for (int oidx = 0; oidx < partial_result_addresses.size(); ++oidx) { |
| llvm::Type* element_type = llvm::cast<llvm::PointerType>( |
| partial_result_addresses[oidx]->getType()) |
| ->getElementType(); |
| int bit_width = llvm_ir::GetSizeInBits(element_type); |
| llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry( |
| element_type, "result_from_other_lane", &b_); |
| |
| reduction_params.push_back(result_from_other_lane); |
| |
| // Bitcast cannot be applied to aggregate types (even packed ones), so |
| // we bitcast addresses of load/store to intN* of the same bit-width. |
| llvm::Type* shuffled_value_type = |
| element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type; |
| auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { |
| return b_.CreatePointerBitCastOrAddrSpaceCast( |
| ptr, shuffled_value_type->getPointerTo()); |
| }; |
| |
| llvm::Value* partial_result_address = partial_result_addresses[oidx]; |
| llvm::Value* partial_result = |
| b_.CreateLoad(convert_pointer_for_shuffle(partial_result_address), |
| "partial_reduction_result"); |
| b_.CreateStore( |
| EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_), |
| convert_pointer_for_shuffle(result_from_other_lane)); |
| } |
| |
| StatusOr<std::vector<llvm::Value*>> returned_scalars = |
| ComputeNestedElementFromAddrs(*reducer, reduction_params); |
| TF_CHECK_OK(returned_scalars.status()); |
| |
| for (int i = 0; i < returned_scalars->size(); i++) { |
| b_.CreateStore(/*Val=*/returned_scalars->at(i), |
| /*Ptr=*/partial_result_addresses[i]); |
| } |
| } |
| } |
| |
| // Given the IrArray index of a reduction input, returns the linear address of |
| // the reduction output as if the reduction were going to keep the input shape |
| // with the dimensions being reduced moved. |
| static llvm::Value* GetUntransposedOutputLinearAddress( |
| llvm::IRBuilder<>* b, const llvm_ir::IrArray::Index& index, |
| const ReductionCodegenState& reduction_info) { |
| const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); |
| if (reduction_info.IsRowReduction()) { |
| // For row-reduction, y-coordinate determines which row we write into. |
| return index[kDimY]; |
| } |
| // For column reduction, we get the transposed address. |
| absl::Span<const int64_t> dims_in_elem = tiling_scheme.GetDimsInElems(); |
| llvm::Value* x_dim_size = index.GetConstantWithIndexType(dims_in_elem[kDimX]); |
| llvm::Value* x_block_offset = b->CreateMul(index[kDimZ], x_dim_size); |
| return b->CreateAdd(x_block_offset, index[kDimX]); |
| } |
| |
| static llvm::Value* GetOutputAddressForReduction( |
| const IrArray::Index& element_index, |
| const IrEmitterUnnested::ReductionOutputMap& output_arrays, |
| int instruction_idx, int output_idx, llvm::IRBuilder<>* b) { |
| const IrArray& output_array = output_arrays.at(instruction_idx)[output_idx]; |
| IrArray::Index output_index(element_index.multidim(), output_array.GetShape(), |
| element_index.GetType()); |
| return output_array.EmitArrayElementAddress(output_index, b, |
| "output_element_address"); |
| } |
| |
| void IrEmitterUnnested::EmitReductionOutput( |
| llvm::Type* index_ty, mlir::lmhlo::FusionOp fusion, |
| absl::Span<const int> reduce_instr_index_group, |
| const ReductionOutputMap& result_ir_arrays, |
| const absl::flat_hash_map<int, HloComputation*> reducers, |
| const ReductionCodegenState& reduction_codegen_state, |
| const TilingKernelInfo& tiling_kernel_info) { |
| const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); |
| auto constant = [&](uint64 c) -> llvm::Constant* { |
| return llvm::ConstantInt::get(index_ty, c); |
| }; |
| |
| IrEmitterUnnested::ThreadIdInfo thread_id_info = |
| EmitThreadIdInfo(tiling_scheme.GetNumThreadsPerBlock(), index_ty, |
| tiling_scheme.GetNumThreadsFor(kDimX)); |
| |
| IrArray::Index start_offset = [&] { |
| llvm::Value* x_loc = thread_id_info.thread_id_x; |
| llvm::Value* y_loc = thread_id_info.thread_id_y; |
| if (!reduction_codegen_state.IsRowReduction()) { |
| std::swap(x_loc, y_loc); |
| } |
| llvm::Value* start_offset_x = |
| GetStartOffsetX(tiling_scheme, x_loc, index_ty, &b_); |
| return tiling_kernel_info.tile_origin.AddOffsetToDim(y_loc, kDimY, &b_) |
| .AddOffsetToDim(start_offset_x, kDimX, &b_); |
| }(); |
| |
| for (int instruction_idx : reduce_instr_index_group) { |
| mlir::Operation* fusion_output = fusion.getFusionRoots()[instruction_idx]; |
| |
| auto reduce_hlo = mlir::cast<mlir::mhlo::ReduceOp>(fusion_output); |
| Shape operand_shape = GetShape(reduce_hlo.getOperand(0)); |
| Shape reduction_kept_element_shape = ShapeUtil::FilterDimensions( |
| [&](int64_t dim) { |
| return !absl::c_linear_search(reduce_hlo.dimensions(), dim); |
| }, |
| operand_shape); |
| |
| const HloComputation* reducer = reducers.at(instruction_idx); |
| |
| for (int partial_result_idx = 0; |
| partial_result_idx < reduction_codegen_state.GetNumPartialResults(); |
| ++partial_result_idx) { |
| llvm::Value* untransposed_output_linear_address = |
| GetUntransposedOutputLinearAddress( |
| &b_, |
| start_offset.AddOffsetToDim(constant(partial_result_idx), kDimX, |
| &b_), |
| reduction_codegen_state); |
| |
| // A reduction is allowed to transpose its output. For example, suppose |
| // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are |
| // allowed to produce as output either f32[10,30]{1,0} (no transpose) or |
| // f32[10,30]{0,1} (transposing the two output dims). |
| // |
| // At this point in the function we have a "partial sum" of input elements |
| // (stored in partial_result_addresses), and we need to accumulate it into |
| // the correct output element. |
| IrArray::Index element_index( |
| /*linear=*/untransposed_output_linear_address, |
| reduction_kept_element_shape, &b_); |
| if (reduction_codegen_state.IsRowReduction()) { |
| EmitReductionOutputForRowReduction(reducer, thread_id_info, |
| reduction_codegen_state, index_ty, |
| result_ir_arrays, element_index, |
| instruction_idx, partial_result_idx); |
| } else { |
| EmitReductionOutputForColumnReduction( |
| reducer, thread_id_info, reduction_codegen_state, index_ty, |
| result_ir_arrays, element_index, instruction_idx, |
| partial_result_idx, tiling_kernel_info); |
| } |
| } |
| } |
| } |
| |
| llvm::Value* IrEmitterUnnested::EmitBlockId() { |
| return gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBlockIdx, {}, |
| {}, &b_); |
| } |
| |
| void IrEmitterUnnested::EmitPrintfWithThreadId( |
| absl::string_view fmt, absl::Span<llvm::Value* const> arguments, |
| absl::optional<int64_t> thread_id_filter, |
| absl::optional<int64_t> block_id_filter) { |
| llvm::Value* thread_id = EmitThreadId(1024, b_.getInt32Ty()); |
| llvm::Value* block_id = EmitBlockId(); |
| std::vector<llvm::Value*> updated_arguments = {thread_id, block_id}; |
| updated_arguments.insert(updated_arguments.end(), arguments.begin(), |
| arguments.end()); |
| llvm::Value* constraint = b_.getTrue(); |
| if (thread_id_filter) { |
| constraint = b_.CreateAnd( |
| constraint, b_.CreateICmpEQ(thread_id, b_.getInt32(*thread_id_filter))); |
| } |
| if (block_id_filter) { |
| constraint = b_.CreateAnd( |
| constraint, b_.CreateICmpEQ(block_id, b_.getInt32(*block_id_filter))); |
| } |
| KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); |
| ksl.If(constraint, [&] { |
| xla::gpu::EmitPrintf(absl::StrCat("[TID=%d,BID=%d] ", fmt, "\n"), |
| updated_arguments, &b_); |
| }); |
| } |
| |
| llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input, |
| llvm::Twine name) { |
| return b_.CreateAddrSpaceCast( |
| input, |
| llvm::PointerType::get(input->getType()->getPointerElementType(), |
| /*AddressSpace=*/0), |
| name); |
| } |
| |
| void IrEmitterUnnested::EmitReductionOutputForRowReduction( |
| const HloComputation* reducer, |
| const IrEmitterUnnested::ThreadIdInfo& thread_id_info, |
| const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, |
| const ReductionOutputMap& output_arrays, |
| const llvm_ir::IrArray::Index& element_index, int reduction_idx, |
| int partial_result_idx) { |
| auto constant = [&](uint64 c) -> llvm::Constant* { |
| return llvm::ConstantInt::get(index_ty, c); |
| }; |
| auto is_zero = [&](llvm::Value* value) { |
| return b_.CreateICmpEQ(value, constant(0)); |
| }; |
| |
| int num_outputs = reducer->num_parameters() / 2; |
| const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); |
| absl::InlinedVector<llvm::Value*, 2> current_outputs; |
| for (int output_idx = 0; output_idx < num_outputs; output_idx++) { |
| const ReductionCodegenState::ReductionCalculationState& state = |
| reduction_codegen_state.GetCalculationStateFor(reduction_idx, |
| output_idx); |
| current_outputs.push_back( |
| b_.CreateInBoundsGEP(state.partial_result_address, |
| {constant(partial_result_idx)}, "current_output")); |
| } |
| |
| EmitFullWarpShuffleDownLoopForReduce(reducer, current_outputs, |
| tiling_scheme.GetNumThreadsPerBlock()); |
| |
| KernelSupportLibrary ksl(&b_); |
| llvm::Value* warp_id = |
| b_.CreateUDiv(thread_id_info.thread_id_x, constant(kWarpSize)); |
| |
| ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { |
| for (int oidx = 0; oidx < num_outputs; oidx++) { |
| const ReductionCodegenState::ReductionCalculationState& state = |
| reduction_codegen_state.GetCalculationStateFor(reduction_idx, oidx); |
| llvm::Value* shmem_output_addr = CastSharedToGlobal(b_.CreateInBoundsGEP( |
| state.shared_cache, |
| {b_.getInt32(0), constant(partial_result_idx), warp_id})); |
| b_.CreateStore(b_.CreateLoad(current_outputs[oidx]), shmem_output_addr); |
| } |
| }); |
| |
| // TODO(cheshire): Don't we want to sync it once for everything in the |
| // output? Not once per each? |
| EmitSyncThreads(); |
| ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { |
| absl::InlinedVector<llvm::Value*, 2> selected_values; |
| for (int oidx = 0; oidx < num_outputs; oidx++) { |
| const ReductionCodegenState::ReductionCalculationState& state = |
| reduction_codegen_state.GetCalculationStateFor(reduction_idx, oidx); |
| llvm::Value* block_accum_addr = CastSharedToGlobal(b_.CreateInBoundsGEP( |
| state.shared_cache, {b_.getInt32(0), constant(partial_result_idx), |
| thread_id_info.lane_id})); |
| |
| llvm::Type* element_type = |
| state.partial_result_address->getType()->getElementType(); |
| |
| llvm::Value* initial_value_addr = |
| CastSharedToGlobal(llvm_ir::EmitAllocaAtFunctionEntry( |
| element_type, "initial_value_addr", &b_)); |
| b_.CreateStore(state.initial_value, initial_value_addr); |
| |
| llvm::Value* warp_exists = b_.CreateICmpULT( |
| thread_id_info.thread_id_x, |
| constant(tiling_scheme.GetNumThreadsFor(kDimX) / kWarpSize)); |
| |
| llvm::Value* selected_value = |
| b_.CreateSelect(warp_exists, block_accum_addr, initial_value_addr); |
| |
| selected_values.push_back(selected_value); |
| } |
| |
| EmitFullWarpShuffleDownLoopForReduce(reducer, selected_values, |
| tiling_scheme.GetNumThreadsPerBlock()); |
| |
| ksl.If("reduction_write_output", is_zero(thread_id_info.thread_id_x), [&] { |
| if (reduction_codegen_state.IsRaceFree()) { |
| VLOG(10) << "Using deterministic reductions: writing out " |
| "the value directly"; |
| for (int oidx = 0; oidx < num_outputs; oidx++) { |
| llvm::Value* output_address = GetOutputAddressForReduction( |
| element_index, output_arrays, reduction_idx, oidx, &b_); |
| |
| b_.CreateStore(b_.CreateLoad(selected_values[oidx], "output"), |
| output_address); |
| } |
| } else { |
| CHECK_EQ(selected_values.size(), 1) |
| << "Variadic non-atomic reductions not supported"; |
| llvm::Value* output_address = GetOutputAddressForReduction( |
| element_index, output_arrays, reduction_idx, |
| /*output_idx=*/0, &b_); |
| TF_CHECK_OK(EmitAtomicOperationForNestedComputation( |
| *reducer, output_address, selected_values[0])); |
| } |
| }); |
| }); |
| } |
| |
| void IrEmitterUnnested::EmitReductionOutputForColumnReduction( |
| const HloComputation* reducer, |
| const IrEmitterUnnested::ThreadIdInfo& thread_id_info, |
| const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, |
| const ReductionOutputMap& output_arrays, |
| const llvm_ir::IrArray::Index& element_index, int reduction_idx, |
| int partial_result_idx, const TilingKernelInfo& tiling_kernel_info) { |
| KernelSupportLibrary ksl(&b_); |
| |
| auto constant = [&](uint64 c) -> llvm::Constant* { |
| return llvm::ConstantInt::get(index_ty, c); |
| }; |
| auto is_zero = [&](llvm::Value* value) { |
| return b_.CreateICmpEQ(value, constant(0)); |
| }; |
| const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); |
| |
| // Store the transpose in shared memory. |
| for (int output_idx = 0; output_idx < reducer->num_parameters() / 2; |
| output_idx++) { |
| const ReductionCodegenState::ReductionCalculationState& state = |
| reduction_codegen_state.GetCalculationStateFor(reduction_idx, |
| output_idx); |
| llvm::GlobalVariable* shared_cache = state.shared_cache; |
| llvm::Value* shmem_output_addr = CastSharedToGlobal( |
| b_.CreateInBoundsGEP( |
| shared_cache, |
| {b_.getInt32(0), constant(partial_result_idx), |
| thread_id_info.thread_id_x, thread_id_info.thread_id_y}), |
| "shmem_output_address"); |
| llvm::Value* current_output = |
| b_.CreateInBoundsGEP(state.partial_result_address, |
| {constant(partial_result_idx)}, "current_output"); |
| |
| llvm::Value* current_output_value = b_.CreateLoad(current_output); |
| b_.CreateStore(current_output_value, shmem_output_addr); |
| } |
| |
| EmitSyncThreads(); |
| |
| // Get transposed element from shared memory. |
| absl::InlinedVector<llvm::Value*, 1> shmem_transposed_addrs; |
| |
| for (int output_idx = 0; output_idx < reducer->num_parameters() / 2; |
| output_idx++) { |
| const ReductionCodegenState::ReductionCalculationState& state = |
| reduction_codegen_state.GetCalculationStateFor(reduction_idx, |
| output_idx); |
| llvm::Value* shmem_transposed_addr = |
| CastSharedToGlobal(b_.CreateInBoundsGEP( |
| state.shared_cache, |
| {b_.getInt32(0), constant(partial_result_idx), |
| thread_id_info.thread_id_y, thread_id_info.thread_id_x}, |
| "shmem_transposed_addr")); |
| shmem_transposed_addrs.push_back(shmem_transposed_addr); |
| } |
| |
| EmitFullWarpShuffleDownLoopForReduce(reducer, shmem_transposed_addrs, |
| tiling_scheme.GetNumThreadsPerBlock()); |
| |
| // Some warps in the block are completely outside of the bound of the |
| // tensor, so they should not write any output at all. |
| llvm::Value* has_output = b_.CreateAnd( |
| b_.CreateICmpULT( |
| GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_y, index_ty, |
| &b_), |
| tiling_kernel_info.output_tile_bounds[kDimX]), |
| b_.CreateICmpULT(thread_id_info.thread_id_x, |
| tiling_kernel_info.output_tile_bounds[kDimY])); |
| |
| ksl.If("reduction_write_output", |
| b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { |
| if (reduction_codegen_state.IsRaceFree()) { |
| VLOG(10) << "Using deterministic reductions: writing out " |
| "the value directly"; |
| for (int output_idx = 0; |
| output_idx < reducer->num_parameters() / 2; output_idx++) { |
| llvm::Value* output_address = |
| GetOutputAddressForReduction(element_index, output_arrays, |
| reduction_idx, output_idx, &b_); |
| b_.CreateStore(b_.CreateLoad(shmem_transposed_addrs[output_idx], |
| "output_value"), |
| output_address); |
| } |
| } else { |
| CHECK_EQ(shmem_transposed_addrs.size(), 1) |
| << "Variadic non-atomic reductions not supported"; |
| llvm::Value* output_address = GetOutputAddressForReduction( |
| element_index, output_arrays, reduction_idx, |
| /*output_idx=*/0, &b_); |
| TF_CHECK_OK(EmitAtomicOperationForNestedComputation( |
| *reducer, output_address, shmem_transposed_addrs[0])); |
| } |
| }); |
| } |
| |
| llvm::Value* IrEmitterUnnested::EmitThreadId(int64_t threads_per_block, |
| llvm::Type* index_ty) { |
| // Calculate (y, x) coordinates respectively in the 2D view of thread block, |
| // defined by (num_thread_y, num_thread_x) from thread_id. |
| llvm::CallInst* thread_id_raw = gpu::EmitCallToTargetIntrinsic( |
| gpu::TargetIntrinsicID::kThreadIdx, {}, {}, &b_); |
| llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id_raw); |
| return b_.CreateIntCast(thread_id_raw, index_ty, |
| /*isSigned=*/true, "thread.id.x"); |
| } |
| |
| IrEmitterUnnested::ThreadIdInfo IrEmitterUnnested::EmitThreadIdInfo( |
| int64_t threads_per_block, llvm::Type* index_ty, int64_t num_threads_x) { |
| auto constant = [&](uint64 c) -> llvm::Constant* { |
| return llvm::ConstantInt::get(index_ty, c); |
| }; |
| llvm::Value* thread_id = EmitThreadId(threads_per_block, index_ty); |
| llvm::Value* num_threads_x_v = constant(num_threads_x); |
| return { |
| /*thread_id=*/thread_id, |
| /*thread_id_x=*/b_.CreateURem(thread_id, num_threads_x_v, "thread_id.x"), |
| /*thread_id_y=*/b_.CreateUDiv(thread_id, num_threads_x_v, "thread_id.y"), |
| /*lane_id=*/b_.CreateURem(thread_id, constant(kWarpSize), "lane_id")}; |
| } |
| |
| IrEmitterUnnested::TilingKernelInfo IrEmitterUnnested::EmitTilingKernel( |
| const TilingScheme& tiling_scheme, llvm::Type* index_ty, |
| const TileElementGenerator& tile_element_generator) { |
| absl::Span<const int64_t> dims_in_elems = tiling_scheme.GetDimsInElems(); |
| Vector3 dims_in_blocks = tiling_scheme.GetDimsInBlocks(); |
| auto constant = [&](uint64 c) -> llvm::Constant* { |
| return llvm::ConstantInt::get(index_ty, c); |
| }; |
| |
| IrEmitterUnnested::ThreadIdInfo thread_id_info = |
| EmitThreadIdInfo(tiling_scheme.GetNumThreadsPerBlock(), index_ty, |
| tiling_scheme.GetNumThreadsFor(kDimX)); |
| |
| KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); |
| |
| const IrArray::Index block_coords = [&] { |
| llvm::Value* block_id = EmitBlockId(); |
| llvm_ir::AddRangeMetadata(0, tiling_scheme.GetNumberOfBlocks(), |
| llvm::cast<llvm::Instruction>(block_id)); |
| llvm::Value* linear_block_id = |
| b_.CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x"); |
| IrArray::Index starting_block(linear_block_id, |
| ShapeUtil::MakeShapeWithDescendingLayout( |
| PRED /*arbitrary*/, dims_in_blocks), |
| &b_); |
| |
| std::vector<llvm::Value*> multidim = { |
| b_.CreateMul(starting_block[0], |
| constant(tiling_scheme.GetBlockTileSizeFor(0)), |
| "block_origin.z"), |
| starting_block[1], starting_block[2]}; |
| return IrArray::Index(multidim, dims_in_blocks, index_ty); |
| }(); |
| |
| std::array<llvm::Value*, 3> output_tile_bounds; |
| for (int i = kDimY; i < kDimTot; ++i) { |
| int64_t tile_size_for_dim = tiling_scheme.GetBlockTileSizeFor(i); |
| // Only last row or column may not have full size. |
| llvm::Value* is_last = |
| b_.CreateICmpEQ(block_coords[i], constant(dims_in_blocks[i] - 1)); |
| int64_t partial_row = |
| dims_in_elems[i] - (dims_in_blocks[i] - 1) * tile_size_for_dim; |
| output_tile_bounds[i] = |
| b_.CreateSelect(is_last, constant(partial_row), |
| constant(tile_size_for_dim), "tile_bound"); |
| } |
| |
| IrArray::Index tile_origin = [&] { |
| std::vector<llvm::Value*> elem_multi_index = block_coords.multidim(); |
| llvm::Type* index_ty = block_coords.GetType(); |
| for (int i = kDimY; i < kDimTot; ++i) { |
| elem_multi_index[i] = |
| b_.CreateMul(block_coords[i], |
| llvm::ConstantInt::get( |
| index_ty, tiling_scheme.GetBlockTileSizeFor(i)), |
| "tile_origin." + std::to_string(i)); |
| } |
| return IrArray::Index(elem_multi_index, tiling_scheme.GetDimsInElems(), |
| index_ty); |
| }(); |
| |
| auto emit_tile = [&](const IrArray::Index& tile) { |
| tile_element_generator(thread_id_info, tile, "output", |
| output_tile_bounds[1], output_tile_bounds[2], &ksl); |
| }; |
| |
| if (tiling_scheme.GetBlockTileSizeFor(kDimZ) == 1) { |
| emit_tile(tile_origin); |
| } else { |
| llvm::Value* starting_tile_index_for_dim = tile_origin[kDimZ]; |
| llvm::Value* block_size_for_dim = |
| constant(tiling_scheme.GetBlockTileSizeFor(kDimZ)); |
| llvm::Value* block_id_for_dim = |
| b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim); |
| llvm::Value* last_block_for_dim = constant(dims_in_blocks[kDimZ] - 1); |
| llvm::Value* last_block_size_for_dim = constant( |
| dims_in_elems[kDimZ] - |
| (dims_in_blocks[kDimZ] - 1) * tiling_scheme.GetBlockTileSizeFor(kDimZ)); |
| |
| llvm::Value* num_tiles_in_block = |
| b_.CreateSelect(b_.CreateICmpEQ(last_block_for_dim, block_id_for_dim), |
| last_block_size_for_dim, block_size_for_dim); |
| ksl.For("loop_z", |
| /*start=*/constant(0), |
| /*end=*/num_tiles_in_block, |
| /*step=*/1, [&](llvm::Value* block_dim_induction_var) { |
| IrArray::Index tile_index = tile_origin.AddOffsetToDim( |
| block_dim_induction_var, kDimZ, &b_); |
| emit_tile(tile_index); |
| }); |
| } |
| return {output_tile_bounds, tile_origin}; |
| } |
| |
| llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() { |
| return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_); |
| } |
| |
| // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose |
| // algorithm to improve the memory access patterns for the input parameters |
| // with a shape that is a 0-2-1 transpose of the output tensor shape. The caller |
| // is responsible for making sure that it is safe to apply the shared memory |
| // transpose on the input parameters. |
| // |
| // |
| // For the purpose of tiling, the output tensors have a logical shape of three |
| // components 0-2-1 while the relevant input parameters have a logical shape |
| // of three components 0-1-2 in the order major to minor. The x- and y- |
| // dimensions of the tensors are tiled in square tiles with an edge length |
| // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads |
| // transposes one tile: each thread copies kTileSize/kNumRows elements from |
| // the input to a shared memory tile, then the otherwise "regular HLO kernel" |
| // reads from the shared memory instead of the original input. |
| // |
| // This is similar to the following CUDA algorithm in TensorFlow: |
| // https://goo.gl/MStRV6. |
| // |
| // `kTileSize` should usually be same as warp size. We currently choose 32 for |
| // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. |
| // |
| // TODO(b/33320379): Here each block transposes 1 tile. It may be more |
| // efficient to launch fewer blocks so each transposes many tiles. |
| void IrEmitterUnnested::EmitHlo021Tile( |
| mlir::Operation* op, Thunk* kernel_thunk, const MlirEmitterContext& context, |
| absl::Span<const llvm_ir::IrArray> operand_arrays, |
| absl::Span<const llvm_ir::IrArray> output_arrays, |
| absl::Span<const int64_t> reduced_output_dims, |
| absl::Span<const int64_t> tiled_param_ids, |
| const TilingScheme& tiling_scheme, |
| const LaunchDimensions& launch_dimensions) { |
| std::string name = mlir::GetNameFromLoc(op->getLoc()); |
| |
| llvm::Type* index_type = |
| GetIndexTypeForKernel(op, launch_dimensions.launch_bound(), &b_); |
| std::vector<IrArray> param_arrays; |
| |
| // For each tiled parameter, cast its input IrArray to the corresponding |
| // reduced shape and keep the reduced shape live during IR emission. |
| std::vector<IrArray> param_in_reduced_shape_arrays; |
| std::vector<llvm::Value*> param_shmem_buffers(context.operand_shapes.size(), |
| nullptr); |
| |
| auto get_shared_memory_buffer = [&](llvm::Type* elem_ty, |
| absl::string_view buffer_name) { |
| // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank |
| // is organized into 32-way. We usually use the warp size or a multiplier or |
| // a the warp size as the size for tiling. This may cause all elements in |
| // the same column of a tile use the same memory bank and therefore shared |
| // memory bank conflicts. Adding 1 to the minor dimension of the shared |
| // memory buffer can reduce such shared memory bank conflicts. |
| llvm::Type* buffer_type = llvm::ArrayType::get( |
| llvm::ArrayType::get(elem_ty, |
| tiling_scheme.GetBlockTileSizeFor(kDimX) + 1), |
| tiling_scheme.GetBlockTileSizeFor(kDimY)); |
| return llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(), |
| buffer_type, buffer_name); |
| }; |
| |
| for (int64_t id = 0; id < context.operand_shapes.size(); id++) { |
| const Shape& param_shape = context.operand_shapes[id]; |
| param_arrays.push_back(operand_arrays[id]); |
| |
| if (absl::c_linear_search(tiled_param_ids, id)) { |
| param_shmem_buffers[id] = get_shared_memory_buffer( |
| llvm_ir::PrimitiveTypeToIrType(param_shape.element_type(), module_), |
| IrName(name, StrCat("tile", id))); |
| VLOG(3) << "Added shmem buffer for parameter " << id << ": " |
| << llvm_ir::DumpToString(*param_shmem_buffers[id]); |
| Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( |
| param_shape.element_type(), Permute(reduced_output_dims, {0, 2, 1})); |
| param_in_reduced_shape_arrays.push_back( |
| param_arrays[id].CastToShape(reduced_shape, &b_)); |
| } else { |
| param_in_reduced_shape_arrays.push_back(IrArray()); |
| } |
| } |
| |
| EmitElementFunction element_generator = |
| [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, |
| llvm::Value* x_loc, int64_t x_iter_num) { |
| auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op); |
| EmitTileElementForFusion(fusion, operand_arrays, output_arrays, index, |
| tiling_scheme, y_loc, x_loc, |
| param_shmem_buffers); |
| }; |
| |
| TileElementGenerator tile_generator = |
| [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index, |
| const string& loop_name, llvm::Value* tile_height, |
| llvm::Value* tile_width, KernelSupportLibrary* ksl) { |
| // If shared memory transpose is needed, wait for all threads to reach |
| // this point, lest we copy a value from tile to output before the other |
| // thread copies it from input to tile. This is `__syncthreads` in CUDA. |
| if (!tiled_param_ids.empty()) { |
| // Calculate the input tile origin from the output tile origin. |
| const IrArray::Index input_tile_origin( |
| Permute(index.multidim(), {0, 2, 1}), |
| Permute(index.dims(), {0, 2, 1}), index.GetType()); |
| |
| // Copy input parameter values to shared memory buffers: |
| // tile[thread_id_y, thread_id_x] = input[index] |
| // Note that tile_width and tile_height are flipped here because we |
| // are reading a transposed tile. |
| EmitTile(tiling_scheme, input_tile_origin, "input", ksl, |
| thread_id_info, tile_width, tile_height, |
| [&](const IrArray::Index& index, llvm::Value* y_loc, |
| llvm::Value* x_loc, int64_t /*x_iter_num*/) { |
| for (int64_t id : tiled_param_ids) { |
| IrArray& input_in_logical_shape = |
| param_in_reduced_shape_arrays.at(id); |
| |
| llvm::Value* shmem_buffer = param_shmem_buffers.at(id); |
| llvm::Value* zero = |
| llvm::ConstantInt::get(index_type, 0); |
| // TODO(jlebar): Add AA metadata to this store. Tile |
| // buffers are global variables, so LLVM can't infer much |
| // about it. |
| auto value = input_in_logical_shape.EmitReadArrayElement( |
| index, &b_, "input_element"); |
| auto addr = GEP(shmem_buffer, {zero, y_loc, x_loc}); |
| Store(value, addr); |
| } |
| }); |
| |
| // Wait for all threads to reach this point using `__syncthreads` in |
| // CUDA. |
| EmitSyncThreads(); |
| } |
| |
| EmitTile(tiling_scheme, index, loop_name, ksl, thread_id_info, |
| tile_height, tile_width, element_generator); |
| bool block_contains_multi_tiles = |
| tiling_scheme.GetBlockTileSizeFor(kDimZ) > 1; |
| |
| // If a tile block contains multiple tiles and shared memory buffers are |
| // used, we need to wait for all threads to finish using the shared |
| // memory buffer for the current tile before we move on to process the |
| // next tile and overwrite the shared memory buffers. |
| if (block_contains_multi_tiles && !tiled_param_ids.empty()) { |
| EmitSyncThreads(); |
| } |
| }; |
| |
| EmitTilingKernel(tiling_scheme, index_type, tile_generator); |
| } |
| |
| namespace { |
| |
| // A recursive function to inspect the users of a parameter to determine |
| // whether it's safe for a parameter to participate in a shared-memory |
| // transpose. |
| // |
| // Consider a fusion parameter P for which we might want to use a shmem |
| // transpose. If we do, we use a GPU thread block to preload a tile of P with |
| // indices [z, y..y+31, x..x+31] to compute an output tile with the same indices |
| // cooperatively, where z, y, x are the indices for the normalized input/output |
| // tensor (see the document for FindTranspose021 for the definition of |
| // normalized tensor for 0-2-1 transpose). This shmem transpose implementation |
| // requires that the computation of the output tile only read elements within |
| // the preload tile. If this is not true, we can't use a shmem transpose for P. |
| // |
| // If the computation of output element [z, y, x] only requires the element of |
| // P with the same indices, the shmem transpose implementation can be applied |
| // to P safely. This is a sufficient but not necessary condition. We check all |
| // the transitive users of P to see if we can find a user that may cause an |
| // exception to the situation. If such a user is not found, we conclude that P |
| // is safe for shmem transpose. |
| // |
| // This is trivially true for elementwise operations and some "data-movement" |
| // ops like kTuple. However, it's not true for operations that can change the |
| // dimensions of the inputs (e.g. pad, slice) and bitcast operation. |
| // For example: |
| // |
| // fused_computation { |
| // param_0 = f32[64,64]{1,0} parameter(0) |
| // ROOT bitcast = f32[64,64]{0,1} bitcast(param_0) |
| // } |
| // The output element at logical address [0, 63] depends on the input element |
| // at logical address [63, 0], which would not be within the shared-memory |
| // block. |
| // |
| // TODO(bixia): In order to extend this for kInput fusion, that is reduction |
| // with transpose, we only need to end the use-chain checking with the input of |
| // a reduce operations. In this case, the above description on "output" apply |
| // to the result of such a use-chain, which provides the input to the reduce |
| // operation. |
| bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) { |
| if (mlir::isa<mlir::memref::TensorStoreOp>(op)) { |
| return true; |
| } |
| |
| HloOpcode opcode; |
| if (mlir::isa<mlir::memref::TensorLoadOp>(op)) { |
| opcode = HloOpcode::kParameter; |
| } else { |
| opcode = *MhloToHloOpcode(op); |
| } |
| if (HloInstruction::IsOpElementwise(opcode)) { |
| for (mlir::Value v : op->getResults()) { |
| for (mlir::OpOperand use : v.getUsers()) { |
| if (!IsInstructionSafeForShmemTranspose(use.getOwner())) { |
| return false; |
| } |
| } |
| } |
| return true; |
| } |
| |
| switch (opcode) { |
| // Non-elementwise instructions that don't cause the shmem transpose |
| // to be unsafe, including the instructions that don't currently fuse. |
| case HloOpcode::kGetDimensionSize: |
| // The result of the operation doesn't rely on the content of the |
| // tensor. As such, there is no need to further inspect its users. |
| return true; |
| case HloOpcode::kGetTupleElement: |
| case HloOpcode::kMap: |
| case HloOpcode::kParameter: |
| case HloOpcode::kTuple: |
| case HloOpcode::kTupleSelect: |
| for (mlir::Value v : op->getResults()) { |
| for (mlir::OpOperand use : v.getUsers()) { |
| if (!IsInstructionSafeForShmemTranspose(use.getOwner())) { |
| return false; |
| } |
| } |
| } |
| return true; |
| |
| default: |
| return false; |
| } |
| } |
| |
| // Given a group of input parameters that are 0-2-1 transpose of the outputs of |
| // a fusion kernel, returns the input parameters that are safe for the shared |
| // memory transpose implementation. |
| // |
| // When a tile based shared memory transpose is used to implement an input with |
| // 0-2-1 transpose, we preload a tile of the input elements |
| // [z, y..y+31, x..x+31] to compute the output tile elements of the same |
| // indices. Preloading the input tile this way is only safe when the computation |
| // of the output tile elements do not need any input element outside the |
| // preloaded tile. We inspect all the transitive users of the input parameter |
| // up to the fusion root instruction to see if we can find any instruction |
| // that can make preloading the input tile unsafe. |
| std::vector<int64_t> FilterInputsForShmemTranspose( |
| mlir::lmhlo::FusionOp fusion, std::vector<int64_t> input_ids) { |
| std::vector<mlir::Value> params = ToStdVector(fusion.getFusionParameters()); |
| |
| std::vector<int64_t> filtered_input_ids; |
| for (int64_t input_id : input_ids) { |
| mlir::Value input = params.at(input_id); |
| if (IsInstructionSafeForShmemTranspose(input.getDefiningOp())) { |
| filtered_input_ids.push_back(input_id); |
| } |
| } |
| return filtered_input_ids; |
| } |
| |
| } // namespace |
| |
| StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021( |
| mlir::Operation* op) { |
| CHECK(mlir::isa<mlir::lmhlo::FusionOp>(op)); |
| |
| MlirEmitterContext context; |
| context.SetOperation(op); |
| |
| // If the output_shape is reduced to 021 shape, find all the parameters of |
| // the HLO that are in the corresponding 012 shape. |
| std::vector<int64_t> params_012; |
| optional<std::vector<int64_t>> reduced_dims_021; |
| for (int64_t operand_idx = 0; operand_idx < context.operand_shapes.size(); |
| ++operand_idx) { |
| const Shape& operand_shape = context.operand_shapes[operand_idx]; |
| auto find_transpose_result = |
| ShapeUtil::FindTranspose021(operand_shape, context.output_shapes[0]); |
| if (!find_transpose_result.has_value()) { |
| continue; |
| } |
| const std::vector<int64_t>& curr_reduced_dims_021 = *find_transpose_result; |
| if (!reduced_dims_021.has_value()) { |
| reduced_dims_021 = curr_reduced_dims_021; |
| } |
| if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) { |
| // There is more than one possible transpose. Instead of picking one |
| // transpose, we simply give up here. |
| return false; |
| } |
| params_012.push_back(operand_idx); |
| } |
| |
| if (!reduced_dims_021.has_value()) { |
| return false; |
| } |
| |
| if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled || |
| (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) { |
| return false; |
| } |
| |
| if (auto fusion_op = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) { |
| params_012 = FilterInputsForShmemTranspose(fusion_op, params_012); |
| if (params_012.empty()) { |
| return false; |
| } |
| } |
| |
| // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the |
| // elements are of size 4 bytes), and CUDA has an architectural limit of |
| // 48kb shared memory per SM. (This is increased to 96kb in Volta, but we |
| // don't use this, in part because it eats into our L1 cache space.) |
| // |
| // For correctness we need to ensure that we don't make more than 48kb worth |
| // of shmem tiles per block. And for performance, we'd probably like to use |
| // significantly less, so that we can fit more than one block at a time on a |
| // gpu core. |
| // |
| // We say without benchmarks that we want at least 3 threads/block, |
| // corresponding to 3 shmem tiles if the elements are 32 bits wide. We |
| // choose which params get the shmem transpose treatment arbitrarily; it's |
| // not clear if there's a Right Choice. |
| // |
| // This is only sound if tiled transposes are the only place where we use |
| // shared memory in fusions. If in the future other fusible ops use shared |
| // memory, we'll have to adjust this heuristic. |
| constexpr int kMinBlocksPerCore = 3; |
| constexpr int64_t kShmemPerCore = 48 * 1024; |
| int64_t shmem_used = 0; |
| for (int64_t i = 0; i < params_012.size(); ++i) { |
| const Shape& operand_shape = context.operand_shapes[params_012[i]]; |
| shmem_used += |
| 32 * 33 * |
| ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type()); |
| |
| if (kMinBlocksPerCore * shmem_used > kShmemPerCore) { |
| // Erase this element and everything after it from params_012. |
| params_012.resize(i); |
| break; |
| } |
| } |
| |
| if (params_012.empty()) { |
| return false; |
| } |
| |
| constexpr int kNumRows = 4; |
| CHECK_EQ(kWarpSize % kNumRows, 0); |
| TilingScheme tiling_scheme(*reduced_dims_021, |
| /*tile_sizes=*/{1, kWarpSize / kNumRows, 1}, |
| /*num_threads=*/{1, kNumRows, kWarpSize}, |
| /*indexing_order=*/kLinearIndexingX, |
| /*vector_size=*/1); |
| LaunchDimensions launch_dimensions(tiling_scheme.GetNumberOfBlocks(), |
| tiling_scheme.GetNumThreadsPerBlock()); |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| TF_ASSIGN_OR_RETURN( |
| std::unique_ptr<Thunk> kernel_thunk, |
| BuildKernelThunk(op, GetThunkInfo(op), &ir_arrays, launch_dimensions)); |
| |
| EmitHlo021Tile( |
| op, kernel_thunk.get(), context, |
| absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size()), |
| absl::MakeSpan(ir_arrays).subspan(context.operand_shapes.size()), |
| *reduced_dims_021, params_012, tiling_scheme, launch_dimensions); |
| AddThunkToThunkSequence(std::move(kernel_thunk)); |
| return true; |
| } |
| |
| namespace { |
| |
| // Returns true if all the transitive users of hlo before hitting users in |
| // use_chain_endings are elementwise operations. |
| bool AreUsersElementwise( |
| mlir::Value value, |
| const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) { |
| return absl::c_all_of(value.getUsers(), [&](mlir::OpOperand use) { |
| mlir::Operation* user = use.getOwner(); |
| CHECK_EQ(1, user->getNumResults()); |
| return use_chain_endings.count(user) || |
| (HloInstruction::IsOpElementwise(*MhloToHloOpcode(user)) && |
| AreUsersElementwise(user->getResult(0), use_chain_endings)); |
| }); |
| } |
| |
| // Returns the number of fusion inputs that have the same dimension as the |
| // given shape, and involve in only elementwise operations. |
| int64_t NumInputsInvolveInOnlyElementwiseOps( |
| mlir::lmhlo::FusionOp fusion, const Shape& op_shape, |
| const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) { |
| return absl::c_count_if( |
| fusion.getFusionParameters(), [&](mlir::Value parameter) { |
| Shape parameter_shape = GetShape(parameter); |
| return ShapeUtil::SameDimensions(op_shape, parameter_shape) && |
| AreUsersElementwise(parameter, use_chain_endings); |
| }); |
| } |
| |
| // Returns the number of fusion inputs that have more elements than the given |
| // shape. |
| int64_t NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion, |
| const Shape& shape) { |
| int64_t num_elements = ShapeUtil::ElementsIn(shape); |
| return absl::c_count_if( |
| fusion.getFusionParameters(), [&](mlir::Value parameter) { |
| Shape parameter_shape = GetShape(parameter); |
| return ShapeUtil::ElementsIn(parameter_shape) > num_elements; |
| }); |
| } |
| |
| // The benefit of unrolling a kInput fusion that is a column reduction comes |
| // from the vectorization of non-reduction fusion outputs and fusion inputs. |
| // On the other hand, unrolling can also introduce factors that can cause |
| // the kernel to run slower. This routine uses a simple heuristic to estimate |
| // the benefit as well as the overhead of unrolling in order to decide whether |
| // unrolling is beneficial for the given kInput fusion. |
| bool IsUnrollingColumnReductionBeneficial(mlir::lmhlo::FusionOp fusion, |
| const Shape& input_shape, |
| int64_t num_kept_minor) { |
| if (num_kept_minor % (kWarpSize * 2) != 0) { |
| return false; |
| } |
| |
| if (input_shape.dimensions()[input_shape.rank() - 1] < 64) { |
| return false; |
| } |
| |
| int64_t can_be_vectorized = 0; |
| int64_t cannot_be_vectorized = 0; |
| llvm::SmallVector<mlir::Operation*> fusion_roots = fusion.getFusionRoots(); |
| absl::flat_hash_set<mlir::Operation*> use_chain_endings; |
| if (fusion_roots.size() == 1) { |
| if (IsReductionFromOrToContiguousDimensions(fusion_roots[0])) { |
| use_chain_endings.insert(fusion_roots[0]); |
| // Atomic.add of the reduction result can't be vectorized. |
| cannot_be_vectorized++; |
| } |
| } else { |
| for (mlir::Operation* op : fusion_roots) { |
| if (IsReductionFromOrToContiguousDimensions(op)) { |
| // Atomic.add of the reduction result can't be vectorized. |
| cannot_be_vectorized++; |
| } else { |
| // Write of the non-reduction result can be vectorized. |
| can_be_vectorized++; |
| } |
| use_chain_endings.insert(op); |
| } |
| } |
| // Fusion inputs that have the same dimension as the reduce input and |
| // only involve in elementwise operations can be vectorized. |
| can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps(fusion, input_shape, |
| use_chain_endings); |
| // Fusion inputs with more elements than the reduce op input must participate |
| // in non-elementwise operations and we assume that they are not vectorizable |
| // for the purpose of estimating the benefit of unrolling. If the kernel is |
| // unrolled even with such an assumption, and the accesses to those inputs |
| // turn out to be vectorizable, the compiler will still vectorize them. |
| cannot_be_vectorized += NumInputsWithMoreElementsThan(fusion, input_shape); |
| return can_be_vectorized >= cannot_be_vectorized; |
| } |
| |
| int64_t NearestPowerOfTwo(int64_t v) { |
| if (v < 0) { |
| return 0; |
| } |
| int64_t upper = tensorflow::NextPowerOfTwo64(v); |
| int64_t lower = upper >> 1; |
| return upper - v < v - lower ? upper : lower; |
| } |
| |
| } // namespace |
| |
| // Returns primitive bitwidth for shape of the value. |
| static int GetPrimitiveBitwidth(mlir::Value i) { |
| // TODO(timshen): may not be efficient. |
| return primitive_util::BitWidth(GetShape(i).element_type()); |
| } |
| |
| ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( |
| mlir::lmhlo::FusionOp fusion, mlir::mhlo::ReduceOp first_reduce) { |
| Shape input_shape = GetShape(first_reduce->getOperand(0)); |
| ReductionDimensions reduction_dimensions = |
| GetReductionKindAndContiguousComponents(first_reduce); |
| VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction |
| << " " << reduction_dimensions.dimensions[0] << " " |
| << reduction_dimensions.dimensions[1] << " " |
| << reduction_dimensions.dimensions[2]; |
| std::array<int64_t, 3> reduction_tiling = GetReductionTiling( |
| reduction_dimensions, ir_emitter_context_->cuda_compute_capability()); |
| |
| int64_t num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize; |
| int64_t num_threads_x = [&] { |
| if (reduction_dimensions.is_row_reduction) { |
| // Use 512 as default block size (threads per block) for row reductions. |
| // For multi-output fusions, reduce the block size further to decrease |
| // register pressure when multiple outputs are computed by each thread. |
| int64_t fan_out = fusion.getFusionRoots().size(); |
| int64_t max_block_size = |
| std::max(kMinThreadsXRowReduction, |
| static_cast<int64_t>(512LL / NearestPowerOfTwo(fan_out))); |
| return std::min( |
| max_block_size, |
| RoundUpToNearest(CeilOfRatio(reduction_dimensions.dimensions[2], |
| reduction_tiling[2]), |
| kWarpSize)); |
| } |
| return kWarpSize; |
| }(); |
| |
| se::CudaComputeCapability cc = ir_emitter_context_->cuda_compute_capability(); |
| |
| int num_partial_results = 1; |
| |
| int smallest_input_dtype_bits = std::numeric_limits<int>::max(); |
| for (mlir::Value operand : fusion.getInputBuffers()) { |
| smallest_input_dtype_bits = |
| std::min(GetPrimitiveBitwidth(operand), smallest_input_dtype_bits); |
| } |
| |
| TilingScheme::IndexingOrder indexing_order = [&]() { |
| if (reduction_dimensions.is_row_reduction && |
| // P100, only try to vectorize+coales memory access when the |
| // tile size fits exactly and dtypes <= 32 bits |
| ((cc.major == 6 && smallest_input_dtype_bits <= 32 && |
| reduction_dimensions.dimensions[kDimX] % |
| (reduction_tiling[2] * num_threads_x) == |
| 0) || |
| // On V100, only try to vectorize+coales memory access for |
| // rows of even size. For odd row sizes, every other row |
| // isn't aligned, so it can't be vectorized. |
| (cc.major >= 7 && reduction_dimensions.dimensions[2] % 2 == 0))) { |
| return kStridedLinearIndexingX; |
| } else if (!reduction_dimensions.is_row_reduction && |
| IsUnrollingColumnReductionBeneficial( |
| fusion, input_shape, reduction_dimensions.dimensions[2])) { |
| num_partial_results = 2; |
| reduction_tiling[2] *= num_partial_results; |
| return kLinearIndexingX; |
| } else { |
| return kStridedIndexingX; |
| } |
| }(); |
| VLOG(3) << "Each threads will produce " << num_partial_results |
| << " output(s)"; |
| |
| int vector_size = 1; |
| if (indexing_order == kStridedLinearIndexingX) { |
| // Assuming XLA will perform the unrolling and LLVM will vectorize, |
| // disable the unroll for the cases that LLVM doesn't vectorize. |
| if (reduction_dimensions.dimensions[2] % 2 == 0 && |
| !MayPreventVectorization(fusion)) { |
| vector_size = 2; |
| } else { |
| indexing_order = kStridedIndexingX; |
| } |
| } |
| // Reduction constructor |
| std::vector<int64_t> num_threads = {1, num_threads_y, num_threads_x}; |
| TilingScheme tiling_scheme(reduction_dimensions.dimensions, reduction_tiling, |
| num_threads, indexing_order, vector_size); |
| return ReductionCodegenInfo( |
| tiling_scheme, num_partial_results, reduction_dimensions.is_row_reduction, |
| ReductionIsRaceFree(reduction_dimensions, reduction_tiling)); |
| } |
| |
| // Generate a single element of the tile (update the accumulator state) for a |
| // given reducer of index `i`. |
| void IrEmitterUnnested::GenerateElementForReducer( |
| int i, int partial_result_index, const HloComputation* reducer, |
| const ReductionCodegenState& codegen_state, |
| const llvm_ir::IrArray::Index& index_without_linear, |
| const IrArray::Index& input_index, int num_partial_results, |
| const ReductionOutputMap& result_ir_arrays) { |
| CHECK_EQ(reducer->num_parameters() % 2, 0); |
| |
| absl::InlinedVector<llvm::Value*, 2> reduction_accumulators; |
| absl::InlinedVector<llvm::Value*, 2> reduction_input_value; |
| for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) { |
| const ReductionCodegenState::ReductionCalculationState& state = |
| codegen_state.GetCalculationStateFor(i, red_idx); |
| |
| llvm::AllocaInst* input_address = state.input_address; |
| llvm::AllocaInst* partial_reduction_result_address = |
| state.partial_result_address; |
| llvm::Value* const input_ir_value = *state.input_gen( |
| num_partial_results > 1 ? index_without_linear : input_index); |
| b_.CreateStore(input_ir_value, input_address); |
| llvm::Value* partial_result_address = b_.CreateInBoundsGEP( |
| partial_reduction_result_address, {b_.getInt32(partial_result_index)}); |
| reduction_accumulators.push_back(partial_result_address); |
| reduction_input_value.push_back(input_address); |
| } |
| |
| absl::InlinedVector<llvm::Value*, 4> reduction_params; |
| for (llvm::Value* acc : reduction_accumulators) { |
| reduction_params.push_back(acc); |
| } |
| for (llvm::Value* value : reduction_input_value) { |
| reduction_params.push_back(value); |
| } |
| |
| // Emit a call to the variadic reducer. Since it may be returning a |
| // tuple, we can't return it directly as a value. Instead, before |
| // the call, we create N (N = # arguments in the tuple) allocas, one |
| // for each returned argument, then when we make the call we pass N |
| // pointers as last parameters, the called computation writes into |
| // those pointers, and we have returned values on the stack (as well |
| // as pointers to them). |
| StatusOr<std::vector<llvm::Value*>> returned_scalars = |
| ComputeNestedElementFromAddrs(*reducer, reduction_params); |
| TF_CHECK_OK(returned_scalars.status()); |
| |
| for (int i = 0; i < returned_scalars->size(); i++) { |
| b_.CreateStore(returned_scalars->at(i), reduction_accumulators[i]); |
| } |
| } |
| |
| void IrEmitterUnnested::EmitIRForReduction( |
| mlir::lmhlo::FusionOp fusion, absl::Span<const int> instr_index_group, |
| HloComputation* fused_computation, FusedIrEmitter* fused_emitter, |
| const ReductionOutputMap& result_ir_arrays, |
| const ReductionCodegenInfo& reduction_info, const Shape& input_shape) { |
| absl::flat_hash_map<int, HloComputation*> reducers; |
| std::vector<int> reduce_instr_index_group; |
| std::vector<std::pair<llvm_ir::ElementGenerator, int>> extra_output_gens; |
| |
| for (int index : instr_index_group) { |
| const HloInstruction* hlo = GetFusionOutput(fused_computation, index); |
| if (IsReductionFromOrToContiguousDimensions( |
| fusion.getFusionRoots()[index])) { |
| reduce_instr_index_group.push_back(index); |
| reducers[index] = hlo->to_apply(); |
| } else { |
| extra_output_gens.emplace_back(*fused_emitter->GetGenerator(hlo), index); |
| } |
| } |
| |
| CHECK(!reducers.empty()) << " expect at least one reduce instructions."; |
| const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); |
| CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % 32, 0); |
| LaunchDimensions launch_dimensions(tiling_scheme.GetNumberOfBlocks(), |
| tiling_scheme.GetNumThreadsPerBlock()); |
| llvm::Type* index_ty = |
| GetIndexTypeForKernel(fusion, launch_dimensions.launch_bound(), &b_); |
| ReductionCodegenState codegen_state = GenerateReductionCodegenState( |
| fusion, reduction_info, reduce_instr_index_group, fused_computation, |
| fused_emitter); |
| |
| EmitElementFunction emit_reduction_element = |
| [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, |
| llvm::Value* x_loc, int64_t x_iter_num) { |
| IrArray::Index input_index = GetUnnormalizedIndex( |
| index, input_shape, &b_, codegen_state.GetTilingScheme()); |
| |
| int partial_result_index = |
| codegen_state.IsRowReduction() ? 0 : x_iter_num; |
| |
| // Clear the linear index field of the IrArray::Index to enable the use |
| // of GetElementPointer with array types. This enables the vectorization |
| // of the computation for different partial results. Use this index if |
| // 'num_partial_results > 1'. |
| int num_partial_results = codegen_state.GetNumPartialResults(); |
| llvm_ir::IrArray::Index index_without_linear = IrArray::Index( |
| input_index.multidim(), input_shape, input_index.GetType()); |
| |
| // Emit code to generate the input and perform the reduction computation |
| // for each reduction instruction. |
| for (const auto& p : reducers) { |
| GenerateElementForReducer(p.first, partial_result_index, p.second, |
| codegen_state, index_without_linear, |
| input_index, num_partial_results, |
| result_ir_arrays); |
| } |
| |
| // Emit code to generate the output for the non-reduction instructions |
| // in the fusion, if any. |
| TF_CHECK_OK(EmitExtraOutputsForReduce( |
| result_ir_arrays, input_index, |
| /*use_linear_index=*/codegen_state.GetNumPartialResults() == 1, |
| extra_output_gens)); |
| }; |
| |
| TilingKernelInfo tiling_kernel_info = EmitTilingKernel( |
| tiling_scheme, index_ty, |
| [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index, |
| const string& loop_name, llvm::Value* tile_height, |
| llvm::Value* tile_width, KernelSupportLibrary* ksl) { |
| EmitTile(codegen_state.GetTilingScheme(), index, loop_name, ksl, |
| thread_id_info, tile_height, tile_width, |
| emit_reduction_element); |
| }); |
| |
| EmitReductionOutput(index_ty, fusion, reduce_instr_index_group, |
| result_ir_arrays, reducers, codegen_state, |
| tiling_kernel_info); |
| } |
| |
| namespace { |
| |
| // Returns whether the `instr` is either a constant, a scalar, or a |
| // broadcasted constant/scalar. |
| bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) { |
| return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) || |
| (HloOpcode::kBroadcast == instr.opcode() && |
| (instr.operand(0)->IsConstant() || |
| ShapeUtil::IsScalar(instr.operand(0)->shape()))); |
| } |
| |
| // Divides `num_reduces` reduces into groups. Different groups will be executed |
| // in parallel. Generally speaking, we'd like to run the reduce instructions |
| // in parallel without incurring too much recomputation overhead. The current |
| // heuristic is to place reduce instructions who share nothing or only |
| // (broadcasted) scalars/constants into different groups; otherwise, they are |
| // placed in the same group. Non-reduce instructions always go with the reduce |
| // instructions into the same group so long as they share any predecessors. |
| std::vector<std::vector<int>> GroupDisjointReductions( |
| HloComputation* fused_computation) { |
| const Shape& root_shape = fused_computation->root_instruction()->shape(); |
| int num_fusion_outputs = |
| fused_computation->root_instruction()->opcode() == HloOpcode::kTuple |
| ? root_shape.tuple_shapes_size() |
| : 1; |
| CHECK_NE(0, num_fusion_outputs); |
| if (num_fusion_outputs == 1) { |
| return {{0}}; |
| } |
| |
| std::vector<tensorflow::UnionFind<HloInstruction*>> disjoint_sets( |
| num_fusion_outputs); |
| for (size_t i = 0; i < num_fusion_outputs; ++i) { |
| disjoint_sets[i].Get() = |
| fused_computation->root_instruction()->mutable_operand(i); |
| } |
| |
| std::unique_ptr<HloReachabilityMap> reachability_map = |
| HloReachabilityMap::Build(fused_computation); |
| for (HloInstruction* instr : fused_computation->instructions()) { |
| std::vector<int64_t> reached_output_ids; |
| for (size_t oid = 0; oid < num_fusion_outputs; ++oid) { |
| HloInstruction* reduce = |
| fused_computation->root_instruction()->mutable_operand(oid); |
| if (HloOpcode::kReduce == reduce->opcode() && |
| (IsBroadcastedConstantOrScalar(*instr))) { |
| // Do not group output reduce instructions through broadcasted |
| // constants or scalars, as the recomputation should be acceptable. |
| VLOG(3) << "Skip broadcasted constant or scalar " << instr->ToString(); |
| continue; |
| } |
| // Now group output instructions if they have common predecessors. |
| if (reachability_map->IsReachable(instr, reduce)) { |
| VLOG(3) << "Reaching " << reduce->ToString() << " from " |
| << instr->ToString(); |
| reached_output_ids.push_back(oid); |
| } |
| } |
| for (size_t j = 1; j < reached_output_ids.size(); ++j) { |
| disjoint_sets[reached_output_ids[0]].Merge( |
| &disjoint_sets[reached_output_ids[j]]); |
| } |
| } |
| // Place output instructions in the same set into the same group. |
| HloInstructionMap<std::vector<int>> groups; |
| for (size_t oid = 0; oid < num_fusion_outputs; ++oid) { |
| groups[disjoint_sets[oid].Get()].push_back(oid); |
| } |
| |
| std::vector<std::vector<int>> ret; |
| absl::c_for_each( |
| groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); }); |
| return ret; |
| } |
| |
| } // namespace |
| |
| Status IrEmitterUnnested::EmitUnnestedReduction(mlir::lmhlo::FusionOp fusion) { |
| llvm::SmallVector<mlir::Operation*> fusion_roots = fusion.getFusionRoots(); |
| |
| // Build a kernel thunk to compute all the outputs. |
| mlir::mhlo::ReduceOp first_reduce; |
| for (mlir::Operation* output_instruction : fusion.getFusionRoots()) { |
| if (IsReductionFromOrToContiguousDimensions(output_instruction)) { |
| first_reduce = mlir::cast<mlir::mhlo::ReduceOp>(output_instruction); |
| break; |
| } |
| } |
| |
| CHECK(first_reduce) << MlirToString(fusion); |
| |
| Shape input_shape = GetShape(first_reduce->getOperand(0)); |
| // The layout of a reduction input is either set by LayoutAssignment for |
| // unnested kReduce or by InstructionFusion for fused kReduce. |
| CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " |
| "doesn't set the input layout of " |
| << MlirToString(first_reduce); |
| |
| HloComputation* fused_computation = nullptr; |
| TF_ASSIGN_OR_RETURN(fused_computation, |
| GetOrCreateSubComputationFromRegion(&fusion.region(), |
| /*is_fusion=*/true)); |
| |
| // Group disjoint reductions in groups, to be executed in parallel. |
| std::vector<std::vector<int>> instr_index_groups = |
| GroupDisjointReductions(fused_computation); |
| |
| VLOG(2) << StrCat("Generate in ", instr_index_groups.size(), " groups for ", |
| MlirToString(fusion)); |
| |
| ReductionCodegenInfo reduction_info = |
| ComputeReductionCodegenInfo(fusion, first_reduce); |
| const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); |
| |
| // block_y_count is set to instr_index_groups.size(), so that each reduction |
| // group can be run in parallel by a different BlockIdy. |
| LaunchDimensions launch_dimensions( |
| {/*x=*/tiling_scheme.GetNumberOfBlocks(), |
| /*y=*/static_cast<int64_t>(instr_index_groups.size()), |
| /*z=*/1}, |
| {/*x=*/tiling_scheme.GetNumThreadsPerBlock(), /*y=*/1, /*z=*/1}); |
| VLOG(3) << "Launch dimensions of " << mlir::GetNameFromLoc(fusion.getLoc()) |
| << ": number of blocks: " << tiling_scheme.GetNumberOfBlocks() |
| << " - threads per block: " << tiling_scheme.GetNumThreadsPerBlock(); |
| |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> kernel_thunk, |
| BuildKernelThunk(fusion, Thunk::ThunkInfo(), &ir_arrays, |
| launch_dimensions)); |
| |
| GpuElementalIrEmitter elemental_emitter(hlo_module_config_, |
| ir_emitter_context_->llvm_module(), |
| &b_, GetNestedComputer()); |
| FusedIrEmitter fused_emitter(&elemental_emitter); |
| CHECK_LT(fused_computation->num_parameters(), ir_arrays.size()); |
| for (int i = 0; i < fused_computation->num_parameters(); i++) { |
| llvm_ir::IrArray ir_array = ir_arrays[i]; |
| HloInstruction* fused_operand = fused_computation->parameter_instruction(i); |
| fused_emitter.BindGenerator( |
| fused_operand, |
| [this, ir_array, fused_operand](const llvm_ir::IrArray::Index& index) { |
| return ir_array.EmitReadArrayElement(index, &b_, |
| fused_operand->name()); |
| }); |
| } |
| |
| // Get outputs. |
| ReductionOutputMap result_ir_arrays; |
| |
| // Skip all parameter buffers first. |
| int ir_arrays_idx = fused_computation->num_parameters(); |
| for (int root_idx = 0; root_idx < fusion_roots.size(); root_idx++) { |
| mlir::Operation* root = fusion_roots[root_idx]; |
| result_ir_arrays[root_idx] = |
| absl::MakeSpan(ir_arrays).subspan(ir_arrays_idx, root->getNumResults()); |
| ir_arrays_idx += root->getNumResults(); |
| } |
| |
| // We always use the first reduce as representative to construct |
| // ReductionCodegenInfo, since all the reductions are required to have the |
| // same shape and layout as verified by `IsFusedReductionOutputConsistent()`. |
| ReductionCodegenInfo reduction_codegen_info = |
| ComputeReductionCodegenInfo(fusion, first_reduce); |
| |
| KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); |
| for (size_t i = 0; i < instr_index_groups.size(); ++i) { |
| // Use raw block_id_y to select the i-th parallel reduction to run. Using |
| // block_id_y instead of block_id_x simplifies the index calculation |
| // for reduction code generation as the block_id_y is orthogonal to |
| // the indices used within the reductions. |
| llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( |
| gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_); |
| llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), |
| llvm::cast<llvm::Instruction>(raw_block_id_y)); |
| |
| ksl.If(StrCat("reduce-group-", i), |
| b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)), [&] { |
| EmitIRForReduction(fusion, instr_index_groups[i], |
| fused_computation, &fused_emitter, |
| result_ir_arrays, reduction_codegen_info, |
| input_shape); |
| }); |
| } |
| |
| if (hlo_module_config_.debug_options().xla_gpu_deterministic_reductions() && |
| !reduction_codegen_info.IsRaceFree()) { |
| return InternalError( |
| "All reductions should be race-free if deterministic reductions are " |
| "enabled"); |
| } |
| |
| ThunkSequence thunks; |
| |
| // Build an initializer thunk to initialize each reduction output. |
| if (!reduction_codegen_info.IsRaceFree()) { |
| for (int i = 0; i < fusion_roots.size(); ++i) { |
| mlir::Operation* output_instruction = fusion_roots[i]; |
| if (IsReductionFromOrToContiguousDimensions(output_instruction)) { |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk, |
| BuildFusedInitializerThunk(fusion, i)); |
| thunks.push_back(std::move(initializer_thunk)); |
| } |
| } |
| } |
| |
| thunks.push_back(std::move(kernel_thunk)); |
| auto sequential_thunk = absl::make_unique<SequentialThunk>( |
| GetThunkInfo(fusion), std::move(thunks)); |
| AddThunkToThunkSequence(std::move(sequential_thunk)); |
| |
| return Status::OK(); |
| } |
| |
| // Emits code for slices based on the below structure. An if statement with |
| // a guarding condition is generated for each ROOT slice. |
| // |
| // Pseudo code: |
| // |
| // Compute values of slice input operands |
| // |
| // Compute guarding_cond0 |
| // if (guarding_cond0) { |
| // Write to output of slice0 |
| // } |
| // |
| // Compute guarding_cond1 |
| // if (guarding_cond1) { |
| // Write to output of slice1 |
| // } |
| // |
| Status IrEmitterUnnested::EmitElementForInputFusibleSlices( |
| const HloComputation* fused_computation, |
| absl::Span<const llvm_ir::IrArray> ir_arrays, |
| const llvm_ir::IrArray::Index& index) { |
| VLOG(10) << "Emitting slice input fusion for " |
| << fused_computation->ToString(); |
| |
| HloInstruction* slice_or_tuple = fused_computation->root_instruction(); |
| auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> { |
| if (slice_or_tuple->opcode() == HloOpcode::kSlice) { |
| return absl::Span<HloInstruction* const>(&slice_or_tuple, 1); |
| } |
| CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple); |
| return slice_or_tuple->operands(); |
| }(); |
| |
| // Emit input operand values of slices. |
| std::vector<llvm::Value*> input_ir_values; |
| GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, |
| GetNestedComputer()); |
| FusedIrEmitter fused_emitter(&elem_emitter); |
| for (int i = 0; i < fused_computation->num_parameters(); i++) { |
| fused_emitter.BindGenerator( |
| fused_computation->parameter_instruction(i), |
| [this, &ir_arrays, i](llvm_ir::IrArray::Index index) { |
| return ir_arrays[i].EmitReadArrayElement(index, &b_); |
| }); |
| } |
| for (const HloInstruction* slice : slice_instructions) { |
| auto input_generator = *fused_emitter.GetGenerator(slice->operand(0)); |
| input_ir_values.push_back(input_generator(index).ValueOrDie()); |
| } |
| |
| // Emit for slice_instructions. |
| KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); |
| for (int64_t i = 0; i < slice_instructions.size(); ++i) { |
| HloInstruction* slice = slice_instructions[i]; |
| |
| // guarding_cond := index >= start && index < limit, for each dim. |
| std::vector<llvm::Value*> index_within_ranges; |
| for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) { |
| CHECK_EQ(slice->slice_strides(dim), 1); |
| auto larger_or_equal_than_start = b_.CreateICmpSGE( |
| index.multidim()[dim], |
| index.GetConstantWithIndexType(slice->slice_starts(dim))); |
| llvm::Value* smaller_than_limit = b_.CreateICmpSLT( |
| index.multidim()[dim], |
| index.GetConstantWithIndexType(slice->slice_limits(dim))); |
| llvm::Value* within_range = |
| b_.CreateAnd(larger_or_equal_than_start, smaller_than_limit); |
| index_within_ranges.push_back(within_range); |
| } |
| llvm::Value* guarding_cond = b_.CreateAnd(index_within_ranges); |
| |
| auto emit_slice_elem_func = [&] { |
| const std::vector<llvm::Value*>& src_multidim = index.multidim(); |
| std::vector<llvm::Value*> dst_multidim(src_multidim.size()); |
| for (size_t dim = 0; dim < src_multidim.size(); ++dim) { |
| dst_multidim[dim] = |
| Sub(src_multidim[dim], |
| index.GetConstantWithIndexType(slice->slice_starts(dim))); |
| } |
| llvm_ir::IrArray src_ir_array = |
| ir_arrays[fused_computation->num_parameters() + i]; |
| IrArray::Index slice_dst_index(dst_multidim, slice->shape(), |
| index.GetType()); |
| src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i], |
| &b_); |
| }; |
| |
| ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func); |
| } |
| return Status::OK(); |
| } |
| |
| Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( |
| mlir::Operation* op) { |
| auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op); |
| |
| constexpr int unroll_factor = 1; |
| |
| TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation, |
| GetOrCreateSubComputationFromRegion(&fusion.region(), |
| /*is_fusion=*/true)); |
| |
| TF_ASSIGN_OR_RETURN(Shape element_shape, |
| GetConsistentInputShapeForRootSlices(fused_computation)); |
| TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, |
| CalculateLaunchDimensions( |
| element_shape, ir_emitter_context_->gpu_device_info(), |
| {unroll_factor})); |
| |
| std::vector<llvm_ir::IrArray> ir_arrays; |
| TF_ASSIGN_OR_RETURN(auto kernel_thunk, |
| BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays, |
| launch_dimensions)); |
| |
| Status emit_status = |
| ParallelLoopEmitter( |
| [&](const llvm_ir::IrArray::Index index) -> Status { |
| return EmitElementForInputFusibleSlices(fused_computation, |
| ir_arrays, index); |
| }, |
| element_shape, launch_dimensions, &b_) |
| .EmitLoop(IrName(mlir::GetNameFromLoc(fusion.getLoc())), |
| GetIndexTypeForKernel( |
| fusion, launch_dimensions.launch_bound(), &b_)); |
| |
| thunk_sequence_.emplace_back(std::move(kernel_thunk)); |
| |
| return emit_status; |
| } |
| |
| Status IrEmitterUnnested::EmitOp(mlir::Operation* op) { |
| if (mlir::isa<mlir::ConstantOp, mlir::memref::ViewOp, |
| mlir::memref::ReinterpretCastOp, mlir::ReturnOp, |
| mlir::lmhlo::TerminatorOp>(op)) { |
| return Status::OK(); |
| } |
| |
| if (mlir::isa<mlir::memref::GetGlobalOp>(op)) { |
| return EmitConstant(op); |
| } |
| |
| if (auto call = mlir::dyn_cast<mlir::lmhlo::CustomCallOp>(op)) { |
| if (call.call_target_name() == "PadToStatic") { |
| return EmitPadToStatic(op); |
| } |
| if (call.call_target_name() == "SliceToDynamic") { |
| return EmitSliceToDynamic(op); |
| } |
| return EmitCustomCallThunk(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(op)) { |
| return EmitGemmThunk(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp, |
| mlir::lmhlo_gpu::ConvForwardFusedOp, |
| mlir::lmhlo_gpu::ConvForwardFusedSideInputOp, |
| mlir::lmhlo_gpu::ConvBackwardFilterOp, |
| mlir::lmhlo_gpu::ConvBackwardInputOp>(op)) { |
| return EmitConvolutionThunk(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo_gpu::BatchNormTrainingOp, |
| mlir::lmhlo_gpu::BatchNormInferenceOp, |
| mlir::lmhlo_gpu::BatchNormGradOp>(op)) { |
| return EmitBatchNormThunk(op); |
| } |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(op)) { |
| return EmitCholeskyThunk(op); |
| } |
| #endif // GOOGLE_CUDA |
| |
| if (mlir::isa<mlir::lmhlo::FftOp>(op)) { |
| return EmitFftThunk(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::TriangularSolveOp>(op)) { |
| return EmitTriangularSolve(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::FusionOp>(op)) { |
| return EmitFusion(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::SelectAndScatterOp>(op)) { |
| return EmitSelectAndScatter(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::RngGetAndUpdateStateOp>(op)) { |
| return EmitRngGetAndUpdateState(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::ScatterOp>(op)) { |
| return EmitScatter(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::SortOp>(op)) { |
| return EmitSort(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::ReplicaIdOp>(op)) { |
| return EmitReplicaOrPartitionId<ReplicaIdThunk, mlir::lmhlo::ReplicaIdOp>( |
| op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::PartitionIdOp>(op)) { |
| return EmitReplicaOrPartitionId<PartitionIdThunk, |
| mlir::lmhlo::PartitionIdOp>(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::CollectivePermuteOp>(op)) { |
| return EmitCollectivePermute(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::AllGatherOp>(op)) { |
| return EmitNcclThunk<NcclAllGatherThunk, mlir::lmhlo::AllGatherOp>(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::AllReduceOp>(op)) { |
| return EmitNcclThunk<NcclAllReduceThunk, mlir::lmhlo::AllReduceOp>(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo_gpu::AllReduceStartOp>(op)) { |
| return EmitNcclThunk<NcclAllReduceStartThunk, |
| mlir::lmhlo_gpu::AllReduceStartOp>(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo_gpu::AllReduceDoneOp>(op)) { |
| return EmitAllReduceDone(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::ReduceScatterOp>(op)) { |
| return EmitNcclThunk<NcclReduceScatterThunk, mlir::lmhlo::ReduceScatterOp>( |
| op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::AllToAllOp>(op)) { |
| return EmitNcclThunk<NcclAllToAllThunk, mlir::lmhlo::AllToAllOp>(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::InfeedOp>(op)) { |
| return EmitInfeed(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::OutfeedOp>(op)) { |
| return EmitOutfeed(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::CaseOp>(op)) { |
| return EmitConditional(op); |
| } |
| |
| if (mlir::isa<mlir::lmhlo::WhileOp>(op)) { |
| return EmitWhile(op); |
| } |
| |
| return InternalError("Unrecognized op: %s", MlirToString(op)); |
| } |
| |
| Status IrEmitterUnnested::EmitLmhloRegion(mlir::Region* region) { |
| for (mlir::Operation& op : llvm::make_early_inc_range(region->front())) { |
| TF_RETURN_IF_ERROR(EmitOp(&op)); |
| } |
| return Status::OK(); |
| } |
| |
| Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(mlir::Operation* op) { |
| auto module = op->getParentOfType<mlir::ModuleOp>(); |
| Thunk::ThunkInfo thunk_info; |
| thunk_info.profile_annotation = absl::StrFormat( |
| "Thunk:#hlo_op=%s,hlo_module=%s#", mlir::GetNameFromLoc(op->getLoc()), |
| mlir::GetNameFromLoc(module->getLoc())); |
| return thunk_info; |
| } |
| |
| void MlirEmitterContext::SetOperation(mlir::Operation* op) { |
| this->name = mlir::GetNameFromLoc(op->getLoc()); |
| |
| auto operands = GetHloOperands(op); |
| auto outputs = GetHloOutputs(op); |
| for (auto operand : operands) { |
| operand_shapes.push_back(GetShape(operand)); |
| } |
| for (auto output : outputs) { |
| output_shapes.push_back(GetShape(output)); |
| } |
| } |
| |
| } // namespace gpu |
| } // namespace xla |