blob: 73fb7b1abcd9d9d004256c0311452e0cbd6cdac9 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "absl/base/casts.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Operator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace llvm_ir {
namespace {
// Note, this function is only useful in an insertion context; in a global
// (e.g. constants) context it will CHECK fail.
llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) {
auto block = CHECK_NOTNULL(b->GetInsertBlock());
auto fn = CHECK_NOTNULL(block->getParent());
auto module = CHECK_NOTNULL(fn->getParent());
return module;
}
} // namespace
std::unique_ptr<llvm::Module> DropConstantInitializers(
const llvm::Module& module) {
std::unique_ptr<llvm::Module> cloned_module = CloneModule(module);
for (llvm::GlobalVariable& global_var : cloned_module->globals()) {
global_var.setInitializer(nullptr);
global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
}
return cloned_module;
}
std::string DumpModuleToString(const llvm::Module& module) {
std::string buffer_string;
llvm::raw_string_ostream ostream(buffer_string);
module.print(ostream, nullptr);
ostream.flush();
return buffer_string;
}
llvm::CallInst* EmitCallToIntrinsic(
llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
absl::string_view name) {
llvm::Module* module = ModuleFromIRBuilder(b);
llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
module, intrinsic_id, AsArrayRef(overloaded_types));
return b->CreateCall(intrinsic, AsArrayRef(operands), name.data());
}
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b, bool enable_fast_min_max,
absl::string_view name) {
if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
} else {
auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
}
}
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b, bool enable_fast_min_max,
absl::string_view name) {
if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
} else {
auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
}
}
llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type,
llvm::Value* index, llvm::IRBuilder<>* b) {
llvm::Type* array_type = array->getType();
CHECK(array_type->isPointerTy());
llvm::PointerType* array_type_as_pointer =
llvm::cast<llvm::PointerType>(array_type);
CHECK(array_type_as_pointer->isOpaqueOrPointeeTypeMatches(element_type));
VLOG(2) << "EmitBufferIndexingGEP with type="
<< llvm_ir::DumpToString(*array_type)
<< " array=" << llvm_ir::DumpToString(*array)
<< " index=" << llvm_ir::DumpToString(*index);
return b->CreateInBoundsGEP(
element_type, array,
llvm::isa<llvm::GlobalVariable>(array)
? llvm::ArrayRef<llvm::Value*>({b->getInt64(0), index})
: index);
}
llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type,
int64_t index, llvm::IRBuilder<>* b) {
return EmitBufferIndexingGEP(array, element_type, b->getInt64(index), b);
}
llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
llvm::Module* module) {
switch (element_type) {
case PRED:
case S8:
case U8:
return llvm::Type::getInt8Ty(module->getContext());
case S16:
case U16:
case BF16:
// For BF16 we just need some type that is 16 bits wide so that it will
// take up the right amount of space in memory. LLVM does not have a BF16
// type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
// we can't map it directly to an LLVM type. We will not map a BF16
// addition to an addition on this type (int16_t) - this is just the type
// used for storage.
return llvm::Type::getInt16Ty(module->getContext());
case F16:
return llvm::Type::getHalfTy(module->getContext());
case S32:
case U32:
return llvm::Type::getInt32Ty(module->getContext());
case S64:
case U64:
return llvm::Type::getInt64Ty(module->getContext());
case F32:
return llvm::Type::getFloatTy(module->getContext());
case F64:
return llvm::Type::getDoubleTy(module->getContext());
case C64: {
auto cplx_t =
llvm::StructType::getTypeByName(module->getContext(), "complex64");
if (cplx_t == nullptr) {
// C++ standard dictates the memory layout of std::complex is contiguous
// real followed by imaginary. C++11 section 26.4 [complex.numbers]:
// If z is an lvalue expression of type cv std::complex<T> then the
// expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
// reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
// z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
// imaginary part of z.
return llvm::StructType::create(
{llvm::Type::getFloatTy(module->getContext()),
llvm::Type::getFloatTy(module->getContext())},
"complex64", /*isPacked=*/true);
}
return cplx_t;
}
case C128: {
auto cplx_t =
llvm::StructType::getTypeByName(module->getContext(), "complex128");
if (cplx_t == nullptr) {
return llvm::StructType::create(
{llvm::Type::getDoubleTy(module->getContext()),
llvm::Type::getDoubleTy(module->getContext())},
"complex128", /*isPacked=*/true);
}
return cplx_t;
} // A Tuple contains an array of pointers. Use i8*.
case TUPLE:
// An Opaque is like a void*, use i8*.
case OPAQUE_TYPE:
return llvm::Type::getInt8PtrTy(module->getContext());
case TOKEN:
// Tokens do not have a physical representation, but the compiler needs
// some placeholder type, so use int8_t*.
return llvm::Type::getInt8PtrTy(module->getContext());
default:
LOG(FATAL) << "unsupported type " << element_type;
}
}
int GetSizeInBits(llvm::Type* type) {
const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type);
if (struct_ty) {
CHECK(struct_ty->isPacked());
int bits = 0;
for (auto element_type : struct_ty->elements()) {
bits += GetSizeInBits(element_type);
}
return bits;
}
int bits = type->getPrimitiveSizeInBits();
CHECK_GT(bits, 0) << "type is not sized";
return bits;
}
llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
if (shape.IsTuple()) {
// A tuple buffer is an array of pointers.
result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
} else if (shape.IsArray()) {
for (int64_t dimension : LayoutUtil::MinorToMajor(shape)) {
result_type =
llvm::ArrayType::get(result_type, shape.dimensions(dimension));
}
}
return result_type;
}
StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
int32_t* shape_size,
llvm::IRBuilder<>* b) {
std::string encoded_shape = shape.SerializeAsString();
if (encoded_shape.size() > std::numeric_limits<int32_t>::max()) {
return InternalError("Encoded shape size exceeded int32_t size limit.");
}
*shape_size = static_cast<int32_t>(encoded_shape.size());
return b->CreateGlobalStringPtr(encoded_shape);
}
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
llvm::Module* module) {
const char* data = static_cast<const char*>(literal.untyped_data());
CHECK_EQ(module->getDataLayout().isLittleEndian(),
tensorflow::port::kLittleEndian);
return llvm::ConstantDataArray::getString(
module->getContext(), llvm::StringRef(data, literal.size_bytes()),
/*AddNull=*/false);
}
llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
llvm::Type* tile_type,
absl::string_view name) {
// Both AMDGPU and NVPTX use the same address space for shared memory.
const int kGPUSharedMemoryAddrSpace = 3;
return new llvm::GlobalVariable(
*module, tile_type,
/*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr,
llvm::GlobalValue::NotThreadLocal, kGPUSharedMemoryAddrSpace);
}
llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
absl::string_view name,
llvm::IRBuilder<>* b,
int alignment) {
return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
}
llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
llvm::Value* element_count,
absl::string_view name,
llvm::IRBuilder<>* b,
int alignment) {
llvm::IRBuilder<>::InsertPointGuard guard(*b);
llvm::Function* function = b->GetInsertBlock()->getParent();
b->SetInsertPoint(&function->getEntryBlock(),
function->getEntryBlock().getFirstInsertionPt());
llvm::AllocaInst* alloca =
b->CreateAlloca(type, element_count, AsStringRef(name));
if (alignment != 0) {
alloca->setAlignment(llvm::Align(alignment));
}
return alloca;
}
llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
absl::string_view name,
llvm::IRBuilder<>* b) {
return llvm::BasicBlock::Create(
/*Context=*/b->getContext(),
/*Name=*/AsStringRef(name),
/*Parent=*/b->GetInsertBlock()->getParent(),
/*InsertBefore*/ insert_before);
}
LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
llvm::IRBuilder<>* b, bool emit_else) {
llvm_ir::LlvmIfData if_data;
if_data.if_block = b->GetInsertBlock();
if_data.true_block =
CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
if_data.false_block =
emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
: nullptr;
// Add a terminator to the if block, if necessary.
if (if_data.if_block->getTerminator() == nullptr) {
b->SetInsertPoint(if_data.if_block);
if_data.after_block =
CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
b->CreateBr(if_data.after_block);
} else {
if_data.after_block = if_data.if_block->splitBasicBlock(
b->GetInsertPoint(), absl::StrCat(name, "-after"));
}
// Our basic block should now end with an unconditional branch. Remove it;
// we're going to replace it with a conditional branch.
if_data.if_block->getTerminator()->eraseFromParent();
b->SetInsertPoint(if_data.if_block);
b->CreateCondBr(condition, if_data.true_block,
emit_else ? if_data.false_block : if_data.after_block);
b->SetInsertPoint(if_data.true_block);
b->CreateBr(if_data.after_block);
if (emit_else) {
b->SetInsertPoint(if_data.false_block);
b->CreateBr(if_data.after_block);
}
b->SetInsertPoint(if_data.after_block,
if_data.after_block->getFirstInsertionPt());
return if_data;
}
llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b, absl::string_view name) {
llvm::Value* comparison_result;
if (lhs_value->getType()->isIntegerTy()) {
comparison_result =
b->CreateICmp(predicate, lhs_value, rhs_value, name.data());
} else {
comparison_result =
b->CreateFCmp(predicate, lhs_value, rhs_value, name.data());
}
// comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
// arrays. So we extend it to i8 so that it's addressable.
return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType(
PRED, ModuleFromIRBuilder(b)));
}
// Internal helper that is called from emitted code to log an int64_t value with
// a tag.
static void LogS64(const char* tag, int64_t value) {
LOG(INFO) << tag << " (int64_t): " << value;
}
void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) {
llvm::FunctionType* log_function_type = llvm::FunctionType::get(
b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false);
b->CreateCall(log_function_type,
b->CreateIntToPtr(b->getInt64(absl::bit_cast<int64_t>(&LogS64)),
log_function_type->getPointerTo()),
{b->getInt64(absl::bit_cast<int64_t>(tag)), value});
}
void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
llvm::LLVMContext& context = load->getContext();
llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
llvm::Constant* alignment_constant =
llvm::ConstantInt::get(int64_ty, alignment);
llvm::MDBuilder metadata_builder(context);
auto* alignment_metadata =
metadata_builder.createConstant(alignment_constant);
load->setMetadata(llvm::LLVMContext::MD_align,
llvm::MDNode::get(context, alignment_metadata));
}
void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
uint64_t dereferenceable_bytes) {
llvm::LLVMContext& context = load->getContext();
llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
llvm::Constant* dereferenceable_bytes_constant =
llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
llvm::MDBuilder metadata_builder(context);
auto* dereferenceable_bytes_metadata =
metadata_builder.createConstant(dereferenceable_bytes_constant);
load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
llvm::MDNode::get(context, dereferenceable_bytes_metadata));
}
llvm::Instruction* AddRangeMetadata(int32_t lower, int32_t upper,
llvm::Instruction* inst) {
llvm::LLVMContext& context = inst->getParent()->getContext();
llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
inst->setMetadata(
llvm::LLVMContext::MD_range,
llvm::MDNode::get(
context,
{llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
return inst;
}
std::string IrName(absl::string_view a) {
std::string s(a);
s.erase(std::remove(s.begin(), s.end(), '%'), s.end());
return s;
}
std::string IrName(absl::string_view a, absl::string_view b) {
if (!a.empty() && !b.empty()) {
return IrName(absl::StrCat(a, ".", b));
}
return IrName(absl::StrCat(a, b));
}
std::string IrName(const HloInstruction* a, absl::string_view b) {
return IrName(a->name(), b);
}
std::string SanitizeFunctionName(std::string function_name) {
// The backend with the strictest requirements on function names is NVPTX, so
// we sanitize to its requirements.
//
// A slightly stricter version of the NVPTX requirements is that names match
// /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$"
// are illegal.
// Sanitize chars in function_name.
std::transform(function_name.begin(), function_name.end(),
function_name.begin(), [](char c) {
if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
('0' <= c && c <= '9') || c == '_' || c == '$') {
return c;
}
return '_';
});
// Ensure the name isn't empty.
if (function_name.empty()) {
function_name = "__unnamed";
}
// Ensure the name doesn't start with a number.
if (!function_name.empty() && function_name[0] >= '0' &&
function_name[0] <= '9') {
function_name.insert(function_name.begin(), '_');
}
// Ensure the name isn't "_" or "$".
if (function_name == "_" || function_name == "$") {
function_name += '_';
}
return function_name;
}
void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
}
void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
if (llvm::Instruction* terminator = blk->getTerminator()) {
builder->SetInsertPoint(terminator);
} else {
builder->SetInsertPoint(blk);
}
}
llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
llvm::IRBuilder<>* builder) {
auto size = rotand->getType()->getPrimitiveSizeInBits();
auto size_value = builder->getIntN(size, size);
auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
return builder->CreateOr(
builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
builder->CreateLShr(rotand, mod(rotor)));
}
int64_t ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
unsigned pointer_size = data_layout.getPointerSize();
return ShapeUtil::ByteSizeOf(shape, pointer_size);
}
llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) {
llvm::FastMathFlags flags;
const auto& options = module_config.debug_options();
if (!options.xla_cpu_enable_fast_math()) {
return flags;
}
// Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal,
// AllowContract, and ApproxFunc.
flags.setFast();
flags.setNoNaNs(!options.xla_cpu_fast_math_honor_nans());
flags.setNoInfs(!options.xla_cpu_fast_math_honor_infs());
flags.setAllowReciprocal(!options.xla_cpu_fast_math_honor_division());
flags.setApproxFunc(!options.xla_cpu_fast_math_honor_functions());
return flags;
}
std::map<int, llvm::MDNode*> MergeMetadata(
llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
const std::map<int, llvm::MDNode*>& b) {
// We should extend this as needed to deal with other kinds of metadata like
// !dereferenceable and !range.
std::map<int, llvm::MDNode*> result;
for (auto kind_md_pair : a) {
if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
for (const auto& scope_a : kind_md_pair.second->operands()) {
scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
}
auto it = b.find(kind_md_pair.first);
if (it != b.end()) {
for (const auto& scope_b : it->second->operands()) {
if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
}
}
}
result[llvm::LLVMContext::MD_alias_scope] =
llvm::MDNode::get(*context, union_of_scopes);
} else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
for (const auto& scope_a : kind_md_pair.second->operands()) {
scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
}
auto it = b.find(kind_md_pair.first);
if (it != b.end()) {
for (const auto& scope_b : it->second->operands()) {
if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
}
}
}
if (!intersection_of_scopes.empty()) {
result[llvm::LLVMContext::MD_noalias] =
llvm::MDNode::get(*context, intersection_of_scopes);
}
}
}
return result;
}
static Status CreateAndWriteStringToFile(const std::string& directory_name,
const std::string& file_name,
const std::string& text) {
std::unique_ptr<tensorflow::WritableFile> f;
TF_RETURN_IF_ERROR(
tensorflow::Env::Default()->RecursivelyCreateDir(directory_name));
TF_RETURN_IF_ERROR(
tensorflow::Env::Default()->NewWritableFile(file_name, &f));
TF_RETURN_IF_ERROR(f->Append(text));
TF_RETURN_IF_ERROR(f->Close());
return ::tensorflow::OkStatus();
}
void DumpIrIfEnabled(const HloModule& hlo_module,
const llvm::Module& llvm_module, bool optimized,
absl::string_view filename_suffix) {
const auto& debug_opts = hlo_module.config().debug_options();
if (!DumpingEnabledForHloModule(hlo_module)) {
return;
}
// We can end up compiling different modules with the same name when using
// XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously
// dumped from the same process in such cases.
std::string suffix =
absl::StrCat("ir-", optimized ? "with" : "no", "-opt",
filename_suffix.empty() ? "" : ".", filename_suffix);
DumpToFileInDirOrStdout(hlo_module, "", absl::StrCat(suffix, ".ll"),
DumpModuleToString(llvm_module));
// For some models the embedded constants can be huge, so also dump the module
// with the constants stripped to get IR that is easier to manipulate. Skip
// this if we're dumping to stdout; there's no point in duplicating everything
// when writing to the terminal.
if (!DumpingToStdout(debug_opts)) {
DumpToFileInDir(hlo_module, "", absl::StrCat(suffix, "-noconst.ll"),
DumpModuleToString(*DropConstantInitializers(llvm_module)));
}
}
llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
llvm::GlobalValue::LinkageTypes linkage,
const HloModuleConfig& module_config,
absl::string_view name,
llvm::Module* module) {
llvm::Function* function =
llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
function->setCallingConv(llvm::CallingConv::C);
function->addFnAttr("no-frame-pointer-elim", "false");
// Generate unwind information so that GDB can crawl through the stack frames
// created by the JIT compiled code.
function->setUWTableKind(llvm::UWTableKind::Default);
// Tensorflow always flushes denormals to zero, let LLVM know that flushing
// denormals is safe. This allows vectorization using ARM's neon instruction
// set.
function->addFnAttr("denormal-fp-math", "preserve-sign");
// Add the optimize attribute to the function if optimizing for size. This
// controls internal behavior of some optimization passes (e.g. loop
// unrolling).
if (cpu::options::OptimizeForSizeRequested(module_config)) {
function->addFnAttr(llvm::Attribute::OptimizeForSize);
}
return function;
}
std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
llvm::Value* src0,
llvm::Value* src1) {
CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32);
CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32);
llvm::Type* int64_ty = b->getInt64Ty();
src0 = b->CreateZExt(src0, int64_ty);
src1 = b->CreateZExt(src1, int64_ty);
return SplitInt64ToInt32s(b, b->CreateMul(src0, src1));
}
std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
llvm::IRBuilder<>* b, llvm::Value* value_64bits) {
CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64);
llvm::Type* int32_ty = b->getInt32Ty();
llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty);
llvm::Value* high_32bits =
b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty);
return std::make_pair(low_32bits, high_32bits);
}
unsigned GetGlobalMemoryAddressSpace() { return 1; }
llvm::GlobalVariable* GetOrCreateVariableForRngState(llvm::Module* module,
llvm::IRBuilder<>* b) {
static const char* kRngStateVariableName = "rng_state";
llvm::GlobalVariable* state_ptr =
module->getNamedGlobal(kRngStateVariableName);
if (!state_ptr) {
llvm::Type* state_type = b->getInt128Ty();
// Use a non-zero initial value as zero state can cause the result of the
// first random number generation not passing the chi-square test. The
// values used here are arbitrarily chosen, any non-zero values should be
// fine.
state_ptr = new llvm::GlobalVariable(
/*M=*/*module,
/*Ty=*/state_type,
/*isConstant=*/false,
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
/*Initializer=*/llvm::ConstantInt::get(b->getInt128Ty(), 0x7012395ull),
/*Name=*/kRngStateVariableName,
/*InsertBefore=*/nullptr,
/*TLMode=*/llvm::GlobalValue::NotThreadLocal,
/*AddressSpace=*/GetGlobalMemoryAddressSpace(),
/*isExternallyInitialized=*/false);
}
return state_ptr;
}
llvm::Value* RngGetAndUpdateState(uint64_t delta, llvm::Module* module,
llvm::IRBuilder<>* builder) {
llvm::GlobalVariable* state_ptr =
GetOrCreateVariableForRngState(module, builder);
llvm::LoadInst* state_value_old =
builder->CreateLoad(state_ptr->getValueType(), state_ptr, "load_state");
llvm::Value* state_value_new = builder->CreateAdd(
state_value_old,
llvm::ConstantInt::get(state_value_old->getType(), delta));
builder->CreateStore(state_value_new, state_ptr);
return state_value_old;
}
llvm::BasicBlock* EmitReturnBlock(llvm::IRBuilder<>* b) {
llvm::Function* function = b->GetInsertBlock()->getParent();
llvm::Module* module = b->GetInsertBlock()->getModule();
llvm::IRBuilder<>::InsertPointGuard guard(*b);
llvm::BasicBlock* early_return =
llvm::BasicBlock::Create(/*Context=*/module->getContext(),
/*Name=*/"early_return",
/*Parent=*/function);
b->SetInsertPoint(early_return);
b->CreateRetVoid();
return early_return;
}
void EmitEarlyReturn(llvm::Value* condition, llvm::IRBuilder<>* b,
llvm::BasicBlock* return_block) {
if (!return_block) {
return_block = EmitReturnBlock(b);
}
llvm::BasicBlock* continued;
// Implicitly check whtether we are already at the end of unterminated block.
if (b->GetInsertBlock()->getTerminator() == nullptr) {
// If we are generating code into an incomplete basic block we can just
// create a new basic block to jump to after our conditional branch.
continued = llvm_ir::CreateBasicBlock(/*insert_before=*/nullptr,
/*name=*/"", b);
} else {
// If we are generating code into a basic block that already has code, we
// need to split that block so as to not disturb the existing code.
auto original = b->GetInsertBlock();
continued = original->splitBasicBlock(b->GetInsertPoint());
// Remove the auto-generated unconditional branch to replace with our
// conditional branch.
original->getTerminator()->eraseFromParent();
b->SetInsertPoint(original);
}
b->CreateCondBr(condition, continued, return_block);
b->SetInsertPoint(continued, continued->getFirstInsertionPt());
}
} // namespace llvm_ir
} // namespace xla