blob: 47651758d76eb999d199d05d0e215a5710175c23 [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Recursive CTE implementation"""
# Despite its name, a "recursive" query is basically just a loop --- a
# loop subject to all the usual loop optimizations, including loop
# fusion and fission [1]. I've given a bit of thought to the subject
# of auto-merging CTEs, and while I'm convinced it's possible to do so
# when the CTEs have provably-equivalent length classes in their
# recursive parts, I don't think we know enough to tell whether it's a
# good or a bad idea to perform any specific merge --- so, in the
# interest of caution, let's just execute CTEs exactly as they appear
# in SQL without trying to do anything smart.
#
# [1] https://en.wikipedia.org/wiki/Loop_fission_and_fusion
import logging
from itertools import chain, zip_longest, count as xcount
from collections import defaultdict
from cytoolz import drop, first, keyfilter, valfilter
import numpy as np
from modernmp.util import the
from .query import (
GenericQueryTable,
InvalidQueryException,
QueryAction,
QueryNode,
QuerySchema,
TableKind,
)
from .sql import (
CteBindingName,
CteBindingValue,
RegularSelect,
)
from ._native import (
Block,
OutputChannelSpec,
QueryExecution,
)
from .queryplan import (
add_action_graph_edge,
assert_action_graph_valid,
build_local_actions_by_output,
condense_cycles_inplace,
delete_unreachable,
plan_query_minstar,
)
from .util import (
EqImmutable,
Immutable,
all_same,
cached_property,
final,
iattr,
override,
sattr,
tattr,
)
log = logging.getLogger(__name__)
# We don't use AutoMultiQuery here because this query doesn't fit the
# metadata-plus-extras structure that AutoMultiQuery is supposed to
# generalize and so AMQ wouldn't buy us anything. The outer-inner
# class structure is supposed to look a bit like
# AutoMultiQuery though.
def _die_bad_cte_transformation():
raise AssertionError(
"BUG: missing recursive CTE transformation!")
@final
class RecursivePart(EqImmutable):
"""Describes one column of a recursively-computed table"""
base_part = iattr(QueryNode)
recursive_part = iattr(QueryNode)
@final
class Recursive(EqImmutable):
"""Recursive query configuration."""
base_number = iattr(int)
parts = tattr(RecursivePart)
@override
def _post_init_assert(self):
super()._post_init_assert()
assert self.parts
assert self.base_number >= 1
@cached_property
def base_queries(self):
"""Set of base queries we might need (pruned during opt)"""
return frozenset(part.base_part for part in self.parts)
@cached_property
def recursive_queries(self):
"""Set of recursive queries we might need (prunted during opt)"""
return frozenset(part.recursive_part for part in self.parts)
@cached_property
def result_queries(self):
"""Possible result-column queries"""
return frozenset(self.make_result_query(part_number)
for part_number in self.part_by_number)
@cached_property
def part_by_number(self):
"""Mapping from part number to part data"""
return {number: part
for part, number in
zip(self.parts, xcount(self.base_number))}
def make_result_query(self, part_number):
"""Make a query for a recursive part"""
return RecursiveResultQuery(self, part_number)
@final
class RecursiveLoopAction(QueryAction):
"""Runtime implementation of recursive CTE computation"""
config = iattr(Recursive)
lifted = sattr(QueryNode)
active_part_numbers = sattr(int)
subquery_plan = iattr()
@cached_property
def part_by_number(self):
"""Parts filtered to ones actually needed"""
active_part_numbers = self.active_part_numbers
return keyfilter(
lambda part_number: part_number in active_part_numbers,
self.config.part_by_number)
@override
def _compute_inputs(self):
return self.lifted | {
part.base_part for part in self.part_by_number.values()
}
@override
def _compute_outputs(self):
return map(self.config.make_result_query,
self.part_by_number)
async def __slurp_input(self, qe, ics,
part_by_number,
oc_by_part_number):
assert len(set(ic.query for ic in ics)) == len(ics), \
"we should have no redundant input channels"
block_size = qe.block_size
ocs_by_base_query = defaultdict(list)
for part_number, part in part_by_number.items():
ocs_by_base_query[part.base_part] \
.append(oc_by_part_number[part_number])
all_reads = [(ic, []) for ic in ics]
reads = all_reads
write_batch = []
while reads:
read_batch = [read[0].read(block_size) for read in reads]
batch = await qe.async_io(*(read_batch + write_batch))
write_batch = []
new_reads = []
for block, read in zip(batch, reads):
if block:
read[1].append(block)
for oc in ocs_by_base_query.get(read[0].query, ()):
write_batch.append(oc.write(block))
if len(block) == block_size:
new_reads.append(read)
reads = new_reads
if write_batch:
await qe.async_io(*write_batch)
return {
read[0].query: read[1]
for read in all_reads
}
@staticmethod
def __make_master_recursive_data(slurped_data_by_input_query,
lifted,
part_by_number,
reference_query_by_part_number):
master_recursive_data = {
query: slurped_data_by_input_query[query]
for query in lifted
}
iteration_data = {
reference_query_by_part_number[part_number]:
slurped_data_by_input_query[part.base_part].copy()
for part_number, part in part_by_number.items()
}
assert not set(master_recursive_data) & set(iteration_data)
return master_recursive_data, iteration_data
@override
async def run_async(self, qe):
part_by_number = self.part_by_number
# Use strict backpressure because each iteration of the
# "recursive" loop is expensive (query execution setup, Python
# code) and we want to give the user the option of bailing out
# early via things like "LIMIT 1". If we didn't use strict
# backpressure, we'd output a whole block before realizing that we
# could stop, and emitting a whole block could take a very long
# time if the block size is large and the output per iteration
# is small.
ics, ocs = await qe.async_setup(
self.inputs,
[OutputChannelSpec(query=query, strict_backpressure=True)
for query in self.outputs])
oc_by_part_number = {oc.query.part_number: oc for oc in ocs}
assert sorted(oc_by_part_number) == sorted(part_by_number)
reference_query_by_part_number = {
part_number: RecursiveReferenceQuery(
part_number,
part.base_part.schema)
for part_number, part in part_by_number.items()
}
master_recursive_data, iteration_data = \
self.__make_master_recursive_data(
await self.__slurp_input(
qe, ics, part_by_number,
oc_by_part_number),
self.lifted,
part_by_number,
reference_query_by_part_number)
env = qe.env.copy()
def _make_recursive_data():
recursive_data = {
query: arrays.copy()
for query, arrays in master_recursive_data.items()
}
for query, arrays in iteration_data.items():
assert query not in recursive_data
recursive_data[query] = arrays.copy()
arrays.clear()
return recursive_data
async def _handle_subquery_output(delivery):
part_number = delivery.part_number
data = delivery.data
assert isinstance(data, np.ndarray)
if len(data): # pylint: disable=len-as-condition
query = reference_query_by_part_number[part_number]
iteration_data[query].append(data)
await oc_by_part_number[part_number].write(data)
async def _loop_once():
env["recursive_data"] = _make_recursive_data()
# The subquery can ask for a query restart by throwing
# _RestartQueryException. Let's just let the top-level query
# engine take care of all restarts.
for output in QueryExecution.make(
plan=self.subquery_plan,
qc=qe.qc,
env=env,
progress_callback=qe.progress_callback):
if isinstance(output, RecursiveDelivery):
await _handle_subquery_output(output)
else:
await qe.async_yield_to_query_runner(output)
assert all_same(map(len, iteration_data.values()))
return not first(iteration_data.values())
while True:
if await _loop_once():
break
@final
class FetchFromOuterLoopAction(QueryAction):
"""Fetches an outer value for use by the inner loop"""
outer_query = iattr(QueryNode)
@override
def _compute_inputs(self):
return ()
@override
def _compute_outputs(self):
return (self.outer_query,)
@override
async def run_async(self, qe):
[], [oc] = await qe.async_setup((), (self.outer_query,))
arrays = qe.env["recursive_data"][self.outer_query]
for array, next_array in zip_longest(arrays, drop(1, arrays)):
# Be careful not to introduce blocks from one query execution
# into a different query execution.
if isinstance(array, Block):
array = array.as_array()
assert isinstance(array, np.ndarray)
is_eof = next_array is None
await oc.write(array, is_eof)
@final
class RecursiveDelivery(Immutable):
"""Bundle of data piped from inner to outer query execution"""
part_number = iattr(int)
data = iattr(np.ndarray)
@final
class DeliverRecursiveAction(QueryAction):
"""Delivers an inner value to the outer loop"""
part_number = iattr(int)
recursive_part_query = iattr(QueryNode)
precious = True
__inherit__ = dict(precious=override)
@override
def _compute_inputs(self):
return (self.recursive_part_query,)
@override
def _compute_outputs(self):
return ()
@override
async def run_async(self, qe):
[ic], [] = await qe.async_setup((self.recursive_part_query,), ())
is_eof = False
while not is_eof:
block = await ic.read(qe.block_size)
is_eof = len(block) < qe.block_size
delivery = RecursiveDelivery(self.part_number, block.as_array())
await qe.async_yield_to_query_runner(delivery)
@final
class RecursivePlaceholderAction(QueryAction):
"""Fake recursive action for early-stage query optimization"""
config = iattr(Recursive)
@override
def _compute_inputs(self):
return (self.config.base_queries |
self.config.recursive_queries)
@override
def _compute_outputs(self):
# We'll prune this set later when we plan the subquery.
return self.config.result_queries
@override
async def run_async(self, qe):
# We should be replacing this action with the real recursive loop
# after DAG construction, meaning we should never get here.
_die_bad_cte_transformation()
@staticmethod
def __install_delivery_actions(subgraph,
actions_by_output,
numbered_parts):
"""Install actions that deliver data to the outer loop"""
deliver_actions = {}
for part_number, part in numbered_parts:
recursive_part = part.recursive_part
deliver_action = DeliverRecursiveAction(part_number, recursive_part)
assert deliver_action not in subgraph
recursive_part_action = actions_by_output[recursive_part]
if recursive_part_action in subgraph:
# We'll suck in the recursive part via
# FetchFromOuterLoopAction later if we don't have a within-SCC
# way to produce it.
add_action_graph_edge(
subgraph,
deliver_action,
recursive_part_action,
recursive_part)
else:
subgraph.add_node(deliver_action)
assert deliver_action in subgraph
assert part_number not in deliver_actions
deliver_actions[part_number] = deliver_action
return deliver_actions
def __replace_recursive_references(self, nx, subgraph, deliver_actions):
# Replace only those recursive references to the current loop!
replacements = {
the(RecursiveReferenceAction, dependee):
FetchFromOuterLoopAction(dependee.recursive_reference_query)
for dependee in subgraph.predecessors(self)
}
# Remove the original back references.
subgraph.remove_edges_from((dependee, self)
for dependee in replacements)
assert not subgraph.in_degree[self]
# Do the substitution.
nx.relabel.relabel_nodes(subgraph, replacements, copy=False)
# Add new back references to the *deliver* action
for recursive_reference_action, fetch_action in replacements.items():
assert recursive_reference_action not in subgraph
assert fetch_action in subgraph
deliver_action = \
deliver_actions[recursive_reference_action.part_number]
assert deliver_action in subgraph
subgraph.add_edge(fetch_action, deliver_action)
def __prune_subgraph(self,
nx,
subgraph,
action_graph,
deliver_actions):
exported_deliver_actions = set()
for edge in action_graph.pred[self].values():
for query in edge["queries"]:
assert isinstance(query, RecursiveResultQuery)
exported_deliver_actions.add(deliver_actions[query.part_number])
delete_unreachable(nx, subgraph, exported_deliver_actions)
new_deliver_actions = valfilter(lambda a: a in subgraph,
deliver_actions)
deliver_actions.clear()
deliver_actions.update(new_deliver_actions)
@staticmethod
def __drop_back_edges(subgraph, deliver_actions):
edges_to_drop = []
for dst in deliver_actions.values():
for src in subgraph.predecessors(dst):
assert isinstance(src, FetchFromOuterLoopAction)
assert "queries" not in subgraph.adj[src][dst]
edges_to_drop.append((src, dst))
subgraph.remove_edges_from(edges_to_drop)
@staticmethod
def __propagate_inner_dependencies(subgraph, actions_by_output):
# Existing FetchFromOuterLoopAction actions don't have within-
# subgraph dependencies on the queries they fetch, meaning that
# needed_outer_queries doesn't include the recursive reference
# queries. The needed_outer_queries we build here are queries
# that the subquery needs but that aren't part of the recursive
# loop, e.g., the number two.
internal_produced_queries = \
set(chain.from_iterable(
action.outputs for action in subgraph))
needed_queries = \
set(chain.from_iterable(action.inputs for action in subgraph))
needed_outer_queries = needed_queries - internal_produced_queries
assert not needed_outer_queries - set(actions_by_output)
subgraph_nodes = tuple(subgraph)
for query in needed_outer_queries:
fetch_action = FetchFromOuterLoopAction(query)
for action in subgraph_nodes:
if query in action.inputs:
add_action_graph_edge(subgraph,
action,
fetch_action,
query)
return needed_outer_queries
def __build_inner_query_action_graph(self, nx, action_graph, component):
"""Main function for building the subquery action graph"""
assert self in component
assert all(action in action_graph for action in component)
# We want to be consistent about wiring a particular input to a
# particular output in case we have two actions that each produce
# the same input, so build a DB mapping query nodes to the actions
# that produce them, but only in the context of this component.
actions_by_output = build_local_actions_by_output(
nx, action_graph, component)
part_by_number = self.config.part_by_number
# The subgraph graph starts out just being the SCC.
subgraph = action_graph.subgraph(component).copy()
# Now add the delivery actions. We need to add these actions
# _after_ adding the "nonrecursive" nodes above because the
# delivery actions depend on the corresponding recursive parts
# (since that's the whole point) and the "recursive" part isn't
# automatically part of the SCC if it doesn't depend on previous
# input. We need to add *all* the potential delivery actions to
# the subgraph in case we have a recursive reference to a column
# that we don't need externally.
deliver_actions = self.__install_delivery_actions(
subgraph, actions_by_output, part_by_number.items())
# Now we can replace the dummy recursive reference nodes with
# actual fetch-from-outer-loop actions. This action doesn't make
# the graph acyclic: we leave back edges from the new reference
# nodes to the corresponding delivery actions, preparing for
# pruning later.
self.__replace_recursive_references(nx, subgraph, deliver_actions)
assert self in subgraph
# Now we can delete the delivery actions we don't actually need.
# We know we need a delivery action when someone depends on the
# corresponding query from the placeholder action.
self.__prune_subgraph(nx,
subgraph,
action_graph,
deliver_actions)
assert self not in subgraph
# Finally, break the back references, making the graph acyclic
# with respect to *us*, although it may still contain cycles if
# internal computations request nested recursive CTEs.
self.__drop_back_edges(subgraph, deliver_actions)
# Make sure the outer graph contains everything we need for the
# inner graph.
needed_outer_queries = self.__propagate_inner_dependencies(
subgraph, actions_by_output)
active_part_numbers = tuple(deliver_actions)
return (subgraph, actions_by_output,
needed_outer_queries, active_part_numbers)
@override
def condense(self,
nx,
action_graph,
component):
"""Remove cycles in this SCC"""
assert self in component
assert all(action in action_graph for action in component)
(subgraph, actions_by_output,
needed_outer_queries, active_part_numbers) = \
self.__build_inner_query_action_graph(nx, action_graph, component)
# We can nest recursive CTEs, so condense cycles recursively until
# we end up with a DAG.
condense_cycles_inplace(nx, subgraph)
# Now we can encapsulate the whole messy loop graph with one
# single RecursiveLoopAction, which depends on all queries that we
# might need to perform the inner query.
assert assert_action_graph_valid(subgraph)
condensed_action = RecursiveLoopAction(
self.config,
needed_outer_queries,
active_part_numbers,
plan_query_minstar(subgraph))
action_graph.remove_nodes_from(component - {self})
nx.relabel.relabel_nodes(action_graph,
{self: condensed_action},
copy=False)
assert self not in action_graph
assert not any(action in action_graph for action in component)
assert condensed_action in action_graph
assert isinstance(needed_outer_queries, set)
for query in needed_outer_queries:
add_action_graph_edge(action_graph,
condensed_action,
actions_by_output[query],
query)
# Delete any edges corresponding to inputs that we no longer need
# after pruning.
edges_to_remove = []
for dst, edge in action_graph[condensed_action].items():
queries = edge["queries"]
queries.intersection_update(condensed_action.inputs)
if not queries:
edges_to_remove.append((condensed_action, dst))
action_graph.remove_edges_from(edges_to_remove)
@final
class RecursiveResultQuery(QueryNode):
"""Query representing one column of the finished CTE"""
config = iattr(Recursive)
part_number = iattr(int)
@override
def _compute_schema(self):
part = self.config.part_by_number[self.part_number]
base_schema = part.base_part.schema.unconstrain()
assert base_schema == part.recursive_part.schema.unconstrain()
return base_schema
@override
def make_action(self):
return RecursivePlaceholderAction(self.config)
@final
class RecursiveReferenceQuery(QueryNode):
"""Placeholdre query for recursive references in CTE"""
# The number field of a RecursiveReferenceQuery refers to the part
# number in its enclosing Recursive object. The SQL engine just
# numbers recursive references sequentially when building Recursive
# queries. You might ask, "does our use a plain, non-unique index
# break QueryNode referential transparency?" The answer is no:
# RecursiveReferenceQuery has special evaluation semantics.
# Evaluating any RecursiveReferenceQuery *by itself* just yields an
# AssertionError, thus trivially preserving referential
# transparency. *Within* a full recursive QueryNode tree, however,
# the RecursiveReferenceQuery refers to the appropriate partial
# result vector during query execution. Since the same
# RecursiveResultQuery always yields the same value and it's not
# legal to evaluate RecursiveReferenceQuery outside a
# RecursiveResultQuery, we preserve referential transparency.
#
# "But wait!", you might object: "doesn't a RecursiveResultQuery
# refer to a different result vector each time through the
# recursive-query evaluation loop?" Well, it does, yes. But each
# iteration of the recursive-query evaluation loop is a separate
# query execution, and referential transparency of QueryNodes
# matters only within a query execution (during which
# RecursiveResultQuery actually does satisfy the constraint) and
# within a cache domain. We never cache RecursiveResultQuery or
# anything that depends on it, up to a RecursiveResultQuery, so
# we're good with respect to the second requirement too. If you
# cheat and users can't tell the difference, does it really matter?
part_number = iattr(int)
the_schema = iattr(QuerySchema)
@override
def _compute_schema(self):
return self.the_schema
@override
def make_action(self):
raise AssertionError("not used")
@override
def special_add_action(self, action_graph, actions_being_added):
"""Special logic for adding recursive references"""
# We maintain a stack of actions as we add them to the graph and
# that we search this stack, most-recent to last-recent, to find a
# referent for this recursive reference. (We can't just encode
# the reference in the QueryNode structure: QueryNode instances
# are immutable, and you need mutability to create a cycle.)
# Having found our referent, we then search the inbound edges to
# see whether we already have an action corresponding to this
# recursive reference: if we do, we just use that. Otherwise, we
# make a new action and stick it into the graph. (Merging nodes
# when possible is important for query optimization.) Normally,
# the action graph builder knows that it can reuse an action when
# there's some existing action that produces a query we want, but
# because RecursiveReferenceQuery is ambiguous except in the
# context of a recursive evaluation, we need this search to
# determine safe reuse.
part_number = self.part_number
referent = None
for candidate_referent in reversed(actions_being_added):
if isinstance(candidate_referent, RecursivePlaceholderAction) \
and self.part_number in candidate_referent.config.part_by_number:
referent = candidate_referent
break
if not referent:
raise InvalidQueryException("orphaned recursive reference")
for referrer in action_graph.predecessors(referent):
if (isinstance(referrer, RecursiveReferenceAction) and
referrer.part_number == part_number):
return referrer
action = RecursiveReferenceAction(self)
action_graph.add_edge(action, referent, queries=())
return action
@final
class RecursiveReferenceAction(QueryAction):
"""Placeholder action for recursive references in early opt"""
recursive_reference_query = iattr(RecursiveReferenceQuery)
@cached_property
def part_number(self):
"""The number of recursive reference part"""
return self.recursive_reference_query.part_number
@override
def _compute_inputs(self):
return () # Lie blatantly. We depend on ourselves.
@override
def _compute_outputs(self):
return (self.recursive_reference_query,)
@override
async def run_async(self, _qe):
_die_bad_cte_transformation()
@final
class RecursiveTableSubquery(CteBindingValue):
"""Binding value provider for a recursive CTE"""
base_te = iattr(RegularSelect)
recursive_te = iattr(RegularSelect)
@override
def make_cte_value(self, tctx, cte_name):
assert isinstance(cte_name, CteBindingName)
base_qt, _ = self.base_te.make_qt(tctx, ())
if not base_qt.columns:
raise InvalidQueryException(
"a recursive CTE with zero columns makes zero sense")
if cte_name.do_rename:
base_qt = GenericQueryTable.rename_columns(base_qt, cte_name.renaming)
first_rce_number = tctx.max_rce_number + 1
rce_counter = xcount(first_rce_number)
recursive_reference_qt = GenericQueryTable([
(column_name,
RecursiveReferenceQuery(rce_nr,
base_qt[column_name].schema.unconstrain()))
for column_name, rce_nr
in zip(base_qt.columns, rce_counter)
])
recursive_tctx = tctx.let({cte_name.name: recursive_reference_qt})
recursive_tctx.max_rce_number = next(rce_counter) - 1
recursive_qt, _ = self.recursive_te.make_qt(recursive_tctx, ())
if base_qt.table_schema.kind != TableKind.REGULAR or \
recursive_qt.table_schema.kind != TableKind.REGULAR:
raise InvalidQueryException(
"recursive queries must yield regular tables")
if len(base_qt.columns) != len(recursive_qt.columns):
raise InvalidQueryException(
"mismatched column count in recursive CTE")
# We require an exact schema match because the value of a
# recursive reference column may come from either the base case or
# from repeated evaluations of the recursive case depending on
# whether we've applied the recursive case before, and the schema
# must be the same either way because the schema controls the
# dtype and we can choose only one dtype (for each column) as the
# overall result of the recursive CTE evaluation.
for base_column_name, recursive_column_name \
in zip(base_qt.columns, recursive_qt.columns):
base_column = base_qt[base_column_name]
recursive_column = recursive_qt[recursive_column_name]
if base_column.schema.unconstrain() != recursive_column.schema:
raise InvalidQueryException(
"schema mismatch in recursive CTE: "
"{!r}/{} in base case vs {!r}/{} in recursive case".format(
base_column_name,
base_column.schema,
recursive_column_name,
recursive_column.schema))
recursive = Recursive(first_rce_number, [
RecursivePart(base_qt[base_column],
recursive_qt[recursive_column])
for base_column, recursive_column
in zip(base_qt.columns, recursive_qt.columns)
])
assert len(recursive.parts) == len(recursive_reference_qt.columns)
return GenericQueryTable([
(column_name, recursive.make_result_query(part_number))
for column_name, part_number in zip(base_qt.columns,
recursive.part_by_number)
])