blob: e1ad14600d768f831a9a4ff4d2ad7e3d58c16f64 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include <memory>
#include <vector>
#include "absl/strings/str_cat.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" // from @llvm-project
#include "mlir/EDSC/Builders.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
using llvm_ir::SetToFirstInsertPoint;
namespace cpu {
namespace {
// Returns true if we should call into multi-threaded Eigen routines.
bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) {
return config.debug_options().xla_cpu_multi_thread_eigen();
}
// Represents a dot operation. We use this in lieu of an `HloInstruction`
// because we want to be able to create this for the "inner" dot operation in a
// batch dot, for which there is no separate HLO instruction.
struct DotInfo {
Shape lhs_shape;
Shape rhs_shape;
Shape result_shape;
DotDimensionNumbers dim_nums;
DotInfo() = default;
explicit DotInfo(const HloInstruction& instr) {
CHECK_EQ(instr.opcode(), HloOpcode::kDot);
lhs_shape = instr.operand(0)->shape();
rhs_shape = instr.operand(1)->shape();
result_shape = instr.shape();
dim_nums = instr.dot_dimension_numbers();
}
};
// Dictates how a dot operation is implemented.
enum class DotImplementationStrategy {
// The dot operation is lowered into LLVM IR that implements a naive nested
// loop that computes the result one element at a time. This is our
// "fallback"; we don't really want this to kick in for any non-trival dot
// operation.
kNaiveLlvmIr,
// The dot operation is lowered into LLVM IR that implements a tiled
// Matrix*Vector operation. This strategy also allows fusing in a bias add
// into the dot. The matrix can be row major or column major, both are
// supported.
kTiledLlvmIrGemv,
// The dot operation is lowered into LLVM IR that implements a tiled
// Matrix*Matrix operation. No fusions are supported. The two inputs
// and the output have to be row major.
kTiledLlvmIrGemm,
// The dot operation is lowered into linalg.matmul op and lowered to LLVM IR.
kLinalgMatmul,
// The dot operation is lowered into a call into an Eigen routine. No fusions
// are supported today. The two inputs and the output have to be row major.
// However, we do allow transposing either the LHS or the RHS as part of the
// GEMM -- we expose this flexibility as flexibility in the contraction
// dimensions, but we can also see this as flexibility in the input layouts.
kEigen,
};
// Returns the implementation strategy for a dot with the configuration
// `dot_info`.
DotImplementationStrategy GetDotImplementationStrategy(
const HloModuleConfig& config, const DotInfo& dot_info,
const TargetMachineFeatures& target_machine_features);
// Helper class for emitting LLVM IR to perform the dot operation.
class DotOpEmitter {
public:
explicit DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& lhs_array,
const llvm_ir::IrArray& rhs_array,
const llvm_ir::IrArray* addend_array,
llvm::Value* executable_run_options_value,
llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
const HloModuleConfig& hlo_module_config,
const TargetMachineFeatures& target_machine_features);
// Emits the IR to perform the dot operation.
Status Emit();
private:
// Emits instructions to perform a scalar dot product (a multiply of the
// LHS and RHS) and store the results in the target.
Status EmitScalarDot();
// Emits a call to the CPU runtime to perform the matrix multiply.
Status EmitCallToRuntime();
// Represents the dimensions of a matrix-matrix multiply operation.
struct MatMultDims {
// The number of rows in the LHS.
int64 m;
// The number of columns in the LHS, which is also must be equal to the
// number of rows in the RHS.
int64 k;
// The number of columns on the RHS.
int64 n;
// True if the LHS matrix is column major.
bool lhs_column_major;
// True if the LHS contraction dimension is 1.
bool lhs_canonical;
// True if the RHS matrix is column major.
bool rhs_column_major;
// True if the RHS contraction dimension is 0.
bool rhs_canonical;
};
// Get the MatMultDims instance for the dot product this DotOpEmitter
// represents. Precondition: the dot is of rank 2 (and thus its operands are
// of rank 2 as well).
MatMultDims GetMatMultDims() const;
// Lowers the dot operation as a tiled Matrix*Vector loop.
void EmitTiledLlvmIrGemv();
// Lowers the dot operation as a tiled Matrix*Matrix loop.
void EmitTiledLlvmIrGemm();
// Lowers the dot operation through MLIR's linalg.matmul.
Status EmitLinalgMatmul();
// Lowers the dot operation as a naive nested loop that computes the result
// one element at a time.
void EmitNaiveLlvmIrGemm();
// When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
// registers.
int64 GetGemvTilingFactor() const {
const int64 kDefaultTilingFactor = 8;
return options::LlvmIrGemvTilingFactor(hlo_module_config_)
.value_or(kDefaultTilingFactor);
}
std::tuple<int64, int64, int64> GetGemmTileSize() const {
// Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
//
// TODO(b/80093688): Tune for other architectures and centralize this
// information in one place.
const std::tuple<int64, int64, int64> kDefaultTileSize =
std::tuple<int64, int64, int64>(11, 9, 1);
return options::LlvmIrGemmTileSize(hlo_module_config_)
.value_or(kDefaultTileSize);
}
DotInfo dot_info_;
string dot_hlo_name_;
const llvm_ir::IrArray& target_array_;
const llvm_ir::IrArray& lhs_array_;
const llvm_ir::IrArray& rhs_array_;
const llvm_ir::IrArray* addend_array_;
llvm::Value* executable_run_options_value_;
llvm::IRBuilder<>* b_;
mlir::MLIRContext* mlir_context_;
const HloModuleConfig& hlo_module_config_;
const TargetMachineFeatures& target_machine_features_;
};
} // namespace
DotOpEmitter::DotOpEmitter(
DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
const llvm_ir::IrArray* addend_array,
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
const TargetMachineFeatures& target_machine_features)
: dot_info_(std::move(dot_info)),
dot_hlo_name_(std::move(dot_hlo_name)),
target_array_(target_array),
lhs_array_(lhs_array),
rhs_array_(rhs_array),
addend_array_(addend_array),
executable_run_options_value_(executable_run_options_value),
b_(b),
mlir_context_(mlir_context),
hlo_module_config_(hlo_module_config),
target_machine_features_(target_machine_features) {}
Status DotOpEmitter::EmitLinalgMatmul() {
Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape};
llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(),
rhs_array_.GetBasePointer()};
llvm::Value* target_ptr = target_array_.GetBasePointer();
// Zero out the output buffer.
int64 size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape);
b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes,
/*Align=*/llvm::MaybeAlign(1));
std::string name =
absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
dot_info_.lhs_shape.ToString(true), "_",
dot_info_.rhs_shape.ToString(true));
return EmitMlirFuncAndCall(
mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
mlir::edsc::ScopedContext scope(*builder, function.getLoc());
mlir::Value a = function.getArgument(0), b = function.getArgument(1),
c = function.getArgument(2);
mlir::edsc::intrinsics::linalg_matmul(b, c, a);
mlir::edsc::intrinsics::std_ret();
});
}
void DotOpEmitter::EmitTiledLlvmIrGemm() {
PrimitiveType primitive_type = dot_info_.result_shape.element_type();
MatMultDims mat_mult_dims = GetMatMultDims();
llvm::Value* lhs = lhs_array_.GetBasePointer();
llvm::Value* rhs = rhs_array_.GetBasePointer();
llvm::Value* target = target_array_.GetBasePointer();
int64 m = mat_mult_dims.m;
int64 k = mat_mult_dims.k;
int64 n = mat_mult_dims.n;
if (mat_mult_dims.lhs_column_major) {
std::swap(lhs, rhs);
std::swap(m, n);
}
int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes,
/*Align=*/llvm::MaybeAlign(1));
int64 max_target_vector_width =
target_machine_features_.vector_register_num_elements(
*b_->GetInsertBlock()->getParent(), primitive_type);
int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width;
std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
GetGemmTileSize();
EmitSmallGemm(
/*scalar_type=*/primitive_type,
/*m=*/m, /*k=*/k, /*n=*/n,
/*max_vectorization_width=*/max_target_vector_width,
/*max_vector_count=*/tile_size_n_in_vector_width,
/*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
/*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
/*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
}
void DotOpEmitter::EmitTiledLlvmIrGemv() {
PrimitiveType primitive_type = dot_info_.result_shape.element_type();
CHECK(primitive_util::IsFloatingPointType(primitive_type) ||
primitive_util::IsIntegralType(primitive_type));
MatMultDims mat_mult_dims = GetMatMultDims();
bool is_column_major_matrix_vector_gemv = false;
bool is_row_major_matrix_vector_gemv = false;
int64 m, k;
bool swap_operands;
if (mat_mult_dims.m == 1) {
// Our emitters can only do Matrix*Vector (abbreviated as M*V) but when M=1
// we actually want V*M. We implement V*M as follows (Tr(X) = Transpose of
// X):
//
// V*M = Tr(Tr(V*M)) // Tr(Tr(X)) == X
// = Tr(Tr(M) * Tr(V)) // Tr(A * B) == Tr(B) * Tr(A)
//
// Since transposing a vector is physically a no-op, this is really
// equivalent to `Tr(M) * V`. We further implement Tr(M) by pretending that
// M is row major if it is actually column major and vice-versa.
bool rhs_effectively_column_major = mat_mult_dims.rhs_canonical
? mat_mult_dims.rhs_column_major
: !mat_mult_dims.rhs_column_major;
if (rhs_effectively_column_major) {
k = mat_mult_dims.k;
m = mat_mult_dims.n;
// We set is_row_major_matrix_vector_gemv and not
// is_column_major_matrix_vector_gemv to implement the Transpose trick
// mentioned above.
is_row_major_matrix_vector_gemv = true;
swap_operands = true;
} else {
k = mat_mult_dims.k;
m = mat_mult_dims.n;
// We set is_column_major_matrix_vector_gemv and not
// is_row_major_matrix_vector_gemv to implement the Transpose trick
// mentioned above.
is_column_major_matrix_vector_gemv = true;
swap_operands = true;
}
}
if (mat_mult_dims.n == 1) {
bool lhs_effectively_column_major = mat_mult_dims.lhs_canonical
? mat_mult_dims.lhs_column_major
: !mat_mult_dims.lhs_column_major;
if (lhs_effectively_column_major) {
m = mat_mult_dims.m;
k = mat_mult_dims.k;
is_column_major_matrix_vector_gemv = true;
swap_operands = false;
} else {
m = mat_mult_dims.m;
k = mat_mult_dims.k;
is_row_major_matrix_vector_gemv = true;
swap_operands = false;
}
}
CHECK(is_column_major_matrix_vector_gemv || is_row_major_matrix_vector_gemv);
int64 tiling_factor = GetGemvTilingFactor();
CHECK_GT(tiling_factor, 0);
llvm::Value* result_op = target_array_.GetBasePointer();
llvm::Value* lhs_op =
swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
llvm::Value* rhs_op =
swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
const int target_vector_register_element_size =
target_machine_features_.vector_register_num_elements(
*b_->GetInsertBlock()->getParent(), primitive_type);
// We may not always know the vector register size for the target we're
// compiling against, in which case target_vector_register_element_size is 0.
// In these cases we choose a default LLVM IR register size.
const int kUnknownTargetVectorRegisterSize = 4;
const int vector_register_element_size =
target_vector_register_element_size == 0
? kUnknownTargetVectorRegisterSize
: target_vector_register_element_size;
if (is_column_major_matrix_vector_gemv) {
VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
<< " and k = " << k;
EmitColumnMajorGemv(
/*scalar_type=*/primitive_type,
/*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
/*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
/*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
/*result=*/result_op, b_, hlo_module_config_);
} else {
VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
<< " and k = " << k;
EmitRowMajorGemv(
/*scalar_type=*/primitive_type,
/*tile_rows=*/tiling_factor,
/*tile_cols=*/vector_register_element_size,
/*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
/*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
/*result=*/result_op, b_, hlo_module_config_);
}
}
Status DotOpEmitter::Emit() {
// The dot operation performs a sum of products over dimension 0 of the left
// hand side operand and dimension 1 of the right hand side operand.
//
// Let the shapes of lhs and rhs be defined as below:
//
// lhs = [L{n-1} x L{n-2} x ... L{0}]
// rhs = [R{m-1} x R{m-2} x ... R{0}]
//
// The sum-of-products dimension in the lhs has size L{0} and the dimension in
// the rhs has size R{1}. Necessarily, then:
//
// L{0} == R{1}
//
// The output of the operation has the following shape:
//
// output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
//
// To perform the operation we construct a loop nest with one for-loop for
// each dimension of the output. Inside this loop nest is another for-loop
// which performs the sum-of-products (the reduction loop) before storing
// the result in the output buffer.
const Shape& lhs_shape = lhs_array_.GetShape();
const Shape& rhs_shape = rhs_array_.GetShape();
if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
// If the operands are scalar, don't emit any loops.
TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
ShapeUtil::IsScalar(rhs_shape));
return EmitScalarDot();
}
switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_,
target_machine_features_)) {
case DotImplementationStrategy::kNaiveLlvmIr:
EmitNaiveLlvmIrGemm();
return Status::OK();
case DotImplementationStrategy::kTiledLlvmIrGemv:
EmitTiledLlvmIrGemv();
return Status::OK();
case DotImplementationStrategy::kTiledLlvmIrGemm:
EmitTiledLlvmIrGemm();
return Status::OK();
case DotImplementationStrategy::kLinalgMatmul:
return EmitLinalgMatmul();
case DotImplementationStrategy::kEigen:
return EmitCallToRuntime();
}
}
void DotOpEmitter::EmitNaiveLlvmIrGemm() {
CHECK_EQ(addend_array_, nullptr);
const Shape& lhs_shape = lhs_array_.GetShape();
const Shape& rhs_shape = rhs_array_.GetShape();
const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
// Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
// case where the reduction dimension is 0 for both LHS and RHS. This results
// in a vector dot product producing a scalar.
int64 lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
int64 rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);
// Verify the reduction dimension in the two operands are the same size.
CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
rhs_shape.dimensions(rhs_reduction_dimension));
bool lhs_reduction_along_minor_dimension =
lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
bool rhs_reduction_along_minor_dimension =
rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);
// Create loop nests which loop through the LHS operand dimensions and the RHS
// operand dimensions. The reduction dimension of the LHS and RHS are handled
// in a separate innermost loop which performs the sum of products.
llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_);
std::vector<llvm::Value*> lhs_multi_index =
loop_nest.EmitOperandArrayLoopNest(
lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
std::vector<llvm::Value*> rhs_multi_index =
loop_nest.EmitOperandArrayLoopNest(
rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
// Create the loop which does the sum of products reduction.
//
// The prevent_unrolling bit is working around a deficiency in LLVM's loop
// vectorization pipeline, wherein in some cases unrolling a loop can prevent
// effective vectorization. Since we know that the IR we generate when
// reducing across the minor dimension in both LHS and RHS is vectorized well
// by the loop vectorizer, we block unrolling in that case to stop loop unroll
// from messing up the vectorization.
std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
/*unroll_mode=*/
(lhs_reduction_along_minor_dimension &&
rhs_reduction_along_minor_dimension)
? xla::llvm_ir::UnrollMode::kNoUnroll
: xla::llvm_ir::UnrollMode::kDefaultUnroll);
// The final entry in the rhs and lhs indexes is the indvar of the
// reduction loop.
lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape,
b_->getInt64Ty());
rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape,
b_->getInt64Ty());
// For computing the sum of products we alloca a single location to store the
// dot product result as we accumulate it within the reduction loop. After the
// reduction loop we load the result and store into the output array.
// Function entry basic block.
// - Emit alloca for accumulator
llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
SetToFirstInsertPoint(&func->getEntryBlock(), b_);
llvm::Type* accum_type = target_array_.GetElementLlvmType();
llvm::Value* accum_address =
b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address");
// Preheader basic block of reduction loop:
// - Initialize accumulator to zero.
llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
b_->SetInsertPoint(preheader_bb->getTerminator());
b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address);
// Body basic block of reduction loop:
// - Load elements from lhs and rhs array.
// - Multiply lhs-element and rhs-element.
// - Load accumulator and add to product.
// - Store sum back into accumulator.
SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_);
llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_);
llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_);
llvm::Value* accum = b_->CreateLoad(accum_address);
llvm::Value* updated_accum;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); };
auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); };
llvm::Value* product_real =
b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)),
b_->CreateFMul(imag(lhs_element), imag(rhs_element)));
llvm::Value* product_imag =
b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)),
b_->CreateFMul(imag(lhs_element), real(rhs_element)));
updated_accum = b_->CreateInsertValue(
accum, b_->CreateFAdd(real(accum), product_real), {0});
updated_accum = b_->CreateInsertValue(
updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1});
} else if (ShapeUtil::ElementIsIntegral(lhs_shape)) {
llvm::Value* product = b_->CreateMul(lhs_element, rhs_element);
updated_accum = b_->CreateAdd(accum, product);
} else {
llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element);
updated_accum = b_->CreateFAdd(accum, product);
}
b_->CreateStore(updated_accum, accum_address);
// Exit basic block of reduction loop.
// - Load accumulator value (the result).
// - Store into output array.
SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_);
llvm::Value* result = b_->CreateLoad(accum_address);
// Create index into target address. The target index is the concatenation of
// the rhs and lhs indexes with the reduction dimensions removed. The terms
// from the rhs index are the lower dimensions in the index so we add them
// first.
std::vector<llvm::Value*> target_multi_index;
for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
if (dimension != lhs_reduction_dimension) {
target_multi_index.push_back(lhs_index[dimension]);
}
}
for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
if (dimension != rhs_reduction_dimension) {
target_multi_index.push_back(rhs_index[dimension]);
}
}
llvm_ir::IrArray::Index target_index(
target_multi_index, target_array_.GetShape(), lhs_index.GetType());
target_array_.EmitWriteArrayElement(target_index, result, b_);
// Set the IR builder insert point to the exit basic block of the outer most
// loop.
b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
}
Status DotOpEmitter::EmitScalarDot() {
// A scalar dot is just a scalar multiply.
llvm::Value* result;
// Use the same index_type for all tensor accesses in the same kernel.
llvm::Type* index_type = b_->getInt64Ty();
llvm_ir::IrArray::Index element_index(index_type);
llvm::Value* lhs_value =
lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
llvm::Value* rhs_value =
rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
auto get_real = [&](llvm::Value* x) {
return b_->CreateExtractValue(x, {0});
};
auto get_imag = [&](llvm::Value* x) {
return b_->CreateExtractValue(x, {1});
};
llvm::Value* real = b_->CreateFSub(
b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)),
b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value)));
llvm::Value* imag = b_->CreateFAdd(
b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)),
b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value)));
result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
result = b_->CreateInsertValue(result, real, {0});
result = b_->CreateInsertValue(result, imag, {1});
} else {
result = b_->CreateFMul(lhs_value, rhs_value);
}
target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_);
return Status::OK();
}
Status DotOpEmitter::EmitCallToRuntime() {
// The signature of the Eigen runtime matmul function is:
//
// (void)(void* run_options, float* out, float* lhs, float* rhs,
// int64 m, int64 n, int64 k, int32 transpose_lhs,
// int32 transpose_rhs);
// The two transpose_... parameters are actually booleans, but we use int32
// to avoid target-dependent calling convention details.
bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
PrimitiveType type = target_array_.GetShape().element_type();
llvm::Type* float_type;
const char* fn_name;
switch (type) {
case F16:
fn_name = multi_threaded
? runtime::kEigenMatMulF16SymbolName
: runtime::kEigenSingleThreadedMatMulF16SymbolName;
float_type = b_->getHalfTy();
break;
case F32:
fn_name = multi_threaded
? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
: runtime::kEigenMatMulF32SymbolName)
: (use_mkl_dnn
? runtime::kMKLSingleThreadedMatMulF32SymbolName
: runtime::kEigenSingleThreadedMatMulF32SymbolName);
float_type = b_->getFloatTy();
break;
case F64:
fn_name = multi_threaded
? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
: runtime::kEigenMatMulF64SymbolName)
: (use_mkl_dnn
? runtime::kMKLSingleThreadedMatMulF64SymbolName
: runtime::kEigenSingleThreadedMatMulF64SymbolName);
float_type = b_->getDoubleTy();
break;
case S32:
fn_name = multi_threaded
? runtime::kEigenMatMulS32SymbolName
: runtime::kEigenSingleThreadedMatMulS32SymbolName;
float_type = b_->getInt32Ty();
break;
default:
return Unimplemented("Invalid type %s for dot operation",
PrimitiveType_Name(type));
}
llvm::Type* float_ptr_type = float_type->getPointerTo();
llvm::Type* int64_type = b_->getInt64Ty();
llvm::Type* int32_type = b_->getInt32Ty();
llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
llvm::FunctionType* matmul_type = llvm::FunctionType::get(
b_->getVoidTy(),
{int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
int64_type, int64_type, int64_type, int32_type, int32_type},
/*isVarArg=*/false);
llvm::Function* function = b_->GetInsertBlock()->getParent();
llvm::Module* module = function->getParent();
llvm::FunctionCallee matmul_func =
module->getOrInsertFunction(fn_name, matmul_type);
if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
fn->setCallingConv(llvm::CallingConv::C);
fn->setDoesNotThrow();
fn->setOnlyAccessesArgMemory();
}
// The Eigen runtime function expects column-major layout. If the matrices are
// row major, then use the following identity to compute the product:
//
// (A x B)^T = B^T x A^T
//
// The connection between this identity and memory layout is that the
// transpose operation can also be considered as an operation that changes the
// memory layout of a matrix from row-major to column-major or vice versa.
//
// Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
MatMultDims mat_mult_dims = GetMatMultDims();
CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
const llvm_ir::IrArray* lhs = &lhs_array_;
const llvm_ir::IrArray* rhs = &rhs_array_;
bool transpose_lhs = !mat_mult_dims.lhs_canonical;
bool transpose_rhs = !mat_mult_dims.rhs_canonical;
if (!mat_mult_dims.lhs_column_major) {
std::swap(mat_mult_dims.m, mat_mult_dims.n);
std::swap(lhs, rhs);
std::swap(transpose_lhs, transpose_rhs);
}
b_->CreateCall(
matmul_func,
{b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs),
b_->getInt32(transpose_rhs)});
return Status::OK();
}
DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
CHECK_LE(dot_info_.result_shape.dimensions_size(), 2);
const Shape& lhs_shape = lhs_array_.GetShape();
const Shape& rhs_shape = rhs_array_.GetShape();
const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
auto is_column_major = [](const Shape& shape) {
return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
};
// Non-contracting dots should never make it here.
CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);
return {
/*m=*/lhs_shape.rank() <= 1
? 1LL
: lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)),
/*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
/*n=*/rhs_shape.rank() <= 1
? 1LL
: rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)),
/*lhs_column_major=*/is_column_major(lhs_shape),
/*lhs_canonical=*/lhs_shape.rank() <= 1 ||
dim_nums.lhs_contracting_dimensions(0) == 1,
/*rhs_column_major=*/is_column_major(rhs_shape),
/*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
}
// For vector-matrix dot products, it is always profitable to make the Rhs
// column major.
absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
const HloInstruction& hlo) {
if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) {
if (hlo.operand(0)->shape().rank() != 1 ||
hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) != 0) {
return {};
}
// Don't bother if the other operand is tiny, switching to column major
// wouldn't use tiling.
constexpr int kColumnMajorThresholdInBytes = 32;
int64 lhs_size =
ShapeUtil::ByteSizeOfPrimitiveType(hlo.shape().element_type()) *
ShapeUtil::ElementsIn(hlo.operand(0)->shape());
if (lhs_size < kColumnMajorThresholdInBytes) {
return {};
}
return 1;
}
if (hlo.IsOutputFusion()) {
auto* fusion_root =
hlo.fused_instructions_computation()->root_instruction();
if (fusion_root->opcode() != HloOpcode::kAdd) {
return {};
}
for (auto* fusion_root_op : fusion_root->operands()) {
if (fusion_root_op->opcode() != HloOpcode::kDot) {
continue;
}
if (auto operand_num =
ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
auto* operand = fusion_root_op->operand(*operand_num);
if (operand->opcode() == HloOpcode::kParameter &&
operand->user_count() == 1) {
return operand->parameter_number();
}
}
}
}
return {};
}
namespace {
// Return whether the given shape is rank 2.
bool IsRank2(const Shape& shape) { return shape.rank() == 2; }
bool IsSimpleLayout(const Layout& layout) {
return layout.tiles().empty() && layout.format() == DENSE;
}
// In a gemm operation where output = lhs * rhs, check whether the given shapes
// are valid for the operation.
bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
const Shape& output_shape,
const TargetMachineFeatures& target_machine_features) {
CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout()))
<< lhs_shape.DebugString();
CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout()))
<< rhs_shape.DebugString();
CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout()))
<< output_shape.DebugString();
switch (output_shape.element_type()) {
case F64:
case F32:
case F16:
case S32:
return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
default:
return false;
}
}
bool IsAlignedGemm(const DotInfo& dot_info,
const TargetMachineFeatures& target_machine_features) {
if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) ||
ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) {
return false;
}
return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape,
dot_info.result_shape, target_machine_features);
}
bool CanEmitTiledLlvmIrGemm(
const HloModuleConfig& config, const DotInfo& dot_info,
const TargetMachineFeatures& target_machine_features) {
CHECK(IsAlignedGemm(dot_info, target_machine_features));
if (ShouldUseMultiThreadedEigen(config)) {
return false;
}
int m = dot_info.result_shape.dimensions(0);
int k = dot_info.lhs_shape.dimensions(
dot_info.dim_nums.lhs_contracting_dimensions(0));
int n = dot_info.result_shape.dimensions(1);
if (!options::ForceEnableExperimentalLlvmIrGemm(config)) {
// TODO(sanjoy): We should make these numbers micro-arch specific.
bool small_gemm =
k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32));
if (!small_gemm) {
return false;
}
}
bool lhs_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 1;
bool rhs_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 0;
if (!(lhs_canonical && rhs_canonical)) {
return false;
}
if (dot_info.result_shape.element_type() == F16) {
// TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
// adding this comment NFC.
return false;
}
return true;
}
DotImplementationStrategy GetDotImplementationStrategy(
const HloModuleConfig& config, const DotInfo& dot_info,
const TargetMachineFeatures& target_machine_features) {
PrimitiveType element_type = dot_info.result_shape.element_type();
// Any Matrix-Vector product of floating point or integral type, or
// a transpose-dot fusion of the same can be lowered to a tiled LLVM
// IR implementation.
if ((dot_info.result_shape.dimensions_size() <= 1 ||
(dot_info.result_shape.dimensions_size() == 2 &&
(dot_info.result_shape.dimensions(0) == 1 ||
dot_info.result_shape.dimensions(1) == 1))) &&
(primitive_util::IsFloatingPointType(element_type) ||
primitive_util::IsIntegralType(element_type))) {
return DotImplementationStrategy::kTiledLlvmIrGemv;
}
if (IsAlignedGemm(dot_info, target_machine_features)) {
if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
return options::UseLinalgForDot(config)
? DotImplementationStrategy::kLinalgMatmul
: DotImplementationStrategy::kTiledLlvmIrGemm;
}
return DotImplementationStrategy::kEigen;
}
return DotImplementationStrategy::kNaiveLlvmIr;
}
Status EmitNonBatchDotOperation(
DotInfo dot_info, string hlo_name, const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
const llvm_ir::IrArray* addend_array,
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
const TargetMachineFeatures& target_machine_features) {
PrimitiveType type = target_array.GetShape().element_type();
TF_RET_CHECK(S32 == type || F16 == type || F32 == type || F64 == type ||
C64 == type || C128 == type);
DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
target_array, lhs_array, rhs_array, addend_array,
executable_run_options_value, b, mlir_context,
hlo_module_config, target_machine_features);
return dot_emitter.Emit();
}
Shape DropFirstDim(const Shape& shape) {
absl::Span<int64 const> array_shape_dims(shape.dimensions());
array_shape_dims.remove_prefix(1);
return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
array_shape_dims);
}
Shape CollapseFirstNDims(const Shape& shape, int64 n) {
absl::Span<int64 const> input_shape_dims(shape.dimensions());
int64 prefix_dim =
std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n,
1ll, std::multiplies<int64>());
DimensionVector result_dims;
result_dims.push_back(prefix_dim);
std::copy(input_shape_dims.begin() + n, input_shape_dims.end(),
std::back_inserter(result_dims));
return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
result_dims);
}
llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b,
const llvm_ir::IrArray& array, int64 n) {
llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
const Shape& shape = array.GetShape();
CHECK(shape.has_layout() &&
LayoutUtil::IsMonotonicWithDim0Major(shape.layout()));
CHECK_GE(shape.dimensions_size(), n);
Shape new_shape = CollapseFirstNDims(shape, n);
llvm::Value* new_value = b->CreateBitCast(
array.GetBasePointer(),
llvm_ir::ShapeToIrType(new_shape, module)->getPointerTo());
return llvm_ir::IrArray(new_value, std::move(new_shape));
}
Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) {
// Checks some invariants that do not hold in general, but DotDecomposer
// should have established for us. This is just a debugging aid.
TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
std::vector<int64> batch_dim_numbers(dim_numbers.lhs_batch_dimensions_size());
absl::c_iota(batch_dim_numbers, 0);
TF_RET_CHECK(
absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
TF_RET_CHECK(
absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
return Status::OK();
}
// Slice out the inner array at batch index `batch_index` from `outer_array`.
llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array,
llvm::Value* batch_index,
llvm::IRBuilder<>* b) {
llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
Shape inner_shape = DropFirstDim(outer_array.GetShape());
std::vector<llvm::Value*> multidim_index(inner_shape.rank() + 1,
b->getInt64(0));
multidim_index[0] = batch_index;
llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(),
batch_index->getType());
llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b);
llvm::Type* slice_ptr_type =
llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo();
return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type),
std::move(inner_shape));
}
Status EmitBatchDotOperation(
const HloInstruction& dot, const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
const TargetMachineFeatures& target_machine_features) {
TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
// Lower a batch dot into a sequence of non-batch dot operations.
int64 num_batch_dims =
dot.dot_dimension_numbers().lhs_batch_dimensions_size();
// First reshape the inputs to make sure we only have one batch dimension.
// This is a no-op bitcast because the operands have to be in row-major layout
// (enforced in CpuLayoutAssignment), and the batch dimensions are the leading
// dimensions (established by DotDecomposer and checked by
// ValidateDotDimensionNumbers above).
llvm_ir::IrArray lhs_array_reshaped =
CollapseFirstNDims(b, lhs_array, num_batch_dims);
llvm_ir::IrArray rhs_array_reshaped =
CollapseFirstNDims(b, rhs_array, num_batch_dims);
llvm_ir::IrArray target_array_reshaped =
CollapseFirstNDims(b, target_array, num_batch_dims);
int64 batch_count = lhs_array_reshaped.GetShape().dimensions(0);
KernelSupportLibrary ksl(b);
return ksl.ForWithStatus(
llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
/*step=*/1, [&](llvm::Value* indvar) {
DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
adjusted_dim_numbers.clear_lhs_batch_dimensions();
adjusted_dim_numbers.clear_rhs_batch_dimensions();
// Create a DotInfo representing the "inner" non-batch dot operation.
DotInfo dot_info;
dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape());
dot_info.dim_nums = dot.dot_dimension_numbers();
dot_info.dim_nums.clear_lhs_batch_dimensions();
dot_info.dim_nums.clear_rhs_batch_dimensions();
dot_info.dim_nums.set_lhs_contracting_dimensions(
0,
dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
dot_info.dim_nums.set_rhs_contracting_dimensions(
0,
dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);
llvm_ir::IrArray lhs_slice =
SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b);
llvm_ir::IrArray rhs_slice =
SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b);
llvm_ir::IrArray target_slice = SliceOutInnerArray(
target_array_reshaped, /*batch_index=*/indvar, b);
// Emit the inner non-batch dot operation.
return EmitNonBatchDotOperation(
dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
executable_run_options_value, b, mlir_context, hlo_module_config,
target_machine_features);
});
}
bool IsBatchDot(const HloInstruction& instr) {
if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
}
return false;
}
} // namespace
bool DotImplementationCanHandleTranspose(
const HloInstruction& dot_instr,
const TargetMachineFeatures& target_machine_features) {
DotImplementationStrategy impl_strategy =
GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
DotInfo(dot_instr), target_machine_features);
return impl_strategy == DotImplementationStrategy::kNaiveLlvmIr ||
impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemv ||
impl_strategy == DotImplementationStrategy::kEigen;
}
bool DotOperandsAndResultMustHaveRowMajorLayout(
const HloInstruction& dot_instr,
const TargetMachineFeatures& target_machine_features) {
// Batched dots require the batch dimensions to be major. DotDecomposer always
// moves batch dimensions to the front of the shape, so force a row-major
// layout.
if (IsBatchDot(dot_instr)) {
return true;
}
DotImplementationStrategy impl_strategy =
GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
DotInfo(dot_instr), target_machine_features);
return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
impl_strategy == DotImplementationStrategy::kEigen;
}
Status EmitDotOperation(const HloInstruction& dot,
const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& lhs_array,
const llvm_ir::IrArray& rhs_array,
const llvm_ir::IrArray* addend_array,
llvm::Value* executable_run_options_value,
llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
const HloModuleConfig& hlo_module_config,
const TargetMachineFeatures& target_machine_features) {
// This routine assumes that the dot operation is not in a parallelized
// enclosing computation.
CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty());
if (IsBatchDot(dot)) {
TF_RET_CHECK(addend_array == nullptr);
return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
executable_run_options_value, b, mlir_context,
hlo_module_config, target_machine_features);
}
return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
lhs_array, rhs_array, addend_array,
executable_run_options_value, b, mlir_context,
hlo_module_config, target_machine_features);
}
} // namespace cpu
} // namespace xla