| # Copyright (C) 2020 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http:#www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Recursive CTE implementation""" |
| |
| # Despite its name, a "recursive" query is basically just a loop --- a |
| # loop subject to all the usual loop optimizations, including loop |
| # fusion and fission [1]. I've given a bit of thought to the subject |
| # of auto-merging CTEs, and while I'm convinced it's possible to do so |
| # when the CTEs have provably-equivalent length classes in their |
| # recursive parts, I don't think we know enough to tell whether it's a |
| # good or a bad idea to perform any specific merge --- so, in the |
| # interest of caution, let's just execute CTEs exactly as they appear |
| # in SQL without trying to do anything smart. |
| # |
| # [1] https://en.wikipedia.org/wiki/Loop_fission_and_fusion |
| |
| import logging |
| from itertools import chain, zip_longest, count as xcount |
| from collections import defaultdict |
| |
| from cytoolz import drop, first, keyfilter, valfilter |
| import numpy as np |
| |
| from modernmp.util import the |
| |
| from .query import ( |
| GenericQueryTable, |
| InvalidQueryException, |
| QueryAction, |
| QueryNode, |
| QuerySchema, |
| TableKind, |
| ) |
| |
| from .sql import ( |
| CteBindingName, |
| CteBindingValue, |
| RegularSelect, |
| ) |
| |
| from ._native import ( |
| Block, |
| OutputChannelSpec, |
| QueryExecution, |
| ) |
| |
| from .queryplan import ( |
| add_action_graph_edge, |
| assert_action_graph_valid, |
| build_local_actions_by_output, |
| condense_cycles_inplace, |
| delete_unreachable, |
| plan_query_minstar, |
| ) |
| |
| from .util import ( |
| EqImmutable, |
| Immutable, |
| all_same, |
| cached_property, |
| final, |
| iattr, |
| override, |
| sattr, |
| tattr, |
| ) |
| |
| log = logging.getLogger(__name__) |
| |
| # We don't use AutoMultiQuery here because this query doesn't fit the |
| # metadata-plus-extras structure that AutoMultiQuery is supposed to |
| # generalize and so AMQ wouldn't buy us anything. The outer-inner |
| # class structure is supposed to look a bit like |
| # AutoMultiQuery though. |
| |
| def _die_bad_cte_transformation(): |
| raise AssertionError( |
| "BUG: missing recursive CTE transformation!") |
| |
| @final |
| class RecursivePart(EqImmutable): |
| """Describes one column of a recursively-computed table""" |
| base_part = iattr(QueryNode) |
| recursive_part = iattr(QueryNode) |
| |
| @final |
| class Recursive(EqImmutable): |
| """Recursive query configuration.""" |
| |
| base_number = iattr(int) |
| parts = tattr(RecursivePart) |
| |
| @override |
| def _post_init_assert(self): |
| super()._post_init_assert() |
| assert self.parts |
| assert self.base_number >= 1 |
| |
| @cached_property |
| def base_queries(self): |
| """Set of base queries we might need (pruned during opt)""" |
| return frozenset(part.base_part for part in self.parts) |
| |
| @cached_property |
| def recursive_queries(self): |
| """Set of recursive queries we might need (prunted during opt)""" |
| return frozenset(part.recursive_part for part in self.parts) |
| |
| @cached_property |
| def result_queries(self): |
| """Possible result-column queries""" |
| return frozenset(self.make_result_query(part_number) |
| for part_number in self.part_by_number) |
| |
| @cached_property |
| def part_by_number(self): |
| """Mapping from part number to part data""" |
| return {number: part |
| for part, number in |
| zip(self.parts, xcount(self.base_number))} |
| |
| def make_result_query(self, part_number): |
| """Make a query for a recursive part""" |
| return RecursiveResultQuery(self, part_number) |
| |
| @final |
| class RecursiveLoopAction(QueryAction): |
| """Runtime implementation of recursive CTE computation""" |
| config = iattr(Recursive) |
| lifted = sattr(QueryNode) |
| active_part_numbers = sattr(int) |
| subquery_plan = iattr() |
| |
| @cached_property |
| def part_by_number(self): |
| """Parts filtered to ones actually needed""" |
| active_part_numbers = self.active_part_numbers |
| return keyfilter( |
| lambda part_number: part_number in active_part_numbers, |
| self.config.part_by_number) |
| |
| @override |
| def _compute_inputs(self): |
| return self.lifted | { |
| part.base_part for part in self.part_by_number.values() |
| } |
| |
| @override |
| def _compute_outputs(self): |
| return map(self.config.make_result_query, |
| self.part_by_number) |
| |
| async def __slurp_input(self, qe, ics, |
| part_by_number, |
| oc_by_part_number): |
| assert len(set(ic.query for ic in ics)) == len(ics), \ |
| "we should have no redundant input channels" |
| block_size = qe.block_size |
| ocs_by_base_query = defaultdict(list) |
| for part_number, part in part_by_number.items(): |
| ocs_by_base_query[part.base_part] \ |
| .append(oc_by_part_number[part_number]) |
| all_reads = [(ic, []) for ic in ics] |
| reads = all_reads |
| write_batch = [] |
| while reads: |
| read_batch = [read[0].read(block_size) for read in reads] |
| batch = await qe.async_io(*(read_batch + write_batch)) |
| write_batch = [] |
| new_reads = [] |
| for block, read in zip(batch, reads): |
| if block: |
| read[1].append(block) |
| for oc in ocs_by_base_query.get(read[0].query, ()): |
| write_batch.append(oc.write(block)) |
| if len(block) == block_size: |
| new_reads.append(read) |
| reads = new_reads |
| if write_batch: |
| await qe.async_io(*write_batch) |
| return { |
| read[0].query: read[1] |
| for read in all_reads |
| } |
| |
| @staticmethod |
| def __make_master_recursive_data(slurped_data_by_input_query, |
| lifted, |
| part_by_number, |
| reference_query_by_part_number): |
| master_recursive_data = { |
| query: slurped_data_by_input_query[query] |
| for query in lifted |
| } |
| iteration_data = { |
| reference_query_by_part_number[part_number]: |
| slurped_data_by_input_query[part.base_part].copy() |
| for part_number, part in part_by_number.items() |
| } |
| assert not set(master_recursive_data) & set(iteration_data) |
| return master_recursive_data, iteration_data |
| |
| @override |
| async def run_async(self, qe): |
| part_by_number = self.part_by_number |
| # Use strict backpressure because each iteration of the |
| # "recursive" loop is expensive (query execution setup, Python |
| # code) and we want to give the user the option of bailing out |
| # early via things like "LIMIT 1". If we didn't use strict |
| # backpressure, we'd output a whole block before realizing that we |
| # could stop, and emitting a whole block could take a very long |
| # time if the block size is large and the output per iteration |
| # is small. |
| ics, ocs = await qe.async_setup( |
| self.inputs, |
| [OutputChannelSpec(query=query, strict_backpressure=True) |
| for query in self.outputs]) |
| oc_by_part_number = {oc.query.part_number: oc for oc in ocs} |
| assert sorted(oc_by_part_number) == sorted(part_by_number) |
| reference_query_by_part_number = { |
| part_number: RecursiveReferenceQuery( |
| part_number, |
| part.base_part.schema) |
| for part_number, part in part_by_number.items() |
| } |
| master_recursive_data, iteration_data = \ |
| self.__make_master_recursive_data( |
| await self.__slurp_input( |
| qe, ics, part_by_number, |
| oc_by_part_number), |
| self.lifted, |
| part_by_number, |
| reference_query_by_part_number) |
| env = qe.env.copy() |
| |
| def _make_recursive_data(): |
| recursive_data = { |
| query: arrays.copy() |
| for query, arrays in master_recursive_data.items() |
| } |
| for query, arrays in iteration_data.items(): |
| assert query not in recursive_data |
| recursive_data[query] = arrays.copy() |
| arrays.clear() |
| return recursive_data |
| |
| async def _handle_subquery_output(delivery): |
| part_number = delivery.part_number |
| data = delivery.data |
| assert isinstance(data, np.ndarray) |
| if len(data): # pylint: disable=len-as-condition |
| query = reference_query_by_part_number[part_number] |
| iteration_data[query].append(data) |
| await oc_by_part_number[part_number].write(data) |
| |
| async def _loop_once(): |
| env["recursive_data"] = _make_recursive_data() |
| # The subquery can ask for a query restart by throwing |
| # _RestartQueryException. Let's just let the top-level query |
| # engine take care of all restarts. |
| for output in QueryExecution.make( |
| plan=self.subquery_plan, |
| qc=qe.qc, |
| env=env, |
| progress_callback=qe.progress_callback): |
| if isinstance(output, RecursiveDelivery): |
| await _handle_subquery_output(output) |
| else: |
| await qe.async_yield_to_query_runner(output) |
| assert all_same(map(len, iteration_data.values())) |
| return not first(iteration_data.values()) |
| while True: |
| if await _loop_once(): |
| break |
| |
| @final |
| class FetchFromOuterLoopAction(QueryAction): |
| """Fetches an outer value for use by the inner loop""" |
| outer_query = iattr(QueryNode) |
| |
| @override |
| def _compute_inputs(self): |
| return () |
| |
| @override |
| def _compute_outputs(self): |
| return (self.outer_query,) |
| |
| @override |
| async def run_async(self, qe): |
| [], [oc] = await qe.async_setup((), (self.outer_query,)) |
| arrays = qe.env["recursive_data"][self.outer_query] |
| for array, next_array in zip_longest(arrays, drop(1, arrays)): |
| # Be careful not to introduce blocks from one query execution |
| # into a different query execution. |
| if isinstance(array, Block): |
| array = array.as_array() |
| assert isinstance(array, np.ndarray) |
| is_eof = next_array is None |
| await oc.write(array, is_eof) |
| |
| @final |
| class RecursiveDelivery(Immutable): |
| """Bundle of data piped from inner to outer query execution""" |
| part_number = iattr(int) |
| data = iattr(np.ndarray) |
| |
| @final |
| class DeliverRecursiveAction(QueryAction): |
| """Delivers an inner value to the outer loop""" |
| part_number = iattr(int) |
| recursive_part_query = iattr(QueryNode) |
| precious = True |
| |
| __inherit__ = dict(precious=override) |
| |
| @override |
| def _compute_inputs(self): |
| return (self.recursive_part_query,) |
| |
| @override |
| def _compute_outputs(self): |
| return () |
| |
| @override |
| async def run_async(self, qe): |
| [ic], [] = await qe.async_setup((self.recursive_part_query,), ()) |
| is_eof = False |
| while not is_eof: |
| block = await ic.read(qe.block_size) |
| is_eof = len(block) < qe.block_size |
| delivery = RecursiveDelivery(self.part_number, block.as_array()) |
| await qe.async_yield_to_query_runner(delivery) |
| |
| @final |
| class RecursivePlaceholderAction(QueryAction): |
| """Fake recursive action for early-stage query optimization""" |
| config = iattr(Recursive) |
| |
| @override |
| def _compute_inputs(self): |
| return (self.config.base_queries | |
| self.config.recursive_queries) |
| |
| @override |
| def _compute_outputs(self): |
| # We'll prune this set later when we plan the subquery. |
| return self.config.result_queries |
| |
| @override |
| async def run_async(self, qe): |
| # We should be replacing this action with the real recursive loop |
| # after DAG construction, meaning we should never get here. |
| _die_bad_cte_transformation() |
| |
| @staticmethod |
| def __install_delivery_actions(subgraph, |
| actions_by_output, |
| numbered_parts): |
| """Install actions that deliver data to the outer loop""" |
| deliver_actions = {} |
| for part_number, part in numbered_parts: |
| recursive_part = part.recursive_part |
| deliver_action = DeliverRecursiveAction(part_number, recursive_part) |
| assert deliver_action not in subgraph |
| recursive_part_action = actions_by_output[recursive_part] |
| if recursive_part_action in subgraph: |
| # We'll suck in the recursive part via |
| # FetchFromOuterLoopAction later if we don't have a within-SCC |
| # way to produce it. |
| add_action_graph_edge( |
| subgraph, |
| deliver_action, |
| recursive_part_action, |
| recursive_part) |
| else: |
| subgraph.add_node(deliver_action) |
| assert deliver_action in subgraph |
| assert part_number not in deliver_actions |
| deliver_actions[part_number] = deliver_action |
| return deliver_actions |
| |
| def __replace_recursive_references(self, nx, subgraph, deliver_actions): |
| # Replace only those recursive references to the current loop! |
| replacements = { |
| the(RecursiveReferenceAction, dependee): |
| FetchFromOuterLoopAction(dependee.recursive_reference_query) |
| for dependee in subgraph.predecessors(self) |
| } |
| # Remove the original back references. |
| subgraph.remove_edges_from((dependee, self) |
| for dependee in replacements) |
| assert not subgraph.in_degree[self] |
| # Do the substitution. |
| nx.relabel.relabel_nodes(subgraph, replacements, copy=False) |
| # Add new back references to the *deliver* action |
| for recursive_reference_action, fetch_action in replacements.items(): |
| assert recursive_reference_action not in subgraph |
| assert fetch_action in subgraph |
| deliver_action = \ |
| deliver_actions[recursive_reference_action.part_number] |
| assert deliver_action in subgraph |
| subgraph.add_edge(fetch_action, deliver_action) |
| |
| def __prune_subgraph(self, |
| nx, |
| subgraph, |
| action_graph, |
| deliver_actions): |
| exported_deliver_actions = set() |
| for edge in action_graph.pred[self].values(): |
| for query in edge["queries"]: |
| assert isinstance(query, RecursiveResultQuery) |
| exported_deliver_actions.add(deliver_actions[query.part_number]) |
| delete_unreachable(nx, subgraph, exported_deliver_actions) |
| new_deliver_actions = valfilter(lambda a: a in subgraph, |
| deliver_actions) |
| deliver_actions.clear() |
| deliver_actions.update(new_deliver_actions) |
| |
| @staticmethod |
| def __drop_back_edges(subgraph, deliver_actions): |
| edges_to_drop = [] |
| for dst in deliver_actions.values(): |
| for src in subgraph.predecessors(dst): |
| assert isinstance(src, FetchFromOuterLoopAction) |
| assert "queries" not in subgraph.adj[src][dst] |
| edges_to_drop.append((src, dst)) |
| subgraph.remove_edges_from(edges_to_drop) |
| |
| @staticmethod |
| def __propagate_inner_dependencies(subgraph, actions_by_output): |
| # Existing FetchFromOuterLoopAction actions don't have within- |
| # subgraph dependencies on the queries they fetch, meaning that |
| # needed_outer_queries doesn't include the recursive reference |
| # queries. The needed_outer_queries we build here are queries |
| # that the subquery needs but that aren't part of the recursive |
| # loop, e.g., the number two. |
| internal_produced_queries = \ |
| set(chain.from_iterable( |
| action.outputs for action in subgraph)) |
| needed_queries = \ |
| set(chain.from_iterable(action.inputs for action in subgraph)) |
| needed_outer_queries = needed_queries - internal_produced_queries |
| assert not needed_outer_queries - set(actions_by_output) |
| subgraph_nodes = tuple(subgraph) |
| for query in needed_outer_queries: |
| fetch_action = FetchFromOuterLoopAction(query) |
| for action in subgraph_nodes: |
| if query in action.inputs: |
| add_action_graph_edge(subgraph, |
| action, |
| fetch_action, |
| query) |
| return needed_outer_queries |
| |
| def __build_inner_query_action_graph(self, nx, action_graph, component): |
| """Main function for building the subquery action graph""" |
| assert self in component |
| assert all(action in action_graph for action in component) |
| # We want to be consistent about wiring a particular input to a |
| # particular output in case we have two actions that each produce |
| # the same input, so build a DB mapping query nodes to the actions |
| # that produce them, but only in the context of this component. |
| actions_by_output = build_local_actions_by_output( |
| nx, action_graph, component) |
| part_by_number = self.config.part_by_number |
| # The subgraph graph starts out just being the SCC. |
| subgraph = action_graph.subgraph(component).copy() |
| # Now add the delivery actions. We need to add these actions |
| # _after_ adding the "nonrecursive" nodes above because the |
| # delivery actions depend on the corresponding recursive parts |
| # (since that's the whole point) and the "recursive" part isn't |
| # automatically part of the SCC if it doesn't depend on previous |
| # input. We need to add *all* the potential delivery actions to |
| # the subgraph in case we have a recursive reference to a column |
| # that we don't need externally. |
| deliver_actions = self.__install_delivery_actions( |
| subgraph, actions_by_output, part_by_number.items()) |
| # Now we can replace the dummy recursive reference nodes with |
| # actual fetch-from-outer-loop actions. This action doesn't make |
| # the graph acyclic: we leave back edges from the new reference |
| # nodes to the corresponding delivery actions, preparing for |
| # pruning later. |
| self.__replace_recursive_references(nx, subgraph, deliver_actions) |
| assert self in subgraph |
| # Now we can delete the delivery actions we don't actually need. |
| # We know we need a delivery action when someone depends on the |
| # corresponding query from the placeholder action. |
| self.__prune_subgraph(nx, |
| subgraph, |
| action_graph, |
| deliver_actions) |
| assert self not in subgraph |
| # Finally, break the back references, making the graph acyclic |
| # with respect to *us*, although it may still contain cycles if |
| # internal computations request nested recursive CTEs. |
| self.__drop_back_edges(subgraph, deliver_actions) |
| # Make sure the outer graph contains everything we need for the |
| # inner graph. |
| needed_outer_queries = self.__propagate_inner_dependencies( |
| subgraph, actions_by_output) |
| active_part_numbers = tuple(deliver_actions) |
| return (subgraph, actions_by_output, |
| needed_outer_queries, active_part_numbers) |
| |
| @override |
| def condense(self, |
| nx, |
| action_graph, |
| component): |
| """Remove cycles in this SCC""" |
| assert self in component |
| assert all(action in action_graph for action in component) |
| (subgraph, actions_by_output, |
| needed_outer_queries, active_part_numbers) = \ |
| self.__build_inner_query_action_graph(nx, action_graph, component) |
| # We can nest recursive CTEs, so condense cycles recursively until |
| # we end up with a DAG. |
| condense_cycles_inplace(nx, subgraph) |
| # Now we can encapsulate the whole messy loop graph with one |
| # single RecursiveLoopAction, which depends on all queries that we |
| # might need to perform the inner query. |
| assert assert_action_graph_valid(subgraph) |
| condensed_action = RecursiveLoopAction( |
| self.config, |
| needed_outer_queries, |
| active_part_numbers, |
| plan_query_minstar(subgraph)) |
| action_graph.remove_nodes_from(component - {self}) |
| nx.relabel.relabel_nodes(action_graph, |
| {self: condensed_action}, |
| copy=False) |
| assert self not in action_graph |
| assert not any(action in action_graph for action in component) |
| assert condensed_action in action_graph |
| assert isinstance(needed_outer_queries, set) |
| for query in needed_outer_queries: |
| add_action_graph_edge(action_graph, |
| condensed_action, |
| actions_by_output[query], |
| query) |
| # Delete any edges corresponding to inputs that we no longer need |
| # after pruning. |
| edges_to_remove = [] |
| for dst, edge in action_graph[condensed_action].items(): |
| queries = edge["queries"] |
| queries.intersection_update(condensed_action.inputs) |
| if not queries: |
| edges_to_remove.append((condensed_action, dst)) |
| action_graph.remove_edges_from(edges_to_remove) |
| |
| @final |
| class RecursiveResultQuery(QueryNode): |
| """Query representing one column of the finished CTE""" |
| |
| config = iattr(Recursive) |
| part_number = iattr(int) |
| |
| @override |
| def _compute_schema(self): |
| part = self.config.part_by_number[self.part_number] |
| base_schema = part.base_part.schema.unconstrain() |
| assert base_schema == part.recursive_part.schema.unconstrain() |
| return base_schema |
| |
| @override |
| def make_action(self): |
| return RecursivePlaceholderAction(self.config) |
| |
| @final |
| class RecursiveReferenceQuery(QueryNode): |
| """Placeholdre query for recursive references in CTE""" |
| # The number field of a RecursiveReferenceQuery refers to the part |
| # number in its enclosing Recursive object. The SQL engine just |
| # numbers recursive references sequentially when building Recursive |
| # queries. You might ask, "does our use a plain, non-unique index |
| # break QueryNode referential transparency?" The answer is no: |
| # RecursiveReferenceQuery has special evaluation semantics. |
| # Evaluating any RecursiveReferenceQuery *by itself* just yields an |
| # AssertionError, thus trivially preserving referential |
| # transparency. *Within* a full recursive QueryNode tree, however, |
| # the RecursiveReferenceQuery refers to the appropriate partial |
| # result vector during query execution. Since the same |
| # RecursiveResultQuery always yields the same value and it's not |
| # legal to evaluate RecursiveReferenceQuery outside a |
| # RecursiveResultQuery, we preserve referential transparency. |
| # |
| # "But wait!", you might object: "doesn't a RecursiveResultQuery |
| # refer to a different result vector each time through the |
| # recursive-query evaluation loop?" Well, it does, yes. But each |
| # iteration of the recursive-query evaluation loop is a separate |
| # query execution, and referential transparency of QueryNodes |
| # matters only within a query execution (during which |
| # RecursiveResultQuery actually does satisfy the constraint) and |
| # within a cache domain. We never cache RecursiveResultQuery or |
| # anything that depends on it, up to a RecursiveResultQuery, so |
| # we're good with respect to the second requirement too. If you |
| # cheat and users can't tell the difference, does it really matter? |
| part_number = iattr(int) |
| the_schema = iattr(QuerySchema) |
| |
| @override |
| def _compute_schema(self): |
| return self.the_schema |
| |
| @override |
| def make_action(self): |
| raise AssertionError("not used") |
| |
| @override |
| def special_add_action(self, action_graph, actions_being_added): |
| """Special logic for adding recursive references""" |
| # We maintain a stack of actions as we add them to the graph and |
| # that we search this stack, most-recent to last-recent, to find a |
| # referent for this recursive reference. (We can't just encode |
| # the reference in the QueryNode structure: QueryNode instances |
| # are immutable, and you need mutability to create a cycle.) |
| # Having found our referent, we then search the inbound edges to |
| # see whether we already have an action corresponding to this |
| # recursive reference: if we do, we just use that. Otherwise, we |
| # make a new action and stick it into the graph. (Merging nodes |
| # when possible is important for query optimization.) Normally, |
| # the action graph builder knows that it can reuse an action when |
| # there's some existing action that produces a query we want, but |
| # because RecursiveReferenceQuery is ambiguous except in the |
| # context of a recursive evaluation, we need this search to |
| # determine safe reuse. |
| part_number = self.part_number |
| referent = None |
| for candidate_referent in reversed(actions_being_added): |
| if isinstance(candidate_referent, RecursivePlaceholderAction) \ |
| and self.part_number in candidate_referent.config.part_by_number: |
| referent = candidate_referent |
| break |
| if not referent: |
| raise InvalidQueryException("orphaned recursive reference") |
| for referrer in action_graph.predecessors(referent): |
| if (isinstance(referrer, RecursiveReferenceAction) and |
| referrer.part_number == part_number): |
| return referrer |
| action = RecursiveReferenceAction(self) |
| action_graph.add_edge(action, referent, queries=()) |
| return action |
| |
| @final |
| class RecursiveReferenceAction(QueryAction): |
| """Placeholder action for recursive references in early opt""" |
| |
| recursive_reference_query = iattr(RecursiveReferenceQuery) |
| |
| @cached_property |
| def part_number(self): |
| """The number of recursive reference part""" |
| return self.recursive_reference_query.part_number |
| |
| @override |
| def _compute_inputs(self): |
| return () # Lie blatantly. We depend on ourselves. |
| |
| @override |
| def _compute_outputs(self): |
| return (self.recursive_reference_query,) |
| |
| @override |
| async def run_async(self, _qe): |
| _die_bad_cte_transformation() |
| |
| @final |
| class RecursiveTableSubquery(CteBindingValue): |
| """Binding value provider for a recursive CTE""" |
| |
| base_te = iattr(RegularSelect) |
| recursive_te = iattr(RegularSelect) |
| |
| @override |
| def make_cte_value(self, tctx, cte_name): |
| assert isinstance(cte_name, CteBindingName) |
| base_qt, _ = self.base_te.make_qt(tctx, ()) |
| if not base_qt.columns: |
| raise InvalidQueryException( |
| "a recursive CTE with zero columns makes zero sense") |
| if cte_name.do_rename: |
| base_qt = GenericQueryTable.rename_columns(base_qt, cte_name.renaming) |
| |
| first_rce_number = tctx.max_rce_number + 1 |
| rce_counter = xcount(first_rce_number) |
| recursive_reference_qt = GenericQueryTable([ |
| (column_name, |
| RecursiveReferenceQuery(rce_nr, |
| base_qt[column_name].schema.unconstrain())) |
| for column_name, rce_nr |
| in zip(base_qt.columns, rce_counter) |
| ]) |
| recursive_tctx = tctx.let({cte_name.name: recursive_reference_qt}) |
| recursive_tctx.max_rce_number = next(rce_counter) - 1 |
| recursive_qt, _ = self.recursive_te.make_qt(recursive_tctx, ()) |
| if base_qt.table_schema.kind != TableKind.REGULAR or \ |
| recursive_qt.table_schema.kind != TableKind.REGULAR: |
| raise InvalidQueryException( |
| "recursive queries must yield regular tables") |
| if len(base_qt.columns) != len(recursive_qt.columns): |
| raise InvalidQueryException( |
| "mismatched column count in recursive CTE") |
| # We require an exact schema match because the value of a |
| # recursive reference column may come from either the base case or |
| # from repeated evaluations of the recursive case depending on |
| # whether we've applied the recursive case before, and the schema |
| # must be the same either way because the schema controls the |
| # dtype and we can choose only one dtype (for each column) as the |
| # overall result of the recursive CTE evaluation. |
| for base_column_name, recursive_column_name \ |
| in zip(base_qt.columns, recursive_qt.columns): |
| base_column = base_qt[base_column_name] |
| recursive_column = recursive_qt[recursive_column_name] |
| if base_column.schema.unconstrain() != recursive_column.schema: |
| raise InvalidQueryException( |
| "schema mismatch in recursive CTE: " |
| "{!r}/{} in base case vs {!r}/{} in recursive case".format( |
| base_column_name, |
| base_column.schema, |
| recursive_column_name, |
| recursive_column.schema)) |
| recursive = Recursive(first_rce_number, [ |
| RecursivePart(base_qt[base_column], |
| recursive_qt[recursive_column]) |
| for base_column, recursive_column |
| in zip(base_qt.columns, recursive_qt.columns) |
| ]) |
| assert len(recursive.parts) == len(recursive_reference_qt.columns) |
| return GenericQueryTable([ |
| (column_name, recursive.make_result_query(part_number)) |
| for column_name, part_number in zip(base_qt.columns, |
| recursive.part_by_number) |
| ]) |