blob: a6213784e205cf846d889e0829376afb0d721f56 [file] [log] [blame]
use std::borrow::Borrow;
use std::collections::HashMap;
use std::error;
use std::fmt;
use std::io;
use std::result;
use super::{TrieSetSlice, CHUNK_SIZE};
// This implementation was pretty much cribbed from raphlinus' contribution
// to the standard library: https://github.com/rust-lang/rust/pull/33098/files
//
// The fundamental principle guiding this implementation is to take advantage
// of the fact that similar Unicode codepoints are often grouped together, and
// that most boolean Unicode properties are quite sparse over the entire space
// of Unicode codepoints.
//
// To do this, we represent sets using something like a trie (which gives us
// prefix compression). The "final" states of the trie are embedded in leaves
// or "chunks," where each chunk is a 64 bit integer. Each bit position of the
// integer corresponds to whether a particular codepoint is in the set or not.
// These chunks are not just a compact representation of the final states of
// the trie, but are also a form of suffix compression. In particular, if
// multiple ranges of 64 contiguous codepoints map have the same set membership
// ordering, then they all map to the exact same chunk in the trie.
//
// We organize this structure by partitioning the space of Unicode codepoints
// into three disjoint sets. The first set corresponds to codepoints
// [0, 0x800), the second [0x800, 0x1000) and the third [0x10000, 0x110000).
// These partitions conveniently correspond to the space of 1 or 2 byte UTF-8
// encoded codepoints, 3 byte UTF-8 encoded codepoints and 4 byte UTF-8 encoded
// codepoints, respectively.
//
// Each partition has its own tree with its own root. The first partition is
// the simplest, since the tree is completely flat. In particular, to determine
// the set membership of a Unicode codepoint (that is less than `0x800`), we
// do the following (where `cp` is the codepoint we're testing):
//
// let chunk_address = cp >> 6;
// let chunk_bit = cp & 0b111111;
// let chunk = tree1[cp >> 6];
// let is_member = 1 == ((chunk >> chunk_bit) & 1);
//
// We do something similar for the second partition:
//
// // we subtract 0x20 since (0x800 >> 6) == 0x20.
// let child_address = (cp >> 6) - 0x20;
// let chunk_address = tree2_level1[child_address];
// let chunk_bit = cp & 0b111111;
// let chunk = tree2_level2[chunk_address];
// let is_member = 1 == ((chunk >> chunk_bit) & 1);
//
// And so on for the third partition.
//
// Note that as a special case, if the second or third partitions are empty,
// then the trie will store empty slices for those levels. The `contains`
// check knows to return `false` in those cases.
const CHUNKS: usize = 0x110000 / CHUNK_SIZE;
/// A type alias that maps to `std::result::Result<T, ucd_trie::Error>`.
pub type Result<T> = result::Result<T, Error>;
/// An error that can occur during construction of a trie.
#[derive(Clone, Debug)]
pub enum Error {
/// This error is returned when an invalid codepoint is given to
/// `TrieSetOwned::from_codepoints`. An invalid codepoint is a `u32` that
/// is greater than `0x10FFFF`.
InvalidCodepoint(u32),
/// This error is returned when a set of Unicode codepoints could not be
/// sufficiently compressed into the trie provided by this crate. There is
/// no work-around for this error at this time.
GaveUp,
}
impl error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Error::InvalidCodepoint(cp) => write!(
f,
"could not construct trie set containing an \
invalid Unicode codepoint: 0x{:X}",
cp
),
Error::GaveUp => {
write!(f, "could not compress codepoint set into a trie")
}
}
}
}
impl From<Error> for io::Error {
fn from(err: Error) -> io::Error {
io::Error::new(io::ErrorKind::Other, err)
}
}
/// An owned trie set.
#[derive(Clone)]
pub struct TrieSetOwned {
tree1_level1: Vec<u64>,
tree2_level1: Vec<u8>,
tree2_level2: Vec<u64>,
tree3_level1: Vec<u8>,
tree3_level2: Vec<u8>,
tree3_level3: Vec<u64>,
}
impl fmt::Debug for TrieSetOwned {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "TrieSetOwned(...)")
}
}
impl TrieSetOwned {
fn new(all: &[bool]) -> Result<TrieSetOwned> {
let mut bitvectors = Vec::with_capacity(CHUNKS);
for i in 0..CHUNKS {
let mut bitvector = 0u64;
for j in 0..CHUNK_SIZE {
if all[i * CHUNK_SIZE + j] {
bitvector |= 1 << j;
}
}
bitvectors.push(bitvector);
}
let tree1_level1 =
bitvectors.iter().cloned().take(0x800 / CHUNK_SIZE).collect();
let (mut tree2_level1, mut tree2_level2) = compress_postfix_leaves(
&bitvectors[0x800 / CHUNK_SIZE..0x10000 / CHUNK_SIZE],
)?;
if tree2_level2.len() == 1 && tree2_level2[0] == 0 {
tree2_level1.clear();
tree2_level2.clear();
}
let (mid, mut tree3_level3) = compress_postfix_leaves(
&bitvectors[0x10000 / CHUNK_SIZE..0x110000 / CHUNK_SIZE],
)?;
let (mut tree3_level1, mut tree3_level2) =
compress_postfix_mid(&mid, 64)?;
if tree3_level3.len() == 1 && tree3_level3[0] == 0 {
tree3_level1.clear();
tree3_level2.clear();
tree3_level3.clear();
}
Ok(TrieSetOwned {
tree1_level1,
tree2_level1,
tree2_level2,
tree3_level1,
tree3_level2,
tree3_level3,
})
}
/// Create a new trie set from a set of Unicode scalar values.
///
/// This returns an error if a set could not be sufficiently compressed to
/// fit into a trie.
pub fn from_scalars<I, C>(scalars: I) -> Result<TrieSetOwned>
where
I: IntoIterator<Item = C>,
C: Borrow<char>,
{
let mut all = vec![false; 0x110000];
for s in scalars {
all[*s.borrow() as usize] = true;
}
TrieSetOwned::new(&all)
}
/// Create a new trie set from a set of Unicode scalar values.
///
/// This returns an error if a set could not be sufficiently compressed to
/// fit into a trie. This also returns an error if any of the given
/// codepoints are greater than `0x10FFFF`.
pub fn from_codepoints<I, C>(codepoints: I) -> Result<TrieSetOwned>
where
I: IntoIterator<Item = C>,
C: Borrow<u32>,
{
let mut all = vec![false; 0x110000];
for cp in codepoints {
let cp = *cp.borrow();
if cp > 0x10FFFF {
return Err(Error::InvalidCodepoint(cp));
}
all[cp as usize] = true;
}
TrieSetOwned::new(&all)
}
/// Return this set as a slice.
#[inline(always)]
pub fn as_slice(&self) -> TrieSetSlice<'_> {
TrieSetSlice {
tree1_level1: &self.tree1_level1,
tree2_level1: &self.tree2_level1,
tree2_level2: &self.tree2_level2,
tree3_level1: &self.tree3_level1,
tree3_level2: &self.tree3_level2,
tree3_level3: &self.tree3_level3,
}
}
/// Returns true if and only if the given Unicode scalar value is in this
/// set.
pub fn contains_char(&self, c: char) -> bool {
self.as_slice().contains_char(c)
}
/// Returns true if and only if the given codepoint is in this set.
///
/// If the given value exceeds the codepoint range (i.e., it's greater
/// than `0x10FFFF`), then this returns false.
pub fn contains_u32(&self, cp: u32) -> bool {
self.as_slice().contains_u32(cp)
}
}
fn compress_postfix_leaves(chunks: &[u64]) -> Result<(Vec<u8>, Vec<u64>)> {
let mut root = vec![];
let mut children = vec![];
let mut bychild = HashMap::new();
for &chunk in chunks {
if !bychild.contains_key(&chunk) {
let start = bychild.len();
if start > ::std::u8::MAX as usize {
return Err(Error::GaveUp);
}
bychild.insert(chunk, start as u8);
children.push(chunk);
}
root.push(bychild[&chunk]);
}
Ok((root, children))
}
fn compress_postfix_mid(
chunks: &[u8],
chunk_size: usize,
) -> Result<(Vec<u8>, Vec<u8>)> {
let mut root = vec![];
let mut children = vec![];
let mut bychild = HashMap::new();
for i in 0..(chunks.len() / chunk_size) {
let chunk = &chunks[i * chunk_size..(i + 1) * chunk_size];
if !bychild.contains_key(chunk) {
let start = bychild.len();
if start > ::std::u8::MAX as usize {
return Err(Error::GaveUp);
}
bychild.insert(chunk, start as u8);
children.extend(chunk);
}
root.push(bychild[chunk]);
}
Ok((root, children))
}
#[cfg(test)]
mod tests {
use super::TrieSetOwned;
use crate::general_category;
use std::collections::HashSet;
fn mk(scalars: &[char]) -> TrieSetOwned {
TrieSetOwned::from_scalars(scalars).unwrap()
}
fn ranges_to_set(ranges: &[(u32, u32)]) -> Vec<u32> {
let mut set = vec![];
for &(start, end) in ranges {
for cp in start..end + 1 {
set.push(cp);
}
}
set
}
#[test]
fn set1() {
let set = mk(&['a']);
assert!(set.contains_char('a'));
assert!(!set.contains_char('b'));
assert!(!set.contains_char('β'));
assert!(!set.contains_char('☃'));
assert!(!set.contains_char('😼'));
}
#[test]
fn set_combined() {
let set = mk(&['a', 'b', 'β', '☃', '😼']);
assert!(set.contains_char('a'));
assert!(set.contains_char('b'));
assert!(set.contains_char('β'));
assert!(set.contains_char('☃'));
assert!(set.contains_char('😼'));
assert!(!set.contains_char('c'));
assert!(!set.contains_char('θ'));
assert!(!set.contains_char('⛇'));
assert!(!set.contains_char('🐲'));
}
// Basic tests on all of the general category sets. We check that
// membership is correct on every Unicode codepoint... because we can.
macro_rules! category_test {
($name:ident, $ranges:ident) => {
#[test]
fn $name() {
let set = ranges_to_set(general_category::$ranges);
let hashset: HashSet<u32> = set.iter().cloned().collect();
let trie = TrieSetOwned::from_codepoints(&set).unwrap();
for cp in 0..0x110000 {
assert!(trie.contains_u32(cp) == hashset.contains(&cp));
}
// Test that an invalid codepoint is treated correctly.
assert!(!trie.contains_u32(0x110000));
assert!(!hashset.contains(&0x110000));
}
};
}
category_test!(gencat_cased_letter, CASED_LETTER);
category_test!(gencat_close_punctuation, CLOSE_PUNCTUATION);
category_test!(gencat_connector_punctuation, CONNECTOR_PUNCTUATION);
category_test!(gencat_control, CONTROL);
category_test!(gencat_currency_symbol, CURRENCY_SYMBOL);
category_test!(gencat_dash_punctuation, DASH_PUNCTUATION);
category_test!(gencat_decimal_number, DECIMAL_NUMBER);
category_test!(gencat_enclosing_mark, ENCLOSING_MARK);
category_test!(gencat_final_punctuation, FINAL_PUNCTUATION);
category_test!(gencat_format, FORMAT);
category_test!(gencat_initial_punctuation, INITIAL_PUNCTUATION);
category_test!(gencat_letter, LETTER);
category_test!(gencat_letter_number, LETTER_NUMBER);
category_test!(gencat_line_separator, LINE_SEPARATOR);
category_test!(gencat_lowercase_letter, LOWERCASE_LETTER);
category_test!(gencat_math_symbol, MATH_SYMBOL);
category_test!(gencat_mark, MARK);
category_test!(gencat_modifier_letter, MODIFIER_LETTER);
category_test!(gencat_modifier_symbol, MODIFIER_SYMBOL);
category_test!(gencat_nonspacing_mark, NONSPACING_MARK);
category_test!(gencat_number, NUMBER);
category_test!(gencat_open_punctuation, OPEN_PUNCTUATION);
category_test!(gencat_other, OTHER);
category_test!(gencat_other_letter, OTHER_LETTER);
category_test!(gencat_other_number, OTHER_NUMBER);
category_test!(gencat_other_punctuation, OTHER_PUNCTUATION);
category_test!(gencat_other_symbol, OTHER_SYMBOL);
category_test!(gencat_paragraph_separator, PARAGRAPH_SEPARATOR);
category_test!(gencat_private_use, PRIVATE_USE);
category_test!(gencat_punctuation, PUNCTUATION);
category_test!(gencat_separator, SEPARATOR);
category_test!(gencat_space_separator, SPACE_SEPARATOR);
category_test!(gencat_spacing_mark, SPACING_MARK);
category_test!(gencat_surrogate, SURROGATE);
category_test!(gencat_symbol, SYMBOL);
category_test!(gencat_titlecase_letter, TITLECASE_LETTER);
category_test!(gencat_unassigned, UNASSIGNED);
category_test!(gencat_uppercase_letter, UPPERCASE_LETTER);
}