blob: 3bf3e25e9b23d7c7c01532ce117ab7fdec0bb9ea [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for basic query machinery"""
# pylint: disable=missing-docstring
import os
import logging
from collections import OrderedDict
from cytoolz import valmap, identity
import numpy as np
import pytest
from modernmp.util import the
from .query import (
ArgSortQuery,
BinaryOperationQuery,
CollateQuery,
CountQuery,
DURATION_SCHEMA,
EVENT_UNPARTITIONED_TIME_MAJOR,
InvalidQueryException,
LiteralQuery,
LockstepFilledArrayQuery,
NULL_FILL_VALUE,
NULL_SCHEMA,
QueryNode,
QuerySchema,
QueryTable,
REGULAR_TABLE,
SPAN_UNPARTITIONED_TIME_MAJOR,
STRING_SCHEMA,
SequentialQuery,
SimpleQueryNode,
TS_SCHEMA,
TableKind,
TableSchema,
TableSorting,
TakeQuery,
UNIQUE,
UnaryOperationQuery,
masked_broadcast_to,
)
from .queryengine import (
Delivery,
QueryCache,
QueryEngine,
)
from .test_util import (
assertrepr_compare_hooks,
pict_parameterize,
)
from .util import (
BOOL,
ExplicitInheritance,
FLOAT32,
FLOAT64,
INT16,
INT32,
INT64,
INT8,
UINT16,
UINT32,
UINT64,
UINT8,
all_unique,
final,
iattr,
lmap,
load_pandas,
override,
tattr,
ureg,
)
from ._native import (
SPAN_INVARIANT_CHECKING,
npy_broadcast_to,
npy_explode,
npy_get_broadcaster,
npy_get_data,
npy_get_mask,
npy_has_mask,
npy_memory_location,
)
log = logging.getLogger(__name__)
T = True
F = False
N = None
ma = np.ma.masked_array
nomask = np.ma.nomask
# Various data types we use to exhaustively test how various query
# operators work across different dtypes.
ALL_SCHEMAS = dict(
[("str", (STRING_SCHEMA, b"")),
("domain_a", (QuerySchema(INT64, domain="a"), np.int64(3))),
("domain_b", (QuerySchema(INT64, domain="b"), np.int64(4))),
("bytes", (QuerySchema(INT64, unit=ureg().bytes), np.int64(5))),
("cm", (QuerySchema(INT64, unit=ureg().cm), np.int64(6))),
("NULL", (NULL_SCHEMA, NULL_FILL_VALUE))] +
[(str(dtype), (QuerySchema(dtype), dtype.type()))
for dtype in (INT8, UINT8, INT16, UINT16, INT32, UINT32,
INT64, UINT64, BOOL, FLOAT32, FLOAT64)])
@final
class TestQuery(SimpleQueryNode):
"""Dummy query for test"""
values = tattr()
_schema = iattr(QuerySchema,
default=QuerySchema(INT64),
name="schema")
@override
def __new__(cls, values, schema=None):
values = tuple(values)
have_string_values = any(isinstance(v, (bytes, str))
for v in values)
if schema is None:
schema = (STRING_SCHEMA if have_string_values
else QuerySchema(INT64))
else:
assert (schema.is_string == have_string_values
or not values
or all(v is None for v in values)), \
("string value-schema mismatch: schema={!r} values={!r}"
.format(schema, values))
return cls._do_new(cls, values, schema)
@override
def _compute_schema(self):
return self._schema
@override
def _compute_inputs(self):
return ()
@property
def __array(self):
values = self.values
try:
need_mask = None in values or np.isnan(values).any()
except TypeError:
need_mask = False
if need_mask:
mask = tuple(value is None or np.isnan(value)
for value in values)
values = tuple(0 if value is None or np.isnan(value)
else value for value in values)
array = np.ma.masked_array(values, mask)
else:
array = np.array(values)
assert array.dtype != np.dtype("O")
return np.asanyarray(array, dtype=self._schema.dtype)
@override
async def run_async(self, qe):
# pylint: disable=comparison-with-itself
[], [oc] = await qe.async_setup((), (self,))
for value in self.values:
if value is None or value != value: # nan test
await oc.write(
np.ma.masked_array([0], [True],
dtype=self.schema.dtype))
continue
if self.schema.is_string:
value = qe.st.intern(value.encode("UTF-8")
if isinstance(value, str) else value)
await oc.write([value])
@override
def __repr__(self):
return "<TestQuery {!r}>".format(self.values)
_UNSPECIFIED = object()
class TestQueryTable(QueryTable):
"""QueryTable for test"""
# TestQuery has a passing similarity to a LiteralQuery over in the
# serious-business side of the query system, query.py, but TestQuery
# and TestQueryTable have some ergonomic conveniences that wouldn't
# be right for production code, e.g., `_end' conversion.
_columns = iattr()
_schemas = tattr()
@override
def __new__(cls, *, columns=None, rows=None, names=None,
schema=None,
schemas=None,
table_schema=_UNSPECIFIED,
span_partition=_UNSPECIFIED,
event_partition=_UNSPECIFIED):
# pylint: disable=redefined-variable-type
assert columns is None or rows is None
if span_partition is not _UNSPECIFIED:
assert table_schema is _UNSPECIFIED
assert event_partition is _UNSPECIFIED
table_schema = TableSchema(TableKind.SPAN, span_partition,
TableSorting.TIME_MAJOR)
elif event_partition is not _UNSPECIFIED:
assert table_schema is _UNSPECIFIED
assert span_partition is _UNSPECIFIED
table_schema = TableSchema(TableKind.EVENT,
event_partition,
TableSorting.TIME_MAJOR)
elif table_schema is not _UNSPECIFIED:
assert event_partition is _UNSPECIFIED
assert span_partition is _UNSPECIFIED
assert isinstance(table_schema, TableSchema)
else:
table_schema = REGULAR_TABLE
if names is not None:
names = list(names)
if rows is not None:
def _fixup_row(row):
nonlocal names
if isinstance(row, dict):
row_names = list(row)
row_values = list(row.values())
# Gross hack, putting this special case here, but it makes
# the test tables a lot less annoying to write.
if "_live_end" in row_names:
le_idx = row_names.index("_live_end")
row_ts = row["_ts"]
assert row_ts <= row_values[le_idx]
row_names[le_idx] = "live_duration"
row_values[le_idx] -= row_ts
if names is None:
names = row_names
else:
assert names == row_names, (
"dict rows must have dict keys matching the "
"column names")
row = row_values
assert isinstance(row, (tuple, list))
return row
rows = lmap(_fixup_row, rows)
assert all(len(r) == len(rows[0]) for r in rows)
columns = _flip_table(rows)
if names is None:
names = ["col{}".format(i) for i in range(len(columns))]
if not columns:
columns = [[]] * len(names)
assert len(columns) == len(names)
if schema:
assert schemas is None
schemas = (schema,) * len(columns)
elif schemas is not None:
assert not schema
schemas = tuple(
(s if isinstance(s, QuerySchema) else
QuerySchema(s)) for s in schemas)
assert len(schemas) == len(columns)
else:
schemas = [None] * len(columns)
if table_schema.kind == TableKind.SPAN:
names = list(names)
assert names[0] == "_ts"
assert names[1] in ("_duration", "_end")
if names[1] == "_end":
names[1] = "_duration"
def _dur(end, ts):
assert ts < end
return end - ts
columns[1] = [_dur(end, ts)
for ts, end in zip(columns[0], columns[1])]
schemas = list(schemas)
schemas[0] = TS_SCHEMA
schemas[1] = DURATION_SCHEMA
elif table_schema.kind == TableKind.EVENT:
assert names[0] == "_ts"
schemas = list(schemas)
schemas[0] = TS_SCHEMA
columns = OrderedDict(
(the(str, name), tuple(data))
for name, data in zip(names, columns))
return cls._do_new(cls,
_columns=columns,
_schemas=schemas,
table_schema=table_schema)
@override
def _make_column_tuple(self):
return tuple(self._columns)
@override
def _make_column_query(self, column):
"""Return a query for a specific named column"""
idx = self.columns.index(column)
return TestQuery(self._columns[column], self._schemas[idx])
def as_dataframe(self):
"""Convert the test query table to a Pandas DataFrame"""
pd = load_pandas()
return pd.DataFrame(self.as_dict())
def as_dict(self):
"""Dict view"""
return valmap(list, self._columns)
def _make_rng():
from random import Random
rng = Random()
rng.seed("the mome raths outgrabe")
return rng
def _flip_table(table):
return lmap(list, zip(*table))
@final
class TestQueryTableResult(dict, ExplicitInheritance):
@override
def __new__(cls, data, *, table_schema, column_queries): # pylint: disable=unused-argument
return super().__new__(cls)
@override
def __init__(self, data, *, table_schema, column_queries):
super().__init__()
self.update(data)
self.__table_schema = the(TableSchema, table_schema)
self.__column_queries = dict(column_queries)
assert list(self.keys()) == list(self.__column_queries.keys())
assert all(isinstance(q, QueryNode)
for q in self.__column_queries.values())
@property
def table_schema(self):
"""The table schema of the result"""
return self.__table_schema
@property
def column_queries(self):
"""A mapping of column queries
Maps column name to QueryNode. Order matches the data dict.
"""
return self.__column_queries
@property
def column_schemas(self):
"""Sequence of column schemas"""
return [q.schema for q in self.column_queries.values()]
def equals_test_qt(self, qt):
"""Dedicated equality comparison with TestQueryTable
Verifies column schemas, table schema as well
as column data."""
assert isinstance(qt, TestQueryTable)
return (
list(self) == list(qt.columns) and
self.table_schema == qt.table_schema and
len(self.column_schemas) == len(qt.column_schemas_for_test) and
all(l_schema.is_a(r_schema) for l_schema, r_schema
in zip(self.column_schemas, qt.column_schemas_for_test)) and
self == qt.as_dict())
def assert_eq(self, qt):
"""Assert that result matches in test"""
assert list(self) == list(qt.columns)
assert self.table_schema == qt.table_schema
assert len(self.column_schemas) == len(qt.column_schemas_for_test)
for l_schema, r_schema in zip(self.column_schemas,
qt.column_schemas_for_test):
assert l_schema.is_a(r_schema)
assert self == qt.as_dict()
assert self == qt
return True
@override
def __eq__(self, other):
if isinstance(other, TestQueryTable):
return self.equals_test_qt(other)
return super().__eq__(other)
@override
def __ne__(self, other):
if isinstance(other, TestQueryTable):
return not self.equals_test_qt(other)
return super().__ne__(other)
def _qt_failure_hook(config, op, left, right):
if (op == "=="
and isinstance(left, TestQueryTableResult)
and isinstance(right, TestQueryTable)
and left != right):
from _pytest.assertion.util import assertrepr_compare
if left.table_schema != right.table_schema:
return assertrepr_compare(config, op,
left.table_schema, right.table_schema)
left_columns = list(left)
right_columns = list(right.columns)
if left_columns != right_columns:
return assertrepr_compare(config, op,
left_columns, right_columns)
left_schemas = left.column_schemas
right_schemas = right.column_schemas_for_test
if (len(left_schemas) != len(right_schemas) or
not all(ls.is_a(rs)
for ls, rs in zip(left_schemas, right_schemas))):
return assertrepr_compare(config, op,
left_schemas, right_schemas)
if dict(left) != right.as_dict():
return assertrepr_compare(config, op,
dict(left), right.as_dict())
log.warning("unknown difference?!")
return None
assertrepr_compare_hooks.append(_qt_failure_hook)
def _test_execute(wanted, *,
as_dtypes=False,
strings=True,
env=None,
qe=None,
block_size=None,
as_raw_array=False):
"""Common functionality for _execute_q and _execute_qt"""
envvar = "DCTV_TEST_BLOCK_SIZE"
if block_size is None and envvar in os.environ:
block_size = int(os.environ[envvar])
qc_kwargs = {}
if block_size is not None:
qc_kwargs["block_size"] = block_size
if qe is None:
qe = QueryEngine(QueryCache(**qc_kwargs), env=env)
def _fix(schema, array):
assert isinstance(array, np.ndarray)
if as_dtypes:
return array.dtype
if strings and schema.is_string:
decoded = qe.st.vlookup(np.ma.filled(array, 0))
mask = np.ma.getmaskarray(array)
decoded[mask] = None
if strings != "bytes":
decoded[~mask] = [s.decode("UTF-8") for s in decoded[~mask]]
array = decoded
if not as_raw_array:
array = [
None if masked else value
for value, masked in zip(array, np.ma.getmaskarray(array))
]
return array
return {
query: _fix(query.schema, data)
for query, data in
qe.execute_for_columns(wanted)
}
def _execute_q(q, **kwargs):
"""Execute a query for test"""
return _test_execute([q], **kwargs)[q]
def _execute_qt(qt, **kwargs):
"""Execute a query table for test"""
assert isinstance(qt, QueryTable)
wanted = [qt[column] for column in qt.columns]
result = _test_execute(wanted, **kwargs)
return TestQueryTableResult(
{column_name: result[qt[column_name]]
for column_name in qt.columns},
table_schema=qt.table_schema,
column_queries=zip(qt.columns, wanted),
)
def test_query_basic():
qe = QueryEngine(QueryCache(block_size=128))
q = TestQuery([1, 2, 3])
result = list(qe.execute([q]))
[[rq, rarray, is_eof]] = result
assert rq is q
assert list(rarray) == list(q.values)
assert is_eof
def test_query_small_blocks():
# pylint: disable=len-as-condition,compare-to-zero
qe = QueryEngine(QueryCache(block_size=1))
q = TestQuery([1, 2, 3])
result = list(qe.execute([q]))
for n in (1, 2, 3):
rq, rarray, is_eof = result.pop(0)
assert rq is q
assert list(rarray) == [n]
assert is_eof is False
[[rq, rarray, is_eof]] = result
assert rq is q
assert len(rarray) == 0
assert is_eof
_HINTS = {
"arbitrary": Delivery.ARBITRARY,
"columns": Delivery.COLUMNWISE,
"rows": Delivery.ROWWISE,
}
@pytest.mark.parametrize("delivery_hint", _HINTS.keys())
@pytest.mark.parametrize("masked", ["mask", "no_mask"])
@pytest.mark.parametrize("block_size", [1, 3, 128])
def test_query_for_columns(delivery_hint, masked, block_size):
# pylint: disable=unidiomatic-typecheck
qe = QueryEngine(QueryCache(block_size=block_size))
middle_value = {
"mask": None,
"no_mask": 2,
}
q = TestQuery([1, middle_value[masked], 3])
[[rq, ra]] = list(qe.execute_for_columns(
[q], delivery_hint=_HINTS[delivery_hint]))
assert rq is q
if None in q.values:
assert type(ra) is np.ma.masked_array
else:
assert type(ra) is np.ndarray
assert ra.tolist() == list(q.values)
@pytest.mark.parametrize("block_size", [1, 3, 128])
@pytest.mark.parametrize("delivery_hint", _HINTS.keys())
def test_query_multiple_rows(block_size, delivery_hint):
q1 = TestQuery([1, 2, 3])
q2 = TestQuery([4, 5, 6])
qe = QueryEngine(QueryCache(block_size=block_size))
[[rq1, ra1], [rq2, ra2]] = list(qe.execute_for_columns(
[q1, q2], delivery_hint=_HINTS[delivery_hint]))
if rq1 is not q1:
tmp = rq1, ra1
rq1, ra1 = rq2, ra2
rq2, ra2 = tmp
assert rq1 is q1
assert rq2 is q2
assert list(ra1) == list(q1.values)
assert list(ra2) == list(q2.values)
@pytest.mark.parametrize("block_size", [1, 3, 128])
@pytest.mark.parametrize("delivery_hint", _HINTS.keys())
def test_query_multiple_rows_unequal(block_size, delivery_hint):
q1 = TestQuery([1, 2, 3, 4, 5, 6, 7, 8, 9])
q2 = TestQuery([4, 5, 6])
qe = QueryEngine(QueryCache(block_size=block_size))
[[rq1, ra1], [rq2, ra2]] = list(qe.execute_for_columns(
[q1, q2], delivery_hint=_HINTS[delivery_hint]))
if rq1 is not q1:
tmp = rq1, ra1
rq1, ra1 = rq2, ra2
rq2, ra2 = tmp
assert rq1 is q1
assert rq2 is q2
assert list(ra1) == list(q1.values)
assert list(ra2) == list(q2.values)
def test_query_table():
"""Test whether QueryTable basically works"""
columns = [[1, 2, 3], [4, 5, 6]]
qt = TestQueryTable(columns=columns, names=["foo", "bar"])
assert _execute_qt(qt) == dict(foo=columns[0], bar=columns[1])
@pytest.mark.parametrize("partitioned", ["partitioned", "unpartitioned"])
def test_query_table_order_conversion(partitioned):
do_part = partitioned == "partitioned"
def _mkqt(sorting, rows):
return TestQueryTable(
names=["_ts", "_duration", "part", "foo"],
rows=rows,
table_schema=TableSchema(
TableKind.SPAN,
"part" if do_part else None,
sorting))
base_rows = [
[3, 1, 1, 33],
[3, 1, 0, -3],
[4, 1, 1, -4],
[7, 1, 0, -7],
[1, 2, 0, -1],
[5, 1, 1, -5],
]
foo_none = _mkqt(TableSorting.NONE, base_rows)
foo_time_major = _mkqt(
TableSorting.TIME_MAJOR,
(sorted(base_rows,
key=(lambda row: (row[0], row[2]) if do_part else row[0]))))
foo_partition_major = _mkqt(
TableSorting.PARTITION_MAJOR,
sorted(base_rows,
key=(lambda row: (row[2], row[0]) if do_part else row[0])))
assert _execute_qt(foo_none) == foo_none
assert _execute_qt(foo_none.transform(identity)) == foo_none
assert _execute_qt(
foo_none.to_schema(sorting=TableSorting.PARTITION_MAJOR)
) == foo_partition_major
assert _execute_qt(
foo_none.to_schema(sorting=TableSorting.TIME_MAJOR)
) == foo_time_major
assert _execute_qt(
foo_none
.to_schema(sorting=TableSorting.TIME_MAJOR)
.to_schema(sorting=TableSorting.PARTITION_MAJOR)
) == foo_partition_major
assert _execute_qt(
foo_none
.to_schema(sorting=TableSorting.TIME_MAJOR)
.to_schema(sorting=TableSorting.PARTITION_MAJOR)
.to_schema(sorting=TableSorting.NONE)
) == _execute_qt(foo_partition_major.with_table_schema(
foo_partition_major.table_schema.evolve(
sorting=TableSorting.NONE)))
def test_match():
"""Test that we can match a condition"""
q = TestQuery([1, 2, 2, 3])
mq = BinaryOperationQuery(q, "=", QueryNode.filled(2, q.countq()))
assert _execute_q(mq) == [False, True, True, False]
def test_match_backward():
"""Test that we can match a condition the other way around"""
q = TestQuery([1, 2, 2, 3])
mq = BinaryOperationQuery(QueryNode.filled(2, q.countq()), "=", q)
assert _execute_q(mq) == [False, True, True, False]
def test_unop():
"""Test that basic unary opreations work"""
q = TestQuery([1, 2, 2, 3])
mq = UnaryOperationQuery("-", q)
result = _execute_q(mq)
assert list(result) == [-1, -2, -2, -3]
def test_scalar():
assert _execute_q(QueryNode.scalar(5)) == [5]
assert _execute_q(QueryNode.scalar(b"foo"), strings=False) \
== [1] # String code
assert _execute_q(QueryNode.scalar(None)) == [None]
@pytest.mark.parametrize("schema_name", ALL_SCHEMAS)
def test_typed_scalar(schema_name):
# pylint: disable=unidiomatic-typecheck
_schema, fill_value = ALL_SCHEMAS[schema_name]
result = _execute_q(QueryNode.scalar(fill_value), strings="bytes")[0]
assert type(result) == type(fill_value)
assert result == fill_value
def test_count_count():
"""Test count of count"""
q = QueryNode.filled(4, 1)
assert _execute_q(q) == [4]
qc = q.countq()
assert _execute_q(qc) == [1]
def test_filled_query():
q = QueryNode.filled(QueryNode.scalar(2), QueryNode.scalar(3))
assert _execute_q(q) == [2, 2, 2]
q = QueryNode.filled(N, 4)
assert _execute_q(q) == [N, N, N, N]
q = QueryNode.filled(TestQuery([np.nan]), 5)
assert _execute_q(q) == [N, N, N, N, N]
q = QueryNode.filled(TestQuery([np.nan, np.nan]), 5)
with pytest.raises(InvalidQueryException):
assert _execute_q(q)
def test_filled_broadcast_shared_structure():
qe = QueryEngine(QueryCache())
q_base = QueryNode.scalar(5)
q_broadcast = QueryNode.filled(q_base, 3)
result = dict(qe.execute_for_columns([q_base, q_broadcast]))
base_array = result[q_base]
assert list(base_array) == [5]
broadcast_array = result[q_broadcast]
assert list(broadcast_array) == [5, 5, 5]
assert (npy_memory_location(broadcast_array) ==
npy_memory_location(base_array))
def test_filled_lockstep_passthrough(monkeypatch):
monkeypatch.setattr(LockstepFilledArrayQuery,
"_do_broadcast", None)
vals = list(range(20))
q1 = TestQuery(vals)
q1_count = q1.countq()
assert isinstance(q1_count, CountQuery)
vals2 = [v + 1 for v in vals]
q2 = TestQuery(vals2)
assert _execute_q(QueryNode.filled(q2, q1_count), block_size=3) == vals2
with pytest.raises(InvalidQueryException):
_execute_q(QueryNode.filled(q2, TestQuery([1, 2, 3]).countq()),
block_size=3)
def test_filled_lockstep_broadcast(monkeypatch):
monkeypatch.setattr(LockstepFilledArrayQuery,
"_do_passthrough", None)
vals = list(range(20))
q1 = TestQuery(vals)
q1_count = q1.countq()
assert isinstance(q1_count, CountQuery)
assert _execute_q(QueryNode.filled(7, q1_count),
block_size=3) == [7] * len(vals)
def test_sequential_query():
"""Test sequential query"""
assert _execute_q(SequentialQuery(
QueryNode.scalar(3),
QueryNode.scalar(4),
QueryNode.scalar(2))) == [3, 5, 7, 9]
def _pyslice(start, count, step):
return list(range(start, start+step*(count), step))
@pytest.mark.parametrize("start", [0, 3])
@pytest.mark.parametrize("count", [20, 31])
@pytest.mark.parametrize("step", [2, 3, 4])
@pytest.mark.parametrize("block_size", [2, 3, 4])
def test_sequential_query_small_blocks(start, count, step, block_size):
"""Test sequential query"""
start = 3
count = 4
step = 2
assert _execute_q(
SequentialQuery(
QueryNode.scalar(start),
QueryNode.scalar(count),
QueryNode.scalar(step)),
block_size=block_size,
) == _pyslice(start, count, step)
def test_take_mask():
# pylint: disable=compare-to-zero
def _test(*, v, i, **kwargs):
q_values = TestQuery(v)
q_index = TestQuery(i)
return _execute_q(TakeQuery(q_values, q_index), **kwargs)
assert (_test(v=[1, 2, 3, 4],
i=[0, 2, -1, 3, 1])
== [1, 3, N, 4, 2])
assert _test(v=[], i=[-1, -1, -1]) == [N, N, N]
with pytest.raises(InvalidQueryException):
_test(v=[], i=[0, 1])
assert _test(v=[], i=[]) == []
assert _test(v=[1, 2, 3], i=[]) == []
assert _test(v=[1, 2, N, 4, 5], i=[1, 2, 3]) == [2, N, 4]
assert _test(v=[1, N, 3, 4, N], i=[-1, 1, 2]) == [N, N, 3]
assert _test(v=[1, N, 3, 4, 5], i=[-1, 1, 2]) == [N, N, 3]
arr = _test(v=[7], i=[-1], as_raw_array=True).data
assert arr[0] == 0
def test_take_identity():
q_values = QueryNode.literals(2, 4, 6, 8)
q_index = QueryNode.literals(0, 1, 2, 3)
q_taken = q_values.take(q_index)
result = _test_execute([q_values, q_taken],
as_raw_array=True,
block_size=2)
r_values = result[q_values]
r_taken = result[q_taken]
assert list(r_values) == list(r_taken)
assert r_values.base is not None
assert r_taken.base is not None
assert r_values.base is r_taken.base
def test_take_values_broadcast():
q_values = QueryNode.filled(count=10, fill_value=3)
q_index = QueryNode.literals(2, 4, 6)
r = _execute_q(q_values.take(q_index), as_raw_array=True)
assert r.tolist() == [3, 3, 3]
assert npy_get_broadcaster(r).tolist() == [3]
def test_take_index_is_broadcast():
q_values = QueryNode.literals(-1, -2, -3, -4)
q_index = QueryNode.filled(count=3, fill_value=2)
r = _execute_q(q_values.take(q_index), as_raw_array=True)
assert r.tolist() == [-3, -3, -3]
assert npy_get_broadcaster(r).tolist() == [-3]
def test_take_values_broadcast_mask_index():
q_values = QueryNode.filled(count=10, fill_value=3)
q_index = QueryNode.literals(2, -1, 6)
r = _execute_q(q_values.take(q_index), as_raw_array=True)
assert r.tolist() == [3, N, 3]
assert npy_get_broadcaster(r) is None
assert isinstance(r, ma)
assert r.data.tolist() == [3, 0, 3]
def test_take_values_broadcast_mask_values():
q_values = QueryNode.filled(count=10, fill_value=N)
q_index = QueryNode.literals(2, 4, 6)
r = _execute_q(q_values.take(q_index), as_raw_array=True)
assert r.tolist() == [N, N, N]
assert npy_get_broadcaster(r).tolist() == [N]
@pytest.mark.parametrize("src", ["numpy", "array"])
def test_mask_functions(src):
# pylint: disable=protected-access,compare-to-zero
from array import array as _array
array = (np.asarray([1, 2, 3])
if src == "numpy"
else np.asarray(_array("l", [1, 2, 3])))
ma_nomask = np.ma.masked_array(array)
ma_mask = np.ma.masked_array(array, mask=[T, F, T])
for value in (array, ma_nomask, ma_mask):
real_mask = np.ma.getmask(value)
is_masked_array = isinstance(value, np.ma.masked_array)
is_really_masked = real_mask is not nomask
assert npy_has_mask(value) is is_really_masked
if is_really_masked:
assert npy_get_data(value).base is array.base
assert npy_get_mask(value) is real_mask
x1, x2 = npy_explode(value)
assert x1.base is array.base
assert list(x1) == list(array)
assert x2 is value._mask
else:
if is_masked_array:
assert npy_get_data(value).base is array.base
else:
assert npy_get_data(value) is value
assert npy_get_mask(value) is nomask
x1, x2 = npy_explode(value)
if is_masked_array:
assert x1.base is array.base
else:
assert x1 is value
assert list(x1) == list(array)
assert x2 is nomask
@pytest.mark.parametrize("block_size", [1, 2, 3, 1024])
def test_sequential_take_basic(block_size):
q_values = QueryNode.literals(2, 4, 6, 8)
q_index = QueryNode.literals(1, 2, 3)
r = _execute_q(q_values.take_sequential(q_index),
as_raw_array=True,
block_size=block_size)
assert r.tolist() == [4, 6, 8]
def test_sequential_take_repeat():
q_values = QueryNode.literals(2, 4, 6, 8)
q_index = QueryNode.literals(1, 1, 1, 3)
r = _execute_q(q_values.take_sequential(q_index),
as_raw_array=True,
block_size=2)
assert r.tolist() == [4, 4, 4, 8]
def test_sequential_take_early_eof_error():
q_values = QueryNode.literals(2, 4, 6)
q_index = QueryNode.literals(9)
with pytest.raises(InvalidQueryException) as ex:
_execute_q(q_values.take_sequential(q_index),
as_raw_array=True,
block_size=2)
assert "early value EOF at index 3" in str(ex)
def test_sequential_take_mask():
q_values = QueryNode.literals(2, 4, 6, 8)
q_index = QueryNode.literals(1, -1, 3)
r = _execute_q(q_values.take_sequential(q_index),
as_raw_array=True,
block_size=2)
assert r.tolist() == [4, N, 8]
def test_sequential_take_nulls_at_end():
q_values = QueryNode.literals(2, 4, 6, 8)
q_index = QueryNode.literals(1, 3, -1, -1, -1, -1, -1)
r = _execute_q(q_values.take_sequential(q_index),
as_raw_array=True,
block_size=2)
assert r.tolist() == [4, 8] + [N] * 5
def test_sequential_take_null_only():
q_values = QueryNode.literals()
q_index = QueryNode.literals(-1, -1, -1, -1, -1)
r = _execute_q(q_values.take_sequential(q_index),
as_raw_array=True,
block_size=2)
assert r.tolist() == [N] * 5
def test_sequential_take_empty():
q_values = QueryNode.literals()
q_index = QueryNode.literals()
r = _execute_q(q_values.take_sequential(q_index),
as_raw_array=True,
block_size=2)
assert r.tolist() == []
def test_sequential_take_sequential():
q_values = QueryNode.literals(2, 4, 6, 8, 10)
q_index = QueryNode.literals(0, 1, 2, 3, 4)
r = _execute_q(q_values.take_sequential(q_index),
as_raw_array=True,
block_size=3)
assert r.tolist() == [2, 4, 6, 8, 10]
def test_sequential_take_sequential_then_normal():
q_values = QueryNode.literals(2, 4, 6, 8, 10)
q_index = QueryNode.literals(0, 1, 2, 2, 4)
r = _execute_q(q_values.take_sequential(q_index),
as_raw_array=True,
block_size=3)
assert r.tolist() == [2, 4, 6, 6, 10]
def test_sequential_take_sequential_rollback():
q_values = QueryNode.literals(2, 4, 6, 8, 10)
q_index = QueryNode.literals(0, 1, 1, 3, 4)
r = _execute_q(q_values.take_sequential(q_index),
as_raw_array=True,
block_size=10)
assert r.tolist() == [2, 4, 4, 8, 10]
@pytest.mark.parametrize("block_size", [1, 2, 3, 10])
@pytest.mark.parametrize("mode", ["full", "conservative"])
@pytest.mark.parametrize("null", ["null", "no_null"])
def test_sequential_take_broadcast_value(block_size, mode, null):
bcast_conservative = mode == "conservative"
use_null = null == "null"
q_values = QueryNode.filled(count=10, fill_value=3)
q_index = QueryNode.literals(0, 3, -1 if use_null else 4, 4, 4)
r = _execute_q(
q_values.take_sequential(q_index,
bcast_conservative=bcast_conservative),
as_raw_array=True,
block_size=block_size)
assert r.tolist() == [3, 3, N if use_null else 3, 3, 3]
bcast = npy_get_broadcaster(r)
if use_null or (block_size in (2, 3) and bcast_conservative):
assert bcast is None
else:
assert bcast.tolist() == [3]
@pytest.mark.skipif(
not SPAN_INVARIANT_CHECKING,
reason="optimized build doesn't check for index errors")
def test_sequential_take_invalid():
q_values = QueryNode.literals(1, 2, 3)
q_index = QueryNode.literals(1, 2, 1)
with pytest.raises(InvalidQueryException):
_execute_q(q_values.take_sequential(q_index))
def test_literal_query():
"""Test basic literal-sequence QueryNode support"""
_execute_q(LiteralQuery(())) == []
_execute_q(LiteralQuery((1, 2, 3))) == [1, 2, 3]
assert LiteralQuery((b"foo", b"bar")).schema.is_string
_execute_q(LiteralQuery((b"foo", b"bar"))) == [1, 2]
with pytest.raises(InvalidQueryException):
_execute_q(LiteralQuery((1, b"foo", 3)))
def test_similar_values_are_distinct():
"""Test distinctification"""
# Make sure that query interning isn't eagerly combining queries
# that test for equality on value (e.g., 0 == False), but not on
# type (int is not bool).
literals = [1.0, 1, True, False, 0]
qs = [LiteralQuery([lv]) for lv in literals]
assert len(set(qs)) == len(qs)
qs = [QueryNode.scalar(lv) for lv in literals]
assert len(set(qs)) == len(qs)
ALL_REGULAR_SCHEMAS = [schema_name for schema_name in ALL_SCHEMAS
if ALL_SCHEMAS[schema_name][0]
not in (NULL_SCHEMA, STRING_SCHEMA)]
@pict_parameterize(dict(schema_name_left=ALL_REGULAR_SCHEMAS,
schema_name_right=ALL_REGULAR_SCHEMAS))
def test_dtype_value_preserving(schema_name_left, schema_name_right):
left_schema = ALL_SCHEMAS[schema_name_left][0]
right_schema = ALL_SCHEMAS[schema_name_right][0]
ldtype = left_schema.dtype
ltype = ldtype.type
rdtype = right_schema.dtype
rtype = rdtype.type
def _extreme_values(dtype):
if dtype.kind == "b":
return [True, False]
if dtype.kind == "f":
info = np.finfo(dtype)
else:
info = np.iinfo(dtype)
return [dtype.type(info.min), dtype.type(info.max)]
extreme_left = _extreme_values(ldtype)
extreme_right = _extreme_values(rdtype)
q = LiteralQuery(extreme_left + extreme_right)
def _i64_cast(value):
return np.asarray(value).view("l").item()
expected_left = extreme_left
expected_right = extreme_right
# We're not value-preserving when mixing int64 and uint64
if ldtype.char == "L" and q.schema.dtype.char == "l":
expected_left = list(map(_i64_cast, expected_left))
ltype = np.int64
if rdtype.char == "L" and q.schema.dtype.char == "l":
expected_right = list(map(_i64_cast, expected_right))
rtype = np.int64
if rdtype.kind == "f" and ldtype.kind != "f":
ltype = rtype
expected_left = list(map(ltype, expected_left))
if ldtype.kind == "f" and rdtype.kind != "f":
rtype = ltype
expected_right = list(map(rtype, expected_right))
assert all_unique(expected_left)
assert all_unique(expected_right)
result = _execute_q(q)
assert len(result) == len(expected_left) + len(expected_right)
result_left = result[:len(expected_left)]
result_right = result[len(expected_left):]
assert result_left == expected_left
assert all_unique(result_left)
assert list(map(ltype, result_left)) == expected_left
assert result_right == expected_right
assert all_unique(result_right)
assert list(map(rtype, result_right)) == expected_right
def test_query_table_result():
# pylint: disable=unneeded-not,superfluous-parens
qt = TestQueryTable(
names=["foo", "bar"],
rows=[
[1, 2],
[3, 4],
])
r = _execute_qt(qt)
assert isinstance(r, TestQueryTableResult)
assert r == dict(foo=[1, 3], bar=[2, 4])
assert r == qt
assert not (r != qt)
qt = TestQueryTable(
names=["foo", "bar"],
rows=[
[3, 2],
[3, 4],
])
assert not (r == qt)
assert r != qt
qt = TestQueryTable(
names=["foo", "bar"],
schemas=[QuerySchema(INT32), QuerySchema(INT32)],
rows=[
[1, 2],
[3, 4],
])
assert not (r == qt)
assert r != qt
qt = TestQueryTable(
names=["_ts", "_duration"],
rows=[
[1, 2],
[3, 4],
],
table_schema=SPAN_UNPARTITIONED_TIME_MAJOR)
assert not (r == qt)
assert r != qt
r = _execute_qt(qt)
assert r == qt
qt = TestQueryTable(
names=["_ts", "_duration"],
rows=[
[1, 2],
[3, 4],
],
table_schema=EVENT_UNPARTITIONED_TIME_MAJOR)
assert r != qt
def test_schema_constraints():
q = QuerySchema(INT64)
assert not q.constraints
assert q.unconstrain() is q
q2 = q.constrain({UNIQUE})
assert q2 is not q
assert q2 != q
assert q2.is_a(q)
assert q2.constrain({UNIQUE}) is q2
q3 = q.constrain({UNIQUE})
assert q2 == q3
q3 = QuerySchema.concat([q, q2])
assert not q3.constraints
# Sorting
def _munge_sorts(sorts):
munged_sorts = OrderedDict()
def _add_sort(column, ascending=True, collation="binary"):
assert isinstance(collation, str)
munged_sorts[column] = (bool(ascending), collation)
for sort in sorts:
_add_sort(*sort)
return tuple(munged_sorts.items())
@final
class SortingQueryTable(QueryTable): # TODO(dancol): remove me!
"""SortingQueryTable sorts another QueryTable"""
base_table = iattr(QueryTable)
sorts = tattr(converter=_munge_sorts)
@override
def _make_column_tuple(self):
return self.base_table.columns
@override
def countq(self):
return self.base_table.countq()
def group_sizes(self):
"""Group sizes for tests"""
if not self.sorts:
return self.base_table.group_sizes()
return TakeQuery(self.base_table.group_sizes(),
self.__index_query)
@property
def __index_query(self):
"""Query that gives indexes to take from base table for sort"""
def _base_query(sort_column, collation):
query = self.base_table.column(sort_column)
if query.schema.is_string:
query = CollateQuery(query, collation)
return query
return ArgSortQuery(
((_base_query(sort_column, collation), ascending)
for sort_column, (ascending, collation) in self.sorts))
@override
def _make_column_query(self, column):
if not self.sorts:
return self.base_table.column(column)
return TakeQuery(self.base_table.column(column),
self.__index_query)
@pytest.mark.parametrize("ascending", [True, False])
def test_sorting_simple(ascending):
"""Test single-column sorting"""
to_sort = TestQueryTable(
rows=[
[3, 2, 7, 4],
[6, 7, 2, 8],
[1, 4, 0, 3],
[9, 3, 1, 2],
])
sorted_tbl = TestQueryTable(
rows=(identity if ascending else reversed)([
[1, 4, 0, 3],
[3, 2, 7, 4],
[6, 7, 2, 8],
[9, 3, 1, 2],
]))
qt = SortingQueryTable(to_sort, [("col0", ascending)])
assert _execute_qt(qt) == sorted_tbl.as_dict()
@pytest.mark.parametrize("col1_ascending", [True, False])
def test_sorting_lexicographical(col1_ascending):
"""Test that sorting works for multiple colums"""
table = [
[3, 2, 7, 4],
[5, 7, 2, 8],
[5, 4, 0, 3],
[9, 3, 1, 2],
]
inv = 1 if col1_ascending else -1
sorted_table = sorted(table, key=lambda row: (row[0], row[1] * inv))
qt = TestQueryTable(rows=table)
sorts = (
("col0", True),
("col1", col1_ascending),
)
# pylint: disable=redefined-variable-type
qt = SortingQueryTable(TestQueryTable(rows=table), sorts)
assert _execute_qt(qt) == TestQueryTable(rows=sorted_table).as_dict()
def _do_test_sorting_strings(strings, ascending, case_sensitive):
qe = QueryEngine(QueryCache())
shuffled_strings = strings[:]
_make_rng().shuffle(shuffled_strings)
qe.st.vintern(tuple(s.encode("UTF-8") for s in shuffled_strings))
qt = TestQueryTable(columns=[shuffled_strings],
schema=STRING_SCHEMA)
sorts = (
("col0", ascending, "binary" if case_sensitive else "nocase"),
)
result = _execute_qt(SortingQueryTable(qt, sorts), strings=True)["col0"]
assert result == sorted(
strings,
key=lambda s: s if case_sensitive else s.lower(),
reverse=not ascending)
@pytest.mark.parametrize("ascending", [True, False])
def test_sorting_strings(ascending):
"""Test string sorting query interface"""
_do_test_sorting_strings(
strings=["bar", "qux", "foo", "1", "", "341"],
ascending=ascending,
case_sensitive=True,
)
@pytest.mark.parametrize("ascending", [True, False])
def test_sorting_strings_case_insensitive(ascending):
"""Test string sorting query interface in case insensitive mode"""
_do_test_sorting_strings(
strings=["bar", "Qux", "foo", "1", "", "341"],
ascending=ascending,
case_sensitive=False,
)
def test_masked_broadcast_to():
assert masked_broadcast_to(np.asarray([1]), 3).tolist() == [1, 1, 1]
assert masked_broadcast_to(ma([1], [T]), 3).tolist() == [N, N, N]
x = masked_broadcast_to(ma([1], [F]), 3)
assert x.tolist() == [1, 1, 1]
assert type(x) is np.ndarray # pylint: disable=unidiomatic-typecheck
x = masked_broadcast_to(ma([1]), 3)
assert x.tolist() == [1, 1, 1]
assert type(x) is np.ndarray # pylint: disable=unidiomatic-typecheck
@pytest.mark.parametrize("mode", ["npy", "array"])
def test_get_broadcaster(mode):
def _m(values):
if mode == "array":
from array import array as _array
values = _array("l", values)
return np.asarray(values)
assert npy_get_broadcaster(_m([1])) is None
x = _m([1, 2, 3])
assert npy_get_broadcaster(x) is None
y = _m([1])
assert npy_get_broadcaster(y) is None
z = np.broadcast_to(y, 5)
assert list(z) == [1, 1, 1, 1, 1]
z_orig = npy_get_broadcaster(z)
assert list(z_orig) == [1]
assert z_orig is y
with pytest.raises(TypeError):
npy_get_broadcaster(1)
def test_get_broadcaster_masked():
assert npy_get_broadcaster(ma([1])) is None
x = ma([1, 2, 3], [T, F, T])
assert not npy_get_broadcaster(x)
y = ma([1], [T])
assert npy_get_broadcaster(y) is None
z = masked_broadcast_to(y, 5)
assert z.data.tolist() == [1, 1, 1, 1, 1]
assert z.tolist() == [N, N, N, N, N]
z_orig = npy_get_broadcaster(z)
assert z_orig.tolist() == y.tolist()
assert z_orig.data.base is y.data.base
assert z_orig.mask.base is y.mask.base
def test_get_broadcaster_nomask():
nm = ma([1])
assert nm.mask is nomask
nm_bcast = ma(np.broadcast_to(nm.data.base, 3))
assert nm_bcast.tolist() == [1, 1, 1]
assert nm_bcast.mask is nomask
nm_orig = npy_get_broadcaster(nm_bcast)
assert type(nm_orig) is np.ndarray # pylint: disable=unidiomatic-typecheck
assert nm_orig is nm.data.base
def test_get_broadcaster_scalar():
orig = np.int32(7)
bcast = npy_broadcast_to(orig, 3)
assert bcast.tolist() == [7, 7, 7]
broadcaster = npy_get_broadcaster(bcast)
assert type(broadcaster) is np.ndarray # pylint: disable=unidiomatic-typecheck
assert broadcaster.size == 1
assert broadcaster.ndim == 1
assert broadcaster.tolist() == [7]
assert broadcaster.dtype == np.dtype("i")
def test_native_broadcast_to():
x = np.asarray([1])
y = npy_broadcast_to(x, 5)
assert y.tolist() == [1, 1, 1, 1, 1]
assert y.base is x
assert npy_broadcast_to([5], 3).tolist() == [5, 5, 5]
assert npy_broadcast_to([1, 2, 3], 3).tolist() == [1, 2, 3]
def test_native_broadcast_to_scalar():
scalar = np.int32(7)
scalar_bcast = npy_broadcast_to(scalar, 3)
assert scalar_bcast.tolist() == [7, 7, 7]
assert type(scalar_bcast) is np.ndarray # pylint: disable=unidiomatic-typecheck
base = scalar_bcast.base
assert type(base) is np.ndarray # pylint: disable=unidiomatic-typecheck
assert base.ndim == 0 # pylint: disable=compare-to-zero
assert base.size == 1
assert base.tolist() == 7
assert type(base.tolist()) is int # pylint: disable=unidiomatic-typecheck
def test_native_broadcast_to_invalid():
with pytest.raises(Exception):
npy_broadcast_to([1, 2], 3)
with pytest.raises(Exception):
npy_broadcast_to([[1]], 3)
def test_native_broadcast_to_masked_array():
ma_in = ma([1, 2, 3], [T, F, T])
assert ma_in.tolist() == [N, 2, N]
z = npy_broadcast_to(ma_in, 3)
assert type(z) is np.ndarray # pylint: disable=unidiomatic-typecheck
assert z.tolist() == [1, 2, 3]
assert z.base is ma_in.data.base
class TestReadAllQuery(SimpleQueryNode):
base = iattr(QueryNode)
read_size = iattr(int)
priming_read = iattr(int)
@override
def _compute_schema(self):
return QuerySchema(INT64)
@override
def _compute_inputs(self):
return [self.base]
@override
async def run_async(self, qe):
[ic], [oc] = await qe.async_setup([self.base], [self])
obs_size = 0
block = await ic.read(self.priming_read or self.read_size)
while block:
array = block.as_array()
broadcaster = npy_get_broadcaster(array)
if len(array) > 1:
assert broadcaster is not None
assert broadcaster.tolist() * len(block) == array.tolist()
obs_size += len(array)
block = await ic.read(self.read_size)
await oc.write([obs_size], True)
@pytest.mark.parametrize("block_size", [1, 3, 4, 5])
@pytest.mark.parametrize("read_size", [2, 3, 4])
@pytest.mark.parametrize("priming_read", [0, 1])
@pytest.mark.parametrize("null", [True, False])
def test_broadcast_read_all(block_size, read_size, priming_read, null):
total_size = 16
val = _execute_q(
TestReadAllQuery(QueryNode.filled(None if null else 3, total_size),
read_size,
priming_read),
block_size=block_size,
)
assert val == [total_size]