| // Copyright 2017 syzkaller project authors. All rights reserved. |
| // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. |
| |
| // Package db implements a simple key-value database. |
| // The database is cached in memory and mirrored on disk. |
| // It is used to store corpus in syz-manager and syz-hub. |
| // The database strives to minimize number of disk accesses |
| // as they can be slow in virtualized environments (GCE). |
| package db |
| |
| import ( |
| "bufio" |
| "bytes" |
| "compress/flate" |
| "encoding/binary" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "os" |
| |
| "github.com/google/syzkaller/pkg/log" |
| "github.com/google/syzkaller/pkg/osutil" |
| ) |
| |
| type DB struct { |
| Version uint64 // arbitrary user version (0 for new database) |
| Records map[string]Record // in-memory cache, must not be modified directly |
| |
| filename string |
| uncompacted int // number of records in the file |
| pending *bytes.Buffer // pending writes to the file |
| } |
| |
| type Record struct { |
| Val []byte |
| Seq uint64 |
| } |
| |
| func Open(filename string) (*DB, error) { |
| db := &DB{ |
| filename: filename, |
| } |
| f, err := os.OpenFile(db.filename, os.O_RDONLY|os.O_CREATE, osutil.DefaultFilePerm) |
| if err != nil { |
| return nil, err |
| } |
| db.Version, db.Records, db.uncompacted = deserializeDB(bufio.NewReader(f)) |
| f.Close() |
| if len(db.Records) == 0 || db.uncompacted/10*9 > len(db.Records) { |
| if err := db.compact(); err != nil { |
| return nil, err |
| } |
| } |
| return db, nil |
| } |
| |
| func (db *DB) Save(key string, val []byte, seq uint64) { |
| if seq == seqDeleted { |
| panic("reserved seq") |
| } |
| if rec, ok := db.Records[key]; ok && seq == rec.Seq && bytes.Equal(val, rec.Val) { |
| return |
| } |
| db.Records[key] = Record{val, seq} |
| db.serialize(key, val, seq) |
| db.uncompacted++ |
| } |
| |
| func (db *DB) Delete(key string) { |
| if _, ok := db.Records[key]; !ok { |
| return |
| } |
| delete(db.Records, key) |
| db.serialize(key, nil, seqDeleted) |
| db.uncompacted++ |
| } |
| |
| func (db *DB) Flush() error { |
| if db.uncompacted/10*9 > len(db.Records) { |
| return db.compact() |
| } |
| if db.pending == nil { |
| return nil |
| } |
| f, err := os.OpenFile(db.filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, osutil.DefaultFilePerm) |
| if err != nil { |
| return err |
| } |
| defer f.Close() |
| if _, err := f.Write(db.pending.Bytes()); err != nil { |
| return err |
| } |
| db.pending = nil |
| return nil |
| } |
| |
| func (db *DB) BumpVersion(version uint64) error { |
| if db.Version == version { |
| return db.Flush() |
| } |
| db.Version = version |
| return db.compact() |
| } |
| |
| func (db *DB) compact() error { |
| buf := new(bytes.Buffer) |
| serializeHeader(buf, db.Version) |
| for key, rec := range db.Records { |
| serializeRecord(buf, key, rec.Val, rec.Seq) |
| } |
| f, err := os.Create(db.filename + ".tmp") |
| if err != nil { |
| return err |
| } |
| defer f.Close() |
| if _, err := f.Write(buf.Bytes()); err != nil { |
| return err |
| } |
| f.Close() |
| if err := os.Rename(f.Name(), db.filename); err != nil { |
| return err |
| } |
| db.uncompacted = len(db.Records) |
| db.pending = nil |
| return nil |
| } |
| |
| func (db *DB) serialize(key string, val []byte, seq uint64) { |
| if db.pending == nil { |
| db.pending = new(bytes.Buffer) |
| } |
| serializeRecord(db.pending, key, val, seq) |
| } |
| |
| const ( |
| dbMagic = uint32(0xbaddb) |
| recMagic = uint32(0xfee1bad) |
| curVersion = uint32(2) |
| seqDeleted = ^uint64(0) |
| ) |
| |
| func serializeHeader(w *bytes.Buffer, version uint64) { |
| binary.Write(w, binary.LittleEndian, dbMagic) |
| binary.Write(w, binary.LittleEndian, curVersion) |
| binary.Write(w, binary.LittleEndian, version) |
| } |
| |
| func serializeRecord(w *bytes.Buffer, key string, val []byte, seq uint64) { |
| binary.Write(w, binary.LittleEndian, recMagic) |
| binary.Write(w, binary.LittleEndian, uint32(len(key))) |
| w.WriteString(key) |
| binary.Write(w, binary.LittleEndian, seq) |
| if seq == seqDeleted { |
| if len(val) != 0 { |
| panic("deleting record with value") |
| } |
| return |
| } |
| if len(val) == 0 { |
| binary.Write(w, binary.LittleEndian, uint32(len(val))) |
| } else { |
| lenPos := len(w.Bytes()) |
| binary.Write(w, binary.LittleEndian, uint32(0)) |
| startPos := len(w.Bytes()) |
| fw, err := flate.NewWriter(w, flate.BestCompression) |
| if err != nil { |
| panic(err) |
| } |
| if _, err := fw.Write(val); err != nil { |
| panic(err) |
| } |
| fw.Close() |
| binary.Write(bytes.NewBuffer(w.Bytes()[lenPos:lenPos:lenPos+8]), binary.LittleEndian, uint32(len(w.Bytes())-startPos)) |
| } |
| } |
| |
| func deserializeDB(r *bufio.Reader) (version uint64, records map[string]Record, uncompacted int) { |
| records = make(map[string]Record) |
| ver, err := deserializeHeader(r) |
| if err != nil { |
| log.Logf(0, "failed to deserialize database header: %v", err) |
| return |
| } |
| version = ver |
| for { |
| key, val, seq, err := deserializeRecord(r) |
| if err == io.EOF { |
| return |
| } |
| if err != nil { |
| log.Logf(0, "failed to deserialize database record: %v", err) |
| return |
| } |
| uncompacted++ |
| if seq == seqDeleted { |
| delete(records, key) |
| } else { |
| records[key] = Record{val, seq} |
| } |
| } |
| } |
| |
| func deserializeHeader(r *bufio.Reader) (uint64, error) { |
| var magic, ver uint32 |
| if err := binary.Read(r, binary.LittleEndian, &magic); err != nil { |
| if err == io.EOF { |
| return 0, nil |
| } |
| return 0, err |
| } |
| if magic != dbMagic { |
| return 0, fmt.Errorf("bad db header: 0x%x", magic) |
| } |
| if err := binary.Read(r, binary.LittleEndian, &ver); err != nil { |
| return 0, err |
| } |
| if ver == 0 || ver > curVersion { |
| return 0, fmt.Errorf("bad db version: %v", ver) |
| } |
| var userVer uint64 |
| if ver >= 2 { |
| if err := binary.Read(r, binary.LittleEndian, &userVer); err != nil { |
| return 0, err |
| } |
| } |
| return userVer, nil |
| } |
| |
| func deserializeRecord(r *bufio.Reader) (key string, val []byte, seq uint64, err error) { |
| var magic uint32 |
| if err = binary.Read(r, binary.LittleEndian, &magic); err != nil { |
| return |
| } |
| if magic != recMagic { |
| err = fmt.Errorf("bad record header: 0x%x", magic) |
| return |
| } |
| var keyLen uint32 |
| if err = binary.Read(r, binary.LittleEndian, &keyLen); err != nil { |
| return |
| } |
| keyBuf := make([]byte, keyLen) |
| if _, err = io.ReadFull(r, keyBuf); err != nil { |
| return |
| } |
| key = string(keyBuf) |
| if err = binary.Read(r, binary.LittleEndian, &seq); err != nil { |
| return |
| } |
| if seq == seqDeleted { |
| return |
| } |
| var valLen uint32 |
| if err = binary.Read(r, binary.LittleEndian, &valLen); err != nil { |
| return |
| } |
| if valLen != 0 { |
| fr := flate.NewReader(&io.LimitedReader{R: r, N: int64(valLen)}) |
| if val, err = ioutil.ReadAll(fr); err != nil { |
| return |
| } |
| fr.Close() |
| } |
| return |
| } |