blob: 62015d07d2344a0a9d7037775560dd47f59983ba [file] [log] [blame]
/*!
This is the Teddy searcher, but ported to AVX2.
See the module comments in the SSSE3 Teddy searcher for a more in depth
explanation of how this algorithm works. For the most part, this port is
basically the same as the SSSE3 version, but using 256-bit vectors instead of
128-bit vectors, which increases throughput.
*/
use std::cmp;
use aho_corasick::{Automaton, AcAutomaton, FullAcAutomaton};
use syntax::hir::literal::Literals;
use vector::avx2::{AVX2VectorBuilder, u8x32};
/// Corresponds to the number of bytes read at a time in the haystack.
const BLOCK_SIZE: usize = 32;
/// Match reports match information.
#[derive(Debug, Clone)]
pub struct Match {
/// The index of the pattern that matched. The index is in correspondence
/// with the order of the patterns given at construction.
pub pat: usize,
/// The start byte offset of the match.
pub start: usize,
/// The end byte offset of the match. This is always `start + pat.len()`.
pub end: usize,
}
/// A SIMD accelerated multi substring searcher.
#[derive(Debug, Clone)]
pub struct Teddy {
/// A builder for AVX2 empowered vectors.
vb: AVX2VectorBuilder,
/// A list of substrings to match.
pats: Vec<Vec<u8>>,
/// An Aho-Corasick automaton of the patterns. We use this when we need to
/// search pieces smaller than the Teddy block size.
ac: FullAcAutomaton<Vec<u8>>,
/// A set of 8 buckets. Each bucket corresponds to a single member of a
/// bitset. A bucket contains zero or more substrings. This is useful
/// when the number of substrings exceeds 8, since our bitsets cannot have
/// more than 8 members.
buckets: Vec<Vec<usize>>,
/// Our set of masks. There's one mask for each byte in the fingerprint.
masks: Masks,
}
impl Teddy {
/// Returns true if and only if Teddy is supported on this platform.
///
/// If this returns `false`, then `Teddy::new(...)` is guaranteed to
/// return `None`.
pub fn available() -> bool {
AVX2VectorBuilder::new().is_some()
}
/// Create a new `Teddy` multi substring matcher.
///
/// If a `Teddy` matcher could not be created (e.g., `pats` is empty or has
/// an empty substring), then `None` is returned.
pub fn new(pats: &Literals) -> Option<Teddy> {
let vb = match AVX2VectorBuilder::new() {
None => return None,
Some(vb) => vb,
};
if !Teddy::available() {
return None;
}
let pats: Vec<_> = pats.literals().iter().map(|p|p.to_vec()).collect();
let min_len = pats.iter().map(|p| p.len()).min().unwrap_or(0);
// Don't allow any empty patterns and require that we have at
// least one pattern.
if min_len < 1 {
return None;
}
// Pick the largest mask possible, but no larger than 3.
let nmasks = cmp::min(3, min_len);
let mut masks = Masks::new(vb, nmasks);
let mut buckets = vec![vec![]; 8];
// Assign a substring to each bucket, and add the bucket's bitfield to
// the appropriate position in the mask.
for (pati, pat) in pats.iter().enumerate() {
let bucket = pati % 8;
buckets[bucket].push(pati);
masks.add(bucket as u8, pat);
}
Some(Teddy {
vb: vb,
pats: pats.to_vec(),
ac: AcAutomaton::new(pats.to_vec()).into_full(),
buckets: buckets,
masks: masks,
})
}
/// Returns all of the substrings matched by this `Teddy`.
pub fn patterns(&self) -> &[Vec<u8>] {
&self.pats
}
/// Returns the number of substrings in this matcher.
pub fn len(&self) -> usize {
self.pats.len()
}
/// Returns the approximate size on the heap used by this matcher.
pub fn approximate_size(&self) -> usize {
self.pats.iter().fold(0, |a, b| a + b.len())
}
/// Searches `haystack` for the substrings in this `Teddy`. If a match was
/// found, then it is returned. Otherwise, `None` is returned.
pub fn find(&self, haystack: &[u8]) -> Option<Match> {
// This is safe because the only way we can construct a Teddy type
// is if AVX2 is available.
unsafe { self.find_impl(haystack) }
}
#[allow(unused_attributes)]
#[target_feature(enable = "avx2")]
unsafe fn find_impl(&self, haystack: &[u8]) -> Option<Match> {
// If our haystack is smaller than the block size, then fall back to
// a naive brute force search.
if haystack.is_empty() || haystack.len() < (BLOCK_SIZE + 2) {
return self.slow(haystack, 0);
}
match self.masks.len() {
0 => None,
1 => self.find1(haystack),
2 => self.find2(haystack),
3 => self.find3(haystack),
_ => unreachable!(),
}
}
/// `find1` is used when there is only 1 mask. This is the easy case and is
/// pretty much as described in the module documentation.
#[inline(always)]
fn find1(&self, haystack: &[u8]) -> Option<Match> {
let mut pos = 0;
let zero = self.vb.u8x32_splat(0);
let len = haystack.len();
debug_assert!(len >= BLOCK_SIZE);
while pos <= len - BLOCK_SIZE {
let h = unsafe {
// I tried and failed to eliminate bounds checks in safe code.
// This is safe because of our loop invariant: pos is always
// <= len-32.
let p = haystack.get_unchecked(pos..);
self.vb.u8x32_load_unchecked_unaligned(p)
};
// N.B. `res0` is our `C` in the module documentation.
let res0 = self.masks.members1(h);
// Only do expensive verification if there are any non-zero bits.
let bitfield = res0.ne(zero).movemask();
if bitfield != 0 {
if let Some(m) = self.verify(haystack, pos, res0, bitfield) {
return Some(m);
}
}
pos += BLOCK_SIZE;
}
self.slow(haystack, pos)
}
/// `find2` is used when there are 2 masks, e.g., the fingerprint is 2 bytes
/// long.
#[inline(always)]
fn find2(&self, haystack: &[u8]) -> Option<Match> {
// This is an exotic way to right shift a SIMD vector across lanes.
// See below at use for more details.
let zero = self.vb.u8x32_splat(0);
let len = haystack.len();
// The previous value of `C` (from the module documentation) for the
// *first* byte in the fingerprint. On subsequent iterations, we take
// the last bitset from the previous `C` and insert it into the first
// position of the current `C`, shifting all other bitsets to the right
// one lane. This causes `C` for the first byte to line up with `C` for
// the second byte, so that they can be `AND`'d together.
let mut prev0 = self.vb.u8x32_splat(0xFF);
let mut pos = 1;
debug_assert!(len >= BLOCK_SIZE);
while pos <= len - BLOCK_SIZE {
let h = unsafe {
// I tried and failed to eliminate bounds checks in safe code.
// This is safe because of our loop invariant: pos is always
// <= len-32.
let p = haystack.get_unchecked(pos..);
self.vb.u8x32_load_unchecked_unaligned(p)
};
let (res0, res1) = self.masks.members2(h);
// Do this:
//
// (prev0 << 15) | (res0 >> 1)
//
// This lets us line up our C values for each byte.
let res0prev0 = res0.alignr_15(prev0);
// `AND`'s our `C` values together.
let res = res0prev0.and(res1);
prev0 = res0;
let bitfield = res.ne(zero).movemask();
if bitfield != 0 {
let pos = pos.checked_sub(1).unwrap();
if let Some(m) = self.verify(haystack, pos, res, bitfield) {
return Some(m);
}
}
pos += BLOCK_SIZE;
}
// The windowing above doesn't check the last byte in the last
// window, so start the slow search at the last byte of the last
// window.
self.slow(haystack, pos.checked_sub(1).unwrap())
}
/// `find3` is used when there are 3 masks, e.g., the fingerprint is 3 bytes
/// long.
///
/// N.B. This is a straight-forward extrapolation of `find2`. The only
/// difference is that we need to keep track of two previous values of `C`,
/// since we now need to align for three bytes.
#[inline(always)]
fn find3(&self, haystack: &[u8]) -> Option<Match> {
let zero = self.vb.u8x32_splat(0);
let len = haystack.len();
let mut prev0 = self.vb.u8x32_splat(0xFF);
let mut prev1 = self.vb.u8x32_splat(0xFF);
let mut pos = 2;
while pos <= len - BLOCK_SIZE {
let h = unsafe {
// I tried and failed to eliminate bounds checks in safe code.
// This is safe because of our loop invariant: pos is always
// <= len-32.
let p = haystack.get_unchecked(pos..);
self.vb.u8x32_load_unchecked_unaligned(p)
};
let (res0, res1, res2) = self.masks.members3(h);
let res0prev0 = res0.alignr_14(prev0);
let res1prev1 = res1.alignr_15(prev1);
let res = res0prev0.and(res1prev1).and(res2);
prev0 = res0;
prev1 = res1;
let bitfield = res.ne(zero).movemask();
if bitfield != 0 {
let pos = pos.checked_sub(2).unwrap();
if let Some(m) = self.verify(haystack, pos, res, bitfield) {
return Some(m);
}
}
pos += BLOCK_SIZE;
}
// The windowing above doesn't check the last two bytes in the last
// window, so start the slow search at the penultimate byte of the
// last window.
// self.slow(haystack, pos.saturating_sub(2))
self.slow(haystack, pos.checked_sub(2).unwrap())
}
/// Runs the verification procedure on `res` (i.e., `C` from the module
/// documentation), where the haystack block starts at `pos` in
/// `haystack`. `bitfield` has ones in the bit positions that `res` has
/// non-zero bytes.
///
/// If a match exists, it returns the first one.
#[inline(always)]
fn verify(
&self,
haystack: &[u8],
pos: usize,
res: u8x32,
mut bitfield: u32,
) -> Option<Match> {
while bitfield != 0 {
// The next offset, relative to pos, where some fingerprint
// matched.
let byte_pos = bitfield.trailing_zeros() as usize;
bitfield &= !(1 << byte_pos);
// Offset relative to the beginning of the haystack.
let start = pos + byte_pos;
// The bitfield telling us which patterns had fingerprints that
// match at this starting position.
let mut patterns = res.extract(byte_pos);
while patterns != 0 {
let bucket = patterns.trailing_zeros() as usize;
patterns &= !(1 << bucket);
// Actual substring search verification.
if let Some(m) = self.verify_bucket(haystack, bucket, start) {
return Some(m);
}
}
}
None
}
/// Verifies whether any substring in the given bucket matches in haystack
/// at the given starting position.
#[inline(always)]
fn verify_bucket(
&self,
haystack: &[u8],
bucket: usize,
start: usize,
) -> Option<Match> {
// This cycles through the patterns in the bucket in the order that
// the patterns were given. Therefore, we guarantee leftmost-first
// semantics.
for &pati in &self.buckets[bucket] {
let pat = &*self.pats[pati];
if start + pat.len() > haystack.len() {
continue;
}
if pat == &haystack[start..start + pat.len()] {
return Some(Match {
pat: pati,
start: start,
end: start + pat.len(),
});
}
}
None
}
/// Slow substring search through all patterns in this matcher.
///
/// This is used when we don't have enough bytes in the haystack for our
/// block based approach.
#[inline(never)]
fn slow(&self, haystack: &[u8], pos: usize) -> Option<Match> {
self.ac.find(&haystack[pos..]).next().map(|m| {
Match {
pat: m.pati,
start: pos + m.start,
end: pos + m.end,
}
})
}
}
/// A list of masks. This has length equal to the length of the fingerprint.
/// The length of the fingerprint is always `min(3, len(smallest_substring))`.
#[derive(Debug, Clone)]
struct Masks {
vb: AVX2VectorBuilder,
masks: [Mask; 3],
size: usize,
}
impl Masks {
/// Create a new set of masks of size `n`, where `n` corresponds to the
/// number of bytes in a fingerprint.
fn new(vb: AVX2VectorBuilder, n: usize) -> Masks {
Masks {
vb: vb,
masks: [Mask::new(vb), Mask::new(vb), Mask::new(vb)],
size: n,
}
}
/// Returns the number of masks.
fn len(&self) -> usize {
self.size
}
/// Adds the given pattern to the given bucket. The bucket should be a
/// power of `2 <= 2^7`.
fn add(&mut self, bucket: u8, pat: &[u8]) {
for i in 0..self.len() {
self.masks[i].add(bucket, pat[i]);
}
}
/// Finds the fingerprints that are in the given haystack block. i.e., this
/// returns `C` as described in the module documentation.
///
/// More specifically, `for i in 0..16` and `j in 0..8, C[i][j] == 1` if and
/// only if `haystack_block[i]` corresponds to a fingerprint that is part
/// of a pattern in bucket `j`.
#[inline(always)]
fn members1(&self, haystack_block: u8x32) -> u8x32 {
let masklo = self.vb.u8x32_splat(0xF);
let hlo = haystack_block.and(masklo);
let hhi = haystack_block.bit_shift_right_4().and(masklo);
self.masks[0].lo.shuffle(hlo).and(self.masks[0].hi.shuffle(hhi))
}
/// Like members1, but computes C for the first and second bytes in the
/// fingerprint.
#[inline(always)]
fn members2(&self, haystack_block: u8x32) -> (u8x32, u8x32) {
let masklo = self.vb.u8x32_splat(0xF);
let hlo = haystack_block.and(masklo);
let hhi = haystack_block.bit_shift_right_4().and(masklo);
let res0 =
self.masks[0].lo.shuffle(hlo).and(self.masks[0].hi.shuffle(hhi));
let res1 =
self.masks[1].lo.shuffle(hlo).and(self.masks[1].hi.shuffle(hhi));
(res0, res1)
}
/// Like `members1`, but computes `C` for the first, second and third bytes
/// in the fingerprint.
#[inline(always)]
fn members3(&self, haystack_block: u8x32) -> (u8x32, u8x32, u8x32) {
let masklo = self.vb.u8x32_splat(0xF);
let hlo = haystack_block.and(masklo);
let hhi = haystack_block.bit_shift_right_4().and(masklo);
let res0 =
self.masks[0].lo.shuffle(hlo).and(self.masks[0].hi.shuffle(hhi));
let res1 =
self.masks[1].lo.shuffle(hlo).and(self.masks[1].hi.shuffle(hhi));
let res2 =
self.masks[2].lo.shuffle(hlo).and(self.masks[2].hi.shuffle(hhi));
(res0, res1, res2)
}
}
/// A single mask.
#[derive(Debug, Clone, Copy)]
struct Mask {
/// Bitsets for the low nybbles in a fingerprint.
lo: u8x32,
/// Bitsets for the high nybbles in a fingerprint.
hi: u8x32,
}
impl Mask {
/// Create a new mask with no members.
fn new(vb: AVX2VectorBuilder) -> Mask {
Mask {
lo: vb.u8x32_splat(0),
hi: vb.u8x32_splat(0),
}
}
/// Adds the given byte to the given bucket.
fn add(&mut self, bucket: u8, byte: u8) {
// Split our byte into two nybbles, and add each nybble to our
// mask.
let byte_lo = (byte & 0xF) as usize;
let byte_hi = (byte >> 4) as usize;
let lo = self.lo.extract(byte_lo) | ((1 << bucket) as u8);
self.lo.replace(byte_lo, lo);
self.lo.replace(byte_lo + 16, lo);
let hi = self.hi.extract(byte_hi) | ((1 << bucket) as u8);
self.hi.replace(byte_hi, hi);
self.hi.replace(byte_hi + 16, hi);
}
}