blob: 94231a518f0fc3eecf46ecead356c41777538ade [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for native string table code"""
import logging
import pytest
import numpy as np
from ._native import StringTable
# pylint: disable=missing-docstring,compare-to-zero,len-as-condition
# pylint: disable=consider-using-enumerate
log = logging.getLogger(__name__)
def _get_c2s(st, *, mode=bytes):
c2s = st.vlookup(np.arange(len(st), dtype="i"), mode, subst=True)
xc2s = st.c2s
assert list(c2s) == list(xc2s)
return tuple(xc2s)
def test_intern():
st = StringTable()
assert len(st) == 1
assert b"" in st
assert 3412 not in st
seqno = st.seqno
assert b"hello" not in st
id_hello = st.intern(b"hello")
assert b"hello" in st
assert st.seqno != seqno
assert id_hello == 1
seqno = st.seqno
assert len(st) == 2
id_world = st.intern(b"world")
assert id_world == 2
assert st.seqno != seqno
assert id_hello == 1
seqno = st.seqno
assert len(st) == 3
id_hello2 = st.intern(b"hello")
assert id_hello2 == id_hello
assert seqno == st.seqno
assert len(st) == 3
def test_get_c2s():
st = StringTable()
st.intern(b"hello")
st.intern(b"world")
st.intern(b"blarg")
assert len(st) == 4
assert list(st.vlookup(np.arange(len(st), dtype="i"))) == \
[b"", b"hello", b"world", b"blarg"]
assert st.vlookup(np.int32(2)) == [b"world"]
def test_item_out_of_range():
with pytest.raises(KeyError):
StringTable().vlookup(np.int32(451).ravel())
with pytest.raises(KeyError):
StringTable().vlookup(np.int32(-1).ravel())
assert list(StringTable().vlookup(np.int32(451), subst=True)) == [None]
assert list(StringTable().vlookup(
np.asarray([-1], dtype="i"), subst=True)) == [None]
assert list(StringTable().vlookup(
np.asarray([-1, 0, -1], dtype="i"), subst=True)) \
== [None, b"", None]
def test_vintern():
st1 = StringTable()
st1.intern(b"hello")
st1.intern(b"world")
st1.intern(b"blarg")
st2 = StringTable()
st2.intern(b"qux")
st2.intern(b"blarg")
assert _get_c2s(st1) == (b"", b"hello", b"world", b"blarg")
assert _get_c2s(st2) == (b"", b"qux", b"blarg")
remap = st1.vintern(_get_c2s(st2))
assert type(remap) is np.ndarray # pylint: disable=unidiomatic-typecheck
assert len(remap) == len(st2)
assert list(remap) == [0, 4, 3]
codes = np.array([st2.intern(b"blarg"), st2.intern(b"qux")])
mapped_codes = np.take(remap, codes)
assert list(mapped_codes) == [st1.intern(b"blarg"), st1.intern(b"qux")]
def test_rank_invalid():
st = StringTable()
with pytest.raises(ValueError):
st.rank("sdfasfdafs")
@pytest.mark.parametrize("collation", ["binary", "nocase", "length"])
def test_rank(collation):
strings = [b"foo", b"FOO", b"qux", b"asdfafafdasfa", b"qwerq", b""]
st = StringTable()
for string in strings:
st.intern(string)
# pylint: disable=unnecessary-lambda
py_collate = {
"binary": lambda s: s,
"nocase": lambda s: s.lower(),
"length": lambda s: len(s),
}
py_sorted_strings = strings.copy()
py_sorted_strings.sort(key=py_collate[collation])
ranks1 = st.rank(collation)
ranks2 = st.rank(collation)
assert ranks1 is ranks2, "cache should work"
assert type(ranks1) is np.ndarray # pylint: disable=unidiomatic-typecheck
rank_by_id = ranks1
st_sorted_strings = strings.copy()
st_sorted_strings.sort(key=lambda s: rank_by_id[st.intern(s)])
assert st_sorted_strings == py_sorted_strings
def test_vlookup():
st = StringTable()
st.intern(b"hello")
st.intern(b"world")
st.intern(b"blarg")
ids_i32 = np.array([3, 2, 3], dtype="i")
result = st.vlookup(ids_i32)
assert result.dtype == np.dtype("O")
assert result.tolist() == [b"blarg", b"world", b"blarg"]
assert result[0] is result[2]
# We reject unsafe casts (e.g., i64->i32) now that we have
# consistent dtypes in the query engine proper.
ids_i64 = np.array([3, 2, 3], dtype="l")
with pytest.raises(TypeError):
st.vlookup(ids_i64)
def test_lookup_cache():
st = StringTable()
st.intern(b"hello")
st.intern(b"world")
st.intern(b"blarg")
ids_i32 = np.array([3, 2, 3], dtype="i")
cache = st.make_lookup_cache()
result1 = st.vlookup(ids_i32, cache)
result2 = st.vlookup(ids_i32, cache)
assert list(result1) == [b"blarg", b"world", b"blarg"]
assert len(result1) == len(result2)
for s1, s2 in zip(result1, result2):
assert s1 is s2
cache = st.make_lookup_cache(str)
result1 = st.vlookup(ids_i32, cache)
result2 = st.vlookup(ids_i32, cache)
assert list(result1) == ["blarg", "world", "blarg"]
assert len(result1) == len(result2)
for s1, s2 in zip(result1, result2):
assert s1 is s2
def test_unicode():
st = StringTable()
unicode_str = "This Is Spın̈al Tap"
encoded_str = unicode_str.encode("UTF-8")
st.intern(encoded_str)
one = np.int32(1)
assert st.vlookup([one])[0] == encoded_str
assert st.vlookup([one], bytes)[0] == encoded_str
assert st.vlookup([one], None)[0] == encoded_str
assert st.vlookup([one], str) == unicode_str
assert st.vlookup([one], decode=str) == unicode_str
assert st.vlookup([one], "strict") == unicode_str
with pytest.raises(TypeError):
assert st.vlookup([one], 4)
def test_unicode_invalid():
invalid_utf8 = b'\x80abc'
st = StringTable()
st.intern(invalid_utf8)
assert st.vlookup(np.int32(1))[0] == invalid_utf8
with pytest.raises(UnicodeDecodeError):
invalid_utf8.decode("UTF-8")
with pytest.raises(UnicodeDecodeError):
assert st.vlookup(np.int32(1), "strict")[0]
assert st.vlookup(np.int32(1), "replace")[0] == \
invalid_utf8.decode("UTF-8", "replace")