blob: 7934d03e7badcdbdfb25d9636805b4e1c7f1fc37 [file] [log] [blame]
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package bcache implements a GC-friendly cache (see [Cache]) for BoringCrypto.
package bcache
import (
"sync/atomic"
"unsafe"
)
// A Cache is a GC-friendly concurrent map from unsafe.Pointer to
// unsafe.Pointer. It is meant to be used for maintaining shadow
// BoringCrypto state associated with certain allocated structs, in
// particular public and private RSA and ECDSA keys.
//
// The cache is GC-friendly in the sense that the keys do not
// indefinitely prevent the garbage collector from collecting them.
// Instead, at the start of each GC, the cache is cleared entirely. That
// is, the cache is lossy, and the loss happens at the start of each GC.
// This means that clients need to be able to cope with cache entries
// disappearing, but it also means that clients don't need to worry about
// cache entries keeping the keys from being collected.
type Cache[K, V any] struct {
// The runtime atomically stores nil to ptable at the start of each GC.
ptable atomic.Pointer[cacheTable[K, V]]
}
type cacheTable[K, V any] [cacheSize]atomic.Pointer[cacheEntry[K, V]]
// A cacheEntry is a single entry in the linked list for a given hash table entry.
type cacheEntry[K, V any] struct {
k *K // immutable once created
v atomic.Pointer[V] // read and written atomically to allow updates
next *cacheEntry[K, V] // immutable once linked into table
}
func registerCache(unsafe.Pointer) // provided by runtime
// Register registers the cache with the runtime,
// so that c.ptable can be cleared at the start of each GC.
// Register must be called during package initialization.
func (c *Cache[K, V]) Register() {
registerCache(unsafe.Pointer(&c.ptable))
}
// cacheSize is the number of entries in the hash table.
// The hash is the pointer value mod cacheSize, a prime.
// Collisions are resolved by maintaining a linked list in each hash slot.
const cacheSize = 1021
// table returns a pointer to the current cache hash table,
// coping with the possibility of the GC clearing it out from under us.
func (c *Cache[K, V]) table() *cacheTable[K, V] {
for {
p := c.ptable.Load()
if p == nil {
p = new(cacheTable[K, V])
if !c.ptable.CompareAndSwap(nil, p) {
continue
}
}
return p
}
}
// Clear clears the cache.
// The runtime does this automatically at each garbage collection;
// this method is exposed only for testing.
func (c *Cache[K, V]) Clear() {
// The runtime does this at the start of every garbage collection
// (itself, not by calling this function).
c.ptable.Store(nil)
}
// Get returns the cached value associated with v,
// which is either the value v corresponding to the most recent call to Put(k, v)
// or nil if that cache entry has been dropped.
func (c *Cache[K, V]) Get(k *K) *V {
head := &c.table()[uintptr(unsafe.Pointer(k))%cacheSize]
e := head.Load()
for ; e != nil; e = e.next {
if e.k == k {
return e.v.Load()
}
}
return nil
}
// Put sets the cached value associated with k to v.
func (c *Cache[K, V]) Put(k *K, v *V) {
head := &c.table()[uintptr(unsafe.Pointer(k))%cacheSize]
// Strategy is to walk the linked list at head,
// same as in Get, to look for existing entry.
// If we find one, we update v atomically in place.
// If not, then we race to replace the start = *head
// we observed with a new k, v entry.
// If we win that race, we're done.
// Otherwise, we try the whole thing again,
// with two optimizations:
//
// 1. We track in noK the start of the section of
// the list that we've confirmed has no entry for k.
// The next time down the list, we can stop at noK,
// because new entries are inserted at the front of the list.
// This guarantees we never traverse an entry
// multiple times.
//
// 2. We only allocate the entry to be added once,
// saving it in add for the next attempt.
var add, noK *cacheEntry[K, V]
n := 0
for {
e := head.Load()
start := e
for ; e != nil && e != noK; e = e.next {
if e.k == k {
e.v.Store(v)
return
}
n++
}
if add == nil {
add = &cacheEntry[K, V]{k: k}
add.v.Store(v)
}
add.next = start
if n >= 1000 {
// If an individual list gets too long, which shouldn't happen,
// throw it away to avoid quadratic lookup behavior.
add.next = nil
}
if head.CompareAndSwap(start, add) {
return
}
noK = start
}
}