| # Copyright (C) 2020 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http:#www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Tests for basic query machinery""" |
| |
| # pylint: disable=missing-docstring |
| |
| import os |
| import logging |
| from collections import OrderedDict |
| from cytoolz import valmap, identity |
| import numpy as np |
| import pytest |
| |
| from modernmp.util import the |
| |
| from .query import ( |
| ArgSortQuery, |
| BinaryOperationQuery, |
| CollateQuery, |
| CountQuery, |
| DURATION_SCHEMA, |
| EVENT_UNPARTITIONED_TIME_MAJOR, |
| InvalidQueryException, |
| LiteralQuery, |
| LockstepFilledArrayQuery, |
| NULL_FILL_VALUE, |
| NULL_SCHEMA, |
| QueryNode, |
| QuerySchema, |
| QueryTable, |
| REGULAR_TABLE, |
| SPAN_UNPARTITIONED_TIME_MAJOR, |
| STRING_SCHEMA, |
| SequentialQuery, |
| SimpleQueryNode, |
| TS_SCHEMA, |
| TableKind, |
| TableSchema, |
| TableSorting, |
| TakeQuery, |
| UNIQUE, |
| UnaryOperationQuery, |
| masked_broadcast_to, |
| ) |
| |
| from .queryengine import ( |
| Delivery, |
| QueryCache, |
| QueryEngine, |
| ) |
| |
| from .test_util import ( |
| assertrepr_compare_hooks, |
| pict_parameterize, |
| ) |
| |
| from .util import ( |
| BOOL, |
| ExplicitInheritance, |
| FLOAT32, |
| FLOAT64, |
| INT16, |
| INT32, |
| INT64, |
| INT8, |
| UINT16, |
| UINT32, |
| UINT64, |
| UINT8, |
| all_unique, |
| final, |
| iattr, |
| lmap, |
| load_pandas, |
| override, |
| tattr, |
| ureg, |
| ) |
| |
| from ._native import ( |
| SPAN_INVARIANT_CHECKING, |
| npy_broadcast_to, |
| npy_explode, |
| npy_get_broadcaster, |
| npy_get_data, |
| npy_get_mask, |
| npy_has_mask, |
| npy_memory_location, |
| ) |
| |
| log = logging.getLogger(__name__) |
| |
| T = True |
| F = False |
| N = None |
| ma = np.ma.masked_array |
| nomask = np.ma.nomask |
| |
| # Various data types we use to exhaustively test how various query |
| # operators work across different dtypes. |
| ALL_SCHEMAS = dict( |
| [("str", (STRING_SCHEMA, b"")), |
| ("domain_a", (QuerySchema(INT64, domain="a"), np.int64(3))), |
| ("domain_b", (QuerySchema(INT64, domain="b"), np.int64(4))), |
| ("bytes", (QuerySchema(INT64, unit=ureg().bytes), np.int64(5))), |
| ("cm", (QuerySchema(INT64, unit=ureg().cm), np.int64(6))), |
| ("NULL", (NULL_SCHEMA, NULL_FILL_VALUE))] + |
| [(str(dtype), (QuerySchema(dtype), dtype.type())) |
| for dtype in (INT8, UINT8, INT16, UINT16, INT32, UINT32, |
| INT64, UINT64, BOOL, FLOAT32, FLOAT64)]) |
| |
| @final |
| class TestQuery(SimpleQueryNode): |
| """Dummy query for test""" |
| |
| values = tattr() |
| _schema = iattr(QuerySchema, |
| default=QuerySchema(INT64), |
| name="schema") |
| |
| @override |
| def __new__(cls, values, schema=None): |
| values = tuple(values) |
| have_string_values = any(isinstance(v, (bytes, str)) |
| for v in values) |
| if schema is None: |
| schema = (STRING_SCHEMA if have_string_values |
| else QuerySchema(INT64)) |
| else: |
| assert (schema.is_string == have_string_values |
| or not values |
| or all(v is None for v in values)), \ |
| ("string value-schema mismatch: schema={!r} values={!r}" |
| .format(schema, values)) |
| return cls._do_new(cls, values, schema) |
| |
| @override |
| def _compute_schema(self): |
| return self._schema |
| |
| @override |
| def _compute_inputs(self): |
| return () |
| |
| @property |
| def __array(self): |
| values = self.values |
| try: |
| need_mask = None in values or np.isnan(values).any() |
| except TypeError: |
| need_mask = False |
| if need_mask: |
| mask = tuple(value is None or np.isnan(value) |
| for value in values) |
| values = tuple(0 if value is None or np.isnan(value) |
| else value for value in values) |
| array = np.ma.masked_array(values, mask) |
| else: |
| array = np.array(values) |
| assert array.dtype != np.dtype("O") |
| return np.asanyarray(array, dtype=self._schema.dtype) |
| |
| @override |
| async def run_async(self, qe): |
| # pylint: disable=comparison-with-itself |
| [], [oc] = await qe.async_setup((), (self,)) |
| for value in self.values: |
| if value is None or value != value: # nan test |
| await oc.write( |
| np.ma.masked_array([0], [True], |
| dtype=self.schema.dtype)) |
| continue |
| if self.schema.is_string: |
| value = qe.st.intern(value.encode("UTF-8") |
| if isinstance(value, str) else value) |
| await oc.write([value]) |
| |
| @override |
| def __repr__(self): |
| return "<TestQuery {!r}>".format(self.values) |
| |
| _UNSPECIFIED = object() |
| |
| class TestQueryTable(QueryTable): |
| """QueryTable for test""" |
| |
| # TestQuery has a passing similarity to a LiteralQuery over in the |
| # serious-business side of the query system, query.py, but TestQuery |
| # and TestQueryTable have some ergonomic conveniences that wouldn't |
| # be right for production code, e.g., `_end' conversion. |
| |
| _columns = iattr() |
| _schemas = tattr() |
| |
| @override |
| def __new__(cls, *, columns=None, rows=None, names=None, |
| schema=None, |
| schemas=None, |
| table_schema=_UNSPECIFIED, |
| span_partition=_UNSPECIFIED, |
| event_partition=_UNSPECIFIED): |
| # pylint: disable=redefined-variable-type |
| assert columns is None or rows is None |
| if span_partition is not _UNSPECIFIED: |
| assert table_schema is _UNSPECIFIED |
| assert event_partition is _UNSPECIFIED |
| table_schema = TableSchema(TableKind.SPAN, span_partition, |
| TableSorting.TIME_MAJOR) |
| elif event_partition is not _UNSPECIFIED: |
| assert table_schema is _UNSPECIFIED |
| assert span_partition is _UNSPECIFIED |
| table_schema = TableSchema(TableKind.EVENT, |
| event_partition, |
| TableSorting.TIME_MAJOR) |
| elif table_schema is not _UNSPECIFIED: |
| assert event_partition is _UNSPECIFIED |
| assert span_partition is _UNSPECIFIED |
| assert isinstance(table_schema, TableSchema) |
| else: |
| table_schema = REGULAR_TABLE |
| if names is not None: |
| names = list(names) |
| if rows is not None: |
| def _fixup_row(row): |
| nonlocal names |
| if isinstance(row, dict): |
| row_names = list(row) |
| row_values = list(row.values()) |
| |
| # Gross hack, putting this special case here, but it makes |
| # the test tables a lot less annoying to write. |
| if "_live_end" in row_names: |
| le_idx = row_names.index("_live_end") |
| row_ts = row["_ts"] |
| assert row_ts <= row_values[le_idx] |
| row_names[le_idx] = "live_duration" |
| row_values[le_idx] -= row_ts |
| |
| if names is None: |
| names = row_names |
| else: |
| assert names == row_names, ( |
| "dict rows must have dict keys matching the " |
| "column names") |
| row = row_values |
| assert isinstance(row, (tuple, list)) |
| return row |
| rows = lmap(_fixup_row, rows) |
| assert all(len(r) == len(rows[0]) for r in rows) |
| columns = _flip_table(rows) |
| if names is None: |
| names = ["col{}".format(i) for i in range(len(columns))] |
| if not columns: |
| columns = [[]] * len(names) |
| assert len(columns) == len(names) |
| if schema: |
| assert schemas is None |
| schemas = (schema,) * len(columns) |
| elif schemas is not None: |
| assert not schema |
| schemas = tuple( |
| (s if isinstance(s, QuerySchema) else |
| QuerySchema(s)) for s in schemas) |
| assert len(schemas) == len(columns) |
| else: |
| schemas = [None] * len(columns) |
| if table_schema.kind == TableKind.SPAN: |
| names = list(names) |
| assert names[0] == "_ts" |
| assert names[1] in ("_duration", "_end") |
| if names[1] == "_end": |
| names[1] = "_duration" |
| def _dur(end, ts): |
| assert ts < end |
| return end - ts |
| columns[1] = [_dur(end, ts) |
| for ts, end in zip(columns[0], columns[1])] |
| schemas = list(schemas) |
| schemas[0] = TS_SCHEMA |
| schemas[1] = DURATION_SCHEMA |
| elif table_schema.kind == TableKind.EVENT: |
| assert names[0] == "_ts" |
| schemas = list(schemas) |
| schemas[0] = TS_SCHEMA |
| columns = OrderedDict( |
| (the(str, name), tuple(data)) |
| for name, data in zip(names, columns)) |
| |
| return cls._do_new(cls, |
| _columns=columns, |
| _schemas=schemas, |
| table_schema=table_schema) |
| |
| @override |
| def _make_column_tuple(self): |
| return tuple(self._columns) |
| |
| @override |
| def _make_column_query(self, column): |
| """Return a query for a specific named column""" |
| idx = self.columns.index(column) |
| return TestQuery(self._columns[column], self._schemas[idx]) |
| |
| def as_dataframe(self): |
| """Convert the test query table to a Pandas DataFrame""" |
| pd = load_pandas() |
| return pd.DataFrame(self.as_dict()) |
| |
| def as_dict(self): |
| """Dict view""" |
| return valmap(list, self._columns) |
| |
| def _make_rng(): |
| from random import Random |
| rng = Random() |
| rng.seed("the mome raths outgrabe") |
| return rng |
| |
| def _flip_table(table): |
| return lmap(list, zip(*table)) |
| |
| @final |
| class TestQueryTableResult(dict, ExplicitInheritance): |
| |
| @override |
| def __new__(cls, data, *, table_schema, column_queries): # pylint: disable=unused-argument |
| return super().__new__(cls) |
| |
| @override |
| def __init__(self, data, *, table_schema, column_queries): |
| super().__init__() |
| self.update(data) |
| self.__table_schema = the(TableSchema, table_schema) |
| self.__column_queries = dict(column_queries) |
| assert list(self.keys()) == list(self.__column_queries.keys()) |
| assert all(isinstance(q, QueryNode) |
| for q in self.__column_queries.values()) |
| |
| @property |
| def table_schema(self): |
| """The table schema of the result""" |
| return self.__table_schema |
| |
| @property |
| def column_queries(self): |
| """A mapping of column queries |
| |
| Maps column name to QueryNode. Order matches the data dict. |
| """ |
| return self.__column_queries |
| |
| @property |
| def column_schemas(self): |
| """Sequence of column schemas""" |
| return [q.schema for q in self.column_queries.values()] |
| |
| def equals_test_qt(self, qt): |
| """Dedicated equality comparison with TestQueryTable |
| |
| Verifies column schemas, table schema as well |
| as column data.""" |
| assert isinstance(qt, TestQueryTable) |
| return ( |
| list(self) == list(qt.columns) and |
| self.table_schema == qt.table_schema and |
| len(self.column_schemas) == len(qt.column_schemas_for_test) and |
| all(l_schema.is_a(r_schema) for l_schema, r_schema |
| in zip(self.column_schemas, qt.column_schemas_for_test)) and |
| self == qt.as_dict()) |
| |
| def assert_eq(self, qt): |
| """Assert that result matches in test""" |
| assert list(self) == list(qt.columns) |
| assert self.table_schema == qt.table_schema |
| assert len(self.column_schemas) == len(qt.column_schemas_for_test) |
| for l_schema, r_schema in zip(self.column_schemas, |
| qt.column_schemas_for_test): |
| assert l_schema.is_a(r_schema) |
| assert self == qt.as_dict() |
| assert self == qt |
| return True |
| |
| @override |
| def __eq__(self, other): |
| if isinstance(other, TestQueryTable): |
| return self.equals_test_qt(other) |
| return super().__eq__(other) |
| |
| @override |
| def __ne__(self, other): |
| if isinstance(other, TestQueryTable): |
| return not self.equals_test_qt(other) |
| return super().__ne__(other) |
| |
| def _qt_failure_hook(config, op, left, right): |
| if (op == "==" |
| and isinstance(left, TestQueryTableResult) |
| and isinstance(right, TestQueryTable) |
| and left != right): |
| from _pytest.assertion.util import assertrepr_compare |
| if left.table_schema != right.table_schema: |
| return assertrepr_compare(config, op, |
| left.table_schema, right.table_schema) |
| left_columns = list(left) |
| right_columns = list(right.columns) |
| if left_columns != right_columns: |
| return assertrepr_compare(config, op, |
| left_columns, right_columns) |
| left_schemas = left.column_schemas |
| right_schemas = right.column_schemas_for_test |
| if (len(left_schemas) != len(right_schemas) or |
| not all(ls.is_a(rs) |
| for ls, rs in zip(left_schemas, right_schemas))): |
| return assertrepr_compare(config, op, |
| left_schemas, right_schemas) |
| if dict(left) != right.as_dict(): |
| return assertrepr_compare(config, op, |
| dict(left), right.as_dict()) |
| log.warning("unknown difference?!") |
| return None |
| |
| assertrepr_compare_hooks.append(_qt_failure_hook) |
| |
| def _test_execute(wanted, *, |
| as_dtypes=False, |
| strings=True, |
| env=None, |
| qe=None, |
| block_size=None, |
| as_raw_array=False): |
| """Common functionality for _execute_q and _execute_qt""" |
| envvar = "DCTV_TEST_BLOCK_SIZE" |
| if block_size is None and envvar in os.environ: |
| block_size = int(os.environ[envvar]) |
| qc_kwargs = {} |
| if block_size is not None: |
| qc_kwargs["block_size"] = block_size |
| if qe is None: |
| qe = QueryEngine(QueryCache(**qc_kwargs), env=env) |
| def _fix(schema, array): |
| assert isinstance(array, np.ndarray) |
| if as_dtypes: |
| return array.dtype |
| if strings and schema.is_string: |
| decoded = qe.st.vlookup(np.ma.filled(array, 0)) |
| mask = np.ma.getmaskarray(array) |
| decoded[mask] = None |
| if strings != "bytes": |
| decoded[~mask] = [s.decode("UTF-8") for s in decoded[~mask]] |
| array = decoded |
| if not as_raw_array: |
| array = [ |
| None if masked else value |
| for value, masked in zip(array, np.ma.getmaskarray(array)) |
| ] |
| return array |
| return { |
| query: _fix(query.schema, data) |
| for query, data in |
| qe.execute_for_columns(wanted) |
| } |
| |
| def _execute_q(q, **kwargs): |
| """Execute a query for test""" |
| return _test_execute([q], **kwargs)[q] |
| |
| def _execute_qt(qt, **kwargs): |
| """Execute a query table for test""" |
| assert isinstance(qt, QueryTable) |
| wanted = [qt[column] for column in qt.columns] |
| result = _test_execute(wanted, **kwargs) |
| return TestQueryTableResult( |
| {column_name: result[qt[column_name]] |
| for column_name in qt.columns}, |
| table_schema=qt.table_schema, |
| column_queries=zip(qt.columns, wanted), |
| ) |
| |
| def test_query_basic(): |
| qe = QueryEngine(QueryCache(block_size=128)) |
| q = TestQuery([1, 2, 3]) |
| result = list(qe.execute([q])) |
| [[rq, rarray, is_eof]] = result |
| assert rq is q |
| assert list(rarray) == list(q.values) |
| assert is_eof |
| |
| def test_query_small_blocks(): |
| # pylint: disable=len-as-condition,compare-to-zero |
| qe = QueryEngine(QueryCache(block_size=1)) |
| q = TestQuery([1, 2, 3]) |
| result = list(qe.execute([q])) |
| for n in (1, 2, 3): |
| rq, rarray, is_eof = result.pop(0) |
| assert rq is q |
| assert list(rarray) == [n] |
| assert is_eof is False |
| [[rq, rarray, is_eof]] = result |
| assert rq is q |
| assert len(rarray) == 0 |
| assert is_eof |
| |
| _HINTS = { |
| "arbitrary": Delivery.ARBITRARY, |
| "columns": Delivery.COLUMNWISE, |
| "rows": Delivery.ROWWISE, |
| } |
| |
| @pytest.mark.parametrize("delivery_hint", _HINTS.keys()) |
| @pytest.mark.parametrize("masked", ["mask", "no_mask"]) |
| @pytest.mark.parametrize("block_size", [1, 3, 128]) |
| def test_query_for_columns(delivery_hint, masked, block_size): |
| # pylint: disable=unidiomatic-typecheck |
| qe = QueryEngine(QueryCache(block_size=block_size)) |
| middle_value = { |
| "mask": None, |
| "no_mask": 2, |
| } |
| q = TestQuery([1, middle_value[masked], 3]) |
| [[rq, ra]] = list(qe.execute_for_columns( |
| [q], delivery_hint=_HINTS[delivery_hint])) |
| assert rq is q |
| if None in q.values: |
| assert type(ra) is np.ma.masked_array |
| else: |
| assert type(ra) is np.ndarray |
| assert ra.tolist() == list(q.values) |
| |
| @pytest.mark.parametrize("block_size", [1, 3, 128]) |
| @pytest.mark.parametrize("delivery_hint", _HINTS.keys()) |
| def test_query_multiple_rows(block_size, delivery_hint): |
| q1 = TestQuery([1, 2, 3]) |
| q2 = TestQuery([4, 5, 6]) |
| qe = QueryEngine(QueryCache(block_size=block_size)) |
| [[rq1, ra1], [rq2, ra2]] = list(qe.execute_for_columns( |
| [q1, q2], delivery_hint=_HINTS[delivery_hint])) |
| if rq1 is not q1: |
| tmp = rq1, ra1 |
| rq1, ra1 = rq2, ra2 |
| rq2, ra2 = tmp |
| assert rq1 is q1 |
| assert rq2 is q2 |
| assert list(ra1) == list(q1.values) |
| assert list(ra2) == list(q2.values) |
| |
| @pytest.mark.parametrize("block_size", [1, 3, 128]) |
| @pytest.mark.parametrize("delivery_hint", _HINTS.keys()) |
| def test_query_multiple_rows_unequal(block_size, delivery_hint): |
| q1 = TestQuery([1, 2, 3, 4, 5, 6, 7, 8, 9]) |
| q2 = TestQuery([4, 5, 6]) |
| qe = QueryEngine(QueryCache(block_size=block_size)) |
| [[rq1, ra1], [rq2, ra2]] = list(qe.execute_for_columns( |
| [q1, q2], delivery_hint=_HINTS[delivery_hint])) |
| if rq1 is not q1: |
| tmp = rq1, ra1 |
| rq1, ra1 = rq2, ra2 |
| rq2, ra2 = tmp |
| assert rq1 is q1 |
| assert rq2 is q2 |
| assert list(ra1) == list(q1.values) |
| assert list(ra2) == list(q2.values) |
| |
| def test_query_table(): |
| """Test whether QueryTable basically works""" |
| columns = [[1, 2, 3], [4, 5, 6]] |
| qt = TestQueryTable(columns=columns, names=["foo", "bar"]) |
| assert _execute_qt(qt) == dict(foo=columns[0], bar=columns[1]) |
| |
| @pytest.mark.parametrize("partitioned", ["partitioned", "unpartitioned"]) |
| def test_query_table_order_conversion(partitioned): |
| do_part = partitioned == "partitioned" |
| def _mkqt(sorting, rows): |
| return TestQueryTable( |
| names=["_ts", "_duration", "part", "foo"], |
| rows=rows, |
| table_schema=TableSchema( |
| TableKind.SPAN, |
| "part" if do_part else None, |
| sorting)) |
| base_rows = [ |
| [3, 1, 1, 33], |
| [3, 1, 0, -3], |
| [4, 1, 1, -4], |
| [7, 1, 0, -7], |
| [1, 2, 0, -1], |
| [5, 1, 1, -5], |
| ] |
| foo_none = _mkqt(TableSorting.NONE, base_rows) |
| foo_time_major = _mkqt( |
| TableSorting.TIME_MAJOR, |
| (sorted(base_rows, |
| key=(lambda row: (row[0], row[2]) if do_part else row[0])))) |
| foo_partition_major = _mkqt( |
| TableSorting.PARTITION_MAJOR, |
| sorted(base_rows, |
| key=(lambda row: (row[2], row[0]) if do_part else row[0]))) |
| assert _execute_qt(foo_none) == foo_none |
| assert _execute_qt(foo_none.transform(identity)) == foo_none |
| assert _execute_qt( |
| foo_none.to_schema(sorting=TableSorting.PARTITION_MAJOR) |
| ) == foo_partition_major |
| assert _execute_qt( |
| foo_none.to_schema(sorting=TableSorting.TIME_MAJOR) |
| ) == foo_time_major |
| assert _execute_qt( |
| foo_none |
| .to_schema(sorting=TableSorting.TIME_MAJOR) |
| .to_schema(sorting=TableSorting.PARTITION_MAJOR) |
| ) == foo_partition_major |
| assert _execute_qt( |
| foo_none |
| .to_schema(sorting=TableSorting.TIME_MAJOR) |
| .to_schema(sorting=TableSorting.PARTITION_MAJOR) |
| .to_schema(sorting=TableSorting.NONE) |
| ) == _execute_qt(foo_partition_major.with_table_schema( |
| foo_partition_major.table_schema.evolve( |
| sorting=TableSorting.NONE))) |
| |
| def test_match(): |
| """Test that we can match a condition""" |
| q = TestQuery([1, 2, 2, 3]) |
| mq = BinaryOperationQuery(q, "=", QueryNode.filled(2, q.countq())) |
| assert _execute_q(mq) == [False, True, True, False] |
| |
| def test_match_backward(): |
| """Test that we can match a condition the other way around""" |
| q = TestQuery([1, 2, 2, 3]) |
| mq = BinaryOperationQuery(QueryNode.filled(2, q.countq()), "=", q) |
| assert _execute_q(mq) == [False, True, True, False] |
| |
| def test_unop(): |
| """Test that basic unary opreations work""" |
| q = TestQuery([1, 2, 2, 3]) |
| mq = UnaryOperationQuery("-", q) |
| result = _execute_q(mq) |
| assert list(result) == [-1, -2, -2, -3] |
| |
| def test_scalar(): |
| assert _execute_q(QueryNode.scalar(5)) == [5] |
| assert _execute_q(QueryNode.scalar(b"foo"), strings=False) \ |
| == [1] # String code |
| assert _execute_q(QueryNode.scalar(None)) == [None] |
| |
| @pytest.mark.parametrize("schema_name", ALL_SCHEMAS) |
| def test_typed_scalar(schema_name): |
| # pylint: disable=unidiomatic-typecheck |
| _schema, fill_value = ALL_SCHEMAS[schema_name] |
| result = _execute_q(QueryNode.scalar(fill_value), strings="bytes")[0] |
| assert type(result) == type(fill_value) |
| assert result == fill_value |
| |
| def test_count_count(): |
| """Test count of count""" |
| q = QueryNode.filled(4, 1) |
| assert _execute_q(q) == [4] |
| qc = q.countq() |
| assert _execute_q(qc) == [1] |
| |
| def test_filled_query(): |
| q = QueryNode.filled(QueryNode.scalar(2), QueryNode.scalar(3)) |
| assert _execute_q(q) == [2, 2, 2] |
| q = QueryNode.filled(N, 4) |
| assert _execute_q(q) == [N, N, N, N] |
| q = QueryNode.filled(TestQuery([np.nan]), 5) |
| assert _execute_q(q) == [N, N, N, N, N] |
| q = QueryNode.filled(TestQuery([np.nan, np.nan]), 5) |
| with pytest.raises(InvalidQueryException): |
| assert _execute_q(q) |
| |
| def test_filled_broadcast_shared_structure(): |
| qe = QueryEngine(QueryCache()) |
| q_base = QueryNode.scalar(5) |
| q_broadcast = QueryNode.filled(q_base, 3) |
| result = dict(qe.execute_for_columns([q_base, q_broadcast])) |
| base_array = result[q_base] |
| assert list(base_array) == [5] |
| broadcast_array = result[q_broadcast] |
| assert list(broadcast_array) == [5, 5, 5] |
| assert (npy_memory_location(broadcast_array) == |
| npy_memory_location(base_array)) |
| |
| def test_filled_lockstep_passthrough(monkeypatch): |
| monkeypatch.setattr(LockstepFilledArrayQuery, |
| "_do_broadcast", None) |
| vals = list(range(20)) |
| q1 = TestQuery(vals) |
| q1_count = q1.countq() |
| assert isinstance(q1_count, CountQuery) |
| vals2 = [v + 1 for v in vals] |
| q2 = TestQuery(vals2) |
| assert _execute_q(QueryNode.filled(q2, q1_count), block_size=3) == vals2 |
| with pytest.raises(InvalidQueryException): |
| _execute_q(QueryNode.filled(q2, TestQuery([1, 2, 3]).countq()), |
| block_size=3) |
| |
| def test_filled_lockstep_broadcast(monkeypatch): |
| monkeypatch.setattr(LockstepFilledArrayQuery, |
| "_do_passthrough", None) |
| vals = list(range(20)) |
| q1 = TestQuery(vals) |
| q1_count = q1.countq() |
| assert isinstance(q1_count, CountQuery) |
| assert _execute_q(QueryNode.filled(7, q1_count), |
| block_size=3) == [7] * len(vals) |
| |
| def test_sequential_query(): |
| """Test sequential query""" |
| assert _execute_q(SequentialQuery( |
| QueryNode.scalar(3), |
| QueryNode.scalar(4), |
| QueryNode.scalar(2))) == [3, 5, 7, 9] |
| |
| def _pyslice(start, count, step): |
| return list(range(start, start+step*(count), step)) |
| |
| @pytest.mark.parametrize("start", [0, 3]) |
| @pytest.mark.parametrize("count", [20, 31]) |
| @pytest.mark.parametrize("step", [2, 3, 4]) |
| @pytest.mark.parametrize("block_size", [2, 3, 4]) |
| def test_sequential_query_small_blocks(start, count, step, block_size): |
| """Test sequential query""" |
| start = 3 |
| count = 4 |
| step = 2 |
| assert _execute_q( |
| SequentialQuery( |
| QueryNode.scalar(start), |
| QueryNode.scalar(count), |
| QueryNode.scalar(step)), |
| block_size=block_size, |
| ) == _pyslice(start, count, step) |
| |
| def test_take_mask(): |
| # pylint: disable=compare-to-zero |
| def _test(*, v, i, **kwargs): |
| q_values = TestQuery(v) |
| q_index = TestQuery(i) |
| return _execute_q(TakeQuery(q_values, q_index), **kwargs) |
| assert (_test(v=[1, 2, 3, 4], |
| i=[0, 2, -1, 3, 1]) |
| == [1, 3, N, 4, 2]) |
| |
| assert _test(v=[], i=[-1, -1, -1]) == [N, N, N] |
| with pytest.raises(InvalidQueryException): |
| _test(v=[], i=[0, 1]) |
| assert _test(v=[], i=[]) == [] |
| assert _test(v=[1, 2, 3], i=[]) == [] |
| assert _test(v=[1, 2, N, 4, 5], i=[1, 2, 3]) == [2, N, 4] |
| assert _test(v=[1, N, 3, 4, N], i=[-1, 1, 2]) == [N, N, 3] |
| assert _test(v=[1, N, 3, 4, 5], i=[-1, 1, 2]) == [N, N, 3] |
| |
| arr = _test(v=[7], i=[-1], as_raw_array=True).data |
| assert arr[0] == 0 |
| |
| def test_take_identity(): |
| q_values = QueryNode.literals(2, 4, 6, 8) |
| q_index = QueryNode.literals(0, 1, 2, 3) |
| q_taken = q_values.take(q_index) |
| result = _test_execute([q_values, q_taken], |
| as_raw_array=True, |
| block_size=2) |
| r_values = result[q_values] |
| r_taken = result[q_taken] |
| assert list(r_values) == list(r_taken) |
| assert r_values.base is not None |
| assert r_taken.base is not None |
| assert r_values.base is r_taken.base |
| |
| def test_take_values_broadcast(): |
| q_values = QueryNode.filled(count=10, fill_value=3) |
| q_index = QueryNode.literals(2, 4, 6) |
| r = _execute_q(q_values.take(q_index), as_raw_array=True) |
| assert r.tolist() == [3, 3, 3] |
| assert npy_get_broadcaster(r).tolist() == [3] |
| |
| def test_take_index_is_broadcast(): |
| q_values = QueryNode.literals(-1, -2, -3, -4) |
| q_index = QueryNode.filled(count=3, fill_value=2) |
| r = _execute_q(q_values.take(q_index), as_raw_array=True) |
| assert r.tolist() == [-3, -3, -3] |
| assert npy_get_broadcaster(r).tolist() == [-3] |
| |
| def test_take_values_broadcast_mask_index(): |
| q_values = QueryNode.filled(count=10, fill_value=3) |
| q_index = QueryNode.literals(2, -1, 6) |
| r = _execute_q(q_values.take(q_index), as_raw_array=True) |
| assert r.tolist() == [3, N, 3] |
| assert npy_get_broadcaster(r) is None |
| assert isinstance(r, ma) |
| assert r.data.tolist() == [3, 0, 3] |
| |
| def test_take_values_broadcast_mask_values(): |
| q_values = QueryNode.filled(count=10, fill_value=N) |
| q_index = QueryNode.literals(2, 4, 6) |
| r = _execute_q(q_values.take(q_index), as_raw_array=True) |
| assert r.tolist() == [N, N, N] |
| assert npy_get_broadcaster(r).tolist() == [N] |
| |
| @pytest.mark.parametrize("src", ["numpy", "array"]) |
| def test_mask_functions(src): |
| # pylint: disable=protected-access,compare-to-zero |
| from array import array as _array |
| array = (np.asarray([1, 2, 3]) |
| if src == "numpy" |
| else np.asarray(_array("l", [1, 2, 3]))) |
| ma_nomask = np.ma.masked_array(array) |
| ma_mask = np.ma.masked_array(array, mask=[T, F, T]) |
| for value in (array, ma_nomask, ma_mask): |
| real_mask = np.ma.getmask(value) |
| is_masked_array = isinstance(value, np.ma.masked_array) |
| is_really_masked = real_mask is not nomask |
| assert npy_has_mask(value) is is_really_masked |
| if is_really_masked: |
| assert npy_get_data(value).base is array.base |
| assert npy_get_mask(value) is real_mask |
| x1, x2 = npy_explode(value) |
| assert x1.base is array.base |
| assert list(x1) == list(array) |
| assert x2 is value._mask |
| else: |
| if is_masked_array: |
| assert npy_get_data(value).base is array.base |
| else: |
| assert npy_get_data(value) is value |
| assert npy_get_mask(value) is nomask |
| x1, x2 = npy_explode(value) |
| if is_masked_array: |
| assert x1.base is array.base |
| else: |
| assert x1 is value |
| assert list(x1) == list(array) |
| assert x2 is nomask |
| |
| @pytest.mark.parametrize("block_size", [1, 2, 3, 1024]) |
| def test_sequential_take_basic(block_size): |
| q_values = QueryNode.literals(2, 4, 6, 8) |
| q_index = QueryNode.literals(1, 2, 3) |
| r = _execute_q(q_values.take_sequential(q_index), |
| as_raw_array=True, |
| block_size=block_size) |
| assert r.tolist() == [4, 6, 8] |
| |
| def test_sequential_take_repeat(): |
| q_values = QueryNode.literals(2, 4, 6, 8) |
| q_index = QueryNode.literals(1, 1, 1, 3) |
| r = _execute_q(q_values.take_sequential(q_index), |
| as_raw_array=True, |
| block_size=2) |
| assert r.tolist() == [4, 4, 4, 8] |
| |
| def test_sequential_take_early_eof_error(): |
| q_values = QueryNode.literals(2, 4, 6) |
| q_index = QueryNode.literals(9) |
| with pytest.raises(InvalidQueryException) as ex: |
| _execute_q(q_values.take_sequential(q_index), |
| as_raw_array=True, |
| block_size=2) |
| assert "early value EOF at index 3" in str(ex) |
| |
| def test_sequential_take_mask(): |
| q_values = QueryNode.literals(2, 4, 6, 8) |
| q_index = QueryNode.literals(1, -1, 3) |
| r = _execute_q(q_values.take_sequential(q_index), |
| as_raw_array=True, |
| block_size=2) |
| assert r.tolist() == [4, N, 8] |
| |
| def test_sequential_take_nulls_at_end(): |
| q_values = QueryNode.literals(2, 4, 6, 8) |
| q_index = QueryNode.literals(1, 3, -1, -1, -1, -1, -1) |
| r = _execute_q(q_values.take_sequential(q_index), |
| as_raw_array=True, |
| block_size=2) |
| assert r.tolist() == [4, 8] + [N] * 5 |
| |
| def test_sequential_take_null_only(): |
| q_values = QueryNode.literals() |
| q_index = QueryNode.literals(-1, -1, -1, -1, -1) |
| r = _execute_q(q_values.take_sequential(q_index), |
| as_raw_array=True, |
| block_size=2) |
| assert r.tolist() == [N] * 5 |
| |
| def test_sequential_take_empty(): |
| q_values = QueryNode.literals() |
| q_index = QueryNode.literals() |
| r = _execute_q(q_values.take_sequential(q_index), |
| as_raw_array=True, |
| block_size=2) |
| assert r.tolist() == [] |
| |
| def test_sequential_take_sequential(): |
| q_values = QueryNode.literals(2, 4, 6, 8, 10) |
| q_index = QueryNode.literals(0, 1, 2, 3, 4) |
| r = _execute_q(q_values.take_sequential(q_index), |
| as_raw_array=True, |
| block_size=3) |
| assert r.tolist() == [2, 4, 6, 8, 10] |
| |
| def test_sequential_take_sequential_then_normal(): |
| q_values = QueryNode.literals(2, 4, 6, 8, 10) |
| q_index = QueryNode.literals(0, 1, 2, 2, 4) |
| r = _execute_q(q_values.take_sequential(q_index), |
| as_raw_array=True, |
| block_size=3) |
| assert r.tolist() == [2, 4, 6, 6, 10] |
| |
| def test_sequential_take_sequential_rollback(): |
| q_values = QueryNode.literals(2, 4, 6, 8, 10) |
| q_index = QueryNode.literals(0, 1, 1, 3, 4) |
| r = _execute_q(q_values.take_sequential(q_index), |
| as_raw_array=True, |
| block_size=10) |
| assert r.tolist() == [2, 4, 4, 8, 10] |
| |
| @pytest.mark.parametrize("block_size", [1, 2, 3, 10]) |
| @pytest.mark.parametrize("mode", ["full", "conservative"]) |
| @pytest.mark.parametrize("null", ["null", "no_null"]) |
| def test_sequential_take_broadcast_value(block_size, mode, null): |
| bcast_conservative = mode == "conservative" |
| use_null = null == "null" |
| q_values = QueryNode.filled(count=10, fill_value=3) |
| q_index = QueryNode.literals(0, 3, -1 if use_null else 4, 4, 4) |
| r = _execute_q( |
| q_values.take_sequential(q_index, |
| bcast_conservative=bcast_conservative), |
| as_raw_array=True, |
| block_size=block_size) |
| assert r.tolist() == [3, 3, N if use_null else 3, 3, 3] |
| bcast = npy_get_broadcaster(r) |
| if use_null or (block_size in (2, 3) and bcast_conservative): |
| assert bcast is None |
| else: |
| assert bcast.tolist() == [3] |
| |
| @pytest.mark.skipif( |
| not SPAN_INVARIANT_CHECKING, |
| reason="optimized build doesn't check for index errors") |
| def test_sequential_take_invalid(): |
| q_values = QueryNode.literals(1, 2, 3) |
| q_index = QueryNode.literals(1, 2, 1) |
| with pytest.raises(InvalidQueryException): |
| _execute_q(q_values.take_sequential(q_index)) |
| |
| def test_literal_query(): |
| """Test basic literal-sequence QueryNode support""" |
| _execute_q(LiteralQuery(())) == [] |
| _execute_q(LiteralQuery((1, 2, 3))) == [1, 2, 3] |
| assert LiteralQuery((b"foo", b"bar")).schema.is_string |
| _execute_q(LiteralQuery((b"foo", b"bar"))) == [1, 2] |
| with pytest.raises(InvalidQueryException): |
| _execute_q(LiteralQuery((1, b"foo", 3))) |
| |
| def test_similar_values_are_distinct(): |
| """Test distinctification""" |
| # Make sure that query interning isn't eagerly combining queries |
| # that test for equality on value (e.g., 0 == False), but not on |
| # type (int is not bool). |
| literals = [1.0, 1, True, False, 0] |
| qs = [LiteralQuery([lv]) for lv in literals] |
| assert len(set(qs)) == len(qs) |
| qs = [QueryNode.scalar(lv) for lv in literals] |
| assert len(set(qs)) == len(qs) |
| |
| ALL_REGULAR_SCHEMAS = [schema_name for schema_name in ALL_SCHEMAS |
| if ALL_SCHEMAS[schema_name][0] |
| not in (NULL_SCHEMA, STRING_SCHEMA)] |
| |
| @pict_parameterize(dict(schema_name_left=ALL_REGULAR_SCHEMAS, |
| schema_name_right=ALL_REGULAR_SCHEMAS)) |
| def test_dtype_value_preserving(schema_name_left, schema_name_right): |
| left_schema = ALL_SCHEMAS[schema_name_left][0] |
| right_schema = ALL_SCHEMAS[schema_name_right][0] |
| ldtype = left_schema.dtype |
| ltype = ldtype.type |
| rdtype = right_schema.dtype |
| rtype = rdtype.type |
| def _extreme_values(dtype): |
| if dtype.kind == "b": |
| return [True, False] |
| if dtype.kind == "f": |
| info = np.finfo(dtype) |
| else: |
| info = np.iinfo(dtype) |
| return [dtype.type(info.min), dtype.type(info.max)] |
| extreme_left = _extreme_values(ldtype) |
| extreme_right = _extreme_values(rdtype) |
| q = LiteralQuery(extreme_left + extreme_right) |
| def _i64_cast(value): |
| return np.asarray(value).view("l").item() |
| expected_left = extreme_left |
| expected_right = extreme_right |
| # We're not value-preserving when mixing int64 and uint64 |
| if ldtype.char == "L" and q.schema.dtype.char == "l": |
| expected_left = list(map(_i64_cast, expected_left)) |
| ltype = np.int64 |
| if rdtype.char == "L" and q.schema.dtype.char == "l": |
| expected_right = list(map(_i64_cast, expected_right)) |
| rtype = np.int64 |
| if rdtype.kind == "f" and ldtype.kind != "f": |
| ltype = rtype |
| expected_left = list(map(ltype, expected_left)) |
| if ldtype.kind == "f" and rdtype.kind != "f": |
| rtype = ltype |
| expected_right = list(map(rtype, expected_right)) |
| assert all_unique(expected_left) |
| assert all_unique(expected_right) |
| result = _execute_q(q) |
| assert len(result) == len(expected_left) + len(expected_right) |
| result_left = result[:len(expected_left)] |
| result_right = result[len(expected_left):] |
| assert result_left == expected_left |
| assert all_unique(result_left) |
| assert list(map(ltype, result_left)) == expected_left |
| assert result_right == expected_right |
| assert all_unique(result_right) |
| assert list(map(rtype, result_right)) == expected_right |
| |
| def test_query_table_result(): |
| |
| # pylint: disable=unneeded-not,superfluous-parens |
| |
| qt = TestQueryTable( |
| names=["foo", "bar"], |
| rows=[ |
| [1, 2], |
| [3, 4], |
| ]) |
| r = _execute_qt(qt) |
| assert isinstance(r, TestQueryTableResult) |
| assert r == dict(foo=[1, 3], bar=[2, 4]) |
| assert r == qt |
| assert not (r != qt) |
| qt = TestQueryTable( |
| names=["foo", "bar"], |
| rows=[ |
| [3, 2], |
| [3, 4], |
| ]) |
| assert not (r == qt) |
| assert r != qt |
| qt = TestQueryTable( |
| names=["foo", "bar"], |
| schemas=[QuerySchema(INT32), QuerySchema(INT32)], |
| rows=[ |
| [1, 2], |
| [3, 4], |
| ]) |
| assert not (r == qt) |
| assert r != qt |
| qt = TestQueryTable( |
| names=["_ts", "_duration"], |
| rows=[ |
| [1, 2], |
| [3, 4], |
| ], |
| table_schema=SPAN_UNPARTITIONED_TIME_MAJOR) |
| assert not (r == qt) |
| assert r != qt |
| r = _execute_qt(qt) |
| assert r == qt |
| qt = TestQueryTable( |
| names=["_ts", "_duration"], |
| rows=[ |
| [1, 2], |
| [3, 4], |
| ], |
| table_schema=EVENT_UNPARTITIONED_TIME_MAJOR) |
| assert r != qt |
| |
| def test_schema_constraints(): |
| q = QuerySchema(INT64) |
| assert not q.constraints |
| assert q.unconstrain() is q |
| q2 = q.constrain({UNIQUE}) |
| assert q2 is not q |
| assert q2 != q |
| assert q2.is_a(q) |
| assert q2.constrain({UNIQUE}) is q2 |
| q3 = q.constrain({UNIQUE}) |
| assert q2 == q3 |
| q3 = QuerySchema.concat([q, q2]) |
| assert not q3.constraints |
| |
| |
| # Sorting |
| |
| def _munge_sorts(sorts): |
| munged_sorts = OrderedDict() |
| def _add_sort(column, ascending=True, collation="binary"): |
| assert isinstance(collation, str) |
| munged_sorts[column] = (bool(ascending), collation) |
| for sort in sorts: |
| _add_sort(*sort) |
| return tuple(munged_sorts.items()) |
| |
| @final |
| class SortingQueryTable(QueryTable): # TODO(dancol): remove me! |
| """SortingQueryTable sorts another QueryTable""" |
| |
| base_table = iattr(QueryTable) |
| sorts = tattr(converter=_munge_sorts) |
| |
| @override |
| def _make_column_tuple(self): |
| return self.base_table.columns |
| |
| @override |
| def countq(self): |
| return self.base_table.countq() |
| |
| def group_sizes(self): |
| """Group sizes for tests""" |
| if not self.sorts: |
| return self.base_table.group_sizes() |
| return TakeQuery(self.base_table.group_sizes(), |
| self.__index_query) |
| |
| @property |
| def __index_query(self): |
| """Query that gives indexes to take from base table for sort""" |
| def _base_query(sort_column, collation): |
| query = self.base_table.column(sort_column) |
| if query.schema.is_string: |
| query = CollateQuery(query, collation) |
| return query |
| |
| return ArgSortQuery( |
| ((_base_query(sort_column, collation), ascending) |
| for sort_column, (ascending, collation) in self.sorts)) |
| |
| @override |
| def _make_column_query(self, column): |
| if not self.sorts: |
| return self.base_table.column(column) |
| return TakeQuery(self.base_table.column(column), |
| self.__index_query) |
| |
| @pytest.mark.parametrize("ascending", [True, False]) |
| def test_sorting_simple(ascending): |
| """Test single-column sorting""" |
| to_sort = TestQueryTable( |
| rows=[ |
| [3, 2, 7, 4], |
| [6, 7, 2, 8], |
| [1, 4, 0, 3], |
| [9, 3, 1, 2], |
| ]) |
| |
| sorted_tbl = TestQueryTable( |
| rows=(identity if ascending else reversed)([ |
| [1, 4, 0, 3], |
| [3, 2, 7, 4], |
| [6, 7, 2, 8], |
| [9, 3, 1, 2], |
| ])) |
| |
| qt = SortingQueryTable(to_sort, [("col0", ascending)]) |
| assert _execute_qt(qt) == sorted_tbl.as_dict() |
| |
| @pytest.mark.parametrize("col1_ascending", [True, False]) |
| def test_sorting_lexicographical(col1_ascending): |
| """Test that sorting works for multiple colums""" |
| table = [ |
| [3, 2, 7, 4], |
| [5, 7, 2, 8], |
| [5, 4, 0, 3], |
| [9, 3, 1, 2], |
| ] |
| |
| inv = 1 if col1_ascending else -1 |
| sorted_table = sorted(table, key=lambda row: (row[0], row[1] * inv)) |
| |
| qt = TestQueryTable(rows=table) |
| sorts = ( |
| ("col0", True), |
| ("col1", col1_ascending), |
| ) |
| # pylint: disable=redefined-variable-type |
| qt = SortingQueryTable(TestQueryTable(rows=table), sorts) |
| assert _execute_qt(qt) == TestQueryTable(rows=sorted_table).as_dict() |
| |
| def _do_test_sorting_strings(strings, ascending, case_sensitive): |
| qe = QueryEngine(QueryCache()) |
| shuffled_strings = strings[:] |
| _make_rng().shuffle(shuffled_strings) |
| qe.st.vintern(tuple(s.encode("UTF-8") for s in shuffled_strings)) |
| qt = TestQueryTable(columns=[shuffled_strings], |
| schema=STRING_SCHEMA) |
| sorts = ( |
| ("col0", ascending, "binary" if case_sensitive else "nocase"), |
| ) |
| result = _execute_qt(SortingQueryTable(qt, sorts), strings=True)["col0"] |
| assert result == sorted( |
| strings, |
| key=lambda s: s if case_sensitive else s.lower(), |
| reverse=not ascending) |
| |
| @pytest.mark.parametrize("ascending", [True, False]) |
| def test_sorting_strings(ascending): |
| """Test string sorting query interface""" |
| _do_test_sorting_strings( |
| strings=["bar", "qux", "foo", "1", "", "341"], |
| ascending=ascending, |
| case_sensitive=True, |
| ) |
| |
| @pytest.mark.parametrize("ascending", [True, False]) |
| def test_sorting_strings_case_insensitive(ascending): |
| """Test string sorting query interface in case insensitive mode""" |
| _do_test_sorting_strings( |
| strings=["bar", "Qux", "foo", "1", "", "341"], |
| ascending=ascending, |
| case_sensitive=False, |
| ) |
| |
| def test_masked_broadcast_to(): |
| assert masked_broadcast_to(np.asarray([1]), 3).tolist() == [1, 1, 1] |
| assert masked_broadcast_to(ma([1], [T]), 3).tolist() == [N, N, N] |
| x = masked_broadcast_to(ma([1], [F]), 3) |
| assert x.tolist() == [1, 1, 1] |
| assert type(x) is np.ndarray # pylint: disable=unidiomatic-typecheck |
| x = masked_broadcast_to(ma([1]), 3) |
| assert x.tolist() == [1, 1, 1] |
| assert type(x) is np.ndarray # pylint: disable=unidiomatic-typecheck |
| |
| |
| @pytest.mark.parametrize("mode", ["npy", "array"]) |
| def test_get_broadcaster(mode): |
| def _m(values): |
| if mode == "array": |
| from array import array as _array |
| values = _array("l", values) |
| return np.asarray(values) |
| assert npy_get_broadcaster(_m([1])) is None |
| x = _m([1, 2, 3]) |
| assert npy_get_broadcaster(x) is None |
| y = _m([1]) |
| assert npy_get_broadcaster(y) is None |
| z = np.broadcast_to(y, 5) |
| assert list(z) == [1, 1, 1, 1, 1] |
| z_orig = npy_get_broadcaster(z) |
| assert list(z_orig) == [1] |
| assert z_orig is y |
| with pytest.raises(TypeError): |
| npy_get_broadcaster(1) |
| |
| def test_get_broadcaster_masked(): |
| assert npy_get_broadcaster(ma([1])) is None |
| x = ma([1, 2, 3], [T, F, T]) |
| assert not npy_get_broadcaster(x) |
| y = ma([1], [T]) |
| assert npy_get_broadcaster(y) is None |
| z = masked_broadcast_to(y, 5) |
| assert z.data.tolist() == [1, 1, 1, 1, 1] |
| assert z.tolist() == [N, N, N, N, N] |
| z_orig = npy_get_broadcaster(z) |
| assert z_orig.tolist() == y.tolist() |
| assert z_orig.data.base is y.data.base |
| assert z_orig.mask.base is y.mask.base |
| |
| def test_get_broadcaster_nomask(): |
| nm = ma([1]) |
| assert nm.mask is nomask |
| nm_bcast = ma(np.broadcast_to(nm.data.base, 3)) |
| assert nm_bcast.tolist() == [1, 1, 1] |
| assert nm_bcast.mask is nomask |
| nm_orig = npy_get_broadcaster(nm_bcast) |
| assert type(nm_orig) is np.ndarray # pylint: disable=unidiomatic-typecheck |
| assert nm_orig is nm.data.base |
| |
| def test_get_broadcaster_scalar(): |
| orig = np.int32(7) |
| bcast = npy_broadcast_to(orig, 3) |
| assert bcast.tolist() == [7, 7, 7] |
| broadcaster = npy_get_broadcaster(bcast) |
| assert type(broadcaster) is np.ndarray # pylint: disable=unidiomatic-typecheck |
| assert broadcaster.size == 1 |
| assert broadcaster.ndim == 1 |
| assert broadcaster.tolist() == [7] |
| assert broadcaster.dtype == np.dtype("i") |
| |
| def test_native_broadcast_to(): |
| x = np.asarray([1]) |
| y = npy_broadcast_to(x, 5) |
| assert y.tolist() == [1, 1, 1, 1, 1] |
| assert y.base is x |
| assert npy_broadcast_to([5], 3).tolist() == [5, 5, 5] |
| assert npy_broadcast_to([1, 2, 3], 3).tolist() == [1, 2, 3] |
| |
| def test_native_broadcast_to_scalar(): |
| scalar = np.int32(7) |
| scalar_bcast = npy_broadcast_to(scalar, 3) |
| assert scalar_bcast.tolist() == [7, 7, 7] |
| assert type(scalar_bcast) is np.ndarray # pylint: disable=unidiomatic-typecheck |
| base = scalar_bcast.base |
| assert type(base) is np.ndarray # pylint: disable=unidiomatic-typecheck |
| assert base.ndim == 0 # pylint: disable=compare-to-zero |
| assert base.size == 1 |
| assert base.tolist() == 7 |
| assert type(base.tolist()) is int # pylint: disable=unidiomatic-typecheck |
| |
| def test_native_broadcast_to_invalid(): |
| with pytest.raises(Exception): |
| npy_broadcast_to([1, 2], 3) |
| with pytest.raises(Exception): |
| npy_broadcast_to([[1]], 3) |
| |
| def test_native_broadcast_to_masked_array(): |
| ma_in = ma([1, 2, 3], [T, F, T]) |
| assert ma_in.tolist() == [N, 2, N] |
| z = npy_broadcast_to(ma_in, 3) |
| assert type(z) is np.ndarray # pylint: disable=unidiomatic-typecheck |
| assert z.tolist() == [1, 2, 3] |
| assert z.base is ma_in.data.base |
| |
| |
| class TestReadAllQuery(SimpleQueryNode): |
| base = iattr(QueryNode) |
| read_size = iattr(int) |
| priming_read = iattr(int) |
| |
| @override |
| def _compute_schema(self): |
| return QuerySchema(INT64) |
| |
| @override |
| def _compute_inputs(self): |
| return [self.base] |
| |
| @override |
| async def run_async(self, qe): |
| [ic], [oc] = await qe.async_setup([self.base], [self]) |
| obs_size = 0 |
| block = await ic.read(self.priming_read or self.read_size) |
| while block: |
| array = block.as_array() |
| broadcaster = npy_get_broadcaster(array) |
| if len(array) > 1: |
| assert broadcaster is not None |
| assert broadcaster.tolist() * len(block) == array.tolist() |
| obs_size += len(array) |
| block = await ic.read(self.read_size) |
| await oc.write([obs_size], True) |
| |
| @pytest.mark.parametrize("block_size", [1, 3, 4, 5]) |
| @pytest.mark.parametrize("read_size", [2, 3, 4]) |
| @pytest.mark.parametrize("priming_read", [0, 1]) |
| @pytest.mark.parametrize("null", [True, False]) |
| def test_broadcast_read_all(block_size, read_size, priming_read, null): |
| total_size = 16 |
| val = _execute_q( |
| TestReadAllQuery(QueryNode.filled(None if null else 3, total_size), |
| read_size, |
| priming_read), |
| block_size=block_size, |
| ) |
| assert val == [total_size] |