blob: f07495618dff43e1ef6e862676c86b65e79fedfb [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include <algorithm>
#include <array>
#include <vector>
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/Module.h"
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/target_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/stream_executor/device_description.h"
namespace xla {
namespace gpu {
namespace {
// Return whether the given shape is rank 2 excluding the batch dimensions.
bool IsRank2(const Shape& shape, int64_t batch_dimensions_size) {
return shape.rank() == batch_dimensions_size + 2;
}
// Given a shape and a group of contiguous dimensions in the shape, returns
// a tuple of three values (major, middle, minor), where major is the size of
// the dimensions more major then the given dimensions, minor is the size of
// dimensions more minor then the given dimensions, and middle is the size of
// the given dimensions.
std::array<int64, 3> PartitionShapeByMiddleDimensions(
const Shape& shape, absl::Span<const int64> dims_middle) {
CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
std::array<int64, 3> values = {1, 1, 1};
enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
Segment cur_segment = kMinor;
for (int64_t cur_dim : LayoutUtil::MinorToMajor(shape)) {
if (cur_segment != kMajor) {
// Handle change of segments.
bool cur_dim_in_middle = absl::c_linear_search(dims_middle, cur_dim);
if (cur_segment == kMinor) {
if (cur_dim_in_middle) {
cur_segment = kMiddle;
}
} else if (cur_segment == kMiddle) {
if (!cur_dim_in_middle) {
cur_segment = kMajor;
}
}
}
values[cur_segment] *= shape.dimensions(cur_dim);
}
return values;
}
Shape GetShapeFromTensorType(mlir::Value value) {
constexpr char kDefaultLayoutAttrName[] = "minor_to_major";
mlir::Operation* op = value.getDefiningOp();
CHECK(op);
CHECK(value.getType().isa<mlir::TensorType>());
Shape shape = TypeToShape(value.getType());
if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>(
kDefaultLayoutAttrName)) {
std::vector<int64> minor_to_major;
absl::c_transform(
attr, std::back_inserter(minor_to_major),
std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
*shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
} else {
*shape.mutable_layout() = LayoutUtil::MakeDescendingLayout(
value.getType().cast<mlir::ShapedType>().getShape().size());
}
return shape;
}
} // namespace
bool IsMatrixMultiplication(const HloInstruction& dot) {
if (dot.opcode() != HloOpcode::kDot) {
return false;
}
const Shape& lhs_shape = dot.operand(0)->shape();
const Shape& rhs_shape = dot.operand(1)->shape();
const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
PrimitiveType output_primitive_type = dot.shape().element_type();
bool type_is_allowed =
(output_primitive_type == F16 || output_primitive_type == BF16 ||
output_primitive_type == F32 || output_primitive_type == F64 ||
output_primitive_type == C64 || output_primitive_type == C128) ||
(output_primitive_type == S32 && lhs_shape.element_type() == S8 &&
lhs_shape.element_type() == S8);
bool shapes_are_valid =
type_is_allowed &&
IsRank2(lhs_shape, dim_numbers.lhs_batch_dimensions_size()) &&
IsRank2(rhs_shape, dim_numbers.lhs_batch_dimensions_size()) &&
IsRank2(dot.shape(), dim_numbers.lhs_batch_dimensions_size()) &&
!ShapeUtil::IsZeroElementArray(lhs_shape) &&
!ShapeUtil::IsZeroElementArray(rhs_shape);
if (!shapes_are_valid) {
return false;
}
// The size of the reduction dimension should match. The shape inference
// guarantees this invariant, so the check here is for programming
// errors.
CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
return true;
}
bool IsCublasGemm(const HloInstruction& hlo) {
return hlo.opcode() == HloOpcode::kCustomCall &&
hlo.custom_call_target() == kGemmCallTarget;
}
std::array<int64, 3> GetReductionTiling(
const ReductionDimensions& reduction_dimensions,
int smallest_input_dtype_bits,
se::CudaComputeCapability cuda_compute_capability) {
if (reduction_dimensions.is_row_reduction) {
int64_t tile_z = std::min(reduction_dimensions.dimensions[0],
kBatchedReductionRaceFreeBound);
return {tile_z, 1, 64};
}
// Column reduction.
return {1, 128, 1};
}
const char* const kCudnnBatchNormForwardInferenceCallTarget =
"__cudnn$batchNormalizationForwardInference";
const char* const kCudnnBatchNormForwardTrainingCallTarget =
"__cudnn$batchNormalizationForwardTraining";
const char* const kCudnnBatchNormBackwardCallTarget =
"__cudnn$batchNormalizationBackward";
bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) {
if (hlo.opcode() != HloOpcode::kCustomCall) {
return false;
}
const auto& target = hlo.custom_call_target();
return target == kCudnnBatchNormForwardInferenceCallTarget ||
target == kCudnnBatchNormForwardTrainingCallTarget ||
target == kCudnnBatchNormBackwardCallTarget;
}
const char* const kGemmCallTarget = "__cublas$gemm";
const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward";
const char* const kCudnnConvBackwardInputCallTarget =
"__cudnn$convBackwardInput";
const char* const kCudnnConvBackwardFilterCallTarget =
"__cudnn$convBackwardFilter";
const char* const kCudnnConvBiasActivationForwardCallTarget =
"__cudnn$convBiasActivationForward";
bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
if (hlo.opcode() != HloOpcode::kCustomCall) {
return false;
}
const auto& target = hlo.custom_call_target();
return target == kCudnnConvForwardCallTarget ||
target == kCudnnConvBackwardInputCallTarget ||
target == kCudnnConvBackwardFilterCallTarget ||
target == kCudnnConvBiasActivationForwardCallTarget;
}
const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
bool IsCustomCallToCusolver(const HloInstruction& hlo) {
if (hlo.opcode() != HloOpcode::kCustomCall) {
return false;
}
const auto& target = hlo.custom_call_target();
return target == kCusolverCholeskyCallTarget;
}
bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
return IsCublasGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
IsCustomCallToDnnConvolution(hlo);
}
static ReductionDimensions GetReductionKindAndContiguousComponentsImpl(
const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
DimensionVector dims_to_keep;
for (int64_t dim = 0; dim < input_shape.rank(); ++dim) {
if (!absl::c_linear_search(dims_to_reduce, dim)) {
dims_to_keep.push_back(dim);
}
}
if (dims_to_keep.empty()) {
return {/*is_row_reduction=*/true,
{1, 1, ShapeUtil::ElementsIn(input_shape)}};
}
if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
dims_to_keep)) {
std::array<int64, 3> shape_partition =
PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
if (shape_partition[1] == 1) {
return {/*is_row_reduction=*/true,
{1, 1, shape_partition[0] * shape_partition[2]}};
}
if (shape_partition[2] == 1) {
return {/*is_row_reduction=*/false,
{1, shape_partition[0], shape_partition[1]}};
}
return {/*is_row_reduction=*/true, shape_partition};
}
std::array<int64, 3> shape_partition =
PartitionShapeByMiddleDimensions(input_shape, dims_to_reduce);
if (shape_partition[2] == 1) {
return {/*is_row_reduction=*/true,
{1, shape_partition[0], shape_partition[1]}};
}
return {/*is_row_reduction=*/false, shape_partition};
}
static bool IsUnnestedReductionFasterThanElemental(
const ReductionDimensions& reduction_dimensions) {
if (reduction_dimensions.is_row_reduction) {
// For row reduction, the tile block is 1 x tile_size_x, and we are reducing
// along tile_size_x which needs to be large enough to make the tiling
// implementation efficient.
return reduction_dimensions.dimensions[2] >= kWarpSize;
}
// For column reduction, the tile block is tile_size_y x tile_size_x, and we
// are reducing along tile_size_y. Only tile_size_y needs to be
// large enough to make the tiling implementation efficient.
return reduction_dimensions.dimensions[1] >= kWarpSize;
}
bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
if (HloOpcode::kReduce != reduce.opcode()) {
return false;
}
// TODO(b/129698548): Remove this check after fixing the bug.
if (reduce.shape().element_type() == C128) {
return false;
}
const HloInstruction* input = reduce.operand(0);
std::vector<int64> dims_to_keep;
for (int64_t dim = 0; dim < input->shape().dimensions().size(); ++dim) {
if (!absl::c_linear_search(reduce.dimensions(), dim)) {
dims_to_keep.push_back(dim);
}
}
// We support fast codegen for three cases:
// 1) Row reduction: (K, R)
// 2) Column reduction: (K, R, K)
// 3) "Batched" row reduction: (R, K, R)
if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
dims_to_keep) &&
!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
reduce.dimensions())) {
return false;
}
return IsUnnestedReductionFasterThanElemental(
GetReductionKindAndContiguousComponents(reduce));
}
// Constructs the fusion layout analysis object by using a heuristic to infer
// the layout of a fusion internal value. In general, if the value is derived
// from a fusion parameter (which by definition has a layout) using elementwise
// operations, it will inherit the layout of that parameter. OTOH if the value
// if written to a fusion output, it will inherit the layout of that output.
// If the heuristic fails, the default layout will be inferred.
FusionLayoutAnalysis::FusionLayoutAnalysis(mlir::lmhlo::FusionOp fusion_op) {
VLOG(3) << "Analyzing \n" << MlirToString(fusion_op);
auto add_layout = [this](mlir::Value v, const Layout& layout) {
layouts_[v] = layout;
VLOG(3) << "===============\n";
VLOG(3) << "For value \n" << MlirToString(v.getDefiningOp());
VLOG(3) << "Layout = " << layout.ToString() << "\n";
VLOG(3) << "===============\n";
};
// Propagate layouts inside fusion region.
for (mlir::Operation& op : fusion_op.region().front().without_terminator()) {
if (auto load = mlir::dyn_cast<mlir::memref::TensorLoadOp>(op)) {
add_layout(load, GetShape(load.memref()).layout());
} else if (auto store = mlir::dyn_cast<mlir::memref::TensorStoreOp>(op)) {
// Propagate the stored memref layout to the value if it does not have a
// inferred layout already. This prefers load coalescing over stores.
if (layouts_.count(store.tensor()) == 0) {
add_layout(store.tensor(), GetShape(store.memref()).layout());
}
} else if (auto bitcast = mlir::dyn_cast<mlir::mhlo::BitcastOp>(op)) {
auto attr =
bitcast->getAttrOfType<mlir::DenseIntElementsAttr>("result_layout");
std::vector<int64> minor_to_major;
absl::c_transform(
attr, std::back_inserter(minor_to_major),
std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
add_layout(bitcast, LayoutUtil::MakeLayout(minor_to_major));
attr =
bitcast->getAttrOfType<mlir::DenseIntElementsAttr>("source_layout");
minor_to_major.clear();
absl::c_transform(
attr, std::back_inserter(minor_to_major),
std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
add_layout(bitcast.operand(), LayoutUtil::MakeLayout(minor_to_major));
} else {
HloOpcode opcode = *xla::MhloToHloOpcode(&op);
if (!HloInstruction::IsOpElementwise(opcode)) {
continue;
}
// If any operand has a layout, infer that layout for the result of the
// operation. If 2 operands have a conflicting layout, we still need to
// choose one of them, so we will arbitrarily choose the first one.
for (mlir::Value operand : op.getOperands()) {
auto it = layouts_.find(operand);
if (it != layouts_.end()) {
// Do not pass in a reference to an entry in the map when adding a new
// entry. The map may expand when adding, and the reference may become
// invalid. To avoid this, create a local copy of the layout.
const Layout operand_layout = it->second;
add_layout(op.getResult(0), operand_layout);
break;
}
}
}
}
}
Shape FusionLayoutAnalysis::GetShape(mlir::Value value) const {
Shape shape = TypeToShape(value.getType());
if (!value.getType().isa<mlir::MemRefType>()) {
auto it = layouts_.find(value);
if (it != layouts_.end()) {
*shape.mutable_layout() = it->second;
}
}
return shape;
}
bool IsReductionFromOrToContiguousDimensions(
mlir::Operation* reduce, const FusionLayoutAnalysis& layout_analysis) {
if (!mlir::isa<mlir::mhlo::ReduceOp>(reduce)) {
return false;
}
std::vector<mlir::Value> results = GetHloOutputs(reduce);
CHECK_EQ(1, results.size());
auto c128_type =
mlir::ComplexType::get(mlir::FloatType::getF64(reduce->getContext()));
// TODO(b/129698548): Remove this check after fixing the bug.
if (results[0].getType().cast<mlir::ShapedType>().getElementType() ==
c128_type) {
return false;
}
mlir::Value input = reduce->getOperand(0);
const Shape operand_shape = layout_analysis.GetShape(input);
// Enable this code to check mismatch between the inferred layout and what was
// there before. Based on actual runs, some mismatches are expected.
#if 0
Shape operand_shape_ir = GetShape(input);
if (auto tensor_type = input.getType().dyn_cast<mlir::TensorType>()) {
if (auto attr = mlir::GetLayoutFromMlirHlo(input.getDefiningOp())) {
std::vector<int64> minor_to_major;
absl::c_transform(
attr, std::back_inserter(minor_to_major),
std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
*operand_shape_ir.mutable_layout() =
LayoutUtil::MakeLayout(minor_to_major);
}
}
bool match = ShapeUtil::Equal(operand_shape, operand_shape_ir);
llvm::errs() << "inferred shape = " << operand_shape.ToString(true) << "\n";
llvm::errs() << "Actual shape in IR = " << operand_shape_ir.ToString(true)
<< "\n";
if (!match) {
llvm::errs() << "Unable to infer layout for reduce op operand(0)\n";
llvm::errs() << "\nreduce = \n";
reduce->dump();
llvm::errs() << "\nparent = \n";
reduce->getParentOp()->dump();
CHECK(0);
}
#endif
std::vector<int64> dimensions;
{
auto attr = reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions");
CHECK(attr);
absl::c_transform(
attr, std::back_inserter(dimensions),
std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
}
std::vector<int64> dims_to_keep;
for (int64_t dim = 0; dim < operand_shape.dimensions().size(); ++dim) {
if (!absl::c_linear_search(dimensions, dim)) {
dims_to_keep.push_back(dim);
}
}
// We support fast codegen for three cases:
// 1) Row reduction: (K, R)
// 2) Column reduction: (K, R, K)
// 3) "Batched" row reduction: (R, K, R)
if (!LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
dims_to_keep) &&
!LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
dimensions)) {
return false;
}
return IsUnnestedReductionFasterThanElemental(
GetReductionKindAndContiguousComponentsImpl(operand_shape, dimensions));
}
bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
bool verify_no_strides) {
auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
if (!fusion) {
return false;
}
auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool {
return absl::c_all_of(
strides, [](const llvm::APInt& stride) { return stride == 1; });
};
for (mlir::Value value : fusion.getFusionResults()) {
auto slice =
mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(value.getDefiningOp());
if (!slice) {
return false;
}
if (verify_no_strides && !is_non_strided(slice.strides())) {
return false;
}
}
return true;
}
ReductionDimensions GetReductionKindAndContiguousComponents(
const HloInstruction& reduce) {
return GetReductionKindAndContiguousComponentsImpl(reduce.operand(0)->shape(),
reduce.dimensions());
}
ReductionDimensions GetReductionKindAndContiguousComponents(
mlir::Operation* reduce) {
mlir::Value input = reduce->getOperand(0);
Shape operand_shape = GetShape(input);
std::vector<int64> dimensions;
{
auto attr = reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions");
CHECK(attr);
absl::c_transform(
attr, std::back_inserter(dimensions),
std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
}
return GetReductionKindAndContiguousComponentsImpl(operand_shape, dimensions);
}
// This emits a device-side call to
// "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
// http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
llvm::Value* EmitPrintf(absl::string_view fmt,
absl::Span<llvm::Value* const> arguments,
llvm::IRBuilder<>* builder) {
std::vector<llvm::Type*> argument_types;
// Variadic arguments implicit promotion [1] converts float to double,
// and bool/char/short are converted to int.
// [1] https://en.cppreference.com/w/cpp/language/variadic_arguments
auto requires_int32_promotion = [](llvm::Type* type) {
return type->isIntegerTy(/*BitWidth=*/1) ||
type->isIntegerTy(/*BitWidth=*/8) ||
type->isIntegerTy(/*BitWidth=*/16);
};
auto requires_double_promotion = [](llvm::Type* type) {
return type->isFloatingPointTy();
};
for (auto argument : arguments) {
llvm::Type* type = argument->getType();
if (requires_double_promotion(type)) {
argument_types.push_back(builder->getDoubleTy());
} else if (requires_int32_promotion(type)) {
argument_types.push_back(builder->getInt32Ty());
} else {
argument_types.push_back(type);
}
}
auto* arguments_type = llvm::StructType::create(argument_types);
llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
for (size_t i = 0; i < arguments.size(); ++i) {
llvm::Value* value = arguments[i];
llvm::Type* type = value->getType();
if (requires_double_promotion(type)) {
value = builder->CreateFPCast(value, builder->getDoubleTy());
} else if (requires_int32_promotion(type)) {
value = builder->CreateIntCast(value, builder->getInt32Ty(),
/*isSigned=*/true);
}
builder->CreateStore(
value, builder->CreateGEP(arguments_ptr, {builder->getInt64(0),
builder->getInt32(i)}));
}
llvm::Type* ptr_ty = builder->getInt8Ty()->getPointerTo();
return builder->CreateCall(
builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
"vprintf",
llvm::FunctionType::get(builder->getInt32Ty(), {ptr_ty, ptr_ty},
/*isVarArg=*/false)),
{builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
builder->CreatePointerCast(arguments_ptr, ptr_ty)});
}
// Helper function to emit call to AMDGPU shfl_down function.
llvm::Value* EmitAMDGPUShflDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* b) {
llvm::Module* module = b->GetInsertBlock()->getModule();
CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
auto* i32_ty = b->getInt32Ty();
llvm::FunctionCallee shfl_fn = module->getOrInsertFunction(
llvm_ir::AsStringRef("__ockl_readuplane_i32"),
llvm::FunctionType::get(/*Result=*/i32_ty, {i32_ty, i32_ty},
/*isVarArg=*/false));
// AMDGPU device function requires first argument as i32.
llvm::Value* result =
b->CreateCall(shfl_fn, {b->CreateBitCast(value, i32_ty), offset});
// AMDGPU device function always returns an i32 type.
return b->CreateBitCast(result, value->getType());
}
// Helper function to emit call to NVPTX shfl_down intrinsic.
llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* b) {
llvm::Module* module = b->GetInsertBlock()->getModule();
llvm::Intrinsic::ID llvm_intrinsic_id;
CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
if (value->getType()->isFloatTy()) {
llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_f32;
} else {
llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_i32;
}
llvm::Function* intrinsic =
llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {});
return b->CreateCall(
intrinsic, {b->getInt32(-1), value, offset, b->getInt32(kWarpSize - 1)});
}
llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* builder) {
int bit_width = value->getType()->getPrimitiveSizeInBits();
llvm::Module* module = builder->GetInsertBlock()->getModule();
llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
// Special case for efficiency
if (value->getType()->isFloatTy() && bit_width == 32) {
if (target_triple.isNVPTX()) {
return EmitNVPTXShflDown(value, offset, builder);
} else if (target_triple.getArch() == llvm::Triple::amdgcn) {
return EmitAMDGPUShflDown(value, offset, builder);
} else {
LOG(FATAL) << "Invalid triple " << target_triple.str();
}
}
// We must split values wider than 32 bits as the "shfl" instruction operates
// on 32-bit values.
int num_segments = CeilOfRatio(bit_width, 32);
llvm::Value* x = builder->CreateBitCast(
builder->CreateZExt(
builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
builder->getIntNTy(32 * num_segments)),
llvm::VectorType::get(builder->getInt32Ty(), num_segments, false));
for (int i = 0; i < num_segments; ++i) {
llvm::Value* insert_val;
if (target_triple.isNVPTX()) {
insert_val = EmitNVPTXShflDown(builder->CreateExtractElement(x, i),
offset, builder);
} else if (target_triple.getArch() == llvm::Triple::amdgcn) {
insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i),
offset, builder);
} else {
LOG(FATAL) << "Invalid triple " << target_triple.str();
}
x = builder->CreateInsertElement(x, insert_val, i);
}
return builder->CreateBitCast(
builder->CreateTrunc(
builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
builder->getIntNTy(bit_width)),
value->getType());
}
StatusOr<CudnnConvKind> GetCudnnConvKind(
const HloCustomCallInstruction* instr) {
absl::string_view target = instr->custom_call_target();
if (target == kCudnnConvForwardCallTarget) {
return CudnnConvKind::kForward;
}
if (target == kCudnnConvBackwardInputCallTarget) {
return CudnnConvKind::kBackwardInput;
}
if (target == kCudnnConvBackwardFilterCallTarget) {
return CudnnConvKind::kBackwardFilter;
}
if (target == kCudnnConvBiasActivationForwardCallTarget) {
return CudnnConvKind::kForwardActivation;
}
return InternalError("Unexpected call target: %s", target);
}
string CudnnConvKindToString(CudnnConvKind kind) {
switch (kind) {
case CudnnConvKind::kForward:
return "forward";
case CudnnConvKind::kBackwardFilter:
return "backward_filter";
case CudnnConvKind::kBackwardInput:
return "backward_input";
case CudnnConvKind::kForwardActivation:
return "forward with activation";
}
}
llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
llvm::Value* is_thread0 = b->CreateICmpEQ(
b->getInt32(0),
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b));
llvm::Value* is_block0 = b->CreateICmpEQ(
b->getInt32(0),
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b));
return b->CreateAnd(is_thread0, is_block0);
}
bool IsFusedReductionOutputConsistent(const HloInstruction* inst,
const HloInstruction* first_reduce) {
if (IsReductionFromOrToContiguousDimensions(*inst)) {
// Shapes, layouts and dimensions must be the same for all reduces
// inside of this fusion.
// TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
return ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
ShapeUtil::Equal(first_reduce->operand(0)->shape(),
inst->operand(0)->shape()) &&
ShapeUtil::Equal(first_reduce->operand(1)->shape(),
inst->operand(1)->shape()) &&
first_reduce->dimensions() == inst->dimensions();
}
return ShapeUtil::CompatibleIgnoringElementType(
first_reduce->operand(0)->shape(), inst->shape()) &&
LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
inst->shape().layout());
}
bool IsFusedReductionOutputConsistent(
mlir::mhlo::ReduceOp inst, mlir::mhlo::ReduceOp first_reduce,
const FusionLayoutAnalysis& layout_analysis) {
CHECK_EQ(1, first_reduce.getNumResults());
Shape first_reduce_operand_shape =
layout_analysis.GetShape(first_reduce.inputs()[0]);
CHECK_EQ(1, inst.getNumResults());
Shape inst_shape = layout_analysis.GetShape(inst.getResult(0));
if (IsReductionFromOrToContiguousDimensions(inst, layout_analysis)) {
Shape first_reduce_shape =
layout_analysis.GetShape(first_reduce.getResult(0));
Shape first_reduce_init_shape =
layout_analysis.GetShape(first_reduce.init_values()[0]);
Shape inst_operand_shape = layout_analysis.GetShape(inst.inputs()[0]);
Shape inst_init_shape = layout_analysis.GetShape(inst.init_values()[0]);
// Shapes, layouts and dimensions must be the same for all reduces
// inside of this fusion.
// TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
if (!(ShapeUtil::Equal(first_reduce_shape, inst_shape) &&
ShapeUtil::Equal(first_reduce_operand_shape, inst_operand_shape) &&
ShapeUtil::Equal(first_reduce_init_shape, inst_init_shape) &&
absl::c_equal(first_reduce.dimensions(), inst.dimensions()))) {
return false;
}
} else {
if (!(ShapeUtil::CompatibleIgnoringElementType(first_reduce_operand_shape,
inst_shape) &&
LayoutUtil::Equal(first_reduce_operand_shape.layout(),
inst_shape.layout()))) {
return false;
}
}
return true;
}
// Given an LMHLO op, returns the operand index of the first output operand.
//
// Notice that an operand alised to an output isn't an output, even though in
// that case WritesMlirBuffer() returns true on that operand.
//
// An operand is !WritesMlirBuffer() || equals (aliases) to a later operand. An
// output is the opposite, being both WritesMlirBuffer() and does not equal to
// any later operand.
int PartitionLmhloOperandsAndOutputs(mlir::Operation* op) {
CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo"));
int i;
for (i = op->getOperands().size() - 1; i >= 0; i--) {
const bool aliased =
std::find(op->getOperands().begin() + i + 1, op->getOperands().end(),
op->getOperand(i)) != op->getOperands().end();
if (!WritesMlirBuffer(op, op->getOperand(i)) || aliased) {
break;
}
}
return i + 1;
}
std::vector<mlir::Value> GetHloOperands(mlir::Operation* op) {
if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
return ToStdVector(fusion.getInputBuffers());
}
if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
int output_start = PartitionLmhloOperandsAndOutputs(op);
std::vector<mlir::Value> operands;
operands.reserve(output_start);
for (int i = 0; i < output_start; i++) {
operands.push_back(op->getOperand(i));
}
return operands;
}
if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
return std::vector<mlir::Value>(op->getOperands().begin(),
op->getOperands().end());
}
LOG(FATAL) << "Unexpected op: " << MlirToString(op);
}
std::vector<mlir::Value> GetHloOutputs(mlir::Operation* op) {
if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
return ToStdVector(fusion.getOutputBuffers());
}
if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
int output_start = PartitionLmhloOperandsAndOutputs(op);
std::vector<mlir::Value> outputs;
for (int i = output_start; i < op->getNumOperands(); i++) {
outputs.push_back(op->getOperand(i));
}
return outputs;
}
if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
return std::vector<mlir::Value>(op->getResults().begin(),
op->getResults().end());
}
LOG(FATAL) << "Unexpected op: " << MlirToString(op);
}
bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand) {
llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
mlir::cast<mlir::MemoryEffectOpInterface>(op).getEffectsOnValue(operand,
effects);
return absl::c_any_of(
effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
});
}
static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) {
// For i1 memrefs, the underlying allocation is 8 bits.
if (type.getElementType().isInteger(/*width=*/1)) {
return type.getNumElements();
} else {
return type.getSizeInBits() / CHAR_BIT;
}
}
static int64_t GetAllocationIndex(mlir::BlockArgument func_arg,
std::string* constant_name) {
auto func_op =
mlir::cast<mlir::FuncOp>(func_arg.getParentRegion()->getParentOp());
if (constant_name) {
if (auto constant_name_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
func_arg.getArgNumber(), "lmhlo.constant_name")) {
*constant_name = constant_name_attr.getValue().str();
}
}
return func_arg.getArgNumber();
}
StatusOr<BufferAllocation::Slice> GetAllocationSlice(
mlir::Value v, absl::Span<const BufferAllocation> allocations,
std::string* constant_name) {
if (constant_name) {
constant_name->clear();
}
int64_t size = GetMemRefSizeInBytes(v.getType().cast<mlir::MemRefType>());
// We match the following patterns here:
// base := ViewOp(arg) | get_global_memref (global_memref) | arg
// root := base | MemRefReinterpretCastOp(base)
if (auto cast = mlir::dyn_cast_or_null<mlir::memref::ReinterpretCastOp>(
v.getDefiningOp())) {
v = cast.getViewSource();
}
if (auto view =
mlir::dyn_cast_or_null<mlir::memref::ViewOp>(v.getDefiningOp())) {
TF_RET_CHECK(view.source().isa<mlir::BlockArgument>());
return BufferAllocation::Slice(
&allocations[GetAllocationIndex(
view.source().cast<mlir::BlockArgument>(), constant_name)],
mlir::cast<mlir::ConstantOp>(view.byte_shift().getDefiningOp())
.value()
.cast<mlir::IntegerAttr>()
.getValue()
.getSExtValue(),
size);
}
if (auto get_global = mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>(
v.getDefiningOp())) {
auto module = get_global->getParentOfType<mlir::ModuleOp>();
if (constant_name) {
*constant_name = get_global.name().str();
}
auto global = mlir::cast<mlir::memref::GlobalOp>(
module.lookupSymbol(get_global.name()));
int64_t index =
global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
return BufferAllocation::Slice(&allocations[index], 0,
allocations[index].size());
}
if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
return BufferAllocation::Slice(
&allocations[GetAllocationIndex(arg, constant_name)], 0, size);
}
return Unimplemented(
"Operand has to be in the form of ViewOp(arg) or "
"StaticMemRefCastOp(ViewOp(arg)) or arg");
}
bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
mlir::lmhlo::FusionOp fusion,
absl::Span<const BufferAllocation> allocations) {
auto results = fusion.getFusionResults();
if (results.size() != 1) {
return false;
}
auto dus = mlir::dyn_cast<mlir::mhlo::DynamicUpdateSliceOp>(
results[0].getDefiningOp());
if (!dus) {
return false;
}
auto output_buffers = fusion.getOutputBuffers();
CHECK_EQ(1, output_buffers.size());
auto parameter =
mlir::dyn_cast<mlir::memref::TensorLoadOp>(dus.operand().getDefiningOp());
if (!parameter) {
return false;
}
auto maybe_lhs = GetAllocationSlice(parameter.memref(), allocations);
auto maybe_rhs = GetAllocationSlice(output_buffers[0], allocations);
return maybe_lhs.ok() && maybe_rhs.ok() && *maybe_lhs == *maybe_rhs;
}
Shape GetShape(mlir::Value value) {
if (value.getType().isa<mlir::MemRefType>()) {
return TypeToShape(value.getType());
} else if (value.getType().isa<mlir::TensorType>()) {
return GetShapeFromTensorType(value);
} else if (value.getType().isa<mlir::TupleType>()) {
return TypeToShape(value.getType());
}
LOG(FATAL) << "Unexpected value type to get shape for";
return {};
}
bool ReductionIsRaceFree(const ReductionDimensions& reduction_dimensions,
const std::array<int64_t, 3>& reduction_tiling) {
return (reduction_dimensions.is_row_reduction &&
reduction_dimensions.dimensions[2] <=
kMinThreadsXRowReduction * reduction_tiling[2] &&
reduction_dimensions.dimensions[0] <=
kBatchedReductionRaceFreeBound) ||
(!reduction_dimensions.is_row_reduction &&
reduction_dimensions.dimensions[1] <=
kWarpSize * reduction_tiling[1]);
}
} // namespace gpu
} // namespace xla