blob: 0945af0552696be8a9394a9a641b6e5f42a5de75 [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for recursive table queries"""
# pylint: disable=missing-docstring
import logging
import pytest
from .query import InvalidQueryException
from .test_query import _execute_qt
from .test_sql import ps, ps2qt
log = logging.getLogger(__name__)
def test_recursive_cte_parse_union_all():
ps("WITH RECURSIVE foo AS (SELECT 1 UNION ALL SELECT 2) SELECT * FROM foo")
def test_basic():
q = """
WITH RECURSIVE t(n) AS (
SELECT * FROM (VALUES (1))
UNION ALL
SELECT n+1 FROM t WHERE n < 8+2
)
SELECT n FROM t
"""
qt = ps2qt(q)
assert _execute_qt(qt) == {
"n": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
}
def test_schema_mismatch():
q = """
WITH RECURSIVE t(n) AS (
SELECT * FROM (VALUES (1.0))
UNION ALL
SELECT 1 FROM t WHERE n < 8+2
)
SELECT n FROM t
"""
with pytest.raises(InvalidQueryException) as ex:
_execute_qt(ps2qt(q))
assert "schema mismatch in recursive CTE" in str(ex)
def test_multiple():
q = """
WITH RECURSIVE t(n, m) AS (
SELECT * FROM (VALUES (1, 2))
UNION ALL
SELECT n+1, m+3 FROM t WHERE n < 8+2
)
SELECT n, m FROM t
"""
qt = ps2qt(q)
assert _execute_qt(qt) == {
"n": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"m": [2, 5, 8, 11, 14, 17, 20, 23, 26, 29],
}
def test_prune():
# N.B. for now, we use three columns to ensure that at least one
# column is unused. The SQL table generate stuff right now is dumb
# enough to use an otherwise-unused column to generate its broadcast
# count, so if we have two columns, both might be "used" internally.
# But with three columns, one is guaranteed to be unused and so will
# exercise the query-pruning path.
#
# TODO(dancol): go back to two columns when we do AND-OR tree stuff.
q = """
WITH RECURSIVE t(n, m, o) AS (
SELECT * FROM (VALUES (1, 2, 3))
UNION ALL
SELECT n+1, m+3, o+4 FROM t WHERE n < 8+2
)
SELECT n FROM t
"""
qt = ps2qt(q)
assert _execute_qt(qt) == {
"n": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
}
def test_nested_cte_uncoordinated():
q = """
WITH RECURSIVE t(n) AS (
SELECT * FROM (VALUES (1))
UNION ALL
SELECT n+(WITH RECURSIVE s(m) AS (
SELECT * FROM (VALUES (2))
UNION ALL
SELECT m+1 FROM s WHERE m < 2
)
SELECT SUM(m) FROM s)
FROM t
WHERE n < 8+2
)
SELECT n FROM t
"""
qt = ps2qt(q)
assert _execute_qt(qt) == {
"n": [1, 3, 5, 7, 9, 11],
}
def test_nested_cte_coordinated():
q = """
WITH RECURSIVE t(n) AS (
SELECT * FROM (VALUES (1))
UNION ALL
SELECT n+(WITH RECURSIVE s(m) AS (
SELECT * FROM (VALUES (2))
UNION ALL
SELECT m+(SELECT SUM(n) FROM t) FROM s WHERE m < 3
)
SELECT SUM(m) FROM s)
FROM t
WHERE n < 8+2
)
SELECT n FROM t
"""
qt = ps2qt(q)
assert _execute_qt(qt) == {
"n": [1, 6, 16]
}
def test_no_recursive_references():
q = """
WITH RECURSIVE t(n) AS (
SELECT * FROM (VALUES (1))
UNION ALL
SELECT 1
)
SELECT n FROM t LIMIT 5
"""
qt = ps2qt(q)
assert _execute_qt(qt) == {
"n": [1, 1, 1, 1, 1]
}