blob: db37fe4ab6fd9549fe5612fa50fa5dbe2bc5687a [file] [log] [blame]
// Copyright 2017 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package main
import (
"bytes"
"fmt"
"os"
"regexp"
"strconv"
"strings"
"text/template"
"github.com/google/syzkaller/pkg/compiler"
"github.com/google/syzkaller/pkg/osutil"
)
func extract(info *compiler.ConstInfo, cc string, args []string, addSource string, declarePrintf bool) (
map[string]uint64, map[string]bool, error) {
data := &CompileData{
AddSource: addSource,
Defines: info.Defines,
Includes: info.Includes,
Values: info.Consts,
DeclarePrintf: declarePrintf,
}
undeclared := make(map[string]bool)
bin, out, err := compile(cc, args, data)
if err != nil {
// Some consts and syscall numbers are not defined on some archs.
// Figure out from compiler output undefined consts,
// and try to compile again without them.
valMap := make(map[string]bool)
for _, val := range info.Consts {
valMap[val] = true
}
for _, errMsg := range []string{
"error: ‘([a-zA-Z0-9_]+)’ undeclared",
"error: '([a-zA-Z0-9_]+)' undeclared",
"note: in expansion of macro ‘([a-zA-Z0-9_]+)’",
"error: use of undeclared identifier '([a-zA-Z0-9_]+)'",
} {
re := regexp.MustCompile(errMsg)
matches := re.FindAllSubmatch(out, -1)
for _, match := range matches {
val := string(match[1])
if valMap[val] {
undeclared[val] = true
}
}
}
data.Values = nil
for _, v := range info.Consts {
if undeclared[v] {
continue
}
data.Values = append(data.Values, v)
}
bin, out, err = compile(cc, args, data)
if err != nil {
return nil, nil, fmt.Errorf("failed to run compiler: %v\n%v", err, string(out))
}
}
defer os.Remove(bin)
out, err = osutil.Command(bin).CombinedOutput()
if err != nil {
return nil, nil, fmt.Errorf("failed to run flags binary: %v\n%v", err, string(out))
}
flagVals := strings.Split(string(out), " ")
if len(out) == 0 {
flagVals = nil
}
if len(flagVals) != len(data.Values) {
return nil, nil, fmt.Errorf("fetched wrong number of values %v, want != %v",
len(flagVals), len(data.Values))
}
res := make(map[string]uint64)
for i, name := range data.Values {
val := flagVals[i]
n, err := strconv.ParseUint(val, 10, 64)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse value: %v (%v)", err, val)
}
res[name] = n
}
return res, undeclared, nil
}
type CompileData struct {
AddSource string
Defines map[string]string
Includes []string
Values []string
DeclarePrintf bool
}
func compile(cc string, args []string, data *CompileData) (bin string, out []byte, err error) {
src := new(bytes.Buffer)
if err := srcTemplate.Execute(src, data); err != nil {
return "", nil, fmt.Errorf("failed to generate source: %v", err)
}
binFile, err := osutil.TempFile("syz-extract-bin")
if err != nil {
return "", nil, err
}
args = append(args, []string{
"-x", "c", "-",
"-o", binFile,
"-w",
}...)
cmd := osutil.Command(cc, args...)
cmd.Stdin = src
if out, err := cmd.CombinedOutput(); err != nil {
os.Remove(binFile)
return "", out, err
}
return binFile, nil, nil
}
var srcTemplate = template.Must(template.New("").Parse(`
#define __asm__(...)
{{range $incl := $.Includes}}
#include <{{$incl}}>
{{end}}
{{range $name, $val := $.Defines}}
#ifndef {{$name}}
# define {{$name}} {{$val}}
#endif
{{end}}
{{.AddSource}}
{{if .DeclarePrintf}}
int printf(const char *format, ...);
{{end}}
int main() {
int i;
unsigned long long vals[] = {
{{range $val := $.Values}}(unsigned long long){{$val}},
{{end}}
};
for (i = 0; i < sizeof(vals)/sizeof(vals[0]); i++) {
if (i != 0)
printf(" ");
printf("%llu", vals[i]);
}
return 0;
}
`))