| /* |
| * Copyright 2016 laf-intel |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include <stdio.h> |
| #include <stdlib.h> |
| #include <unistd.h> |
| |
| #include <list> |
| #include <string> |
| #include <fstream> |
| #include <sys/time.h> |
| #include "llvm/Config/llvm-config.h" |
| |
| #include "llvm/ADT/Statistic.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/LegacyPassManager.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "llvm/Transforms/IPO/PassManagerBuilder.h" |
| #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| #include "llvm/Pass.h" |
| #include "llvm/Analysis/ValueTracking.h" |
| |
| #if LLVM_VERSION_MAJOR > 3 || \ |
| (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4) |
| #include "llvm/IR/Verifier.h" |
| #include "llvm/IR/DebugInfo.h" |
| #else |
| #include "llvm/Analysis/Verifier.h" |
| #include "llvm/DebugInfo.h" |
| #define nullptr 0 |
| #endif |
| |
| #include <set> |
| #include "afl-llvm-common.h" |
| |
| using namespace llvm; |
| |
| namespace { |
| |
| class CompareTransform : public ModulePass { |
| |
| public: |
| static char ID; |
| CompareTransform() : ModulePass(ID) { |
| |
| initWhitelist(); |
| |
| } |
| |
| bool runOnModule(Module &M) override; |
| |
| #if LLVM_VERSION_MAJOR < 4 |
| const char *getPassName() const override { |
| |
| #else |
| StringRef getPassName() const override { |
| |
| #endif |
| return "transforms compare functions"; |
| |
| } |
| |
| protected: |
| int be_quiet = 0; |
| |
| private: |
| bool transformCmps(Module &M, const bool processStrcmp, |
| const bool processMemcmp, const bool processStrncmp, |
| const bool processStrcasecmp, |
| const bool processStrncasecmp); |
| |
| }; |
| |
| } // namespace |
| |
| char CompareTransform::ID = 0; |
| |
| bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, |
| const bool processMemcmp, |
| const bool processStrncmp, |
| const bool processStrcasecmp, |
| const bool processStrncasecmp) { |
| |
| DenseMap<Value *, std::string *> valueMap; |
| std::vector<CallInst *> calls; |
| LLVMContext & C = M.getContext(); |
| IntegerType * Int8Ty = IntegerType::getInt8Ty(C); |
| IntegerType * Int32Ty = IntegerType::getInt32Ty(C); |
| IntegerType * Int64Ty = IntegerType::getInt64Ty(C); |
| |
| #if LLVM_VERSION_MAJOR < 9 |
| Constant * |
| #else |
| FunctionCallee |
| #endif |
| c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty |
| #if LLVM_VERSION_MAJOR < 5 |
| , |
| NULL |
| #endif |
| ); |
| #if LLVM_VERSION_MAJOR < 9 |
| Function *tolowerFn = cast<Function>(c); |
| #else |
| FunctionCallee tolowerFn = c; |
| #endif |
| |
| /* iterate over all functions, bbs and instruction and add suitable calls to |
| * strcmp/memcmp/strncmp/strcasecmp/strncasecmp */ |
| for (auto &F : M) { |
| |
| if (!isInWhitelist(&F)) continue; |
| |
| for (auto &BB : F) { |
| |
| for (auto &IN : BB) { |
| |
| CallInst *callInst = nullptr; |
| |
| if ((callInst = dyn_cast<CallInst>(&IN))) { |
| |
| bool isStrcmp = processStrcmp; |
| bool isMemcmp = processMemcmp; |
| bool isStrncmp = processStrncmp; |
| bool isStrcasecmp = processStrcasecmp; |
| bool isStrncasecmp = processStrncasecmp; |
| bool isIntMemcpy = true; |
| bool indirect = false; |
| |
| Function *Callee = callInst->getCalledFunction(); |
| if (!Callee) continue; |
| if (callInst->getCallingConv() != llvm::CallingConv::C) continue; |
| StringRef FuncName = Callee->getName(); |
| isStrcmp &= !FuncName.compare(StringRef("strcmp")); |
| isMemcmp &= !FuncName.compare(StringRef("memcmp")); |
| isStrncmp &= !FuncName.compare(StringRef("strncmp")); |
| isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp")); |
| isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp")); |
| isIntMemcpy &= !FuncName.compare("llvm.memcpy.p0i8.p0i8.i64"); |
| |
| if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && |
| !isStrncasecmp && !isIntMemcpy) |
| continue; |
| |
| /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function |
| * prototype */ |
| FunctionType *FT = Callee->getFunctionType(); |
| |
| isStrcmp &= |
| FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && |
| FT->getParamType(0) == FT->getParamType(1) && |
| FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); |
| isStrcasecmp &= |
| FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && |
| FT->getParamType(0) == FT->getParamType(1) && |
| FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); |
| isMemcmp &= FT->getNumParams() == 3 && |
| FT->getReturnType()->isIntegerTy(32) && |
| FT->getParamType(0)->isPointerTy() && |
| FT->getParamType(1)->isPointerTy() && |
| FT->getParamType(2)->isIntegerTy(); |
| isStrncmp &= FT->getNumParams() == 3 && |
| FT->getReturnType()->isIntegerTy(32) && |
| FT->getParamType(0) == FT->getParamType(1) && |
| FT->getParamType(0) == |
| IntegerType::getInt8PtrTy(M.getContext()) && |
| FT->getParamType(2)->isIntegerTy(); |
| isStrncasecmp &= FT->getNumParams() == 3 && |
| FT->getReturnType()->isIntegerTy(32) && |
| FT->getParamType(0) == FT->getParamType(1) && |
| FT->getParamType(0) == |
| IntegerType::getInt8PtrTy(M.getContext()) && |
| FT->getParamType(2)->isIntegerTy(); |
| |
| if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && |
| !isStrncasecmp && !isIntMemcpy) |
| continue; |
| |
| /* is a str{n,}{case,}cmp/memcmp, check if we have |
| * str{case,}cmp(x, "const") or str{case,}cmp("const", x) |
| * strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..) |
| * memcmp(x, "const", ..) or memcmp("const", x, ..) */ |
| Value *Str1P = callInst->getArgOperand(0), |
| *Str2P = callInst->getArgOperand(1); |
| StringRef Str1, Str2; |
| bool HasStr1 = getConstantStringInfo(Str1P, Str1); |
| bool HasStr2 = getConstantStringInfo(Str2P, Str2); |
| |
| if (isIntMemcpy && HasStr2) { |
| |
| valueMap[Str1P] = new std::string(Str2.str()); |
| // fprintf(stderr, "saved %s for %p\n", Str2.str().c_str(), Str1P); |
| continue; |
| |
| } |
| |
| // not literal? maybe global or local variable |
| if (!(HasStr1 || HasStr2)) { |
| |
| auto *Ptr = dyn_cast<ConstantExpr>(Str2P); |
| if (Ptr && Ptr->isGEPWithNoNotionalOverIndexing()) { |
| |
| if (auto *Var = dyn_cast<GlobalVariable>(Ptr->getOperand(0))) { |
| |
| if (Var->hasInitializer()) { |
| |
| if (auto *Array = |
| dyn_cast<ConstantDataArray>(Var->getInitializer())) { |
| |
| HasStr2 = true; |
| Str2 = Array->getAsString(); |
| valueMap[Str2P] = new std::string(Str2.str()); |
| fprintf(stderr, "glo2 %s\n", Str2.str().c_str()); |
| |
| } |
| |
| } |
| |
| } |
| |
| } |
| |
| if (!HasStr2) { |
| |
| auto *Ptr = dyn_cast<ConstantExpr>(Str1P); |
| if (Ptr && Ptr->isGEPWithNoNotionalOverIndexing()) { |
| |
| if (auto *Var = dyn_cast<GlobalVariable>(Ptr->getOperand(0))) { |
| |
| if (Var->hasInitializer()) { |
| |
| if (auto *Array = dyn_cast<ConstantDataArray>( |
| Var->getInitializer())) { |
| |
| HasStr1 = true; |
| Str1 = Array->getAsString(); |
| valueMap[Str1P] = new std::string(Str1.str()); |
| // fprintf(stderr, "glo1 %s\n", Str1.str().c_str()); |
| |
| } |
| |
| } |
| |
| } |
| |
| } |
| |
| } else if (isIntMemcpy) { |
| |
| valueMap[Str1P] = new std::string(Str2.str()); |
| // fprintf(stderr, "saved\n"); |
| |
| } |
| |
| if ((HasStr1 || HasStr2)) indirect = true; |
| |
| } |
| |
| if (isIntMemcpy) continue; |
| |
| if (!(HasStr1 || HasStr2)) { |
| |
| // do we have a saved local variable initialization? |
| std::string *val = valueMap[Str1P]; |
| if (val && !val->empty()) { |
| |
| Str1 = StringRef(*val); |
| HasStr1 = true; |
| indirect = true; |
| // fprintf(stderr, "loaded1 %s\n", Str1.str().c_str()); |
| |
| } else { |
| |
| val = valueMap[Str2P]; |
| if (val && !val->empty()) { |
| |
| Str2 = StringRef(*val); |
| HasStr2 = true; |
| indirect = true; |
| // fprintf(stderr, "loaded2 %s\n", Str2.str().c_str()); |
| |
| } |
| |
| } |
| |
| } |
| |
| /* handle cases of one string is const, one string is variable */ |
| if (!(HasStr1 || HasStr2)) continue; |
| |
| if (isMemcmp || isStrncmp || isStrncasecmp) { |
| |
| /* check if third operand is a constant integer |
| * strlen("constStr") and sizeof() are treated as constant */ |
| Value * op2 = callInst->getArgOperand(2); |
| ConstantInt *ilen = dyn_cast<ConstantInt>(op2); |
| if (ilen) { |
| |
| uint64_t len = ilen->getZExtValue(); |
| // if len is zero this is a pointless call but allow real |
| // implementation to worry about that |
| if (!len) continue; |
| |
| if (isMemcmp) { |
| |
| // if size of compare is larger than constant string this is |
| // likely a bug but allow real implementation to worry about |
| // that |
| uint64_t literalLength = HasStr1 ? Str1.size() : Str2.size(); |
| if (literalLength + 1 < ilen->getZExtValue()) continue; |
| |
| } |
| |
| } else if (isMemcmp) |
| |
| // this *may* supply a len greater than the constant string at |
| // runtime so similarly we don't want to have to handle that |
| continue; |
| |
| } |
| |
| calls.push_back(callInst); |
| |
| } |
| |
| } |
| |
| } |
| |
| } |
| |
| if (!calls.size()) return false; |
| if (!be_quiet) |
| errs() << "Replacing " << calls.size() |
| << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n"; |
| |
| for (auto &callInst : calls) { |
| |
| Value *Str1P = callInst->getArgOperand(0), |
| *Str2P = callInst->getArgOperand(1); |
| StringRef Str1, Str2, ConstStr; |
| std::string TmpConstStr; |
| Value * VarStr; |
| bool HasStr1 = getConstantStringInfo(Str1P, Str1); |
| bool HasStr2 = getConstantStringInfo(Str2P, Str2); |
| uint64_t constStrLen, constSizedLen, unrollLen; |
| bool isMemcmp = |
| !callInst->getCalledFunction()->getName().compare(StringRef("memcmp")); |
| bool isSizedcmp = isMemcmp || |
| !callInst->getCalledFunction()->getName().compare( |
| StringRef("strncmp")) || |
| !callInst->getCalledFunction()->getName().compare( |
| StringRef("strncasecmp")); |
| Value *sizedValue = isSizedcmp ? callInst->getArgOperand(2) : NULL; |
| bool isConstSized = sizedValue && isa<ConstantInt>(sizedValue); |
| bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare( |
| StringRef("strcasecmp")) || |
| !callInst->getCalledFunction()->getName().compare( |
| StringRef("strncasecmp")); |
| |
| if (!(HasStr1 || HasStr2)) { |
| |
| // do we have a saved local or global variable initialization? |
| std::string *val = valueMap[Str1P]; |
| if (val && !val->empty()) { |
| |
| Str1 = StringRef(*val); |
| HasStr1 = true; |
| |
| } else { |
| |
| val = valueMap[Str2P]; |
| if (val && !val->empty()) { |
| |
| Str2 = StringRef(*val); |
| HasStr2 = true; |
| |
| } |
| |
| } |
| |
| } |
| |
| if (isConstSized) { |
| |
| constSizedLen = dyn_cast<ConstantInt>(sizedValue)->getZExtValue(); |
| |
| } |
| |
| if (HasStr1) { |
| |
| TmpConstStr = Str1.str(); |
| VarStr = Str2P; |
| |
| } else { |
| |
| TmpConstStr = Str2.str(); |
| VarStr = Str1P; |
| |
| } |
| |
| // add null termination character implicit in c strings |
| TmpConstStr.append("\0", 1); |
| |
| // in the unusual case the const str has embedded null |
| // characters, the string comparison functions should terminate |
| // at the first null |
| if (!isMemcmp) |
| TmpConstStr.assign(TmpConstStr, 0, TmpConstStr.find('\0') + 1); |
| |
| constStrLen = TmpConstStr.length(); |
| // prefer use of StringRef (in comparison to std::string a StringRef has |
| // built-in runtime bounds checking, which makes debugging easier) |
| ConstStr = StringRef(TmpConstStr); |
| |
| if (isConstSized) |
| unrollLen = constSizedLen < constStrLen ? constSizedLen : constStrLen; |
| else |
| unrollLen = constStrLen; |
| |
| if (!be_quiet) |
| errs() << callInst->getCalledFunction()->getName() << ": unroll len " |
| << unrollLen |
| << ((isSizedcmp && !isConstSized) ? ", variable n" : "") << ": " |
| << ConstStr << "\n"; |
| |
| /* split before the call instruction */ |
| BasicBlock *bb = callInst->getParent(); |
| BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst)); |
| |
| BasicBlock *next_lenchk_bb = NULL; |
| if (isSizedcmp && !isConstSized) { |
| |
| next_lenchk_bb = |
| BasicBlock::Create(C, "len_check", end_bb->getParent(), end_bb); |
| BranchInst::Create(end_bb, next_lenchk_bb); |
| |
| } |
| |
| BasicBlock *next_cmp_bb = |
| BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); |
| BranchInst::Create(end_bb, next_cmp_bb); |
| PHINode *PN = PHINode::Create( |
| Int32Ty, (next_lenchk_bb ? 2 : 1) * unrollLen + 1, "cmp_phi"); |
| |
| #if LLVM_VERSION_MAJOR < 8 |
| TerminatorInst *term = bb->getTerminator(); |
| #else |
| Instruction *term = bb->getTerminator(); |
| #endif |
| BranchInst::Create(next_lenchk_bb ? next_lenchk_bb : next_cmp_bb, bb); |
| term->eraseFromParent(); |
| |
| for (uint64_t i = 0; i < unrollLen; i++) { |
| |
| BasicBlock * cur_cmp_bb = next_cmp_bb, *cur_lenchk_bb = next_lenchk_bb; |
| unsigned char c; |
| |
| if (cur_lenchk_bb) { |
| |
| IRBuilder<> cur_lenchk_IRB(&*(cur_lenchk_bb->getFirstInsertionPt())); |
| Value * icmp = cur_lenchk_IRB.CreateICmpEQ(sizedValue, |
| ConstantInt::get(Int64Ty, i)); |
| cur_lenchk_IRB.CreateCondBr(icmp, end_bb, cur_cmp_bb); |
| cur_lenchk_bb->getTerminator()->eraseFromParent(); |
| |
| PN->addIncoming(ConstantInt::get(Int32Ty, 0), cur_lenchk_bb); |
| |
| } |
| |
| if (isCaseInsensitive) |
| c = (unsigned char)(tolower((int)ConstStr[i]) & 0xff); |
| else |
| c = (unsigned char)ConstStr[i]; |
| |
| IRBuilder<> cur_cmp_IRB(&*(cur_cmp_bb->getFirstInsertionPt())); |
| |
| Value *v = ConstantInt::get(Int64Ty, i); |
| Value *ele = cur_cmp_IRB.CreateInBoundsGEP(VarStr, v, "empty"); |
| Value *load = cur_cmp_IRB.CreateLoad(ele); |
| |
| if (isCaseInsensitive) { |
| |
| // load >= 'A' && load <= 'Z' ? load | 0x020 : load |
| load = cur_cmp_IRB.CreateZExt(load, Int32Ty); |
| std::vector<Value *> args; |
| args.push_back(load); |
| load = cur_cmp_IRB.CreateCall(tolowerFn, args); |
| load = cur_cmp_IRB.CreateTrunc(load, Int8Ty); |
| |
| } |
| |
| Value *isub; |
| if (HasStr1) |
| isub = cur_cmp_IRB.CreateSub(ConstantInt::get(Int8Ty, c), load); |
| else |
| isub = cur_cmp_IRB.CreateSub(load, ConstantInt::get(Int8Ty, c)); |
| |
| Value *sext = cur_cmp_IRB.CreateSExt(isub, Int32Ty); |
| PN->addIncoming(sext, cur_cmp_bb); |
| |
| if (i < unrollLen - 1) { |
| |
| if (cur_lenchk_bb) { |
| |
| next_lenchk_bb = |
| BasicBlock::Create(C, "len_check", end_bb->getParent(), end_bb); |
| BranchInst::Create(end_bb, next_lenchk_bb); |
| |
| } |
| |
| next_cmp_bb = |
| BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); |
| BranchInst::Create(end_bb, next_cmp_bb); |
| |
| Value *icmp = |
| cur_cmp_IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0)); |
| cur_cmp_IRB.CreateCondBr( |
| icmp, next_lenchk_bb ? next_lenchk_bb : next_cmp_bb, end_bb); |
| cur_cmp_bb->getTerminator()->eraseFromParent(); |
| |
| } else { |
| |
| // IRB.CreateBr(end_bb); |
| |
| } |
| |
| // add offset to varstr |
| // create load |
| // create signed isub |
| // create icmp |
| // create jcc |
| // create next_bb |
| |
| } |
| |
| /* since the call is the first instruction of the bb it is safe to |
| * replace it with a phi instruction */ |
| BasicBlock::iterator ii(callInst); |
| ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN); |
| |
| } |
| |
| return true; |
| |
| } |
| |
| bool CompareTransform::runOnModule(Module &M) { |
| |
| if ((isatty(2) && getenv("AFL_QUIET") == NULL) || getenv("AFL_DEBUG") != NULL) |
| llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, " |
| "extended by heiko@hexco.de\n"; |
| else |
| be_quiet = 1; |
| transformCmps(M, true, true, true, true, true); |
| verifyModule(M); |
| |
| return true; |
| |
| } |
| |
| static void registerCompTransPass(const PassManagerBuilder &, |
| legacy::PassManagerBase &PM) { |
| |
| auto p = new CompareTransform(); |
| PM.add(p); |
| |
| } |
| |
| static RegisterStandardPasses RegisterCompTransPass( |
| PassManagerBuilder::EP_OptimizerLast, registerCompTransPass); |
| |
| static RegisterStandardPasses RegisterCompTransPass0( |
| PassManagerBuilder::EP_EnabledOnOptLevel0, registerCompTransPass); |
| |