| # Copyright (C) 2020 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http:#www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Tests for recursive table queries""" |
| |
| # pylint: disable=missing-docstring |
| |
| import logging |
| import pytest |
| from .query import InvalidQueryException |
| from .test_query import _execute_qt |
| from .test_sql import ps, ps2qt |
| |
| log = logging.getLogger(__name__) |
| |
| def test_recursive_cte_parse_union_all(): |
| ps("WITH RECURSIVE foo AS (SELECT 1 UNION ALL SELECT 2) SELECT * FROM foo") |
| |
| def test_basic(): |
| q = """ |
| WITH RECURSIVE t(n) AS ( |
| SELECT * FROM (VALUES (1)) |
| UNION ALL |
| SELECT n+1 FROM t WHERE n < 8+2 |
| ) |
| SELECT n FROM t |
| """ |
| qt = ps2qt(q) |
| assert _execute_qt(qt) == { |
| "n": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] |
| } |
| |
| def test_schema_mismatch(): |
| q = """ |
| WITH RECURSIVE t(n) AS ( |
| SELECT * FROM (VALUES (1.0)) |
| UNION ALL |
| SELECT 1 FROM t WHERE n < 8+2 |
| ) |
| SELECT n FROM t |
| """ |
| with pytest.raises(InvalidQueryException) as ex: |
| _execute_qt(ps2qt(q)) |
| assert "schema mismatch in recursive CTE" in str(ex) |
| |
| def test_multiple(): |
| q = """ |
| WITH RECURSIVE t(n, m) AS ( |
| SELECT * FROM (VALUES (1, 2)) |
| UNION ALL |
| SELECT n+1, m+3 FROM t WHERE n < 8+2 |
| ) |
| SELECT n, m FROM t |
| """ |
| qt = ps2qt(q) |
| assert _execute_qt(qt) == { |
| "n": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], |
| "m": [2, 5, 8, 11, 14, 17, 20, 23, 26, 29], |
| } |
| |
| def test_prune(): |
| # N.B. for now, we use three columns to ensure that at least one |
| # column is unused. The SQL table generate stuff right now is dumb |
| # enough to use an otherwise-unused column to generate its broadcast |
| # count, so if we have two columns, both might be "used" internally. |
| # But with three columns, one is guaranteed to be unused and so will |
| # exercise the query-pruning path. |
| # |
| # TODO(dancol): go back to two columns when we do AND-OR tree stuff. |
| q = """ |
| WITH RECURSIVE t(n, m, o) AS ( |
| SELECT * FROM (VALUES (1, 2, 3)) |
| UNION ALL |
| SELECT n+1, m+3, o+4 FROM t WHERE n < 8+2 |
| ) |
| SELECT n FROM t |
| """ |
| qt = ps2qt(q) |
| assert _execute_qt(qt) == { |
| "n": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], |
| } |
| |
| def test_nested_cte_uncoordinated(): |
| q = """ |
| WITH RECURSIVE t(n) AS ( |
| SELECT * FROM (VALUES (1)) |
| UNION ALL |
| SELECT n+(WITH RECURSIVE s(m) AS ( |
| SELECT * FROM (VALUES (2)) |
| UNION ALL |
| SELECT m+1 FROM s WHERE m < 2 |
| ) |
| SELECT SUM(m) FROM s) |
| FROM t |
| WHERE n < 8+2 |
| ) |
| SELECT n FROM t |
| """ |
| qt = ps2qt(q) |
| assert _execute_qt(qt) == { |
| "n": [1, 3, 5, 7, 9, 11], |
| } |
| |
| def test_nested_cte_coordinated(): |
| q = """ |
| WITH RECURSIVE t(n) AS ( |
| SELECT * FROM (VALUES (1)) |
| UNION ALL |
| SELECT n+(WITH RECURSIVE s(m) AS ( |
| SELECT * FROM (VALUES (2)) |
| UNION ALL |
| SELECT m+(SELECT SUM(n) FROM t) FROM s WHERE m < 3 |
| ) |
| SELECT SUM(m) FROM s) |
| FROM t |
| WHERE n < 8+2 |
| ) |
| SELECT n FROM t |
| """ |
| qt = ps2qt(q) |
| assert _execute_qt(qt) == { |
| "n": [1, 6, 16] |
| } |
| |
| def test_no_recursive_references(): |
| q = """ |
| WITH RECURSIVE t(n) AS ( |
| SELECT * FROM (VALUES (1)) |
| UNION ALL |
| SELECT 1 |
| ) |
| SELECT n FROM t LIMIT 5 |
| """ |
| qt = ps2qt(q) |
| assert _execute_qt(qt) == { |
| "n": [1, 1, 1, 1, 1] |
| } |