blob: cb29175b1b2521c7b72e9320bd198e7514bc70e5 [file] [log] [blame]
// Copyright 2017 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
// This is a very rough prototype of an utility that extracts syscall descriptions from header files.
// It needs to extract struct/union descriptions, better analyze types,
// analyze pointer directions (in, out), figure out len types (usually marked with sal).
// The easiest way to build it is to build it as part of clang. Add the following lines to CMakeLists.txt:
// +add_clang_executable(syz-declextract syz-declextract/syz-declextract.cpp)
// +target_link_libraries(syz-declextract clangTooling)
// It was used to extract windows descriptions:
// syz-declextract -extra-arg="--driver-mode=cl" -extra-arg="-I/path/to/windows/headers" Windows.h
#include "clang/AST/AST.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Driver/Options.h"
#include "clang/Frontend/ASTConsumers.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Rewrite/Core/Rewriter.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Tooling.h"
using namespace clang;
using namespace clang::tooling;
std::string convertType(ASTContext &C, QualType T) {
auto name = T.getAsString();
if (name == "HANDLE")
return name;
if (T->isIntegralOrEnumerationType()) {
int size = C.getTypeSize(T);
char buf[10];
sprintf(buf, "int%d", size);
return buf;
}
if (T->isVoidPointerType()) {
return "ptr[inout, array[int8]]";
}
if (T->isPointerType()) {
auto inner = convertType(C, T->getPointeeType());
if (inner == "")
return "ptr[inout, array[int8]]";
char buf[1024];
sprintf(buf, "ptr[inout, %s]", inner.c_str());
return buf;
}
return "intptr";
}
class DeclExtractCallVisitor : public RecursiveASTVisitor<DeclExtractCallVisitor> {
public:
explicit DeclExtractCallVisitor(ASTContext *Context)
: Context(*Context) {}
bool VisitFunctionDecl(const FunctionDecl *D) {
if (D->doesThisDeclarationHaveABody())
return true;
// TODO(dvyukov): need to select only stdcall (WINAPI) functions.
// But the following 2 approaches do not work.
if (false) {
if (auto *FPT = D->getType()->getAs<FunctionProtoType>()) {
if (FPT->getExtInfo().getCC() != CC_X86StdCall)
return true;
}
}
if (false) {
if (!D->hasAttr<StdCallAttr>())
return true;
}
// Tons of functions are bulk ignored below because they cause
// static/dynamic link failures, reboot machine, etc.
auto fn = D->getNameInfo().getAsString();
if (fn.empty()) return true;
if (*fn.rbegin() == 'W') return true; // Unicode versions.
const char *ignore_prefixes[] {
"_",
"Rtl",
"IBind",
"Ndr",
"NDR",
"SCard",
};
for (auto prefix: ignore_prefixes) {
if (strncmp(fn.c_str(), prefix, strlen(prefix)) == 0) return true;
}
const char *ignore_functions[] {
"IEnum",
"IStream",
"IType",
"IService",
"IProperty",
"ISequential",
"IDispatch",
"I_RPC",
"I_Rpc",
"CLEANLOCAL",
"WinMain",
"PropertySheet",
"LookupAccountNameLocalA",
"LookupAccountSidLocalA",
"WTSGetServiceSessionId",
"WTSIsServerContainer",
"GetDisplayAutoRotationPreferencesByProcessId",
"LoadStringByReference",
"IdnToNameprepUnicode",
"VerFindFileA",
"VerInstallFileA",
"GetFileVersionInfoSizeA",
"GetFileVersionInfoA",
"GetFileVersionInfoSizeExA",
"GetFileVersionInfoExA",
"VerQueryValueA",
"sndOpenSound",
"Netbios",
"RpcBindingGetTrainingContextHandle",
"RpcAsyncCleanupThread",
"ShellMessageBoxA",
"SHEnumerateUnreadMailAccountsA",
"SHGetUnreadMailCountA",
"SHSetUnreadMailCountA",
"GetEncSChannel",
"CryptExportPKCS8Ex",
"FindCertsByIssuer",
"CryptCancelAsyncRetrieval",
"CryptGetTimeValidObject",
"CryptFlushTimeValidObject",
"CryptProtectDataNoUI",
"CryptUnprotectDataNoUI",
"NsServerBindSearch",
"NsClientBindSearch",
"NsClientBindDone",
"GetOpenCardNameA",
"SubscribeServiceChangeNotifications",
"UnsubscribeServiceChangeNotifications",
"GetThreadDescription",
"SetThreadDescription",
"DialogControlDpi",
"SetDialogDpiChangeBehavior",
"GetDialogDpiChangeBehavior",
"RpcServer",
"DecodePointer",
"DecodeRemotePointer",
"DecodeSystemPointer",
"EncodePointer",
"EncodeRemotePointer",
"EncodeSystemPointer",
"UnmapViewOfFile2",
"MapViewOfFileNuma2",
"DeriveCapabilitySidsFromName",
"QueryAuxiliaryCounterFrequency",
"ConvertPerformanceCounterToAuxiliaryCounter",
"ConvertAuxiliaryCounterToPerformanceCounter",
"FreePropVariantArray",
"PropVariantCopy",
"PropVariantClear",
"InitiateShutdown",
"ExitWindowsEx",
"LockWorkStation",
"InitiateSystemShutdown",
"InitiateSystemShutdownEx",
"shutdown",
};
for (auto func: ignore_functions) {
if (strstr(fn.c_str(), func)) return true;
}
// These are already described:
const char *ignore_exact[] {
"CreateFileA",
"CloseHandle",
"VirtualAlloc",
};
for (auto func: ignore_exact) {
if (strcmp(fn.c_str(), func) == 0) return true;
}
const char *ignore_files[] {
"/um/ole",
"htiface.h",
"objbase.h",
"HLink.h",
"urlmon.h",
"HlGuids.h",
"unknwn.h",
"unknwnbase.h",
"coguid.h",
"MsHtmHst.h",
"msime.h",
"ComSvcs.h",
"combaseapi.h",
"WbemGlue.h",
"OCIdl.h",
"mfapi.h",
"CompPkgSup.h",
"ole2.h",
"Ole2.h",
"oleidl.h",
"ObjIdl.h",
"WabDefs.h",
"objidl.h",
};
auto src = D->getSourceRange().getBegin().printToString(Context.getSourceManager());
if (strstr(src.c_str(), "/um/") == 0) return true;
for (auto file: ignore_files) {
if (strstr(src.c_str(), file)) return true;
}
for (const ParmVarDecl *P : D->parameters()) {
auto typ = convertType(Context, P->getType());
if (typ == "") {
llvm::outs() << D->getNameInfo().getAsString() << ": UNKNOWN TYPE: " <<
QualType(P->getType()).getAsString() << "\n";
return true;
}
}
if (Generated[D->getNameInfo().getAsString()])
return true;
Generated[D->getNameInfo().getAsString()] = true;
llvm::outs() << D->getNameInfo().getAsString() << "(";
int i = 0;
for (const ParmVarDecl *P : D->parameters()) {
if (i)
llvm::outs() << ", ";
auto name = P->getNameAsString();
if (name == "") {
char buf[10];
sprintf(buf, "arg%d", i);
name = buf;
}
llvm::outs() << name << " " << convertType(Context, P->getType());
i++;
if (i == 9)
break;
}
llvm::outs() << ")";
auto ret = convertType(Context, D->getReturnType());
if (ret == "HANDLE")
llvm::outs() << " " << ret;
llvm::outs() << "\n";
return true;
}
private:
ASTContext &Context;
std::map<std::string, bool> Generated;
};
class DeclExtractCallConsumer : public clang::ASTConsumer {
public:
explicit DeclExtractCallConsumer(ASTContext *Context)
: Visitor(Context) {}
virtual void HandleTranslationUnit(clang::ASTContext &Context) {
Visitor.TraverseDecl(Context.getTranslationUnitDecl());
}
private:
DeclExtractCallVisitor Visitor;
};
class DeclExtractCallAction : public clang::ASTFrontendAction {
public:
DeclExtractCallAction() {}
virtual std::unique_ptr<clang::ASTConsumer> CreateASTConsumer(
clang::CompilerInstance &Compiler, llvm::StringRef InFile) {
return std::unique_ptr<clang::ASTConsumer>(
new DeclExtractCallConsumer(&Compiler.getASTContext()));
}
};
static llvm::cl::OptionCategory MyToolCategory("my-tool options");
int main(int argc, const char **argv) {
CommonOptionsParser OptionsParser(argc, argv, MyToolCategory);
ClangTool Tool(OptionsParser.getCompilations(),
OptionsParser.getSourcePathList());
return Tool.run(newFrontendActionFactory<DeclExtractCallAction>().get());
}