Revert "add test cases for splitting integer comparisons"
This reverts commit e0aa411647e1a525a3a0488d929ec71611388d54.
diff --git a/instrumentation/split-compares-pass.so.cc b/instrumentation/split-compares-pass.so.cc
index 3dbf787..b02a89f 100644
--- a/instrumentation/split-compares-pass.so.cc
+++ b/instrumentation/split-compares-pass.so.cc
@@ -47,99 +47,50 @@
using namespace llvm;
#include "afl-llvm-common.h"
-// uncomment this toggle function verification at each step. horribly slow, but
-// helps to pinpoint a potential problem in the splitting code.
-//#define VERIFY_TOO_MUCH 1
-
namespace {
class SplitComparesTransform : public ModulePass {
+
public:
static char ID;
SplitComparesTransform() : ModulePass(ID), enableFPSplit(0) {
+
initInstrumentList();
+
}
bool runOnModule(Module &M) override;
#if LLVM_VERSION_MAJOR >= 4
StringRef getPassName() const override {
+
#else
const char *getPassName() const override {
#endif
- return "AFL_SplitComparesTransform";
+ return "simplifies and splits ICMP instructions";
+
}
private:
int enableFPSplit;
- unsigned target_bitwidth = 8;
-
- size_t count = 0;
-
+ size_t splitIntCompares(Module &M, unsigned bitw);
size_t splitFPCompares(Module &M);
+ bool simplifyCompares(Module &M);
bool simplifyFPCompares(Module &M);
+ bool simplifyIntSignedness(Module &M);
size_t nextPowerOfTwo(size_t in);
- /// simplify the comparison and then split the comparison until the
- /// target_bitwidth is reached.
- bool simplifyAndSplit(CmpInst *I, Module &M);
- /// simplify a non-strict comparison (e.g., less than or equals)
- bool simplifyOrEqualsCompare(CmpInst *IcmpInst, Module &M,
- std::vector<CmpInst *> &worklist);
- /// simplify a signed comparison (signed less or greater than)
- bool simplifySignedCompare(CmpInst *IcmpInst, Module &M,
- std::vector<CmpInst *> &worklist);
- /// splits an icmp into nested icmps recursivly until target_bitwidth is
- /// reached
- bool splitCompare(CmpInst *I, Module &M);
-
- /// print an error to llvm's errs stream, but only if not ordered to be quiet
- void reportError(const StringRef msg, Instruction *I, Module &M) {
- if (!be_quiet) {
- errs() << "[AFL++ SplitComparesTransform] ERROR: " << msg << "\n";
- if (debug) {
- if (I) {
- errs() << "Instruction = " << *I << "\n";
- if (auto BB = I->getParent()) {
- if (auto F = BB->getParent()) {
- if (F->hasName()) {
- errs() << "|-> in function " << F->getName() << " ";
- }
- }
- }
- }
- auto n = M.getName();
- if (n.size() > 0) { errs() << "in module " << n << "\n"; }
- }
- }
- }
-
- bool isSupportedBitWidth(unsigned bitw) {
- // IDK whether the icmp code works on other bitwidths. I guess not? So we
- // try to avoid dealing with other weird icmp's that llvm might use (looking
- // at you `icmp i0`).
- switch (bitw) {
- case 8:
- case 16:
- case 32:
- case 64:
- case 128:
- case 256:
- return true;
- default:
- return false;
- }
- }
};
} // namespace
char SplitComparesTransform::ID = 0;
-/// This function splits FCMP instructions with xGE or xLE predicates into two
-/// FCMP instructions with predicate xGT or xLT and EQ
+/* This function splits FCMP instructions with xGE or xLE predicates into two
+ * FCMP instructions with predicate xGT or xLT and EQ */
bool SplitComparesTransform::simplifyFPCompares(Module &M) {
+
LLVMContext & C = M.getContext();
std::vector<Instruction *> fcomps;
IntegerType * Int1Ty = IntegerType::getInt1Ty(C);
@@ -147,18 +98,23 @@
/* iterate over all functions, bbs and instruction and add
* all integer comparisons with >= and <= predicates to the icomps vector */
for (auto &F : M) {
+
if (!isInInstrumentList(&F)) continue;
for (auto &BB : F) {
+
for (auto &IN : BB) {
+
CmpInst *selectcmpInst = nullptr;
if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
+
if (enableFPSplit &&
(selectcmpInst->getPredicate() == CmpInst::FCMP_OGE ||
selectcmpInst->getPredicate() == CmpInst::FCMP_UGE ||
selectcmpInst->getPredicate() == CmpInst::FCMP_OLE ||
selectcmpInst->getPredicate() == CmpInst::FCMP_ULE)) {
+
auto op0 = selectcmpInst->getOperand(0);
auto op1 = selectcmpInst->getOperand(1);
@@ -171,16 +127,22 @@
if (TyOp0->isArrayTy() || TyOp0->isVectorTy()) { continue; }
fcomps.push_back(selectcmpInst);
+
}
+
}
+
}
+
}
+
}
if (!fcomps.size()) { return false; }
/* transform for floating point */
for (auto &FcmpInst : fcomps) {
+
BasicBlock *bb = FcmpInst->getParent();
auto op0 = FcmpInst->getOperand(0);
@@ -193,6 +155,7 @@
CmpInst::Predicate new_pred;
switch (pred) {
+
case CmpInst::FCMP_UGE:
new_pred = CmpInst::FCMP_UGT;
break;
@@ -207,6 +170,7 @@
break;
default: // keep the compiler happy
continue;
+
}
/* split before the fcmp instruction */
@@ -250,428 +214,305 @@
/* replace the old FcmpInst with our new and shiny PHI inst */
BasicBlock::iterator ii(FcmpInst);
ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN);
+
}
return true;
+
}
-/// This function splits ICMP instructions with xGE or xLE predicates into two
-/// ICMP instructions with predicate xGT or xLT and EQ
-bool SplitComparesTransform::simplifyOrEqualsCompare(
- CmpInst *IcmpInst, Module &M, std::vector<CmpInst *> &worklist) {
- LLVMContext &C = M.getContext();
- IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
+/* This function splits ICMP instructions with xGE or xLE predicates into two
+ * ICMP instructions with predicate xGT or xLT and EQ */
+bool SplitComparesTransform::simplifyCompares(Module &M) {
- /* find out what the new predicate is going to be */
- auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
- if (!cmp_inst) { return false; }
-
- BasicBlock *bb = IcmpInst->getParent();
+ LLVMContext & C = M.getContext();
+ std::vector<Instruction *> icomps;
+ IntegerType * Int1Ty = IntegerType::getInt1Ty(C);
- auto op0 = IcmpInst->getOperand(0);
- auto op1 = IcmpInst->getOperand(1);
+ /* iterate over all functions, bbs and instruction and add
+ * all integer comparisons with >= and <= predicates to the icomps vector */
+ for (auto &F : M) {
- CmpInst::Predicate pred = cmp_inst->getPredicate();
- CmpInst::Predicate new_pred;
+ if (!isInInstrumentList(&F)) continue;
- switch (pred) {
- case CmpInst::ICMP_UGE:
- new_pred = CmpInst::ICMP_UGT;
- break;
- case CmpInst::ICMP_SGE:
- new_pred = CmpInst::ICMP_SGT;
- break;
- case CmpInst::ICMP_ULE:
- new_pred = CmpInst::ICMP_ULT;
- break;
- case CmpInst::ICMP_SLE:
- new_pred = CmpInst::ICMP_SLT;
- break;
- default: // keep the compiler happy
- return false;
- }
+ for (auto &BB : F) {
- /* split before the icmp instruction */
- BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
+ for (auto &IN : BB) {
- /* the old bb now contains a unconditional jump to the new one (end_bb)
- * we need to delete it later */
+ CmpInst *selectcmpInst = nullptr;
- /* create the ICMP instruction with new_pred and add it to the old basic
- * block bb it is now at the position where the old IcmpInst was */
- CmpInst *icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), icmp_np);
+ if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
- /* create a new basic block which holds the new EQ icmp */
- CmpInst *icmp_eq;
- /* insert middle_bb before end_bb */
- BasicBlock *middle_bb =
- BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
- icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1);
- middle_bb->getInstList().push_back(icmp_eq);
- /* add an unconditional branch to the end of middle_bb with destination
- * end_bb */
- BranchInst::Create(end_bb, middle_bb);
+ if (selectcmpInst->getPredicate() == CmpInst::ICMP_UGE ||
+ selectcmpInst->getPredicate() == CmpInst::ICMP_SGE ||
+ selectcmpInst->getPredicate() == CmpInst::ICMP_ULE ||
+ selectcmpInst->getPredicate() == CmpInst::ICMP_SLE) {
- /* replace the uncond branch with a conditional one, which depends on the
- * new_pred icmp. True goes to end, false to the middle (injected) bb */
- auto term = bb->getTerminator();
- BranchInst::Create(end_bb, middle_bb, icmp_np, bb);
- term->eraseFromParent();
+ auto op0 = selectcmpInst->getOperand(0);
+ auto op1 = selectcmpInst->getOperand(1);
- /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI
- * inst to wire up the loose ends */
- PHINode *PN = PHINode::Create(Int1Ty, 2, "");
- /* the first result depends on the outcome of icmp_eq */
- PN->addIncoming(icmp_eq, middle_bb);
- /* if the source was the original bb we know that the icmp_np yielded true
- * hence we can hardcode this value */
- PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
- /* replace the old IcmpInst with our new and shiny PHI inst */
- BasicBlock::iterator ii(IcmpInst);
- ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
+ IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+ IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
- worklist.push_back(icmp_np);
- worklist.push_back(icmp_eq);
+ /* this is probably not needed but we do it anyway */
+ if (!intTyOp0 || !intTyOp1) { continue; }
- return true;
-}
+ icomps.push_back(selectcmpInst);
-/// Simplify a signed comparison operator by splitting it into a unsigned and
-/// bit comparison. add all resulting comparisons to
-/// the worklist passed as a reference.
-bool SplitComparesTransform::simplifySignedCompare(
- CmpInst *IcmpInst, Module &M, std::vector<CmpInst *> &worklist) {
- LLVMContext &C = M.getContext();
- IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
+ }
- BasicBlock *bb = IcmpInst->getParent();
+ }
- auto op0 = IcmpInst->getOperand(0);
- auto op1 = IcmpInst->getOperand(1);
+ }
- IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
- if (!intTyOp0) { return false; }
- unsigned bitw = intTyOp0->getBitWidth();
- IntegerType *IntType = IntegerType::get(C, bitw);
-
- /* get the new predicate */
- auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
- if (!cmp_inst) { return false; }
- auto pred = cmp_inst->getPredicate();
- CmpInst::Predicate new_pred;
-
- if (pred == CmpInst::ICMP_SGT) {
- new_pred = CmpInst::ICMP_UGT;
-
- } else {
- new_pred = CmpInst::ICMP_ULT;
- }
-
- BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
-
- /* create a 1 bit compare for the sign bit. to do this shift and trunc
- * the original operands so only the first bit remains.*/
- Value *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit;
-
- IRBuilder<> IRB(bb->getTerminator());
- s_op0 = IRB.CreateLShr(op0, ConstantInt::get(IntType, bitw - 1));
- t_op0 = IRB.CreateTruncOrBitCast(s_op0, Int1Ty);
- s_op1 = IRB.CreateLShr(op1, ConstantInt::get(IntType, bitw - 1));
- t_op1 = IRB.CreateTruncOrBitCast(s_op1, Int1Ty);
- /* compare of the sign bits */
- icmp_sign_bit = IRB.CreateCmp(CmpInst::ICMP_EQ, t_op0, t_op1);
-
- /* create a new basic block which is executed if the signedness bit is
- * different */
- CmpInst * icmp_inv_sig_cmp;
- BasicBlock *sign_bb =
- BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb);
- if (pred == CmpInst::ICMP_SGT) {
- /* if we check for > and the op0 positive and op1 negative then the final
- * result is true. if op0 negative and op1 pos, the cmp must result
- * in false
- */
- icmp_inv_sig_cmp =
- CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1);
-
- } else {
- /* just the inverse of the above statement */
- icmp_inv_sig_cmp =
- CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1);
- }
-
- sign_bb->getInstList().push_back(icmp_inv_sig_cmp);
- BranchInst::Create(end_bb, sign_bb);
-
- /* create a new bb which is executed if signedness is equal */
- CmpInst * icmp_usign_cmp;
- BasicBlock *middle_bb =
- BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
- /* we can do a normal unsigned compare now */
- icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
-
- middle_bb->getInstList().push_back(icmp_usign_cmp);
- BranchInst::Create(end_bb, middle_bb);
-
- auto term = bb->getTerminator();
- /* if the sign is eq do a normal unsigned cmp, else we have to check the
- * signedness bit */
- BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb);
- term->eraseFromParent();
-
- PHINode *PN = PHINode::Create(Int1Ty, 2, "");
-
- PN->addIncoming(icmp_usign_cmp, middle_bb);
- PN->addIncoming(icmp_inv_sig_cmp, sign_bb);
-
- BasicBlock::iterator ii(IcmpInst);
- ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
-
- // save for later
- worklist.push_back(icmp_usign_cmp);
-
- // signed comparisons are not supported by the splitting code, so we must not
- // add it to the worklist.
- // worklist.push_back(icmp_inv_sig_cmp);
-
- return true;
-}
-
-bool SplitComparesTransform::splitCompare(CmpInst *cmp_inst, Module &M) {
- auto pred = cmp_inst->getPredicate();
- switch (pred) {
- case CmpInst::ICMP_EQ:
- case CmpInst::ICMP_NE:
- case CmpInst::ICMP_UGT:
- case CmpInst::ICMP_ULT:
- break;
- default:
- // unsupported predicate!
- return false;
- }
-
- auto op0 = cmp_inst->getOperand(0);
- auto op1 = cmp_inst->getOperand(1);
-
- // get bitwidth by checking the bitwidth of the first operator
- IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
- if (!intTyOp0) {
- // not an integer type
- return false;
- }
-
- unsigned bitw = intTyOp0->getBitWidth();
- if (bitw == target_bitwidth) {
- // already the target bitwidth so we have to do nothing here.
- return true;
- }
-
- LLVMContext &C = M.getContext();
- IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
- BasicBlock *bb = cmp_inst->getParent();
- IntegerType *OldIntType = IntegerType::get(C, bitw);
- IntegerType *NewIntType = IntegerType::get(C, bitw / 2);
- BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(cmp_inst));
- CmpInst *icmp_high, *icmp_low;
-
- /* create the comparison of the top halves of the original operands */
- Instruction *s_op0, *op0_high, *s_op1, *op1_high;
-
- s_op0 = BinaryOperator::Create(Instruction::LShr, op0,
- ConstantInt::get(OldIntType, bitw / 2));
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0);
- op0_high = new TruncInst(s_op0, NewIntType);
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), op0_high);
-
- s_op1 = BinaryOperator::Create(Instruction::LShr, op1,
- ConstantInt::get(OldIntType, bitw / 2));
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1);
- op1_high = new TruncInst(s_op1, NewIntType);
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), op1_high);
-
- icmp_high = CmpInst::Create(Instruction::ICmp, pred, op0_high, op1_high);
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
- icmp_high);
-
- PHINode *PN = nullptr;
-
- /* now we have to destinguish between == != and > < */
- if (pred == CmpInst::ICMP_EQ || pred == CmpInst::ICMP_NE) {
- /* transformation for == and != icmps */
-
- /* create a compare for the lower half of the original operands */
- BasicBlock *cmp_low_bb =
- BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb);
-
- Value *op0_low, *op1_low;
-
- IRBuilder<> Builder(cmp_low_bb);
-
- op0_low = Builder.CreateTrunc(op0, NewIntType);
- op1_low = Builder.CreateTrunc(op1, NewIntType);
-
- icmp_low = dyn_cast<CmpInst>(Builder.CreateICmp(pred, op0_low, op1_low));
- // icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
- // cmp_low_bb->getInstList().push_back(icmp_low);
-
- BranchInst::Create(end_bb, cmp_low_bb);
-
- /* dependent on the cmp of the high parts go to the end or go on with
- * the comparison */
- auto term = bb->getTerminator();
- BranchInst *br = nullptr;
- if (pred == CmpInst::ICMP_EQ) {
- br = BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb);
- } else {
- /* CmpInst::ICMP_NE */
- br = BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb);
}
+
+ }
+
+ if (!icomps.size()) { return false; }
+
+ for (auto &IcmpInst : icomps) {
+
+ BasicBlock *bb = IcmpInst->getParent();
+
+ auto op0 = IcmpInst->getOperand(0);
+ auto op1 = IcmpInst->getOperand(1);
+
+ /* find out what the new predicate is going to be */
+ auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
+ if (!cmp_inst) { continue; }
+ auto pred = cmp_inst->getPredicate();
+ CmpInst::Predicate new_pred;
+
+ switch (pred) {
+
+ case CmpInst::ICMP_UGE:
+ new_pred = CmpInst::ICMP_UGT;
+ break;
+ case CmpInst::ICMP_SGE:
+ new_pred = CmpInst::ICMP_SGT;
+ break;
+ case CmpInst::ICMP_ULE:
+ new_pred = CmpInst::ICMP_ULT;
+ break;
+ case CmpInst::ICMP_SLE:
+ new_pred = CmpInst::ICMP_SLT;
+ break;
+ default: // keep the compiler happy
+ continue;
+
+ }
+
+ /* split before the icmp instruction */
+ BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
+
+ /* the old bb now contains a unconditional jump to the new one (end_bb)
+ * we need to delete it later */
+
+ /* create the ICMP instruction with new_pred and add it to the old basic
+ * block bb it is now at the position where the old IcmpInst was */
+ Instruction *icmp_np;
+ icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
+ bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
+ icmp_np);
+
+ /* create a new basic block which holds the new EQ icmp */
+ Instruction *icmp_eq;
+ /* insert middle_bb before end_bb */
+ BasicBlock *middle_bb =
+ BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
+ icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1);
+ middle_bb->getInstList().push_back(icmp_eq);
+ /* add an unconditional branch to the end of middle_bb with destination
+ * end_bb */
+ BranchInst::Create(end_bb, middle_bb);
+
+ /* replace the uncond branch with a conditional one, which depends on the
+ * new_pred icmp. True goes to end, false to the middle (injected) bb */
+ auto term = bb->getTerminator();
+ BranchInst::Create(end_bb, middle_bb, icmp_np, bb);
term->eraseFromParent();
- /* create the PHI and connect the edges accordingly */
- PN = PHINode::Create(Int1Ty, 2, "");
- PN->addIncoming(icmp_low, cmp_low_bb);
- Value *val = nullptr;
- if (pred == CmpInst::ICMP_EQ) {
- val = ConstantInt::get(Int1Ty, 0);
- } else {
- /* CmpInst::ICMP_NE */
- val = ConstantInt::get(Int1Ty, 1);
- }
- PN->addIncoming(val, icmp_high->getParent());
+ /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI
+ * inst to wire up the loose ends */
+ PHINode *PN = PHINode::Create(Int1Ty, 2, "");
+ /* the first result depends on the outcome of icmp_eq */
+ PN->addIncoming(icmp_eq, middle_bb);
+ /* if the source was the original bb we know that the icmp_np yielded true
+ * hence we can hardcode this value */
+ PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
+ /* replace the old IcmpInst with our new and shiny PHI inst */
+ BasicBlock::iterator ii(IcmpInst);
+ ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
- } else {
- /* CmpInst::ICMP_UGT and CmpInst::ICMP_ULT */
- /* transformations for < and > */
+ }
- /* create a basic block which checks for the inverse predicate.
- * if this is true we can go to the end if not we have to go to the
- * bb which checks the lower half of the operands */
- Instruction *icmp_inv_cmp, *op0_low, *op1_low;
- BasicBlock * inv_cmp_bb =
- BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb);
- if (pred == CmpInst::ICMP_UGT) {
- icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
- op0_high, op1_high);
+ return true;
- } else {
- icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT,
- op0_high, op1_high);
+}
+
+/* this function transforms signed compares to equivalent unsigned compares */
+bool SplitComparesTransform::simplifyIntSignedness(Module &M) {
+
+ LLVMContext & C = M.getContext();
+ std::vector<Instruction *> icomps;
+ IntegerType * Int1Ty = IntegerType::getInt1Ty(C);
+
+ /* iterate over all functions, bbs and instructions and add
+ * all signed compares to icomps vector */
+ for (auto &F : M) {
+
+ if (!isInInstrumentList(&F)) continue;
+
+ for (auto &BB : F) {
+
+ for (auto &IN : BB) {
+
+ CmpInst *selectcmpInst = nullptr;
+
+ if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
+
+ if (selectcmpInst->getPredicate() == CmpInst::ICMP_SGT ||
+ selectcmpInst->getPredicate() == CmpInst::ICMP_SLT) {
+
+ auto op0 = selectcmpInst->getOperand(0);
+ auto op1 = selectcmpInst->getOperand(1);
+
+ IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+ IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+
+ /* see above */
+ if (!intTyOp0 || !intTyOp1) { continue; }
+
+ /* i think this is not possible but to lazy to look it up */
+ if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) {
+
+ continue;
+
+ }
+
+ icomps.push_back(selectcmpInst);
+
+ }
+
+ }
+
+ }
+
}
- inv_cmp_bb->getInstList().push_back(icmp_inv_cmp);
+ }
+
+ if (!icomps.size()) { return false; }
+
+ for (auto &IcmpInst : icomps) {
+
+ BasicBlock *bb = IcmpInst->getParent();
+
+ auto op0 = IcmpInst->getOperand(0);
+ auto op1 = IcmpInst->getOperand(1);
+
+ IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+ if (!intTyOp0) { continue; }
+ unsigned bitw = intTyOp0->getBitWidth();
+ IntegerType *IntType = IntegerType::get(C, bitw);
+
+ /* get the new predicate */
+ auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
+ if (!cmp_inst) { continue; }
+ auto pred = cmp_inst->getPredicate();
+ CmpInst::Predicate new_pred;
+
+ if (pred == CmpInst::ICMP_SGT) {
+
+ new_pred = CmpInst::ICMP_UGT;
+
+ } else {
+
+ new_pred = CmpInst::ICMP_ULT;
+
+ }
+
+ BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
+
+ /* create a 1 bit compare for the sign bit. to do this shift and trunc
+ * the original operands so only the first bit remains.*/
+ Instruction *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit;
+
+ s_op0 = BinaryOperator::Create(Instruction::LShr, op0,
+ ConstantInt::get(IntType, bitw - 1));
+ bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0);
+ t_op0 = new TruncInst(s_op0, Int1Ty);
+ bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op0);
+
+ s_op1 = BinaryOperator::Create(Instruction::LShr, op1,
+ ConstantInt::get(IntType, bitw - 1));
+ bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1);
+ t_op1 = new TruncInst(s_op1, Int1Ty);
+ bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op1);
+
+ /* compare of the sign bits */
+ icmp_sign_bit =
+ CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_op0, t_op1);
+ bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
+ icmp_sign_bit);
+
+ /* create a new basic block which is executed if the signedness bit is
+ * different */
+ Instruction *icmp_inv_sig_cmp;
+ BasicBlock * sign_bb =
+ BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb);
+ if (pred == CmpInst::ICMP_SGT) {
+
+ /* if we check for > and the op0 positive and op1 negative then the final
+ * result is true. if op0 negative and op1 pos, the cmp must result
+ * in false
+ */
+ icmp_inv_sig_cmp =
+ CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1);
+
+ } else {
+
+ /* just the inverse of the above statement */
+ icmp_inv_sig_cmp =
+ CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1);
+
+ }
+
+ sign_bb->getInstList().push_back(icmp_inv_sig_cmp);
+ BranchInst::Create(end_bb, sign_bb);
+
+ /* create a new bb which is executed if signedness is equal */
+ Instruction *icmp_usign_cmp;
+ BasicBlock * middle_bb =
+ BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
+ /* we can do a normal unsigned compare now */
+ icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
+ middle_bb->getInstList().push_back(icmp_usign_cmp);
+ BranchInst::Create(end_bb, middle_bb);
auto term = bb->getTerminator();
+ /* if the sign is eq do a normal unsigned cmp, else we have to check the
+ * signedness bit */
+ BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb);
term->eraseFromParent();
- BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb);
- /* create a bb which handles the cmp of the lower halves */
- BasicBlock *cmp_low_bb =
- BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb);
- op0_low = new TruncInst(op0, NewIntType);
- cmp_low_bb->getInstList().push_back(op0_low);
- op1_low = new TruncInst(op1, NewIntType);
- cmp_low_bb->getInstList().push_back(op1_low);
+ PHINode *PN = PHINode::Create(Int1Ty, 2, "");
- icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
- cmp_low_bb->getInstList().push_back(icmp_low);
- BranchInst::Create(end_bb, cmp_low_bb);
+ PN->addIncoming(icmp_usign_cmp, middle_bb);
+ PN->addIncoming(icmp_inv_sig_cmp, sign_bb);
- BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb);
+ BasicBlock::iterator ii(IcmpInst);
+ ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
- PN = PHINode::Create(Int1Ty, 3);
- PN->addIncoming(icmp_low, cmp_low_bb);
- PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
- PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb);
- }
-
- BasicBlock::iterator ii(cmp_inst);
- ReplaceInstWithInst(cmp_inst->getParent()->getInstList(), ii, PN);
-
- // We split the comparison into low and high. If this isn't our target
- // bitwidth we recursivly split the low and high parts again until we have
- // target bitwidth.
- if ((bitw / 2) > target_bitwidth) {
- if (!splitCompare(icmp_high, M)) {
- reportError("Failed to split high comparison", icmp_high, M);
- return false;
- }
- if (!splitCompare(icmp_low, M)) {
- reportError("Failed to split low comparison", icmp_low, M);
- return false;
- }
}
return true;
-}
-bool SplitComparesTransform::simplifyAndSplit(CmpInst *I, Module &M) {
- std::vector<CmpInst *> worklist;
-
- auto op0 = I->getOperand(0);
- auto op1 = I->getOperand(1);
- if (!op0 || !op1) { return false; }
- auto op0Ty = dyn_cast<IntegerType>(op0->getType());
- if (!op0Ty || !isa<IntegerType>(op1->getType())) { return true; }
-
- unsigned bitw = op0Ty->getBitWidth();
-
-#ifdef VERIFY_TOO_MUCH
- auto F = I->getParent()->getParent();
-#endif
-
- // we run the comparison simplification on all compares regardless of their
- // bitwidth.
- if (I->getPredicate() == CmpInst::ICMP_UGE ||
- I->getPredicate() == CmpInst::ICMP_SGE ||
- I->getPredicate() == CmpInst::ICMP_ULE ||
- I->getPredicate() == CmpInst::ICMP_SLE) {
- if (!simplifyOrEqualsCompare(I, M, worklist)) {
- reportError(
- "Failed to simplify inequality or equals comparison "
- "(UGE,SGE,ULE,SLE)",
- I, M);
- }
- } else if (I->getPredicate() == CmpInst::ICMP_SGT ||
- I->getPredicate() == CmpInst::ICMP_SLT) {
- if (!simplifySignedCompare(I, M, worklist)) {
- reportError("Failed to simplify signed comparison (SGT,SLT)", I, M);
- }
- }
-
-#ifdef VERIFY_TOO_MUCH
- if (verifyFunction(*F, &errs())) {
- reportError("simpliyfing compare lead to broken function", nullptr, M);
- }
-#endif
-
- // the simplification methods replace the original CmpInst and push the
- // resulting new CmpInst into the worklist. If the worklist is empty then
- // we only have to split the original CmpInst.
- if (worklist.size() == 0) { worklist.push_back(I); }
-
- for (auto cmp : worklist) {
- // we split the simplified compares into comparisons with smaller bitwidths
- // if they are larger than our target_bitwidth.
- if (bitw > target_bitwidth) {
- if (!splitCompare(cmp, M)) {
- reportError("Failed to split comparison", cmp, M);
- }
-
-#ifdef VERIFY_TOO_MUCH
- if (verifyFunction(*F, &errs())) {
- reportError("splitting compare lead to broken function", nullptr, M);
- }
-#endif
- }
- }
-
- count++;
- return true;
}
size_t SplitComparesTransform::nextPowerOfTwo(size_t in) {
+
--in;
in |= in >> 1;
in |= in >> 2;
@@ -679,10 +520,12 @@
// in |= in >> 8;
// in |= in >> 16;
return in + 1;
+
}
/* splits fcmps into two nested fcmps with sign compare and the rest */
size_t SplitComparesTransform::splitFPCompares(Module &M) {
+
size_t count = 0;
LLVMContext &C = M.getContext();
@@ -694,9 +537,13 @@
/* define unions with floating point and (sign, exponent, mantissa) triples
*/
if (dl.isLittleEndian()) {
+
} else if (dl.isBigEndian()) {
+
} else {
+
return count;
+
}
#endif
@@ -706,13 +553,17 @@
/* get all EQ, NE, GT, and LT fcmps. if the other two
* functions were executed only these four predicates should exist */
for (auto &F : M) {
+
if (!isInInstrumentList(&F)) continue;
for (auto &BB : F) {
+
for (auto &IN : BB) {
+
CmpInst *selectcmpInst = nullptr;
if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
+
if (selectcmpInst->getPredicate() == CmpInst::FCMP_OEQ ||
selectcmpInst->getPredicate() == CmpInst::FCMP_UEQ ||
selectcmpInst->getPredicate() == CmpInst::FCMP_ONE ||
@@ -721,6 +572,7 @@
selectcmpInst->getPredicate() == CmpInst::FCMP_OGT ||
selectcmpInst->getPredicate() == CmpInst::FCMP_ULT ||
selectcmpInst->getPredicate() == CmpInst::FCMP_OLT) {
+
auto op0 = selectcmpInst->getOperand(0);
auto op1 = selectcmpInst->getOperand(1);
@@ -732,10 +584,15 @@
if (TyOp0->isArrayTy() || TyOp0->isVectorTy()) { continue; }
fcomps.push_back(selectcmpInst);
+
}
+
}
+
}
+
}
+
}
if (!fcomps.size()) { return count; }
@@ -743,6 +600,7 @@
IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
for (auto &FcmpInst : fcomps) {
+
BasicBlock *bb = FcmpInst->getParent();
auto op0 = FcmpInst->getOperand(0);
@@ -867,6 +725,7 @@
BasicBlock::iterator(signequal_bb->getTerminator()), t_e1);
if (sizeInBits - precision < exTySizeBytes * 8) {
+
m_e0 = BinaryOperator::Create(
Instruction::And, t_e0,
ConstantInt::get(t_e0->getType(), mask_exponent));
@@ -879,8 +738,10 @@
BasicBlock::iterator(signequal_bb->getTerminator()), m_e1);
} else {
+
m_e0 = t_e0;
m_e1 = t_e1;
+
}
/* compare the exponents of the operands */
@@ -888,6 +749,7 @@
Instruction *icmp_exponent_result;
BasicBlock * signequal2_bb = signequal_bb;
switch (FcmpInst->getPredicate()) {
+
case CmpInst::FCMP_UEQ:
case CmpInst::FCMP_OEQ:
icmp_exponent_result =
@@ -957,6 +819,7 @@
break;
default:
continue;
+
}
signequal2_bb->getInstList().insert(
@@ -964,9 +827,11 @@
icmp_exponent_result);
{
+
term = signequal2_bb->getTerminator();
switch (FcmpInst->getPredicate()) {
+
case CmpInst::FCMP_UEQ:
case CmpInst::FCMP_OEQ:
/* if the exponents are satifying the compare do a fraction cmp in
@@ -989,9 +854,11 @@
break;
default:
continue;
+
}
term->eraseFromParent();
+
}
/* isolate the mantissa aka fraction */
@@ -999,6 +866,7 @@
bool needTrunc = IntFractionTy->getPrimitiveSizeInBits() < op_size;
if (precision - 1 < frTySizeBytes * 8) {
+
Instruction *m_f0, *m_f1;
m_f0 = BinaryOperator::Create(
Instruction::And, b_op0,
@@ -1012,6 +880,7 @@
BasicBlock::iterator(middle_bb->getTerminator()), m_f1);
if (needTrunc) {
+
t_f0 = new TruncInst(m_f0, IntFractionTy);
t_f1 = new TruncInst(m_f1, IntFractionTy);
middle_bb->getInstList().insert(
@@ -1020,12 +889,16 @@
BasicBlock::iterator(middle_bb->getTerminator()), t_f1);
} else {
+
t_f0 = m_f0;
t_f1 = m_f1;
+
}
} else {
+
if (needTrunc) {
+
t_f0 = new TruncInst(b_op0, IntFractionTy);
t_f1 = new TruncInst(b_op1, IntFractionTy);
middle_bb->getInstList().insert(
@@ -1034,9 +907,12 @@
BasicBlock::iterator(middle_bb->getTerminator()), t_f1);
} else {
+
t_f0 = b_op0;
t_f1 = b_op1;
+
}
+
}
/* compare the fractions of the operands */
@@ -1044,6 +920,7 @@
BasicBlock * middle2_bb = middle_bb;
PHINode * PN2 = nullptr;
switch (FcmpInst->getPredicate()) {
+
case CmpInst::FCMP_UEQ:
case CmpInst::FCMP_OEQ:
icmp_fraction_result =
@@ -1066,6 +943,7 @@
case CmpInst::FCMP_UGT:
case CmpInst::FCMP_OLT:
case CmpInst::FCMP_ULT: {
+
Instruction *icmp_fraction_result2;
middle2_bb = middle_bb->splitBasicBlock(
@@ -1078,6 +956,7 @@
if (FcmpInst->getPredicate() == CmpInst::FCMP_OGT ||
FcmpInst->getPredicate() == CmpInst::FCMP_UGT) {
+
negative_bb->getInstList().push_back(
icmp_fraction_result = CmpInst::Create(
Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1));
@@ -1086,12 +965,14 @@
Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1));
} else {
+
negative_bb->getInstList().push_back(
icmp_fraction_result = CmpInst::Create(
Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1));
positive_bb->getInstList().push_back(
icmp_fraction_result2 = CmpInst::Create(
Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1));
+
}
BranchInst::Create(middle2_bb, negative_bb);
@@ -1111,11 +992,13 @@
default:
continue;
+
}
PHINode *PN = PHINode::Create(Int1Ty, 3, "");
switch (FcmpInst->getPredicate()) {
+
case CmpInst::FCMP_UEQ:
case CmpInst::FCMP_OEQ:
/* unequal signs cannot be equal values */
@@ -1154,94 +1037,328 @@
break;
default:
continue;
+
}
BasicBlock::iterator ii(FcmpInst);
ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN);
++count;
+
}
return count;
+
+}
+
+/* splits icmps of size bitw into two nested icmps with bitw/2 size each */
+size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) {
+
+ size_t count = 0;
+
+ LLVMContext &C = M.getContext();
+
+ IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
+ IntegerType *OldIntType = IntegerType::get(C, bitw);
+ IntegerType *NewIntType = IntegerType::get(C, bitw / 2);
+
+ std::vector<Instruction *> icomps;
+
+ if (bitw % 2) { return 0; }
+
+ /* not supported yet */
+ if (bitw > 64) { return 0; }
+
+ /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the
+ * functions simplifyCompares() and simplifyIntSignedness()
+ * were executed only these four predicates should exist */
+ for (auto &F : M) {
+
+ if (!isInInstrumentList(&F)) continue;
+
+ for (auto &BB : F) {
+
+ for (auto &IN : BB) {
+
+ CmpInst *selectcmpInst = nullptr;
+
+ if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
+
+ if (selectcmpInst->getPredicate() == CmpInst::ICMP_EQ ||
+ selectcmpInst->getPredicate() == CmpInst::ICMP_NE ||
+ selectcmpInst->getPredicate() == CmpInst::ICMP_UGT ||
+ selectcmpInst->getPredicate() == CmpInst::ICMP_ULT) {
+
+ auto op0 = selectcmpInst->getOperand(0);
+ auto op1 = selectcmpInst->getOperand(1);
+
+ IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+ IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+
+ if (!intTyOp0 || !intTyOp1) { continue; }
+
+ /* check if the bitwidths are the one we are looking for */
+ if (intTyOp0->getBitWidth() != bitw ||
+ intTyOp1->getBitWidth() != bitw) {
+
+ continue;
+
+ }
+
+ icomps.push_back(selectcmpInst);
+
+ }
+
+ }
+
+ }
+
+ }
+
+ }
+
+ if (!icomps.size()) { return 0; }
+
+ for (auto &IcmpInst : icomps) {
+
+ BasicBlock *bb = IcmpInst->getParent();
+
+ auto op0 = IcmpInst->getOperand(0);
+ auto op1 = IcmpInst->getOperand(1);
+
+ auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
+ if (!cmp_inst) { continue; }
+ auto pred = cmp_inst->getPredicate();
+
+ BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
+
+ /* create the comparison of the top halves of the original operands */
+ Instruction *s_op0, *op0_high, *s_op1, *op1_high, *icmp_high;
+
+ s_op0 = BinaryOperator::Create(Instruction::LShr, op0,
+ ConstantInt::get(OldIntType, bitw / 2));
+ bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0);
+ op0_high = new TruncInst(s_op0, NewIntType);
+ bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
+ op0_high);
+
+ s_op1 = BinaryOperator::Create(Instruction::LShr, op1,
+ ConstantInt::get(OldIntType, bitw / 2));
+ bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1);
+ op1_high = new TruncInst(s_op1, NewIntType);
+ bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
+ op1_high);
+
+ icmp_high = CmpInst::Create(Instruction::ICmp, pred, op0_high, op1_high);
+ bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
+ icmp_high);
+
+ /* now we have to destinguish between == != and > < */
+ if (pred == CmpInst::ICMP_EQ || pred == CmpInst::ICMP_NE) {
+
+ /* transformation for == and != icmps */
+
+ /* create a compare for the lower half of the original operands */
+ Instruction *op0_low, *op1_low, *icmp_low;
+ BasicBlock * cmp_low_bb =
+ BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
+
+ op0_low = new TruncInst(op0, NewIntType);
+ cmp_low_bb->getInstList().push_back(op0_low);
+
+ op1_low = new TruncInst(op1, NewIntType);
+ cmp_low_bb->getInstList().push_back(op1_low);
+
+ icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
+ cmp_low_bb->getInstList().push_back(icmp_low);
+ BranchInst::Create(end_bb, cmp_low_bb);
+
+ /* dependent on the cmp of the high parts go to the end or go on with
+ * the comparison */
+ auto term = bb->getTerminator();
+ if (pred == CmpInst::ICMP_EQ) {
+
+ BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb);
+
+ } else {
+
+ /* CmpInst::ICMP_NE */
+ BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb);
+
+ }
+
+ term->eraseFromParent();
+
+ /* create the PHI and connect the edges accordingly */
+ PHINode *PN = PHINode::Create(Int1Ty, 2, "");
+ PN->addIncoming(icmp_low, cmp_low_bb);
+ if (pred == CmpInst::ICMP_EQ) {
+
+ PN->addIncoming(ConstantInt::get(Int1Ty, 0), bb);
+
+ } else {
+
+ /* CmpInst::ICMP_NE */
+ PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
+
+ }
+
+ /* replace the old icmp with the new PHI */
+ BasicBlock::iterator ii(IcmpInst);
+ ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
+
+ } else {
+
+ /* CmpInst::ICMP_UGT and CmpInst::ICMP_ULT */
+ /* transformations for < and > */
+
+ /* create a basic block which checks for the inverse predicate.
+ * if this is true we can go to the end if not we have to go to the
+ * bb which checks the lower half of the operands */
+ Instruction *icmp_inv_cmp, *op0_low, *op1_low, *icmp_low;
+ BasicBlock * inv_cmp_bb =
+ BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb);
+ if (pred == CmpInst::ICMP_UGT) {
+
+ icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
+ op0_high, op1_high);
+
+ } else {
+
+ icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT,
+ op0_high, op1_high);
+
+ }
+
+ inv_cmp_bb->getInstList().push_back(icmp_inv_cmp);
+
+ auto term = bb->getTerminator();
+ term->eraseFromParent();
+ BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb);
+
+ /* create a bb which handles the cmp of the lower halves */
+ BasicBlock *cmp_low_bb =
+ BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
+ op0_low = new TruncInst(op0, NewIntType);
+ cmp_low_bb->getInstList().push_back(op0_low);
+ op1_low = new TruncInst(op1, NewIntType);
+ cmp_low_bb->getInstList().push_back(op1_low);
+
+ icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
+ cmp_low_bb->getInstList().push_back(icmp_low);
+ BranchInst::Create(end_bb, cmp_low_bb);
+
+ BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb);
+
+ PHINode *PN = PHINode::Create(Int1Ty, 3);
+ PN->addIncoming(icmp_low, cmp_low_bb);
+ PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
+ PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb);
+
+ BasicBlock::iterator ii(IcmpInst);
+ ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
+
+ }
+
+ ++count;
+
+ }
+
+ return count;
+
}
bool SplitComparesTransform::runOnModule(Module &M) {
+
+ int bitw = 64;
+ size_t count = 0;
+
char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW");
if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW");
- if (bitw_env) { target_bitwidth = atoi(bitw_env); }
+ if (bitw_env) { bitw = atoi(bitw_env); }
enableFPSplit = getenv("AFL_LLVM_LAF_SPLIT_FLOATS") != NULL;
if ((isatty(2) && getenv("AFL_QUIET") == NULL) ||
getenv("AFL_DEBUG") != NULL) {
- errs() << "Split-compare-pass by laf.intel@gmail.com, extended by "
- "heiko@hexco.de (splitting icmp to "
- << target_bitwidth << " bit)\n";
- if (getenv("AFL_DEBUG") != NULL && !debug) { debug = 1; }
+ printf(
+ "Split-compare-pass by laf.intel@gmail.com, extended by "
+ "heiko@hexco.de\n");
} else {
+
be_quiet = 1;
+
}
if (enableFPSplit) {
+
count = splitFPCompares(M);
/*
if (!be_quiet) {
+
errs() << "Split-floatingpoint-compare-pass: " << count
<< " FP comparisons split\n";
+
}
+
*/
simplifyFPCompares(M);
+
}
- std::vector<CmpInst *> worklist;
- /* iterate over all functions, bbs and instruction search for all integer
- * compare instructions. Save them into the worklist for later. */
- for (auto &F : M) {
- if (!isInInstrumentList(&F)) continue;
+ simplifyCompares(M);
- for (auto &BB : F) {
- for (auto &IN : BB) {
- if (auto CI = dyn_cast<CmpInst>(&IN)) {
- auto op0 = CI->getOperand(0);
- auto op1 = CI->getOperand(1);
- if (!op0 || !op1) { return false; }
- auto iTy1 = dyn_cast<IntegerType>(op0->getType());
- if (iTy1 && isa<IntegerType>(op1->getType())) {
- unsigned bitw = iTy1->getBitWidth();
- if (isSupportedBitWidth(bitw)) { worklist.push_back(CI); }
- }
- }
- }
- }
+ simplifyIntSignedness(M);
+
+ switch (bitw) {
+
+ case 64:
+ count += splitIntCompares(M, bitw);
+ if (debug)
+ errs() << "Split-integer-compare-pass " << bitw << "bit: " << count
+ << " split\n";
+ bitw >>= 1;
+#if LLVM_VERSION_MAJOR > 3 || \
+ (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7)
+ [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */
+#endif
+ case 32:
+ count += splitIntCompares(M, bitw);
+ if (debug)
+ errs() << "Split-integer-compare-pass " << bitw << "bit: " << count
+ << " split\n";
+ bitw >>= 1;
+#if LLVM_VERSION_MAJOR > 3 || \
+ (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7)
+ [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */
+#endif
+ case 16:
+ count += splitIntCompares(M, bitw);
+ if (debug)
+ errs() << "Split-integer-compare-pass " << bitw << "bit: " << count
+ << " split\n";
+ // bitw >>= 1;
+ break;
+
+ default:
+ // if (!be_quiet) errs() << "NOT Running split-compare-pass \n";
+ return false;
+ break;
+
}
- // now that we have a list of all integer comparisons we can start replacing
- // them with the splitted alternatives.
- for (auto CI : worklist) {
- simplifyAndSplit(CI, M);
- }
-
- bool brokenDebug = false;
- if (verifyModule(M, &errs(), &brokenDebug)) {
- reportError(
- "Module Verifier failed! Consider reporting a bug with the AFL++ "
- "project.",
- nullptr, M);
- }
-
- if (brokenDebug) {
- reportError("Module Verifier reported broken Debug Infos - Stripping!",
- nullptr, M);
- StripDebugInfo(M);
- }
+ verifyModule(M);
return true;
+
}
static void registerSplitComparesPass(const PassManagerBuilder &,
legacy::PassManagerBase &PM) {
+
PM.add(new SplitComparesTransform());
+
}
static RegisterStandardPasses RegisterSplitComparesPass(
@@ -1256,7 +1373,3 @@
registerSplitComparesPass);
#endif
-static RegisterPass<SplitComparesTransform> X("splitcompares",
- "AFL++ split compares",
- true /* Only looks at CFG */,
- true /* Analysis Pass */);
diff --git a/test/test-int_cases.c b/test/test-int_cases.c
deleted file mode 100644
index c76206c..0000000
--- a/test/test-int_cases.c
+++ /dev/null
@@ -1,424 +0,0 @@
-/* test cases for integer comparison transformations
- * compile with -DINT_TYPE="signed char"
- * or -DINT_TYPE="short"
- * or -DINT_TYPE="int"
- * or -DINT_TYPE="long"
- * or -DINT_TYPE="long long"
- */
-
-#include <assert.h>
-
-int main() {
-
- volatile INT_TYPE a, b;
- /* different values */
- a = -21;
- b = -2; /* signs equal */
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- a = 1;
- b = 8; /* signs equal */
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((unsigned)(INT_TYPE)(~0) > 255) { /* short or bigger */
- volatile short a, b;
- a = 2;
- b = 256+1; /* signs equal */
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- a = -1 - 256;
- b = -8; /* signs equal */
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((unsigned)(INT_TYPE)(~0) > 65535) { /* int or bigger */
- volatile int a, b;
- a = 2;
- b = 65536+1; /* signs equal */
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- a = -1 - 65536;
- b = -8; /* signs equal */
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((unsigned)(INT_TYPE)(~0) > 4294967295) { /* long or bigger */
- volatile long a, b;
- a = 2;
- b = 4294967296+1; /* signs equal */
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- a = -1 - 4294967296;
- b = -8; /* signs equal */
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- }
- }
- }
-
- a = -1;
- b = 1; /* signs differ */
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- a = -1;
- b = 0; /* signs differ */
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- a = -2;
- b = 8; /* signs differ */
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- a = -1;
- b = -2; /* signs equal */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = 8;
- b = 1; /* signs equal */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((unsigned)(INT_TYPE)(~0) > 255) {
- volatile short a, b;
- a = 1 + 256;
- b = 3; /* signs equal */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = -1;
- b = -256; /* signs equal */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((unsigned)(INT_TYPE)(~0) > 65535) {
- volatile int a, b;
- a = 1 + 65536;
- b = 3; /* signs equal */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = -1;
- b = -65536; /* signs equal */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((unsigned)(INT_TYPE)(~0) > 4294967295) {
- volatile long a, b;
- a = 1 + 4294967296;
- b = 3; /* signs equal */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = -1;
- b = -4294967296; /* signs equal */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
- }
- }
- }
-
- a = 1;
- b = -1; /* signs differ */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = 0;
- b = -1; /* signs differ */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = 8;
- b = -2; /* signs differ */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = 1;
- b = -2; /* signs differ */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((unsigned)(INT_TYPE)(~0) > 255) {
- volatile short a, b;
- a = 1 + 256;
- b = -2; /* signs differ */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = -1;
- b = -2 - 256; /* signs differ */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((unsigned)(INT_TYPE)(~0) > 65535) {
- volatile int a, b;
- a = 1 + 65536;
- b = -2; /* signs differ */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = -1;
- b = -2 - 65536; /* signs differ */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((unsigned)(INT_TYPE)(~0) > 4294967295) {
- volatile long a, b;
- a = 1 + 4294967296;
- b = -2; /* signs differ */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = -1;
- b = -2 - 4294967296; /* signs differ */
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- }
- }
- }
-
- /* equal values */
- a = 0;
- b = 0;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- a = -0;
- b = 0;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- a = 1;
- b = 1;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- a = 5;
- b = 5;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- a = -1;
- b = -1;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- a = -5;
- b = -5;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- if ((unsigned)(INT_TYPE)(~0) > 255) {
- volatile short a, b;
- a = 1 + 256;
- b = 1 + 256;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- a = -2 - 256;
- b = -2 - 256;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- if ((unsigned)(INT_TYPE)(~0) > 65535) {
- volatile int a, b;
- a = 1 + 65536;
- b = 1 + 65536;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- a = -2 - 65536;
- b = -2 - 65536;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- if ((unsigned)(INT_TYPE)(~0) > 4294967295) {
- volatile long a, b;
- a = 1 + 4294967296;
- b = 1 + 4294967296;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- a = -2 - 4294967296;
- b = -2 - 4294967296;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- }
- }
- }
-}
-
diff --git a/test/test-uint_cases.c b/test/test-uint_cases.c
deleted file mode 100644
index 8496cff..0000000
--- a/test/test-uint_cases.c
+++ /dev/null
@@ -1,217 +0,0 @@
-/*
- * compile with -DUINT_TYPE="unsigned char"
- * or -DUINT_TYPE="unsigned short"
- * or -DUINT_TYPE="unsigned int"
- * or -DUINT_TYPE="unsigned long"
- * or -DUINT_TYPE="unsigned long long"
- */
-
-#include <assert.h>
-
-int main() {
-
- volatile UINT_TYPE a, b;
-
- a = 1;
- b = 8;
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((UINT_TYPE)(~0) > 255) {
- volatile unsigned short a, b;
- a = 256+2;
- b = 256+21;
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- a = 21;
- b = 256+1;
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((UINT_TYPE)(~0) > 65535) {
- volatile unsigned int a, b;
- a = 65536+2;
- b = 65536+21;
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- a = 21;
- b = 65536+1;
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
- }
-
- if ((UINT_TYPE)(~0) > 4294967295) {
- volatile unsigned long a, b;
- a = 4294967296+2;
- b = 4294967296+21;
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
-
- a = 21;
- b = 4294967296+1;
- assert((a < b));
- assert((a <= b));
- assert(!(a > b));
- assert(!(a >= b));
- assert((a != b));
- assert(!(a == b));
- }
- }
-
- a = 8;
- b = 1;
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((UINT_TYPE)(~0) > 255) {
- volatile unsigned short a, b;
- a = 256+2;
- b = 256+1;
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = 256+2;
- b = 6;
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((UINT_TYPE)(~0) > 65535) {
- volatile unsigned int a, b;
- a = 65536+2;
- b = 65536+1;
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = 65536+2;
- b = 6;
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- if ((UINT_TYPE)(~0) > 4294967295) {
- volatile unsigned long a, b;
- a = 4294967296+2;
- b = 4294967296+1;
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- a = 4294967296+2;
- b = 6;
- assert((a > b));
- assert((a >= b));
- assert(!(a < b));
- assert(!(a <= b));
- assert((a != b));
- assert(!(a == b));
-
- }
- }
- }
-
-
- a = 0;
- b = 0;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- a = 1;
- b = 1;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- if ((UINT_TYPE)(~0) > 255) {
- volatile unsigned short a, b;
- a = 256+5;
- b = 256+5;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- if ((UINT_TYPE)(~0) > 65535) {
- volatile unsigned int a, b;
- a = 65536+5;
- b = 65536+5;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
-
- if ((UINT_TYPE)(~0) > 4294967295) {
- volatile unsigned long a, b;
- a = 4294967296+5;
- b = 4294967296+5;
- assert(!(a < b));
- assert((a <= b));
- assert(!(a > b));
- assert((a >= b));
- assert(!(a != b));
- assert((a == b));
- }
- }
-
- }
-
-}
-
diff --git a/utils/crash_triage/triage_crashes.sh b/utils/crash_triage/triage_crashes.sh
index 9ca1d5f..4d75430 100755
--- a/utils/crash_triage/triage_crashes.sh
+++ b/utils/crash_triage/triage_crashes.sh
@@ -65,11 +65,7 @@
fi
if [ ! -d "$DIR/queue" ]; then
-<<<<<<< Updated upstream
echo "[-] Error: directory '$DIR' not found or not created by afl-fuzz." 1>&2
-=======
- echo "[-] Error: directory '$DIR/queue' not found or not created by afl-fuzz." 1>&2
->>>>>>> Stashed changes
exit 1
fi