blob: 68b252e7b5a5b3ec1a9cfe366c91fef0451666cf [file] [log] [blame] [edit]
//===- FunctionUtils.cpp - Implementation of function utilities -----------===//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===----------------------------------------------------------------------===//
//
// This file defines utilities on LLVM Functions that are used as part of the AD
// process.
//
//===----------------------------------------------------------------------===//
#include "FunctionUtils.h"
#include "DiffeGradientUtils.h"
#include "EnzymeLogic.h"
#include "GradientUtils.h"
#include "LibraryFuncs.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/LazyValueInfo.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/MemoryDependenceAnalysis.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include <set>
#if LLVM_VERSION_MAJOR < 16
#include "llvm/Analysis/CFLSteensAliasAnalysis.h"
#endif
#include "llvm/Analysis/DependenceAnalysis.h"
#include "llvm/Analysis/TypeBasedAliasAnalysis.h"
#include "llvm/CodeGen/UnreachableBlockElim.h"
#include "llvm/Analysis/PhiValues.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScopedNoAliasAA.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Support/TimeProfiler.h"
#include "llvm/Transforms/IPO/FunctionAttrs.h"
#include "llvm/Transforms/Utils/Mem2Reg.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h"
#include "llvm/Transforms/Scalar/DCE.h"
#include "llvm/Transforms/Scalar/DeadStoreElimination.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Scalar/IndVarSimplify.h"
#include "llvm/Transforms/Scalar/InstSimplifyPass.h"
#include "llvm/Transforms/Scalar/LoopIdiomRecognize.h"
#include "llvm/Transforms/Scalar/MemCpyOptimizer.h"
#include "llvm/Transforms/Scalar/SROA.h"
#include "llvm/Transforms/Scalar/SimplifyCFG.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/LCSSA.h"
#include "llvm/Transforms/Utils/LowerInvoke.h"
#include "llvm/Transforms/IPO/FunctionAttrs.h"
#include "llvm/Transforms/Scalar/DCE.h"
#include "llvm/Transforms/Scalar/LoopDeletion.h"
#include "llvm/Transforms/Scalar/LoopRotation.h"
#include "llvm/Transforms/Utils/CodeExtractor.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/IR/LegacyPassManager.h"
#if LLVM_VERSION_MAJOR <= 16
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#endif
#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include <optional>
#include "CacheUtility.h"
#define addAttribute addAttributeAtIndex
#define removeAttribute removeAttributeAtIndex
#define getAttribute getAttributeAtIndex
#define hasAttribute hasAttributeAtIndex
#define DEBUG_TYPE "enzyme"
using namespace llvm;
extern "C" {
cl::opt<bool> EnzymePreopt("enzyme-preopt", cl::init(true), cl::Hidden,
cl::desc("Run enzyme preprocessing optimizations"));
cl::opt<bool> EnzymeInline("enzyme-inline", cl::init(false), cl::Hidden,
cl::desc("Force inlining of autodiff"));
cl::opt<bool> EnzymeNoAlias("enzyme-noalias", cl::init(false), cl::Hidden,
cl::desc("Force noalias of autodiff"));
#if LLVM_VERSION_MAJOR < 16
cl::opt<bool>
EnzymeAggressiveAA("enzyme-aggressive-aa", cl::init(false), cl::Hidden,
cl::desc("Use more unstable but aggressive LLVM AA"));
#endif
cl::opt<bool> EnzymeLowerGlobals(
"enzyme-lower-globals", cl::init(false), cl::Hidden,
cl::desc("Lower globals to locals assuming the global values are not "
"needed outside of this gradient"));
cl::opt<int>
EnzymeInlineCount("enzyme-inline-count", cl::init(10000), cl::Hidden,
cl::desc("Limit of number of functions to inline"));
cl::opt<bool> EnzymeCoalese("enzyme-coalese", cl::init(false), cl::Hidden,
cl::desc("Whether to coalese memory allocations"));
static cl::opt<bool> EnzymePHIRestructure(
"enzyme-phi-restructure", cl::init(false), cl::Hidden,
cl::desc("Whether to restructure phi's to have better unwrap behavior"));
cl::opt<bool>
EnzymeNameInstructions("enzyme-name-instructions", cl::init(false),
cl::Hidden,
cl::desc("Have enzyme name all instructions"));
cl::opt<bool> EnzymeSelectOpt("enzyme-select-opt", cl::init(true), cl::Hidden,
cl::desc("Run Enzyme select optimization"));
cl::opt<bool> EnzymeAutoSparsity("enzyme-auto-sparsity", cl::init(false),
cl::Hidden,
cl::desc("Run Enzyme auto sparsity"));
cl::opt<int> EnzymePostOptLevel(
"enzyme-post-opt-level", cl::init(0), cl::Hidden,
cl::desc("Post optimization level within Enzyme differentiated function"));
cl::opt<bool> EnzymeAlwaysInlineDiff(
"enzyme-always-inline", cl::init(false), cl::Hidden,
cl::desc("Mark generated functions as always-inline"));
}
/// Is the use of value val as an argument of call CI potentially captured
bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val) {
Function *F = CI->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
if (castinst->isCast())
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
F = fn;
}
}
if (F == nullptr)
return true;
if (F->getIntrinsicID() == Intrinsic::memset)
return false;
if (F->getIntrinsicID() == Intrinsic::memcpy)
return false;
if (F->getIntrinsicID() == Intrinsic::memmove)
return false;
auto arg = F->arg_begin();
for (size_t i = 0, size = CI->arg_size(); i < size; i++) {
if (val == CI->getArgOperand(i)) {
// This is a vararg, assume captured
if (arg == F->arg_end()) {
return true;
} else {
if (!arg->hasNoCaptureAttr()) {
return true;
}
}
}
if (arg != F->arg_end())
arg++;
}
// No argument captured
return false;
}
enum RecurType {
MaybeRecursive = 1,
NotRecursive = 2,
DefinitelyRecursive = 3,
};
/// Return whether this function eventually calls itself
static bool
IsFunctionRecursive(Function *F,
std::map<const Function *, RecurType> &Results) {
// If we haven't seen this function before, look at all callers
// and mark this as potentially recursive. If we see this function
// still as marked as MaybeRecursive, we will definitionally have
// found an eventual caller of the original function. If not,
// the function does not eventually call itself (in a static way)
if (Results.find(F) == Results.end()) {
Results[F] = MaybeRecursive; // staging
for (auto &BB : *F) {
for (auto &I : BB) {
if (auto call = dyn_cast<CallInst>(&I)) {
if (call->getCalledFunction() == nullptr)
continue;
if (call->getCalledFunction()->empty())
continue;
IsFunctionRecursive(call->getCalledFunction(), Results);
}
if (auto call = dyn_cast<InvokeInst>(&I)) {
if (call->getCalledFunction() == nullptr)
continue;
if (call->getCalledFunction()->empty())
continue;
IsFunctionRecursive(call->getCalledFunction(), Results);
}
}
}
if (Results[F] == MaybeRecursive) {
Results[F] = NotRecursive; // not recursive
}
} else if (Results[F] == MaybeRecursive) {
Results[F] = DefinitelyRecursive; // definitely recursive
}
assert(Results[F] != MaybeRecursive);
return Results[F] == DefinitelyRecursive;
}
static inline bool OnlyUsedInOMP(AllocaInst *AI) {
bool ompUse = false;
for (auto U : AI->users()) {
if (auto SI = dyn_cast<StoreInst>(U))
if (SI->getPointerOperand() == AI)
continue;
if (auto CI = dyn_cast<CallInst>(U)) {
if (auto F = CI->getCalledFunction()) {
if (F->getName() == "__kmpc_for_static_init_4" ||
F->getName() == "__kmpc_for_static_init_4u" ||
F->getName() == "__kmpc_for_static_init_8" ||
F->getName() == "__kmpc_for_static_init_8u") {
ompUse = true;
}
}
}
}
if (!ompUse)
return false;
return true;
}
void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) {
SmallVector<std::tuple<Value *, Value *, Instruction *>, 1> Todo;
for (auto U : AI->users()) {
Todo.push_back(
std::make_tuple((Value *)rep, (Value *)AI, cast<Instruction>(U)));
}
SmallVector<Instruction *, 1> toErase;
if (auto I = dyn_cast<Instruction>(AI))
toErase.push_back(I);
SmallVector<StoreInst *, 1> toPostCache;
while (Todo.size()) {
auto cur = Todo.back();
Todo.pop_back();
Value *rep = std::get<0>(cur);
Value *prev = std::get<1>(cur);
Value *inst = std::get<2>(cur);
if (auto ASC = dyn_cast<AddrSpaceCastInst>(inst)) {
auto AS = cast<PointerType>(rep->getType())->getAddressSpace();
if (AS == ASC->getDestAddressSpace()) {
ASC->replaceAllUsesWith(rep);
toErase.push_back(ASC);
continue;
}
ASC->setOperand(0, rep);
continue;
}
if (auto CI = dyn_cast<CastInst>(inst)) {
if (!CI->getType()->isPointerTy()) {
CI->setOperand(0, rep);
continue;
}
IRBuilder<> B(CI);
auto nCI0 = B.CreateCast(
CI->getOpcode(), rep,
#if LLVM_VERSION_MAJOR < 17
PointerType::get(CI->getType()->getPointerElementType(),
cast<PointerType>(rep->getType())->getAddressSpace())
#else
rep->getType()
#endif
);
if (auto nCI = dyn_cast<CastInst>(nCI0))
nCI->takeName(CI);
for (auto U : CI->users()) {
Todo.push_back(
std::make_tuple((Value *)nCI0, (Value *)CI, cast<Instruction>(U)));
}
toErase.push_back(CI);
continue;
}
if (auto GEP = dyn_cast<GetElementPtrInst>(inst)) {
IRBuilder<> B(GEP);
SmallVector<Value *, 1> ind(GEP->indices());
auto nGEP = cast<GetElementPtrInst>(
B.CreateGEP(GEP->getSourceElementType(), rep, ind));
nGEP->takeName(GEP);
for (auto U : GEP->users()) {
Todo.push_back(
std::make_tuple((Value *)nGEP, (Value *)GEP, cast<Instruction>(U)));
}
toErase.push_back(GEP);
continue;
}
if (auto P = dyn_cast<PHINode>(inst)) {
auto NumOperands = P->getNumIncomingValues();
SmallVector<Value *, 1> replacedOperands(NumOperands, nullptr);
for (size_t i = 0; i < NumOperands; i++)
if (P->getOperand(i) == prev)
replacedOperands[i] = rep;
for (auto tval : Todo) {
if (std::get<2>(tval) != P)
continue;
for (size_t i = 0; i < NumOperands; i++)
if (P->getOperand(i) == std::get<1>(tval)) {
replacedOperands[i] = std::get<0>(tval);
}
}
bool allReplaced = true;
for (size_t i = 0; i < NumOperands; i++) {
if (!replacedOperands[i]) {
allReplaced = false;
}
}
if (!allReplaced) {
bool remainingArePHIs = true;
for (auto v : Todo) {
if (isa<PHINode>(std::get<2>(v))) {
} else {
remainingArePHIs = false;
}
}
if (!remainingArePHIs) {
Todo.insert(Todo.begin(), cur);
llvm::errs() << " continuing\n";
continue;
}
} else {
IRBuilder<> B(&(*P->getParent()->getFirstNonPHIOrDbgOrLifetime()));
auto nP = B.CreatePHI(rep->getType(), P->getNumOperands());
for (size_t i = 0; i < NumOperands; i++) {
nP->addIncoming(replacedOperands[i], P->getIncomingBlock(i));
}
nP->takeName(P);
for (auto U : P->users()) {
Todo.push_back(
std::make_tuple((Value *)nP, (Value *)P, cast<Instruction>(U)));
}
toErase.push_back(P);
for (int i = Todo.size() - 1; i >= 0; i--) {
if (std::get<2>(Todo[i]) != P)
continue;
Todo.erase(Todo.begin() + i);
}
continue;
}
}
if (auto II = dyn_cast<IntrinsicInst>(inst)) {
if (isIntelSubscriptIntrinsic(*II)) {
const std::array<size_t, 4> idxArgsIndices{{0, 1, 2, 4}};
const size_t ptrArgIndex = 3;
SmallVector<Value *, 5> args(5);
for (auto i : idxArgsIndices) {
Value *idx = II->getOperand(i);
args[i] = idx;
}
args[ptrArgIndex] = rep;
IRBuilder<> B(II);
auto nII = cast<CallInst>(B.CreateCall(II->getCalledFunction(), args));
// Must copy the elementtype attribute as it is needed by the intrinsic
nII->addParamAttr(
ptrArgIndex,
II->getParamAttr(ptrArgIndex, Attribute::AttrKind::ElementType));
nII->takeName(II);
for (auto U : II->users()) {
Todo.push_back(
std::make_tuple((Value *)nII, (Value *)II, cast<Instruction>(U)));
}
toErase.push_back(II);
continue;
}
}
if (auto LI = dyn_cast<LoadInst>(inst)) {
LI->setOperand(0, rep);
continue;
}
if (auto SI = dyn_cast<StoreInst>(inst)) {
if (SI->getPointerOperand() == prev) {
SI->setOperand(1, rep);
toPostCache.push_back(SI);
continue;
}
}
if (auto MS = dyn_cast<MemSetInst>(inst)) {
IRBuilder<> B(MS);
Value *nargs[] = {rep, MS->getArgOperand(1), MS->getArgOperand(2),
MS->getArgOperand(3)};
Type *tys[] = {nargs[0]->getType(), nargs[2]->getType()};
auto nMS = cast<CallInst>(B.CreateCall(
getIntrinsicDeclaration(MS->getParent()->getParent()->getParent(),
Intrinsic::memset, tys),
nargs));
nMS->copyMetadata(*MS);
nMS->setAttributes(MS->getAttributes());
toErase.push_back(MS);
continue;
}
if (auto MTI = dyn_cast<MemTransferInst>(inst)) {
IRBuilder<> B(MTI);
Value *nargs[4] = {MTI->getArgOperand(0), MTI->getArgOperand(1),
MTI->getArgOperand(2), MTI->getArgOperand(3)};
if (nargs[0] == prev)
nargs[0] = rep;
if (nargs[1] == prev)
nargs[1] = rep;
Type *tys[] = {nargs[0]->getType(), nargs[1]->getType(),
nargs[2]->getType()};
auto nMTI = cast<CallInst>(B.CreateCall(
getIntrinsicDeclaration(MTI->getParent()->getParent()->getParent(),
MTI->getIntrinsicID(), tys),
nargs));
nMTI->copyMetadata(*MTI);
nMTI->setAttributes(MTI->getAttributes());
toErase.push_back(MTI);
continue;
}
if (auto CI = dyn_cast<CallInst>(inst)) {
if (auto F = CI->getCalledFunction()) {
if (F->getName() == "julia.write_barrier" && legal) {
toErase.push_back(CI);
continue;
}
if (F->getName() == "julia.write_barrier_binding" && legal) {
toErase.push_back(CI);
continue;
}
}
IRBuilder<> B(CI);
auto Addr = B.CreateAddrSpaceCast(rep, prev->getType());
for (size_t i = 0; i < CI->arg_size(); i++) {
if (CI->getArgOperand(i) == prev) {
CI->setArgOperand(i, Addr);
}
}
continue;
}
if (auto I = dyn_cast<Instruction>(inst))
llvm::errs() << *I->getParent()->getParent() << "\n";
llvm_unreachable("Illegal address space propagation");
}
for (auto I : llvm::reverse(toErase)) {
I->eraseFromParent();
}
for (auto SI : toPostCache) {
IRBuilder<> B(SI->getNextNode());
PostCacheStore(SI, B);
}
}
/// Convert necessary stack allocations into mallocs for use in the reverse
/// pass. Specifically if we're not topLevel all allocations must be upgraded
/// Even if topLevel any allocations that aren't in the entry block (and
/// therefore may not be reachable in the reverse pass) must be upgraded.
static inline void
UpgradeAllocasToMallocs(Function *NewF, DerivativeMode mode,
SmallPtrSetImpl<llvm::BasicBlock *> &Unreachable) {
SmallVector<AllocaInst *, 4> ToConvert;
for (auto &BB : *NewF) {
if (Unreachable.count(&BB))
continue;
for (auto &I : BB) {
if (auto AI = dyn_cast<AllocaInst>(&I)) {
bool UsableEverywhere = AI->getParent() == &NewF->getEntryBlock();
// TODO use is_value_needed_in_reverse (requiring GradientUtils)
if (OnlyUsedInOMP(AI))
continue;
if (!UsableEverywhere || mode != DerivativeMode::ReverseModeCombined) {
ToConvert.push_back(AI);
}
}
}
}
#if LLVM_VERSION_MAJOR >= 22
Function *start_lifetime = nullptr;
Function *end_lifetime = nullptr;
#endif
for (auto AI : ToConvert) {
std::string nam = AI->getName().str();
AI->setName("");
#if LLVM_VERSION_MAJOR >= 22
for (auto U : llvm::make_early_inc_range(AI->users())) {
if (auto II = dyn_cast<IntrinsicInst>(U)) {
if (II->getIntrinsicID() == Intrinsic::lifetime_start) {
if (!start_lifetime) {
start_lifetime = cast<Function>(
NewF->getParent()
->getOrInsertFunction(
"llvm.enzyme.lifetime_start",
FunctionType::get(Type::getVoidTy(NewF->getContext()),
{}, true))
.getCallee());
}
IRBuilder<> B(II);
SmallVector<Value *, 2> args(II->arg_size());
for (unsigned i = 0; i < II->arg_size(); ++i) {
args[i] = II->getArgOperand(i);
}
auto newI = B.CreateCall(start_lifetime, args);
newI->takeName(II);
newI->setDebugLoc(II->getDebugLoc());
II->eraseFromParent();
continue;
}
if (II->getIntrinsicID() == Intrinsic::lifetime_end) {
if (!end_lifetime) {
end_lifetime = cast<Function>(
NewF->getParent()
->getOrInsertFunction(
"llvm.enzyme.lifetime_end",
FunctionType::get(Type::getVoidTy(NewF->getContext()),
{}, true))
.getCallee());
}
IRBuilder<> B(II);
SmallVector<Value *, 2> args(II->arg_size());
for (unsigned i = 0; i < II->arg_size(); ++i) {
args[i] = II->getArgOperand(i);
}
auto newI = B.CreateCall(end_lifetime, args);
newI->takeName(II);
newI->setDebugLoc(II->getDebugLoc());
II->eraseFromParent();
continue;
}
}
}
#endif
// Ensure we insert the malloc after the allocas
Instruction *insertBefore = AI;
while (isa<AllocaInst>(insertBefore->getNextNode())) {
insertBefore = insertBefore->getNextNode();
assert(insertBefore);
}
auto i64 = Type::getInt64Ty(NewF->getContext());
IRBuilder<> B(insertBefore);
CallInst *CI = nullptr;
Instruction *ZeroInst = nullptr;
auto rep = CreateAllocation(
B, AI->getAllocatedType(), B.CreateZExtOrTrunc(AI->getArraySize(), i64),
nam, &CI, /*ZeroMem*/ EnzymeZeroCache ? &ZeroInst : nullptr);
auto align = AI->getAlign().value();
CI->setMetadata(
"enzyme_fromstack",
MDNode::get(CI->getContext(),
{ConstantAsMetadata::get(ConstantInt::get(
IntegerType::get(AI->getContext(), 64), align))}));
for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type",
"enzymejl_allocart", "enzymejl_allocart_name"})
if (auto M = AI->getMetadata(MD))
CI->setMetadata(MD, M);
if (rep != CI) {
cast<Instruction>(rep)->setMetadata("enzyme_caststack",
MDNode::get(CI->getContext(), {}));
}
if (ZeroInst) {
ZeroInst->setMetadata("enzyme_zerostack",
MDNode::get(CI->getContext(), {}));
}
auto PT0 = cast<PointerType>(rep->getType());
auto PT1 = cast<PointerType>(AI->getType());
if (PT0->getAddressSpace() != PT1->getAddressSpace()) {
RecursivelyReplaceAddressSpace(AI, rep, /*legal*/ false);
} else {
assert(rep->getType() == AI->getType());
AI->replaceAllUsesWith(rep);
AI->eraseFromParent();
}
}
}
// Create a stack variable containing the size of the allocation
// error if not possible (e.g. not local)
static inline AllocaInst *
OldAllocationSize(Value *Ptr, CallInst *Loc, Function *NewF, IntegerType *T,
const std::map<CallInst *, Value *> &reallocSizes) {
IRBuilder<> B(&*NewF->getEntryBlock().begin());
AllocaInst *AI = B.CreateAlloca(T);
std::set<std::pair<Value *, Instruction *>> seen;
std::deque<std::pair<Value *, Instruction *>> todo = {{Ptr, Loc}};
while (todo.size()) {
auto next = todo.front();
todo.pop_front();
if (seen.count(next))
continue;
seen.insert(next);
if (auto CI = dyn_cast<CastInst>(next.first)) {
todo.push_back({CI->getOperand(0), CI});
continue;
}
// Assume zero size if realloc of undef pointer
if (isa<UndefValue>(next.first)) {
B.SetInsertPoint(next.second);
B.CreateStore(ConstantInt::get(T, 0), AI);
continue;
}
if (auto CE = dyn_cast<ConstantExpr>(next.first)) {
if (CE->isCast()) {
todo.push_back({CE->getOperand(0), next.second});
continue;
}
}
if (auto C = dyn_cast<Constant>(next.first)) {
if (C->isNullValue()) {
B.SetInsertPoint(next.second);
B.CreateStore(ConstantInt::get(T, 0), AI);
continue;
}
}
if (auto CI = dyn_cast<ConstantInt>(next.first)) {
// if negative or below 0xFFF this cannot possibly represent
// a real pointer, so ignore this case by setting to 0
if (CI->isNegative() || CI->getLimitedValue() <= 0xFFF) {
B.SetInsertPoint(next.second);
B.CreateStore(ConstantInt::get(T, 0), AI);
continue;
}
}
// Todo consider more general method for selects
if (auto SI = dyn_cast<SelectInst>(next.first)) {
if (auto C1 = dyn_cast<ConstantInt>(SI->getTrueValue())) {
// if negative or below 0xFFF this cannot possibly represent
// a real pointer, so ignore this case by setting to 0
if (C1->isNegative() || C1->getLimitedValue() <= 0xFFF) {
if (auto C2 = dyn_cast<ConstantInt>(SI->getFalseValue())) {
if (C2->isNegative() || C2->getLimitedValue() <= 0xFFF) {
B.SetInsertPoint(next.second);
B.CreateStore(ConstantInt::get(T, 0), AI);
continue;
}
}
}
}
}
if (auto PN = dyn_cast<PHINode>(next.first)) {
for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
todo.push_back({PN->getIncomingValue(i),
PN->getIncomingBlock(i)->getTerminator()});
}
continue;
}
if (auto CI = dyn_cast<CallInst>(next.first)) {
if (auto F = CI->getCalledFunction()) {
if (F->getName() == "malloc") {
B.SetInsertPoint(next.second);
B.CreateStore(CI->getArgOperand(0), AI);
continue;
}
if (F->getName() == "calloc") {
B.SetInsertPoint(next.second);
B.CreateStore(B.CreateMul(CI->getArgOperand(0), CI->getArgOperand(1)),
AI);
continue;
}
if (F->getName() == "realloc") {
assert(reallocSizes.find(CI) != reallocSizes.end());
B.SetInsertPoint(next.second);
B.CreateStore(reallocSizes.find(CI)->second, AI);
continue;
}
}
}
if (auto LI = dyn_cast<LoadInst>(next.first)) {
bool success = false;
for (Instruction *prev = LI->getPrevNode(); prev != nullptr;
prev = prev->getPrevNode()) {
if (auto CI = dyn_cast<CallInst>(prev)) {
if (auto F = CI->getCalledFunction()) {
if (F->getName() == "posix_memalign" &&
CI->getArgOperand(0) == LI->getOperand(0)) {
B.SetInsertPoint(next.second);
B.CreateStore(CI->getArgOperand(2), AI);
success = true;
break;
}
}
}
if (prev->mayWriteToMemory()) {
break;
}
}
if (success)
continue;
auto v2 = simplifyLoad(LI);
if (v2) {
todo.push_back({v2, next.second});
continue;
}
}
EmitFailure("DynamicReallocSize", Loc->getDebugLoc(), Loc,
"could not statically determine size of realloc ", *Loc,
" - because of - ", *next.first);
return AI;
std::string allocName;
switch (llvm::Triple(NewF->getParent()->getTargetTriple()).getOS()) {
case llvm::Triple::Linux:
case llvm::Triple::FreeBSD:
case llvm::Triple::NetBSD:
case llvm::Triple::OpenBSD:
case llvm::Triple::Fuchsia:
allocName = "malloc_usable_size";
break;
case llvm::Triple::Darwin:
case llvm::Triple::IOS:
case llvm::Triple::MacOSX:
case llvm::Triple::WatchOS:
case llvm::Triple::TvOS:
allocName = "malloc_size";
break;
case llvm::Triple::Win32:
allocName = "_msize";
break;
default:
llvm_unreachable("unknown reallocation for OS");
}
AttributeList list;
list = list.addFnAttribute(NewF->getContext(), Attribute::ReadOnly);
list = list.addParamAttribute(NewF->getContext(), 0, Attribute::ReadNone);
list = addFunctionNoCapture(NewF->getContext(), list, 0);
auto allocSize = NewF->getParent()->getOrInsertFunction(
allocName,
FunctionType::get(
IntegerType::get(NewF->getContext(), 8 * sizeof(size_t)),
{getInt8PtrTy(NewF->getContext())}, /*isVarArg*/ false),
list);
B.SetInsertPoint(Loc);
Value *sz = B.CreateZExtOrTrunc(B.CreateCall(allocSize, {Ptr}), T);
B.CreateStore(sz, AI);
return AI;
llvm_unreachable("DynamicReallocSize");
}
return AI;
}
void PreProcessCache::AlwaysInline(Function *NewF) {
PreservedAnalyses PA;
PA.preserve<AssumptionAnalysis>();
PA.preserve<TargetLibraryAnalysis>();
FAM.invalidate(*NewF, PA);
SmallVector<CallInst *, 2> ToInline;
// TODO this logic should be combined with the dynamic loop emission
// to minimize the number of branches if the realloc is used for multiple
// values with the same bound.
for (auto &BB : *NewF) {
for (auto &I : make_early_inc_range(BB)) {
if (hasMetadata(&I, "enzyme_zerostack")) {
if (isa<AllocaInst>(getBaseObject(I.getOperand(0)))) {
I.eraseFromParent();
continue;
}
}
if (auto CI = dyn_cast<CallInst>(&I)) {
if (!CI->getCalledFunction())
continue;
if (CI->getCalledFunction()->hasFnAttribute(Attribute::AlwaysInline))
ToInline.push_back(CI);
}
}
}
for (auto CI : ToInline) {
InlineFunctionInfo IFI;
#if LLVM_VERSION_MAJOR >= 18 && LLVM_VERSION_MAJOR < 21
auto F = CI->getCalledFunction();
if (CI->getParent()->IsNewDbgInfoFormat != F->IsNewDbgInfoFormat) {
if (CI->getParent()->IsNewDbgInfoFormat) {
F->convertToNewDbgValues();
} else {
F->convertFromNewDbgValues();
}
}
#endif
InlineFunction(*CI, IFI);
}
}
// Simplify all extractions to use inserted values, if possible.
void simplifyExtractions(Function *NewF) {
// First rewrite/remove any extractions
for (auto &BB : *NewF) {
IRBuilder<> B(&BB);
auto first = BB.begin();
auto last = BB.empty() ? BB.end() : std::prev(BB.end());
for (auto it = first; it != last;) {
auto inst = &*it;
// We iterate first here, since we may delete the instruction
// in the body
++it;
if (auto E = dyn_cast<ExtractValueInst>(inst)) {
auto rep = GradientUtils::extractMeta(B, E->getAggregateOperand(),
E->getIndices(), E->getName(),
/*fallback*/ false);
if (rep) {
E->replaceAllUsesWith(rep);
E->eraseFromParent();
}
}
}
}
// Now that there may be unused insertions, delete them. We keep a list of
// todo's since deleting an insertvalue may cause a different insertvalue to
// have no uses
SmallVector<InsertValueInst *, 1> todo;
for (auto &BB : *NewF) {
for (auto &inst : BB)
if (auto I = dyn_cast<InsertValueInst>(&inst)) {
if (I->getNumUses() == 0)
todo.push_back(I);
}
}
while (todo.size()) {
auto I = todo.pop_back_val();
auto op = I->getAggregateOperand();
I->eraseFromParent();
if (auto I2 = dyn_cast<InsertValueInst>(op))
if (I2->getNumUses() == 0)
todo.push_back(I2);
}
}
void PreProcessCache::LowerAllocAddr(Function *NewF) {
simplifyExtractions(NewF);
SmallVector<Instruction *, 1> Todo;
for (auto &BB : *NewF) {
for (auto &I : BB) {
if (hasMetadata(&I, "enzyme_backstack")) {
Todo.push_back(&I);
// TODO
// I.eraseMetadata("enzyme_backstack");
}
}
}
for (auto T : Todo) {
auto T0 = T->getOperand(0);
if (auto CI = dyn_cast<BitCastInst>(T0))
T0 = CI->getOperand(0);
auto AI = cast<AllocaInst>(T0);
llvm::Value *AIV = AI;
#if LLVM_VERSION_MAJOR < 17
if (AIV->getType()->getPointerElementType() !=
T->getType()->getPointerElementType()) {
IRBuilder<> B(AI->getNextNode());
AIV = B.CreateBitCast(
AIV, PointerType::get(
T->getType()->getPointerElementType(),
cast<PointerType>(AI->getType())->getAddressSpace()));
}
#endif
RecursivelyReplaceAddressSpace(T, AIV, /*legal*/ true);
}
#if LLVM_VERSION_MAJOR >= 22
{
auto start_lifetime =
NewF->getParent()->getFunction("llvm.enzyme.lifetime_start");
auto end_lifetime =
NewF->getParent()->getFunction("llvm.enzyme.lifetime_end");
SmallVector<CallInst *, 1> Todo;
for (auto &BB : *NewF) {
for (auto &I : BB) {
if (auto CB = dyn_cast<CallInst>(&I)) {
if (!CB->getCalledFunction())
continue;
if (CB->getCalledFunction() == start_lifetime ||
CB->getCalledFunction() == end_lifetime) {
Todo.push_back(CB);
}
}
}
}
for (auto CB : Todo) {
if (!isa<AllocaInst>(CB->getArgOperand(1))) {
CB->eraseFromParent();
continue;
}
IRBuilder<> B(CB);
if (CB->getCalledFunction() == start_lifetime) {
B.CreateLifetimeStart(CB->getArgOperand(1),
cast<ConstantInt>(CB->getArgOperand(0)));
} else {
B.CreateLifetimeEnd(CB->getArgOperand(1),
cast<ConstantInt>(CB->getArgOperand(0)));
}
CB->eraseFromParent();
}
}
#endif
}
/// Calls to realloc with an appropriate implementation
void PreProcessCache::ReplaceReallocs(Function *NewF, bool mem2reg) {
if (mem2reg) {
auto PA = PromotePass().run(*NewF, FAM);
FAM.invalidate(*NewF, PA);
#if !defined(FLANG)
PA = GVNPass().run(*NewF, FAM);
#else
PA = GVN().run(*NewF, FAM);
#endif
FAM.invalidate(*NewF, PA);
}
SmallVector<CallInst *, 4> ToConvert;
std::map<CallInst *, Value *> reallocSizes;
IntegerType *T = nullptr;
for (auto &BB : *NewF) {
for (auto &I : BB) {
if (auto CI = dyn_cast<CallInst>(&I)) {
if (auto F = CI->getCalledFunction()) {
if (F->getName() == "realloc") {
ToConvert.push_back(CI);
IRBuilder<> B(CI->getNextNode());
T = cast<IntegerType>(CI->getArgOperand(1)->getType());
reallocSizes[CI] = B.CreatePHI(T, 0);
}
}
}
}
}
SmallVector<AllocaInst *, 4> memoryLocations;
for (auto CI : ToConvert) {
assert(T);
AllocaInst *AI =
OldAllocationSize(CI->getArgOperand(0), CI, NewF, T, reallocSizes);
BasicBlock *resize =
BasicBlock::Create(CI->getContext(), "resize" + CI->getName(), NewF);
assert(resize->getParent() == NewF);
BasicBlock *splitParent = CI->getParent();
BasicBlock *nextBlock = splitParent->splitBasicBlock(CI);
splitParent->getTerminator()->eraseFromParent();
IRBuilder<> B(splitParent);
Value *p = CI->getArgOperand(0);
Value *req = CI->getArgOperand(1);
Value *old = B.CreateLoad(AI->getAllocatedType(), AI);
Value *cmp = B.CreateICmpULE(req, old);
// if (req < old)
B.CreateCondBr(cmp, nextBlock, resize);
B.SetInsertPoint(resize);
// size_t newsize = nextPowerOfTwo(req);
// void* next = malloc(newsize);
// memcpy(next, p, newsize);
// free(p);
// return { next, newsize };
Value *newsize = nextPowerOfTwo(B, req);
Module *M = NewF->getParent();
Type *BPTy = getInt8PtrTy(NewF->getContext());
auto MallocFunc =
M->getOrInsertFunction("malloc", BPTy, newsize->getType());
auto next = B.CreateCall(MallocFunc, newsize);
B.SetInsertPoint(resize);
auto volatile_arg = ConstantInt::getFalse(CI->getContext());
Value *nargs[] = {next, p, old, volatile_arg};
Type *tys[] = {next->getType(), p->getType(), old->getType()};
auto memcpyF =
getIntrinsicDeclaration(NewF->getParent(), Intrinsic::memcpy, tys);
auto mem = cast<CallInst>(B.CreateCall(memcpyF, nargs));
mem->setCallingConv(memcpyF->getCallingConv());
Type *VoidTy = Type::getVoidTy(M->getContext());
auto FreeFunc = M->getOrInsertFunction("free", VoidTy, BPTy);
B.CreateCall(FreeFunc, p);
B.SetInsertPoint(resize);
B.CreateBr(nextBlock);
// else
// return { p, old }
B.SetInsertPoint(&*nextBlock->begin());
PHINode *retPtr = B.CreatePHI(CI->getType(), 2);
retPtr->addIncoming(p, splitParent);
retPtr->addIncoming(next, resize);
CI->replaceAllUsesWith(retPtr);
std::string nam = CI->getName().str();
CI->setName("");
retPtr->setName(nam);
Value *nextSize = B.CreateSelect(cmp, old, req);
reallocSizes[CI]->replaceAllUsesWith(nextSize);
cast<PHINode>(reallocSizes[CI])->eraseFromParent();
reallocSizes[CI] = nextSize;
}
for (auto CI : ToConvert) {
CI->eraseFromParent();
}
PreservedAnalyses PA;
FAM.invalidate(*NewF, PA);
PA = PromotePass().run(*NewF, FAM);
FAM.invalidate(*NewF, PA);
}
Function *CreateMPIWrapper(Function *F) {
std::string name = ("enzyme_wrapmpi$$" + F->getName() + "#").str();
if (auto W = F->getParent()->getFunction(name))
return W;
Type *types = {F->getFunctionType()->getParamType(0)};
auto FT = FunctionType::get(F->getReturnType(), types, false);
Function *W = Function::Create(FT, GlobalVariable::InternalLinkage, name,
F->getParent());
llvm::Attribute::AttrKind attrs[] = {
Attribute::WillReturn,
Attribute::MustProgress,
#if LLVM_VERSION_MAJOR < 16
Attribute::ReadOnly,
#endif
Attribute::Speculatable,
Attribute::NoUnwind,
Attribute::AlwaysInline,
Attribute::NoFree,
Attribute::NoSync,
#if LLVM_VERSION_MAJOR < 16
Attribute::InaccessibleMemOnly
#endif
};
for (auto attr : attrs) {
W->addFnAttr(attr);
}
#if LLVM_VERSION_MAJOR >= 16
W->setOnlyAccessesInaccessibleMemory();
W->setOnlyReadsMemory();
#endif
W->addFnAttr(Attribute::get(F->getContext(), "enzyme_inactive"));
BasicBlock *entry = BasicBlock::Create(W->getContext(), "entry", W);
IRBuilder<> B(entry);
auto alloc = B.CreateAlloca(F->getReturnType());
Value *args[] = {W->arg_begin(), alloc};
auto T = F->getFunctionType()->getParamType(1);
if (!isa<PointerType>(T)) {
assert(isa<IntegerType>(T));
args[1] = B.CreatePtrToInt(args[1], T);
}
B.CreateCall(F, args);
B.CreateRet(B.CreateLoad(F->getReturnType(), alloc));
return W;
}
static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) {
DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(NewF);
SmallVector<CallBase *, 4> Todo;
SmallVector<CallBase *, 4> OMPBounds;
for (auto &BB : NewF) {
for (auto &I : BB) {
if (auto CI = dyn_cast<CallBase>(&I)) {
Function *Fn = CI->getCalledFunction();
if (Fn == nullptr)
continue;
if (Fn->getName() == "MPI_Comm_rank" ||
Fn->getName() == "PMPI_Comm_rank" ||
Fn->getName() == "MPI_Comm_size" ||
Fn->getName() == "PMPI_Comm_size") {
Todo.push_back(CI);
}
if (Fn->getName() == "__kmpc_for_static_init_4" ||
Fn->getName() == "__kmpc_for_static_init_4u" ||
Fn->getName() == "__kmpc_for_static_init_8" ||
Fn->getName() == "__kmpc_for_static_init_8u") {
OMPBounds.push_back(CI);
}
}
}
}
if (Todo.size() == 0 && OMPBounds.size() == 0)
return;
for (auto CI : Todo) {
IRBuilder<> B(CI);
Value *arg[] = {CI->getArgOperand(0)};
SmallVector<OperandBundleDef, 2> Defs;
CI->getOperandBundlesAsDefs(Defs);
CallBase *res = nullptr;
if (auto II = dyn_cast<InvokeInst>(CI))
res = B.CreateInvoke(CreateMPIWrapper(CI->getCalledFunction()),
II->getNormalDest(), II->getUnwindDest(), arg, Defs);
else
res = B.CreateCall(CreateMPIWrapper(CI->getCalledFunction()), arg, Defs);
Value *storePointer = CI->getArgOperand(1);
// Comm_rank and Comm_size return Err, assume 0 is success
CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0));
CI->eraseFromParent();
while (auto Cast = dyn_cast<CastInst>(storePointer)) {
storePointer = Cast->getOperand(0);
if (Cast->use_empty())
Cast->eraseFromParent();
}
B.SetInsertPoint(res);
if (auto PT = dyn_cast<PointerType>(storePointer->getType())) {
(void)PT;
#if LLVM_VERSION_MAJOR < 17
if (PT->getContext().supportsTypedPointers()) {
if (PT->getPointerElementType() != res->getType())
storePointer = B.CreateBitCast(
storePointer,
PointerType::get(res->getType(), PT->getAddressSpace()));
}
#endif
} else {
assert(isa<IntegerType>(storePointer->getType()));
storePointer = B.CreateIntToPtr(storePointer,
PointerType::getUnqual(res->getType()));
}
if (isa<AllocaInst>(storePointer)) {
// If this is only loaded from, immedaitely replace
// Immediately replace all dominated stores.
SmallVector<LoadInst *, 2> LI;
bool nonload = false;
for (auto &U : storePointer->uses()) {
if (auto L = dyn_cast<LoadInst>(U.getUser())) {
LI.push_back(L);
} else
nonload = true;
}
if (!nonload) {
for (auto L : LI) {
if (DT.dominates(res, L)) {
L->replaceAllUsesWith(res);
L->eraseFromParent();
}
}
}
}
if (auto II = dyn_cast<InvokeInst>(res)) {
B.SetInsertPoint(II->getNormalDest()->getFirstNonPHI());
} else {
B.SetInsertPoint(res->getNextNode());
}
B.CreateStore(res, storePointer);
}
for (auto Bound : OMPBounds) {
for (int i = 4; i <= 6; i++) {
auto AI = cast<AllocaInst>(Bound->getArgOperand(i));
IRBuilder<> B(AI);
auto AI2 = B.CreateAlloca(AI->getAllocatedType(), nullptr,
AI->getName() + "_smpl");
B.SetInsertPoint(Bound);
B.CreateStore(B.CreateLoad(AI->getAllocatedType(), AI), AI2);
Bound->setArgOperand(i, AI2);
if (auto II = dyn_cast<InvokeInst>(Bound)) {
B.SetInsertPoint(II->getNormalDest()->getFirstNonPHI());
} else {
B.SetInsertPoint(Bound->getNextNode());
}
B.CreateStore(B.CreateLoad(AI2->getAllocatedType(), AI2), AI);
addCallSiteNoCapture(Bound, i);
}
}
PreservedAnalyses PA;
PA.preserve<AssumptionAnalysis>();
PA.preserve<TargetLibraryAnalysis>();
PA.preserve<LoopAnalysis>();
PA.preserve<DominatorTreeAnalysis>();
PA.preserve<PostDominatorTreeAnalysis>();
FAM.invalidate(NewF, PA);
}
/// Perform recursive inlinining on NewF up to the given limit
static void ForceRecursiveInlining(Function *NewF, size_t Limit) {
std::map<const Function *, RecurType> RecurResults;
for (size_t count = 0; count < Limit; count++) {
for (auto &BB : *NewF) {
for (auto &I : BB) {
if (auto CI = dyn_cast<CallInst>(&I)) {
if (CI->getCalledFunction() == nullptr)
continue;
if (CI->getCalledFunction()->empty())
continue;
if (startsWith(CI->getCalledFunction()->getName(),
"_ZN3std2io5stdio6_print"))
continue;
if (startsWith(CI->getCalledFunction()->getName(), "_ZN4core3fmt"))
continue;
if (startsWith(CI->getCalledFunction()->getName(),
"enzyme_wrapmpi$$"))
continue;
if (CI->getCalledFunction()->hasFnAttribute(
Attribute::ReturnsTwice) ||
CI->getCalledFunction()->hasFnAttribute(Attribute::NoInline))
continue;
if (IsFunctionRecursive(CI->getCalledFunction(), RecurResults)) {
LLVM_DEBUG(llvm::dbgs()
<< "not inlining recursive "
<< CI->getCalledFunction()->getName() << "\n");
continue;
}
InlineFunctionInfo IFI;
InlineFunction(*CI, IFI);
goto outermostContinue;
}
}
}
// No functions were inlined, break
break;
outermostContinue:;
}
}
void CanonicalizeLoops(Function *F, FunctionAnalysisManager &FAM) {
LoopSimplifyPass().run(*F, FAM);
DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(*F);
LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
AssumptionCache &AC = FAM.getResult<AssumptionAnalysis>(*F);
TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(*F);
MustExitScalarEvolution SE(*F, TLI, AC, DT, LI);
for (Loop *L : LI.getLoopsInPreorder()) {
auto pair =
InsertNewCanonicalIV(L, Type::getInt64Ty(F->getContext()), "iv");
PHINode *CanonicalIV = pair.first;
assert(CanonicalIV);
RemoveRedundantIVs(
L->getHeader(), CanonicalIV, pair.second, SE,
[&](Instruction *I, Value *V) { I->replaceAllUsesWith(V); },
[&](Instruction *I) { I->eraseFromParent(); });
}
PreservedAnalyses PA;
PA.preserve<AssumptionAnalysis>();
PA.preserve<TargetLibraryAnalysis>();
PA.preserve<LoopAnalysis>();
PA.preserve<DominatorTreeAnalysis>();
PA.preserve<PostDominatorTreeAnalysis>();
PA.preserve<TypeBasedAA>();
PA.preserve<BasicAA>();
PA.preserve<ScopedNoAliasAA>();
FAM.invalidate(*F, PA);
}
void RemoveRedundantPHI(Function *F, FunctionAnalysisManager &FAM) {
DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(*F);
for (BasicBlock &BB : *F) {
for (BasicBlock::iterator II = BB.begin(); isa<PHINode>(II);) {
PHINode *PN = cast<PHINode>(II);
++II;
SmallPtrSet<Value *, 2> vals;
SmallPtrSet<PHINode *, 2> done;
SmallVector<PHINode *, 2> todo = {PN};
while (todo.size() > 0) {
PHINode *N = todo.back();
todo.pop_back();
if (done.count(N))
continue;
done.insert(N);
if (vals.size() == 0 && todo.size() == 0 && PN != N &&
DT.dominates(N, PN)) {
vals.insert(N);
break;
}
for (auto &v : N->incoming_values()) {
if (isa<UndefValue>(v))
continue;
if (auto NN = dyn_cast<PHINode>(v)) {
todo.push_back(NN);
continue;
}
vals.insert(v);
if (vals.size() > 1)
break;
}
if (vals.size() > 1)
break;
}
if (vals.size() == 1) {
auto V = *vals.begin();
if (!isa<Instruction>(V) || DT.dominates(cast<Instruction>(V), PN)) {
PN->replaceAllUsesWith(V);
PN->eraseFromParent();
}
}
}
}
}
PreProcessCache::PreProcessCache() {
// Explicitly chose AA passes that are stateless
// and will not be invalidated
FAM.registerPass([] { return TypeBasedAA(); });
FAM.registerPass([] { return BasicAA(); });
MAM.registerPass([] { return GlobalsAA(); });
// CallGraphAnalysis required for GlobalsAA
MAM.registerPass([] { return CallGraphAnalysis(); });
FAM.registerPass([] { return ScopedNoAliasAA(); });
// SCEVAA causes some breakage/segfaults
// disable for now, consider enabling in future
// FAM.registerPass([] { return SCEVAA(); });
#if LLVM_VERSION_MAJOR < 16
if (EnzymeAggressiveAA)
FAM.registerPass([] { return CFLSteensAA(); });
#endif
MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
LAM.registerPass([&] { return FunctionAnalysisManagerLoopProxy(FAM); });
FAM.registerPass([&] { return LoopAnalysisManagerFunctionProxy(LAM); });
FAM.registerPass([] {
auto AM = AAManager();
AM.registerFunctionAnalysis<BasicAA>();
AM.registerFunctionAnalysis<TypeBasedAA>();
AM.registerModuleAnalysis<GlobalsAA>();
AM.registerFunctionAnalysis<ScopedNoAliasAA>();
// broken for different reasons
// AM.registerFunctionAnalysis<SCEVAA>();
#if LLVM_VERSION_MAJOR < 16
if (EnzymeAggressiveAA)
AM.registerFunctionAnalysis<CFLSteensAA>();
#endif
return AM;
});
PassBuilder PB;
PB.registerModuleAnalyses(MAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
}
llvm::AAResults &
PreProcessCache::getAAResultsFromFunction(llvm::Function *NewF) {
return FAM.getResult<AAManager>(*NewF);
}
void setFullWillReturn(Function *NewF) {
for (auto &BB : *NewF) {
for (auto &I : BB) {
if (auto CI = dyn_cast<CallInst>(&I)) {
CI->addFnAttr(Attribute::WillReturn);
CI->addFnAttr(Attribute::MustProgress);
}
if (auto CI = dyn_cast<InvokeInst>(&I)) {
CI->addFnAttr(Attribute::WillReturn);
CI->addFnAttr(Attribute::MustProgress);
}
}
}
}
void SplitPHIs(llvm::Function &F) {
SetVector<Instruction *> todo;
for (auto &BB : F) {
for (auto &I : BB) {
if (isa<PHINode>(&I)) {
todo.insert(&I);
} else if (isa<SelectInst>(&I)) {
todo.insert(&I);
}
}
}
while (todo.size()) {
auto cur = todo.pop_back_val();
IRBuilder<> B(cur);
auto ST = dyn_cast<StructType>(cur->getType());
if (!ST)
continue;
bool justExtract = true;
for (auto U : cur->users()) {
if (!isa<ExtractValueInst>(U)) {
justExtract = false;
break;
}
if (cast<ExtractValueInst>(U)->getIndices().size() == 0) {
justExtract = false;
break;
}
}
if (!justExtract)
continue;
SmallVector<Value *, 1> replacements;
for (size_t i = 0, e = ST->getNumElements(); i < e; i++) {
if (auto cur2 = dyn_cast<PHINode>(cur)) {
auto nPhi =
B.CreatePHI(ST->getElementType(i), cur2->getNumIncomingValues(),
cur->getName() + ".extract." + std::to_string(i));
for (auto &&[blk, val] :
llvm::zip(cur2->blocks(), cur2->incoming_values())) {
IRBuilder B2(blk->getTerminator());
nPhi->addIncoming(GradientUtils::extractMeta(B2, val, i), blk);
}
replacements.push_back(nPhi);
todo.insert(nPhi);
} else {
auto cur3 = cast<SelectInst>(cur);
auto rep = B.CreateSelect(
cur3->getCondition(),
GradientUtils::extractMeta(B, cur3->getTrueValue(), i),
GradientUtils::extractMeta(B, cur3->getFalseValue(), i),
cur->getName() + ".extract." + std::to_string(i));
replacements.push_back(rep);
if (auto sel = dyn_cast<SelectInst>(rep))
todo.insert(sel);
}
}
for (auto &U : make_early_inc_range(cur->uses())) {
auto user = cast<ExtractValueInst>(U.getUser());
Value *rep = replacements[user->getIndices()[0]];
IRBuilder<> B(user);
if (user->getIndices().size() > 1)
rep = B.CreateExtractValue(rep, user->getIndices().slice(1));
assert(rep->getType() == user->getType());
user->replaceAllUsesWith(rep);
user->eraseFromParent();
}
cur->eraseFromParent();
}
}
Function *PreProcessCache::preprocessForClone(Function *F,
DerivativeMode mode) {
TimeTraceScope timeScope("preprocessForClone", F->getName());
if (mode == DerivativeMode::ReverseModeGradient)
mode = DerivativeMode::ReverseModePrimal;
if (mode == DerivativeMode::ForwardModeSplit)
mode = DerivativeMode::ReverseModePrimal;
// If we've already processed this, return the previous version
// and derive aliasing information
if (cache.find(std::make_pair(F, mode)) != cache.end()) {
Function *NewF = cache[std::make_pair(F, mode)];
return NewF;
}
Function *NewF =
Function::Create(F->getFunctionType(), F->getLinkage(),
"preprocess_" + F->getName(), F->getParent());
ValueToValueMapTy VMap;
for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) {
VMap[i] = j;
j->setName(i->getName());
if (EnzymeNoAlias && j->getType()->isPointerTy()) {
j->addAttr(Attribute::NoAlias);
}
++i;
++j;
}
SmallVector<ReturnInst *, 4> Returns;
if (!F->empty()) {
CloneFunctionInto(
NewF, F, VMap,
/*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly,
Returns, "", nullptr);
}
CloneOrigin[NewF] = F;
NewF->setAttributes(F->getAttributes());
if (EnzymeNoAlias)
for (auto j = NewF->arg_begin(); j != NewF->arg_end(); j++) {
if (j->getType()->isPointerTy()) {
j->addAttr(Attribute::NoAlias);
}
}
NewF->addFnAttr(Attribute::WillReturn);
NewF->addFnAttr(Attribute::MustProgress);
setFullWillReturn(NewF);
if (EnzymePreopt) {
if (EnzymeInline) {
ForceRecursiveInlining(NewF, /*Limit*/ EnzymeInlineCount);
setFullWillReturn(NewF);
PreservedAnalyses PA;
FAM.invalidate(*NewF, PA);
}
}
{
SmallVector<CallInst *, 4> ItersToErase;
for (auto &BB : *NewF) {
for (auto &I : BB) {
if (auto CI = dyn_cast<CallInst>(&I)) {
Function *called = CI->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
if (castinst->isCast()) {
if (auto fn = dyn_cast<Function>(castinst->getOperand(0)))
called = fn;
}
}
if (called && called->getName() == "__enzyme_iter") {
ItersToErase.push_back(CI);
}
}
}
}
for (auto Call : ItersToErase) {
IRBuilder<> B(Call);
Call->setArgOperand(
0, B.CreateAdd(Call->getArgOperand(0), Call->getArgOperand(1)));
}
}
// Assume allocations do not return null
{
TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(*F);
SmallVector<Instruction *, 4> CmpsToErase;
SmallVector<BasicBlock *, 4> BranchesToErase;
for (auto &BB : *NewF) {
for (auto &I : BB) {
if (auto IC = dyn_cast<ICmpInst>(&I)) {
if (!IC->isEquality())
continue;
for (int i = 0; i < 2; i++) {
if (isa<ConstantPointerNull>(IC->getOperand(1 - i)))
if (isAllocationCall(IC->getOperand(i), TLI)) {
for (auto U : IC->users()) {
if (auto BI = dyn_cast<BranchInst>(U))
BranchesToErase.push_back(BI->getParent());
}
IC->replaceAllUsesWith(
IC->getPredicate() == ICmpInst::ICMP_NE
? ConstantInt::getTrue(I.getContext())
: ConstantInt::getFalse(I.getContext()));
CmpsToErase.push_back(&I);
break;
}
}
}
}
}
for (auto I : CmpsToErase)
I->eraseFromParent();
for (auto BE : BranchesToErase)
ConstantFoldTerminator(BE);
}
SimplifyMPIQueries(*NewF, FAM);
{
auto PA = PromotePass().run(*NewF, FAM);
FAM.invalidate(*NewF, PA);
}
if (EnzymeLowerGlobals) {
SmallVector<CallInst *, 4> Calls;
SmallVector<ReturnInst *, 4> Returns;
for (BasicBlock &BB : *NewF) {
for (Instruction &I : BB) {
if (auto CI = dyn_cast<CallInst>(&I)) {
Calls.push_back(CI);
}
if (auto RI = dyn_cast<ReturnInst>(&I)) {
Returns.push_back(RI);
}
}
}
// TODO consider using TBAA and globals as well
// instead of just BasicAA
AAResults AA2(FAM.getResult<TargetLibraryAnalysis>(*NewF));
AA2.addAAResult(FAM.getResult<BasicAA>(*NewF));
AA2.addAAResult(FAM.getResult<TypeBasedAA>(*NewF));
AA2.addAAResult(FAM.getResult<ScopedNoAliasAA>(*NewF));
for (auto &g : NewF->getParent()->globals()) {
bool inF = false;
{
std::set<Constant *> seen;
std::deque<Constant *> todo = {(Constant *)&g};
while (todo.size()) {
auto GV = todo.front();
todo.pop_front();
if (!seen.insert(GV).second)
continue;
for (auto u : GV->users()) {
if (auto C = dyn_cast<Constant>(u)) {
todo.push_back(C);
} else if (auto I = dyn_cast<Instruction>(u)) {
if (I->getParent()->getParent() == NewF) {
inF = true;
goto doneF;
}
}
}
}
}
doneF:;
if (inF) {
bool activeCall = false;
bool hasWrite = false;
MemoryLocation Loc =
MemoryLocation(&g, LocationSize::beforeOrAfterPointer());
for (CallInst *CI : Calls) {
if (isa<IntrinsicInst>(CI))
continue;
Function *F = CI->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
if (castinst->isCast())
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
F = fn;
}
}
if (F && isMemFreeLibMFunction(F->getName())) {
continue;
}
if (F && F->getName().contains("__enzyme_integer")) {
continue;
}
if (F && F->getName().contains("__enzyme_pointer")) {
continue;
}
if (F && F->getName().contains("__enzyme_float")) {
continue;
}
if (F && F->getName().contains("__enzyme_double")) {
continue;
}
if (F && (startsWith(F->getName(), "f90io") ||
F->getName() == "ftnio_fmt_write64" ||
F->getName() == "__mth_i_ipowi" ||
F->getName() == "f90_pausea")) {
continue;
}
if (llvm::isModOrRefSet(AA2.getModRefInfo(CI, Loc))) {
llvm::errs() << " failed to inline global: " << g << " due to "
<< *CI << "\n";
activeCall = true;
break;
}
}
if (!activeCall) {
std::set<Value *> seen;
std::deque<Value *> todo = {(Value *)&g};
while (todo.size()) {
auto GV = todo.front();
todo.pop_front();
if (!seen.insert(GV).second)
continue;
for (auto u : GV->users()) {
if (isa<Constant>(u) || isa<GetElementPtrInst>(u) ||
isa<CastInst>(u) || isa<LoadInst>(u)) {
todo.push_back(u);
continue;
}
if (auto II = dyn_cast<IntrinsicInst>(u)) {
if (isIntelSubscriptIntrinsic(*II)) {
todo.push_back(u);
continue;
}
}
if (auto CI = dyn_cast<CallInst>(u)) {
Function *F = CI->getCalledFunction();
if (auto castinst =
dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
if (castinst->isCast())
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
F = fn;
}
}
if (F && isMemFreeLibMFunction(F->getName())) {
continue;
}
if (F && F->getName().contains("__enzyme_integer")) {
continue;
}
if (F && F->getName().contains("__enzyme_pointer")) {
continue;
}
if (F && F->getName().contains("__enzyme_float")) {
continue;
}
if (F && F->getName().contains("__enzyme_double")) {
continue;
}
if (F && (startsWith(F->getName(), "f90io") ||
F->getName() == "ftnio_fmt_write64" ||
F->getName() == "__mth_i_ipowi" ||
F->getName() == "f90_pausea")) {
continue;
}
if (couldFunctionArgumentCapture(CI, GV)) {
hasWrite = true;
goto endCheck;
}
if (llvm::isModSet(AA2.getModRefInfo(CI, Loc))) {
hasWrite = true;
goto endCheck;
}
}
else if (auto I = dyn_cast<Instruction>(u)) {
if (llvm::isModSet(AA2.getModRefInfo(I, Loc))) {
hasWrite = true;
goto endCheck;
}
}
}
}
}
endCheck:;
if (!activeCall && hasWrite) {
IRBuilder<> bb(&NewF->getEntryBlock(), NewF->getEntryBlock().begin());
AllocaInst *antialloca = bb.CreateAlloca(
g.getValueType(), g.getType()->getPointerAddressSpace(), nullptr,
g.getName() + "_local");
if (g.getAlignment()) {
antialloca->setAlignment(Align(g.getAlignment()));
}
std::map<Constant *, Value *> remap;
remap[&g] = antialloca;
std::deque<Constant *> todo = {&g};
while (todo.size()) {
auto GV = todo.front();
todo.pop_front();
if (&g != GV && remap.find(GV) != remap.end())
continue;
Value *replaced = nullptr;
if (remap.find(GV) != remap.end()) {
replaced = remap[GV];
} else if (auto CE = dyn_cast<ConstantExpr>(GV)) {
auto I = CE->getAsInstruction();
bb.Insert(I);
assert(isa<Constant>(I->getOperand(0)));
assert(remap[cast<Constant>(I->getOperand(0))]);
I->setOperand(0, remap[cast<Constant>(I->getOperand(0))]);
replaced = remap[GV] = I;
}
assert(replaced && "unhandled constantexpr");
SmallVector<std::pair<Instruction *, size_t>, 4> uses;
for (Use &U : GV->uses()) {
if (auto I = dyn_cast<Instruction>(U.getUser())) {
if (I->getParent()->getParent() == NewF) {
uses.emplace_back(I, U.getOperandNo());
}
}
if (auto C = dyn_cast<Constant>(U.getUser())) {
assert(C != &g);
todo.push_back(C);
}
}
for (auto &U : uses) {
U.first->setOperand(U.second, replaced);
}
}
Value *args[] = {
bb.CreateBitCast(antialloca, getInt8PtrTy(g.getContext())),
bb.CreateBitCast(&g, getInt8PtrTy(g.getContext())),
ConstantInt::get(
Type::getInt64Ty(g.getContext()),
g.getParent()->getDataLayout().getTypeAllocSizeInBits(
g.getValueType()) /
8),
ConstantInt::getFalse(g.getContext())};
Type *tys[] = {args[0]->getType(), args[1]->getType(),
args[2]->getType()};
auto intr =
getIntrinsicDeclaration(g.getParent(), Intrinsic::memcpy, tys);
{
auto cal = bb.CreateCall(intr, args);
if (g.getAlignment()) {
cal->addParamAttr(
0, Attribute::getWithAlignment(g.getContext(),
Align(g.getAlignment())));
cal->addParamAttr(
1, Attribute::getWithAlignment(g.getContext(),
Align(g.getAlignment())));
}
}
std::swap(args[0], args[1]);
for (ReturnInst *RI : Returns) {
IRBuilder<> IB(RI);
auto cal = IB.CreateCall(intr, args);
if (g.getAlignment()) {
cal->addParamAttr(
0, Attribute::getWithAlignment(g.getContext(),
Align(g.getAlignment())));
cal->addParamAttr(
1, Attribute::getWithAlignment(g.getContext(),
Align(g.getAlignment())));
}
}
}
}
}
auto Level = OptimizationLevel::O2;
PassBuilder PB;
FunctionPassManager FPM =
PB.buildFunctionSimplificationPipeline(Level, ThinOrFullLTOPhase::None);
auto PA = FPM.run(*F, FAM);
FAM.invalidate(*F, PA);
}
if (EnzymePreopt) {
{
auto PA = LowerInvokePass().run(*NewF, FAM);
FAM.invalidate(*NewF, PA);
}
{
auto PA = UnreachableBlockElimPass().run(*NewF, FAM);
FAM.invalidate(*NewF, PA);
}
{
auto PA = PromotePass().run(*NewF, FAM);
FAM.invalidate(*NewF, PA);
}
{
#if LLVM_VERSION_MAJOR >= 16 && !defined(FLANG)
auto PA = SROAPass(llvm::SROAOptions::ModifyCFG).run(*NewF, FAM);
#elif !defined(FLANG)
auto PA = SROAPass().run(*NewF, FAM);
#else
auto PA = SROA().run(*NewF, FAM);
#endif
FAM.invalidate(*NewF, PA);
}
if (mode != DerivativeMode::ForwardMode)
ReplaceReallocs(NewF);
{
#if LLVM_VERSION_MAJOR >= 16 && !defined(FLANG)
auto PA = SROAPass(llvm::SROAOptions::PreserveCFG).run(*NewF, FAM);
#elif !defined(FLANG)
auto PA = SROAPass().run(*NewF, FAM);
#else
auto PA = SROA().run(*NewF, FAM);
#endif
FAM.invalidate(*NewF, PA);
}
SimplifyCFGOptions scfgo;
{
auto PA = SimplifyCFGPass(scfgo).run(*NewF, FAM);
FAM.invalidate(*NewF, PA);
}
}
{
SplitPHIs(*NewF);
PreservedAnalyses PA;
PA.preserve<AssumptionAnalysis>();
PA.preserve<TargetLibraryAnalysis>();
PA.preserve<LoopAnalysis>();
PA.preserve<DominatorTreeAnalysis>();
PA.preserve<PostDominatorTreeAnalysis>();
PA.preserve<TypeBasedAA>();
PA.preserve<BasicAA>();
PA.preserve<ScopedNoAliasAA>();
PA.preserve<ScalarEvolutionAnalysis>();
PA.preserve<PhiValuesAnalysis>();
}
if (mode != DerivativeMode::ForwardMode)
ReplaceReallocs(NewF);
if (mode == DerivativeMode::ReverseModePrimal ||
mode == DerivativeMode::ReverseModeGradient ||
mode == DerivativeMode::ReverseModeCombined) {
// For subfunction calls upgrade stack allocations to mallocs
// to ensure availability in the reverse pass
auto unreachable = getGuaranteedUnreachable(NewF);
UpgradeAllocasToMallocs(NewF, mode, unreachable);
}
CanonicalizeLoops(NewF, FAM);
RemoveRedundantPHI(NewF, FAM);
// Run LoopSimplifyPass to ensure preheaders exist on all loops
{
auto PA = LoopSimplifyPass().run(*NewF, FAM);
FAM.invalidate(*NewF, PA);
}
{
for (auto &BB : *NewF) {
for (auto &I : make_early_inc_range(BB)) {
if (auto MTI = dyn_cast<MemTransferInst>(&I)) {
if (auto CI = dyn_cast<ConstantInt>(MTI->getOperand(2))) {
if (CI->getValue() == 0) {
MTI->eraseFromParent();
}
}
}
}
}
PreservedAnalyses PA;
PA.preserve<AssumptionAnalysis>();
PA.preserve<TargetLibraryAnalysis>();
PA.preserve<LoopAnalysis>();
PA.preserve<DominatorTreeAnalysis>();
PA.preserve<PostDominatorTreeAnalysis>();
PA.preserve<TypeBasedAA>();
PA.preserve<BasicAA>();
PA.preserve<ScopedNoAliasAA>();
PA.preserve<ScalarEvolutionAnalysis>();
PA.preserve<PhiValuesAnalysis>();
FAM.invalidate(*NewF, PA);
if (EnzymeNameInstructions) {
for (auto &Arg : NewF->args()) {
if (!Arg.hasName())
Arg.setName("arg");
}
for (BasicBlock &BB : *NewF) {
if (!BB.hasName())
BB.setName("bb");
for (Instruction &I : BB) {
if (!I.hasName() && !I.getType()->isVoidTy())
I.setName("i");
}
}
}
}
if (EnzymePHIRestructure) {
if (false) {
reset:;
PreservedAnalyses PA;
FAM.invalidate(*NewF, PA);
}
SmallVector<BasicBlock *, 4> MultiBlocks;
for (auto &B : *NewF) {
if (B.hasNPredecessorsOrMore(3))
MultiBlocks.push_back(&B);
}
LoopInfo &LI = FAM.getResult<LoopAnalysis>(*NewF);
for (BasicBlock *B : MultiBlocks) {
// Map of function edges to list of values possible
std::map<std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>,
std::set<BasicBlock *>>
done;
{
std::deque<std::tuple<
std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>,
BasicBlock *>>
Q; // newblock, target
for (auto P : predecessors(B)) {
Q.emplace_back(std::make_pair(P, B), P);
}
for (std::tuple<
std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>,
BasicBlock *>
trace;
Q.size() > 0;) {
trace = Q.front();
Q.pop_front();
auto edge = std::get<0>(trace);
auto block = edge.first;
auto target = std::get<1>(trace);
if (done[edge].count(target))
continue;
done[edge].insert(target);
Loop *blockLoop = LI.getLoopFor(block);
for (BasicBlock *Pred : predecessors(block)) {
// Don't go up the backedge as we can use the last value if desired
// via lcssa
if (blockLoop && blockLoop->getHeader() == block &&
blockLoop == LI.getLoopFor(Pred))
continue;
Q.push_back(
std::tuple<std::pair<BasicBlock *, BasicBlock *>, BasicBlock *>(
std::make_pair(Pred, block), target));
}
}
}
SmallPtrSet<BasicBlock *, 4> Preds;
for (auto &pair : done) {
Preds.insert(pair.first.first);
}
for (auto BB : Preds) {
bool illegal = false;
SmallPtrSet<BasicBlock *, 2> UnionSet;
size_t numSuc = 0;
for (BasicBlock *sucI : successors(BB)) {
numSuc++;
const auto &SI = done[std::make_pair(BB, sucI)];
if (SI.size() == 0) {
// sucI->getName();
illegal = true;
break;
}
for (auto si : SI) {
UnionSet.insert(si);
for (BasicBlock *sucJ : successors(BB)) {
if (sucI == sucJ)
continue;
if (done[std::make_pair(BB, sucJ)].count(si)) {
illegal = true;
goto endIllegal;
}
}
}
}
endIllegal:;
if (!illegal && numSuc > 1 && !B->hasNPredecessors(UnionSet.size())) {
BasicBlock *Ins =
BasicBlock::Create(BB->getContext(), "tmpblk", BB->getParent());
IRBuilder<> Builder(Ins);
for (auto &phi : B->phis()) {
auto nphi = Builder.CreatePHI(phi.getType(), 2);
SmallVector<BasicBlock *, 4> Blocks;
for (auto blk : UnionSet) {
nphi->addIncoming(phi.getIncomingValueForBlock(blk), blk);
phi.removeIncomingValue(blk, /*deleteifempty*/ false);
}
phi.addIncoming(nphi, Ins);
}
Builder.CreateBr(B);
for (auto blk : UnionSet) {
auto term = blk->getTerminator();
for (unsigned Idx = 0, NumSuccessors = term->getNumSuccessors();
Idx != NumSuccessors; ++Idx)
if (term->getSuccessor(Idx) == B)
term->setSuccessor(Idx, Ins);
}
goto reset;
}
}
}
}
if (EnzymePrint)
llvm::errs() << "after simplification :\n" << *NewF << "\n";
if (llvm::verifyFunction(*NewF, &llvm::errs())) {
llvm::errs() << *NewF << "\n";
report_fatal_error("function failed verification (1)");
}
cache[std::make_pair(F, mode)] = NewF;
return NewF;
}
FunctionType *getFunctionTypeForClone(
llvm::FunctionType *FTy, DerivativeMode mode, unsigned width,
llvm::Type *additionalArg, llvm::ArrayRef<DIFFE_TYPE> constant_args,
bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE returnType) {
SmallVector<Type *, 4> RetTypes;
if (returnValue == ReturnType::ArgsWithReturn ||
returnValue == ReturnType::Return) {
if (returnType != DIFFE_TYPE::CONSTANT &&
returnType != DIFFE_TYPE::OUT_DIFF) {
RetTypes.push_back(
GradientUtils::getShadowType(FTy->getReturnType(), width));
} else {
RetTypes.push_back(FTy->getReturnType());
}
} else if (returnValue == ReturnType::ArgsWithTwoReturns ||
returnValue == ReturnType::TwoReturns) {
RetTypes.push_back(FTy->getReturnType());
if (returnType != DIFFE_TYPE::CONSTANT &&
returnType != DIFFE_TYPE::OUT_DIFF) {
RetTypes.push_back(
GradientUtils::getShadowType(FTy->getReturnType(), width));
} else {
RetTypes.push_back(FTy->getReturnType());
}
}
SmallVector<Type *, 4> ArgTypes;
// The user might be deleting arguments to the function by specifying them in
// the VMap. If so, we need to not add the arguments to the arg ty vector
unsigned argno = 0;
for (auto &I : FTy->params()) {
ArgTypes.push_back(I);
if (constant_args[argno] == DIFFE_TYPE::DUP_ARG ||
constant_args[argno] == DIFFE_TYPE::DUP_NONEED) {
ArgTypes.push_back(GradientUtils::getShadowType(I, width));
} else if (constant_args[argno] == DIFFE_TYPE::OUT_DIFF) {
RetTypes.push_back(GradientUtils::getShadowType(I, width));
}
++argno;
}
if (diffeReturnArg) {
assert(!FTy->getReturnType()->isVoidTy());
ArgTypes.push_back(
GradientUtils::getShadowType(FTy->getReturnType(), width));
}
if (additionalArg) {
ArgTypes.push_back(additionalArg);
}
Type *RetType = StructType::get(FTy->getContext(), RetTypes);
if (returnValue == ReturnType::TapeAndTwoReturns ||
returnValue == ReturnType::TapeAndReturn ||
returnValue == ReturnType::Tape) {
RetTypes.clear();
RetTypes.push_back(getDefaultAnonymousTapeType(FTy->getContext()));
if (returnValue == ReturnType::TapeAndTwoReturns) {
RetTypes.push_back(FTy->getReturnType());
RetTypes.push_back(
GradientUtils::getShadowType(FTy->getReturnType(), width));
} else if (returnValue == ReturnType::TapeAndReturn) {
if (returnType != DIFFE_TYPE::CONSTANT &&
returnType != DIFFE_TYPE::OUT_DIFF)
RetTypes.push_back(
GradientUtils::getShadowType(FTy->getReturnType(), width));
else
RetTypes.push_back(FTy->getReturnType());
}
RetType = StructType::get(FTy->getContext(), RetTypes);
} else if (returnValue == ReturnType::Return) {
assert(RetTypes.size() == 1);
RetType = RetTypes[0];
} else if (returnValue == ReturnType::TwoReturns) {
assert(RetTypes.size() == 2);
}
bool noReturn = RetTypes.size() == 0;
if (noReturn)
RetType = Type::getVoidTy(RetType->getContext());
// Create a new function type...
return FunctionType::get(RetType, ArgTypes, FTy->isVarArg());
}
Function *PreProcessCache::CloneFunctionWithReturns(
DerivativeMode mode, unsigned width, Function *&F,
ValueToValueMapTy &ptrInputs, ArrayRef<DIFFE_TYPE> constant_args,
SmallPtrSetImpl<Value *> &constants, SmallPtrSetImpl<Value *> &nonconstant,
SmallPtrSetImpl<Value *> &returnvals, ReturnType returnValue,
DIFFE_TYPE returnType, const Twine &name,
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> *VMapO,
bool diffeReturnArg, llvm::Type *additionalArg) {
if (!F->empty())
F = preprocessForClone(F, mode);
llvm::ValueToValueMapTy VMap;
llvm::FunctionType *FTy = getFunctionTypeForClone(
F->getFunctionType(), mode, width, additionalArg, constant_args,
diffeReturnArg, returnValue, returnType);
for (BasicBlock &BB : *F) {
if (auto ri = dyn_cast<ReturnInst>(BB.getTerminator())) {
if (auto rv = ri->getReturnValue()) {
returnvals.insert(rv);
}
}
}
// Create the new function...
Function *NewF = Function::Create(FTy, F->getLinkage(), name, F->getParent());
if (diffeReturnArg) {
auto I = NewF->arg_end();
I--;
if (additionalArg)
I--;
I->setName("differeturn");
}
if (additionalArg) {
auto I = NewF->arg_end();
I--;
I->setName("tapeArg");
}
{
unsigned ii = 0;
for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) {
VMap[i] = j;
++j;
++i;
if (constant_args[ii] == DIFFE_TYPE::DUP_ARG ||
constant_args[ii] == DIFFE_TYPE::DUP_NONEED) {
++j;
}
++ii;
}
}
// Loop over the arguments, copying the names of the mapped arguments over...
Function::arg_iterator DestI = NewF->arg_begin();
for (const Argument &I : F->args())
if (VMap.count(&I) == 0) { // Is this argument preserved?
DestI->setName(I.getName()); // Copy the name over...
VMap[&I] = &*DestI++; // Add mapping to VMap
}
SmallVector<ReturnInst *, 4> Returns;
if (!F->empty()) {
CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
Returns, "", nullptr);
}
if (NewF->empty()) {
auto entry = BasicBlock::Create(NewF->getContext(), "entry", NewF);
IRBuilder<> B(entry);
B.CreateUnreachable();
}
CloneOrigin[NewF] = F;
if (VMapO) {
for (const auto &data : VMap)
VMapO->insert(std::pair<const llvm::Value *, AssertingReplacingVH>(
data.first, (llvm::Value *)data.second));
VMapO->getMDMap() = VMap.getMDMap();
}
for (auto attr : {"enzyme_ta_norecur", "frame-pointer"})
if (F->getAttributes().hasAttribute(AttributeList::FunctionIndex, attr)) {
NewF->addAttribute(
AttributeList::FunctionIndex,
F->getAttributes().getAttribute(AttributeList::FunctionIndex, attr));
}
for (auto attr :
{"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"})
if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, attr)) {
NewF->addAttribute(
AttributeList::ReturnIndex,
F->getAttributes().getAttribute(AttributeList::ReturnIndex, attr));
}
bool hasPtrInput = false;
unsigned ii = 0, jj = 0;
for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) {
if (F->hasParamAttribute(ii, Attribute::StructRet)) {
NewF->addParamAttr(jj, Attribute::get(F->getContext(), "enzyme_sret"));
// TODO
// NewF->addParamAttr(
// jj,
// Attribute::get(
// F->getContext(), Attribute::AttrKind::ElementType,
// F->getParamAttribute(ii,
// Attribute::StructRet).getValueAsType()));
}
if (F->getAttributes().hasParamAttr(ii, "enzymejl_returnRoots")) {
NewF->addParamAttr(
jj, F->getAttributes().getParamAttr(ii, "enzymejl_returnRoots"));
// TODO
// NewF->addParamAttr(jj, F->getParamAttribute(ii,
// Attribute::ElementType));
}
for (auto attr :
{"enzymejl_parmtype", "enzymejl_parmtype_ref", "enzyme_type"})
if (F->getAttributes().hasParamAttr(ii, attr)) {
NewF->addParamAttr(jj, F->getAttributes().getParamAttr(ii, attr));
for (auto ty : PrimalParamAttrsToPreserve)
if (F->getAttributes().hasParamAttr(ii, ty)) {
auto attr = F->getAttributes().getParamAttr(ii, ty);
NewF->addParamAttr(jj, attr);
}
}
if (constant_args[ii] == DIFFE_TYPE::CONSTANT) {
if (!i->hasByValAttr())
constants.insert(i);
if (EnzymePrintActivity)
llvm::errs() << "in new function " << NewF->getName()
<< " constant arg " << *j << "\n";
} else {
nonconstant.insert(i);
if (EnzymePrintActivity)
llvm::errs() << "in new function " << NewF->getName()
<< " nonconstant arg " << *j << "\n";
}
// Always remove nonnull/noundef since the caller may choose to pass
// undef as an arg if provably it will not be used in the reverse pass
if (constant_args[ii] == DIFFE_TYPE::DUP_NONEED ||
mode == DerivativeMode::ReverseModeGradient) {
if (F->hasParamAttribute(ii, Attribute::NonNull)) {
NewF->removeParamAttr(jj, Attribute::NonNull);
}
if (F->hasParamAttribute(ii, Attribute::NoUndef)) {
NewF->removeParamAttr(jj, Attribute::NoUndef);
}
}
if (constant_args[ii] == DIFFE_TYPE::DUP_ARG ||
constant_args[ii] == DIFFE_TYPE::DUP_NONEED) {
hasPtrInput = true;
ptrInputs[i] = (j + 1);
// TODO: find a way to keep the attributes in vector mode.
if (width == 1)
for (auto ty : ShadowParamAttrsToPreserve)
if (F->getAttributes().hasParamAttr(ii, ty)) {
auto attr = F->getAttributes().getParamAttr(ii, ty);
NewF->addParamAttr(jj + 1, attr);
}
for (auto attr :
{"enzymejl_parmtype", "enzymejl_parmtype_ref", "enzyme_type"})
if (F->getAttributes().hasParamAttr(ii, attr)) {
if (width == 1)
NewF->addParamAttr(jj + 1,
F->getAttributes().getParamAttr(ii, attr));
}
if (F->getAttributes().hasParamAttr(ii, "enzymejl_returnRoots")) {
if (width == 1) {
NewF->addParamAttr(jj + 1, F->getAttributes().getParamAttr(
ii, "enzymejl_returnRoots"));
} else {
NewF->addParamAttr(jj + 1, Attribute::get(F->getContext(),
"enzymejl_returnRoots_v"));
}
// TODO
// NewF->addParamAttr(jj + 1,
// F->getParamAttribute(ii,
// Attribute::ElementType));
}
if (F->hasParamAttribute(ii, Attribute::StructRet)) {
if (width == 1) {
NewF->addParamAttr(jj + 1,
Attribute::get(F->getContext(), "enzyme_sret"));
// TODO
// NewF->addParamAttr(
// jj + 1,
// Attribute::get(F->getContext(),
// Attribute::AttrKind::ElementType,
// F->getParamAttribute(ii,
// Attribute::StructRet)
// .getValueAsType()));
} else {
NewF->addParamAttr(jj + 1,
Attribute::get(F->getContext(), "enzyme_sret_v"));
// TODO
// NewF->addParamAttr(
// jj + 1,
// Attribute::get(F->getContext(),
// Attribute::AttrKind::ElementType,
// F->getParamAttribute(ii,
// Attribute::StructRet)
// .getValueAsType()));
}
}
j->setName(i->getName());
++j;
j->setName(i->getName() + "'");
nonconstant.insert(j);
++j;
jj += 2;
++i;
} else {
j->setName(i->getName());
++j;
++jj;
++i;
}
++ii;
}
if (hasPtrInput && (mode == DerivativeMode::ReverseModeCombined ||
mode == DerivativeMode::ReverseModeGradient)) {
if (NewF->hasFnAttribute(Attribute::ReadOnly)) {
NewF->removeFnAttr(Attribute::ReadOnly);
}
#if LLVM_VERSION_MAJOR >= 16
auto eff = NewF->getMemoryEffects();
for (auto loc : MemoryEffects::locations()) {
if (loc == MemoryEffects::Location::InaccessibleMem)
continue;
auto mr = eff.getModRef(loc);
if (isModSet(mr))
eff |= MemoryEffects(loc, ModRefInfo::Ref);
if (isRefSet(mr))
eff |= MemoryEffects(loc, ModRefInfo::Mod);
}
NewF->setMemoryEffects(eff);
#endif
}
NewF->setLinkage(Function::LinkageTypes::InternalLinkage);
if (EnzymeAlwaysInlineDiff)
NewF->addFnAttr(Attribute::AlwaysInline);
assert(NewF->hasLocalLinkage());
return NewF;
}
void CoaleseTrivialMallocs(Function &F, DominatorTree &DT) {
std::map<BasicBlock *, std::vector<std::pair<CallInst *, CallInst *>>>
LegalMallocs;
std::map<Metadata *, std::vector<CallInst *>> frees;
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
if (auto CI = dyn_cast<CallInst>(&I)) {
if (auto F2 = CI->getCalledFunction()) {
if (F2->getName() == "free") {
if (auto MD = hasMetadata(CI, "enzyme_cache_free")) {
Metadata *op = MD->getOperand(0);
frees[op].push_back(CI);
}
}
}
}
}
}
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
if (auto CI = dyn_cast<CallInst>(&I)) {
if (auto F = CI->getCalledFunction()) {
if (F->getName() == "malloc") {
CallInst *freeCall = nullptr;
for (auto U : CI->users()) {
if (auto CI2 = dyn_cast<CallInst>(U)) {
if (auto F2 = CI2->getCalledFunction()) {
if (F2->getName() == "free") {
if (DT.dominates(CI, CI2)) {
freeCall = CI2;
break;
}
}
}
}
}
if (!freeCall) {
if (auto MD = hasMetadata(CI, "enzyme_cache_alloc")) {
Metadata *op = MD->getOperand(0);
if (frees[op].size() == 1)
freeCall = frees[op][0];
}
}
if (freeCall)
LegalMallocs[&BB].emplace_back(CI, freeCall);
}
}
}
}
}
for (auto &pair : LegalMallocs) {
if (pair.second.size() < 2)
continue;
CallInst *First = pair.second[0].first;
for (auto &z : pair.second) {
if (!DT.dominates(First, z.first))
First = z.first;
}
bool legal = true;
for (auto &z : pair.second) {
if (auto inst = dyn_cast<Instruction>(z.first->getArgOperand(0)))
if (!DT.dominates(inst, First))
legal = false;
}
if (!legal)
continue;
IRBuilder<> B(First);
Value *Size = First->getArgOperand(0);
for (auto &z : pair.second) {
if (z.first == First)
continue;
Size = B.CreateAdd(
B.CreateOr(B.CreateSub(Size, ConstantInt::get(Size->getType(), 1)),
ConstantInt::get(Size->getType(), 15)),
ConstantInt::get(Size->getType(), 1));
z.second->eraseFromParent();
IRBuilder<> B2(z.first);
Value *gepPtr = B2.CreateInBoundsGEP(Type::getInt8Ty(First->getContext()),
First, Size);
z.first->replaceAllUsesWith(gepPtr);
Size = B.CreateAdd(Size, z.first->getArgOperand(0));
z.first->eraseFromParent();
}
auto NewMalloc =
cast<CallInst>(B.CreateCall(First->getCalledFunction(), Size));
NewMalloc->copyIRFlags(First);
NewMalloc->setMetadata("enzyme_cache_alloc",
hasMetadata(First, "enzyme_cache_alloc"));
First->replaceAllUsesWith(NewMalloc);
First->eraseFromParent();
}
}
void SelectOptimization(Function *F) {
DominatorTree DT(*F);
for (auto &BB : *F) {
if (auto BI = dyn_cast<BranchInst>(BB.getTerminator())) {
if (BI->isConditional()) {
for (auto &I : BB) {
if (auto SI = dyn_cast<SelectInst>(&I)) {
if (SI->getCondition() == BI->getCondition()) {
for (Value::use_iterator UI = SI->use_begin(), E = SI->use_end();
UI != E;) {
Use &U = *UI;
++UI;
if (DT.dominates(BasicBlockEdge(&BB, BI->getSuccessor(0)), U))
U.set(SI->getTrueValue());
else if (DT.dominates(BasicBlockEdge(&BB, BI->getSuccessor(1)),
U))
U.set(SI->getFalseValue());
}
}
}
}
}
}
}
}
void ReplaceFunctionImplementation(Module &M) {
for (Function &Impl : M) {
for (auto attr : {"implements", "implements2"}) {
if (!Impl.hasFnAttribute(attr))
continue;
const Attribute &A = Impl.getFnAttribute(attr);
const StringRef SpecificationName = A.getValueAsString();
Function *Specification = M.getFunction(SpecificationName);
if (!Specification) {
LLVM_DEBUG(dbgs() << "Found implementation '" << Impl.getName()
<< "' but no matching specification with name '"
<< SpecificationName
<< "', potentially inlined and/or eliminated.\n");
continue;
}
LLVM_DEBUG(dbgs() << "Replace specification '" << Specification->getName()
<< "' with implementation '" << Impl.getName()
<< "'\n");
for (auto I = Specification->use_begin(), UE = Specification->use_end();
I != UE;) {
auto &use = *I;
++I;
auto cext = ConstantExpr::getBitCast(&Impl, Specification->getType());
if (cast<Instruction>(use.getUser())->getParent()->getParent() == &Impl)
continue;
use.set(cext);
if (auto CI = dyn_cast<CallInst>(use.getUser())) {
if (CI->getCalledOperand() == cext ||
CI->getCalledFunction() == &Impl) {
CI->setCallingConv(Impl.getCallingConv());
}
}
}
}
}
}
void PreProcessCache::optimizeIntermediate(Function *F) {
PreservedAnalyses PA;
PA = PromotePass().run(*F, FAM);
FAM.invalidate(*F, PA);
#if !defined(FLANG)
PA = GVNPass().run(*F, FAM);
#else
PA = GVN().run(*F, FAM);
#endif
FAM.invalidate(*F, PA);
#if LLVM_VERSION_MAJOR >= 16 && !defined(FLANG)
PA = SROAPass(llvm::SROAOptions::PreserveCFG).run(*F, FAM);
#elif !defined(FLANG)
PA = SROAPass().run(*F, FAM);
#else
PA = SROA().run(*F, FAM);
#endif
FAM.invalidate(*F, PA);
if (EnzymeSelectOpt) {
SimplifyCFGOptions scfgo;
PA = SimplifyCFGPass(scfgo).run(*F, FAM);
FAM.invalidate(*F, PA);
PA = CorrelatedValuePropagationPass().run(*F, FAM);
FAM.invalidate(*F, PA);
SelectOptimization(F);
}
// EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM);
if (EnzymeCoalese)
CoaleseTrivialMallocs(*F, FAM.getResult<DominatorTreeAnalysis>(*F));
ReplaceFunctionImplementation(*F->getParent());
{
PreservedAnalyses PA;
FAM.invalidate(*F, PA);
}
OptimizationLevel Level = OptimizationLevel::O0;
switch (EnzymePostOptLevel) {
default:
case 0:
Level = OptimizationLevel::O0;
break;
case 1:
Level = OptimizationLevel::O1;
break;
case 2:
Level = OptimizationLevel::O2;
break;
case 3:
Level = OptimizationLevel::O3;
break;
}
if (Level != OptimizationLevel::O0) {
PassBuilder PB;
FunctionPassManager FPM =
PB.buildFunctionSimplificationPipeline(Level, ThinOrFullLTOPhase::None);
PA = FPM.run(*F, FAM);
FAM.invalidate(*F, PA);
}
// TODO actually run post optimizations.
}
void PreProcessCache::clear() {
LAM.clear();
FAM.clear();
MAM.clear();
cache.clear();
}
// Returns if a is guaranteed to be equivalent to not b
static bool isNot(Value *a, Value *b) {
// cmp pred, a, b and cmp inverse(pred), a, b
if (auto I1 = dyn_cast<CmpInst>(a))
if (auto I2 = dyn_cast<CmpInst>(b))
if (I1->getOperand(0) == I2->getOperand(0) &&
I1->getOperand(1) == I2->getOperand(1) &&
I1->getPredicate() == I2->getInversePredicate())
return true;
// a := xor true, b
if (auto I = dyn_cast<Instruction>(a))
if (I->getOpcode() == Instruction::Xor)
for (int i = 0; i < 2; i++) {
if (I->getOperand(i) == b)
if (auto CI = dyn_cast<ConstantInt>(I->getOperand(1 - i)))
#if LLVM_VERSION_MAJOR > 16
if (CI->getValue().isAllOnes())
#else
if (CI->getValue().isAllOnesValue())
#endif
return true;
}
// b := xor true, a
if (auto I = dyn_cast<Instruction>(b))
if (I->getOpcode() == Instruction::Xor)
for (int i = 0; i < 2; i++) {
if (I->getOperand(i) == a)
if (auto CI = dyn_cast<ConstantInt>(I->getOperand(1 - i)))
#if LLVM_VERSION_MAJOR > 16
if (CI->getValue().isAllOnes())
#else
if (CI->getValue().isAllOnesValue())
#endif
return true;
}
return false;
}
struct compare_insts {
public:
DominatorTree &DT;
LoopInfo &LI;
compare_insts(DominatorTree &DT, LoopInfo &LI) : DT(DT), LI(LI) {}
// return true if A appears later than B.
bool operator()(Instruction *A, Instruction *B) const {
if (A == B) {
return false;
}
if (A->getParent() == B->getParent()) {
return !A->comesBefore(B);
}
auto AB = A->getParent();
auto BB = B->getParent();
assert(AB->getParent() == BB->getParent());
for (auto prev = BB->getPrevNode(); prev; prev = prev->getPrevNode()) {
if (prev == AB)
return false;
}
return true;
}
};
class DominatorOrderSet : public std::set<Instruction *, compare_insts> {
public:
DominatorOrderSet(DominatorTree &DT, LoopInfo &LI)
: std::set<Instruction *, compare_insts>(compare_insts(DT, LI)) {}
bool contains(Instruction *I) const {
auto __i = find(I);
return __i != end();
}
void remove(Instruction *I) {
auto __i = find(I);
assert(__i != end());
erase(__i);
}
Instruction *pop_back_val() {
auto back = end();
back--;
auto v = *back;
erase(back);
return v;
}
};
bool directlySparse(Value *z) {
if (isa<UIToFPInst>(z))
return true;
if (isa<SIToFPInst>(z))
return true;
if (isa<ZExtInst>(z))
return true;
if (isa<SExtInst>(z))
return true;
if (auto SI = dyn_cast<SelectInst>(z)) {
if (auto CI = dyn_cast<ConstantInt>(SI->getTrueValue()))
if (CI->isZero())
return true;
if (auto CI = dyn_cast<ConstantInt>(SI->getFalseValue()))
if (CI->isZero())
return true;
}
return false;
}
typedef DominatorOrderSet QueueType;
Function *getProductIntrinsic(llvm::Module &M, llvm::Type *T) {
std::string name = "__enzyme_product.";
if (T->isFloatTy())
name += "f32";
else if (T->isDoubleTy())
name += "f64";
else if (T->isIntegerTy())
name += "i" + std::to_string(cast<IntegerType>(T)->getBitWidth());
else
assert(0);
auto FT = llvm::FunctionType::get(T, {}, true);
AttributeList AL;
AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
Attribute::ReadNone);
AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
Attribute::NoUnwind);
AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
Attribute::NoFree);
AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
Attribute::NoSync);
AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
Attribute::WillReturn);
return cast<Function>(M.getOrInsertFunction(name, FT, AL).getCallee());
}
Function *getSumIntrinsic(llvm::Module &M, llvm::Type *T) {
std::string name = "__enzyme_sum.";
if (T->isFloatTy())
name += "f32";
else if (T->isDoubleTy())
name += "f64";
else if (T->isIntegerTy())
name += "i" + std::to_string(cast<IntegerType>(T)->getBitWidth());
else
assert(0);
auto FT = llvm::FunctionType::get(T, {}, true);
AttributeList AL;
AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
Attribute::ReadNone);
AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
Attribute::NoUnwind);
AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
Attribute::NoFree);
AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
Attribute::NoSync);
AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
Attribute::WillReturn);
return cast<Function>(M.getOrInsertFunction(name, FT, AL).getCallee());
}
CallInst *isProduct(llvm::Value *v) {
if (auto prod = dyn_cast<CallInst>(v))
if (auto F = getFunctionFromCall(prod))
if (startsWith(F->getName(), "__enzyme_product"))
return prod;
return nullptr;
}
CallInst *isSum(llvm::Value *v) {
if (auto prod = dyn_cast<CallInst>(v))
if (auto F = getFunctionFromCall(prod))
if (startsWith(F->getName(), "__enzyme_sum"))
return prod;
return nullptr;
}
SmallVector<Value *, 1> callOperands(llvm::CallBase *CB) {
return SmallVector<Value *, 1>(CB->args().begin(), CB->args().end());
}
bool guaranteedDataDependent(Value *z) {
if (isa<LoadInst>(z))
return true;
if (isa<Constant>(z))
return true;
if (auto BO = dyn_cast<BinaryOperator>(z))
return guaranteedDataDependent(BO->getOperand(0)) &&
guaranteedDataDependent(BO->getOperand(1));
if (auto C = dyn_cast<CastInst>(z))
return guaranteedDataDependent(C->getOperand(0));
if (auto S = isSum(z)) {
for (auto op : callOperands(S))
if (guaranteedDataDependent(op))
return true;
return false;
}
if (auto S = isProduct(z)) {
for (auto op : callOperands(S))
if (!guaranteedDataDependent(op))
return false;
return true;
}
if (auto II = dyn_cast<IntrinsicInst>(z)) {
switch (II->getIntrinsicID()) {
case Intrinsic::sqrt:
case Intrinsic::sin:
case Intrinsic::cos:
#if LLVM_VERSION_MAJOR >= 19
case Intrinsic::sinh:
case Intrinsic::cosh:
case Intrinsic::tanh:
#endif
return guaranteedDataDependent(II->getArgOperand(0));
default:
break;
}
}
return false;
}
std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
QueueType &Q, DominatorTree &DT,
ScalarEvolution &SE, LoopInfo &LI,
const DataLayout &DL) {
auto push = [&](llvm::Value *V) {
if (V == cur)
return V;
assert(V);
if (auto I = dyn_cast<Instruction>(V)) {
Q.insert(I);
for (auto U : I->users()) {
if (auto I2 = dyn_cast<Instruction>(U)) {
if (I2 == cur)
continue;
Q.insert(I2);
}
}
}
return V;
};
auto pushcse = [&](llvm::Value *V) -> llvm::Value * {
if (auto I = dyn_cast<Instruction>(V)) {
for (size_t i = 0; i < I->getNumOperands(); i++) {
if (auto I2 = dyn_cast<Instruction>(I->getOperand(i))) {
Instruction *candidate = nullptr;
for (auto U : I2->users()) {
candidate = dyn_cast<Instruction>(U);
if (!candidate)
continue;
if (candidate == I && candidate->getType() != I->getType()) {
candidate = nullptr;
continue;
}
bool isSame = candidate->isIdenticalTo(I);
if (!isSame) {
if (auto P1 = isProduct(I))
if (auto P2 = isProduct(I2)) {
std::multiset<llvm::Value *> s1;
std::multiset<llvm::Value *> s2;
for (auto &v : callOperands(P1))
s1.insert(v);
for (auto &v : callOperands(P2))
s2.insert(v);
isSame = s1 == s2;
}
if (auto P1 = isSum(I))
if (auto P2 = isSum(I2)) {
std::multiset<llvm::Value *> s1;
std::multiset<llvm::Value *> s2;
for (auto &v : callOperands(P1))
s1.insert(v);
for (auto &v : callOperands(P2))
s2.insert(v);
isSame = s1 == s2;
}
}
if (!isSame) {
candidate = nullptr;
continue;
}
if (DT.dominates(candidate, I)) {
break;
}
candidate = nullptr;
}
if (candidate) {
I->eraseFromParent();
return candidate;
}
break;
}
}
return push(I);
}
return V;
};
auto replaceAndErase = [&](llvm::Instruction *I, llvm::Value *candidate) {
for (auto U : I->users())
push(U);
I->replaceAllUsesWith(candidate);
push(candidate);
SetVector<Instruction *> operands;
for (size_t i = 0; i < I->getNumOperands(); i++) {
if (auto I2 = dyn_cast<Instruction>(I->getOperand(i))) {
if ((!I2->mayWriteToMemory() ||
(isa<CallInst>(I2) && isReadOnly(cast<CallInst>(I2)))))
operands.insert(I2);
}
}
if (Q.contains(I)) {
Q.remove(I);
}
assert(!Q.contains(I));
I->eraseFromParent();
for (auto op : operands)
if (op->getNumUses() == 0) {
if (Q.contains(op))
Q.remove(op);
op->eraseFromParent();
}
};
if (!cur->getType()->isVoidTy() &&
(!cur->mayWriteToMemory() ||
(isa<CallInst>(cur) && isReadOnly(cast<CallInst>(cur))))) {
// DCE
if (cur->getNumUses() == 0) {
for (size_t i = 0; i < cur->getNumOperands(); i++)
push(cur->getOperand(i));
assert(!Q.contains(cur));
cur->eraseFromParent();
return "DCE";
}
// CSE
{
for (size_t i = 0; i < cur->getNumOperands(); i++) {
if (auto I = dyn_cast<Instruction>(cur->getOperand(i))) {
Instruction *candidate = nullptr;
bool reverse = false;
for (auto U : I->users()) {
candidate = dyn_cast<Instruction>(U);
if (!candidate)
continue;
if (candidate == cur && candidate->getType() != cur->getType()) {
candidate = nullptr;
continue;
}
bool isSame = candidate->isIdenticalTo(cur);
if (!isSame) {
if (auto P1 = isProduct(candidate))
if (auto P2 = isProduct(cur)) {
std::multiset<llvm::Value *> s1;
std::multiset<llvm::Value *> s2;
for (auto &v : callOperands(P1))
s1.insert(v);
for (auto &v : callOperands(P2))
s2.insert(v);
isSame = s1 == s2;
}
if (auto P1 = isSum(candidate))
if (auto P2 = isSum(cur)) {
std::multiset<llvm::Value *> s1;
std::multiset<llvm::Value *> s2;
for (auto &v : callOperands(P1))
s1.insert(v);
for (auto &v : callOperands(P2))
s2.insert(v);
isSame = s1 == s2;
}
}
if (!isSame) {
candidate = nullptr;
continue;
}
if (DT.dominates(candidate, cur)) {
break;
} else if (DT.dominates(cur, candidate)) {
reverse = true;
break;
}
candidate = nullptr;
}
if (candidate) {
if (reverse) {
if (Q.contains(candidate))
Q.remove(candidate);
auto tmp = candidate;
candidate = cur;
cur = tmp;
}
replaceAndErase(cur, candidate);
return "CSE";
}
break;
}
}
}
}
if (auto SI = dyn_cast<SelectInst>(cur))
if (auto CI = dyn_cast<ConstantInt>(SI->getCondition())) {
if (CI->isOne()) {
replaceAndErase(cur, SI->getTrueValue());
return "SelectToTrue";
} else {
replaceAndErase(cur, SI->getFalseValue());
return "SelectToFalse";
}
}
if (cur->getOpcode() == Instruction::Or) {
for (int i = 0; i < 2; i++) {
if (auto C = dyn_cast<ConstantInt>(cur->getOperand(i))) {
// or a, 0 -> a
if (C->isZero()) {
replaceAndErase(cur, cur->getOperand(1 - i));
return "OrZero";
}
// or a, 1 -> 1
if (C->isOne() && cur->getType()->isIntegerTy(1)) {
replaceAndErase(cur, C);
return "OrOne";
}
}
}
}
if (cur->getOpcode() == Instruction::And) {
for (int i = 0; i < 2; i++) {
if (auto C = dyn_cast<ConstantInt>(cur->getOperand(i))) {
// and a, 1 -> a
if (C->isOne() && cur->getType()->isIntegerTy(1)) {
replaceAndErase(cur, cur->getOperand(1 - i));
return "AndOne";
}
// and a, 0 -> 0
if (C->isZero()) {
replaceAndErase(cur, C);
return "AndZero";
}
}
}
}
IRBuilder<> B(cur);
if (auto CI = dyn_cast<CastInst>(cur))
if (auto C = dyn_cast<Constant>(CI->getOperand(0))) {
replaceAndErase(
cur, cast<Constant>(B.CreateCast(CI->getOpcode(), C, CI->getType())));
return "CastConstProp";
}
std::function<Value *(Value *, Value *, Value *)> replace = [&](Value *val,
Value *orig,
Value *with) {
if (val == orig) {
return with;
}
if (isNot(val, orig)) {
return pushcse(B.CreateNot(with));
}
if (isa<PHINode>(val))
return val;
if (auto I = dyn_cast<Instruction>(val)) {
if (I->mayWriteToMemory() &&
!(isa<CallInst>(I) && isReadOnly(cast<CallInst>(I))))
return val;
if (I->getOpcode() == Instruction::Add) {
Value *lhs = replace(I->getOperand(0), orig, with);
Value *rhs = replace(I->getOperand(1), orig, with);
if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
return val;
push(I);
return pushcse(B.CreateAdd(lhs, rhs, "sel." + I->getName(),
I->hasNoUnsignedWrap(),
I->hasNoSignedWrap()));
}
if (I->getOpcode() == Instruction::Sub) {
Value *lhs = replace(I->getOperand(0), orig, with);
Value *rhs = replace(I->getOperand(1), orig, with);
if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
return val;
push(I);
return pushcse(B.CreateSub(lhs, rhs, "sel." + I->getName(),
I->hasNoUnsignedWrap(),
I->hasNoSignedWrap()));
}
if (I->getOpcode() == Instruction::Mul) {
Value *lhs = replace(I->getOperand(0), orig, with);
Value *rhs = replace(I->getOperand(1), orig, with);
if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
return val;
push(I);
return pushcse(B.CreateMul(lhs, rhs, "sel." + I->getName(),
I->hasNoUnsignedWrap(),
I->hasNoSignedWrap()));
}
if (I->getOpcode() == Instruction::And) {
Value *lhs = replace(I->getOperand(0), orig, with);
Value *rhs = replace(I->getOperand(1), orig, with);
if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
return val;
push(I);
return pushcse(B.CreateAnd(lhs, rhs, "sel." + I->getName()));
}
if (I->getOpcode() == Instruction::Or) {
Value *lhs = replace(I->getOperand(0), orig, with);
Value *rhs = replace(I->getOperand(1), orig, with);
if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
return val;
push(I);
return pushcse(B.CreateOr(lhs, rhs, "sel." + I->getName()));
}
if (I->getOpcode() == Instruction::Xor) {
Value *lhs = replace(I->getOperand(0), orig, with);
Value *rhs = replace(I->getOperand(1), orig, with);
if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
return val;
push(I);
return pushcse(B.CreateXor(lhs, rhs, "sel." + I->getName()));
}
if (I->getOpcode() == Instruction::FAdd) {
Value *lhs = replace(I->getOperand(0), orig, with);
Value *rhs = replace(I->getOperand(1), orig, with);
if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
return val;
push(I);
return pushcse(B.CreateFAddFMF(lhs, rhs, I, "sel." + I->getName()));
}
if (I->getOpcode() == Instruction::FSub) {
Value *lhs = replace(I->getOperand(0), orig, with);
Value *rhs = replace(I->getOperand(1), orig, with);
if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
return val;
push(I);
return pushcse(B.CreateFSubFMF(lhs, rhs, I, "sel." + I->getName()));
}
if (I->getOpcode() == Instruction::FMul) {
Value *lhs = replace(I->getOperand(0), orig, with);
Value *rhs = replace(I->getOperand(1), orig, with);
if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
return val;
push(I);
return pushcse(B.CreateFMulFMF(lhs, rhs, I, "sel." + I->getName()));
}
if (I->getOpcode() == Instruction::ZExt) {
Value *op = replace(I->getOperand(0), orig, with);
if (op == I->getOperand(0))
return val;
push(I);
return pushcse(B.CreateZExt(op, I->getType(), "sel." + I->getName()));
}
if (I->getOpcode() == Instruction::SExt) {
Value *op = replace(I->getOperand(0), orig, with);
if (op == I->getOperand(0))
return val;
push(I);
return pushcse(B.CreateSExt(op, I->getType(), "sel." + I->getName()));
}
if (I->getOpcode() == Instruction::UIToFP) {
Value *op = replace(I->getOperand(0), orig, with);
if (op == I->getOperand(0))
return val;
push(I);
return pushcse(B.CreateUIToFP(op, I->getType(), "sel." + I->getName()));
}
if (I->getOpcode() == Instruction::SIToFP) {
Value *op = replace(I->getOperand(0), orig, with);
if (op == I->getOperand(0))
return val;
push(I);
return pushcse(B.CreateSIToFP(op, I->getType(), "sel." + I->getName()));
}
if (auto CI = dyn_cast<CmpInst>(I)) {
Value *lhs = replace(I->getOperand(0), orig, with);
Value *rhs = replace(I->getOperand(1), orig, with);
if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
return val;
push(I);
return pushcse(
B.CreateCmp(CI->getPredicate(), lhs, rhs, "sel." + I->getName()));
}
if (auto SI = dyn_cast<SelectInst>(I)) {
Value *cond = replace(SI->getCondition(), orig, with);
Value *tval = replace(SI->getTrueValue(), orig, with);
Value *fval = replace(SI->getFalseValue(), orig, with);
if (cond == SI->getCondition() && tval == SI->getTrueValue() &&
fval == SI->getFalseValue())
return val;
push(I);
if (auto CI = dyn_cast<ConstantInt>(cond)) {
if (CI->isOne())
return tval;
else
return fval;
}
return pushcse(B.CreateSelect(cond, tval, fval, "sel." + I->getName()));
}
if (isProduct(I) || isSum(I)) {
auto C = cast<CallBase>(I);
auto ops = callOperands(C);
bool changed = false;
for (auto &op : ops) {
auto next = replace(op, orig, with);
if (next != op) {
changed = true;
op = next;
}
}
if (!changed)
return (Value *)I;
push(I);
pushcse(
B.CreateCall(getFunctionFromCall(C), ops, "sel." + I->getName()));
}
}
return val;
};
if (auto II = dyn_cast<IntrinsicInst>(cur))
if (II->getIntrinsicID() == Intrinsic::fmuladd ||
II->getIntrinsicID() == Intrinsic::fma) {
B.setFastMathFlags(getFast());
auto mul = pushcse(B.CreateFMul(II->getOperand(0), II->getOperand(1)));
auto add = pushcse(B.CreateFAdd(mul, II->getOperand(2)));
replaceAndErase(cur, add);
return "FMulAddExpand";
}
if (auto BO = dyn_cast<BinaryOperator>(cur)) {
if (BO->getOpcode() == Instruction::FMul && BO->isFast()) {
Value *args[2] = {BO->getOperand(0), BO->getOperand(1)};
auto mul = pushcse(
B.CreateCall(getProductIntrinsic(*F.getParent(), BO->getType()), args,
cur->getName()));
replaceAndErase(cur, mul);
return "FMulToProduct";
}
if (BO->getOpcode() == Instruction::FDiv && BO->isFast()) {
auto c0 = dyn_cast<ConstantFP>(BO->getOperand(0));
if (!c0 || !c0->isExactlyValue(1.0)) {
B.setFastMathFlags(getFast());
auto div = pushcse(B.CreateFDivFMF(ConstantFP::get(BO->getType(), 1.0),
BO->getOperand(1), BO));
auto mul = pushcse(
B.CreateFMulFMF(BO->getOperand(0), div, BO, cur->getName()));
replaceAndErase(cur, mul);
return "FDivToFMul";
}
}
if (BO->getOpcode() == Instruction::FAdd && BO->isFast()) {
Value *args[2] = {BO->getOperand(0), BO->getOperand(1)};
auto mul = pushcse(
B.CreateCall(getSumIntrinsic(*F.getParent(), BO->getType()), args));
replaceAndErase(cur, mul);
return "FAddToSum";
}
if (BO->getOpcode() == Instruction::FSub && BO->isFast()) {
B.setFastMathFlags(getFast());
Value *args[2] = {BO->getOperand(0),
pushcse(B.CreateFNeg(BO->getOperand(1)))};
auto mul =
pushcse(B.CreateCall(getSumIntrinsic(*F.getParent(), BO->getType()),
args, cur->getName()));
replaceAndErase(cur, mul);
return "FAddToSum";
}
}
if (cur->getOpcode() == Instruction::FNeg) {
B.setFastMathFlags(getFast());
auto mul =
pushcse(B.CreateFMulFMF(ConstantFP::get(cur->getType(), -1.0),
cur->getOperand(0), cur, cur->getName()));
replaceAndErase(cur, mul);
return "FNegToMul";
}
if (auto SI = dyn_cast<SelectInst>(cur)) {
if (auto tc = dyn_cast<ConstantFP>(SI->getTrueValue()))
if (auto fc = dyn_cast<ConstantFP>(SI->getFalseValue()))
if (fc->isZero()) {
if (tc->isExactlyValue(1.0)) {
auto res =
pushcse(B.CreateUIToFP(SI->getCondition(), tc->getType()));
replaceAndErase(cur, res);
return "SelToUIFP";
}
if (tc->isExactlyValue(-1.0)) {
auto res =
pushcse(B.CreateSIToFP(SI->getCondition(), tc->getType()));
replaceAndErase(cur, res);
return "SelToSIFP";
}
}
}
if (auto P = isProduct(cur)) {
SmallVector<Value *, 1> operands;
std::optional<APFloat> constval;
bool changed = false;
for (auto &v : callOperands(P))
{
if (auto P2 = isProduct(v)) {
for (auto &v2 : callOperands(P2)) {
push(v2);
operands.push_back(v2);
}
push(P2);
changed = true;
continue;
}
if (auto C = dyn_cast<ConstantFP>(v)) {
if (C->isExactlyValue(1.0)) {
changed = true;
continue;
}
if (C->isZero()) {
replaceAndErase(cur, C);
return "ZeroProduct";
}
if (!constval) {
constval = C->getValue();
continue;
}
constval = (*constval) * C->getValue();
changed = true;
continue;
}
if (auto op = dyn_cast<SelectInst>(v)) {
if (auto tc = dyn_cast<ConstantFP>(op->getTrueValue()))
if (tc->isZero()) {
operands.push_back(pushcse(B.CreateUIToFP(
pushcse(B.CreateNot(op->getCondition())), op->getType())));
operands.push_back(op->getFalseValue());
changed = true;
continue;
}
if (auto tc = dyn_cast<ConstantFP>(op->getFalseValue()))
if (tc->isZero()) {
operands.push_back(
pushcse(B.CreateUIToFP(op->getCondition(), op->getType())));
operands.push_back(op->getTrueValue());
changed = true;
continue;
}
}
operands.push_back(v);
}
if (constval)
operands.push_back(ConstantFP::get(cur->getType(), *constval));
if (operands.size() == 0) {
replaceAndErase(cur, ConstantFP::get(cur->getType(), 1.0));
return "EmptyProduct";
}
if (operands.size() == 1) {
replaceAndErase(cur, operands[0]);
return "SingleProduct";
}
if (changed) {
auto mul = pushcse(
B.CreateCall(getProductIntrinsic(*F.getParent(), cur->getType()),
operands, cur->getName()));
replaceAndErase(cur, mul);
return "ProductSimplification";
}
}
if (auto P = isSum(cur)) {
// map from operand, to number of counts
std::map<Value *, unsigned> operands;
std::optional<APFloat> constval;
bool changed = false;
for (auto &v : callOperands(P)) {
if (auto P2 = isSum(v)) {
for (auto &v2 : callOperands(P2)) {
push(v2);
operands[v2]++;
}
push(P2);
changed = true;
continue;
}
if (auto C = dyn_cast<ConstantFP>(v)) {
if (C->isExactlyValue(0.0)) {
changed = true;
continue;
}
if (!constval) {
constval = C->getValue();
continue;
}
constval = (*constval) + C->getValue();
changed = true;
continue;
}
operands[v]++;
}
if (constval)
operands[ConstantFP::get(cur->getType(), *constval)]++;
if (operands.size() == 0) {
replaceAndErase(cur, ConstantFP::get(cur->getType(), 0.0));
return "EmptySum";
}
SmallVector<Value *, 1> args;
for (auto &pair : operands) {
if (pair.second == 1) {
args.push_back(pair.first);
continue;
}
changed = true;
Value *sargs[] = {pair.first,
ConstantFP::get(cur->getType(), (double)pair.second)};
args.push_back(pushcse(B.CreateCall(
getProductIntrinsic(*F.getParent(), cur->getType()), sargs)));
}
if (args.size() == 1) {
replaceAndErase(cur, args[0]);
return "SingleSum";
}
if (changed) {
auto sum =
pushcse(B.CreateCall(getSumIntrinsic(*F.getParent(), cur->getType()),
args, cur->getName()));
replaceAndErase(cur, sum);
return "SumSimplification";
}
}
if (auto P = isProduct(cur)) {
SmallVector<Value *, 1> operands;
SmallVector<Value *, 1> conditions;
for (auto &v : callOperands(P)) {
// z = uitofp i1 c to float -> select c, (prod withot z), 0
if (auto op = dyn_cast<UIToFPInst>(v)) {
if (op->getOperand(0)->getType()->isIntegerTy(1)) {
conditions.push_back(op->getOperand(0));
continue;
}
}
// z = sitofp i1 c to float -> select c, (-prod withot z), 0
if (auto op = dyn_cast<SIToFPInst>(v)) {
if (op->getOperand(0)->getType()->isIntegerTy(1)) {
conditions.push_back(op->getOperand(0));
operands.push_back(ConstantFP::get(cur->getType(), -1.0));
continue;
}
}
if (auto op = dyn_cast<SelectInst>(v)) {
if (auto tc = dyn_cast<ConstantFP>(op->getTrueValue()))
if (tc->isZero()) {
conditions.push_back(pushcse(B.CreateNot(op->getCondition())));
operands.push_back(op->getFalseValue());
continue;
}
if (auto tc = dyn_cast<ConstantFP>(op->getFalseValue()))
if (tc->isZero()) {
conditions.push_back(op->getCondition());
operands.push_back(op->getTrueValue());
continue;
}
}
operands.push_back(v);
}
if (conditions.size()) {
auto mul = pushcse(B.CreateCall(
getProductIntrinsic(*F.getParent(), cur->getType()), operands));
Value *condition = nullptr;
for (auto v : conditions) {
assert(v->getType()->isIntegerTy(1));
if (condition == nullptr) {
condition = v;
continue;
}
condition = pushcse(B.CreateAnd(condition, v));
}
auto zero = ConstantFP::get(cur->getType(), 0.0);
auto sel = pushcse(B.CreateSelect(condition, mul, zero, cur->getName()));
replaceAndErase(cur, sel);
return "ProductSelect";
}
}
// TODO
if (auto P = isSum(cur)) {
// whether negated
SmallVector<std::pair<Value *, bool>, 1> conditions;
bool legal = true;
for (auto &v : callOperands(P)) {
// z = uitofp i1 c to float -> select c, (prod withot z), 0
if (auto op = dyn_cast<UIToFPInst>(v)) {
if (op->getOperand(0)->getType()->isIntegerTy(1)) {
conditions.emplace_back(op->getOperand(0), false);
continue;
}
}
// z = sitofp i1 c to float -> select c, (-prod withot z), 0
if (auto op = dyn_cast<SIToFPInst>(v)) {
if (op->getOperand(0)->getType()->isIntegerTy(1)) {
conditions.emplace_back(op->getOperand(0), false);
continue;
}
}
if (auto op = dyn_cast<SelectInst>(v)) {
if (auto tc = dyn_cast<ConstantFP>(op->getTrueValue()))
if (tc->isZero()) {
conditions.emplace_back(op->getCondition(), true);
continue;
}
if (auto tc = dyn_cast<ConstantFP>(op->getFalseValue()))
if (tc->isZero()) {
conditions.emplace_back(op->getCondition(), false);
continue;
}
}
legal = false;
break;
}
Value *condition = nullptr;
if (legal)
for (size_t i = 0; i < conditions.size(); i++) {
size_t count = 0;
for (size_t j = 0; j < conditions.size(); j++) {
if (((conditions[i].first == conditions[j].first) &&
(conditions[i].second == conditions[i].second)) ||
((isNot(conditions[i].first, conditions[j].first) &&
(conditions[i].second != conditions[i].second))))
count++;
}
if (count == conditions.size() && count > 1) {
condition = conditions[i].first;
if (conditions[i].second)
condition = pushcse(B.CreateNot(condition, "sumpnot"));
break;
}
}
if (condition) {
SmallVector<Value *, 1> operands;
for (auto &v : callOperands(P)) {
// z = uitofp i1 c to float -> select c, (prod withot z), 0
if (auto op = dyn_cast<UIToFPInst>(v)) {
if (op->getOperand(0)->getType()->isIntegerTy(1)) {
operands.push_back(ConstantFP::get(cur->getType(), 1.0));
continue;
}
}
// z = sitofp i1 c to float -> select c, (-prod withot z), 0
if (auto op = dyn_cast<SIToFPInst>(v)) {
if (op->getOperand(0)->getType()->isIntegerTy(1)) {
operands.push_back(ConstantFP::get(cur->getType(), -1.0));
continue;
}
}
if (auto op = dyn_cast<SelectInst>(v)) {
if (auto tc = dyn_cast<ConstantFP>(op->getTrueValue()))
if (tc->isZero()) {
operands.push_back(op->getFalseValue());
continue;
}
if (auto tc = dyn_cast<ConstantFP>(op->getFalseValue()))
if (tc->isZero()) {
operands.push_back(op->getTrueValue());
continue;
}
}
llvm::errs() << " unhandled call op sumselect: " << *v << "\n";
assert(0);
}
if (conditions.size()) {
auto sum = pushcse(B.CreateCall(
getSumIntrinsic(*F.getParent(), cur->getType()), operands));
auto zero = ConstantFP::get(cur->getType(), 0.0);
auto sel =
pushcse(B.CreateSelect(condition, sum, zero, cur->getName()));
replaceAndErase(cur, sel);
return "SumSelect";
}
}
}
// (a1*b1) + (a1*c1) + (a1*d1 ) + ... -> a1 * (b1 + c1 + d1 + ...)
if (auto S = isSum(cur)) {
SmallVector<Value *, 1> allOps;
auto combine = [](const SmallVector<Value *, 1> &lhs,
SmallVector<Value *, 1> rhs) {
SmallVector<Value *, 1> out;
for (auto v : lhs) {
bool seen = false;
for (auto &v2 : rhs) {
if (v == v2) {
v2 = nullptr;
seen = true;
break;
}
}
if (seen) {
out.push_back(v);
}
}
return out;
};
auto subtract = [](SmallVector<Value *, 1> lhs,
const SmallVector<Value *, 1> &rhs) {
for (auto v : rhs) {
auto found = find(lhs, v);
assert(found != lhs.end());
lhs.erase(found);
}
return lhs;
};
bool seen = false;
bool legal = true;
for (auto op : callOperands(S)) {
auto P = isProduct(op);
if (!P) {
legal = false;
break;
}
if (!seen) {
allOps = callOperands(P);
seen = true;
continue;
}
allOps = combine(allOps, callOperands(P));
}
if (legal && allOps.size() > 0) {
SmallVector<Value *, 1> operands;
for (auto op : callOperands(S)) {
auto P = isProduct(op);
push(op);
auto sub = subtract(callOperands(P), allOps);
auto newprod = pushcse(B.CreateCall(
getProductIntrinsic(*F.getParent(), S->getType()), sub));
operands.push_back(newprod);
}
auto newsum = pushcse(B.CreateCall(
getSumIntrinsic(*F.getParent(), S->getType()), operands));
allOps.push_back(newsum);
auto fprod = pushcse(B.CreateCall(
getProductIntrinsic(*F.getParent(), S->getType()), allOps));
replaceAndErase(cur, fprod);
return "SumFactor";
}
}
/*
// add (ext (x == expr )), ( ext (x == expr + 1)) -> -expr == c2 ) and c1
!= c2 -> false if (cur->getOpcode() == Instruction::Add) for (int j=0; j<2;
j++) if (auto c0 = dyn_cast<ZExtInst>(cur->getOperand(j))) if (auto cmp0 =
dyn_cast<ICmpInst>(c0->getOperand(0))) if (auto c1 =
dyn_cast<CastInst>(cur->getOperand(1-j))) if (auto cmp1 =
dyn_cast<ICmpInst>(c0->getOperand(0))) if (cmp0->getPredicate() ==
ICmpInst::ICMP_EQ && cmp1->getPredicate() == ICmpInst::ICMP_EQ)
{
for (size_t i0 = 0; i0 < 2; i0++)
for (size_t i1 = 0; i1 < 2; i1++)
if (cmp0->getOperand(1 - i0) == cmp1->getOperand(1 - i1))
auto e0 = SE.getSCEV(cmp0->getOperand(i0));
auto e1 = SE.getSCEV(cmp1->getOperand(i1));
auto m = SE.getMinusSCEV(e0, e1, SCEV::NoWrapMask);
if (auto C = dyn_cast<SCEVConstant>(m)) {
// if c1 == c2 don't need the and they are equivalent
if (C->getValue()->isZero()) {
} else {
auto sel0 = pushcse(B.CreateSelect(cmp0,
ConstantInt::get(cur->getType(), isa<ZExtInst>(cmp0) ? 1 : -1),
ConstantInt::get(cur->getType(), 0));
// if non one constant they must be distinct.
replaceAndErase(cur,
ConstantInt::getFalse(cur->getContext()));
return "AndNEExpr";
}
}
}
}
*/
if (auto fcmp = dyn_cast<FCmpInst>(cur)) {
auto predicate = fcmp->getPredicate();
if (predicate == FCmpInst::FCMP_OEQ || predicate == FCmpInst::FCMP_UEQ ||
predicate == FCmpInst::FCMP_UNE || predicate == FCmpInst::FCMP_ONE) {
for (int i = 0; i < 2; i++)
if (auto C = dyn_cast<ConstantFP>(fcmp->getOperand(i))) {
if (C->isZero()) {
// (a1*a2*...an) == 0 -> (a1 == 0) || (a2 == 0) || ... (a2 == 0)
// (a1*a2*...an) != 0 -> ![ (a1 == 0) || (a2 == 0) || ... (a2 ==
// 0)
// ]
if (auto P = isProduct(fcmp->getOperand(1 - i))) {
Value *res = nullptr;
auto eq_predicate = predicate;
if (predicate == FCmpInst::FCMP_UNE ||
predicate == FCmpInst::FCMP_ONE)
eq_predicate = fcmp->getInversePredicate();
for (auto &v : callOperands(P)) {
auto ncmp1 = pushcse(B.CreateFCmp(eq_predicate, v, C));
if (!res)
res = ncmp1;
else
res = pushcse(B.CreateOr(res, ncmp1));
}
if (predicate == FCmpInst::FCMP_UNE ||
predicate == FCmpInst::FCMP_ONE) {
res = pushcse(B.CreateNot(res));
}
replaceAndErase(cur, res);
return "CmpProductSplit";
}
// (a1*b1) + (a1*c1) + (a1*d1 ) + ... ?= 0 -> a1 * (b1 + c1 + d1 +
// ...) ?= 0
if (auto S = isSum(fcmp->getOperand(1 - i))) {
SmallVector<Value *, 1> allOps;
auto combine = [](const SmallVector<Value *, 1> &lhs,
SmallVector<Value *, 1> rhs) {
SmallVector<Value *, 1> out;
for (auto v : lhs) {
bool seen = false;
for (auto &v2 : rhs) {
if (v == v2) {
v2 = nullptr;
seen = true;
break;
}
}
if (seen) {
out.push_back(v);
}
}
return out;
};
auto subtract = [](SmallVector<Value *, 1> lhs,
const SmallVector<Value *, 1> &rhs) {
for (auto v : rhs) {
auto found = find(lhs, v);
assert(found != lhs.end());
lhs.erase(found);
}
return lhs;
};
bool seen = false;
bool legal = true;
for (auto op : callOperands(S)) {
auto P = isProduct(op);
if (!P) {
legal = false;
break;
}
if (!seen) {
allOps = callOperands(P);
seen = true;
continue;
}
allOps = combine(allOps, callOperands(P));
}
if (legal && allOps.size() > 0) {
SmallVector<Value *, 1> operands;
for (auto op : callOperands(S)) {
auto P = isProduct(op);
push(op);
auto sub = subtract(callOperands(P), allOps);
auto newprod = pushcse(B.CreateCall(
getProductIntrinsic(*F.getParent(), C->getType()), sub));
operands.push_back(newprod);
}
auto newsum = pushcse(B.CreateCall(
getSumIntrinsic(*F.getParent(), C->getType()), operands));
allOps.push_back(newsum);
auto fprod = pushcse(B.CreateCall(
getProductIntrinsic(*F.getParent(), C->getType()), allOps));
auto fcmp = pushcse(B.CreateCmp(predicate, fprod, C));
replaceAndErase(cur, fcmp);
return "CmpSumFactor";
}
}
}
}
}
}
if (auto fcmp = dyn_cast<FCmpInst>(cur)) {
auto predicate = fcmp->getPredicate();
if (predicate == FCmpInst::FCMP_OEQ || predicate == FCmpInst::FCMP_UEQ ||
predicate == FCmpInst::FCMP_UNE || predicate == FCmpInst::FCMP_ONE) {
for (int i = 0; i < 2; i++)
if (auto C = dyn_cast<ConstantFP>(fcmp->getOperand(i))) {
if (C->isZero()) {
// a + b == 0 -> ( (a == 0 & b == 0) || a == -b)
if (auto S = isSum(fcmp->getOperand(1 - i))) {
auto allOps = callOperands(S);
if (!llvm::any_of(allOps, guaranteedDataDependent)) {
auto eq_predicate = predicate;
if (predicate == FCmpInst::FCMP_UNE ||
predicate == FCmpInst::FCMP_ONE)
eq_predicate = fcmp->getInversePredicate();
Value *op_checks = nullptr;
for (auto a : allOps) {
auto a_e0 = pushcse(B.CreateFCmp(eq_predicate, a, C));
if (op_checks == nullptr)
op_checks = a_e0;
else
op_checks = pushcse(B.CreateAnd(op_checks, a_e0));
}
SmallVector<Value *, 1> slice;
for (size_t i = 1; i < allOps.size(); i++)
slice.push_back(allOps[i]);
auto ane = pushcse(B.CreateFCmp(
eq_predicate, pushcse(B.CreateFNeg(allOps[0])),
pushcse(B.CreateCall(getFunctionFromCall(S), slice))));
auto ori = pushcse(B.CreateOr(op_checks, ane));
if (predicate == FCmpInst::FCMP_UNE ||
predicate == FCmpInst::FCMP_ONE) {
ori = pushcse(B.CreateNot(ori));
}
replaceAndErase(cur, ori);
return "Sum2ZeroSplit";
}
}
}
}
}
}
// (zext a) + (zext b) ?= 0 -> zext a ?= - zext b
if (auto icmp = dyn_cast<CmpInst>(cur)) {
if (icmp->getPredicate() == CmpInst::ICMP_EQ ||
icmp->getPredicate() == CmpInst::ICMP_NE) {
for (int i = 0; i < 2; i++)
if (auto C = dyn_cast<ConstantInt>(icmp->getOperand(i)))
if (C->isZero())
if (auto add = dyn_cast<BinaryOperator>(icmp->getOperand(1 - i)))
if (add->getOpcode() == Instruction::Add)
if (auto a0 = dyn_cast<CastInst>(add->getOperand(0)))
if (auto a1 = dyn_cast<CastInst>(add->getOperand(1)))
if (a0->getOperand(0)->getType() ==
a1->getOperand(0)->getType() &&
(isa<ZExtInst>(a0) || isa<SExtInst>(a0))) {
auto cmp2 = pushcse(B.CreateCmp(
icmp->getPredicate(), a0, pushcse(B.CreateNeg(a1))));
replaceAndErase(cur, cmp2);
return "CmpExt0Shuffle";
}
}
}
// sub 0, (zext i1 to N) -> sext i1 to N
// sub 0, (sext i1 to N) -> zext i1 to N
if (auto sub = dyn_cast<BinaryOperator>(cur))
if (sub->getOpcode() == Instruction::Sub)
if (auto C = dyn_cast<ConstantInt>(sub->getOperand(0)))
if (C->isZero())
if (auto a0 = dyn_cast<CastInst>(sub->getOperand(1)))
if (a0->getOperand(0)->getType()->isIntegerTy(1)) {
Value *tmp = nullptr;
if (isa<ZExtInst>(a0))
tmp = pushcse(B.CreateSExt(a0->getOperand(0), a0->getType()));
else if (isa<SExtInst>(a0))
tmp = pushcse(B.CreateZExt(a0->getOperand(0), a0->getType()));
else
assert(0);
replaceAndErase(cur, tmp);
return "NegSZExtI1";
}
if ((cur->getOpcode() == Instruction::LShr ||
cur->getOpcode() == Instruction::SDiv ||
cur->getOpcode() == Instruction::UDiv) &&
cur->isExact())
if (auto C2 = dyn_cast<ConstantInt>(cur->getOperand(1)))
if (auto mul = dyn_cast<BinaryOperator>(cur->getOperand(0))) {
// (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if
// C2 divides C1
if (mul->getOpcode() == Instruction::Mul)
for (int i0 = 0; i0 < 2; i0++)
if (auto C1 = dyn_cast<ConstantInt>(mul->getOperand(i0))) {
auto lhs = C1->getValue();
APInt rhs = C2->getValue();
if (cur->getOpcode() == Instruction::LShr) {
rhs = APInt(rhs.getBitWidth(), 1) << rhs;
}
APInt div, rem;
if (cur->getOpcode() == Instruction::LShr ||
cur->getOpcode() == Instruction::UDiv)
APInt::udivrem(lhs, rhs, div, rem);
else
APInt::sdivrem(lhs, rhs, div, rem);
if (rem == 0) {
auto res = pushcse(B.CreateMul(
mul->getOperand(1 - i0),
ConstantInt::get(cur->getType(), div),
"mdiv." + cur->getName(), mul->hasNoUnsignedWrap(),
mul->hasNoSignedWrap()));
push(mul);
replaceAndErase(cur, res);
return "IMulDivConst";
}
}
// (lshr exact (add a, C1), C2), C -> add a, (lhsr exact C1, C2) if
// C2
if (mul->getOpcode() == Instruction::Add)
for (int i0 = 0; i0 < 2; i0++)
if (auto C1 = dyn_cast<ConstantInt>(mul->getOperand(i0))) {
auto lhs = C1->getValue();
APInt rhs = C2->getValue();
if (cur->getOpcode() == Instruction::LShr) {
rhs = APInt(rhs.getBitWidth(), 1) << rhs;
}
APInt div, rem;
if (cur->getOpcode() == Instruction::LShr ||
cur->getOpcode() == Instruction::UDiv)
APInt::udivrem(lhs, rhs, div, rem);
else
APInt::sdivrem(lhs, rhs, div, rem);
if (rem == 0 && ((mul->hasNoUnsignedWrap() &&
(cur->getOpcode() == Instruction::LShr ||
cur->getOpcode() == Instruction::UDiv)) ||
(mul->hasNoSignedWrap() &&
(cur->getOpcode() == Instruction::AShr ||
cur->getOpcode() == Instruction::SDiv)))) {
auto res = pushcse(B.CreateAdd(
mul->getOperand(1 - i0),
ConstantInt::get(cur->getType(), div),
"madd." + cur->getName(), mul->hasNoUnsignedWrap(),
mul->hasNoSignedWrap()));
push(mul);
replaceAndErase(cur, res);
return "IAddDivConst";
}
}
}
// mul (mul a, const1), (mul b, const2) -> mul (mul a, b), (const1, const2)
if (cur->getOpcode() == Instruction::FMul)
if (cur->isFast())
if (auto mul1 = dyn_cast<Instruction>(cur->getOperand(0)))
if (mul1->getOpcode() == Instruction::FMul && mul1->isFast())
if (auto mul2 = dyn_cast<Instruction>(cur->getOperand(1)))
if (mul2->getOpcode() == Instruction::FMul && mul2->isFast()) {
for (auto i1 = 0; i1 < 2; i1++)
for (auto i2 = 0; i2 < 2; i2++)
if (isa<Constant>(mul1->getOperand(i1)))
if (isa<Constant>(mul2->getOperand(i2))) {
auto n0 = pushcse(
B.CreateFMulFMF(mul1->getOperand(1 - i1),
mul2->getOperand(1 - i2), cur));
auto n1 = pushcse(B.CreateFMulFMF(
mul1->getOperand(i1), mul2->getOperand(i2), cur));
auto n2 = pushcse(B.CreateFMulFMF(n0, n1, cur));
push(mul1);
push(mul2);
replaceAndErase(cur, n2);
return "MulMulConstConst";
}
}
// mul (mul a, const1), const2 -> mul a, (mul const1, const2)
if ((cur->getOpcode() == Instruction::FMul && cur->isFast()) ||
cur->getOpcode() == Instruction::Mul)
for (auto i1 = 0; i1 < 2; i1++)
if (auto mul1 = dyn_cast<Instruction>(cur->getOperand(i1)))
if (((mul1->getOpcode() == Instruction::FMul && mul1->isFast())) ||
mul1->getOpcode() == Instruction::FMul)
if (auto const2 = dyn_cast<Constant>(cur->getOperand(1 - i1)))
for (auto i2 = 0; i2 < 2; i2++)
if (auto const1 = dyn_cast<Constant>(mul1->getOperand(i2))) {
Value *res = nullptr;
if (cur->getOpcode() == Instruction::FMul) {
auto const3 = pushcse(B.CreateFMulFMF(const1, const2, mul1));
res = pushcse(
B.CreateFMulFMF(mul1->getOperand(1 - i2), const3, cur));
} else {
auto const3 = pushcse(B.CreateMul(const1, const2));
res = pushcse(B.CreateMul(mul1->getOperand(1 - i2), const3));
}
push(mul1);
replaceAndErase(cur, res);
return "MulConstConst";
}
if (auto fcmp = dyn_cast<FCmpInst>(cur)) {
if (fcmp->getPredicate() == FCmpInst::FCMP_OEQ) {
for (int i = 0; i < 2; i++)
if (auto C = dyn_cast<ConstantFP>(fcmp->getOperand(i))) {
if (C->isZero()) {
if (auto fmul = dyn_cast<BinaryOperator>(fcmp->getOperand(1 - i))) {
// (a*b) == 0 -> (a == 0) || (b == 0)
if (fmul->getOpcode() == Instruction::FMul) {
auto ncmp1 = pushcse(
B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(0), C));
auto ncmp2 = pushcse(
B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(1), C));
auto ori = pushcse(B.CreateOr(ncmp1, ncmp2));
replaceAndErase(cur, ori);
return "CmpFMulSplit";
}
// (a/b) == 0 -> (a == 0)
if (fmul->getOpcode() == Instruction::FDiv) {
auto ncmp1 = pushcse(
B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(0), C));
replaceAndErase(cur, ncmp1);
return "CmpFDivSplit";
}
// (a - b) ?= 0 -> a ?= b
if (fmul->getOpcode() == Instruction::FSub) {
auto ncmp1 = pushcse(B.CreateFCmp(fcmp->getPredicate(),
fmul->getOperand(0),
fmul->getOperand(1)));
replaceAndErase(cur, ncmp1);
return "CmpFSubSplit";
}
}
if (auto cast = dyn_cast<SIToFPInst>(fcmp->getOperand(1 - i))) {
auto ncmp1 = pushcse(B.CreateICmp(
ICmpInst::ICMP_EQ, cast->getOperand(0),
ConstantInt::get(cast->getOperand(0)->getType(), 0)));
replaceAndErase(cur, ncmp1);
return "SFCmpToICmp";
}
if (auto cast = dyn_cast<UIToFPInst>(fcmp->getOperand(1 - i))) {
auto ncmp1 = pushcse(B.CreateICmp(
ICmpInst::ICMP_EQ, cast->getOperand(0),
ConstantInt::get(cast->getOperand(0)->getType(), 0)));
replaceAndErase(cur, ncmp1);
return "UFCmpToICmp";
}
if (auto SI = dyn_cast<SelectInst>(fcmp->getOperand(1 - i))) {
auto res = pushcse(
B.CreateSelect(SI->getCondition(),
pushcse(B.CreateCmp(fcmp->getPredicate(), C,
SI->getTrueValue())),
pushcse(B.CreateCmp(fcmp->getPredicate(), C,
SI->getFalseValue()))));
replaceAndErase(cur, res);
return "FCmpSelect";
}
}
}
}
}
if (auto fcmp = dyn_cast<CmpInst>(cur)) {
if (fcmp->getPredicate() == CmpInst::ICMP_EQ ||
fcmp->getPredicate() == CmpInst::ICMP_NE ||
fcmp->getPredicate() == CmpInst::FCMP_OEQ ||
fcmp->getPredicate() == CmpInst::FCMP_ONE) {
// a + c ?= a -> c ?= 0 , if fast
for (int i = 0; i < 2; i++)
if (auto inst = dyn_cast<Instruction>(fcmp->getOperand(i)))
if (inst->getOpcode() == Instruction::FAdd && inst->isFast())
for (int i2 = 0; i2 < 2; i2++)
if (inst->getOperand(i2) == fcmp->getOperand(1 - i)) {
auto res = pushcse(
B.CreateCmp(fcmp->getPredicate(), inst->getOperand(1 - i2),
ConstantFP::get(inst->getType(), 0)));
replaceAndErase(cur, res);
return "CmpFAddSame";
}
// a == b -> a & b | !a & !b
// a != b -> a & !b | !a & b
if (fcmp->getOperand(0)->getType()->isIntegerTy(1)) {
auto a = fcmp->getOperand(0);
auto b = fcmp->getOperand(1);
if (fcmp->getPredicate() == CmpInst::ICMP_EQ) {
auto res = pushcse(
B.CreateOr(pushcse(B.CreateAnd(a, b)),
pushcse(B.CreateAnd(pushcse(B.CreateNot(a)),
pushcse(B.CreateNot(b))))));
replaceAndErase(cur, res);
return "CmpI1EQ";
}
if (fcmp->getPredicate() == CmpInst::ICMP_NE) {
auto res = pushcse(
B.CreateOr(pushcse(B.CreateAnd(pushcse(B.CreateNot(a)), b)),
pushcse(B.CreateAnd(a, pushcse(B.CreateNot(b))))));
replaceAndErase(cur, res);
return "CmpI1NE";
}
}
for (int i = 0; i < 2; i++)
if (auto CI = dyn_cast<ConstantInt>(fcmp->getOperand(i)))
if (CI->isZero()) {
// a + a ?= 0 -> a ?= 0
if (auto addI = dyn_cast<Instruction>(fcmp->getOperand(1 - i))) {
if (addI->getOperand(0) == addI->getOperand(1)) {
Value *res = pushcse(
B.CreateCmp(fcmp->getPredicate(), addI->getOperand(0), CI));
replaceAndErase(cur, res);
return "CmpAddAdd";
}
// (a-b) ?= 0 -> a ?= b
if (addI->getOpcode() == Instruction::Sub) {
auto ncmp1 = pushcse(B.CreateICmp(fcmp->getPredicate(),
addI->getOperand(0),
addI->getOperand(1)));
replaceAndErase(cur, ncmp1);
return "CmpISubSplit";
}
}
}
// (a * b) == (c * b) -> (a == c) || b == 0
// (a * b) != (c * b) -> (a != c) && b != 0
// auto S1 = SE.getSCEV(cur->getOperand(0));
// auto S2 = SE.getSCEV(cur->getOperand(1));
// llvm::errs() <<" attempting push: " << *cur << " S1: " << *S1 << "
// S2: " << *S2 << " and " << *cur->getOperand(0) << " " <<
// *cur->getOperand(1) << "\n";
if (auto mul1 = dyn_cast<Instruction>(cur->getOperand(0)))
if (auto mul2 = dyn_cast<Instruction>(cur->getOperand(1))) {
if (mul1->getOpcode() == Instruction::Mul &&
mul2->getOpcode() == Instruction::Mul &&
mul1->hasNoUnsignedWrap() && mul1->hasNoSignedWrap() &&
mul2->hasNoUnsignedWrap() && mul2->hasNoSignedWrap()) {
for (int i = 0; i < 2; i++) {
if (mul1->getOperand(i) == mul2->getOperand(i)) {
Value *res = pushcse(B.CreateICmp(fcmp->getPredicate(),
mul1->getOperand(1 - i),
mul2->getOperand(1 - i)));
auto b = mul1->getOperand(i);
if (fcmp->getPredicate() == CmpInst::ICMP_EQ) {
Value *bZero = pushcse(B.CreateICmp(
CmpInst::ICMP_EQ, b, ConstantInt::get(b->getType(), 0)));
res = pushcse(B.CreateOr(res, bZero));
} else {
Value *bZero = pushcse(B.CreateICmp(
ICmpInst::ICMP_NE, b, ConstantInt::get(b->getType(), 0)));
res = pushcse(B.CreateAnd(res, bZero));
}
replaceAndErase(cur, res);
return "CmpMulCommon";
}
}
}
// same as above but now with floats
if (mul1->getOpcode() == Instruction::FMul &&
mul2->getOpcode() == Instruction::FMul && mul1->isFast() &&
mul2->isFast()) {
for (int i = 0; i < 2; i++) {
if (mul1->getOperand(i) == mul2->getOperand(i)) {
Value *res = pushcse(B.CreateFCmp(fcmp->getPredicate(),
mul1->getOperand(1 - i),
mul2->getOperand(1 - i)));
auto b = mul1->getOperand(i);
if (fcmp->getPredicate() == CmpInst::FCMP_OEQ) {
Value *bZero = pushcse(B.CreateCmp(
CmpInst::FCMP_OEQ, b, ConstantFP::get(b->getType(), 0)));
res = pushcse(B.CreateOr(res, bZero));
} else {
Value *bZero = pushcse(B.CreateCmp(
CmpInst::FCMP_ONE, b, ConstantFP::get(b->getType(), 0)));
res = pushcse(B.CreateAnd(res, bZero));
}
replaceAndErase(cur, res);
return "CmpMulfCommon";
}
}
}
// (uitofp a ) ?= (uitofp b) -> a ?= b
for (auto cond : {Instruction::UIToFP, Instruction::SIToFP})
if (mul1->getOpcode() == cond && mul2->getOpcode() == cond &&
mul1->getOperand(0)->getType() ==
mul2->getOperand(0)->getType()) {
Value *res = pushcse(B.CreateICmp(
fcmp->getPredicate() == CmpInst::FCMP_OEQ ? CmpInst::ICMP_EQ
: CmpInst::ICMP_NE,
mul1->getOperand(0), mul2->getOperand(0)));
replaceAndErase(cur, res);
return "CmpUIToFP";
}
// (zext a ) ?= (zext b) -> a ?= b
if (mul1->getOpcode() == Instruction::ZExt &&
mul2->getOpcode() == Instruction::ZExt &&
mul1->getOperand(0)->getType() ==
mul2->getOperand(0)->getType()) {
Value *res =
pushcse(B.CreateICmp(fcmp->getPredicate(), mul1->getOperand(0),
mul2->getOperand(0)));
replaceAndErase(cur, res);
return "CmpZExt";
}
// (zext i1 a ) == (sext i1 b) -> (!a & !b)
// (zext i1 a ) != (sext i1 b) -> (a | b)
if (auto mul1 = dyn_cast<Instruction>(cur->getOperand(0)))
if (auto mul2 = dyn_cast<Instruction>(cur->getOperand(1)))
if (((mul1->getOpcode() == Instruction::ZExt &&
mul2->getOpcode() == Instruction::SExt) ||
(mul1->getOpcode() == Instruction::SExt &&
mul2->getOpcode() == Instruction::ZExt)) &&
mul1->getOperand(0)->getType() ==
mul2->getOperand(0)->getType() &&
mul1->getOperand(0)->getType()->isIntegerTy(1)) {
Value *na = mul1->getOperand(0);
Value *nb = mul2->getOperand(0);
if (fcmp->getPredicate() == ICmpInst::ICMP_EQ) {
na = pushcse(B.CreateNot(na));
nb = pushcse(B.CreateNot(nb));
}
Value *res = nullptr;
if (fcmp->getPredicate() == ICmpInst::ICMP_EQ)
res = pushcse(B.CreateAnd(na, nb));
else
res = pushcse(B.CreateOr(na, nb));
replaceAndErase(cur, res);
return "CmpZExtSExt";
}
}
}
if (fcmp->getPredicate() == ICmpInst::ICMP_EQ) {
for (int i = 0; i < 2; i++) {
if (auto C = dyn_cast<ConstantInt>(fcmp->getOperand(i))) {
if (C->isZero()) {
if (auto fmul = dyn_cast<BinaryOperator>(fcmp->getOperand(1 - i))) {
// (a*b) == 0 -> (a == 0) || (b == 0)
if (fmul->getOpcode() == Instruction::Mul) {
auto ncmp1 = pushcse(
B.CreateICmp(fcmp->getPredicate(), fmul->getOperand(0), C));
auto ncmp2 = pushcse(
B.CreateICmp(fcmp->getPredicate(), fmul->getOperand(1), C));
auto ori = pushcse(B.CreateOr(ncmp1, ncmp2));
replaceAndErase(cur, ori);
return "CmpIMulSplit";
}
}
}
}
}
}
}
if (cur->getOpcode() == Instruction::FAdd) {
// add x, x -> mul 2.0
if (cur->getOperand(0) == cur->getOperand(1) && cur->isFast()) {
auto res = pushcse(B.CreateFMulFMF(
cur->getOperand(0), ConstantFP::get(cur->getType(), 2.0), cur));
replaceAndErase(cur, res);
return "AddToMul2";
}
}
if (cur->getOpcode() == Instruction::Add) {
// add x, (y * -1) -> sub x, y
for (int i = 0; i < 2; i++) {
if (auto mul1 = dyn_cast<Instruction>(cur->getOperand(i)))
if (mul1->getOpcode() == Instruction::Mul) {
for (int j = 0; j < 2; j++) {
if (auto C = dyn_cast<ConstantInt>(mul1->getOperand(j))) {
if (C->isMinusOne()) {
auto res = pushcse(B.CreateSub(cur->getOperand(1 - i),
mul1->getOperand(1 - j)));
push(mul1);
replaceAndErase(cur, res);
return "AddToSub";
}
}
}
}
}
}
if (auto SI = dyn_cast<SelectInst>(cur)) {
auto shouldMove = [](Value *v) { return isa<Constant>(v); };
/*
// select c, 0, x -> fmul (uitofp (!c)), x
if (auto C1 = dyn_cast<ConstantFP>(SI->getTrueValue())) {
if (C1->isZero()) {
auto n = pushcse(B.CreateNot(SI->getCondition()));
auto val = pushcse(B.CreateUIToFP(n, SI->getType()));
auto res = pushcse(B.CreateFMul(val, SI->getFalseValue()));
if (auto I = dyn_cast<Instruction>(res))
I->setFast(true);
replaceAndErase(cur, res);
return true;
}
}
// select c, x, 0 -> fmul (uitofp c), x
if (auto C1 = dyn_cast<ConstantFP>(SI->getFalseValue())) {
if (C1->isZero()) {
auto val = pushcse(B.CreateUIToFP(SI->getCondition(), SI->getType()));
auto res = pushcse(B.CreateFMul(val, SI->getTrueValue()));
if (auto I = dyn_cast<Instruction>(res))
I->setFast(true);
replaceAndErase(cur, res);
return true;
}
}
*/
// select c, (mul x y), 0 -> mul x, (select c, y, 0)
for (int i = 0; i < 2; i++)
if (auto inst = dyn_cast<Instruction>(SI->getOperand(1 + i)))
if (inst->getOpcode() == Instruction::Mul)
// inst->getOpcode() == Instruction::FMul)
if (auto C = dyn_cast<Constant>(SI->getOperand(1 + (1 - i))))
if ((isa<ConstantInt>(C) && cast<ConstantInt>(C)->isZero()) ||
(isa<ConstantFP>(C) && cast<ConstantFP>(C)->isZero()))
for (int j = 0; j < 2; j++)
if (shouldMove(inst->getOperand(j))) {
auto x = inst->getOperand(j);
auto y = inst->getOperand(1 - j);
auto isel = pushcse(B.CreateSelect(
SI->getCondition(), (i == 0) ? y : C, (i == 0) ? C : y,
"smulmove." + SI->getName()));
Value *imul;
if (cur->getType()->isIntegerTy())
imul = pushcse(B.CreateMul(isel, x, "",
inst->hasNoUnsignedWrap(),
inst->hasNoSignedWrap()));
else
imul = pushcse(B.CreateFMulFMF(isel, x, inst, ""));
replaceAndErase(cur, imul);
return "SelMulMove";
}
// select c, (sitofp x), (sitofp y) -> sitofp (select c, x, y)
// select c, c5, (sitofp y) -> sitofp (select c, c5, y)
{
Value *ops[2] = {nullptr, nullptr};
bool legal = true;
for (int i = 0; i < 2; i++) {
if (isa<ConstantFP>(SI->getOperand(1 + i))) {
ops[i] = nullptr;
continue;
}
if (auto CI = dyn_cast<CastInst>(SI->getOperand(1 + i))) {
if (CI->getOpcode() == Instruction::SIToFP) {
ops[i] = CI->getOperand(0);
continue;
}
}
legal = false;
break;
}
for (int i = 0; i < 2; i++) {
if (!ops[i] && ops[1 - i])
ops[i] = ConstantInt::get(ops[1 - i]->getType(), 0);
}
for (int i = 0; i < 2; i++) {
if (ops[i] == nullptr || ops[i]->getType() != ops[0]->getType()) {
legal = false;
break;
}
}
if (legal) {
auto isel = pushcse(B.CreateSelect(SI->getCondition(), ops[0], ops[1],
"seltofp." + SI->getName()));
auto res = pushcse(B.CreateSIToFP(isel, SI->getType()));
replaceAndErase(cur, res);
return "SelSIMerge";
}
}
}
if (cur->getOpcode() == Instruction::Mul) {
for (int i = 0; i < 2; i++) {
// mul (x, 1) -> x
if (auto C = dyn_cast<ConstantInt>(cur->getOperand(i)))
if (C->isOne()) {
replaceAndErase(cur, cur->getOperand(1 - i));
return "MulIdent";
}
// mul (zext i1 x), y -> mul (zext i1 x) y[x->1]
if (auto Z = dyn_cast<ZExtInst>(cur->getOperand(i)))
if (Z->getOperand(0)->getType()->isIntegerTy(1)) {
auto prev = cur->getOperand(1 - i);
auto next = replace(prev, Z->getOperand(0),
ConstantInt::getTrue(cur->getContext()));
if (next != prev) {
auto res = pushcse(B.CreateMul(Z, next, "postmul." + cur->getName(),
cur->hasNoUnsignedWrap(),
cur->hasNoSignedWrap()));
replaceAndErase(cur, res);
return "MulReplaceZExt";
}
}
}
/*
// mul x, (select c, 0, y) -> select c (mul x 0), (mul x y)
for (int i=0; i<2; i++)
if (auto SI = dyn_cast<SelectInst>(cur->getOperand(i)))
for (int j=0; j<2; j++)
if (auto CI = dyn_cast<ConstantInt>(SI->getOperand(1+j)))
if (CI->isZero()) {
auto tval = (j == 0) ? CI :
pushcse(B.CreateMul(SI->getTrueValue(), cur->getOperand(1-i), "tval." +
cur->getName(), cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap())); auto
fval = (j == 1) ? CI : pushcse(B.CreateMul(SI->getFalseValue(),
cur->getOperand(1-i), "fval." + cur->getName(), cur->hasNoUnsignedWrap(),
cur->hasNoSignedWrap()));
auto res = pushcse(B.CreateSelect(SI->getCondition(), tval, fval));
replaceAndErase(cur, res);
return true;
}
*/
// mul (sub x, y), -c -> mul (sub, y, x), c
for (int i = 0; i < 2; i++)
if (auto inst = dyn_cast<Instruction>(cur->getOperand(i)))
if (inst->getOpcode() == Instruction::Sub)
if (auto CI = dyn_cast<ConstantInt>(cur->getOperand(1 - i)))
if (CI->isNegative()) {
auto sub2 = pushcse(B.CreateSub(
inst->getOperand(1), inst->getOperand(0), "",
inst->hasNoUnsignedWrap(), inst->hasNoSignedWrap()));
auto mul2 = pushcse(B.CreateMul(
sub2, ConstantInt::get(CI->getType(), -CI->getValue()), "",
cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap()));
replaceAndErase(cur, mul2);
return "MulSubNegConst";
}
}
if (cur->getOpcode() == Instruction::Sub)
if (auto CI = dyn_cast<ConstantInt>(cur->getOperand(0)))
if (CI->isZero())
if (auto zext = dyn_cast<Instruction>(cur->getOperand(1))) {
// sub 0, (zext i1 x) -> sext x
if (zext->getOpcode() == Instruction::ZExt &&
zext->getOperand(0)->getType()->isIntegerTy(1)) {
auto res =
pushcse(B.CreateSExt(zext->getOperand(0), cur->getType()));
replaceAndErase(cur, res);
return "SubZExt";
}
// sub 0, (mul nsw nuw constant, x) -> mul nsw nuw -constant, x
if (zext->getOpcode() == Instruction::Mul &&
zext->hasNoUnsignedWrap() && zext->hasNoSignedWrap()) {
for (int i = 0; i < 2; i++)
if (auto CI = dyn_cast<ConstantInt>(zext->getOperand(i))) {
auto res = pushcse(B.CreateMul(
zext->getOperand(1 - i),
ConstantInt::get(CI->getType(), -CI->getValue()),
"neg." + zext->getName(), true, true));
replaceAndErase(cur, res);
return "SubMulConstant";
}
}
}
// add (zext (and c1, x) ), (zext (and c1, y)) -> select c1, (add (zext x),
// (zext y)), 0
/*
if (cur->getOpcode() == Instruction::Add ||
cur->getOpcode() == Instruction::Sub ||
cur->getOpcode() == Instruction::Mul)
if (auto inst1 = dyn_cast<Instruction>(cur->getOperand(0)))
if (auto inst2 = dyn_cast<Instruction>(cur->getOperand(1)))
if (inst1->getOpcode() == Instruction::ZExt && inst2->getOpcode() ==
Instruction::ZExt) if (auto and1 =
dyn_cast<Instruction>(inst1->getOperand(0))) if (auto and2 =
dyn_cast<Instruction>(inst2->getOperand(0))) if
(and1->getType()->isIntegerTy(1) && and2->getType()->isIntegerTy(1) &&
and1->getOpcode() == Instruction::And && and2->getOpcode() ==
Instruction::And) { bool done = false; for (int i1=0; i1<2; i1++) for (int
i2=0; i2<2; i2++) if (and1->getOperand(i1) == and2->getOperand(i2)) { auto
c1 = and1->getOperand(i1); auto x = and1->getOperand(1-i1); x =
pushcse(B.CreateZExt(x, inst1->getType())); auto y =
and2->getOperand(1-i2);
y = pushcse(B.CreateZExt(y, inst2->getType()));
Value *res = nullptr;
switch (cur->getOpcode()) {
case Instruction::Add:
res = pushcse(B.CreateAdd(x, y, "", cur->hasNoUnsignedWrap(),
cur->hasNoSignedWrap())); break; case Instruction::Sub: res = B.CreateSub(x,
y,
"", cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap()); break; case
Instruction::Mul: res = B.CreateMul(x, y, "", cur->hasNoUnsignedWrap(),
cur->hasNoSignedWrap()); break; default: llvm_unreachable("Illegal opcode");
}
res = pushcse(B.CreateSelect(c1, res,
Constant::getNullValue(cur->getType())));
replaceAndErase(cur, res);
return;
}
}
*/
// add (select %c c0, x), (select %c, c1, y) -> select %c, (add c0, c1),
// (add x, y) and for sub/mul/cmp
if (cur->getOpcode() == Instruction::Add ||
cur->getOpcode() == Instruction::Sub ||
cur->getOpcode() == Instruction::Mul ||
cur->getOpcode() == Instruction::FAdd ||
cur->getOpcode() == Instruction::FSub ||
cur->getOpcode() == Instruction::FMul ||
// cur->getOpcode() == Instruction::SIToFP ||
// cur->getOpcode() == Instruction::UIToFP ||
cur->getOpcode() == Instruction::ICmp ||
cur->getOpcode() == Instruction::FCmp) {
Value *SI1cond = nullptr;
Value *SI1tval = nullptr;
Value *SI1fval = nullptr;
if (auto SI1 = dyn_cast<SelectInst>(cur->getOperand(0))) {
SI1cond = SI1->getCondition();
SI1tval = SI1->getTrueValue();
SI1fval = SI1->getFalseValue();
}
if (auto SI1 = dyn_cast<ZExtInst>(cur->getOperand(0)))
if (SI1->getOperand(0)->getType()->isIntegerTy(1)) {
SI1cond = SI1->getOperand(0);
SI1tval = SI1;
SI1fval = ConstantInt::get(SI1->getType(), 0);
}
if (auto SI1 = dyn_cast<SExtInst>(cur->getOperand(0)))
if (SI1->getOperand(0)->getType()->isIntegerTy(1)) {
SI1cond = SI1->getOperand(0);
SI1tval = SI1;
SI1fval = ConstantInt::get(SI1->getType(), 0);
}
Value *SI2cond = nullptr;
Value *SI2tval = nullptr;
Value *SI2fval = nullptr;
auto op2 = cur->getOperand((cur->getOpcode() == Instruction::SIToFP ||
cur->getOpcode() == Instruction::UIToFP)
? 0
: 1);
if (auto SI2 = dyn_cast<SelectInst>(op2)) {
SI2cond = SI2->getCondition();
SI2tval = SI2->getTrueValue();
SI2fval = SI2->getFalseValue();
}
if (auto SI2 = dyn_cast<ZExtInst>(op2))
if (SI2->getOperand(0)->getType()->isIntegerTy(1)) {
SI2cond = SI2->getOperand(0);
SI2tval = SI2;
SI2fval = ConstantInt::get(SI2->getType(), 0);
}
if (auto SI2 = dyn_cast<SExtInst>(op2))
if (SI2->getOperand(0)->getType()->isIntegerTy(1)) {
SI2cond = SI2->getOperand(0);
SI2tval = SI2;
SI2fval = ConstantInt::get(SI2->getType(), 0);
}
if (SI1cond && SI2cond && (SI1cond == SI2cond || isNot(SI1cond, SI2cond)))
if ((SI1cond == SI2cond &&
((isa<Constant>(SI1tval) && isa<Constant>(SI2tval)) ||
(isa<Constant>(SI1fval) && isa<Constant>(SI2fval)))) ||
(SI1cond != SI2cond &&
((isa<Constant>(SI1tval) && isa<Constant>(SI2fval)) ||
(isa<Constant>(SI1fval) && isa<Constant>(SI2tval))))
) {
Value *tval = nullptr;
Value *fval = nullptr;
bool inverted = SI1cond != SI2cond;
switch (cur->getOpcode()) {
case Instruction::SIToFP:
tval =
B.CreateSIToFP(SI1tval, cur->getType(), "tval." + cur->getName());
fval =
B.CreateSIToFP(SI1fval, cur->getType(), "fval." + cur->getName());
break;
case Instruction::UIToFP:
tval =
B.CreateUIToFP(SI1tval, cur->getType(), "tval." + cur->getName());
fval =
B.CreateUIToFP(SI1fval, cur->getType(), "fval." + cur->getName());
break;
case Instruction::FAdd:
tval = B.CreateFAddFMF(SI1tval, inverted ? SI2fval : SI2tval, cur,
"tval." + cur->getName());
fval = B.CreateFAddFMF(SI1fval, inverted ? SI2tval : SI2fval, cur,
"fval." + cur->getName());
break;
case Instruction::FSub:
tval = B.CreateFSubFMF(SI1tval, inverted ? SI2fval : SI2tval, cur,
"tval." + cur->getName());
fval = B.CreateFSubFMF(SI1fval, inverted ? SI2tval : SI2fval, cur,
"fval." + cur->getName());
break;
case Instruction::FMul:
tval = B.CreateFMulFMF(SI1tval, inverted ? SI2fval : SI2tval, cur,
"tval." + cur->getName());
fval = B.CreateFMulFMF(SI1fval, inverted ? SI2tval : SI2fval, cur,
"fval." + cur->getName());
break;
case Instruction::Add:
tval = B.CreateAdd(SI1tval, inverted ? SI2fval : SI2tval,
"tval." + cur->getName(), cur->hasNoUnsignedWrap(),
cur->hasNoSignedWrap());
fval = B.CreateAdd(SI1fval, inverted ? SI2tval : SI2fval,
"fval." + cur->getName(), cur->hasNoUnsignedWrap(),
cur->hasNoSignedWrap());
break;
case Instruction::Sub:
tval = B.CreateSub(SI1tval, inverted ? SI2fval : SI2tval,
"tval." + cur->getName(), cur->hasNoUnsignedWrap(),
cur->hasNoSignedWrap());
fval = B.CreateSub(SI1fval, inverted ? SI2tval : SI2fval,
"fval." + cur->getName(), cur->hasNoUnsignedWrap(),
cur->hasNoSignedWrap());
break;
case Instruction::Mul:
tval = B.CreateMul(SI1tval, inverted ? SI2fval : SI2tval,
"tval." + cur->getName(), cur->hasNoUnsignedWrap(),
cur->hasNoSignedWrap());
fval = B.CreateMul(SI1fval, inverted ? SI2tval : SI2fval,
"fval." + cur->getName(), cur->hasNoUnsignedWrap(),
cur->hasNoSignedWrap());
break;
case Instruction::ICmp:
case Instruction::FCmp:
tval = B.CreateCmp(cast<CmpInst>(cur)->getPredicate(), SI1tval,
inverted ? SI2fval : SI2tval,
"tval." + cur->getName());
fval = B.CreateCmp(cast<CmpInst>(cur)->getPredicate(), SI1fval,
inverted ? SI2tval : SI2fval,
"fval." + cur->getName());
break;
default:
llvm_unreachable("illegal opcode");
}
tval = pushcse(tval);
fval = pushcse(fval);
auto res = pushcse(
B.CreateSelect(SI1cond, tval, fval, "selmerge." + cur->getName()));
push(cur->getOperand(0));
push(cur->getOperand(1));
replaceAndErase(cur, res);
return "BinopSelFuse";
}
}
/*
// and (i == c), (i != d) -> and (i == c) && (c != d)
if (cur->getOpcode() == Instruction::And) {
auto lhs = replace(cur->getOperand(0), cur->getOperand(1),
ConstantInt::getTrue(cur->getContext()));
auto rhs = replace(cur->getOperand(1), cur->getOperand(0),
ConstantInt::getTrue(cur->getContext()));
if (lhs != cur->getOperand(0) || rhs != cur->getOperand(1)) {
auto res = pushcse(B.CreateAnd(lhs, rhs, "postand." + cur->getName()));
replaceAndErase(cur, res);
return "AndReplace2";
}
}
*/
// and a, (or q, (not a)) -> and a q
if (cur->getOpcode() == Instruction::And) {
for (size_t i1 = 0; i1 < 2; i1++)
if (auto inst2 = dyn_cast<Instruction>(cur->getOperand(1 - i1)))
if (inst2->getOpcode() == Instruction::Or)
for (size_t i2 = 0; i2 < 2; i2++)
if (isNot(cur->getOperand(i1), inst2->getOperand(i2))) {
auto q = inst2->getOperand(1 - i2);
cur->setOperand(1 - i1, q);
push(cur);
push(q);
push(inst2);
push(cur->getOperand(i1));
push(inst2->getOperand(i2));
Q.insert(cur);
for (auto U : cur->users())
push(U);
return "AndOrProp";
}
}
// and (and a, b), a) -> and a, b
if (cur->getOpcode() == Instruction::And) {
for (size_t i1 = 0; i1 < 2; i1++)
if (auto inst2 = dyn_cast<Instruction>(cur->getOperand(i1)))
if (inst2->getOpcode() == Instruction::And)
for (size_t i2 = 0; i2 < 2; i2++)
if (inst2->getOperand(i2) == cur->getOperand(1 - i1)) {
replaceAndErase(cur, inst2);
return "AndAndProp";
}
}
// or a, (and q, (not a)) -> and a q
if (cur->getOpcode() == Instruction::And) {
for (size_t i1 = 0; i1 < 2; i1++)
if (auto inst2 = dyn_cast<Instruction>(cur->getOperand(1 - i1)))
if (inst2->getOpcode() == Instruction::Or)
for (size_t i2 = 0; i2 < 2; i2++)
if (isNot(cur->getOperand(i1), inst2->getOperand(i2))) {
auto q = inst2->getOperand(1 - i2);
cur->setOperand(1 - i1, q);
push(cur);
push(q);
push(inst2);
push(cur->getOperand(i1));
push(inst2->getOperand(i2));
Q.insert(cur);
for (auto U : cur->users())
push(U);
return "OrAndProp";
}
}
// and ( (a +/- b) != c ), ( (d +/- b) != c ) -> and ( a != (c -/+ b) ), (
// d != (c -/+ b) )
// also with or
if (cur->getOpcode() == Instruction::And ||
cur->getOpcode() == Instruction::Or) {
for (auto cmpOp : {ICmpInst::ICMP_EQ, ICmpInst::ICMP_NE})
for (auto interOp : {Instruction::Add, Instruction::Sub})
if (auto cmp1 = dyn_cast<ICmpInst>(cur->getOperand(0)))
if (auto cmp2 = dyn_cast<ICmpInst>(cur->getOperand(1)))
for (size_t i1 = 0; i1 < 2; i1++)
for (size_t i2 = 0; i2 < 2; i2++)
if (cmp1->getOperand(1 - i1) == cmp2->getOperand(1 - i2) &&
cmp1->getPredicate() == cmpOp &&
cmp2->getPredicate() == cmpOp)
if (auto add1 = dyn_cast<Instruction>(cmp1->getOperand(i1)))
if (auto add2 = dyn_cast<Instruction>(cmp2->getOperand(i2)))
if (add1->getOpcode() == interOp &&
add2->getOpcode() == interOp)
for (size_t ia = 0; ia < 2; ia++)
if (add1->getOperand(ia) == add2->getOperand(ia)) {
auto b = add1->getOperand(ia);
auto c = cmp1->getOperand(1 - i1);
auto a = add1->getOperand(1 - ia);
auto d = add2->getOperand(1 - ia);
Value *res = nullptr;
if (interOp == Instruction::Add)
res = pushcse(B.CreateSub(ia == 0 ? b : c,
ia == 0 ? c : b));
else
res = pushcse(B.CreateAdd(ia == 0 ? b : c,
ia == 0 ? c : b));
auto lhs = pushcse(B.CreateCmp(cmpOp, a, res));
auto rhs = pushcse(B.CreateCmp(cmpOp, d, res));
Value *fres = nullptr;
if (cur->getOpcode() == Instruction::And)
fres = pushcse(B.CreateAnd(lhs, rhs));
else
fres = pushcse(B.CreateOr(lhs, rhs));
replaceAndErase(cur, fres);
return "AndLinearShift";
}
}
// and ( expr == c1 ), ( expr == c2 ) and c1 != c2 -> false
if (cur->getOpcode() == Instruction::And) {
for (auto cmpOp : {ICmpInst::ICMP_EQ})
if (auto cmp1 = dyn_cast<ICmpInst>(cur->getOperand(0)))
if (auto cmp2 = dyn_cast<ICmpInst>(cur->getOperand(1)))
for (size_t i1 = 0; i1 < 2; i1++)
for (size_t i2 = 0; i2 < 2; i2++)
if (cmp1->getOperand(1 - i1) == cmp2->getOperand(1 - i2) &&
cmp1->getPredicate() == cmpOp &&
cmp2->getPredicate() == cmpOp) {
auto c1 = SE.getSCEV(cmp1->getOperand(i1));
auto c2 = SE.getSCEV(cmp2->getOperand(i2));
auto m = SE.getMinusSCEV(c1, c2, SCEV::NoWrapMask);
if (auto C = dyn_cast<SCEVConstant>(m)) {
// if c1 == c2 don't need the and they are equivalent
if (C->getValue()->isZero()) {
push(cmp1);
push(cmp2);
replaceAndErase(cur, cmp1);
return "AndEQExpr";
} else {
// if non one constant they must be distinct.
replaceAndErase(cur,
ConstantInt::getFalse(cur->getContext()));
return "AndNEExpr";
}
}
}
}
// (a | b) == 0 -> a == 0 & b == 0
if (auto icmp = dyn_cast<ICmpInst>(cur))
if (icmp->getPredicate() == ICmpInst::ICMP_EQ &&
cur->getType()->isIntegerTy(1))
for (int i = 0; i < 2; i++)
if (auto C = dyn_cast<ConstantInt>(icmp->getOperand(i)))
if (C->isZero())
if (auto z = dyn_cast<BinaryOperator>(icmp->getOperand(1 - i)))
if (z->getOpcode() == BinaryOperator::Or) {
auto a0 = pushcse(B.CreateICmpEQ(z->getOperand(0), C));
auto b0 = pushcse(B.CreateICmpEQ(z->getOperand(1), C));
auto res = pushcse(B.CreateAnd(a0, b0));
push(z);
push(icmp);
replaceAndErase(cur, res);
return "OrEQZero";
}
// add (mul a b), (mul c, b) -> mul (add a, c), b
if (cur->getOpcode() == Instruction::Sub ||
cur->getOpcode() == Instruction::Add) {
if (auto mul1 = dyn_cast<Instruction>(cur->getOperand(0)))
if (auto mul2 = dyn_cast<Instruction>(cur->getOperand(1)))
if ((mul1->getOpcode() == Instruction::Mul &&
mul2->getOpcode() == Instruction::Mul) ||
(mul1->getOpcode() == Instruction::FMul &&
mul2->getOpcode() == Instruction::FMul && mul1->isFast() &&
mul2->isFast() && cur->isFast())) {
for (int i1 = 0; i1 < 2; i1++)
for (int i2 = 0; i2 < 2; i2++) {
if (mul1->getOperand(i1) == mul2->getOperand(i2)) {
Value *res = nullptr;
switch (cur->getOpcode()) {
case Instruction::Add:
res = B.CreateAdd(mul1->getOperand(1 - i1),
mul2->getOperand(1 - i2));
break;
case Instruction::Sub:
res = B.CreateSub(mul1->getOperand(1 - i1),
mul2->getOperand(1 - i2));
break;
case Instruction::FAdd:
res = B.CreateFAddFMF(mul1->getOperand(1 - i1),
mul2->getOperand(1 - i2), cur);
break;
case Instruction::FSub:
res = B.CreateFSubFMF(mul1->getOperand(1 - i1),
mul2->getOperand(1 - i2), cur);
break;
default:
llvm_unreachable("Illegal opcode");
}
res = pushcse(res);
Value *res2 = nullptr;
if (cur->getType()->isIntegerTy())
res2 = B.CreateMul(
res, mul1->getOperand(i1), "",
mul1->hasNoUnsignedWrap() && mul1->hasNoUnsignedWrap(),
mul2->hasNoSignedWrap() && mul2->hasNoSignedWrap());
else
res2 = B.CreateFMulFMF(res, mul1->getOperand(i1), cur);
res2 = pushcse(res2);
replaceAndErase(cur, res2);
return "InvDistributive";
}
}
}
}
// fadd (ext a), (ext b) -> ext (a + b)
// fsub (ext a), (ext b) -> ext (a - b)
// fmul (ext a), (ext b) -> ext (a * b)
if (cur->getOpcode() == Instruction::FSub ||
cur->getOpcode() == Instruction::FAdd ||
cur->getOpcode() == Instruction::FMul ||
cur->getOpcode() == Instruction::FNeg ||
(isSum(cur) && callOperands(cast<CallBase>(cur)).size() == 2)) {
auto opcode = cur->getOpcode();
if (isSum(cur))
opcode = Instruction::FAdd;
auto Ty = B.getInt64Ty();
SmallPtrSet<Instruction *, 1> temporaries;
SmallVector<Instruction *, 1> precasts;
Value *lhs = nullptr;
Value *prelhs = (cur->getOpcode() == Instruction::FNeg)
? ConstantFP::get(cur->getType(), 0.0)
: cur->getOperand(0);
Value *prerhs = (cur->getOpcode() == Instruction::FNeg)
? cur->getOperand(0)
: cur->getOperand(1);
APInt minval(64, 0);
APInt maxval(64, 0);
if (auto C = dyn_cast<ConstantFP>(prelhs)) {
APSInt Tmp(64);
bool isExact = false;
C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero,
&isExact);
if (isExact || C->isZero()) {
minval = maxval = Tmp;
lhs = ConstantInt::get(Ty, Tmp);
}
}
if (auto ext = dyn_cast<CastInst>(prelhs)) {
if (ext->getOpcode() == Instruction::UIToFP ||
ext->getOpcode() == Instruction::SIToFP) {
precasts.push_back(ext);
auto ity = cast<IntegerType>(ext->getOperand(0)->getType());
bool md = false;
if (auto I = dyn_cast<Instruction>(ext->getOperand(0)))
if (auto MD = hasMetadata(I, LLVMContext::MD_range)) {
md = true;
minval =
cast<ConstantInt>(
cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
->getValue()
.zextOrTrunc(64);
maxval =
cast<ConstantInt>(
cast<ConstantAsMetadata>(MD->getOperand(1))->getValue())
->getValue()
.zextOrTrunc(64);
}
if (!md) {
if (ext->getOpcode() == Instruction::UIToFP)
maxval = APInt::getMaxValue(ity->getBitWidth()).zextOrTrunc(64);
else {
maxval =
APInt::getSignedMaxValue(ity->getBitWidth()).zextOrTrunc(64);
minval =
APInt::getSignedMinValue(ity->getBitWidth()).zextOrTrunc(64);
}
}
if (ext->getOperand(0)->getType() == Ty)
lhs = ext->getOperand(0);
else if (ity->getBitWidth() < Ty->getBitWidth()) {
if (ext->getOpcode() == Instruction::UIToFP)
lhs = B.CreateZExt(ext->getOperand(0), Ty);
else
lhs = B.CreateSExt(ext->getOperand(0), Ty);
if (auto I = dyn_cast<Instruction>(lhs))
if (I != ext->getOperand(0))
temporaries.insert(I);
}
}
}
Value *rhs = nullptr;
if (auto C = dyn_cast<ConstantFP>(prerhs)) {
APSInt Tmp(64);
bool isExact = false;
C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero,
&isExact);
if (isExact || C->isZero()) {
rhs = ConstantInt::get(Ty, Tmp);
switch (opcode) {
case Instruction::FAdd:
minval += Tmp;
maxval += Tmp;
break;
case Instruction::FSub:
case Instruction::FNeg:
minval -= Tmp;
maxval -= Tmp;
break;
case Instruction::FMul:
minval *= Tmp;
maxval *= Tmp;
break;
default:
llvm_unreachable("Illegal opcode");
}
}
}
if (auto ext = dyn_cast<CastInst>(prerhs)) {
if (ext->getOpcode() == Instruction::UIToFP ||
ext->getOpcode() == Instruction::SIToFP) {
precasts.push_back(ext);
auto ity = cast<IntegerType>(ext->getOperand(0)->getType());
bool md = false;
APInt rhsMin(64, 0);
APInt rhsMax(64, 0);
if (auto I = dyn_cast<Instruction>(ext->getOperand(0)))
if (auto MD = hasMetadata(I, LLVMContext::MD_range)) {
md = true;
rhsMin =
cast<ConstantInt>(
cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
->getValue()
.zextOrTrunc(64);
rhsMax =
cast<ConstantInt>(
cast<ConstantAsMetadata>(MD->getOperand(1))->getValue())
->getValue()
.zextOrTrunc(64);
}
if (!md) {
if (ext->getOpcode() == Instruction::UIToFP) {
rhsMax = APInt::getMaxValue(ity->getBitWidth()).zextOrTrunc(64);
rhsMin = APInt(64, 0);
} else {
rhsMax =
APInt::getSignedMaxValue(ity->getBitWidth()).zextOrTrunc(64);
rhsMin =
APInt::getSignedMinValue(ity->getBitWidth()).zextOrTrunc(64);
}
}
switch (opcode) {
case Instruction::FAdd:
minval += rhsMin;
maxval += rhsMax;
break;
case Instruction::FSub:
case Instruction::FNeg:
minval -= rhsMax;
maxval -= rhsMin;
break;
case Instruction::FMul: {
auto minf = [&](APInt a, APInt b) { return a.sle(b) ? a : b; };
auto maxf = [&](APInt a, APInt b) { return a.sle(b) ? b : b; };
minval = minf(
minval * rhsMin,
minf(minval * rhsMax, minf(maxval * rhsMin, maxval * rhsMax)));
maxval = maxf(
minval * rhsMin,
maxf(minval * rhsMax, maxf(maxval * rhsMin, maxval * rhsMax)));
break;
}
default:
llvm_unreachable("Illegal opcode");
}
if (ext->getOperand(0)->getType() == Ty)
rhs = ext->getOperand(0);
else if (ity->getBitWidth() < Ty->getBitWidth()) {
if (ext->getOpcode() == Instruction::UIToFP)
rhs = B.CreateZExt(ext->getOperand(0), Ty);
else
rhs = B.CreateSExt(ext->getOperand(0), Ty);
if (auto I = dyn_cast<Instruction>(rhs))
if (I != ext->getOperand(0))
temporaries.insert(I);
}
}
}
if (lhs && rhs) {
Value *res = nullptr;
if (temporaries.count(dyn_cast<Instruction>(lhs)))
lhs = pushcse(lhs);
if (temporaries.count(dyn_cast<Instruction>(rhs)))
rhs = pushcse(rhs);
switch (opcode) {
case Instruction::FAdd:
res = B.CreateAdd(lhs, rhs, "", false, true);
break;
case Instruction::FSub:
case Instruction::FNeg:
res = B.CreateSub(lhs, rhs, "", false, true);
break;
case Instruction::FMul:
res = B.CreateMul(lhs, rhs, "", false, true);
break;
default:
llvm_unreachable("Illegal opcode");
}
res = pushcse(res);
for (auto I : precasts)
push(I);
/*
if (auto I = dyn_cast<Instruction>(res)) {
Q.insert(I);
Metadata *vals[] = {(Metadata *)ConstantAsMetadata::get(
ConstantInt::get(Ty, minval)),
(Metadata *)ConstantAsMetadata::get(
ConstantInt::get(Ty, maxval))};
I->setMetadata(LLVMContext::MD_range,
MDNode::get(I->getContext(), vals));
}
*/
auto ext = pushcse(B.CreateSIToFP(res, cur->getType()));
replaceAndErase(cur, ext);
return "BinopExtToExtBinop";
} else {
for (auto I : temporaries)
I->eraseFromParent();
}
}
// select(cond, const1, b) ?= const2 -> select(cond, const1 ?= const2, b ?=
// const2)
if (auto fcmp = dyn_cast<FCmpInst>(cur))
for (int i = 0; i < 2; i++)
if (auto const2 = dyn_cast<Constant>(fcmp->getOperand(i)))
if (auto sel = dyn_cast<SelectInst>(fcmp->getOperand(1 - i)))
if (isa<Constant>(sel->getTrueValue()) ||
isa<Constant>(sel->getFalseValue())) {
auto tval = pushcse(B.CreateFCmp(fcmp->getPredicate(),
sel->getTrueValue(), const2));
auto fval = pushcse(B.CreateFCmp(fcmp->getPredicate(),
sel->getFalseValue(), const2));
auto res = pushcse(B.CreateSelect(sel->getCondition(), tval, fval));
replaceAndErase(cur, res);
return "FCmpSelectConst";
}
// mul (mul a, const), b:not_sparse_or_const -> mul (mul a, b), const
// note we avoid the case where b = (mul a, const) since otherwise
// we create an infinite recursion
// and also we make sure b isn't sparse, since sparse is the first
// precedence for pushing, then constant, then others
if (cur->getOpcode() == Instruction::FMul)
if (cur->isFast() && cur->getOperand(0) != cur->getOperand(1))
for (auto ic = 0; ic < 2; ic++)
if (auto mul = dyn_cast<Instruction>(cur->getOperand(ic)))
if (mul->getOpcode() == Instruction::FMul && mul->isFast()) {
auto b = cur->getOperand(1 - ic);
if (!isa<Constant>(b) && !directlySparse(b)) {
for (int i = 0; i < 2; i++)
if (auto C = dyn_cast<Constant>(mul->getOperand(i))) {
auto n0 =
pushcse(B.CreateFMulFMF(mul->getOperand(1 - i), b, mul));
auto n1 = pushcse(B.CreateFMulFMF(n0, C, cur));
push(mul);
replaceAndErase(cur, n1);
return "MulMulConst";
}
}
}
// (mul c, a) +/- (mul c, b) -> mul c, (a +/- b)
if (cur->getOpcode() == Instruction::FAdd ||
cur->getOpcode() == Instruction::FSub) {
if (auto mul1 = dyn_cast<BinaryOperator>(cur->getOperand(0))) {
if (mul1->getOpcode() == Instruction::FMul && mul1->isFast()) {
if (auto mul2 = dyn_cast<BinaryOperator>(cur->getOperand(1))) {
if (mul2->getOpcode() == Instruction::FMul && mul2->isFast()) {
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
if (mul1->getOperand(i) == mul2->getOperand(j)) {
auto c = mul1->getOperand(i);
auto a = mul1->getOperand(1 - i);
auto b = mul2->getOperand(1 - j);
Value *intermediate = nullptr;
if (cur->getOpcode() == Instruction::FAdd)
intermediate = pushcse(B.CreateFAddFMF(a, b, cur));
else
intermediate = pushcse(B.CreateFSubFMF(a, b, cur));
auto res = pushcse(B.CreateFMulFMF(c, intermediate, cur));
push(mul1);
push(mul2);
replaceAndErase(cur, res);
return "FAddMulConstMulConst";
}
}
}
}
}
}
}
}
// fmul a, (sitofp (imul c:const, b)) -> fmul (fmul (a, (sitofp c))),
// (sitofp b)
if (cur->getOpcode() == Instruction::FMul && cur->isFast()) {
for (int i = 0; i < 2; i++)
if (auto z = dyn_cast<Instruction>(cur->getOperand(i)))
if (isa<SIToFPInst>(z) || isa<UIToFPInst>(z))
if (auto imul = dyn_cast<BinaryOperator>(z->getOperand(0)))
if (imul->getOpcode() == Instruction::Mul)
for (int j = 0; j < 2; j++)
if (auto c = dyn_cast<Constant>(imul->getOperand(j))) {
auto b = imul->getOperand(1 - j);
auto a = cur->getOperand(1 - i);
auto c_fp = pushcse(B.CreateSIToFP(c, cur->getType()));
auto b_fp = pushcse(B.CreateSIToFP(b, cur->getType()));
auto n_mul = pushcse(B.CreateFMulFMF(a, c_fp, cur));
auto res = pushcse(
B.CreateFMulFMF(n_mul, b_fp, cur, cur->getName()));
push(imul);
push(z);
replaceAndErase(cur, res);
return "FMulIMulConstRotate";
}
}
if (cur->getOpcode() == Instruction::FDiv) {
Value *prelhs = cur->getOperand(0);
Value *b = cur->getOperand(1);
// fdiv (sitofp a), b -> select (a == 0), 0 [ (fdiv 1 / b) * sitofp a]
if (auto ext = dyn_cast<CastInst>(prelhs)) {
if (ext->getOpcode() == Instruction::UIToFP ||
ext->getOpcode() == Instruction::SIToFP) {
push(ext);
Value *condition = pushcse(
B.CreateICmpEQ(ext->getOperand(0),
ConstantInt::get(ext->getOperand(0)->getType(), 0),
"sdivcmp." + cur->getName()));
Value *fdiv = pushcse(
B.CreateFMulFMF(pushcse(B.CreateFDivFMF(
ConstantFP::get(cur->getType(), 1.0), b, cur)),
ext, cur));
Value *sel = pushcse(
B.CreateSelect(condition, ConstantFP::get(cur->getType(), 0.0),
fdiv, "sfdiv." + cur->getName()));
replaceAndErase(cur, sel);
return "FDivSIToFPProp";
}
}
// fdiv (select c, 0, a), b -> select c, 0 (fdiv a, b)
if (auto SI = dyn_cast<SelectInst>(prelhs)) {
auto tvalC = dyn_cast<ConstantFP>(SI->getTrueValue());
auto fvalC = dyn_cast<ConstantFP>(SI->getFalseValue());
if ((tvalC && tvalC->isZero()) || (fvalC && fvalC->isZero())) {
push(SI);
auto ntval =
(tvalC && tvalC->isZero())
? tvalC
: pushcse(B.CreateFDivFMF(SI->getTrueValue(), b, cur,
"sfdiv2_t." + cur->getName()));
auto nfval =
(fvalC && fvalC->isZero())
? fvalC
: pushcse(B.CreateFDivFMF(SI->getFalseValue(), b, cur,
"sfdiv2_f." + cur->getName()));
// Work around bad fdivfmf, fixed in LLVM 16+
// https://github.com/llvm/llvm-project/commit/4f3b1c6dd6ef6c7b5bb79f058e3b7ba4bcdf4566
#if LLVM_VERSION_MAJOR < 16
for (auto v : {ntval, nfval})
if (auto I = dyn_cast<Instruction>(v))
I->setFastMathFlags(cur->getFastMathFlags());
#endif
auto res = pushcse(B.CreateSelect(SI->getCondition(), ntval, nfval,
"sfdiv2." + cur->getName()));
replaceAndErase(cur, res);
return "FDivSelectProp";
}
}
}
// div (mul a:not_sparse, b:is_sparse), c -> mul (div, a, c), b:is_sparse
if (cur->getOpcode() == Instruction::FDiv) {
auto c = cur->getOperand(1);
if (auto z = dyn_cast<BinaryOperator>(cur->getOperand(0))) {
if (z->getOpcode() == Instruction::FMul) {
for (int i = 0; i < 2; i++) {
Value *a = z->getOperand(i);
Value *b = z->getOperand(1 - i);
if (directlySparse(a))
continue;
if (!directlySparse(b))
continue;
Value *inner_fdiv = pushcse(B.CreateFDivFMF(a, c, cur));
Value *outer_fmul = pushcse(B.CreateFMulFMF(inner_fdiv, b, z));
push(z);
replaceAndErase(cur, outer_fmul);
return "FDivFMulSparseProp";
}
}
}
}
if (cur->getOpcode() == Instruction::FMul)
for (int i = 0; i < 2; i++) {
Value *prelhs = cur->getOperand(i);
Value *b = cur->getOperand(1 - i);
// fmul (fmul x:constant, y):z, b:constant .
if (isa<Constant>(b))
if (auto z = dyn_cast<BinaryOperator>(prelhs)) {
if (z->getOpcode() == Instruction::FMul) {
for (int j = 0; j < 2; j++) {
auto x = z->getOperand(i);
if (!isa<Constant>(x))
continue;
auto y = z->getOperand(1 - i);
Value *inner_fmul = pushcse(B.CreateFMulFMF(x, b, cur));
Value *outer_fmul = pushcse(B.CreateFMulFMF(inner_fmul, y, z));
push(z);
replaceAndErase(cur, outer_fmul);
return "FMulFMulConstantReorder";
}
}
}
auto integralFloat = [](Value *z) {
if (auto C = dyn_cast<ConstantFP>(z)) {
APSInt Tmp(64);
bool isExact = false;
C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero,
&isExact);
if (isExact || C->isZero()) {
return true;
}
}
return false;
};
// fmul (fmul x:sparse, y):z, b
// 1) If x and y are both sparse, do nothing and let the inner fmul be
// simplified into a single sparse instruction. Thus, we may assume
// y is not sparse.
// 2) if b is sparse, swap it to be fmul (fmul x, b), y so the inner
// sparsity can be simplified.
// 3) otherwise b is not sparse and we should push the sparsity to
// be the outermost value
if (auto z = dyn_cast<BinaryOperator>(prelhs)) {
if (z->getOpcode() == Instruction::FMul) {
for (int j = 0; j < 2; j++) {
auto x = z->getOperand(j);
if (!directlySparse(x))
continue;
auto y = z->getOperand(1 - j);
if (directlySparse(y))
continue;
if (directlySparse(b) || integralFloat(b)) {
push(z);
Value *inner_fmul = pushcse(
B.CreateFMulFMF(x, b, cur, "mulisr." + cur->getName()));
Value *outer_fmul = pushcse(
B.CreateFMulFMF(inner_fmul, y, z, "mulisr." + z->getName()));
replaceAndErase(cur, outer_fmul);
return "FMulFMulSparseReorder";
} else {
push(z);
Value *inner_fmul = pushcse(
B.CreateFMulFMF(y, b, cur, "mulisp." + cur->getName()));
Value *outer_fmul = pushcse(
B.CreateFMulFMF(inner_fmul, x, z, "mulisp." + z->getName()));
replaceAndErase(cur, outer_fmul);
return "FMulFMulSparsePush";
}
}
}
}
/*
auto contains = [](MDNode *MD, Value *V) {
if (!MD)
return false;
for (auto &op : MD->operands()) {
auto V2 = cast<ValueAsMetadata>(op)->getValue();
if (V == V2)
return true;
}
return false;
};
// fmul (sitofp a), b -> select (a == 0), 0 [noprop fmul ( sitofp a), b]
if (true || !contains(hasMetadata(cur, "enzyme_fmulnoprop"), prelhs))
if (auto ext = dyn_cast<CastInst>(prelhs)) {
if (ext->getOpcode() == Instruction::UIToFP ||
ext->getOpcode() == Instruction::SIToFP) {
push(ext);
Value *condition = pushcse(B.CreateICmpEQ(
ext->getOperand(0),
ConstantInt::get(ext->getOperand(0)->getType(), 0),
"mulcsicmp." + cur->getName()));
Value *fmul = pushcse(B.CreateFMulFMF(ext, b, cur));
if (auto I = dyn_cast<Instruction>(fmul)) {
SmallVector<Metadata *, 1> nodes;
if (auto MD = hasMetadata(cur, "enzyme_fmulnoprop")) {
for (auto &M : MD->operands()) {
nodes.push_back(M.get());
}
}
nodes.push_back(ValueAsMetadata::get(ext));
I->setMetadata("enzyme_fmulnoprop",
MDNode::get(I->getContext(), nodes));
}
Value *sel = pushcse(
B.CreateSelect(condition, ConstantFP::get(cur->getType(),
0.0), fmul, "mulcsi." + cur->getName()));
replaceAndErase(cur, sel);
return "FMulSIToFPProp";
}
}
*/
// fmul (select c, 0, a), b -> select c, 0 (fmul a, b)
if (auto SI = dyn_cast<SelectInst>(prelhs)) {
auto tvalC = dyn_cast<ConstantFP>(SI->getTrueValue());
auto fvalC = dyn_cast<ConstantFP>(SI->getFalseValue());
if ((tvalC && tvalC->isZero()) || (fvalC && fvalC->isZero())) {
push(SI);
auto ntval =
(tvalC && tvalC->isZero())
? tvalC
: pushcse(B.CreateFMulFMF(SI->getTrueValue(), b, cur));
auto nfval =
(fvalC && fvalC->isZero())
? fvalC
: pushcse(B.CreateFMulFMF(SI->getFalseValue(), b, cur));
auto res = pushcse(B.CreateSelect(SI->getCondition(), ntval, nfval,
"mulsi." + cur->getName()));
replaceAndErase(cur, res);
return "FMulSelectProp";
}
}
}
if (auto icmp = dyn_cast<BinaryOperator>(cur)) {
if (icmp->getOpcode() == Instruction::Xor) {
for (int i = 0; i < 2; i++) {
if (auto C = dyn_cast<ConstantInt>(icmp->getOperand(i))) {
// !(cmp a, b) -> inverse(cmp), a, b
if (C->isOne()) {
if (auto scmp = dyn_cast<CmpInst>(icmp->getOperand(1 - i))) {
auto next = pushcse(
B.CreateCmp(scmp->getInversePredicate(), scmp->getOperand(0),
scmp->getOperand(1), "not." + scmp->getName()));
replaceAndErase(cur, next);
return "NotCmp";
}
}
}
}
}
}
// select cmp, (ext tval), (ext fval) -> (cmp & tval) | (!cmp & fval)
if (auto SI = dyn_cast<SelectInst>(cur)) {
Value *trueVal = nullptr;
if (auto C = dyn_cast<ConstantFP>(SI->getTrueValue())) {
if (C->isZero()) {
trueVal = ConstantInt::getFalse(SI->getContext());
}
if (C->isExactlyValue(1.0)) {
trueVal = ConstantInt::getTrue(SI->getContext());
}
}
if (auto ext = dyn_cast<CastInst>(SI->getTrueValue())) {
if (ext->getOperand(0)->getType()->isIntegerTy(1))
trueVal = ext->getOperand(0);
}
Value *falseVal = nullptr;
if (auto C = dyn_cast<ConstantFP>(SI->getFalseValue())) {
if (C->isZero()) {
falseVal = ConstantInt::getFalse(SI->getContext());
}
if (C->isExactlyValue(1.0)) {
falseVal = ConstantInt::getTrue(SI->getContext());
}
}
if (auto ext = dyn_cast<CastInst>(SI->getFalseValue())) {
if (ext->getOperand(0)->getType()->isIntegerTy(1))
falseVal = ext->getOperand(0);
}
if (trueVal && falseVal) {
auto ncmp1 = pushcse(B.CreateAnd(SI->getCondition(), trueVal));
auto notV = pushcse(B.CreateNot(SI->getCondition()));
auto ncmp2 = pushcse(B.CreateAnd(notV, falseVal));
auto ori = pushcse(B.CreateOr(ncmp1, ncmp2));
auto ext = pushcse(B.CreateUIToFP(ori, SI->getType()));
replaceAndErase(cur, ext);
return "SelectI1Ext";
}
}
// select cmp, (i1 tval), (i1 fval) -> (cmp & tval) | (!cmp & fval)
if (cur->getType()->isIntegerTy(1))
if (auto SI = dyn_cast<SelectInst>(cur)) {
auto ncmp1 = pushcse(B.CreateAnd(SI->getCondition(), SI->getTrueValue()));
auto notV = pushcse(B.CreateNot(SI->getCondition()));
auto ncmp2 = pushcse(B.CreateAnd(notV, SI->getFalseValue()));
auto ori = pushcse(B.CreateOr(ncmp1, ncmp2));
replaceAndErase(cur, ori);
return "SelectI1";
}
if (auto PN = dyn_cast<PHINode>(cur)) {
B.SetInsertPoint(PN->getParent()->getFirstNonPHI());
if (SE.isSCEVable(PN->getType())) {
auto S = SE.getSCEV(PN);
bool legal = false;
if (auto SV = dyn_cast<SCEVUnknown>(S)) {
auto val = SV->getValue();
legal |= isa<Constant>(val) || isa<Argument>(val);
if (auto I = dyn_cast<Instruction>(val)) {
auto L = LI.getLoopFor(I->getParent());
if ((!L || L->getCanonicalInductionVariable() != I) && I != PN)
legal = true;
}
}
if (isa<SCEVAddRecExpr>(S)) {
auto L = LI.getLoopFor(PN->getParent());
assert(L);
if (L->getCanonicalInductionVariable() != PN)
legal = true;
}
if (legal) {
for (auto U : cur->users()) {
push(U);
}
auto point = PN->getParent()->getFirstNonPHI();
auto tmp = cast<PHINode>(pushcse(B.CreatePHI(cur->getType(), 1)));
cur->replaceAllUsesWith(tmp);
cur->eraseFromParent();
Value *newIV = nullptr;
{
SCEVExpander Exp(SE, DL, "sparseenzyme");
// We place that at first non phi as it may produce a non-phi
// instruction and must thus be expanded after all phi's
newIV = Exp.expandCodeFor(S, tmp->getType(), point);
// sadly this doesn't exist on 11
for (auto I : Exp.getAllInsertedInstructions())
Q.insert(I);
}
tmp->replaceAllUsesWith(newIV);
tmp->eraseFromParent();
return "InductVarSCEV";
}
}
// phi a, a -> a
{
bool legal = true;
for (size_t i = 1; i < PN->getNumIncomingValues(); i++) {
auto v = PN->getIncomingValue(i);
if (v != PN->getIncomingValue(0)) {
legal = false;
break;
}
}
if (legal) {
auto val = PN->getIncomingValue(0);
replaceAndErase(cur, val);
return "PhiMerge";
}
}
// phi (idx=0) ? b, a, a -> select (idx == 0), b, a
if (auto L = LI.getLoopFor(PN->getParent()))
if (L->getHeader() == PN->getParent())
if (auto idx = L->getCanonicalInductionVariable())
if (auto PH = L->getLoopPreheader()) {
bool legal = idx != PN;
auto ph_idx = PN->getBasicBlockIndex(PH);
assert(ph_idx >= 0);
for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
if ((int)i == ph_idx)
continue;
auto v = PN->getIncomingValue(i);
if (v != PN->getIncomingValue(1 - ph_idx)) {
legal = false;
break;
}
// The given var must dominate the loop
if (isa<Constant>(v))
continue;
if (isa<Argument>(v))
continue;
// exception for the induction itself, which we handle specially
if (v == idx)
continue;
auto I = cast<Instruction>(v);
if (!DT.dominates(I, PN)) {
legal = false;
break;
}
}
if (legal) {
auto val = PN->getIncomingValue(1 - ph_idx);
push(val);
if (val == idx) {
val = pushcse(
B.CreateSub(idx, ConstantInt::get(idx->getType(), 1)));
}
auto val2 = PN->getIncomingValue(ph_idx);
push(val2);
auto c0 = ConstantInt::get(idx->getType(), 0);
// if (val2 == c0 && PN->getIncomingValue(1 - ph_idx) == idx) {
// val = B.CreateBinaryIntrinsic(Intrinsic::umax, c0, val);
//} else {
auto eq = pushcse(B.CreateICmpEQ(idx, c0));
val = pushcse(
B.CreateSelect(eq, val2, val, "phisel." + cur->getName()));
//}
replaceAndErase(cur, val);
return "PhiLoop0Sel";
}
}
// phi (sitofp a), (sitofp b) -> sitofp (phi a, b)
{
SmallVector<Value *, 1> negOps;
SmallVector<Instruction *, 1> prevNegOps;
bool legal = true;
for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
auto v = PN->getIncomingValue(i);
if (auto C = dyn_cast<ConstantFP>(v)) {
APSInt Tmp(64);
bool isExact = false;
C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero,
&isExact);
if (isExact || C->isZero()) {
negOps.push_back(ConstantInt::get(B.getInt64Ty(), Tmp));
continue;
}
}
if (auto fneg = dyn_cast<Instruction>(v)) {
if (fneg->getOpcode() == Instruction::SIToFP &&
cast<IntegerType>(fneg->getOperand(0)->getType())
->getBitWidth() == 64) {
negOps.push_back(fneg->getOperand(0));
prevNegOps.push_back(fneg);
continue;
}
}
legal = false;
}
if (legal) {
auto PN2 = cast<PHINode>(
pushcse(B.CreatePHI(B.getInt64Ty(), PN->getNumIncomingValues())));
PN2->takeName(PN);
for (auto val : llvm::enumerate(negOps))
PN2->addIncoming(val.value(), PN->getIncomingBlock(val.index()));
push(PN2);
auto fneg = pushcse(B.CreateSIToFP(PN2, PN->getType()));
for (auto I : prevNegOps)
push(I);
replaceAndErase(cur, fneg);
return "PhiSIToFP";
}
}
// phi (fneg a), (fneg b) -> fneg (phi a, b)
{
SmallVector<Value *, 1> negOps;
SmallVector<Instruction *, 1> prevNegOps;
bool legal = true;
bool hasNeg = false;
for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
auto v = PN->getIncomingValue(i);
if (auto C = dyn_cast<ConstantFP>(v)) {
negOps.push_back(C->isZero() ? C : pushcse(B.CreateFNeg(C)));
continue;
}
if (auto fneg = dyn_cast<Instruction>(v)) {
if (fneg->getOpcode() == Instruction::FNeg) {
negOps.push_back(fneg->getOperand(0));
prevNegOps.push_back(fneg);
continue;
}
}
legal = false;
}
if (legal && hasNeg) {
for (auto val : llvm::enumerate(negOps))
PN->setIncomingValue(val.index(), val.value());
push(PN);
auto fneg = pushcse(B.CreateFNeg(PN));
for (auto &U : cur->uses()) {
if (U.getUser() == fneg)
continue;
push(U.getUser());
U.set(fneg);
}
for (auto I : prevNegOps)
push(I);
return "PhiFNeg";
}
}
// phi (neg a), (neg b) -> neg (phi a, b)
{
SmallVector<Value *, 1> negOps;
SmallVector<Instruction *, 1> prevNegOps;
bool legal = true;
bool hasNeg = false;
for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
auto v = PN->getIncomingValue(i);
if (auto C = dyn_cast<ConstantInt>(v)) {
negOps.push_back(pushcse(B.CreateNeg(C)));
continue;
}
if (auto fneg = dyn_cast<BinaryOperator>(v)) {
if (auto CI = dyn_cast<ConstantInt>(fneg->getOperand(0)))
if (fneg->getOpcode() == Instruction::Sub && CI->isZero()) {
negOps.push_back(fneg->getOperand(1));
prevNegOps.push_back(fneg);
hasNeg = true;
continue;
}
}
legal = false;
}
if (legal && hasNeg) {
for (auto val : llvm::enumerate(negOps))
PN->setIncomingValue(val.index(), val.value());
push(PN);
auto fneg = pushcse(B.CreateNeg(PN));
for (auto &U : cur->uses()) {
if (U.getUser() == fneg)
continue;
push(U.getUser());
U.set(fneg);
}
for (auto I : prevNegOps)
push(I);
return "PHINeg";
}
}
// p = phi (mul a, c), (mul b, d) -> mul (phi a, b), (phi c, d) if
// a,b,c != p
{
for (auto code :
{(unsigned)Instruction::Mul, (unsigned)Instruction::Sub,
(unsigned)Instruction::Add, (unsigned)Instruction::ZExt,
(unsigned)Instruction::UIToFP, (unsigned)Instruction::ICmp,
(unsigned)Instruction::FMul, (unsigned)Instruction::Or,
(unsigned)Instruction::And}) {
SmallVector<Value *, 1> lhsOps;
SmallVector<Value *, 1> rhsOps;
SmallVector<Instruction *, 1> prevOps;
bool legal = true;
bool fast = false;
bool NUW = false;
bool NSW = false;
size_t numOps = 0;
std::optional<llvm::CmpInst::Predicate> cmpPredicate;
switch (code) {
case Instruction::FMul:
case Instruction::FSub:
case Instruction::FAdd:
fast = true;
numOps = 2;
break;
case Instruction::Mul:
case Instruction::Add:
NUW = NSW = true;
numOps = 2;
break;
case Instruction::Sub:
NSW = true;
numOps = 2;
break;
case Instruction::ICmp:
case Instruction::FCmp:
case Instruction::Or:
case Instruction::And:
numOps = 2;
break;
case Instruction::ZExt:
case Instruction::UIToFP:
numOps = 1;
break;
default:;
llvm_unreachable("unknown opcode");
}
bool changed = false;
for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
auto v = PN->getIncomingValue(i);
if (auto C = dyn_cast<ConstantInt>(v)) {
if (code == Instruction::ZExt) {
lhsOps.push_back(ConstantInt::getFalse(C->getContext()));
continue;
} else if (C->isZero()) {
rhsOps.push_back(C);
lhsOps.push_back(C);
continue;
}
}
if (auto C = dyn_cast<ConstantFP>(v)) {
if (code == Instruction::UIToFP) {
if (C->isZero()) {
lhsOps.push_back(ConstantInt::getFalse(C->getContext()));
}
} else if (code == Instruction::FMul || code == Instruction::FSub ||
code == Instruction::FAdd) {
if (C->isZero()) {
rhsOps.push_back(C);
lhsOps.push_back(C);
continue;
}
}
}
if (auto fneg = dyn_cast<Instruction>(v)) {
if (fneg->getOpcode() == code) {
switch (code) {
case Instruction::FMul:
case Instruction::FSub:
case Instruction::FAdd:
fast &= fneg->isFast();
if (fneg->getOperand(0) == PN)
legal = false;
if (fneg->getOperand(1) == PN)
legal = false;
lhsOps.push_back(fneg->getOperand(0));
rhsOps.push_back(fneg->getOperand(1));
break;
case Instruction::Mul:
case Instruction::Sub:
case Instruction::Add:
NUW &= fneg->hasNoUnsignedWrap();
NSW &= fneg->hasNoSignedWrap();
if (fneg->getOperand(0) == PN)
legal = false;
if (fneg->getOperand(1) == PN)
legal = false;
lhsOps.push_back(fneg->getOperand(0));
rhsOps.push_back(fneg->getOperand(1));
break;
case Instruction::Or:
case Instruction::And:
if (fneg->getOperand(0) == PN)
legal = false;
if (fneg->getOperand(1) == PN)
legal = false;
lhsOps.push_back(fneg->getOperand(0));
rhsOps.push_back(fneg->getOperand(1));
break;
case Instruction::ICmp:
case Instruction::FCmp:
if (fneg->getOperand(0) == PN)
legal = false;
if (fneg->getOperand(1) == PN)
legal = false;
if (cmpPredicate) {
if (*cmpPredicate != cast<CmpInst>(fneg)->getPredicate())
legal = false;
} else {
cmpPredicate = cast<CmpInst>(fneg)->getPredicate();
}
lhsOps.push_back(fneg->getOperand(0));
rhsOps.push_back(fneg->getOperand(1));
break;
case Instruction::ZExt:
case Instruction::UIToFP:
if (cast<IntegerType>(fneg->getOperand(0)->getType())
->getBitWidth() != 1)
legal = false;
lhsOps.push_back(fneg->getOperand(0));
break;
default:
llvm_unreachable("unhandled opcode");
}
prevOps.push_back(fneg);
changed = true;
continue;
}
}
legal = false;
}
int preheader_fix = -1;
if (code == Instruction::ICmp || code == Instruction::FCmp) {
if (!cmpPredicate)
legal = false;
auto L = LI.getLoopFor(PN->getParent());
if (legal && L && L->getLoopPreheader() &&
L->getCanonicalInductionVariable() &&
L->getHeader() == PN->getParent()) {
auto ph_idx = PN->getBasicBlockIndex(L->getLoopPreheader());
if (isa<ConstantInt>(PN->getIncomingValue(ph_idx))) {
lhsOps[ph_idx] =
Constant::getNullValue(lhsOps[1 - ph_idx]->getType());
rhsOps[ph_idx] =
Constant::getNullValue(rhsOps[1 - ph_idx]->getType());
preheader_fix = ph_idx;
}
}
for (auto v : lhsOps)
if (v->getType() != lhsOps[0]->getType())
legal = false;
for (auto v : rhsOps)
if (v->getType() != rhsOps[0]->getType())
legal = false;
}
if (legal && changed) {
auto lhsPN = cast<PHINode>(pushcse(
B.CreatePHI(lhsOps[0]->getType(), PN->getNumIncomingValues())));
PHINode *rhsPN = nullptr;
if (numOps == 2)
rhsPN = cast<PHINode>(pushcse(
B.CreatePHI(rhsOps[0]->getType(), PN->getNumIncomingValues())));
for (auto val : llvm::enumerate(lhsOps))
lhsPN->addIncoming(val.value(), PN->getIncomingBlock(val.index()));
if (numOps == 2) {
for (auto val : llvm::enumerate(rhsOps))
rhsPN->addIncoming(val.value(),
PN->getIncomingBlock(val.index()));
}
Value *fneg = nullptr;
switch (code) {
case Instruction::FMul:
fneg = B.CreateFMul(lhsPN, rhsPN);
if (auto I = dyn_cast<Instruction>(fneg))
I->setFast(fast);
break;
case Instruction::FAdd:
fneg = B.CreateFAdd(lhsPN, rhsPN);
if (auto I = dyn_cast<Instruction>(fneg))
I->setFast(fast);
break;
case Instruction::FSub:
fneg = B.CreateFSub(lhsPN, rhsPN);
if (auto I = dyn_cast<Instruction>(fneg))
I->setFast(fast);
break;
case Instruction::Mul:
fneg = B.CreateMul(lhsPN, rhsPN, "", NUW, NSW);
break;
case Instruction::Add:
fneg = B.CreateAdd(lhsPN, rhsPN, "", NUW, NSW);
break;
case Instruction::Sub:
fneg = B.CreateSub(lhsPN, rhsPN, "", NUW, NSW);
break;
case Instruction::ZExt:
fneg = B.CreateZExt(lhsPN, PN->getType());
break;
case Instruction::FCmp:
case Instruction::ICmp:
fneg = B.CreateCmp(*cmpPredicate, lhsPN, rhsPN);
break;
case Instruction::UIToFP:
fneg = B.CreateUIToFP(lhsPN, PN->getType());
break;
case Instruction::Or:
fneg = B.CreateOr(lhsPN, rhsPN);
break;
case Instruction::And:
fneg = B.CreateAnd(lhsPN, rhsPN);
break;
default:
llvm_unreachable("unhandled opcode");
}
push(fneg);
if (preheader_fix != -1) {
auto L = LI.getLoopFor(PN->getParent());
auto idx = L->getCanonicalInductionVariable();
auto eq = pushcse(
B.CreateICmpEQ(idx, ConstantInt::get(idx->getType(), 0)));
fneg =
pushcse(B.CreateSelect(eq, PN->getIncomingValue(preheader_fix),
fneg, "phphisel." + cur->getName()));
}
replaceAndErase(cur, fneg);
return "PHIBinop";
}
}
}
// phi -> select
if (PN->getNumIncomingValues() == 2) {
for (int i = 0; i < 2; i++) {
auto prev = PN->getIncomingBlock(i);
if (!DT.dominates(prev, PN->getParent())) {
continue;
}
auto br = dyn_cast<BranchInst>(prev->getTerminator());
if (!br) {
continue;
}
if (!br->isConditional()) {
continue;
}
if (br->getSuccessor(0) != PN->getParent()) {
continue;
}
if (br->getSuccessor(1) != PN->getIncomingBlock(1 - i)) {
continue;
}
Value *specVal = PN->getIncomingValue(1 - i);
SetVector<Value *, std::deque<Value *>> todo;
todo.insert(specVal);
SetVector<Instruction *> toMove;
bool legal = true;
while (!todo.empty()) {
auto cur = *todo.begin();
todo.erase(todo.begin());
auto I = dyn_cast<Instruction>(cur);
if (!I)
continue;
if (I->mayReadOrWriteMemory()) {
legal = false;
break;
}
if (DT.dominates(I, PN))
continue;
for (size_t i = 0; i < I->getNumOperands(); i++)
todo.insert(I->getOperand(i));
toMove.insert(I);
}
if (!legal)
continue;
for (auto iter = toMove.rbegin(), end = toMove.rend(); iter != end;
iter++) {
(*iter)->moveBefore(br);
}
auto sel = pushcse(B.CreateSelect(
br->getCondition(), PN->getIncomingValueForBlock(prev),
PN->getIncomingValueForBlock(br->getSuccessor(1)),
"tphisel." + cur->getName()));
replaceAndErase(cur, sel);
return "TPhiSel";
}
}
}
if (auto SI = dyn_cast<SelectInst>(cur)) {
auto tval = replace(SI->getTrueValue(), SI->getCondition(),
ConstantInt::getTrue(SI->getContext()));
auto fval = replace(SI->getFalseValue(), SI->getCondition(),
ConstantInt::getFalse(SI->getContext()));
if (tval != SI->getTrueValue() || fval != SI->getFalseValue()) {
auto res = pushcse(B.CreateSelect(SI->getCondition(), tval, fval,
"postsel." + SI->getName()));
replaceAndErase(cur, res);
return "SelectReplace";
}
}
// and a, b -> and a b[with a true]
if (cur->getOpcode() == Instruction::And) {
auto lhs = replace(cur->getOperand(0), cur->getOperand(1),
ConstantInt::getTrue(cur->getContext()));
if (lhs != cur->getOperand(0)) {
auto res = pushcse(
B.CreateAnd(lhs, cur->getOperand(1), "postand." + cur->getName()));
replaceAndErase(cur, res);
return "AndReplaceLHS";
}
auto rhs = replace(cur->getOperand(1), cur->getOperand(0),
ConstantInt::getTrue(cur->getContext()));
if (rhs != cur->getOperand(1)) {
auto res = pushcse(
B.CreateAnd(cur->getOperand(0), rhs, "postand." + cur->getName()));
replaceAndErase(cur, res);
return "AndReplaceRHS";
}
}
// or a, b -> or a b[with a false]
if (cur->getOpcode() == Instruction::Or) {
auto lhs = replace(cur->getOperand(0), cur->getOperand(1),
ConstantInt::getFalse(cur->getContext()));
if (lhs != cur->getOperand(0)) {
auto res = pushcse(
B.CreateOr(lhs, cur->getOperand(1), "postor." + cur->getName()));
replaceAndErase(cur, res);
return "OrReplaceLHS";
}
auto rhs = replace(cur->getOperand(1), cur->getOperand(0),
ConstantInt::getFalse(cur->getContext()));
if (rhs != cur->getOperand(1)) {
auto res = pushcse(
B.CreateOr(cur->getOperand(0), rhs, "postor." + cur->getName()));
replaceAndErase(cur, res);
return "OrReplaceRHS";
}
}
return {};
}
class Constraints;
raw_ostream &operator<<(raw_ostream &os, const Constraints &c);
struct ConstraintComparator {
bool operator()(std::shared_ptr<const Constraints> lhs,
std::shared_ptr<const Constraints> rhs) const;
};
struct ConstraintContext {
ScalarEvolution &SE;
const Loop *loopToSolve;
const SmallVectorImpl<Instruction *> &Assumptions;
DominatorTree &DT;
using InnerTy = std::shared_ptr<const Constraints>;
using SetTy = std::set<InnerTy, ConstraintComparator>;
SetTy seen;
ConstraintContext(ScalarEvolution &SE, const Loop *loopToSolve,
const SmallVectorImpl<Instruction *> &Assumptions,
DominatorTree &DT)
: SE(SE), loopToSolve(loopToSolve), Assumptions(Assumptions), DT(DT) {
assert(loopToSolve);
}
ConstraintContext(const ConstraintContext &) = delete;
ConstraintContext(const ConstraintContext &ctx, InnerTy lhs)
: SE(ctx.SE), loopToSolve(ctx.loopToSolve), Assumptions(ctx.Assumptions),
DT(ctx.DT), seen(ctx.seen) {
seen.insert(lhs);
}
ConstraintContext(const ConstraintContext &ctx, InnerTy lhs, InnerTy rhs)
: SE(ctx.SE), loopToSolve(ctx.loopToSolve), Assumptions(ctx.Assumptions),
DT(ctx.DT), seen(ctx.seen) {
seen.insert(lhs);
seen.insert(rhs);
}
bool contains(InnerTy x) const { return seen.count(x) != 0; }
};
bool cannotDependOnLoopIV(const SCEV *S, const Loop *L) {
assert(L);
if (isa<SCEVConstant>(S))
return true;
if (auto M = dyn_cast<SCEVAddExpr>(S)) {
for (auto o : M->operands())
if (!cannotDependOnLoopIV(o, L))
return false;
return true;
}
if (auto M = dyn_cast<SCEVMulExpr>(S)) {
for (auto o : M->operands())
if (!cannotDependOnLoopIV(o, L))
return false;
return true;
}
if (auto M = dyn_cast<SCEVUDivExpr>(S)) {
for (auto o : {M->getLHS(), M->getRHS()})
if (!cannotDependOnLoopIV(o, L))
return false;
return true;
}
if (auto UV = dyn_cast<SCEVUnknown>(S)) {
auto U = UV->getValue();
if (isa<Argument>(U))
return true;
if (isa<Constant>(U))
return true;
auto I = cast<Instruction>(U);
return !L->contains(I->getParent());
}
if (auto addrec = dyn_cast<SCEVAddRecExpr>(S)) {
if (addrec->getLoop() == L)
return false;
for (auto o : addrec->operands())
if (!cannotDependOnLoopIV(o, L))
return false;
return true;
}
if (auto expr = dyn_cast<SCEVSignExtendExpr>(S)) {
return cannotDependOnLoopIV(expr->getOperand(), L);
}
llvm::errs() << " cannot tell if depends on loop iv: " << *S << "\n";
return false;
}
const SCEV *evaluateAtLoopIter(const SCEV *V, ScalarEvolution &SE,
const Loop *find, const SCEV *replace) {
assert(find);
if (cannotDependOnLoopIV(V, find))
return V;
if (auto addrec = dyn_cast<SCEVAddRecExpr>(V)) {
if (addrec->getLoop() == find) {
auto V2 = addrec->evaluateAtIteration(replace, SE);
return evaluateAtLoopIter(V2, SE, find, replace);
}
}
if (auto div = dyn_cast<SCEVUDivExpr>(V)) {
auto lhs = evaluateAtLoopIter(div->getLHS(), SE, find, replace);
if (!lhs)
return nullptr;
auto rhs = evaluateAtLoopIter(div->getRHS(), SE, find, replace);
if (!rhs)
return nullptr;
return SE.getUDivExpr(lhs, rhs);
}
return nullptr;
}
class Constraints : public std::enable_shared_from_this<Constraints> {
public:
const enum class Type {
Union = 0,
Intersect = 1,
Compare = 2,
All = 3,
None = 4
} ty;
using InnerTy = std::shared_ptr<const Constraints>;
using SetTy = std::set<InnerTy, ConstraintComparator>;
const SetTy values;
const SCEV *const node;
// whether equal to the node, or not equal to the node
bool isEqual;
// the loop of the iv comparing against.
const llvm::Loop *const Loop;
// using SetTy = SmallVector<InnerTy, 0>;
// using SetTy = SetVector<InnerTy, SmallVector<InnerTy, 0>,
// std::set<InnerTy>>;
Constraints()
: ty(Type::Union), values(), node(nullptr), isEqual(false),
Loop(nullptr) {}
private:
Constraints(const SCEV *v, bool isEqual, const llvm::Loop *Loop, bool)
: ty(Type::Compare), values(), node(v), isEqual(isEqual), Loop(Loop) {}
public:
static InnerTy make_compare(const SCEV *v, bool isEqual,
const llvm::Loop *Loop,
const ConstraintContext &ctx);
Constraints(Type t)
: ty(t), values(), node(nullptr), isEqual(false), Loop(nullptr) {
assert(t == Type::All || t == Type::None);
}
Constraints(Type t, const SetTy &c, bool check = true)
: ty(t), values(c), node(nullptr), isEqual(false), Loop(nullptr) {
assert(t != Type::All);
assert(t != Type::None);
assert(c.size() != 0);
assert(c.size() != 1);
#ifndef NDEBUG
SmallVector<InnerTy, 1> tmp(c.begin(), c.end());
for (unsigned i = 0; i < tmp.size(); i++)
for (unsigned j = 0; j < i; j++)
assert(*tmp[i] != *tmp[j]);
if (t == Type::Intersect) {
for (auto &v : c) {
assert(v->ty != Type::Intersect);
}
}
if (t == Type::Union) {
for (auto &v : c) {
assert(v->ty != Type::Union);
}
}
if (t == Type::Intersect && check) {
for (unsigned i = 0; i < tmp.size(); i++)
if (tmp[i]->ty == Type::Compare && tmp[i]->isEqual && tmp[i]->Loop)
for (unsigned j = 0; j < tmp.size(); j++)
if (tmp[j]->ty == Type::Compare)
if (auto s = dyn_cast<SCEVAddRecExpr>(tmp[j]->node))
assert(s->getLoop() != tmp[i]->Loop);
}
#endif
}
bool operator==(const Constraints &rhs) const {
if (ty != rhs.ty) {
return false;
}
if (node != rhs.node) {
return false;
}
if (isEqual != rhs.isEqual) {
return false;
}
if (Loop != rhs.Loop) {
return false;
}
if (values.size() != rhs.values.size()) {
return false;
}
for (auto pair : llvm::zip(values, rhs.values)) {
if (*std::get<0>(pair) != *std::get<1>(pair))
return false;
}
return true;
//) && !(rhs.values < values)
/*
for (size_t i=0; i<values.size(); i++)
if (*values[i] != *rhs.values[i]) return false;
return true;
*/
}
bool operator>(const Constraints &rhs) const { return rhs < *this; }
bool operator<(const Constraints &rhs) const {
if (ty < rhs.ty) {
return true;
}
if (ty > rhs.ty) {
return false;
}
if (node < rhs.node) {
return true;
}
if (node > rhs.node) {
return false;
}
if (isEqual < rhs.isEqual) {
return true;
}
if (isEqual > rhs.isEqual) {
return false;
}
if (Loop < rhs.Loop) {
return true;
}
if (Loop > rhs.Loop) {
return false;
}
if (values.size() < rhs.values.size()) {
return true;
}
if (values.size() > rhs.values.size()) {
return false;
}
for (auto pair : llvm::zip(values, rhs.values)) {
if (*std::get<0>(pair) < *std::get<1>(pair))
return true;
if (*std::get<0>(pair) > *std::get<1>(pair))
return false;
}
return false;
}
unsigned hash() const {
unsigned res = 5 * (unsigned)ty +
DenseMapInfo<const SCEV *>::getHashValue(node) + isEqual;
res = llvm::detail::combineHashValue(res, (unsigned)(size_t)Loop);
for (auto v : values)
res = llvm::detail::combineHashValue(res, v->hash());
return res;
}
bool operator!=(const Constraints &rhs) const { return !(*this == rhs); }
static InnerTy all() {
static auto allv = std::make_shared<Constraints>(Type::All);
return allv;
}
static InnerTy none() {
static auto nonev = std::make_shared<Constraints>(Type::None);
return nonev;
}
bool isNone() const { return ty == Type::None; }
bool isAll() const { return ty == Type::All; }
static void insert(SetTy &set, InnerTy ty) {
set.insert(ty);
int mcount = 0;
for (auto &v : set)
if (*v == *ty)
mcount++;
assert(mcount == 1);
/*
for (auto &v : set)
if (*v == *ty)
return;
set.push_back(ty);
*/
}
static SetTy intersect(const SetTy &lhs, const SetTy &rhs) {
SetTy res;
for (auto &v : lhs)
if (rhs.count(v))
res.insert(v);
return res;
}
static void set_subtract(SetTy &set, const SetTy &rhs) {
for (auto &v : rhs)
if (set.count(v))
set.erase(v);
/*
for (const auto &val : rhs)
for (auto I = set.begin(); I != set.end(); I++) {
if (**I == *val) {
set.erase(I);
break;
}
}
*/
}
__attribute__((noinline)) void dump() const { llvm::errs() << *this << "\n"; }
InnerTy notB(const ConstraintContext &ctx) const {
switch (ty) {
case Type::None:
return Constraints::all();
case Type::All:
return Constraints::none();
case Type::Compare:
return make_compare(node, !isEqual, Loop, ctx);
case Type::Union: {
// not of or's is and of not's
SetTy next;
for (const auto &v : values)
insert(next, v->notB(ctx));
if (next.size() == 1)
llvm::errs() << " uold : " << *this << "\n";
return std::make_shared<Constraints>(Type::Intersect, next);
}
case Type::Intersect: {
// not of and's is or of not's
SetTy next;
for (const auto &v : values)
insert(next, v->notB(ctx));
if (next.size() == 1)
llvm::errs() << " old : " << *this << "\n";
return std::make_shared<Constraints>(Type::Union, next);
}
}
return Constraints::none();
}
InnerTy orB(InnerTy rhs, const ConstraintContext &ctx) const {
auto notLHS = notB(ctx);
if (!notLHS)
return nullptr;
auto notRHS = rhs->notB(ctx);
if (!notRHS)
return nullptr;
auto andV = notLHS->andB(notRHS, ctx);
if (!andV)
return nullptr;
auto res = andV->notB(ctx);
return res;
}
InnerTy andB(const InnerTy rhs, const ConstraintContext &ctx) const {
assert(rhs);
if (*rhs == *this)
return shared_from_this();
if (rhs->isNone())
return rhs;
if (rhs->isAll())
return shared_from_this();
if (isNone())
return shared_from_this();
if (isAll())
return rhs;
// llvm::errs() << " anding: " << *this << " with " << *rhs << "\n";
if (ctx.contains(shared_from_this()) || ctx.contains(rhs)) {
// llvm::errs() << " %%% stopping recursion\n";
return nullptr;
}
if (ty == Type::Compare && rhs->ty == Type::Compare) {
auto sub = ctx.SE.getMinusSCEV(node, rhs->node);
if (Loop == rhs->Loop) {
// llvm::errs() << " + sameloop, sub=" << *sub << "\n";
if (auto cst = dyn_cast<SCEVConstant>(sub)) {
// the two solves are equivalent to each other
if (cst->getValue()->isZero()) {
// iv = a and iv = a
// also iv != a and iv != a
if (isEqual == rhs->isEqual)
return shared_from_this();
else {
// iv = a and iv != a
return Constraints::none();
}
} else {
// the two solves are guaranteed to be distinct
// iv == 0 and iv == 1
if (isEqual && rhs->isEqual) {
return Constraints::none();
} else if (!isEqual && !rhs->isEqual) {
// iv != 0 and iv != 1
SetTy vals;
insert(vals, shared_from_this());
insert(vals, rhs);
return std::make_shared<Constraints>(Type::Intersect, vals);
} else if (!isEqual) {
assert(rhs->isEqual);
// iv != 0 and iv == 1
return rhs;
;
} else {
// iv == 0 and iv != 1
assert(isEqual);
assert(!rhs->isEqual);
return shared_from_this();
}
}
} else if (isEqual || rhs->isEqual) {
// llvm::errs() << " + botheq\n";
// eq(i, a) & i ?= b -> eq(i, a) & (a ?= b)
if (auto addrec = dyn_cast<SCEVAddRecExpr>(sub)) {
// we want a ?= b, but we can only represent loopvar ?= something
// so suppose a-b is of the form X + Y * lv then a-b ?= 0 is
// X + Y * lv ?= 0 -> lv ?= - X / Y
if (addrec->isAffine()) {
auto X = addrec->getStart();
auto Y = addrec->getStepRecurrence(ctx.SE);
auto MinusX = X;
if (isa<SCEVConstant>(Y) &&
cast<SCEVConstant>(Y)->getAPInt().isNegative())
Y = ctx.SE.getNegativeSCEV(Y);
else
MinusX = ctx.SE.getNegativeSCEV(X);
auto div = ctx.SE.getUDivExpr(MinusX, Y);
auto div_e = ctx.SE.getUDivExactExpr(MinusX, Y);
// in case of inexact division, check that these exactly equal
// for replacement
if (div == div_e) {
if (isEqual) {
auto res = make_compare(div, /*isEqual*/ rhs->isEqual,
addrec->getLoop(), ctx);
// llvm::errs() << " simplified rhs to: " << *res << "\n";
return andB(res, ctx);
} else {
assert(rhs->isEqual);
auto res = make_compare(div, /*isEqual*/ isEqual,
addrec->getLoop(), ctx);
// llvm::errs() << " simplified lhs to: " << *res << "\n";
return rhs->andB(res, ctx);
}
}
}
}
if (isEqual && rhs->Loop &&
cannotDependOnLoopIV(sub, ctx.loopToSolve)) {
auto res = make_compare(sub, /*isEqual*/ rhs->isEqual,
/*loop*/ nullptr, ctx);
// llvm::errs() << " simplified(noloop) rhs from " << *rhs
// << " to: " << *res << "\n";
return andB(res, ctx);
}
if (rhs->isEqual && Loop &&
cannotDependOnLoopIV(sub, ctx.loopToSolve)) {
auto res =
make_compare(sub, /*isEqual*/ isEqual, /*loop*/ nullptr, ctx);
// llvm::errs() << " simplified(noloop) lhs from " << *rhs
// << " to: " << *res << "\n";
return rhs->andB(res, ctx);
}
llvm::errs() << " warning: potential but unhandled simplification of "
"equalities: "
<< *this << " and " << *rhs << " sub: " << *sub << "\n";
}
}
if (isEqual) {
if (Loop)
if (auto rep = evaluateAtLoopIter(rhs->node, ctx.SE, Loop, node))
if (rep != rhs->node) {
auto newrhs = make_compare(rep, rhs->isEqual, rhs->Loop, ctx);
return andB(newrhs, ctx);
}
// not loop -> node == 0
if (!Loop) {
for (auto sub1 : {ctx.SE.getMinusSCEV(node, rhs->node),
ctx.SE.getMinusSCEV(rhs->node, node)}) {
// llvm::errs() << " maybe replace lhs: " << *this << " rhs: " <<
// *rhs
// << " sub1: " << *sub1 << "\n";
auto newrhs = make_compare(sub1, rhs->isEqual, rhs->Loop, ctx);
if (*newrhs == *this)
return shared_from_this();
if (!isa<SCEVConstant>(rhs->node) && isa<SCEVConstant>(sub1)) {
return andB(newrhs, ctx);
}
}
}
}
if (rhs->isEqual) {
if (rhs->Loop)
if (auto rep = evaluateAtLoopIter(node, ctx.SE, rhs->Loop, rhs->node))
if (rep != node) {
auto newlhs = make_compare(rep, isEqual, Loop, ctx);
return newlhs->andB(rhs, ctx);
}
// not loop -> node == 0
if (!rhs->Loop) {
for (auto sub1 : {ctx.SE.getMinusSCEV(node, rhs->node),
ctx.SE.getMinusSCEV(rhs->node, node)}) {
// llvm::errs() << " maybe replace lhs2: " << *this << " rhs: " <<
// *rhs
// << " sub1: " << *sub1 << "\n";
auto newlhs = make_compare(sub1, isEqual, Loop, ctx);
if (*newlhs == *this)
return shared_from_this();
if (!isa<SCEVConstant>(node) && isa<SCEVConstant>(sub1)) {
return newlhs->andB(rhs, ctx);
}
}
}
}
if (!Loop && !rhs->Loop && isEqual == rhs->isEqual) {
if (node == ctx.SE.getNegativeSCEV(rhs->node))
return shared_from_this();
}
SetTy vals;
insert(vals, shared_from_this());
insert(vals, rhs);
if (vals.size() == 1) {
llvm::errs() << "this: " << *this << " rhs: " << *rhs << "\n";
}
auto res = std::make_shared<Constraints>(Type::Intersect, vals);
// llvm::errs() << " naiive comp merge: " << *res << "\n";
return res;
}
if (ty == Type::Intersect && rhs->ty == Type::Intersect) {
auto tmp = shared_from_this();
for (const auto &v : rhs->values) {
auto tmp2 = tmp->andB(v, ctx);
if (!tmp2)
return nullptr;
tmp = std::move(tmp2);
}
return tmp;
}
if (ty == Type::Intersect && rhs->ty == Type::Compare) {
SetTy vals;
// Force internal merging to do individual compares
bool foldedIn = false;
for (auto en : llvm::enumerate(values)) {
auto i = en.index();
auto v = en.value();
assert(v->ty != Type::Intersect);
assert(v->ty != Type::All);
assert(v->ty != Type::None);
assert(v->ty == Type::Compare || v->ty == Type::Union);
if (foldedIn) {
insert(vals, v);
continue;
}
// this is either a compare or a union
auto tmp = rhs->andB(v, ctx);
if (!tmp)
return nullptr;
switch (tmp->ty) {
case Type::Union:
case Type::All:
llvm_unreachable("Impossible");
case Type::None:
return Constraints::none();
case Type::Compare:
insert(vals, tmp);
foldedIn = true;
break;
// if intersected, these two were not foldable, try folding into later
case Type::Intersect: {
SetTy fuse;
insert(fuse, rhs);
insert(fuse, v);
Constraints trivialFuse(Type::Intersect, fuse, false);
// If this is not just making an intersect of the two operands,
// remerge.
if (trivialFuse != *tmp) {
InnerTy newlhs = Constraints::all();
bool legal = true;
for (auto en2 : llvm::enumerate(values)) {
auto i2 = en2.index();
auto v2 = en2.value();
if (i2 == i)
continue;
auto newlhs2 = newlhs->andB(v2, ctx);
if (!newlhs2) {
legal = false;
break;
}
newlhs = std::move(newlhs2);
}
if (legal) {
return newlhs->andB(tmp, ctx);
}
}
insert(vals, v);
}
}
}
if (!foldedIn) {
insert(vals, rhs);
return std::make_shared<Constraints>(Type::Intersect, vals);
} else {
auto cur = Constraints::all();
for (auto &iv : vals) {
auto cur2 = cur->andB(iv, ctx);
if (!cur2)
return nullptr;
cur = std::move(cur2);
}
return cur;
}
}
if ((ty == Type::Intersect || ty == Type::Compare) &&
rhs->ty == Type::Union) {
SetTy unionVals = rhs->values;
bool changed = false;
SetTy ivVals;
if (ty == Type::Intersect)
ivVals = values;
else
insert(ivVals, shared_from_this());
ConstraintContext ctxd(ctx, shared_from_this(), rhs);
for (const auto &iv : ivVals) {
SetTy nextunionVals;
bool midchanged = false;
for (auto &uv : unionVals) {
auto tmp = iv->andB(uv, ctxd);
if (!tmp) {
midchanged = false;
nextunionVals = unionVals;
break;
}
switch (tmp->ty) {
case Type::None:
case Type::Compare:
case Type::Union:
insert(nextunionVals, tmp);
changed |= tmp != uv;
break;
case Type::Intersect: {
SetTy fuse;
if (uv->ty == Type::Intersect)
fuse = uv->values;
else {
assert(uv->ty == Type::Compare);
insert(fuse, uv);
}
insert(fuse, iv);
Constraints trivialFuse(Type::Intersect, fuse, false);
if (trivialFuse != *tmp) {
insert(nextunionVals, tmp);
midchanged = true;
break;
}
insert(nextunionVals, uv);
break;
}
case Type::All:
llvm_unreachable("Impossible");
}
}
if (midchanged) {
unionVals = nextunionVals;
changed = true;
}
}
if (changed) {
auto cur = Constraints::none();
for (auto uv : unionVals) {
cur = cur->orB(uv, ctxd);
if (!cur)
break;
}
if (*cur != *rhs)
return andB(cur, ctx);
}
SetTy vals = ivVals;
insert(vals, rhs);
return std::make_shared<Constraints>(Type::Intersect, vals);
}
// Handled above via symmetry
if (rhs->ty == Type::Intersect || rhs->ty == Type::Compare) {
return rhs->andB(shared_from_this(), ctx);
}
// (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d)
// and (c or e))
if (ty == Type::Union && rhs->ty == Type::Union) {
if (*this == *rhs->notB(ctx)) {
return Constraints::none();
}
SetTy intersection = intersect(values, rhs->values);
if (intersection.size() != 0) {
InnerTy other_lhs = remove(intersection);
InnerTy other_rhs = rhs->remove(intersection);
InnerTy remainder;
if (intersection.size() == 1)
remainder = *intersection.begin();
else {
remainder = std::make_shared<Constraints>(Type::Union, intersection);
}
return remainder->orB(other_lhs->andB(other_rhs, ctx), ctx);
}
bool changed = false;
SetTy lhsVals = values;
SetTy rhsVals = rhs->values;
ConstraintContext ctxd(ctx, shared_from_this(), rhs);
SetTy distributedVals;
for (const auto &l1 : lhsVals) {
bool subchanged = false;
SetTy subDistributedVals;
for (auto &r1 : rhsVals) {
auto tmp = l1->andB(r1, ctxd);
if (!tmp) {
subchanged = false;
break;
}
if (l1->ty == Type::Intersect || r1->ty == Type::Intersect) {
subchanged = true;
insert(subDistributedVals, tmp);
} else {
SetTy fuse;
insert(fuse, l1);
insert(fuse, r1);
assert(fuse.size() == 2);
Constraints trivialFuse(Type::Intersect, fuse);
if ((trivialFuse != *tmp) || distributedVals.count(tmp)) {
subchanged = true;
}
insert(subDistributedVals, tmp);
}
}
if (subchanged) {
for (auto sub : subDistributedVals)
insert(distributedVals, sub);
changed = true;
} else {
auto midand = l1->andB(rhs, ctxd);
if (!midand) {
changed = false;
break;
}
insert(distributedVals, midand);
}
}
if (changed) {
auto cur = Constraints::none();
bool legal = true;
for (auto &uv : distributedVals) {
auto cur2 = cur->orB(uv, ctxd);
if (!cur2) {
legal = false;
break;
}
cur = std::move(cur2);
}
if (legal) {
return cur;
}
}
SetTy vals;
insert(vals, shared_from_this());
insert(vals, rhs);
auto res = std::make_shared<Constraints>(Type::Intersect, vals);
return res;
}
llvm::errs() << " andB this: " << *this << " rhs: " << *rhs << "\n";
llvm_unreachable("Illegal predicate state");
}
// what this would be like when removing the following list of constraints
InnerTy remove(const SetTy &sub) const {
assert(ty == Type::Union || ty == Type::Intersect);
SetTy res = values;
set_subtract(res, sub);
// res.set_subtract(sub);
if (res.size() == 0) {
if (ty == Type::Union)
return Constraints::none();
else
return Constraints::all();
} else if (res.size() == 1) {
return *res.begin();
} else {
return std::make_shared<Constraints>(ty, res);
}
}
SmallVector<std::pair<Value *, Value *>, 1>
allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP,
const ConstraintContext &ctx, IRBuilder<> &B) const;
};
void dump(const Constraints &c) { c.dump(); }
void dump(std::shared_ptr<const Constraints> c) { c->dump(); }
bool ConstraintComparator::operator()(
std::shared_ptr<const Constraints> lhs,
std::shared_ptr<const Constraints> rhs) const {
return *lhs < *rhs;
}
raw_ostream &operator<<(raw_ostream &os, const Constraints &c) {
switch (c.ty) {
case Constraints::Type::All:
return os << "All";
case Constraints::Type::None:
return os << "None";
case Constraints::Type::Union: {
os << "(Union ";
for (auto v : c.values)
os << *v << ", ";
os << ")";
return os;
}
case Constraints::Type::Intersect: {
os << "(Intersect ";
for (auto v : c.values)
os << *v << ", ";
os << ")";
return os;
}
case Constraints::Type::Compare: {
if (c.isEqual)
os << "(eq ";
else
os << "(ne ";
os << *c.node << ", L=";
if (c.Loop)
os << c.Loop->getHeader()->getName();
else
os << "nullptr";
return os << ")";
}
}
return os;
}
SmallVector<std::pair<Value *, Value *>, 1>
Constraints::allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP,
const ConstraintContext &ctx, IRBuilder<> &B) const {
switch (ty) {
case Type::None:
return {};
case Type::All:
llvm::errs() << *this << "\n";
llvm_unreachable("All not handled");
case Type::Compare: {
Value *cond = ConstantInt::getTrue(T->getContext());
if (ctx.loopToSolve != Loop) {
assert(ctx.loopToSolve);
Value *ivVal = Exp.expandCodeFor(node, T, IP);
Value *iv = nullptr;
if (Loop) {
iv = Loop->getCanonicalInductionVariable();
assert(iv);
} else {
iv = ConstantInt::getNullValue(ivVal->getType());
}
if (isEqual)
cond = B.CreateICmpEQ(ivVal, iv);
else
cond = B.CreateICmpNE(ivVal, iv);
return {std::make_pair((Value *)nullptr, cond)};
}
if (isEqual) {
return {std::make_pair(Exp.expandCodeFor(node, T, IP), cond)};
}
EmitFailure("NoSparsification", IP->getDebugLoc(), IP,
"Negated solution not handled: ", *this);
assert(0);
return {};
}
case Type::Union: {
SmallVector<std::pair<Value *, Value *>, 1> vals;
for (auto v : values)
for (auto sol : v->allSolutions(Exp, T, IP, ctx, B))
vals.push_back(sol);
return vals;
}
case Type::Intersect: {
{
SmallVector<InnerTy, 1> vals(values.begin(), values.end());
ssize_t unionidx = -1;
for (unsigned i = 0; i < vals.size(); i++) {
if (vals[i]->ty == Type::Union) {
unionidx = i;
bool allne = true;
for (auto &v : vals[i]->values) {
if (v->ty != Type::Compare || v->isEqual) {
allne = false;
break;
}
}
if (allne)
break;
}
}
if (unionidx != -1) {
auto others = Constraints::all();
for (unsigned j = 0; j < vals.size(); j++)
if (unionidx != j)
others = others->andB(vals[j], ctx);
SmallVector<std::pair<Value *, Value *>, 1> resvals;
for (auto &v : vals[unionidx]->values) {
auto tmp = v->andB(others, ctx);
for (const auto &sol : tmp->allSolutions(Exp, T, IP, ctx, B))
resvals.push_back(sol);
}
return resvals;
}
}
Value *solVal = nullptr;
Value *cond = ConstantInt::getTrue(T->getContext());
for (auto v : values) {
auto sols = v->allSolutions(Exp, T, IP, ctx, B);
if (sols.size() != 1) {
llvm::errs() << *this << "\n";
for (auto s : sols)
if (s.first)
llvm::errs() << " + sol: " << *s.first << " " << *s.second << "\n";
else
llvm::errs() << " + sol: " << s.first << " " << *s.second << "\n";
llvm::errs() << " v: " << *v << " this: " << *this << "\n";
llvm_unreachable("Intersect not handled (solsize>1)");
}
auto sol = sols[0];
if (sol.first) {
if (solVal != nullptr) {
llvm::errs() << *this << "\n";
llvm::errs() << " prevsolVal: " << *solVal << "\n";
llvm_unreachable("Intersect not handled (prevsolval)");
}
assert(solVal == nullptr);
solVal = sol.first;
}
cond = B.CreateAnd(cond, sol.second);
}
return {std::make_pair(solVal, cond)};
}
}
return {};
}
constexpr bool SparseDebug = false;
std::shared_ptr<const Constraints>
getSparseConditions(bool &legal, Value *val,
std::shared_ptr<const Constraints> defaultFloat,
Instruction *scope, const ConstraintContext &ctx) {
if (auto I = dyn_cast<Instruction>(val)) {
// Binary `and` is a bit-wise `umin`.
if (I->getOpcode() == Instruction::And) {
auto lhs = getSparseConditions(legal, I->getOperand(0),
Constraints::all(), I, ctx);
auto rhs = getSparseConditions(legal, I->getOperand(1),
Constraints::all(), I, ctx);
auto res = lhs->andB(rhs, ctx);
assert(res);
assert(ctx.seen.size() == 0);
if (SparseDebug) {
llvm::errs() << " getSparse(and, " << *I << "), lhs("
<< *I->getOperand(0) << ") = " << *lhs << "\n";
llvm::errs() << " getSparse(and, " << *I << "), rhs("
<< *I->getOperand(1) << ") = " << *rhs << "\n";
llvm::errs() << " getSparse(and, " << *I << ") = " << *res << "\n";
}
return res;
}
// Binary `or` is a bit-wise `umax`.
if (I->getOpcode() == Instruction::Or) {
auto lhs = getSparseConditions(legal, I->getOperand(0),
Constraints::none(), I, ctx);
auto rhs = getSparseConditions(legal, I->getOperand(1),
Constraints::none(), I, ctx);
auto res = lhs->orB(rhs, ctx);
if (SparseDebug) {
llvm::errs() << " getSparse(or, " << *I << "), lhs("
<< *I->getOperand(0) << ") = " << *lhs << "\n";
llvm::errs() << " getSparse(or, " << *I << "), rhs("
<< *I->getOperand(1) << ") = " << *rhs << "\n";
llvm::errs() << " getSparse(or, " << *I << ") = " << *res << "\n";
}
return res;
}
if (I->getOpcode() == Instruction::Xor) {
for (int i = 0; i < 2; i++) {
if (auto C = dyn_cast<ConstantInt>(I->getOperand(i)))
if (C->isOne()) {
auto pres =
getSparseConditions(legal, I->getOperand(1 - i),
defaultFloat->notB(ctx), scope, ctx);
auto res = pres->notB(ctx);
if (SparseDebug) {
llvm::errs() << " getSparse(not, " << *I << "), prev ("
<< *I->getOperand(0) << ") = " << *pres << "\n";
llvm::errs() << " getSparse(not, " << *I << ") = " << *res
<< "\n";
}
return res;
}
}
}
if (auto icmp = dyn_cast<ICmpInst>(I)) {
auto L = ctx.loopToSolve;
auto lhs = ctx.SE.getSCEVAtScope(icmp->getOperand(0), L);
auto rhs = ctx.SE.getSCEVAtScope(icmp->getOperand(1), L);
if (SparseDebug) {
llvm::errs() << " lhs: " << *lhs << "\n";
llvm::errs() << " rhs: " << *rhs << "\n";
}
auto sub1 = ctx.SE.getMinusSCEV(lhs, rhs);
if (icmp->getPredicate() == ICmpInst::ICMP_EQ ||
icmp->getPredicate() == ICmpInst::ICMP_NE) {
if (auto add = dyn_cast<SCEVAddRecExpr>(sub1)) {
if (add->isAffine()) {
// 0 === A + B * inc -> -A / B = inc
auto A = add->getStart();
if (auto B =
dyn_cast<SCEVConstant>(add->getStepRecurrence(ctx.SE))) {
auto MA = A;
if (B->getAPInt().isNegative())
B = cast<SCEVConstant>(ctx.SE.getNegativeSCEV(B));
else
MA = ctx.SE.getNegativeSCEV(A);
auto div = ctx.SE.getUDivExpr(MA, B);
auto div_e = ctx.SE.getUDivExactExpr(MA, B);
if (div == div_e) {
auto res = Constraints::make_compare(
div, icmp->getPredicate() == ICmpInst::ICMP_EQ,
add->getLoop(), ctx);
if (SparseDebug) {
llvm::errs()
<< " getSparse(icmp, " << *I << ") = " << *res << "\n";
}
return res;
}
}
}
}
if (cannotDependOnLoopIV(sub1, ctx.loopToSolve)) {
auto res = Constraints::make_compare(
sub1, icmp->getPredicate() == ICmpInst::ICMP_EQ, nullptr, ctx);
llvm::errs() << " getSparse(icmp_noloop, " << *I << ") = " << *res
<< "\n";
return res;
}
}
if (scope)
EmitWarning("NoSparsification", *I,
" No sparsification: not sparse solvable(icmp): ", *I,
" via ", *sub1);
if (SparseDebug) {
llvm::errs() << " getSparse(icmp_dflt, " << *I
<< ") = " << *defaultFloat << "\n";
}
return defaultFloat;
}
// cmp x, 1.0 -> false/true
if (auto fcmp = dyn_cast<FCmpInst>(I)) {
auto res = defaultFloat;
if (SparseDebug) {
llvm::errs() << " getSparse(fcmp, " << *I << ") = " << *res << "\n";
}
return res;
if (fcmp->getPredicate() == CmpInst::FCMP_OEQ ||
fcmp->getPredicate() == CmpInst::FCMP_UEQ) {
return Constraints::all();
} else if (fcmp->getPredicate() == CmpInst::FCMP_ONE ||
fcmp->getPredicate() == CmpInst::FCMP_UNE) {
return Constraints::none();
}
}
}
if (scope) {
EmitFailure("NoSparsification", scope->getDebugLoc(), scope,
" No sparsification: not sparse solvable: ", *val);
}
legal = false;
return defaultFloat;
}
Constraints::InnerTy Constraints::make_compare(const SCEV *v, bool isEqual,
const llvm::Loop *Loop,
const ConstraintContext &ctx) {
if (!Loop) {
assert(!isa<SCEVAddRecExpr>(v));
SmallVector<Instruction *, 1> noassumption;
ConstraintContext ctx2(ctx.SE, ctx.loopToSolve, noassumption, ctx.DT);
for (auto I : ctx.Assumptions) {
bool legal = true;
if (I->getParent()->getParent() !=
ctx.loopToSolve->getHeader()->getParent())
continue;
auto parsedCond = getSparseConditions(legal, I->getOperand(0),
Constraints::none(), nullptr, ctx2);
bool dominates = ctx.DT.dominates(I, ctx.loopToSolve->getHeader());
if (legal && dominates) {
if (parsedCond->ty == Type::Compare && !parsedCond->Loop) {
if (parsedCond->node == v ||
parsedCond->node == ctx.SE.getNegativeSCEV(v)) {
InnerTy res;
if (parsedCond->isEqual == isEqual)
res = Constraints::all();
else
res = Constraints::none();
return res;
}
}
}
}
}
// cannot have negative loop canonical induction var
if (Loop)
if (auto C = dyn_cast<SCEVConstant>(v))
if (C->getAPInt().isNegative()) {
if (isEqual)
return Constraints::none();
else
return Constraints::all();
}
return InnerTy(new Constraints(v, isEqual, Loop, false));
}
void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
SetVector<BasicBlock *> &toDenseBlocks) {
auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F);
auto &LI = FAM.getResult<LoopAnalysis>(F);
auto &DL = F.getParent()->getDataLayout();
QueueType Q(DT, LI);
{
llvm::SetVector<BasicBlock *> todoBlocks;
for (auto b : toDenseBlocks) {
auto L = LI.getLoopFor(b);
if (L) {
for (auto B : L->getBlocks())
todoBlocks.insert(B);
}
}
for (auto BB : todoBlocks)
for (auto &I : *BB)
if (!I.getType()->isVoidTy()) {
Q.insert(&I);
assert(Q.contains(&I));
}
}
// llvm::errs() << " pre fix inner: " << F << "\n";
// Full simplification
while (!Q.empty()) {
auto cur = Q.pop_back_val();
/*
std::set<Instruction *> prev;
for (auto v : Q)
prev.insert(v);
// llvm::errs() << "\n\n\n\n" << F << "\n";
llvm::errs() << "cur: " << *cur << "\n";
*/
auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL);
(void)changed;
/*
if (changed) {
llvm::errs() << "changed: " << *changed << "\n";
for (auto I : Q)
if (!prev.count(I))
llvm::errs() << " + " << *I << "\n";
// llvm::errs() << F << "\n\n";
}
*/
}
// llvm::errs() << " post fix inner " << F << "\n";
SmallVector<std::pair<BasicBlock *, BranchInst *>, 1> sparseBlocks;
bool legalToSparse = true;
for (auto &B : F)
if (auto br = dyn_cast<BranchInst>(B.getTerminator()))
if (br->isConditional())
for (int bidx = 0; bidx < 2; bidx++)
if (auto uncond_br =
dyn_cast<BranchInst>(br->getSuccessor(bidx)->getTerminator()))
if (!uncond_br->isConditional())
if (uncond_br->getSuccessor(0) == br->getSuccessor(1 - bidx)) {
auto blk = br->getSuccessor(bidx);
int countSparse = 0;
for (auto &I : *blk) {
if (auto CI = dyn_cast<CallInst>(&I)) {
if (auto F = CI->getCalledFunction()) {
if (F->hasFnAttribute("enzyme_sparse_accumulate")) {
countSparse++;
}
}
}
}
if (countSparse == 0)
continue;
if (countSparse > 1) {
legalToSparse = false;
EmitFailure(
"NoSparsification", br->getDebugLoc(), br, "F: ", F,
"\nMultiple distinct sparse stores in same block: ",
*blk);
break;
}
for (auto &I : *blk) {
if (auto CI = dyn_cast<CallInst>(&I)) {
if (auto F = CI->getCalledFunction()) {
if (F->hasFnAttribute("enzyme_sparse_accumulate")) {
continue;
}
}
if (isReadOnly(CI))
continue;
}
if (!I.mayWriteToMemory())
continue;
legalToSparse = false;
EmitFailure(
"NoSparsification", br->getDebugLoc(), br, "F: ", F,
"\nIllegal writing instruction in sparse block: ", I);
break;
}
if (!legalToSparse) {
break;
}
auto L = LI.getLoopFor(blk);
if (!L) {
legalToSparse = false;
EmitFailure("NoSparsification", br->getDebugLoc(), br,
"F: ", F, "\nCould not find loop for: ", *blk);
break;
}
auto idx = L->getCanonicalInductionVariable();
if (!idx) {
legalToSparse = false;
EmitFailure("NoSparsification", br->getDebugLoc(), br,
"F: ", F, "\nL:", *L,
"\nCould not find loop index: ", *L->getHeader());
break;
}
assert(idx);
auto preheader = L->getLoopPreheader();
if (!preheader) {
legalToSparse = false;
EmitFailure("NoSparsification", br->getDebugLoc(), br,
"F: ", F, "\nL:", *L,
"\nCould not find loop preheader");
break;
}
sparseBlocks.emplace_back(blk, br);
}
if (!legalToSparse) {
return;
}
// block, bound, scev for indexset
std::map<Loop *,
std::pair<std::pair<PHINode *, PHINode *>,
SmallVector<std::pair<BasicBlock *,
std::shared_ptr<const Constraints>>,
1>>>
forSparsification;
SmallVector<Instruction *, 1> Assumptions;
for (auto &BB : F)
for (auto &I : BB)
if (auto II = dyn_cast<IntrinsicInst>(&I))
if (II->getIntrinsicID() == Intrinsic::assume)
Assumptions.push_back(II);
bool sawError = false;
for (auto [blk, br] : sparseBlocks) {
auto L = LI.getLoopFor(blk);
assert(L);
auto idx = L->getCanonicalInductionVariable();
assert(idx);
auto preheader = L->getLoopPreheader();
assert(preheader);
// default is condition avoids sparse, negated is condition goes
// to sparse
auto cond = br->getCondition();
bool negated = br->getSuccessor(0) == blk;
bool legal = true;
// Whether the i1 value does not contain any icmp's
std::function<bool(Value *)> onlyDataDependentValues = [&](Value *val) {
auto I = cast<Instruction>(val);
if (I->getOpcode() == Instruction::Or) {
return onlyDataDependentValues(I->getOperand(0)) &&
onlyDataDependentValues(I->getOperand(1));
}
if (I->getOpcode() == Instruction::And) {
return onlyDataDependentValues(I->getOperand(0)) &&
onlyDataDependentValues(I->getOperand(1));
}
if (isa<FCmpInst>(I))
return true;
if (isa<ICmpInst>(I))
return false;
EmitFailure("NoSparsification", I->getDebugLoc(), I,
" No sparsification: bad datadepedent values check: ", *I);
legal = false;
return true;
};
// Simplify variable val which is known to branch away from the
// actual store (if not negated) or to the store (if negated)
// if! negated the result may become more false if negated the
// result may become more true
//
// default is condition avoids sparse, negated is condition goes
// to sparse
Instruction *context =
isa<Instruction>(cond) ? cast<Instruction>(cond) : idx;
ConstraintContext cctx(SE, L, Assumptions, DT);
auto solutions = getSparseConditions(
legal, cond, negated ? Constraints::all() : Constraints::none(),
context, cctx);
// llvm::errs() << " solutions pre negate: " << *solutions << "\n";
if (!negated) {
solutions = solutions->notB(cctx);
}
// llvm::errs() << " solutions post negate: " << *solutions << "\n";
if (!legal) {
sawError = true;
continue;
}
if (solutions == Constraints::none() || solutions == Constraints::all()) {
EmitFailure(
"NoSparsification", context->getDebugLoc(), context, "F: ", F,
"\nL: ", *L, "\ncond: ", *cond, " negated:", negated,
"\n No sparsification: not sparse solvable(nosoltn): solutions:",
*solutions);
sawError = true;
}
// llvm::errs() << " found solvable solutions " << *solutions << "\n";
if (forSparsification.count(L) == 0) {
{
IRBuilder<> PB(preheader->getTerminator());
forSparsification[L].first =
std::make_pair(PB.CreatePHI(idx->getType(), 0, "ph.idx"),
PB.CreatePHI(idx->getType(), 0, "loop.idx"));
}
Value *LoopCount = nullptr;
IRBuilder<> B(L->getHeader()->getFirstNonPHI());
{
SCEVExpander Exp(SE, DL, "sparseenzyme");
auto LoopCountS = SE.getBackedgeTakenCount(L);
LoopCount = B.CreateAdd(
ConstantInt::get(idx->getType(), 1),
Exp.expandCodeFor(LoopCountS, idx->getType(), &blk->front()));
}
Value *inbounds = B.CreateAnd(
B.CreateICmpSLT(idx, LoopCount),
B.CreateICmpSGE(idx, ConstantInt::get(idx->getType(), 0)));
Value *args[] = {inbounds, forSparsification[L].first.second};
B.CreateCall(F.getParent()->getOrInsertFunction(
"enzyme.sparse.inbounds", B.getVoidTy(),
inbounds->getType(), idx->getType()),
args);
}
IRBuilder<> B(br);
B.SetInsertPoint(br);
auto nidx = B.CreateICmpEQ(
forSparsification[L].first.first,
ConstantInt::get(idx->getType(), forSparsification[L].second.size()));
// TODO check direction
if (!negated)
nidx = B.CreateNot(nidx);
br->setCondition(nidx);
forSparsification[L].second.emplace_back(blk, solutions);
}
if (sawError) {
for (auto &pair : forSparsification) {
for (auto PN : {pair.second.first.first, pair.second.first.second}) {
PN->replaceAllUsesWith(UndefValue::get(PN->getType()));
PN->eraseFromParent();
}
}
if (llvm::verifyFunction(F, &llvm::errs())) {
llvm::errs() << F << "\n";
report_fatal_error("function failed verification (6)");
}
return;
}
if (forSparsification.size() == 0) {
auto context = &F.getEntryBlock().front();
EmitFailure("NoSparsification", context->getDebugLoc(), context, "F: ", F,
"\n Found no stores for sparsification");
return;
}
for (const auto &pair : forSparsification) {
auto L = pair.first;
auto [PN, inductPN] = pair.second.first;
auto ph = L->getLoopPreheader();
#if LLVM_VERSION_MAJOR >= 20
CodeExtractor ext(L->getBlocks(), &DT);
#else
CodeExtractor ext(DT, *L);
#endif
CodeExtractorAnalysisCache cache(F);
SetVector<Value *> Inputs, Outputs;
auto F2 = ext.extractCodeRegion(cache, Inputs, Outputs);
assert(F2);
F2->addFnAttr(Attribute::AlwaysInline);
for (auto U : F2->users())
cast<Instruction>(U)->eraseFromParent();
ssize_t induct_idx = -1;
ssize_t off_idx = -1;
for (auto en : llvm::enumerate(Inputs)) {
if (en.value() == inductPN)
induct_idx = en.index();
if (en.value() == PN)
off_idx = en.index();
}
assert(induct_idx != -1);
assert(off_idx != -1);
auto L2 = LI.getLoopFor(F2->getEntryBlock().getSingleSuccessor());
auto new_idx = F2->getArg(induct_idx);
auto L2Header = L2->getHeader();
auto new_lidx = L2->getCanonicalInductionVariable();
auto idxty = new_idx->getType();
auto new_pn = F2->getArg(off_idx);
// Find all sparse accumulates we weren't meant to handle
{
SmallVector<CallInst *, 1> toErase;
// First delete any accumulates in sub loops
for (auto SL : L2->getSubLoops())
for (auto B : SL->getBlocks())
for (auto &I : *B)
if (auto CI = dyn_cast<CallInst>(&I))
if (auto F = CI->getCalledFunction()) {
if (F->hasFnAttribute("enzyme_sparse_accumulate")) {
toErase.push_back(CI);
continue;
}
}
for (auto C : toErase)
C->eraseFromParent();
toErase.clear();
// Next delete any accumulates not in latchany loops
for (auto B : L2->getBlocks()) {
bool guarded = false;
if (auto P = B->getSinglePredecessor())
if (auto S = B->getSingleSuccessor())
if (auto BI = dyn_cast<BranchInst>(P->getTerminator()))
if (BI->isConditional())
for (size_t i = 0; i < 2; i++)
if (BI->getSuccessor(i) == B &&
BI->getSuccessor(1 - i) == S) {
auto val = BI->getCondition();
if (auto xori = dyn_cast<Instruction>(val))
if (xori->getOpcode() == Instruction::Xor)
val = xori->getOperand(0);
if (auto cmp = dyn_cast<ICmpInst>(val))
if (cmp->getOperand(0) == new_pn ||
cmp->getOperand(1) == new_pn)
guarded = true;
}
if (guarded)
continue;
for (auto &I : *B)
if (auto CI = dyn_cast<CallInst>(&I))
if (auto F = CI->getCalledFunction()) {
if (F->hasFnAttribute("enzyme_sparse_accumulate")) {
toErase.push_back(CI);
continue;
}
}
}
for (auto C : toErase)
C->eraseFromParent();
toErase.clear();
}
auto guard = L2->getLoopLatch()->getTerminator();
assert(guard);
IRBuilder<> G(guard);
G.CreateRetVoid();
guard->eraseFromParent();
new_lidx->replaceAllUsesWith(new_idx);
new_lidx->eraseFromParent();
auto phterm = ph->getTerminator();
IRBuilder<> B(phterm);
// We extracted code, reset analyses.
/*
DT.reset();
SE.forgetAllLoops();
*/
for (auto en : llvm::enumerate(pair.second.second)) {
auto off = en.index();
auto &solutions = en.value().second;
ConstraintContext ctx(SE, L, Assumptions, DT);
SCEVExpander Exp(SE, DL, "sparseenzyme", /*preservelcssa*/ false);
auto sols = solutions->allSolutions(Exp, idxty, phterm, ctx, B);
SmallVector<Value *, 1> prevSols;
for (auto [sol, condition] : sols) {
SmallVector<Value *, 1> args(Inputs.begin(), Inputs.end());
args[off_idx] = ConstantInt::get(idxty, off);
args[induct_idx] = sol;
for (auto sol2 : prevSols)
condition = B.CreateAnd(condition, B.CreateICmpNE(sol, sol2));
prevSols.push_back(sol);
auto BB = B.GetInsertBlock();
auto B2 = BB->splitBasicBlock(B.GetInsertPoint(), "poststore");
B2->moveAfter(BB);
BB->getTerminator()->eraseFromParent();
B.SetInsertPoint(BB);
auto callB = BasicBlock::Create(BB->getContext(), "tostore",
BB->getParent(), B2);
B.CreateCondBr(condition, callB, B2);
B.SetInsertPoint(callB);
B.CreateCall(F2, args);
B.CreateBr(B2);
B.SetInsertPoint(B2->getTerminator());
}
auto blk = en.value().first;
auto term = blk->getTerminator();
IRBuilder<> B2(blk);
B2.CreateRetVoid();
term->eraseFromParent();
}
PN->eraseFromParent();
for (auto &I : *L2Header) {
auto boundsCheck = dyn_cast<CallInst>(&I);
if (!boundsCheck)
continue;
auto BF = boundsCheck->getCalledFunction();
if (!BF)
continue;
if (BF->getName() != "enzyme.sparse.inbounds")
continue;
auto boundsCond = boundsCheck->getArgOperand(0);
auto next = L2Header->splitBasicBlock(boundsCheck);
auto exit = BasicBlock::Create(F2->getContext(), "bounds.exit", F2,
L2Header->getNextNode());
{
IRBuilder B(exit);
B.CreateRetVoid();
}
L2Header->getTerminator()->eraseFromParent();
{
IRBuilder B(L2Header);
B.CreateCondBr(boundsCond, next, exit);
}
boundsCheck->eraseFromParent();
inductPN->eraseFromParent();
break;
}
}
for (auto &F2 : F.getParent()->functions()) {
if (startsWith(F2.getName(), "__enzyme_product")) {
SmallVector<Instruction *, 1> toErase;
for (llvm::User *I : F2.users()) {
auto CB = cast<CallBase>(I);
IRBuilder<> B(CB);
B.setFastMathFlags(getFast());
Value *res = nullptr;
for (auto v : callOperands(CB)) {
if (res == nullptr)
res = v;
else {
res = B.CreateFMul(res, v);
}
}
CB->replaceAllUsesWith(res);
toErase.push_back(CB);
}
for (auto CB : toErase)
CB->eraseFromParent();
} else if (startsWith(F2.getName(), "__enzyme_sum")) {
SmallVector<Instruction *, 1> toErase;
for (llvm::User *I : F2.users()) {
auto CB = cast<CallBase>(I);
IRBuilder<> B(CB);
B.setFastMathFlags(getFast());
Value *res = nullptr;
for (auto v : callOperands(CB)) {
if (res == nullptr)
res = v;
else {
res = B.CreateFAdd(res, v);
}
}
CB->replaceAllUsesWith(res);
toErase.push_back(CB);
}
for (auto CB : toErase)
CB->eraseFromParent();
}
}
}
void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F,
const llvm::DataLayout &DL) {
auto load_fn = cast<Function>(getBaseObject(CI->getArgOperand(0)));
auto store_fn = cast<Function>(getBaseObject(CI->getArgOperand(1)));
size_t argstart = 2;
size_t num_args = CI->arg_size();
SmallVector<std::pair<Instruction *, Value *>, 1> users;
for (auto U : CI->users()) {
users.push_back(std::make_pair(cast<Instruction>(U), CI));
}
IntegerType *intTy = IntegerType::get(CI->getContext(), 64);
auto toInt = [&](IRBuilder<> &B, llvm::Value *V) {
if (auto PT = dyn_cast<PointerType>(V->getType())) {
if (PT->getAddressSpace() != 0) {
#if LLVM_VERSION_MAJOR < 17
if (CI->getContext().supportsTypedPointers()) {
V = B.CreateAddrSpaceCast(
V, PointerType::getUnqual(PT->getPointerElementType()));
} else {
V = B.CreateAddrSpaceCast(V,
PointerType::getUnqual(PT->getContext()));
}
#else
V = B.CreateAddrSpaceCast(V, PointerType::getUnqual(PT->getContext()));
#endif
}
return B.CreatePtrToInt(V, intTy);
}
auto IT = cast<IntegerType>(V->getType());
if (IT == intTy)
return V;
return B.CreateZExtOrTrunc(V, intTy);
};
SmallVector<Instruction *, 1> toErase;
ValueToValueMapTy replacements;
replacements[CI] = Constant::getNullValue(CI->getType());
Instruction *remaining = nullptr;
while (users.size()) {
auto pair = users.back();
users.pop_back();
auto U = pair.first;
auto val = pair.second;
if (replacements.count(U))
continue;
IRBuilder B(U);
if (auto CI = dyn_cast<CastInst>(U)) {
for (auto U : CI->users()) {
users.push_back(std::make_pair(cast<Instruction>(U), CI));
}
auto rep =
B.CreateCast(CI->getOpcode(), replacements[val], CI->getDestTy());
if (auto I = dyn_cast<Instruction>(rep))
I->setDebugLoc(CI->getDebugLoc());
replacements[CI] = rep;
continue;
}
if (auto SI = dyn_cast<SelectInst>(U)) {
for (auto U : SI->users()) {
users.push_back(std::make_pair(cast<Instruction>(U), SI));
}
auto tval = SI->getTrueValue();
auto fval = SI->getFalseValue();
auto rep = B.CreateSelect(
SI->getCondition(),
replacements.count(tval) ? (Value *)replacements[tval] : tval,
replacements.count(fval) ? (Value *)replacements[fval] : fval);
if (auto I = dyn_cast<Instruction>(rep))
I->setDebugLoc(SI->getDebugLoc());
replacements[SI] = rep;
continue;
}
/*
if (auto CI = dyn_cast<PHINode>(U)) {
for (auto U : CI->users()) {
users.push_back(std::make_pair(cast<Instruction>(U), CI));
}
continue;
}
*/
if (auto CI = dyn_cast<CallInst>(U)) {
auto funcName = getFuncNameFromCall(CI);
if (funcName == "julia.pointer_from_objref") {
for (auto U : CI->users()) {
users.push_back(std::make_pair(cast<Instruction>(U), CI));
}
auto *F = CI->getCalledOperand();
SmallVector<Value *, 1> args;
for (auto &arg : CI->args())
args.push_back(replacements[arg]);
auto FT = CI->getFunctionType();
auto cal = cast<CallInst>(B.CreateCall(FT, F, args));
cal->setCallingConv(CI->getCallingConv());
cal->setDebugLoc(CI->getDebugLoc());
replacements[CI] = cal;
continue;
}
}
if (auto CI = dyn_cast<GetElementPtrInst>(U)) {
for (auto U : CI->users()) {
users.push_back(std::make_pair(cast<Instruction>(U), CI));
}
SmallVector<Value *, 1> inds;
bool allconst = true;
for (auto &ind : CI->indices()) {
if (!isa<ConstantInt>(ind)) {
allconst = false;
}
inds.push_back(ind);
}
Value *gep;
if (inds.size() == 1) {
gep = ConstantInt::get(
intTy, (DL.getTypeSizeInBits(CI->getSourceElementType()) + 7) / 8);
gep = B.CreateMul(intTy == inds[0]->getType()
? inds[0]
: B.CreateZExtOrTrunc(inds[0], intTy),
gep, "", true, true);
gep = B.CreateAdd(B.CreatePtrToInt(replacements[val], intTy), gep);
gep = B.CreateIntToPtr(gep, CI->getType());
} else if (!allconst) {
gep = B.CreateGEP(CI->getSourceElementType(), replacements[val], inds);
if (auto ge = cast<GetElementPtrInst>(gep))
ge->setIsInBounds(CI->isInBounds());
} else {
APInt ai(64, 0);
CI->accumulateConstantOffset(DL, ai);
gep = B.CreateIntToPtr(ConstantInt::get(intTy, ai), CI->getType());
}
if (auto I = dyn_cast<Instruction>(gep))
I->setDebugLoc(CI->getDebugLoc());
replacements[CI] = gep;
continue;
}
if (auto LI = dyn_cast<LoadInst>(U)) {
auto diff = toInt(B, replacements[LI->getPointerOperand()]);
SmallVector<Value *, 2> args;
args.push_back(diff);
for (size_t i = argstart; i < num_args; i++)
args.push_back(CI->getArgOperand(i));
if (load_fn->getFunctionType()->getNumParams() != args.size()) {
auto fnName = load_fn->getName();
auto found_numargs = load_fn->getFunctionType()->getNumParams();
auto expected_numargs = args.size();
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" incorrect number of arguments to loader function ",
fnName, " expected ", expected_numargs, " found ",
found_numargs, " - ", *load_fn->getFunctionType());
continue;
} else {
bool tocontinue = false;
for (size_t i = 0; i < args.size(); i++) {
if (load_fn->getFunctionType()->getParamType(i) !=
args[i]->getType()) {
auto fnName = load_fn->getName();
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" incorrect type of argument ", i,
" to loader function ", fnName, " expected ",
*args[i]->getType(), " found ",
load_fn->getFunctionType()->params()[i]);
tocontinue = true;
args[i] = UndefValue::get(args[i]->getType());
}
}
if (tocontinue)
continue;
}
CallInst *call = B.CreateCall(load_fn, args);
call->setDebugLoc(LI->getDebugLoc());
Value *tmp = call;
if (tmp->getType() != LI->getType()) {
if (CastInst::castIsValid(Instruction::BitCast, tmp, LI->getType()))
tmp = B.CreateBitCast(tmp, LI->getType());
else {
auto fnName = load_fn->getName();
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" incorrect return type of loader function ", fnName,
" expected ", *LI->getType(), " found ",
*call->getType());
tmp = UndefValue::get(LI->getType());
}
}
LI->replaceAllUsesWith(tmp);
if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) {
InlineFunctionInfo IFI;
InlineFunction(*call, IFI);
}
toErase.push_back(LI);
continue;
}
if (auto SI = dyn_cast<StoreInst>(U)) {
assert(SI->getValueOperand() != val);
auto diff = toInt(B, replacements[SI->getPointerOperand()]);
SmallVector<Value *, 2> args;
args.push_back(SI->getValueOperand());
auto sty = store_fn->getFunctionType()->getParamType(0);
if (args[0]->getType() != store_fn->getFunctionType()->getParamType(0)) {
if (CastInst::castIsValid(Instruction::BitCast, args[0], sty))
args[0] = B.CreateBitCast(args[0], sty);
else {
auto args0ty = args[0]->getType();
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" first argument of store function must be the type of "
"the store found fn arg type ",
*sty, " expected ", *args0ty);
args[0] = UndefValue::get(sty);
}
}
args.push_back(diff);
for (size_t i = argstart; i < num_args; i++)
args.push_back(CI->getArgOperand(i));
if (store_fn->getFunctionType()->getNumParams() != args.size()) {
auto fnName = store_fn->getName();
auto found_numargs = store_fn->getFunctionType()->getNumParams();
auto expected_numargs = args.size();
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" incorrect number of arguments to store function ", fnName,
" expected ", expected_numargs, " found ", found_numargs,
" - ", *store_fn->getFunctionType());
continue;
} else {
bool tocontinue = false;
for (size_t i = 0; i < args.size(); i++) {
if (store_fn->getFunctionType()->getParamType(i) !=
args[i]->getType()) {
auto fnName = store_fn->getName();
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" incorrect type of argument ", i,
" to storeer function ", fnName, " expected ",
*args[i]->getType(), " found ",
store_fn->getFunctionType()->params()[i]);
tocontinue = true;
args[i] = UndefValue::get(args[i]->getType());
}
}
if (tocontinue)
continue;
}
auto call = B.CreateCall(store_fn, args);
call->setDebugLoc(SI->getDebugLoc());
if (store_fn->hasFnAttribute(Attribute::AlwaysInline)) {
InlineFunctionInfo IFI;
InlineFunction(*call, IFI);
}
toErase.push_back(SI);
continue;
}
remaining = U;
}
for (auto U : toErase)
U->eraseFromParent();
if (!remaining) {
CI->replaceAllUsesWith(Constant::getNullValue(CI->getType()));
CI->eraseFromParent();
} else if (replaceAll) {
EmitFailure("IllegalSparse", remaining->getDebugLoc(), remaining,
" Illegal remaining use (", *remaining, ") of todense (", *CI,
") in function ", *F);
}
}
bool LowerSparsification(llvm::Function *F, bool replaceAll) {
auto &DL = F->getParent()->getDataLayout();
bool changed = false;
SmallVector<CallBase *, 1> todo;
SetVector<BasicBlock *> toDenseBlocks;
for (auto &BB : *F) {
for (auto &I : BB) {
if (auto CI = dyn_cast<CallInst>(&I)) {
if (getFuncNameFromCall(CI).contains("__enzyme_todense")) {
todo.push_back(CI);
toDenseBlocks.insert(&BB);
}
}
}
}
for (auto CI : todo) {
changed = true;
replaceToDense(CI, replaceAll, F, DL);
}
todo.clear();
if (changed && EnzymeAutoSparsity) {
PassBuilder PB;
LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;
PB.registerModuleAnalyses(MAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.registerCGSCCAnalyses(CGAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
SimplifyCFGPass(SimplifyCFGOptions()).run(*F, FAM);
InstCombinePass().run(*F, FAM);
// required to make preheaders
LoopSimplifyPass().run(*F, FAM);
fixSparseIndices(*F, FAM, toDenseBlocks);
}
for (auto &BB : *F) {
for (auto &I : BB) {
if (auto CI = dyn_cast<CallInst>(&I)) {
if (getFuncNameFromCall(CI).contains("__enzyme_post_sparse_todense")) {
todo.push_back(CI);
}
}
}
}
for (auto CI : todo) {
changed = true;
replaceToDense(CI, replaceAll, F, DL);
}
return changed;
}