| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" |
| |
| #include <vector> |
| |
| // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/types/span.h" |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/IR/BasicBlock.h" |
| #include "llvm/IR/Constants.h" |
| #include "llvm/IR/DerivedTypes.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/IR/Value.h" |
| #include "tensorflow/compiler/xla/primitive_util.h" |
| #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" |
| #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" |
| #include "tensorflow/compiler/xla/service/gpu/target_util.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" |
| #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/core/lib/core/status.h" |
| |
| namespace xla { |
| namespace llvm_ir { |
| |
| namespace { |
| |
| // Adds the inner comparison loop body where we compare elements. |
| Status EmitCompareLoopBody( |
| int64_t iteration_bound, int64_t num_values, |
| llvm::Value* element_pair_index, int64_t xor_mask, llvm::Type* index_type, |
| std::function<llvm::Value*(int64_t operand, llvm::Value* index)> |
| element_address, |
| std::function<llvm::Type*(int64_t operand, llvm::Value* index)> |
| element_address_pointee_type, |
| std::function<void(int64_t operand, llvm::Value* index, llvm::Value* value)> |
| write_element, |
| const EmitCallToNestedComputationCallback& emit_compare_callback, |
| llvm::IRBuilder<>* b, bool needs_bounds_checks = true) { |
| auto index_typed_constant = [&](int64_t value) { |
| return llvm::ConstantInt::get(index_type, value); |
| }; |
| // The 'xor_mask' determines which elements are compared against each other. |
| // Index 'current_keys_index' will be compared with 'current_keys_index' xor |
| // 'xor_mask'. This means that we will always compare a block of consecutive |
| // elements against elements from the adjacent block of the same size. When |
| // 'xor_mask' is a power of 2, it immediately identifies the size of such a |
| // block. We can also have 'xor_mask' being 2^k - 1 (for some value of k). In |
| // that case, we essentially flip the last 'k' - 1 bits when computing the |
| // position of the element to compare to, so the block size is 2^(k - 1). |
| int64_t block_size = xor_mask; |
| // Check if it is a value 2^k - 1. |
| if (xor_mask > 1 && (xor_mask & (xor_mask + 1)) == 0) { |
| block_size = (xor_mask + 1) / 2; |
| } |
| auto current_keys_index = element_pair_index; |
| if (block_size == 1) { |
| // If the block size is 1, we take every second element and compare it to |
| // the next one. |
| current_keys_index = |
| b->CreateMul(current_keys_index, index_typed_constant(2)); |
| } else if (block_size * 2 < iteration_bound) { |
| // current_keys_index iterates through the 'left' elements of the element |
| // pairs to be compared. We first need to compute the comparison block to |
| // which the element belongs. The block id of that block is index / |
| // block_size. |
| auto block_id = |
| b->CreateUDiv(current_keys_index, index_typed_constant(block_size)); |
| // The index of the 'left' element within its block is simply the remainder |
| // when dividing by 'block_size'. |
| auto index_within_block = |
| b->CreateURem(current_keys_index, index_typed_constant(block_size)); |
| // The first element of the 'left' block of elements that is compared |
| // against elements from the adjacent 'right' block of elements is |
| // 'block_id' * (2 * 'block_size'). |
| auto first_element_in_block = |
| b->CreateMul(block_id, index_typed_constant(2 * block_size)); |
| current_keys_index = |
| b->CreateAdd(first_element_in_block, index_within_block); |
| } |
| auto compare_keys_index = |
| b->CreateXor(current_keys_index, index_typed_constant(xor_mask)); |
| // current_keys_index < compare_keys_index |
| llvm::Value* is_smaller_index = |
| b->CreateICmpSLT(current_keys_index, compare_keys_index); |
| // compare_keys_index < iteration_bound |
| llvm::Value* index_is_inbounds = b->CreateICmpSLT( |
| compare_keys_index, index_typed_constant(iteration_bound)); |
| llvm::Value* do_comparison = |
| needs_bounds_checks ? b->CreateAnd(is_smaller_index, index_is_inbounds) |
| : b->getInt1(true); |
| |
| // if (is_smaller_index && index_is_inbounds) |
| KernelSupportLibrary ksl(b); |
| return ksl.IfWithStatus("smaller_comparison_index", do_comparison, [&]() { |
| std::vector<llvm::Value*> values_to_compare; |
| std::vector<llvm::Type*> values_to_compare_types; |
| for (int i = 0; i < num_values; ++i) { |
| values_to_compare.push_back(element_address(i, compare_keys_index)); |
| values_to_compare_types.push_back( |
| element_address_pointee_type(i, compare_keys_index)); |
| |
| values_to_compare.push_back(element_address(i, current_keys_index)); |
| values_to_compare_types.push_back( |
| element_address_pointee_type(i, current_keys_index)); |
| } |
| llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); |
| llvm::Type* pred_type = llvm_ir::PrimitiveTypeToIrType(PRED, module); |
| llvm::Value* compare_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( |
| pred_type, "compare_return_buffer", b); |
| TF_RETURN_IF_ERROR( |
| emit_compare_callback(values_to_compare, compare_return_buffer)); |
| llvm::Value* result = b->CreateLoad(pred_type, compare_return_buffer); |
| |
| // Check if the 'compare' function returns true. |
| llvm::Value* is_smaller_than = |
| b->CreateICmpNE(result, llvm::ConstantInt::get(result->getType(), 0), |
| "boolean_predicate"); |
| ksl.If("is_smaller_than", is_smaller_than, [&]() { |
| for (int64_t i = 0; i < num_values; ++i) { |
| // Swap the values. |
| auto value1 = b->CreateLoad(values_to_compare_types[i * 2], |
| values_to_compare[i * 2]); |
| auto value2 = b->CreateLoad(values_to_compare_types[i * 2 + 1], |
| values_to_compare[i * 2 + 1]); |
| write_element(i, current_keys_index, value1); |
| write_element(i, compare_keys_index, value2); |
| } |
| }); |
| return ::tensorflow::OkStatus(); |
| }); |
| } |
| |
| Status EmitTiledCompareLoop( |
| const IrArray::Index& tiled_keys_index, int64_t dimension_to_sort, |
| int64_t dimension_to_sort_bound, absl::Span<const int64_t> xor_masks, |
| const std::vector<IrArray>& params, |
| const std::vector<llvm::GlobalVariable*>& param_shmem_buffers, |
| int64_t tile_size, |
| const EmitCallToNestedComputationCallback& emit_compare_callback, |
| llvm::IRBuilder<>* b) { |
| KernelSupportLibrary ksl(b); |
| llvm::Value* thread_id = gpu::EmitCallToTargetIntrinsic( |
| gpu::TargetIntrinsicID::kThreadIdx, {}, {}, b); |
| llvm_ir::AddRangeMetadata(0, tile_size / 2, |
| llvm::cast<llvm::Instruction>(thread_id)); |
| thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(), |
| /*isSigned=*/true, "thread.id.x"); |
| |
| auto copy_loop_body = |
| [&](std::function<void(llvm::Value * cache_index, llvm::Value * index)> |
| read_or_write) { |
| auto value_one = tiled_keys_index.GetConstantWithIndexType(1); |
| auto current_keys_index = |
| b->CreateShl(tiled_keys_index[dimension_to_sort], value_one); |
| // We want to copy two adjacent elements. We first check whether the |
| // first index position is within bounds. |
| ksl.If( |
| "smaller_keys_index", |
| b->CreateICmpSLT(current_keys_index, |
| tiled_keys_index.GetConstantWithIndexType( |
| dimension_to_sort_bound)), |
| [&]() { |
| auto cache_index = b->CreateShl(thread_id, value_one); |
| read_or_write(cache_index, current_keys_index); |
| // Increment to go to the next index position. |
| current_keys_index = b->CreateAdd(current_keys_index, value_one); |
| // Here we check whether the next index position is within bounds. |
| ksl.If("inner_smaller_keys_index", |
| b->CreateICmpSLT(current_keys_index, |
| tiled_keys_index.GetConstantWithIndexType( |
| dimension_to_sort_bound)), |
| [&]() { |
| cache_index = b->CreateAdd(cache_index, value_one); |
| read_or_write(cache_index, current_keys_index); |
| }); |
| }); |
| }; |
| |
| // Copy operand tiles from the operand buffers to shared memory. |
| std::vector<llvm::Value*> keys_multi_index = tiled_keys_index.multidim(); |
| for (int64_t i = 0; i < params.size(); ++i) { |
| copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { |
| keys_multi_index[dimension_to_sort] = index; |
| IrArray::Index keys_index(keys_multi_index, params[i].GetShape(), |
| tiled_keys_index.GetType()); |
| auto value = params[i].EmitReadArrayElement(keys_index, b); |
| b->CreateStore( |
| value, |
| b->CreateGEP( |
| param_shmem_buffers[i]->getValueType(), param_shmem_buffers[i], |
| {tiled_keys_index.GetConstantWithIndexType(0), cache_index})); |
| }); |
| } |
| // Wait until all reads have happened. |
| gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBarrierId, {}, {}, b); |
| |
| // Now emit the bodies of the comparison loops. |
| auto element_address = [&](int64_t operand, llvm::Value* index) { |
| auto shared_memory_address = |
| b->CreateGEP(param_shmem_buffers[operand]->getValueType(), |
| param_shmem_buffers[operand], |
| {tiled_keys_index.GetConstantWithIndexType(0), index}); |
| auto ptr_type = shared_memory_address->getType(); |
| // We need a generic pointer with address space 0 instead of a pointer to |
| // shared memory (address space 3) so that we can pass it to the comparison |
| // computation. |
| return b->CreateAddrSpaceCast(shared_memory_address, |
| llvm::PointerType::getWithSamePointeeType( |
| llvm::cast<llvm::PointerType>(ptr_type), |
| /*AddressSpace=*/0)); |
| }; |
| auto element_address_pointee_type = [&](int64_t operand, llvm::Value* index) { |
| return llvm::GetElementPtrInst::getIndexedType( |
| param_shmem_buffers[operand]->getValueType(), |
| {tiled_keys_index.GetConstantWithIndexType(0), index}); |
| }; |
| auto write_element = [&](int64_t operand, llvm::Value* index, |
| llvm::Value* value) { |
| b->CreateStore( |
| value, |
| b->CreateGEP(param_shmem_buffers[operand]->getValueType(), |
| param_shmem_buffers[operand], |
| {tiled_keys_index.GetConstantWithIndexType(0), index})); |
| }; |
| for (int64_t xor_mask : xor_masks) { |
| // The index of the element pair to be compared within the tile stored in |
| // shared memory. We order the element pairs by the element with the smaller |
| // index. |
| auto element_pair_index = thread_id; |
| // If 'dimension_to_sort_bound' is evenly divisible by 'tile_size', we don't |
| // need any bounds checks. |
| if (dimension_to_sort_bound % tile_size) { |
| // Otherwise we need a bounds check for the last tile. The last tile has |
| // size 'dimension_to_sort_bound' % 'tile_size'. |
| TF_RETURN_IF_ERROR(ksl.IfWithStatus( |
| "is_last_tile", |
| b->CreateICmpUGE( |
| b->CreateMul(tiled_keys_index[dimension_to_sort], |
| tiled_keys_index.GetConstantWithIndexType(2)), |
| tiled_keys_index.GetConstantWithIndexType( |
| RoundDownTo(dimension_to_sort_bound, tile_size))), |
| [&]() { |
| return EmitCompareLoopBody( |
| dimension_to_sort_bound % tile_size, params.size(), |
| element_pair_index, xor_mask, tiled_keys_index.GetType(), |
| element_address, element_address_pointee_type, write_element, |
| emit_compare_callback, b); |
| }, |
| [&]() { |
| return EmitCompareLoopBody( |
| tile_size, params.size(), element_pair_index, xor_mask, |
| tiled_keys_index.GetType(), element_address, |
| element_address_pointee_type, write_element, |
| emit_compare_callback, b, |
| /*needs_bounds_checks=*/false); |
| })); |
| } else { |
| TF_RETURN_IF_ERROR(EmitCompareLoopBody( |
| tile_size, params.size(), element_pair_index, xor_mask, |
| tiled_keys_index.GetType(), element_address, |
| element_address_pointee_type, write_element, emit_compare_callback, b, |
| /*needs_bounds_checks=*/false)); |
| } |
| // Wait until all comparisons have happened. |
| gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBarrierId, {}, {}, |
| b); |
| } |
| |
| // Copy the operand tiles back from shared memory to the operand buffers. |
| for (int64_t i = 0; i < params.size(); ++i) { |
| copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { |
| keys_multi_index[dimension_to_sort] = index; |
| IrArray::Index keys_index(keys_multi_index, params[i].GetShape(), |
| tiled_keys_index.GetType()); |
| auto gep = b->CreateGEP( |
| param_shmem_buffers[i]->getValueType(), param_shmem_buffers[i], |
| {tiled_keys_index.GetConstantWithIndexType(0), cache_index}); |
| auto gep_type = llvm::GetElementPtrInst::getIndexedType( |
| param_shmem_buffers[i]->getValueType(), |
| {tiled_keys_index.GetConstantWithIndexType(0), cache_index}); |
| auto value = b->CreateLoad(gep_type, gep); |
| params[i].EmitWriteArrayElement(keys_index, value, b); |
| }); |
| } |
| // We should normally synchronize here to make sure all writes have happened. |
| // However the very next thing each thread does is reading 2 elements from the |
| // operand buffer and writing it into the same location in shared memory from |
| // which it previously copied it to the operand buffer, and we synchronize |
| // after this has happened. We can be sure that a thread always writes to the |
| // same location in shared memory because we have exactly tile_size / 2 many |
| // threads, and the linear index calculated by ParallelLoopEmitter uses |
| // linear_index = blockIdx.x * blockDim.x + threadIdx.x; |
| return ::tensorflow::OkStatus(); |
| } |
| } // namespace |
| |
| Status EmitSortInPlace( |
| int64_t dimension_to_sort, const std::vector<IrArray>& values_arrays, |
| absl::string_view name, absl::Span<const int64_t> xor_masks, |
| llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions, |
| int64_t num_iterations_in_sort_dim, const int64_t tile_size, |
| const EmitCallToNestedComputationCallback& emit_compare_callback) { |
| // Iterate through the keys shape in physical order, but skip the dimension to |
| // sort and make it the innermost loop which is the loop where the comparisons |
| // happen. In the dimension to sort, if we use tiling, we iterate through it |
| // in tiles of 64 elements each, so we use another loop that happens within |
| // one thread to process this tile worth of data (thereby combining several |
| // comparison stages of the bitonic sort algorithm because they all happen |
| // within those 64 elements and are therefore independent of the other |
| // comparisons). |
| |
| const Shape& keys_shape = values_arrays[0].GetShape(); |
| int64_t rank = keys_shape.rank(); |
| int64_t dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); |
| std::vector<int64_t> dimensions_in_iteration_order(rank); |
| std::vector<int64_t> iteration_order_to_logical_order(rank); |
| int64_t dim = 0; |
| for (int64_t dimension : LayoutUtil::MinorToMajor(keys_shape)) { |
| if (dimension != dimension_to_sort) { |
| dimensions_in_iteration_order[dim] = keys_shape.dimensions(dimension); |
| iteration_order_to_logical_order[dim++] = dimension; |
| } |
| } |
| dimensions_in_iteration_order[dim] = num_iterations_in_sort_dim; |
| iteration_order_to_logical_order[dim] = dimension_to_sort; |
| |
| Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(), |
| dimensions_in_iteration_order); |
| |
| // Allocate shared memory for the tiled compare loop. |
| std::vector<llvm::GlobalVariable*> param_shmem_buffers(values_arrays.size(), |
| nullptr); |
| if (xor_masks.size() > 1) { |
| llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); |
| for (int64_t i = 0; i < values_arrays.size(); ++i) { |
| llvm::Type* tile_type = llvm::ArrayType::get( |
| llvm_ir::PrimitiveTypeToIrType( |
| values_arrays[i].GetShape().element_type(), module), |
| tile_size); |
| param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile( |
| module, tile_type, absl::StrCat(name, "_tile_param_", i)); |
| } |
| } |
| |
| auto compare_loop_body_emitter = |
| [&](const IrArray::Index& tiles_index) -> Status { |
| // Naive C++ code for the inner compare loop: |
| // |
| // for (int64_t i = 0; i < dimension_to_sort_bound; ++i) { |
| // int64_t j = i ^ xor_mask; |
| // /* emitted in EmitCompareLoopBody() */ |
| // if (i < j && j < dimension_to_sort_bound) { |
| // int64_t min_key = std::min(keys[i], keys[j]); |
| // keys[j] = std::max(keys[i], keys[j]); |
| // keys[i] = min_key; |
| // } |
| // } |
| // |
| // This follows the algorithm described on Wikipedia: |
| // https://en.wikipedia.org/wiki/Bitonic_sorter |
| std::vector<llvm::Value*> keys_multi_index(rank); |
| for (int64_t i = 0; i < rank; ++i) { |
| keys_multi_index[iteration_order_to_logical_order[i]] = tiles_index[i]; |
| } |
| if (xor_masks.size() > 1) { |
| IrArray::Index keys_index(keys_multi_index, values_arrays[0].GetShape(), |
| tiles_index.GetType()); |
| TF_RETURN_IF_ERROR(EmitTiledCompareLoop( |
| keys_index, dimension_to_sort, dimension_to_sort_bound, xor_masks, |
| values_arrays, param_shmem_buffers, tile_size, emit_compare_callback, |
| b)); |
| } else { |
| auto element_address = [&](int64_t operand, llvm::Value* index) { |
| keys_multi_index[dimension_to_sort] = index; |
| IrArray::Index keys_index(keys_multi_index, |
| values_arrays[operand].GetShape(), |
| tiles_index.GetType()); |
| return values_arrays[operand].EmitArrayElementAddress(keys_index, b); |
| }; |
| auto element_address_pointee_type = [&](int64_t operand, llvm::Value*) { |
| return values_arrays[operand].GetElementLlvmType(); |
| }; |
| auto write_element = [&](int64_t operand, llvm::Value* index, |
| llvm::Value* value) { |
| keys_multi_index[dimension_to_sort] = index; |
| IrArray::Index keys_index(keys_multi_index, |
| values_arrays[operand].GetShape(), |
| tiles_index.GetType()); |
| values_arrays[operand].EmitWriteArrayElement(keys_index, value, b); |
| }; |
| TF_RETURN_IF_ERROR(EmitCompareLoopBody( |
| dimension_to_sort_bound, values_arrays.size(), tiles_index[rank - 1], |
| xor_masks[0], tiles_index.GetType(), element_address, |
| element_address_pointee_type, write_element, emit_compare_callback, |
| b)); |
| } |
| return ::tensorflow::OkStatus(); |
| }; |
| return gpu::ParallelLoopEmitter(compare_loop_body_emitter, iteration_shape, |
| launch_dimensions, b) |
| .EmitLoop(name); |
| } |
| |
| } // namespace llvm_ir |
| } // namespace xla |