blob: d3e00d04dfd2c73bfb5af1219f72b05712ee6bef [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
using absl::StrCat;
using llvm_ir::IrArray;
using llvm_ir::IrName;
using llvm_ir::SetToFirstInsertPoint;
namespace {
int64 GlobalRandomValue() {
static auto* mu = new tensorflow::mutex();
static std::mt19937_64 rng{42};
tensorflow::mutex_lock l(*mu);
return rng();
}
StatusOr<llvm::Value*> EmitReducePrecisionIR(PrimitiveType src_ty,
llvm::Value* x,
int64 dest_exponent_bits,
int64 dest_mantissa_bits,
llvm::IRBuilder<>* b) {
using llvm::APInt;
if (!primitive_util::IsFloatingPointType(src_ty)) {
return Unimplemented(
"ReducePrecision cannot accept non-floating-point type %s.",
PrimitiveType_Name(src_ty));
}
// Integer and float types for casting and constant generation.
llvm::Type* float_type = x->getType();
int64 nbits = float_type->getPrimitiveSizeInBits();
llvm::IntegerType* int_type = b->getIntNTy(nbits);
// SignificandWidth includes the implicit extra bit.
int src_mantissa_bits = primitive_util::SignificandWidth(src_ty) - 1;
int src_exponent_bits = nbits - 1 - src_mantissa_bits;
// Cast the input value to an integer for bitwise manipulation.
llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
if (dest_mantissa_bits < src_mantissa_bits) {
// Last remaining mantissa bit.
APInt last_mantissa_bit_mask(nbits, 1);
last_mantissa_bit_mask <<= src_mantissa_bits - dest_mantissa_bits;
// Compute rounding bias for round-to-nearest with ties to even. This is
// equal to a base value of 0111... plus one bit if the last remaining
// mantissa bit is 1.
APInt base_rounding_bias = last_mantissa_bit_mask.lshr(1) - 1;
llvm::Value* x_last_mantissa_bit = b->CreateLShr(
b->CreateAnd(x_as_int,
llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
(src_mantissa_bits - dest_mantissa_bits));
llvm::Value* x_rounding_bias =
b->CreateAdd(x_last_mantissa_bit,
llvm::ConstantInt::get(int_type, base_rounding_bias));
// Add rounding bias, and mask out truncated bits. Note that the case
// where adding the rounding bias overflows into the exponent bits is
// correct; the non-masked mantissa bits will all be zero, and the
// exponent will be incremented by one.
APInt truncation_mask = ~(last_mantissa_bit_mask - 1);
x_as_int = b->CreateAdd(x_as_int, x_rounding_bias);
x_as_int = b->CreateAnd(x_as_int,
llvm::ConstantInt::get(int_type, truncation_mask));
}
if (dest_exponent_bits < src_exponent_bits) {
APInt sign_bit_mask(nbits, 1);
sign_bit_mask <<= nbits - 1;
APInt exp_bits_mask(nbits, 1);
exp_bits_mask = ((exp_bits_mask << src_exponent_bits) - 1)
<< src_mantissa_bits;
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
// significant bit -- is equal to 1.0f for all exponent sizes. Adding
// 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
// size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
// exponent (corresponding to 0.0f).
//
// Thus, the f32 exponent corresponding to the highest non-infinite
// exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
// exponent corresponding to the lowest exponent for a bit size of n is
// (2^7-1) - 2^(n-1)-1.
//
// Note that we have already checked that exponents_bits >= 1.
APInt exponent_bias(nbits, 1);
exponent_bias = (exponent_bias << (src_exponent_bits - 1)) - 1;
APInt reduced_exponent_bias(nbits, 1);
reduced_exponent_bias =
(reduced_exponent_bias << (dest_exponent_bits - 1)) - 1;
APInt reduced_max_exponent = exponent_bias + reduced_exponent_bias;
APInt reduced_min_exponent = exponent_bias - reduced_exponent_bias;
// Do we overflow or underflow?
llvm::Value* x_exponent =
b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, exp_bits_mask));
llvm::Value* x_overflows = b->CreateICmpUGT(
x_exponent, llvm::ConstantInt::get(
int_type, reduced_max_exponent << src_mantissa_bits));
llvm::Value* x_underflows = b->CreateICmpULE(
x_exponent, llvm::ConstantInt::get(
int_type, reduced_min_exponent << src_mantissa_bits));
// Compute appropriately-signed values of zero and infinity.
llvm::Value* x_signed_zero =
b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, sign_bit_mask));
llvm::Value* x_signed_inf = b->CreateOr(
x_signed_zero, llvm::ConstantInt::get(int_type, exp_bits_mask));
// Force to zero or infinity if overflow or underflow. (Note that this
// truncates all denormal values to zero, rather than rounding them.)
x_as_int = b->CreateSelect(x_overflows, x_signed_inf, x_as_int);
x_as_int = b->CreateSelect(x_underflows, x_signed_zero, x_as_int);
}
// Cast the result back to a floating-point type.
llvm::Value* result = b->CreateBitCast(x_as_int, float_type);
// Correct result for NaN inputs.
//
// The exponent handling will "normalize" NaN values to infinities, which is
// undesirable (except in the case with no mantissa bits, in which case it
// is mandatory). This logic also handles cases where mantissa-rounding
// causes a NaN's mantissa to overflow into the exponent bits, which would
// otherwise create an erroneous zero value.
//
// If the fast-math flags are set to assume no NaNs, the comparison is likely
// to be optimized away, so there's no point in even emitting it.
if (!b->getFastMathFlags().noNaNs()) {
llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x);
if (dest_mantissa_bits > 0) {
result = b->CreateSelect(x_is_nan, x, result);
} else {
result = b->CreateSelect(
x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
}
}
return result;
}
StatusOr<llvm::Value*> EmitF32ToBF16(llvm::Value* f32_value,
llvm::IRBuilder<>* b) {
TF_ASSIGN_OR_RETURN(
auto reduced_precision,
EmitReducePrecisionIR(
/*src_ty=*/F32, f32_value,
/*dest_exponent_bits=*/primitive_util::kBFloat16ExponentBits,
/*dest_mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b));
auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
auto shifted = b->CreateLShr(as_int32, 16);
auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
return b->CreateBitCast(truncated, b->getInt16Ty());
}
llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) {
auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty());
auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty());
auto shifted = b->CreateShl(as_int32, 16);
return b->CreateBitCast(shifted, b->getFloatTy());
}
llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
PrimitiveType from_type,
PrimitiveType to_type, llvm::Module* module,
llvm::IRBuilder<>* b) {
if (primitive_util::IsSignedIntegralType(from_type)) {
return b->CreateSIToFP(integer_value,
llvm_ir::PrimitiveTypeToIrType(to_type, module));
} else {
CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
from_type == PRED);
return b->CreateUIToFP(integer_value,
llvm_ir::PrimitiveTypeToIrType(to_type, module));
}
}
} // namespace
StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) {
if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
op->operand(0)->shape().element_type() == PRED) {
return EmitIntegerUnaryOp(op, operand_value);
} else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
return EmitComplexUnaryOp(op, operand_value);
} else {
return EmitFloatUnaryOp(op, operand_value);
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) {
switch (op->opcode()) {
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
PrimitiveType to_type = op->shape().element_type();
CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED)
<< from_type;
if (from_type == to_type) {
return operand_value;
}
if (to_type == PRED) {
return b_->CreateZExt(
ICmpNE(operand_value,
llvm::ConstantInt::get(operand_value->getType(), 0)),
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
if (primitive_util::IsIntegralType(to_type)) {
return IntCast(operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, module_),
primitive_util::IsSignedIntegralType(from_type));
}
if (primitive_util::IsFloatingPointType(to_type)) {
if (to_type == BF16) {
return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type,
F32, module_, b_),
b_);
}
return EmitIntegralToFloating(operand_value, from_type, to_type,
module_, b_);
}
if (primitive_util::IsComplexType(to_type)) {
auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
primitive_util::ComplexComponentType(to_type), module_);
if (primitive_util::IsSignedIntegralType(from_type)) {
return EmitComposeComplex(
op, SIToFP(operand_value, to_ir_component_type), nullptr);
}
if (primitive_util::IsUnsignedIntegralType(from_type) ||
from_type == PRED) {
return EmitComposeComplex(
op, UIToFP(operand_value, to_ir_component_type), nullptr);
}
}
return Unimplemented("conversion from primitive type %s to %s",
PrimitiveType_Name(from_type),
PrimitiveType_Name(to_type));
}
case HloOpcode::kBitcastConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
PrimitiveType to_type = op->shape().element_type();
CHECK(primitive_util::IsIntegralType(from_type));
if (from_type == to_type) {
return operand_value;
}
if (primitive_util::BitWidth(from_type) ==
primitive_util::BitWidth(to_type)) {
return BitCast(operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return InvalidArgument(
"bitcast conversion from primitive type %s to %s with unequal "
"bit-widths (%u versus %u) ",
PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
primitive_util::BitWidth(from_type),
primitive_util::BitWidth(to_type));
}
case HloOpcode::kAbs: {
bool is_signed =
primitive_util::IsSignedIntegralType(op->shape().element_type());
if (is_signed) {
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
auto cmp = ICmpSGE(operand_value, GetZero(type));
return Select(cmp, operand_value, Neg(operand_value));
} else {
return operand_value;
}
}
case HloOpcode::kClz: {
auto is_zero_undef = b_->getFalse();
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctlz,
{operand_value, is_zero_undef},
{operand_value->getType()}, b_);
}
case HloOpcode::kSign: {
CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type()))
<< op->shape().element_type();
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
auto cmp = ICmpEQ(operand_value, GetZero(type));
auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1);
return Select(cmp, GetZero(type), Or(ashr, 1));
}
case HloOpcode::kNegate:
return Neg(operand_value);
case HloOpcode::kNot: {
auto type = op->shape().element_type();
if (type == PRED) {
// It is not sufficient to just call CreateNot() here because a PRED
// is represented as an i8 and the truth value is stored only in the
// bottom bit.
return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())),
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
} else if (primitive_util::IsIntegralType(type)) {
return Not(operand_value);
}
return Unimplemented("unary op Not is not defined for type '%d'", type);
}
case HloOpcode::kPopulationCount: {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctpop,
{operand_value},
{operand_value->getType()}, b_);
}
default:
return Unimplemented("unary integer op '%s'",
HloOpcodeString(op->opcode()));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) {
switch (op->opcode()) {
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
PrimitiveType to_type = op->shape().element_type();
CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type;
if (from_type == to_type) {
return operand_value;
}
if (from_type == BF16) {
TF_RET_CHECK(to_type != BF16);
operand_value = EmitBF16ToF32(operand_value, b_);
from_type = F32;
if (from_type == to_type) {
return operand_value;
}
}
if (primitive_util::IsComplexType(to_type)) {
PrimitiveType to_component_type =
primitive_util::ComplexComponentType(to_type);
if (from_type == to_component_type) {
return EmitComposeComplex(op, operand_value, nullptr);
}
return EmitComposeComplex(
op,
FPCast(operand_value,
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
nullptr);
}
if (to_type == BF16) {
// Cast to F32 first. Other floating point formats are not supported by
// EmitReducePrecisionIR.
if (from_type != F32) {
operand_value = b_->CreateFPCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(F32, module_));
}
return EmitF32ToBF16(operand_value, b_);
}
if (to_type == PRED) {
return b_->CreateZExt(
FCmpUNE(operand_value,
llvm::ConstantFP::get(operand_value->getType(), 0.0)),
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
if (primitive_util::IsFloatingPointType(to_type)) {
return FPCast(operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsSignedIntegralType(to_type)) {
return FPToSI(operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsUnsignedIntegralType(to_type)) {
return FPToUI(operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return Unimplemented("unhandled conversion operation: %s => %s",
PrimitiveType_Name(from_type),
PrimitiveType_Name(to_type));
}
case HloOpcode::kBitcastConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
PrimitiveType to_type = op->shape().element_type();
CHECK(primitive_util::IsFloatingPointType(from_type));
if (from_type == to_type) {
return operand_value;
}
if (primitive_util::BitWidth(from_type) ==
primitive_util::BitWidth(to_type)) {
return BitCast(operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return InvalidArgument(
"bitcast conversion from primitive type %s to %s with unequal "
"bit-widths (%u versus %u) ",
PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
primitive_util::BitWidth(from_type),
primitive_util::BitWidth(to_type));
}
case HloOpcode::kExp:
return EmitExp(op->shape().element_type(), operand_value);
case HloOpcode::kExpm1:
return EmitExpm1(op->shape().element_type(), operand_value);
case HloOpcode::kLog:
return EmitLog(op->shape().element_type(), operand_value);
case HloOpcode::kLog1p:
return EmitLog1p(op->shape().element_type(), operand_value);
case HloOpcode::kCos:
return EmitCos(op->shape().element_type(), operand_value);
case HloOpcode::kSin:
return EmitSin(op->shape().element_type(), operand_value);
case HloOpcode::kTanh:
return EmitTanh(op->shape().element_type(), operand_value);
case HloOpcode::kSqrt:
return EmitSqrt(op->shape().element_type(), operand_value);
case HloOpcode::kRsqrt:
return EmitRsqrt(op->shape().element_type(), operand_value);
case HloOpcode::kCbrt:
return EmitCbrt(op->shape().element_type(), operand_value);
case HloOpcode::kFloor:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
{operand_value},
{operand_value->getType()}, b_);
case HloOpcode::kCeil:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ceil,
{operand_value},
{operand_value->getType()}, b_);
case HloOpcode::kAbs:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
{operand_value},
{operand_value->getType()}, b_);
case HloOpcode::kRoundNearestAfz:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round,
{operand_value},
{operand_value->getType()}, b_);
case HloOpcode::kSign: {
auto type = operand_value->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
auto ne0_i1 = FCmpONE(operand_value, zero);
auto ne0_float = UIToFP(ne0_i1, type);
llvm::Value* result = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::copysign, {ne0_float, operand_value},
{operand_value->getType()}, b_);
auto is_nan = FCmpUNO(operand_value, operand_value);
result = Select(is_nan, operand_value, result);
return result;
}
case HloOpcode::kIsFinite: {
// abs(x) o!= inf, this works because the comparison returns false if
// either operand is NaN.
auto type = operand_value->getType();
auto abs_value = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
auto infinity = llvm::ConstantFP::getInfinity(type);
auto not_infinite = FCmpONE(abs_value, infinity);
return b_->CreateZExt(not_infinite,
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
case HloOpcode::kNegate:
return FNeg(operand_value);
case HloOpcode::kReal:
return operand_value;
case HloOpcode::kImag:
return llvm::ConstantFP::get(operand_value->getType(), 0.0);
default:
return Unimplemented("unary floating-point op '%s'",
HloOpcodeString(op->opcode()));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) {
PrimitiveType input_type = op->operand(0)->shape().element_type();
PrimitiveType component_type =
primitive_util::IsComplexType(input_type)
? primitive_util::ComplexComponentType(input_type)
: input_type;
switch (op->opcode()) {
case HloOpcode::kLog: {
// log(a+bi) = log(abs(a+bi)) + i*atan2(b,a)
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a));
TF_ASSIGN_OR_RETURN(llvm::Value * abs,
EmitComplexAbs(component_type, operand_value));
TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs));
return EmitComposeComplex(op, log_abs, angle);
}
case HloOpcode::kLog1p: {
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
llvm::Type* llvm_ty = a->getType();
auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
auto a_plus_one = FAdd(a, one);
auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
}
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
TF_RET_CHECK(primitive_util::IsComplexType(from_type));
PrimitiveType to_type = op->shape().element_type();
TF_RET_CHECK(primitive_util::IsComplexType(to_type));
if (from_type == to_type) {
return operand_value;
}
PrimitiveType to_component_type =
primitive_util::ComplexComponentType(to_type);
auto to_ir_component_type =
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
return EmitComposeComplex(
op, FPCast(EmitExtractReal(operand_value), to_ir_component_type),
FPCast(EmitExtractImag(operand_value), to_ir_component_type));
}
case HloOpcode::kExp: {
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
TF_ASSIGN_OR_RETURN(
auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
TF_ASSIGN_OR_RETURN(
auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
TF_ASSIGN_OR_RETURN(
auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b));
}
case HloOpcode::kExpm1: {
// e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
TF_ASSIGN_OR_RETURN(
auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
TF_ASSIGN_OR_RETURN(
auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
TF_ASSIGN_OR_RETURN(
auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
auto real_result = FSub(FMul(exp_a, cos_b), one);
auto imag_result = FMul(exp_a, sin_b);
return EmitComposeComplex(op, real_result, imag_result);
}
case HloOpcode::kCos: {
// cos(z) = .5(e^(iz) + e^(-iz))
// cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
// now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
// cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
// cos(-x) = cos(x) and sin(-x) = -sin(x), so
// cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
// = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
return EmitComposeComplex(op,
FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)),
FMul(sin_a, FSub(half_exp_neg_b, half_exp_b)));
}
case HloOpcode::kSin: {
// sin(z) = .5i(e^(-iz) - e^(iz))
// sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
// = .5i(e^(b-ai) - e^(-b+ai))
// now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
// sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
// = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
// cos(-x) = cos(x) and sin(-x) = -sin(x), so
// = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
// = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
return EmitComposeComplex(op,
FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)),
FMul(cos_a, FSub(half_exp_b, half_exp_neg_b)));
}
case HloOpcode::kTanh: {
/*
tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
e^(a+bi) = e^a*(cos(b)+sin(b)i)
so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
(((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
cos(b)=cos(-b), sin(-b)=-sin(b)
so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
(((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
=(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
(cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
=(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
(cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
This is a complex division, so we can multiply by denom_conj/denom_conj
=(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
(cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
=(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
=(e^(2a)-e^(-2a) +
i*[cos(b)sin(b)(e^(2a)+2+e^(-2a))-cos(b)sin(b)(e^(2a)-2+e^(2a)))]
/ (cos(b)^2*(e^(2a)+2+e^(-2a)) + sin(b)^2*(e^(2a)-2+e^(2a))
=(e^(2a)-e^(-2a) +
i*cos(b)sin(b)*[e^(2a)+2+e^(-2a)-e^(2a)+2-e^(-2a)]) /
([cos(b)^2 + sin(b)^2][e^(2a)+e^(-2a)])+2*[cos(b)^2 - sin(b)^2])
=(e^(2a)-e^(-2a) + i*cos(b)sin(b)*4) /
(e^(2a)+e^(-2a)+2*[cos(b)^2 - sin(b)^2])
=(e^(2a)-e^(-2a) + i*[sin(2b)/2]*4) /
(e^(2a)+e^(-2a)+2*[cos(2b)])
=(e^(2a)-e^(-2a) + i*2*sin(2b)) / (e^(2a) + e^(-2a) + 2*cos(2b))
*/
llvm::Value* a = EmitExtractReal(operand_value);
llvm::Value* b = EmitExtractImag(operand_value);
llvm::Type* type = a->getType();
llvm::Value* neg_one = llvm::ConstantFP::get(type, -1.F);
llvm::Value* two_a = FAdd(a, a);
llvm::Value* neg_2a = FMul(neg_one, two_a);
// When we are calculating the real numerator, e^(2a)-e^(-2a), for small
// values of `a`, we will get a ULP of 2^-23 using the exp function. Using
// expm1 to calculate e^(2a)-e^(-2a) = [e^(2a)-1] - [e^(-2a)-1] allows our
// ULP to be arbitrarily small. For larger values of `a`, calculating the
// numerator as Exp(2a)-Exp(-2a) vs Expm1(2a)-Expm1(-2a) return virtually
// identical results.
TF_ASSIGN_OR_RETURN(llvm::Value * exp_2a_m1,
EmitExpm1(component_type, two_a));
TF_ASSIGN_OR_RETURN(llvm::Value * exp_neg_2a_m1,
EmitExpm1(component_type, neg_2a));
llvm::Value* real_numerator = FSub(exp_2a_m1, exp_neg_2a_m1);
// We can use the identity cos(2b)+1 = cos(b)^2-sin(b)^2+cos(b)^2+sin(b)^2
// = 2cos(b)^2. This gives us the ability to be more precise when the
// denominator is close to zero.
TF_ASSIGN_OR_RETURN(llvm::Value * cos_b, EmitCos(component_type, b));
llvm::Value* four = llvm::ConstantFP::get(type, 4.F);
llvm::Value* cos_b_sq = FMul(cos_b, cos_b);
llvm::Value* two_cos_2b_p2 = FMul(cos_b_sq, four);
// Similarly we can compute sin(2b) with the formula sin(2b) =
// 2*sin(b)*cos(b).
TF_ASSIGN_OR_RETURN(llvm::Value * sin_b, EmitSin(component_type, b));
llvm::Value* imag_numerator = FMul(four, FMul(cos_b, sin_b));
// Expm1(x) is about x for small values of x, but exp_sum_m2 is about x^2
// for small value of x. As a result, due to floating point precision
// issues, x^2 is a better approximation than Expm1(x) + Expm1(x) for
// small values of x.
llvm::Value* a_sqr = FMul(a, a);
llvm::Value* use_approx_cutoff = llvm::ConstantFP::get(type, 1e-8);
llvm::Value* use_approx = FCmpOLT(a_sqr, use_approx_cutoff);
llvm::Value* exp_sum_m2 =
Select(use_approx, a_sqr, FAdd(exp_2a_m1, exp_neg_2a_m1));
llvm::Value* denom = FAdd(exp_sum_m2, two_cos_2b_p2);
// As `a` grows toward +inf and -inf, the real numerator will grow towards
// +inf and -inf respectively, while the denominator will always grow
// towards +inf. The result is real_numerator/denom = NaN, when it should
// equal +1 and -1 respectively. Therefore, if our denominator is +inf,
// we just hardcode the limits for the real numbers.
llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
llvm::Value* is_inf = FCmpOEQ(exp_sum_m2, inf);
llvm::Value* real_limit = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::copysign, {neg_one, a}, {type}, b_);
llvm::Value* real =
Select(is_inf, real_limit, FDiv(real_numerator, denom));
llvm::Value* imag = FDiv(imag_numerator, denom);
// The complex tanh functions have a few corner cases:
// 1. (+0, +0) => (+0, +0) - Handled normally
// 2. (x, +Inf) => (NaN, NaN) - See below
// 3. (x, NaN) => (NaN, NaN) - See below
// 4. (+inf, y) => (1, +0) - Handled normally
// 5. (+Inf, +Inf) => (1, +/-0) - See below
// 6. (+Inf, NaN) => (1, +/-0) - See below
// 7. (NaN, +0) => (NaN, +0) - See below
// 8. (NaN, y) => (NaN, NaN) - Handled normally
// 9. (NaN, NaN) => (NaN, NaN) - Handled normally
//
// For the cases that aren't handled normally:
// 2/3) Part of the calculation we do is that if exp(a) + exp(-a) = +inf,
// then we return (+/-1, +/-0). However, this is only true if we
// assume that a is infinity or b is finite. In the event that both a
// is finite and b is either +/-Inf or NaN, then our normal
// calculation would end up returing (+/-1, NaN), as opposed to (NaN,
// NaN).
// 5/6) We always calculate the imaginary value as sin(2b)/denominator.
// When the denominator is infinity, this assures us that the zero is
// the correct sign. However if our imaginary input results in
// sin(2b) = NaN, we calculate our imaginary result as NaN.
// 7) In the event that a is NaN, the denominator will be NaN.
// Therefore, the normal calculation gives (NaN, NaN) while we need
// (NaN, +0).
if (!(b_->getFastMathFlags().noNaNs() &&
b_->getFastMathFlags().noInfs())) {
llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
{a}, {type}, b_);
llvm::Value* zero = llvm::ConstantFP::get(type, 0.F);
llvm::Value* nan = llvm::ConstantFP::getNaN(type);
llvm::Value* a_is_inf = FCmpOEQ(abs_a, inf);
llvm::Value* b_is_zero = FCmpOEQ(b, zero);
// imag_numerator = 2sin(2b), so sin(2b) is NaN if and only if
// imag_numerator is NaN.
llvm::Value* sin_2b_is_nan =
b_->CreateFCmpUNO(imag_numerator, imag_numerator);
llvm::Value* real_is_nan =
b_->CreateAnd(sin_2b_is_nan, b_->CreateNot(a_is_inf));
llvm::Value* imag_is_zero =
b_->CreateOr(b_is_zero, b_->CreateAnd(a_is_inf, sin_2b_is_nan));
real = Select(real_is_nan, nan, real);
imag = Select(imag_is_zero, zero, imag);
}
return EmitComposeComplex(op, real, imag);
}
case HloOpcode::kAbs: {
return EmitComplexAbs(component_type, operand_value);
}
case HloOpcode::kSign: { // Sign(c) = c / |c|
TF_ASSIGN_OR_RETURN(auto cplx_abs,
EmitComplexAbs(component_type, operand_value));
auto type = cplx_abs->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
auto oeq = FCmpOEQ(cplx_abs, zero);
return Select(
oeq, EmitComposeComplex(op, zero, zero),
EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs),
FDiv(EmitExtractImag(operand_value), cplx_abs)));
}
case HloOpcode::kSqrt: {
return EmitComplexSqrt(op, component_type, operand_value);
}
case HloOpcode::kRsqrt: {
return EmitComplexRsqrt(op, component_type, operand_value);
}
case HloOpcode::kCbrt: {
return EmitComplexCbrt(op, component_type, operand_value);
}
case HloOpcode::kNegate:
return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
FNeg(EmitExtractImag(operand_value)));
case HloOpcode::kReal:
return EmitExtractReal(operand_value);
case HloOpcode::kImag:
return EmitExtractImag(operand_value);
default:
return Unimplemented("unary complex op '%s'",
HloOpcodeString(op->opcode()));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
PrimitiveType operand_type = op->operand(0)->shape().element_type();
if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
operand_type == PRED) {
return EmitIntegerBinaryOp(
op, lhs_value, rhs_value,
primitive_util::IsSignedIntegralType(operand_type));
} else if (primitive_util::IsComplexType(operand_type)) {
return EmitComplexBinaryOp(op, lhs_value, rhs_value);
} else {
return EmitFloatBinaryOp(op, lhs_value, rhs_value);
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
switch (op->opcode()) {
case HloOpcode::kComplex:
return EmitComposeComplex(op, lhs_value, rhs_value);
case HloOpcode::kAdd:
return FAdd(lhs_value, rhs_value);
case HloOpcode::kSubtract:
return FSub(lhs_value, rhs_value);
case HloOpcode::kMultiply:
return FMul(lhs_value, rhs_value);
case HloOpcode::kDivide:
return FDiv(lhs_value, rhs_value);
case HloOpcode::kRemainder:
return FRem(lhs_value, rhs_value);
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
// unordered comparisons return true.
//
// We use ordered comparisons for everything except kNe, where we use an
// unordered comparison. This makes x != y equivalent to !(x == y), and
// matches C++'s semantics.
case HloOpcode::kCompare: {
switch (op->comparison_direction()) {
case ComparisonDirection::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
rhs_value, b_);
case ComparisonDirection::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
rhs_value, b_);
case ComparisonDirection::kLt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
rhs_value, b_);
case ComparisonDirection::kGt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
rhs_value, b_);
case ComparisonDirection::kLe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
rhs_value, b_);
case ComparisonDirection::kGe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
rhs_value, b_);
}
}
case HloOpcode::kMaximum:
return EmitFloatMax(lhs_value, rhs_value);
case HloOpcode::kMinimum:
return EmitFloatMin(lhs_value, rhs_value);
case HloOpcode::kPower:
return EmitPow(op->shape().element_type(), lhs_value, rhs_value);
case HloOpcode::kAtan2:
return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
default:
return Unimplemented("binary floating point op '%s'",
HloOpcodeString(op->opcode()));
}
}
// Using sqrt(a^2 + b^2) can cause overflow errors. Therefore we can use
// sqrt(a^2 + b^2) = sqrt(a^2 * (1 + b^2/a^2))
// = |a| * sqrt(1 + (b/a)^2)
// With the assumption that |a| >= |b|.
//
// This method returns the min, max, and sqrt term for this calculation. This is
// done to prevent potential overflow errors that can occur from multiplying the
// max with the sqrt term. (i.e. when calculating the sqrt of the absolute
// value, we can take the sqrt of the max and the sqrt term before multiplying
// them together.) If return_sqrt is false, it returns 1 + (b/a)^2 instead of
// sqrt(1 + (b/a)^2).
StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type,
llvm::Value* operand_value,
bool return_sqrt) {
llvm::Value* real = EmitExtractReal(operand_value);
llvm::Value* imag = EmitExtractImag(operand_value);
llvm::Value* abs_real = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {real}, {real->getType()}, b_);
llvm::Value* abs_imag = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {imag}, {imag->getType()}, b_);
llvm::Value* max = EmitFloatMax(abs_real, abs_imag);
llvm::Value* min = EmitFloatMin(abs_real, abs_imag);
llvm::Value* div = FDiv(min, max);
llvm::Value* div_sq = FMul(div, div);
llvm::Value* one = llvm::ConstantFP::get(max->getType(), 1);
llvm::Value* one_p_div_sq = FAdd(one, div_sq);
TF_ASSIGN_OR_RETURN(llvm::Value * sqrt, EmitSqrt(prim_type, one_p_div_sq));
return std::make_tuple(min, max, return_sqrt ? sqrt : one_p_div_sq);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAbs(
PrimitiveType prim_type, llvm::Value* operand_value) {
llvm::Value* min;
llvm::Value* max;
llvm::Value* sqrt;
TF_ASSIGN_OR_RETURN(
std::tie(min, max, sqrt),
EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
llvm::Value* result = FMul(max, sqrt);
// When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
// In such cases, we return `min` instead of `result`.
return Select(FCmpUNO(result, result), min, result);
}
// Calculates ComplexAbs in the same way, except using:
// sqrt(|a| * sqrt(1 + (b/a)^2)) = sqrt(|a|) * pow(1 + (b/a)^2, .25)
StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrtComplexAbs(
PrimitiveType prim_type, llvm::Value* operand_value) {
llvm::Value* min;
llvm::Value* max;
llvm::Value* one_p_div_sq;
TF_ASSIGN_OR_RETURN(
std::tie(min, max, one_p_div_sq),
EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/false));
TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max));
TF_ASSIGN_OR_RETURN(llvm::Value * pow,
EmitPow(prim_type, one_p_div_sq,
llvm::ConstantFP::get(max->getType(), .25)));
llvm::Value* result = FMul(sqrt_max, pow);
// When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
// In such cases, we return `min` instead of `result`.
return Select(FCmpUNO(result, result), min, result);
}
// Calculates ComplexAbs in the same way, except using:
// rsqrt(|a| * sqrt(1 + (b/a)^2)) = rsqrt(|a|) * rsqrt(sqrt(1 + (b/a)^2))
StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrtComplexAbs(
PrimitiveType prim_type, llvm::Value* operand_value) {
llvm::Value* min;
llvm::Value* max;
llvm::Value* sqrt;
TF_ASSIGN_OR_RETURN(
std::tie(min, max, sqrt),
EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max));
TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt));
llvm::Value* result = FMul(rsqrt_max, rsqrt_sqrt);
TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_min, EmitRsqrt(prim_type, min));
// When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
// In such cases, we return rsqrt(min) instead of `result`.
return Select(FCmpUNO(result, result), rsqrt_min, result);
}
// Using our EmitComplexPower formula, but setting c=0.5 and d=0, we get:
// e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
// = e^[ln(r)*0.5] * [cos(t*0.5) + i*sin(t*0.5)]
// = r^0.5 * [cos(t/2) + i*sin(t/2)]
// = sqrt(r) * [cos(t/2) + i*sin(t/2)]
// where r = |a+bi| and t = atan2(b,a)
// TODO(bixia): See doc for implementation without atan2.
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSqrt(
const HloInstruction* op, PrimitiveType prim_type,
llvm::Value* operand_value) {
llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
->getElementType(0);
TF_ASSIGN_OR_RETURN(llvm::Value * r,
EmitSqrtComplexAbs(prim_type, operand_value));
llvm::Value* a = EmitExtractReal(operand_value);
llvm::Value* b = EmitExtractImag(operand_value);
TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a));
llvm::Value* c = llvm::ConstantFP::get(type, 0.5);
llvm::Value* angle = FMul(t, c);
TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
llvm::Value* real_part;
llvm::Value* imag_part;
llvm::Value* zero = llvm::ConstantFP::get(type, 0);
if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
llvm::Value* nan = llvm::ConstantFP::getNaN(type);
llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
{b}, {b->getType()}, b_);
real_part = Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, inf)), inf,
Select(And(FCmpOEQ(a, neg_inf), FCmpONE(abs_b, inf)),
zero, FMul(r, cos)));
llvm::Value* b_signed_inf = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::copysign, {inf, b}, {b->getType()}, b_);
imag_part =
Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, neg_inf)), b_signed_inf,
Select(FCmpUNO(r, r), nan,
Select(FCmpOEQ(sin, zero), sin, FMul(r, sin))));
} else {
real_part = FMul(r, cos);
imag_part = Select(FCmpOEQ(sin, zero), sin, FMul(r, sin));
}
return Select(FCmpOEQ(r, zero), EmitComposeComplex(op, zero, zero),
EmitComposeComplex(op, real_part, imag_part));
}
// Similar to Sqrt, we can use our EmitComplexPower formula, but set
// c=-0.5 and d=0. We get:
// e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
// = e^[ln(r)*-0.5] * [cos(t*-0.5) + i*sin(t*-0.5)]
// = r^(-0.5) * [cos(-t/2) + i*sin(-t/2)]
// = rsqrt(r) * [cos(-t/2) + i*sin(-t/2)]
// where r = |a+bi| and t = atan2(b,a).
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexRsqrt(
const HloInstruction* op, PrimitiveType prim_type,
llvm::Value* operand_value) {
llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
->getElementType(0);
TF_ASSIGN_OR_RETURN(llvm::Value * r,
EmitRsqrtComplexAbs(prim_type, operand_value));
llvm::Value* a = EmitExtractReal(operand_value);
llvm::Value* b = EmitExtractImag(operand_value);
TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a));
llvm::Value* c = llvm::ConstantFP::get(type, -0.5);
llvm::Value* angle = FMul(t, c);
TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
llvm::Value* real_part = FMul(r, cos);
llvm::Value* imag_part = FMul(r, sin);
if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
llvm::Value* zero = llvm::ConstantFP::get(type, 0);
llvm::Value* neg_one = llvm::ConstantFP::get(type, -1);
llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
llvm::Value* nan = llvm::ConstantFP::getNaN(type);
// llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
llvm::Value* a_signed_zero = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::copysign, {zero, a}, {a->getType()}, b_);
llvm::Value* b_signed_zero = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::copysign, {zero, b}, {b->getType()}, b_);
llvm::Value* neg_b_signed_zero = FMul(b_signed_zero, neg_one);
llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
{a}, {a->getType()}, b_);
llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
{b}, {b->getType()}, b_);
llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
real_part = Select(
is_zero_zero, inf,
Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
a_signed_zero, FMul(r, cos)));
imag_part = Select(
is_zero_zero, nan,
Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
neg_b_signed_zero, FMul(r, sin)));
} else {
llvm::Value* zero = llvm::ConstantFP::get(type, 0);
llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
llvm::Value* nan = llvm::ConstantFP::getNaN(type);
llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
real_part = Select(is_zero_zero, inf, FMul(r, cos));
imag_part = Select(is_zero_zero, nan, FMul(r, sin));
}
return EmitComposeComplex(op, real_part, imag_part);
}
//
// Using EmitComplexPower with c=1.0/3.0 and d=0
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexCbrt(
const HloInstruction* op, PrimitiveType prim_type,
llvm::Value* operand_value) {
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
auto zero = llvm::ConstantFP::get(type, 0);
llvm::Value* a = EmitExtractReal(operand_value);
llvm::Value* b = EmitExtractImag(operand_value);
return EmitComplexPower(op, a, b, third, zero);
}
// (a+bi)^(c+di) =
// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexPower(
const HloInstruction* op, llvm::Value* a, llvm::Value* b, llvm::Value* c,
llvm::Value* d) {
PrimitiveType component_type =
primitive_util::ComplexComponentType(op->shape().element_type());
auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
auto zero = llvm::ConstantFP::get(a->getType(), 0);
auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
auto one = llvm::ConstantFP::get(a->getType(), 1);
auto half_c = FMul(one_half, c);
TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
EmitPow(component_type, aa_p_bb, half_c));
auto neg_d = FNeg(d);
TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
EmitExp(component_type, neg_d_arg_lhs));
auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
auto half_d = FMul(one_half, d);
auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
// d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
// Branch Cuts for Complex Elementary Functions or Much Ado About
// Nothing's Sign Bit, W. Kahan, Section 10.
return Select(
And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)),
EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero),
EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)));
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
switch (op->opcode()) {
case HloOpcode::kAdd:
return EmitComposeComplex(
op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
case HloOpcode::kSubtract:
return EmitComposeComplex(
op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
case HloOpcode::kMultiply:
return EmitComposeComplex(
op,
FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))),
FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))));
case HloOpcode::kDivide: {
// Division of complex numbers is implemented here, taking into account
// over/underflow, NaN and Inf values.
auto a_r = EmitExtractReal(lhs_value);
auto a_i = EmitExtractImag(lhs_value);
auto b_r = EmitExtractReal(rhs_value);
auto b_i = EmitExtractImag(rhs_value);
auto type = a_r->getType();
// Smith's algorithm to divide complex numbers. It is just a bit smarter
// way to compute the following formula:
// (a_r + a_i * i) / (b_r + b_i * i)
// = (a_r + a_i * i) (b_r - b_i * i) / ((b_r + b_i * i)(b_r - b_i * i))
// = ((a_r * b_r + a_i * b_i) + (a_i * b_r - a_r * b_i) * i) / ||b||^2
//
// Depending on whether |b_r| < |b_i| we compute either
// b_r_b_i_ratio = b_r / b_i
// b_r_b_i_denom = b_i + b_r * b_r_b_i_ratio
// c_r = (a_r * b_r_b_i_ratio + a_i ) / b_r_b_i_denom
// c_i = (a_i * b_r_b_i_ratio - a_r ) / b_r_b_i_denom
//
// or
//
// b_i_b_r_ratio = b_i / b_r
// b_i_b_r_denom = b_r + b_i * b_i_b_r_denom
// c_r = (a_r + a_i * b_i_b_r_ratio ) / b_i_b_r_denom
// c_i = (a_i - a_r * b_i_b_r_ratio ) / b_i_b_r_denom
//
// See https://dl.acm.org/citation.cfm?id=368661 for more details.
auto b_r_b_i_ratio = FDiv(b_r, b_i);
auto b_r_b_i_denom = FAdd(b_i, FMul(b_r_b_i_ratio, b_r));
auto b_i_b_r_ratio = FDiv(b_i, b_r);
auto b_i_b_r_denom = FAdd(b_r, FMul(b_i_b_r_ratio, b_i));
auto b_r_lt_b_i = FCmpOLT(llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {b_r}, {type}, b_),
llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {b_i}, {type}, b_));
auto c_r = Select(
b_r_lt_b_i, FDiv(FAdd(FMul(b_r_b_i_ratio, a_r), a_i), b_r_b_i_denom),
FDiv(FAdd(FMul(b_i_b_r_ratio, a_i), a_r), b_i_b_r_denom));
auto c_i = Select(
b_r_lt_b_i, FDiv(FSub(FMul(b_r_b_i_ratio, a_i), a_r), b_r_b_i_denom),
FDiv(FSub(a_i, FMul(b_i_b_r_ratio, a_r)), b_i_b_r_denom));
auto result = EmitComposeComplex(op, c_r, c_i);
// Consider corner cases, if the result is (NaN, NaN).
auto zero = llvm::ConstantFP::get(type, 0.0);
auto one = llvm::ConstantFP::get(type, 1.0);
auto inf = llvm::ConstantFP::getInfinity(type);
// Case 1. Zero denominator.
auto zero_denominator =
And(And(FCmpOEQ(b_r, zero), FCmpOEQ(b_i, zero)),
Or(Neg(FCmpONE(a_r, zero)), Neg(FCmpONE(a_i, zero))));
auto inf_with_sign_of_c = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::copysign, {inf, a_r}, {type}, b_);
auto zero_denominator_result = EmitComposeComplex(
op, FMul(inf_with_sign_of_c, a_r), FMul(inf_with_sign_of_c, a_i));
// Case 2. Infinite numerator, finite denominator.
auto b_r_finite = FCmpONE(llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {b_r}, {type}, b_),
inf);
auto b_i_finite = FCmpONE(llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {b_i}, {type}, b_),
inf);
auto inf_num_finite_denom = And(Or(FCmpOEQ(a_r, inf), FCmpOEQ(a_i, inf)),
And(b_r_finite, b_i_finite));
auto a_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::copysign,
{Select(FCmpOEQ(a_r, inf), one, zero), a_r}, {type}, b_);
auto a_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::copysign,
{Select(FCmpOEQ(a_i, inf), one, zero), a_i}, {type}, b_);
auto inf_num_finite_denom_result =
EmitComposeComplex(op,
FMul(inf, FAdd(FMul(a_r_inf_with_sign, b_r),
FMul(a_i_inf_with_sign, b_i))),
FMul(inf, FSub(FMul(a_i_inf_with_sign, b_r),
FMul(a_r_inf_with_sign, b_i))));
// Case 3. Finite numerator, infinite denominator.
auto a_r_finite = FCmpONE(llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {a_r}, {type}, b_),
inf);
auto a_i_finite = FCmpONE(llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {a_i}, {type}, b_),
inf);
auto finite_num_inf_denom = And(Or(FCmpOEQ(b_r, inf), FCmpOEQ(b_i, inf)),
And(a_r_finite, a_i_finite));
auto b_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::copysign,
{Select(FCmpOEQ(b_r, inf), one, zero), b_r}, {type}, b_);
auto b_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::copysign,
{Select(FCmpOEQ(b_i, inf), one, zero), b_i}, {type}, b_);
auto finite_num_inf_denom_result =
EmitComposeComplex(op,
FMul(zero, FAdd(FMul(a_r, b_r_inf_with_sign),
FMul(a_i, b_i_inf_with_sign))),
FMul(zero, FSub(FMul(a_i, b_r_inf_with_sign),
FMul(a_r, b_i_inf_with_sign))));
auto c_nan = And(FCmpUNO(c_r, zero), FCmpUNO(c_i, zero));
return Select(
c_nan,
Select(zero_denominator, zero_denominator_result,
Select(inf_num_finite_denom, inf_num_finite_denom_result,
Select(finite_num_inf_denom,
finite_num_inf_denom_result, result))),
result);
}
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
// unordered comparisons return true.
//
// We use ordered comparisons for everything except kNe, where we use an
// unordered comparison. This makes x != y equivalent to !(x == y), and
// matches C++'s semantics.
case HloOpcode::kCompare: {
switch (op->comparison_direction()) {
case ComparisonDirection::kEq:
return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
EmitExtractReal(lhs_value),
EmitExtractReal(rhs_value), b_),
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
EmitExtractImag(lhs_value),
EmitExtractImag(rhs_value), b_));
case ComparisonDirection::kNe:
return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
EmitExtractReal(lhs_value),
EmitExtractReal(rhs_value), b_),
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
EmitExtractImag(lhs_value),
EmitExtractImag(rhs_value), b_));
default:
return Unimplemented(
"complex comparison '%s'",
ComparisonDirectionToString(op->comparison_direction()));
}
}
case HloOpcode::kPower: {
auto a = EmitExtractReal(lhs_value);
auto b = EmitExtractImag(lhs_value);
auto c = EmitExtractReal(rhs_value);
auto d = EmitExtractImag(rhs_value);
return EmitComplexPower(op, a, b, c, d);
}
default:
return Unimplemented("binary complex op '%s'",
HloOpcodeString(op->opcode()));
}
}
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max());
}
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max());
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
llvm::Value* value) {
auto x = value;
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
auto negative_half = llvm::ConstantFP::get(type, -0.5);
// When x is large, the naive evaluation of ln(x + 1) is more
// accurate than the Taylor series.
TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
// When x is small, (defined to be less than sqrt(2) / 2), use a rational
// approximation. The approximation below is based on one from the Cephes
// Mathematical Library.
//
// sqrt(2) - 1.
const auto kAntilogarithmIsSmallThreshold = 0.41421356237309504880;
static const std::array<double, 7> kDenominatorCoeffs{
1.,
1.5062909083469192043167E1,
8.3047565967967209469434E1,
2.2176239823732856465394E2,
3.0909872225312059774938E2,
2.1642788614495947685003E2,
6.0118660497603843919306E1,
};
static const std::array<double, 7> kNumeratorCoeffs{
4.5270000862445199635215E-5, 4.9854102823193375972212E-1,
6.5787325942061044846969E0, 2.9911919328553073277375E1,
6.0949667980987787057556E1, 5.7112963590585538103336E1,
2.0039553499201281259648E1,
};
auto x_squared = FMul(x, x);
TF_ASSIGN_OR_RETURN(auto denominator,
EvaluatePolynomial(type, x, kDenominatorCoeffs));
TF_ASSIGN_OR_RETURN(auto numerator,
EvaluatePolynomial(type, x, kNumeratorCoeffs));
auto for_small_x = FDiv(numerator, denominator);
for_small_x = FMul(FMul(x, x_squared), for_small_x);
for_small_x = FAdd(FMul(negative_half, x_squared), for_small_x);
for_small_x = FAdd(x, for_small_x);
auto abs_x =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
auto x_is_small = FCmpOLT(
abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
return Select(x_is_small, for_small_x, for_large_x);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType,
llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
llvm::Value* value) {
TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value));
return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
llvm::Value* value) {
auto x = value;
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
auto half = llvm::ConstantFP::get(type, 0.5);
// When the exponent is large, the naive evaluation of e^(x) - 1 is more
// accurate than the Taylor series.
TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value));
auto for_large_x = FSub(exp_x, one);
// The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
// We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
// We use the second degree approximation of exp(x)-1 = x + x^2/2.
auto x_squared = FMul(x, x);
auto x_squared_over_two = FMul(x_squared, half);
auto for_small_x = FAdd(x, x_squared_over_two);
// At this point, the relative errors due to floating point precision loss of
// calculating exp(x) - 1 and the polynomial exp(x)-1 = x + x^2/2 are about
// equal, with a value of approximately 2^-16.
const auto kExponentIsSmallThreshold = 0.009;
auto abs_x =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
auto x_is_small =
FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
return Select(x_is_small, for_small_x, for_large_x);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
{lhs->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
llvm::Value* value) {
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
auto abs_value =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * abs_res,
EmitPow(prim_type, abs_value, third));
auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
{abs_res, value}, {type}, b_);
return signed_res;
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
return Unimplemented("atan2");
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
llvm::Value* value) {
return Unimplemented("tanh");
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
const HloInstruction* hlo, llvm::Value* x) {
return EmitReducePrecisionIR(
/*src_ty=*/hlo->operand(0)->shape().element_type(), x,
/*dest_exponent_bits=*/hlo->exponent_bits(),
/*dest_mantissa_bits=*/hlo->mantissa_bits(), b_);
}
static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
llvm::Value* lhs, llvm::Value* rhs,
llvm::Value* shift_result,
bool saturate_to_sign_bit) {
llvm::IntegerType* integer_type =
llvm::cast<llvm::IntegerType>(lhs->getType());
unsigned integer_bitsize = integer_type->getBitWidth();
llvm::ConstantInt* integer_bitsize_constant =
llvm::ConstantInt::get(integer_type, integer_bitsize);
llvm::ConstantInt* zero = llvm::ConstantInt::get(integer_type, 0);
llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1);
llvm::Value* saturated_value;
if (saturate_to_sign_bit) {
saturated_value =
b->CreateSelect(b->CreateICmpSLT(lhs, zero), minus_one, zero);
} else {
saturated_value = zero;
}
llvm::Value* shift_amt_in_range =
b->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk");
return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
}
llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) {
return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
}
llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) {
return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
}
llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) {
auto* integer_type = llvm::cast<llvm::IntegerType>(type);
return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
integer_type->getBitWidth()));
}
llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) {
auto* integer_type = llvm::cast<llvm::IntegerType>(type);
return llvm::ConstantInt::get(
integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
}
llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) {
return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
}
llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs,
llvm::Value* rhs) {
return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())),
ICmpEQ(rhs, GetMinusOne(rhs->getType())));
}
llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
llvm::Value* rhs,
bool is_signed) {
// Integer division overflow behavior:
//
// X / 0 == -1
// INT_SMIN /s -1 = INT_SMIN
if (!is_signed) {
llvm::Value* udiv_is_unsafe = IsZero(rhs);
llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
llvm::Value* safe_div = UDiv(lhs, safe_rhs);
return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
}
llvm::Value* has_zero_divisor = IsZero(rhs);
llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
llvm::Value* safe_div = SDiv(lhs, safe_rhs);
return Select(
has_zero_divisor, GetMinusOne(lhs->getType()),
Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div));
}
llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
llvm::Value* rhs,
bool is_signed) {
// Integer remainder overflow behavior:
//
// X % 0 == X
// INT_SMIN %s -1 = 0
if (!is_signed) {
llvm::Value* urem_is_unsafe = IsZero(rhs);
llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
llvm::Value* safe_rem = URem(lhs, safe_rhs);
return Select(urem_is_unsafe, lhs, safe_rem);
}
llvm::Value* has_zero_divisor = IsZero(rhs);
llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
llvm::Value* safe_rem = SRem(lhs, safe_rhs);
return Select(
has_zero_divisor, lhs,
Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed) {
switch (op->opcode()) {
// TODO(jingyue): add the "nsw" attribute for signed types.
case HloOpcode::kAdd:
return Add(lhs_value, rhs_value);
case HloOpcode::kSubtract:
return Sub(lhs_value, rhs_value);
case HloOpcode::kMultiply:
return Mul(lhs_value, rhs_value);
case HloOpcode::kDivide:
return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
case HloOpcode::kRemainder:
return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
case HloOpcode::kCompare: {
switch (op->comparison_direction()) {
case ComparisonDirection::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
rhs_value, b_);
case ComparisonDirection::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
rhs_value, b_);
case ComparisonDirection::kLt:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
lhs_value, rhs_value, b_);
case ComparisonDirection::kGt:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
lhs_value, rhs_value, b_);
case ComparisonDirection::kLe:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
lhs_value, rhs_value, b_);
case ComparisonDirection::kGe:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
lhs_value, rhs_value, b_);
}
}
case HloOpcode::kMinimum:
return EmitIntegralMin(lhs_value, rhs_value, is_signed);
case HloOpcode::kMaximum:
return EmitIntegralMax(lhs_value, rhs_value, is_signed);
case HloOpcode::kAnd:
return And(lhs_value, rhs_value);
case HloOpcode::kOr:
return Or(lhs_value, rhs_value);
case HloOpcode::kXor:
return Xor(lhs_value, rhs_value);
// Shifting out bits >= the number of bits in the type being shifted
// produces a poison value in LLVM which is basically "deferred undefined
// behavior" -- doing something observable with such a value precipitates
// UB. We replace the poison value with a constant to avoid this deferred
// UB.
case HloOpcode::kShiftRightArithmetic:
return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
AShr(lhs_value, rhs_value),
/*saturate_to_sign_bit=*/true);
case HloOpcode::kShiftLeft:
return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
Shl(lhs_value, rhs_value),
/*saturate_to_sign_bit=*/false);
case HloOpcode::kShiftRightLogical:
return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
LShr(lhs_value, rhs_value),
/*saturate_to_sign_bit=*/false);
default:
return Unimplemented("binary integer op '%s'",
HloOpcodeString(op->opcode()));
}
}
llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
llvm::Value* rhs_value,
bool is_signed) {
return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
: llvm::ICmpInst::ICMP_UGE,
lhs_value, rhs_value),
lhs_value, rhs_value);
}
llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
llvm::Value* rhs_value,
bool is_signed) {
return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
: llvm::ICmpInst::ICMP_ULE,
lhs_value, rhs_value),
lhs_value, rhs_value);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index) {
TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
operand_to_generator.at(hlo->operand(0))(index));
TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
operand_to_generator.at(hlo->operand(1))(index));
TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
operand_to_generator.at(hlo->operand(2))(index));
return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value,
on_false_value);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index) {
TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
operand_to_generator.at(hlo->operand(0))(index));
TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
operand_to_generator.at(hlo->operand(1))(index));
TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
operand_to_generator.at(hlo->operand(2))(index));
PrimitiveType prim_type = hlo->shape().element_type();
if (primitive_util::IsFloatingPointType(prim_type)) {
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
} else if (primitive_util::IsIntegralType(prim_type)) {
bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
return EmitIntegralMin(
max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
} else {
return Unimplemented("Clamp unimplemented for %s",
PrimitiveType_Name(prim_type));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& target_index) {
const int64 concat_dim = hlo->dimensions(0);
auto source_index = target_index;
llvm::BasicBlock* init_block = b_->GetInsertBlock();
// A terminator should be present iff we're emitting code
// into the middle (as opposed to the end) of a basic block.
CHECK_EQ(b_->GetInsertPoint() == init_block->end(),
init_block->getTerminator() == nullptr);
llvm::BasicBlock* exit_block;
if (b_->GetInsertPoint() == init_block->end()) {
exit_block = llvm_ir::CreateBasicBlock(
/*insert_before=*/nullptr, IrName(hlo, "merge"), b_);
} else {
exit_block =
init_block->splitBasicBlock(b_->GetInsertPoint(), IrName(hlo, "merge"));
init_block->getTerminator()->eraseFromParent();
}
llvm_ir::SetToFirstInsertPoint(exit_block, b_);
llvm::PHINode* output =
PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
hlo->operands().size());
auto prior_insert_point = b_->GetInsertPoint();
b_->SetInsertPoint(init_block);
// Assign a unique id for each *different* operand, and count how often each
// operand is used. If all operands are different, the usage count will be 1
// for each operand.
absl::flat_hash_map<const HloInstruction*, int64> to_unique_operand_id;
std::vector<int64> operand_usage_count;
for (const auto* operand : hlo->operands()) {
if (to_unique_operand_id.contains(operand)) {
++operand_usage_count[to_unique_operand_id[operand]];
} else {
int64 unique_operand_id = to_unique_operand_id.size();
to_unique_operand_id[operand] = unique_operand_id;
operand_usage_count.push_back(1);
}
}
// To avoid that we emit the same operand more than once, we create one basic
// block for each *different* operand with a PHI node for the different source
// index inputs.
std::vector<llvm::BasicBlock*> emit_operand_blocks(
to_unique_operand_id.size(), nullptr);
std::vector<llvm::PHINode*> source_index_phis(to_unique_operand_id.size(),
nullptr);
for (const auto* operand : hlo->operands()) {
int64 operand_id = to_unique_operand_id[operand];
if (emit_operand_blocks[operand_id] != nullptr) {
continue;
}
emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock(
exit_block, StrCat("concat_index_from_operand_id", operand_id), b_);
auto saved_insert_point = b_->GetInsertPoint();
llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_);
source_index_phis[operand_id] =
PHI(source_index.GetType(), operand_usage_count[operand_id]);
std::vector<llvm::Value*> operand_multi_index = source_index.multidim();
operand_multi_index[concat_dim] =
NSWSub(operand_multi_index[concat_dim], source_index_phis[operand_id]);
// Create the terminator of the block before calling operand generators,
// because they require non-degenerate basic blocks.
b_->SetInsertPoint(llvm::BranchInst::Create(
exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id]));
llvm_ir::IrArray::Index operand_index(operand_multi_index, operand->shape(),
source_index.GetType());
TF_ASSIGN_OR_RETURN(llvm::Value * value,
operand_to_generator.at(operand)(operand_index));
output->addIncoming(value, b_->GetInsertBlock());
b_->SetInsertPoint(init_block, saved_insert_point);
}
int64 concat_dim_size = 0;
for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
++operand_idx) {
const HloInstruction* operand = hlo->operand(operand_idx);
auto false_block = llvm_ir::CreateBasicBlock(
exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_);
int64 operand_id = to_unique_operand_id[operand];
source_index_phis[operand_id]->addIncoming(
source_index.GetConstantWithIndexType(concat_dim_size),
b_->GetInsertBlock());
concat_dim_size += operand->shape().dimensions(concat_dim);
CondBr(ICmpULT(source_index[concat_dim],
source_index.GetConstantWithIndexType(concat_dim_size)),
emit_operand_blocks[operand_id], false_block);
// Subtract the size of the concat dimension of the current operand
// from the source index.
b_->SetInsertPoint(false_block);
}
Unreachable();
b_->SetInsertPoint(exit_block, prior_insert_point);
return output;
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index) {
// Emit IR to read dynamic start indices from hlo->operand(1).
const HloInstruction* input_hlo = hlo->operand(0);
const int64 rank = input_hlo->shape().rank();
// Use the same index type for all tensor accesses in the same kernel.
llvm::Type* index_type = index.GetType();
std::vector<llvm::Value*> slice_start_multi_index(rank);
for (int64 i = 0; i < rank; ++i) {
auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_type, c);
};
llvm_ir::IrArray::Index zero_index(index_type);
TF_ASSIGN_OR_RETURN(
llvm::Value * start_index_value,
operand_to_generator.at(hlo->operand(1 + i))(zero_index));
// Clamp the start index so that the sliced portion fits in the operand:
// start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
start_index_value = SExtOrTrunc(start_index_value, index_type);
int64 largest_valid_start_index =
input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
CHECK_GE(largest_valid_start_index, 0);
bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape());
start_index_value = EmitIntegralMin(
index_typed_const(largest_valid_start_index),
EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
is_signed);
start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
slice_start_multi_index[i] = start_index_value;
}
std::vector<llvm::Value*> input_multi_index(rank);
for (int64 i = 0; i < rank; ++i) {
// Emit IR which computes:
// input_index = start_index + offset_index
input_multi_index[i] = Add(slice_start_multi_index[i], index[i]);
}
llvm_ir::IrArray::Index input_index(input_multi_index, input_hlo->shape(),
index_type);
return operand_to_generator.at(input_hlo)(input_index);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index) {
const Shape& operand_shape = hlo->operand(0)->shape();
const Shape& indices_shape = hlo->operand(1)->shape();
const Shape& output_shape = hlo->shape();
const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers();
const llvm_ir::ElementGenerator& operand_generator =
operand_to_generator.at(hlo->operand(0));
const llvm_ir::ElementGenerator& indices_generator =
operand_to_generator.at(hlo->operand(1));
llvm::Type* index_type = index.GetType();
// This is the index into `operand` that holds the element we want to
// generate.
std::vector<llvm::Value*> operand_multi_index;
// First copy in the window indices to operand_index. Also collect a mapping
// from operand dimension to output window dimension. Elided window dimensions
// map to -1.
std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
i < e; i++) {
if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
operand_multi_index.push_back(index.GetConstantWithIndexType(0));
} else {
int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
operand_to_output_dim[i] = output_window_dim;
operand_multi_index.push_back(index[output_window_dim]);
}
}
// This is the index of the index vector in the start_indices tensor.
std::vector<llvm::Value*> gather_index_index_components;
{
for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
gather_index_index_components.push_back(index[i]);
}
}
if (gather_index_index_components.size() !=
indices_shape.dimensions_size()) {
gather_index_index_components.insert(
gather_index_index_components.begin() +
dim_numbers.index_vector_dim(),
nullptr);
}
}
auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
auto index_component_type = index_component->getType();
auto extended_type = index_component_type->getScalarSizeInBits() >=
index_type->getScalarSizeInBits()
? index_component_type
: index_type;
// Possibly extend the value at the beginning to ensure clamping logic stays
// in bounds.
auto maybe_extended_index =
index_component_type != extended_type
? b_->CreateSExt(index_component, extended_type)
: index_component;
int64 operand_dim = dim_numbers.start_index_map(dim);
int64 output_dim = operand_to_output_dim[operand_dim];
// If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
// This means we set the iteration index to 0, so for the purpose of the
// following calculations we can consider the output dimension size to be 1.
int64 output_dim_size =
output_dim == -1 ? 1 : output_shape.dimensions(output_dim);
int64 largest_valid_start_index =
operand_shape.dimensions(operand_dim) - output_dim_size;
CHECK_GE(largest_valid_start_index, 0);
// Clamp the gather index so that the gather region fits in the operand.
// clamped_index =
// clamp(gather_dim_component_extended, 0, largest_valid_start_index);
bool is_signed = ShapeUtil::ElementIsSigned(indices_shape);
auto clamped_index = EmitIntegralMin(
llvm::ConstantInt::get(extended_type, largest_valid_start_index),
EmitIntegralMax(llvm::ConstantInt::get(extended_type, 0),
maybe_extended_index, is_signed),
is_signed);
// Truncate at the end to the optimized index size
auto maybe_truncated_clamped_index = extended_type != index_type
? Trunc(clamped_index, index_type)
: clamped_index;
operand_multi_index[operand_dim] =
Add(operand_multi_index[operand_dim], maybe_truncated_clamped_index);
};
if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
IrArray::Index gather_index_index(gather_index_index_components,
indices_shape, index_type);
TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
indices_generator(gather_index_index));
add_to_operand_index(gather_dim_component, 0);
} else {
int64 index_vector_size =
indices_shape.dimensions(dim_numbers.index_vector_dim());
for (int64 i = 0; i < index_vector_size; i++) {
gather_index_index_components[dim_numbers.index_vector_dim()] =
index.GetConstantWithIndexType(i);
IrArray::Index gather_index_index(gather_index_index_components,
indices_shape, index_type);
TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
indices_generator(gather_index_index));
add_to_operand_index(gather_dim_component, i);
}
}
IrArray::Index operand_index(operand_multi_index, operand_shape, index_type);
return operand_generator(operand_index);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index) {
const HloInstruction* input_hlo = hlo->operand(0);
const HloInstruction* update_hlo = hlo->operand(1);
const HloInstruction* start_hlo = hlo->operand(2);
// Calculate slice start/end indices.
const int64 rank = input_hlo->shape().rank();
std::vector<llvm::Value*> slice_start_multi_index(rank);
std::vector<llvm::Value*> slice_limit_multi_index(rank);
// Slice intersection gathers (ANDs) conditions on all ranks for which
// 'input' is set to 'update'
llvm::Value* slice_intersection = b_->getTrue();
for (int64 i = 0; i < rank; ++i) {
llvm::Type* index_type = index[0]->getType();
auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_type, c);
};
llvm_ir::IrArray::Index zero_index(index_type);
TF_ASSIGN_OR_RETURN(
llvm::Value * start_index_value,
operand_to_generator.at(hlo->operand(2 + i))(zero_index));
// Clamp the start index so that the update region fits in the operand.
// start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
start_index_value = SExtOrTrunc(start_index_value, index_type);
llvm::Value* update_dim_size =
index_typed_const(update_hlo->shape().dimensions(i));
int64 largest_valid_start_index =
input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i);
CHECK_GE(largest_valid_start_index, 0);
bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape());
start_index_value = EmitIntegralMin(
index_typed_const(largest_valid_start_index),
EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
is_signed);
start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
slice_start_multi_index[i] = start_index_value;
slice_limit_multi_index[i] =
Add(slice_start_multi_index[i], update_dim_size);
slice_intersection =
And(slice_intersection, ICmpSGE(index[i], slice_start_multi_index[i]),
"slice_intersection");
slice_intersection =
And(slice_intersection, ICmpSLT(index[i], slice_limit_multi_index[i]),
"slice_intersection");
}
// Emit:
// if (slice_intersection) -> return data from 'update'.
// else -> return data from 'input'.
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
"ret_value_addr", b_);
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_);
// Handle true BB (return data from 'update')
SetToFirstInsertPoint(if_data.true_block, b_);
// Compute update index for intersection case.
std::vector<llvm::Value*> update_multi_index(rank);
for (int64 i = 0; i < rank; ++i) {
update_multi_index[i] = Sub(index[i], slice_start_multi_index[i]);
}
llvm_ir::IrArray::Index update_index(update_multi_index, update_hlo->shape(),
index.GetType());
TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
operand_to_generator.at(update_hlo)(update_index));
Store(true_value, ret_value_addr);
// Handle false BB (return data from 'input')
SetToFirstInsertPoint(if_data.false_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
operand_to_generator.at(input_hlo)(index));
Store(false_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, b_);
return Load(ret_value_addr);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& padded_index) {
std::vector<llvm::Value*> multi_index = padded_index.multidim();
llvm::Value* in_bounds = b_->getTrue();
for (size_t i = 0; i < multi_index.size(); ++i) {
auto index_typed_const = [=](int64 n) {
return padded_index.GetConstantWithIndexType(n);
};
const auto& pad_dim = hlo->padding_config().dimensions(i);
multi_index[i] =
Sub(multi_index[i], index_typed_const(pad_dim.edge_padding_low()));
in_bounds = And(in_bounds, ICmpSGE(multi_index[i], index_typed_const(0)),
"in_bounds");
in_bounds =
And(in_bounds,
ICmpEQ(index_typed_const(0),
URem(multi_index[i],
index_typed_const(pad_dim.interior_padding() + 1))),
"in_bounds");
multi_index[i] =
SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1));
in_bounds =
And(in_bounds,
ICmpSLT(multi_index[i],
index_typed_const(hlo->operand(0)->shape().dimensions(i))),
"in_bounds");
}
// if (in_bounds) {
// ret_value = operand0[index]; // source
// } else {
// ret_value = *operand1; // padding
// }
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
"pad_result_addr", b_);
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
SetToFirstInsertPoint(if_data.true_block, b_);
llvm_ir::IrArray::Index index(multi_index, hlo->operand(0)->shape(),
padded_index.GetType());
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(index));
Store(operand_value, ret_value_addr);
SetToFirstInsertPoint(if_data.false_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
operand_to_generator.at(hlo->operand(1))(
IrArray::Index(index.GetType())));
Store(padding_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, b_);
// Don't create phi(operand_value, padding_value) here, because invoking
// operand_to_generator may create new basic blocks, making the parent
// of operand_value or padding_value no longer a predecessor of
// if_data.after_block.
return Load(ret_value_addr);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& dot_result_index) {
auto lhs_generator = operand_to_generator.at(hlo->operand(0));
auto rhs_generator = operand_to_generator.at(hlo->operand(1));
const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
int64 contracted_dim_size =
hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
llvm::Type* index_type = dot_result_index.GetType();
auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_type, c);
};
std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
IrName(hlo, "inner"), index_typed_const(0),
index_typed_const(contracted_dim_size), index_typed_const(1), b_);
SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_);
PrimitiveType primitive_type = hlo->shape().element_type();
llvm::Type* primitive_type_llvm =
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
llvm::Value* accumulator_alloca =
llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca);
SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
// This is the inner reduction loop for a dot operation that produces
// one element in the output. If the operands to the dot operation have
// shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
// Given an output index [a,b,c,d,e] in the result, we compute:
// sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
std::vector<llvm::Value*> lhs_multi_index, rhs_multi_index;
for (int64 i = 0; i < lhs_dims - 1; i++) {
lhs_multi_index.push_back(dot_result_index[i]);
}
lhs_multi_index.insert(lhs_multi_index.begin() + lhs_contracting_dim,
inner_loop->GetIndVarValue());
IrArray::Index lhs_index(lhs_multi_index, hlo->operand(0)->shape(),
index_type);
int64 num_batch_dims = dim_numbers.rhs_batch_dimensions_size();
for (int64 i = 0; i < num_batch_dims; i++) {
rhs_multi_index.push_back(
dot_result_index[dim_numbers.rhs_batch_dimensions(i)]);
}
for (int64 i = 0; i < rhs_dims - 1 - num_batch_dims; i++) {
rhs_multi_index.push_back(dot_result_index[lhs_dims - 1 + i]);
}
rhs_multi_index.insert(rhs_multi_index.begin() + rhs_contracting_dim,
inner_loop->GetIndVarValue());
IrArray::Index rhs_index(rhs_multi_index, hlo->operand(1)->shape(),
index_type);
llvm::Value* current_accumulator = Load(accumulator_alloca);
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
llvm::Value* next_accumulator;
if (primitive_util::IsComplexType(primitive_type)) {
llvm::Value* product_real =
FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
llvm::Value* product_imag =
FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
next_accumulator = InsertValue(
current_accumulator,
FAdd(EmitExtractReal(current_accumulator), product_real), {0});
next_accumulator = InsertValue(
next_accumulator,
FAdd(EmitExtractImag(current_accumulator), product_imag), {1});
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value));
} else {
next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value));
}
Store(next_accumulator, accumulator_alloca);
SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
return Load(accumulator_alloca);
}
llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
switch (hlo->opcode()) {
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kConvert:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(index));
return EmitUnaryOp(hlo, operand_value);
};
case HloOpcode::kAdd:
case HloOpcode::kAnd:
case HloOpcode::kAtan2:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSubtract:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
const HloInstruction* lhs = hlo->operand(0);
const HloInstruction* rhs = hlo->operand(1);
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
operand_to_generator.at(lhs)(index));
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
operand_to_generator.at(rhs)(index));
return EmitBinaryOp(hlo, lhs_value, rhs_value);
};
case HloOpcode::kSelect:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
return EmitElementalSelect(hlo, operand_to_generator, index);
};
case HloOpcode::kClamp:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
return EmitElementalClamp(hlo, operand_to_generator, index);
};
case HloOpcode::kReducePrecision:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(index));
return EmitReducePrecision(hlo, operand_value);
};
case HloOpcode::kConcatenate:
return [this, hlo, &operand_to_generator](
const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
return EmitElementalConcatenate(hlo, operand_to_generator,
target_index);
};
case HloOpcode::kReverse:
return [this, hlo, &operand_to_generator](
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
const HloInstruction* operand = hlo->operand(0);
std::vector<llvm::Value*> source_multi_index = target_index.multidim();
for (int64 dim : hlo->dimensions()) {
source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType(
hlo->shape().dimensions(dim) - 1),
target_index[dim]);
}
llvm_ir::IrArray::Index source_index(
source_multi_index, operand->shape(), target_index.GetType());
return operand_to_generator.at(operand)(source_index);
};
case HloOpcode::kBroadcast:
return [this, hlo, &operand_to_generator](
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
const HloInstruction* operand = hlo->operand(0);
// The `dimensions` member of the broadcast instruction maps from
// input dimensions to output dimensions.
return operand_to_generator.at(operand)(
target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
hlo->dimensions(), b_));
};
case HloOpcode::kIota:
return [this, hlo](
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
auto* iota = Cast<HloIotaInstruction>(hlo);
PrimitiveType element_type = iota->shape().element_type();
IrArray::Index elem_index =
iota->shape().rank() > 1
? target_index.SourceIndexOfBroadcast(
iota->shape(),
ShapeUtil::MakeShapeWithDescendingLayout(
element_type,
{iota->shape().dimensions(iota->iota_dimension())}),
{iota->iota_dimension()}, b_)
: target_index;
llvm::Value* elem_index_linear = elem_index.linear();
if (elem_index_linear == nullptr) {
std::vector<int64> iota_bound = {
iota->shape().dimensions(iota->iota_dimension())};
elem_index_linear = elem_index.Linearize(iota_bound, b_);
}
Shape component_shape =
ShapeUtil::ElementIsComplex(iota->shape())
? ShapeUtil::ComplexComponentShape(iota->shape())
: iota->shape();
PrimitiveType component_element_type = component_shape.element_type();
llvm::Value* iota_result;
if (primitive_util::IsIntegralType(component_element_type)) {
iota_result = b_->CreateIntCast(
elem_index_linear,
llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
/*isSigned=*/false);
} else {
TF_RET_CHECK(
primitive_util::IsFloatingPointType(component_element_type))
<< component_element_type;
llvm::Type* float_ir_type;
if (component_element_type == BF16) {
float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
} else {
float_ir_type =
llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
}
llvm::Value* float_val =
b_->CreateUIToFP(elem_index_linear, float_ir_type);
if (component_element_type == BF16) {
TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val, b_));
} else {
iota_result = float_val;
}
}
if (ShapeUtil::ElementIsComplex(iota->shape())) {
return EmitComposeComplex(iota, iota_result, nullptr);
} else {
return iota_result;
}
};
case HloOpcode::kSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
IrArray::Index sliced_index = index.SourceIndexOfSlice(
/*operand_shape=*/hlo->operand(0)->shape(),
/*starts=*/hlo->slice_starts(),
/*strides=*/hlo->slice_strides(), /*builder=*/b_);
return operand_to_generator.at(hlo->operand(0))(sliced_index);
};
case HloOpcode::kDynamicSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
return EmitElementalDynamicSlice(hlo, operand_to_generator, index);
};
case HloOpcode::kGather:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
return EmitElementalGather(hlo, operand_to_generator, index);
};
case HloOpcode::kDynamicUpdateSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator,
index);
};
case HloOpcode::kBitcast:
CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
const HloInstruction* operand = hlo->operand(0);
return operand_to_generator.at(operand)(
index.SourceIndexOfBitcast(hlo->shape(), operand->shape(), b_));
};
case HloOpcode::kReshape:
CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
const HloInstruction* operand = hlo->operand(0);
return operand_to_generator.at(operand)(
index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_));
};
case HloOpcode::kCopy:
return [hlo, &operand_to_generator](
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
IrArray::Index source_index(target_index.multidim(),
hlo->operand(0)->shape(),
target_index.GetType());
TF_ASSIGN_OR_RETURN(
llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(source_index));
return operand_value;
};
case HloOpcode::kTranspose:
return [this, hlo,
&operand_to_generator](const IrArray::Index& target_index) {
return operand_to_generator.at(hlo->operand(0))(
target_index.SourceIndexOfTranspose(
hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions()));
};
case HloOpcode::kPad:
return [this, hlo, &operand_to_generator](
const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
return EmitElementalPad(hlo, operand_to_generator, padded_index);
};
case HloOpcode::kDot:
return [this, hlo,
&operand_to_generator](const IrArray::Index& dot_result_index)
-> StatusOr<llvm::Value*> {
return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
};
case HloOpcode::kMap:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
std::vector<llvm::Value*> operands;
for (int i = 0; i < hlo->operand_count(); i++) {
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(i))(index));
operands.push_back(operand_value);
}
return EmitElementalMap(Cast<HloMapInstruction>(hlo), operands);
};
case HloOpcode::kReduceWindow:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
return EmitElementalReduceWindow(
Cast<HloReduceWindowInstruction>(hlo),
operand_to_generator.at(hlo->operand(0)),
operand_to_generator.at(hlo->operand(1)), index);
};
case HloOpcode::kReduce:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
auto reduce_instr = Cast<HloReduceInstruction>(hlo);
std::vector<llvm_ir::ElementGenerator> input_generators;
for (const HloInstruction* instr : reduce_instr->inputs()) {
input_generators.push_back(operand_to_generator.at(instr));
}
std::vector<llvm_ir::ElementGenerator> initial_value_generators;
for (const HloInstruction* instr : reduce_instr->init_values()) {
initial_value_generators.push_back(operand_to_generator.at(instr));
}
return EmitElementalReduce(reduce_instr, std::move(input_generators),
std::move(initial_value_generators), index);
};
case HloOpcode::kConvolution:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
return EmitConvolution(hlo, operand_to_generator, index);
};
default:
return [hlo](const IrArray::Index& index) {
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
HloOpcodeString(hlo->opcode()));
};
}
}
llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) {
return ExtractValue(value, {0});
}
llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) {
return ExtractValue(value, {1});
}
llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
llvm::Value* real,
llvm::Value* imag) {
auto cplx_type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
auto complex =
InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0});
if (imag != nullptr) {
complex = InsertValue(complex, imag, {1});
}
return complex;
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
const HloMapInstruction* map_instr,
absl::Span<llvm::Value* const> elemental_operands) {
TF_ASSIGN_OR_RETURN(
std::vector<llvm::Value*> values,
EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands,
llvm_ir::IrName(map_instr)));
CHECK_EQ(values.size(), 1);
return values[0];
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
const HloReduceWindowInstruction* reduce_window,
const llvm_ir::ElementGenerator& input_generator,
const llvm_ir::ElementGenerator& initial_value_generator,
const llvm_ir::IrArray::Index& index) {
// Pseudocode:
// for each index I in output
// value = init_value
// for each index W in window
// for each dimension i from 0 to rank - 1
// (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
// if I in bounds of input
// value = function(value, input[I])
// output[O] = value
if (reduce_window->shape().IsTuple()) {
return Status(tensorflow::error::UNIMPLEMENTED,
"Variadic reduce window op is not yet fully supported.");
}
const HloInstruction* operand = reduce_window->operand(0);
const Window& window = reduce_window->window();
PrimitiveType operand_element_type = operand->shape().element_type();
llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
"reduce_window_accum_ptr", b_);
{
TF_ASSIGN_OR_RETURN(
llvm::Value* const init_value,
initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
Store(init_value, accum_ptr);
}
llvm::Type* index_type = index.GetType();
auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
return index.GetConstantWithIndexType(c);
};
llvm_ir::ForLoopNest loops(IrName(reduce_window), b_, index_type);
std::vector<int64> window_size;
for (const auto& dim : window.dimensions()) {
window_size.push_back(dim.size());
}
const IrArray::Index window_index = loops.AddLoopsForShape(
ShapeUtil::MakeShape(operand_element_type, window_size), "window");
CHECK_EQ(window_index.size(), index.size());
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
std::vector<llvm::Value*> input_multi_index(index.size());
llvm::Value* in_bounds = b_->getInt1(true);
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* stridden_index =
NSWMul(index[i], index_typed_const(window.dimensions(i).stride()));
input_multi_index[i] = NSWSub(
NSWAdd(
stridden_index,
NSWMul(window_index[i],
index_typed_const(window.dimensions(i).window_dilation()))),
index_typed_const(window.dimensions(i).padding_low()));
// We need to verify that we are not in the dilated base area.
llvm::Value* dilation_condition =
ICmpEQ(SRem(input_multi_index[i],
index_typed_const(window.dimensions(i).base_dilation())),
index_typed_const(0));
in_bounds = And(in_bounds, dilation_condition);
// Apply base dilation to the index.
input_multi_index[i] =
SDiv(input_multi_index[i],
index_typed_const(window.dimensions(i).base_dilation()));
// We must check whether 0 <= input_multi_index[i] < bound, as
// otherwise we are in the pad and so can skip the computation. This
// comparison is equivalent to the unsigned comparison
// input_multi_index[i] < bound, as a negative value wraps to a large
// positive value.
in_bounds = And(in_bounds,
ICmpULT(input_multi_index[i],
index_typed_const(operand->shape().dimensions(i))));
}
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
SetToFirstInsertPoint(if_data.true_block, b_);
// We are not in pad, so do the computation.
IrArray::Index input_index(input_multi_index, operand->shape(), index_type);
TF_ASSIGN_OR_RETURN(llvm::Value * input_value, input_generator(input_index));
TF_ASSIGN_OR_RETURN(
std::vector<llvm::Value*> accum_values,
EmitThreadLocalCall(*reduce_window->to_apply(),
{Load(accum_ptr), input_value}, "reducer_function"));
CHECK_EQ(accum_values.size(), 1);
Store(accum_values[0], accum_ptr);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
return Load(accum_ptr);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
const HloReduceInstruction* reduce,
std::vector<llvm_ir::ElementGenerator> input_generators,
std::vector<llvm_ir::ElementGenerator> initial_value_generators,
const llvm_ir::IrArray::Index& index) {
const Shape& out_shape = reduce->shape();
bool is_variadic = !out_shape.IsArray();
int accumulators_count = 1;
if (is_variadic) {
CHECK(out_shape.IsTuple());
accumulators_count = out_shape.tuple_shapes_size();
}
absl::Span<const int64> reduced_dimensions(reduce->dimensions());
std::vector<llvm::Value*> accumulator_addrs;
std::vector<llvm::Type*> accumulator_types;
llvm::Type* index_type = index.GetType();
for (int i = 0; i < accumulators_count; i++) {
const Shape& element_shape =
is_variadic ? out_shape.tuple_shapes(i) : out_shape;
PrimitiveType accumulator_type = element_shape.element_type();
llvm::Type* accumulator_llvm_type =
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
accumulator_types.push_back(accumulator_llvm_type);
// Initialize an accumulator with init_value.
llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
accumulator_llvm_type, "accumulator_" + std::to_string(i), b());
TF_ASSIGN_OR_RETURN(
llvm::Value* const init_value,
initial_value_generators[i](llvm_ir::IrArray::Index(index_type)));
Store(init_value, accumulator_addr);
accumulator_addrs.push_back(accumulator_addr);
}
// The enclosing loops go over all the target elements. Now we have to compute
// the actual target element. For this, we build a new loop nest to iterate
// over all the reduction dimensions in the argument.
// AddLoopsForShapeOnDimensions will return an Index where induction Value*s
// are placed for each dimension in dimensions, and all the rest are nullptrs.
llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), b(), index_type);
const HloInstruction* arg = reduce->operand(0);
std::vector<llvm::Value*> input_multi_index =
loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
"reduction_dim");
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
// Build a full index for the input argument, using input_multi_index as the
// base. In input_multi_index only the reduction dimensions are filled in. We
// fill in the rest of the dimensions with induction Value*s taken from
// 'index' which iterates over the target array. See the high-level
// description in the XLA documentation for details.
auto it = index.begin();
for (auto& i : input_multi_index) {
if (i == nullptr) {
i = *it++;
}
}
CHECK(index.end() == it);
llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
index_type);
std::vector<llvm::Value*> reduction_operands;
for (llvm::Value* accum : accumulator_addrs) {
llvm::Value* accum_value = Load(accum);
reduction_operands.push_back(accum_value);
}
for (int i = 0; i < accumulators_count; i++) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
input_generators[i](input_index));
reduction_operands.push_back(input_element);
}
TF_ASSIGN_OR_RETURN(
std::vector<llvm::Value*> results,
EmitThreadLocalCall(*reduce->to_apply(), reduction_operands,
"reduce_function"));
CHECK(results.size() == accumulators_count);
for (int i = 0; i < accumulators_count; i++) {
Store(results[i], accumulator_addrs[i]);
}
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
if (is_variadic) {
// Emit a structure, as that what the LoopEmitter expects.
llvm::Value* returned_structure = llvm::UndefValue::get(
llvm::StructType::get(b()->getContext(), accumulator_types));
for (int i = 0; i < accumulators_count; i++) {
llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
returned_structure =
b()->CreateInsertValue(returned_structure, accumulator_value, i);
}
return returned_structure;
} else {
CHECK_EQ(accumulator_addrs.size(), 1);
return Load(accumulator_addrs[0]);
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index) {
return Unimplemented("Elemental convolution is not implemented");
}
// Evaluate polynomial using Horner's method.
StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
llvm::Value* poly = llvm::ConstantFP::get(type, 0.0);
for (const double c : coefficients) {
poly = FAdd(FMul(poly, x), llvm::ConstantFP::get(type, c));
}
return poly;
}
} // namespace xla