| # Copyright (C) 2020 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http:#www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Test for native string table code""" |
| |
| import logging |
| import pytest |
| import numpy as np |
| from ._native import StringTable |
| |
| # pylint: disable=missing-docstring,compare-to-zero,len-as-condition |
| # pylint: disable=consider-using-enumerate |
| |
| log = logging.getLogger(__name__) |
| |
| def _get_c2s(st, *, mode=bytes): |
| c2s = st.vlookup(np.arange(len(st), dtype="i"), mode, subst=True) |
| xc2s = st.c2s |
| assert list(c2s) == list(xc2s) |
| return tuple(xc2s) |
| |
| def test_intern(): |
| st = StringTable() |
| assert len(st) == 1 |
| assert b"" in st |
| assert 3412 not in st |
| |
| seqno = st.seqno |
| assert b"hello" not in st |
| id_hello = st.intern(b"hello") |
| assert b"hello" in st |
| assert st.seqno != seqno |
| assert id_hello == 1 |
| seqno = st.seqno |
| assert len(st) == 2 |
| |
| id_world = st.intern(b"world") |
| assert id_world == 2 |
| assert st.seqno != seqno |
| assert id_hello == 1 |
| seqno = st.seqno |
| assert len(st) == 3 |
| |
| id_hello2 = st.intern(b"hello") |
| assert id_hello2 == id_hello |
| assert seqno == st.seqno |
| assert len(st) == 3 |
| |
| def test_get_c2s(): |
| st = StringTable() |
| st.intern(b"hello") |
| st.intern(b"world") |
| st.intern(b"blarg") |
| assert len(st) == 4 |
| assert list(st.vlookup(np.arange(len(st), dtype="i"))) == \ |
| [b"", b"hello", b"world", b"blarg"] |
| assert st.vlookup(np.int32(2)) == [b"world"] |
| |
| def test_item_out_of_range(): |
| with pytest.raises(KeyError): |
| StringTable().vlookup(np.int32(451).ravel()) |
| with pytest.raises(KeyError): |
| StringTable().vlookup(np.int32(-1).ravel()) |
| assert list(StringTable().vlookup(np.int32(451), subst=True)) == [None] |
| assert list(StringTable().vlookup( |
| np.asarray([-1], dtype="i"), subst=True)) == [None] |
| assert list(StringTable().vlookup( |
| np.asarray([-1, 0, -1], dtype="i"), subst=True)) \ |
| == [None, b"", None] |
| |
| def test_vintern(): |
| st1 = StringTable() |
| st1.intern(b"hello") |
| st1.intern(b"world") |
| st1.intern(b"blarg") |
| |
| st2 = StringTable() |
| st2.intern(b"qux") |
| st2.intern(b"blarg") |
| |
| assert _get_c2s(st1) == (b"", b"hello", b"world", b"blarg") |
| assert _get_c2s(st2) == (b"", b"qux", b"blarg") |
| |
| remap = st1.vintern(_get_c2s(st2)) |
| assert type(remap) is np.ndarray # pylint: disable=unidiomatic-typecheck |
| assert len(remap) == len(st2) |
| assert list(remap) == [0, 4, 3] |
| |
| codes = np.array([st2.intern(b"blarg"), st2.intern(b"qux")]) |
| mapped_codes = np.take(remap, codes) |
| assert list(mapped_codes) == [st1.intern(b"blarg"), st1.intern(b"qux")] |
| |
| def test_rank_invalid(): |
| st = StringTable() |
| with pytest.raises(ValueError): |
| st.rank("sdfasfdafs") |
| |
| @pytest.mark.parametrize("collation", ["binary", "nocase", "length"]) |
| def test_rank(collation): |
| strings = [b"foo", b"FOO", b"qux", b"asdfafafdasfa", b"qwerq", b""] |
| st = StringTable() |
| for string in strings: |
| st.intern(string) |
| # pylint: disable=unnecessary-lambda |
| py_collate = { |
| "binary": lambda s: s, |
| "nocase": lambda s: s.lower(), |
| "length": lambda s: len(s), |
| } |
| |
| py_sorted_strings = strings.copy() |
| py_sorted_strings.sort(key=py_collate[collation]) |
| |
| ranks1 = st.rank(collation) |
| ranks2 = st.rank(collation) |
| assert ranks1 is ranks2, "cache should work" |
| assert type(ranks1) is np.ndarray # pylint: disable=unidiomatic-typecheck |
| rank_by_id = ranks1 |
| st_sorted_strings = strings.copy() |
| st_sorted_strings.sort(key=lambda s: rank_by_id[st.intern(s)]) |
| |
| assert st_sorted_strings == py_sorted_strings |
| |
| def test_vlookup(): |
| st = StringTable() |
| st.intern(b"hello") |
| st.intern(b"world") |
| st.intern(b"blarg") |
| |
| ids_i32 = np.array([3, 2, 3], dtype="i") |
| result = st.vlookup(ids_i32) |
| assert result.dtype == np.dtype("O") |
| assert result.tolist() == [b"blarg", b"world", b"blarg"] |
| assert result[0] is result[2] |
| |
| # We reject unsafe casts (e.g., i64->i32) now that we have |
| # consistent dtypes in the query engine proper. |
| ids_i64 = np.array([3, 2, 3], dtype="l") |
| with pytest.raises(TypeError): |
| st.vlookup(ids_i64) |
| |
| def test_lookup_cache(): |
| st = StringTable() |
| st.intern(b"hello") |
| st.intern(b"world") |
| st.intern(b"blarg") |
| |
| ids_i32 = np.array([3, 2, 3], dtype="i") |
| |
| cache = st.make_lookup_cache() |
| result1 = st.vlookup(ids_i32, cache) |
| result2 = st.vlookup(ids_i32, cache) |
| assert list(result1) == [b"blarg", b"world", b"blarg"] |
| assert len(result1) == len(result2) |
| for s1, s2 in zip(result1, result2): |
| assert s1 is s2 |
| |
| cache = st.make_lookup_cache(str) |
| result1 = st.vlookup(ids_i32, cache) |
| result2 = st.vlookup(ids_i32, cache) |
| assert list(result1) == ["blarg", "world", "blarg"] |
| assert len(result1) == len(result2) |
| for s1, s2 in zip(result1, result2): |
| assert s1 is s2 |
| |
| def test_unicode(): |
| st = StringTable() |
| unicode_str = "This Is Spın̈al Tap" |
| encoded_str = unicode_str.encode("UTF-8") |
| st.intern(encoded_str) |
| one = np.int32(1) |
| assert st.vlookup([one])[0] == encoded_str |
| assert st.vlookup([one], bytes)[0] == encoded_str |
| assert st.vlookup([one], None)[0] == encoded_str |
| assert st.vlookup([one], str) == unicode_str |
| assert st.vlookup([one], decode=str) == unicode_str |
| assert st.vlookup([one], "strict") == unicode_str |
| with pytest.raises(TypeError): |
| assert st.vlookup([one], 4) |
| |
| def test_unicode_invalid(): |
| invalid_utf8 = b'\x80abc' |
| st = StringTable() |
| st.intern(invalid_utf8) |
| assert st.vlookup(np.int32(1))[0] == invalid_utf8 |
| with pytest.raises(UnicodeDecodeError): |
| invalid_utf8.decode("UTF-8") |
| with pytest.raises(UnicodeDecodeError): |
| assert st.vlookup(np.int32(1), "strict")[0] |
| assert st.vlookup(np.int32(1), "replace")[0] == \ |
| invalid_utf8.decode("UTF-8", "replace") |