blob: 861958bd189aaacbf46f04c34eb38a617f7530cd [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test cases for SQL binary operations"""
# pylint: disable=missing-docstring
import logging
from .query import (
ArithmeticOperator,
ArithmeticOperatorUnitsAdd,
ArithmeticOperatorUnitsMul,
ComparisonOperator,
InvalidQueryException,
_BINOP,
)
from .sql import (
BinaryOperation,
)
from .test_query import (
ALL_SCHEMAS,
TestQueryTable,
_execute_qt,
)
from .test_sql import (
ps2qt,
)
from .test_util import (
parameterize_dwim,
)
log = logging.getLogger(__name__)
# Split into own file so combinators stuff doesn't slow down
# test_sql.py collection.
_TEST_BINOPS = ("* / % + - << >> & | = <=> <!=> >= > "
"<= < <> != == and or //".split()) + [
"is",
"is not",
"is distinct from",
"is not distinct from",
"||",
]
def get_example_tbl(schema_name, value_name):
"""Get a single-value example table for an operand test
The generated table has a single column called "vl".
"""
schema, exemplar = ALL_SCHEMAS[schema_name]
assert value_name in ("E", "N", "2")
if value_name == "N" or schema_name == "NULL":
value = None
elif value_name == "E":
value = exemplar
elif value_name == "2" and schema.is_null:
value = None
elif value_name == "2":
value = "foo" if schema.is_string else 2
else:
raise ValueError("unknown value " + repr(value_name))
return TestQueryTable(
names=["vl"],
schemas=[schema],
rows=[
[value],
])
@parameterize_dwim(
left_val=["E", "N", "2"],
left_schema_name=ALL_SCHEMAS,
operator_name=_TEST_BINOPS,
right_val=["E", "N", "2"],
right_schema_name=ALL_SCHEMAS,
)
def test_expression_binop(left_val, left_schema_name,
right_val, right_schema_name,
operator_name):
tbl_left = get_example_tbl(left_schema_name, left_val)
left_schema = tbl_left["vl"].schema
tbl_right = get_example_tbl(right_schema_name, right_val)
right_schema = tbl_right["vl"].schema
munged_operator_name = BinaryOperation.EQUIV.get(
operator_name, operator_name)
operator = _BINOP[munged_operator_name]
q = ("SELECT (SELECT * FROM tbl_left) "
"{} "
"(SELECT * FROM tbl_right) AS x" .format(operator_name))
# pylint: disable=redefined-variable-type,unidiomatic-typecheck
should_throw = False
if not should_throw and isinstance(operator, ArithmeticOperator):
if (left_schema.is_string or right_schema.is_string):
should_throw = "arith"
elif left_schema.dtype.kind not in operator.left_supported_kinds:
should_throw = "kind_left"
elif right_schema.dtype.kind not in operator.right_supported_kinds:
should_throw = "kind_right"
if (not should_throw and operator_name == "||" and
((not left_schema.is_string and not left_schema.is_null) or
(not right_schema.is_string and not right_schema.is_null))):
should_throw = "concat"
if (not should_throw and
left_schema.domain and right_schema.domain and
left_schema.domain != right_schema.domain):
should_throw = "domain-mismatch"
if (not should_throw and
not isinstance(operator, ComparisonOperator) and
(left_schema.domain or right_schema.domain)):
should_throw = "has-domain"
if (not should_throw and
isinstance(operator, ComparisonOperator)):
if (left_schema.is_string and
not right_schema.is_string and
not right_schema.is_null):
should_throw = "compare1"
elif (right_schema.is_string and
not left_schema.is_string and
not left_schema.is_null):
should_throw = "compare2"
if (not should_throw and
(left_schema.unit and right_schema.unit) and
not isinstance(operator, ArithmeticOperatorUnitsMul) and
left_schema.unit.dimensionality != right_schema.unit.dimensionality):
should_throw = "unit-mismatch"
if (not should_throw and
not left_schema.is_null and
not right_schema.is_null and
(left_schema.unit or right_schema.unit) and
type(operator) is ArithmeticOperator):
should_throw = "has-units"
if (not should_throw and
not left_schema.is_null and
not right_schema.is_null and
(bool(left_schema.unit) ^ bool(right_schema.unit)) and
(isinstance(operator, ComparisonOperator) or
type(operator) is ArithmeticOperatorUnitsAdd)):
# We allow multiplying units with dimensionless values, but not
# adding them. "Two times two feet" makes more than than "two
# feet plus three" --- three what?
should_throw = "unit-with-dimensionless"
try:
qt = ps2qt(q, tbl_left=tbl_left, tbl_right=tbl_right)
_execute_qt(qt)
except InvalidQueryException:
if should_throw:
return
raise
assert not should_throw