blob: c10b36e01c784298b22db84084111fc5b0de34fb [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for group-by operations"""
# pylint: disable=bad-whitespace,missing-docstring
import logging
import pytest
import numpy as np
from .test_query import (
ALL_SCHEMAS,
TestQueryTable,
_execute_qt,
)
from .sql import (
WELL_KNOWN_AGGREGATIONS,
)
from ._native import (
native_aggregation_info,
)
from .test_sql import (
ps2qt,
)
from .query import (
InvalidQueryException,
QuerySchema,
)
from .util import (
BOOL,
INT64,
UINT64,
all_unique,
)
log = logging.getLogger(__name__)
N = None
T = True
FUNCTIMES = TestQueryTable(
names=("fn", "time", "cost"),
rows=[
[0, 5, 1],
[1, 5, 9],
[1, 20, 0],
[1, 1, 3],
[0, 1, 1],
])
def test_group_by_basic():
query = ("SELECT fn, SUM(time) AS cumtime "
"FROM functimes GROUP BY fn")
qt = ps2qt(query, functimes=FUNCTIMES)
assert _execute_qt(qt) == TestQueryTable(
names=("fn", "cumtime"),
rows=[
[0, 6],
[1, 26],
]).as_dict()
@pytest.mark.parametrize("add_null", ["add_null", "no_null"])
@pytest.mark.parametrize("schema_name", ALL_SCHEMAS)
def test_group_by_type(schema_name, add_null):
schema, _examplar = ALL_SCHEMAS[schema_name]
dtype = schema.dtype
if schema.is_string:
values = ["", "foo", "bar"]
elif schema.is_null:
values = [None]
elif dtype.kind == "b":
values = [True, False]
elif dtype.kind in "ui":
info = np.iinfo(dtype)
values = [info.min, info.max, 1]
if info.min != 0: # pylint: disable=compare-to-zero
values.append(0)
values = [dtype.type(x) for x in values]
elif dtype.kind == "f":
info = np.finfo(dtype)
values = [info.min, info.max, info.tiny, info.eps] # pylint: disable=no-member
values = [dtype.type(x) for x in values]
else:
assert False, "unrecognized schema {}".format(schema)
if add_null == "add_null" and not schema.is_null:
values.append(None)
assert all_unique(values)
mapping = {value: i + 1 for i, value in enumerate(values)}
repeat = 3
test_labels = list(mapping.keys()) * repeat
control_labels = list(mapping.values()) * repeat
grouped_data = np.arange(len(control_labels))
qt_test = TestQueryTable(
columns=[test_labels, grouped_data],
schemas=[schema, QuerySchema(INT64)]
)
qt_control = TestQueryTable(
columns=[control_labels, grouped_data],
schemas=[QuerySchema(INT64), QuerySchema(INT64)]
)
sql = "SELECT col0, SUM(col1) AS s FROM qt GROUP BY col0"
qt_control = ps2qt(sql, qt=qt_control)
result_control = _execute_qt(qt_control)
expected = {label: s for label, s
in zip(result_control["col0"],
result_control["s"])}
qt_test = ps2qt(sql, qt=qt_test)
result_test = _execute_qt(qt_test)
found = {mapping[label]: s for label, s in
zip(result_test["col0"],
result_test["s"])}
assert found == expected
@pytest.mark.parametrize("add_null", ["add_null", "no_null"])
@pytest.mark.parametrize("distinct", ["distinct", "no_distinct"])
@pytest.mark.parametrize("aggfunc", WELL_KNOWN_AGGREGATIONS)
@pytest.mark.parametrize("schema_name", ALL_SCHEMAS)
def test_group_aggregate_type(schema_name, add_null, aggfunc, distinct):
# pylint: disable=redefined-variable-type
schema, examplar = ALL_SCHEMAS[schema_name]
acls, _has_empty_value = native_aggregation_info(aggfunc)
supports_strings = acls in "=i"
if schema.is_string and not supports_strings:
return
grptbl = TestQueryTable(
names=["grp", "val"],
schemas=[INT64, schema],
rows=[
[0, examplar],
[1, examplar],
[0, N if add_null == "add_null" else examplar],
])
q = ("SELECT grp, dctv.{aggfunc}({distinct} val) "
"FROM grptbl GROUP BY grp").format(
aggfunc=aggfunc,
distinct="distinct" if distinct == "distinct" else "")
should_throw = False
if schema.domain:
should_throw = "domain"
elif schema.unit and aggfunc not in (
"biggest", "count", "first", "max", "min", "sum", "unique_mask"):
should_throw = "unit"
try:
qt = ps2qt(q, grptbl=grptbl)
_execute_qt(qt)
except InvalidQueryException:
if should_throw:
return
raise
assert not should_throw, (
"should have raised an error because {!r}".format(should_throw))
def test_group_by_literal():
q = "SELECT 1 AS x FROM dctv.filled(5) GROUP BY x"
assert _execute_qt(ps2qt(q)) == TestQueryTable(
names=["x"],
columns=[[1]])
def test_group_by_uint64_max():
q = "SELECT CAST(-1 AS UINT64 UNSAFE) AS x FROM dctv.filled(5) GROUP BY x"
assert _execute_qt(ps2qt(q)) == TestQueryTable(
names=["x"],
schemas=[UINT64],
columns=[[np.uint64(-1)]])
q = "SELECT CAST(-1 AS UINT64 UNSAFE) AS x GROUP BY x"
assert _execute_qt(ps2qt(q)) == TestQueryTable(
names=["x"],
schemas=[UINT64],
columns=[[np.uint64(-1)]])
def test_count_star_grouped():
query = ("SELECT fn, COUNT(*) AS nr "
"FROM functimes GROUP BY fn")
qt = ps2qt(query, functimes=FUNCTIMES)
assert _execute_qt(qt) == TestQueryTable(
names=("fn", "nr"),
rows=[
[0, 2],
[1, 3],
]).as_dict()
def test_group_multiple():
query = ("SELECT fn, cost, COUNT(*) AS nr "
"FROM functimes GROUP BY fn, cost")
qt = ps2qt(query, functimes=FUNCTIMES)
assert _execute_qt(qt) == TestQueryTable(
names=("fn", "cost", "nr"),
rows=[
[0, 1, 2],
[1, 9, 1],
[1, 0, 1],
[1, 3, 1],
]).as_dict()
def test_group_null():
tbl = TestQueryTable(
names=("key1", "key2", "value"),
rows=[
[1, 2, -1],
[1, 1, -2],
[1, 1, -3],
[N, 1, -4],
[N, 1, -6],
[1, N, -7],
[N, N, -8],
])
query = ("SELECT key1, key2, COUNT(*) AS nr, 0 AS sv "
"FROM tbl GROUP BY key1, key2")
qt = ps2qt(query, tbl=tbl)
assert _execute_qt(qt) == TestQueryTable(
names=("key1", "key2", "nr", "sv"),
rows = [
[1, 2, 1, 0],
[1, 1, 2, 0],
[N, 1, 2, 0],
[1, N, 1, 0],
[N, N, 1, 0],
]).as_dict()
def test_group_null_aggregation_count():
tbl = TestQueryTable(
names=("key1", "key2", "value"),
rows=[
[1, 2, -1],
[1, 1, -2],
[1, 1, -3],
[N, 1, -4],
[N, 1, N],
[1, N, -7],
[N, N, -8],
[N, N, N],
])
query = ("SELECT key1, key2, COUNT(value) AS nr "
"FROM tbl GROUP BY key1, key2")
qt = ps2qt(query, tbl=tbl)
assert _execute_qt(qt) == TestQueryTable(
names=("key1", "key2", "nr"),
rows = [
[1, 2, 1],
[1, 1, 2],
[N, 1, 1],
[1, N, 1],
[N, N, 1],
]).as_dict()
def test_group_null_aggregation_sum():
tbl = TestQueryTable(
names=("key1", "key2", "value"),
rows=[
[1, 2, -1],
[1, 1, -2],
[1, 1, -3],
[N, 1, -4],
[N, 1, -4],
[N, 1, N],
[1, N, -7],
[N, N, N],
])
query = ("SELECT key1, key2, SUM(value) AS sum "
"FROM tbl GROUP BY key1, key2")
qt = ps2qt(query, tbl=tbl)
assert _execute_qt(qt) == TestQueryTable(
names=("key1", "key2", "sum"),
rows = [
[1, 2, -1],
[1, 1, -5],
[N, 1, -8],
[1, N, -7],
[N, N, N],
]).as_dict()
def test_group_null_aggregation_min():
tbl = TestQueryTable(
names=("key1", "key2", "value"),
rows=[
[1, 2, -1],
[1, 1, -2],
[1, 1, -3],
[N, 1, -4],
[N, 1, -4],
[N, 1, N],
[1, N, -7],
[N, N, N],
])
query = ("SELECT key1, key2, MIN(value) AS min "
"FROM tbl GROUP BY key1, key2")
qt = ps2qt(query, tbl=tbl)
assert _execute_qt(qt) == TestQueryTable(
names=("key1", "key2", "min"),
rows = [
[1, 2, -1],
[1, 1, -3],
[N, 1, -4],
[1, N, -7],
[N, N, N],
]).as_dict()
def test_group_null_aggregation_sum_float():
tbl = TestQueryTable(
names=("key1", "key2", "value"),
rows=[
[1, 2, -1.0],
[1, 1, -2.0],
[1, 1, -3.0],
[N, 1, -4.0],
[N, 1, -4.0],
[N, 1, N],
[1, N, -7.0],
[N, N, N],
])
query = ("SELECT key1, key2, SUM(value) AS sum "
"FROM tbl GROUP BY key1, key2")
qt = ps2qt(query, tbl=tbl)
assert _execute_qt(qt) == TestQueryTable(
names=("key1", "key2", "sum"),
rows = [
[1, 2, -1.0],
[1, 1, -5.0],
[N, 1, -8.0],
[1, N, -7.0],
[N, N, N],
]).as_dict()
def test_group_null_aggregation_min_float():
tbl = TestQueryTable(
names=("key1", "key2", "value"),
rows=[
[1, 2, -1.0],
[1, 1, -2.0],
[1, 1, -3.0],
[N, 1, -4.0],
[N, 1, -4.0],
[N, 1, N],
[1, N, -7.0],
[N, N, N],
])
query = ("SELECT key1, key2, MIN(value) AS min "
"FROM tbl GROUP BY key1, key2")
qt = ps2qt(query, tbl=tbl)
assert _execute_qt(qt) == TestQueryTable(
names=("key1", "key2", "min"),
rows = [
[1, 2, -1.0],
[1, 1, -3.0],
[N, 1, -4.0],
[1, N, -7.0],
[N, N, N],
]).as_dict()
def test_group_null_aggregation_distinct_sum():
tbl = TestQueryTable(
names=("key1", "key2", "value"),
rows=[
[1, 2, -1],
[1, 1, -2],
[1, 1, -3],
[N, 1, -4],
[N, 1, -4],
[N, 1, N],
[1, N, -7],
[N, N, N],
])
query = ("SELECT key1, key2, SUM(DISTINCT value) AS sum "
"FROM tbl GROUP BY key1, key2")
qt = ps2qt(query, tbl=tbl)
assert _execute_qt(qt) == TestQueryTable(
names=("key1", "key2", "sum"),
rows = [
[1, 2, -1],
[1, 1, -5],
[N, 1, -4],
[1, N, -7],
[N, N, N],
]).as_dict()
def test_group_null_aggregation_distinct_count():
tbl = TestQueryTable(
names=("key1", "key2", "value"),
rows=[
[1, 2, -1],
[1, 1, -2],
[1, 1, -3],
[N, 1, -4],
[N, 1, -4],
[N, 1, N],
[1, N, -7],
[N, N, N],
])
query = ("SELECT key1, key2, "
"COUNT(value) AS cnt, "
"COUNT(DISTINCT value) AS cntd "
"FROM tbl GROUP BY key1, key2")
qt = ps2qt(query, tbl=tbl)
assert _execute_qt(qt) == TestQueryTable(
names=("key1", "key2", "cnt", "cntd"),
rows = [
[1, 2, 1, 1],
[1, 1, 2, 2],
[N, 1, 2, 1],
[1, N, 1, 1],
[N, N, 0, 0],
]).as_dict()
def test_group_nonnull_aggregation_distinct_count():
tbl = TestQueryTable(
names=("key1", "key2", "value"),
rows=[
[1, 2, -1],
[1, 1, -2],
[1, 1, -2],
[1, 1, -3],
])
query = ("SELECT key1, key2, "
"COUNT(value) AS cnt, "
"COUNT(DISTINCT value) AS cntd "
"FROM tbl GROUP BY key1, key2")
qt = ps2qt(query, tbl=tbl)
assert _execute_qt(qt) == TestQueryTable(
names=("key1", "key2", "cnt", "cntd"),
rows = [
[1, 2, 1, 1],
[1, 1, 3, 2],
]).as_dict()
def test_aggregate_sum_null():
tbl = TestQueryTable(
names=("key1", "key2", "value"),
rows=[
[1, 2, -1],
[1, 1, N],
[1, 1, -2],
[1, 1, -3],
])
query = ("SELECT SUM(value) AS sum FROM tbl")
qt = ps2qt(query, tbl=tbl)
assert _execute_qt(qt) == TestQueryTable(
names=("sum",),
rows = [
[-6],
]).as_dict()
def test_group_by_bool_sum():
q = ("SELECT 1 AS x, SUM(TRUE) AS y, MIN(TRUE) as z "
"FROM dctv.filled(2) GROUP BY x")
assert _execute_qt(ps2qt(q)) == TestQueryTable(
names=["x", "y", "z"],
schemas=[INT64, INT64, BOOL],
rows=[
[1, 2, T],
])
def test_group_having():
tbl = TestQueryTable(
names=("key1", "key2", "value"),
rows=[
[1, 2, -1.0],
[1, 1, -2.0],
[1, 1, -3.0],
[N, 1, -4.0],
[N, 1, -4.0],
[N, 1, N],
[1, N, -7.0],
[N, N, N],
])
query = ("SELECT key1, key2 "
"FROM tbl GROUP BY key1, key2 HAVING MIN(value) <= -3")
qt = ps2qt(query, tbl=tbl)
assert _execute_qt(qt) == TestQueryTable(
names=("key1", "key2"),
rows = [
[1, 1],
[N, 1],
[1, N],
]).as_dict()