blob: 79bd18c9e7b77b48d7c3034554f5003c4ca3f9f0 [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SQL front-end to query.py"""
# TODO(dancol): we should support partitioning by multiple columns;
# right now, it's zero or one.
import logging
import operator as oper
import sys
import weakref
from io import StringIO
from collections import OrderedDict, defaultdict
from collections.abc import Mapping
from functools import partial, reduce
from itertools import chain, count as xcount, starmap
from textwrap import dedent
from cytoolz import (
first,
second,
)
from modernmp.util import (
assert_seq_type,
once,
the,
)
from .util import (
BOOL,
CaseInsensitiveCasePreservingDict,
EqImmutable,
ExplicitInheritance,
Immutable,
NoneType,
UsageTrackingDictionary,
abstract,
all_same,
cached_property,
final,
iattr,
override,
tattr,
ureg,
)
from .iddict import (
IdentityDictionary,
WeakKeyIdentityDictionary,
)
from .query import (
ArgSortQuery,
BadUnitsException,
BinaryOperationQuery,
C_UNIQUE,
CoalesceQuery,
CollateQuery,
DURATION_SCHEMA,
FlsQuery,
GenericQueryTable,
InvalidQueryException,
QueryNode,
QuerySchema,
QueryTable,
REGULAR_TABLE,
SPAN_UNPARTITIONED_TIME_MAJOR,
SimpleQueryNode,
SqlAttributeLookup,
TS_SCHEMA,
TableKind,
TableSchema,
TableSorting,
UNIQUE,
UnaryOperationQuery,
iattr_query_node_int,
passthrough,
)
from .queryop import (
Backfill,
DropDuplicatesIndexerQuery,
DropDuplicatesQuery,
EventJoin,
EventTableConfig,
GroupCountQuery,
GroupLabelsQuery,
GroupSizesQuery,
JoinMeta,
JoinMetaQuery,
NativeGroupedAggregationQuery,
NativeUngroupedAggregationQuery,
SpanFixup,
SpanGroup,
SpanJoin,
SpanPivot,
SpanTableConfig,
TimeSeriesQueryTable,
WELL_KNOWN_AGGREGATIONS,
)
from .sql_util import identifier_quote
log = logging.getLogger(__name__)
NO_ARGS = {}
"""Dummy dict for when we have no subst args in compilation"""
_documentation = WeakKeyIdentityDictionary()
"""Holds documentation for various table-namespace things"""
def _munge_columnwise_path(path):
# SQL columnwise functions live in a namespace distinct from
# that of tables, which we simulate by munging columnwise
# function names and looking them up in our single namespace.
assert not isinstance(path, str)
return tuple(path[:-1]) + (("__colfunc_" + path[-1]),)
class LazyResolverFailedError(RuntimeError):
"""Raised when a lazy resolver fails with a special exception
We wrap KeyError in this exception type so that we don't confuse
resolver failure with a missing attribute.
"""
class Namespace(CaseInsensitiveCasePreservingDict, SqlAttributeLookup):
"""SQL object namespace
SQL conceptually has at least two namespaces: one for tables and one
for aggregate functions. Rather than complicate lookup by actually
tracking these various namespaces, we just transform
non-table-namespace entry names before lookup.
SQL namespaces themselves are case-sensitive and case-preserving.
SQL language case insensitivity is on the lookup side.
"""
disable_autocomplete = False
"""Set to true to prevent autocomplete resolution"""
class KeyExistsError(InvalidQueryException):
"""Exception raised on assignment if a value already exists"""
class NotNamespaceError(InvalidQueryException):
"""Excepted raised on traversal through non-namespace"""
@final
def assign_by_path(self, path, item, *,
overwrite=False,
make_namespaces=False):
"""Assign a value in the namespace.
PATH is a sequence of names. ITEM is the value to assign.
If OVERWRITE is true and an item already exists at leaf position,
overwrite that item. If the item already exists and OVERWRITE is
false, raise KeyExistsError.
All non-final entries in PATH must refer to existing namespaces if
MAKE_NAMESPACES is false. If this condition is violated, raise
KeyError if a name is absent or NotNamespaceError if a value is
present but not a namespace. If MAKE_NAMESPACES is true, create
empty namespace objects along path as needed, but never replacing
existing values.
"""
assert not isinstance(path, str)
last_idx = len(path) - 1
assert last_idx >= 0
ns = self
for i, name in enumerate(path):
if i == last_idx:
if not overwrite and name in ns:
raise self.KeyExistsError(
"key {} already exists".format(identifier_quote(name)))
ns[name] = item
else:
try:
ns = ns[name]
except KeyError:
if not make_namespaces:
raise
ns[name] = ns = Namespace()
else:
if not isinstance(ns, Namespace):
raise self.NotNamespaceError
@final
def delete_by_path(self, path):
"""Delete the thing at PATH.
If PATH refers to a namespace, all values under that namespace are
dropped.
All non-final entries in PATH must refer to existing namespaces.
If this constraint is violated, raise KeyError if a name is absent
or NotNamespaceError if a value is present but not a namespace.
"""
assert not isinstance(path, str)
last_idx = len(path) - 1
assert last_idx >= 0
ns = self
for i, name in enumerate(path):
if i == last_idx:
del ns[name]
else:
ns = ns[name]
if not isinstance(ns, Namespace):
raise self.NotNamespaceError
@final
@override
def __getitem__(self, key):
value = super().__getitem__(key)
if isinstance(value, LazyNsEntry):
# Make sure to propagate any documentation to the value
# being created.
doc = _documentation.get(value)
try:
self[key] = value = value.resolver()
except KeyError as ex:
# Wrap the exception so a failure in resolver doesn't look
# like a missing key.
raise LazyResolverFailedError from ex
if doc is not None:
_documentation[value] = doc
return value
lookup_sql_attribute = __getitem__
@final
@override
def enumerate_sql_attributes(self):
return tuple(self)
@final
def walk(self):
"""Generate left node contents of namespace.
Yield (PATH, VALUE) tuples, where PATH is a tuple giving the path
to an item and VALUE is the item stored.
"""
for name in sorted(self):
value = self[name]
if isinstance(value, Namespace):
for sub_name, sub_value in value.walk():
yield name + "." + sub_name, sub_value
else:
yield name, value
def sql_qt(self, sql, args=NO_ARGS): # pylint: disable=dangerous-default-value
"""Make a QueryTable out of SQL evaluated in this namespace"""
return _parse_select(sql).make_qt(TvfContext.from_ns(self, args))
__inherit__ = dict(lookup_sql_attribute=override)
EMPTY_NS = Namespace()
@final
class LexicalEnvironment(SqlAttributeLookup):
"""Lexical environment providing symbols for query compilation"""
@override
def __init__(self, parent, ns):
assert isinstance(parent, (NoneType, LexicalEnvironment))
assert isinstance(ns, Namespace)
self.__parent = parent
self.__ns = ns
def chain(self, ns):
"""Bind NS and return a new lexical environment"""
return LexicalEnvironment(self, ns)
@override
def lookup_sql_attribute(self, name):
try:
return self.__ns.lookup_sql_attribute(name)
except KeyError:
parent = self.__parent
if parent is not None:
return parent.lookup_sql_attribute(name)
raise
@override
def enumerate_sql_attributes(self):
attr = set(super().enumerate_sql_attributes())
attr.update(self.__ns.enumerate_sql_attributes())
if self.__parent is not None:
attr.update(self.__parent.enumerate_sql_attributes())
return attr
@final
class TvfContext(ExplicitInheritance):
"""Everything we need to evaluate functions in TVF context"""
max_rce_number = 0
"""Bookkeeping for recursive CTE transformation"""
@override
def __init__(self, lexenv, args):
assert isinstance(lexenv, LexicalEnvironment)
assert isinstance(args, Mapping)
self.__lexenv = lexenv
self.__args = args
self.lookup_sql_attribute_by_path = \
lexenv.lookup_sql_attribute_by_path
@staticmethod
def from_ns(ns=None, args=NO_ARGS): # pylint: disable=dangerous-default-value
"""Make a TvfContext from a namespace and arguments.
The standard namespace is available automatically. If NS is not
none, it must be a SQL namespace that acts as an overlay on top of
the standard namespace.
"""
lexenv = make_standard_lexical_environment()
if ns:
lexenv = lexenv.chain(ns)
return TvfContext(lexenv, args)
def get_bind_value(self, ref):
"""Resolve query compilation environment option"""
return self.__args[ref]
def let(self, bindings):
"""Prepend a lexical environment to this TvfContext
BINDINGS is a mapping describing the additional bindings that the
new context should have.
Return a new TvfContext incorporating the new bindings, leaving
self unchanged.
"""
ns = Namespace()
ns.update(bindings)
tctx = TvfContext(self.__lexenv.chain(ns), self.__args)
if self.max_rce_number:
tctx.max_rce_number = self.max_rce_number
return tctx
@property
def lexenv(self):
"""The lexical environment for this TVF"""
return self.__lexenv
@final
class QueryCompilationContext(ExplicitInheritance):
"""Local data kept during tokens"""
regenerate_column_names = False
prohibit_window_functions = False
@override
def __init__(self, tctx):
assert isinstance(tctx, TvfContext)
self.tctx = tctx
self.colrefs = IdentityDictionary()
self.join_info = IdentityDictionary()
self.__te_to_qt = IdentityDictionary()
self.__functions_by_funcall = IdentityDictionary()
aggregate_functions_used = False
simplified_gb_expressions = None
def te_to_qt(self, te):
"""Generate the QueryTable for the given table expression."""
assert isinstance(te, QueryTableExpression), \
"expected a QTE: got {!r}".format(te)
qt = self.__te_to_qt.get(te)
if not qt:
self.__te_to_qt[te] = qt = te.uncached_make_qt(self)
assert isinstance(qt, QueryTable), \
"wanted a QueryTable: got {!r}".format(qt)
return qt
def resolve_funcall(self, funcall):
"""Find and cache the function object for AST node FUNCALL"""
assert isinstance(funcall, FunctionCall)
function = self.__functions_by_funcall.get(funcall)
if not function:
self.__functions_by_funcall[funcall] = function = \
self.tctx.lookup_sql_attribute_by_path(
_munge_columnwise_path(funcall.path))
return function
class UnboundReferenceException(InvalidQueryException):
"""Exception thrown for an unbound reference in a query"""
class AstNode(EqImmutable):
"""Root note of a SQL parse"""
def get_children(self):
"""Return a list of child nodes"""
for value in self.__dict__.values():
if isinstance(value, AstNode):
yield value
if isinstance(value, tuple):
for subvalue in value:
if isinstance(subvalue, AstNode):
yield subvalue
def walk(self):
"""Yield self and children recursively"""
yield self
for child in self.get_children():
yield from child.walk()
def simplify(self, subst=None):
"""Apply local simplifications
If SUBST is non-None, it is a function of one argument, an
AstNode. It returns a node to substitute. Any node that subst
changes will not be re-simplified by this routine. SUBST should
call simplify() explicitly to continue simplification.
Return self if unchanged or a new node if something changed.
"""
if subst:
new_self = subst(self)
if self is not new_self:
return new_self
changes = {}
for key, value in tuple(self.__dict__.items()):
if isinstance(value, AstNode):
simplified_value = value.simplify(subst)
elif isinstance(value, tuple):
simplified_value = tuple((part if isinstance(part, str)
else part.simplify(subst))
for part in value)
if not any(orig is not new
for orig, new in zip(value, simplified_value)):
simplified_value = value
else:
continue
if simplified_value is not value:
changes[key] = simplified_value
if not changes:
return self
return self.evolve(**changes)
class NoWalkAstNode(AstNode):
"""Mixin for not walking children"""
@final
@override
def get_children(self):
return () # Don't walk subqueries
@final
@override
def simplify(self, subst=None):
return self # Don't walk subqueries
class Expression(AstNode):
"""AST node for a value computed"""
def self_resolve(self, ctx, resolver):
"""Resolve column references and return self"""
for child in self.get_children():
child.self_resolve(ctx, resolver)
return self
@abstract
def to_query(self, ctx, te):
"""Build the QueryNode this expression represents"""
raise NotImplementedError("abstract")
@final
def to_query_uncoordinated(self, tctx):
"""Build a QueryNode that can't reference query lexical context
We use this method in contexts like VALUES in which we want to
generate expressions, but can't let expressions refer to
surrounding query context.
"""
assert isinstance(tctx, TvfContext)
return self.to_query(QueryCompilationContext(tctx),
DummyTableExpression())
def flatten_conjunctions(self):
"""Transform expression into tuple of AND-ed expressions"""
return self,
def decompose_into_equi_join(self, _ctx, _left_te_id, _right_te_id): # pylint: disable=no-self-use
"""Extract equi-join conditions from expression.
If this expression doesn't represent an equi-join, return a false
value. Otherwise, return a tuple (LEFT_EXPR, RIGHT_EXPR, OP) of
expressions to match on each side of the join. LEFT_TE_ID and
RIGHT_TE_ID are sets of table expression IDs on the left and right
sides of the join. OP is the equi-join operation, which is either
"=" or "<=>" for no-NULL and NULL-allowed matches, respectively.
"""
return None
def te_ids_used(self, ctx):
"""All TableExpressions in this expression"""
return frozenset(id(ctx.colrefs[tr][0])
for tr in self.walk()
if isinstance(tr, ColumnReference))
def unwrap_collation(self):
"""Return COLLATION, AST_NODE decoding CollationTag"""
return None, self
def to_sort_query(self, ctx, te):
"""Make a QueryNode suitable for use in sorting operations"""
query = self.to_query(ctx, te)
if query.schema.is_string:
query = CollateQuery(query, "binary")
return query
@abstract
def dump(self, ctx, out):
"""Write to OUT a simple text description of this operation.
This method is used for automatically forming column names. The resulting
string should be adequate for human inspection, but doesn't have to completely
describe the expression."""
raise NotImplementedError("abstract")
def get_auto_name(self, ctx):
"""Automatic column name for this expression
If we set ctx.regenerate_column_names, we re-run the column naming
pass after building the query, at which time we might have more
information and generate a better name than we did the first time.
"""
out = StringIO()
self.dump(ctx, out)
return out.getvalue()
@final
class Cast(Expression):
"""Data type cast"""
expr = iattr(Expression)
dtype = iattr(str)
safe = iattr(bool)
@override
def to_query(self, ctx, te):
return self.expr.to_query(ctx, te).to_dtype(self.dtype, self.safe)
@override
def dump(self, ctx, out):
out.write("CAST(")
self.expr.dump(ctx, out)
out.write(", ")
out.write(identifier_quote(self.dtype))
out.write(")")
@abstract
class CteBindingValue(AstNode):
"""AST node that provides a value for a CTE binding"""
@abstract
def make_cte_value(self, tctx, cte_name):
"""Make the value to which we should bind a CTE.
TCTX is a TvfContext.
CTE_NAME is the common table expression binding name (a
CteBindingName) for which we're making this table.
Having CTE_NAME available here is important for recursive CTEs,
which need to lexically bind references to CTE_NAME to the table
under construction. We generally ignore CTE_NAME otherwise.
This function returns a value that we embed into the SQL table
namespace. This value is usually a QueryTable, but it can be any
object. The ability to embed arbitrary objects is sometimes
useful.
"""
raise NotImplementedError("abstract")
@abstract
class FunExpr(CteBindingValue):
"""AST node for an expression in a table-valued funcall"""
@abstract
def evaluate_tvf(self, tctx):
"""Parse-time evaluation of a table-valued function
TCTX is a TvfContext for the evaluation.
"""
raise NotImplementedError("abstract")
@final
@override
def make_cte_value(self, tctx, cte_name):
assert isinstance(cte_name, CteBindingName)
value = self.evaluate_tvf(tctx)
if cte_name.do_rename:
if not isinstance(value, QueryTable):
raise InvalidQueryException(
"can rename columns in a CTE only if "
"the CTE evaluates to a table!")
value = GenericQueryTable.rename_columns(value, cte_name.renaming)
return value
@final
class UnaryFunExprOperation(FunExpr):
"""Unary operation in table-valued funcall"""
operator = iattr(str)
argument = iattr(FunExpr)
OPS = {
"-": oper.neg,
"+": oper.pos,
"~": oper.inv,
"!": oper.not_,
}
@override
def evaluate_tvf(self, tctx):
return self.OPS[self.operator](self.argument.evaluate_tvf(tctx))
@final
class BinaryFunExprOperation(FunExpr):
"""Binary operation in table-valued funcall"""
left = iattr(FunExpr)
operator = iattr(str)
right = iattr(FunExpr)
OPS = {
"*": oper.mul,
"/": oper.truediv,
"%": oper.mod,
"//": oper.floordiv,
"+": oper.add,
"-": oper.sub,
"<<": oper.lshift,
">>": oper.rshift,
"&": oper.and_,
"|": oper.or_,
"=": oper.eq,
"<=>": oper.eq,
"<!=>": oper.ne,
">=": oper.ge,
">": oper.gt,
"<=": oper.le,
"<": oper.lt,
"<>": oper.ne,
"!=": oper.ne,
"==": oper.eq,
"is": oper.is_,
"is not": oper.is_not,
# and/or special-cased
}
@override
def evaluate_tvf(self, tctx):
operator = self.operator
if operator == "and":
return self.left.evaluate_tvf(tctx) and self.right.evaluate_tvf(tctx)
if operator == "or":
return self.left.evaluate_tvf(tctx) or self.right.evaluate_tvf(tctx)
left = self.left.evaluate_tvf(tctx)
right = self.right.evaluate_tvf(tctx)
return self.OPS[self.operator](left, right)
class AstLiteral(Expression, FunExpr):
"""Parent of literal values"""
__abstract__ = True
@final
@override
def to_query(self, ctx, te):
# TODO(dancol): use broadcasting instead of count
value = QueryNode.coerce_(self.evaluate_tvf(ctx.tctx))
return QueryNode.filled(value, te.make_count_query(ctx))
@final
class IntegerLiteral(AstLiteral):
"""AST node for a literal integer"""
value = iattr(int, converter=int)
unit = iattr(default=None)
@override
def evaluate_tvf(self, tctx):
if self.unit:
try:
unit = ureg().parse_units(self.unit)
except:
raise BadUnitsException("Invalid unit {!r}".format(self.unit))
return self.value * unit
return self.value
@override
def dump(self, ctx, out):
# We leave out the unit because it's frequently repeated in column
# headings
out.write("{}".format(self.value))
@final
class FloatLiteral(AstLiteral):
"""AST node for a literal float"""
value = iattr(float, converter=float)
unit = iattr(default=None)
@override
def evaluate_tvf(self, tctx):
if self.unit:
return self.value * ureg().parse_units(self.unit)
return self.value
@override
def dump(self, ctx, out):
# We leave out the unit because it's frequently repeated in column
# headings
out.write("{}".format(self.value))
def quote(s):
"""Quote a string as a SQL string"""
return "'{}'".format(s.replace("'", "''"))
def path2str(path):
"""Convert a path sequence to a path string"""
return ".".join(map(identifier_quote, path))
@final
class StringLiteral(AstLiteral):
"""AST node for a literal thing"""
value = iattr(str, converter=str)
@staticmethod
def decode(unescaped):
"""Parse a quoted SQL string into its unquoted payload"""
delim = unescaped[0]
assert unescaped.endswith(delim), \
"invalid SQL string {!r}".format(unescaped)
return unescaped[1:-1].replace(delim*2, delim)
@override
def evaluate_tvf(self, tctx):
return self.value
@override
def dump(self, ctx, out):
out.write(quote(self.value))
@final
class TrueLiteral(AstLiteral):
"""AST node for true"""
@override
def evaluate_tvf(self, tctx):
return True
@override
def dump(self, ctx, out):
out.write("true")
@final
class FalseLiteral(AstLiteral):
"""AST node for false"""
@override
def evaluate_tvf(self, tctx):
return False
@override
def dump(self, ctx, out):
out.write("true")
@final
class NullLiteral(AstLiteral):
"""AST node for NULL"""
@override
def evaluate_tvf(self, tctx):
return None
@override
def dump(self, ctx, out):
out.write("true")
@final
class KeywordArgument(AstNode):
"""AST node for a function keyword argument"""
name = iattr(str)
expr = iattr(Expression)
def dump(self, ctx, out):
"""Dump for syntax"""
out.write(identifier_quote(self.name))
out.write("=>")
self.expr.dump(ctx, out)
@final
class FunExprKeywordArgument(AstNode):
"""AST node for function keyword argument in TVF context"""
name = iattr(str)
expr = iattr(FunExpr)
@final
class FunExprList(FunExpr):
"""AST node for a list literal in TBF context"""
parts = tattr(FunExpr)
@override
def evaluate_tvf(self, tctx):
return [part.evaluate_tvf(tctx) for part in self.parts]
@final
class FunExprDictItem(AstNode):
"""Root for dict parts of the TVF AST"""
kw = iattr((FunExpr, str))
expr = iattr(FunExpr)
def evaluate(self, tctx):
"""Create a dict pair from this AST node"""
kw = self.kw
if isinstance(kw, FunExpr):
kw = kw.evaluate_tvf(tctx)
return kw, self.expr.evaluate_tvf(tctx)
@final
class FunExprDict(FunExpr):
"""Dictionary literal"""
items = tattr(FunExprDictItem)
@override
def evaluate_tvf(self, tctx):
return dict(item.evaluate(tctx) for item in self.items)
@final
class FunctionCall(Expression):
"""AST node for a function call"""
path = tattr(str, nonempty=True)
arguments = tattr((Expression, KeywordArgument), default=())
distinct = iattr(bool, default=False)
@override
def self_resolve(self, ctx, resolver):
function = ctx.resolve_funcall(self)
if isinstance(function, AggregationFunction):
ctx.aggregate_functions_used = True
resolver = resolver.get_aggregate_argument_resolver()
if (isinstance(function, NormalFunction) and
function.hook_self_resolve(ctx, self, resolver)):
return self
return super().self_resolve(ctx, resolver)
@override
def to_query(self, ctx, te):
return te.make_funcall_query(ctx, self)
@override
def dump(self, ctx, out):
out.write(".".join(map(identifier_quote, self.path)))
out.write("(")
if self.distinct:
out.write("DISTINCT")
if self.arguments:
out.write(" ")
first_argument = True
for argument in self.arguments:
if first_argument:
first_argument = False
else:
out.write(", ")
argument.dump(ctx, out)
out.write(")")
@final
class ColumnReference(Expression):
"""AST node for a "variable" reference"""
column = iattr(str)
path = tattr(str, default=())
@override
def __str__(self):
return ".".join(map(identifier_quote,
self.path + (self.column,)))
@override
def self_resolve(self, ctx, resolver):
if self not in ctx.colrefs:
ctx.colrefs[self] = resolver.match_column_unique(ctx, self)
return self
@override
def to_query(self, ctx, te):
table_reference, column_name = ctx.colrefs[self]
return te.make_query(ctx, table_reference, column_name)
def is_bare_match(self, column):
"""Return whether we're a bare match for column-name COLUMN"""
return not self.path and self.column == column
def as_bare_reference(self):
"""Strip everything but the leaf name of this ColumnReference
Return a new ColumnReference"""
return ColumnReference(self.column)
@override
def dump(self, ctx, out):
out.write(str(self))
@final
class SubqueryExpression(Expression, NoWalkAstNode):
"""AST node for a subquery in column-expression context
We do not yet support coordinated subqueries, so the subquery is
just evaluated on its own and glommed onto the result set.
The query must have a single column and length of one or the table
length, of course.
"""
# We can't specify the type we really want: it'd make a
# circular reference.
subquery = iattr(AstNode)
@override
def to_query(self, ctx, te):
assert isinstance(self.subquery, Select)
qt = ctx.join_info.get(self)
if not qt:
# pylint: disable=no-member
ctx.join_info[self] = qt = self.subquery.make_qt(ctx.tctx)
return QueryNode.filled(QueryNode.coerce_(qt), te.make_count_query(ctx))
@override
def dump(self, ctx, out):
qt = ctx.join_info.get(self)
if qt and len(qt.columns) == 1:
out.write(first(qt.columns))
else:
ctx.regenerate_column_names = True
out.write("{subquery}")
@final
class BinaryOperation(Expression):
"""AST node for a two-operand operation"""
left = iattr(Expression)
operator = iattr(str)
right = iattr(Expression)
EQUIV = {
"<>": "!=",
"==": "=",
"is": "<=>",
"is not": "<!=>",
"is distinct from": "<!=>",
"is not distinct from": "<=>",
}
@staticmethod
def and_(left, right):
"""Convenience function for generating a conjunction"""
return BinaryOperation(left, "and", right)
@staticmethod
def eq(left, right):
"""Convenience function for equality"""
return BinaryOperation(left, "=", right)
@override
def to_query(self, ctx, te):
left_collation, left = self.left.unwrap_collation()
right_collation, right = self.right.unwrap_collation()
return BinaryOperationQuery(
left.to_query(ctx, te),
self.operator,
right.to_query(ctx, te),
left_collation or right_collation or "binary")
def __decompose_into_equi_join_1(self,
ctx,
expr_left,
expr_right,
left_te_id,
right_te_id):
expr_left_te_id = expr_left.te_ids_used(ctx)
expr_left_contained = expr_left_te_id <= left_te_id
expr_right_te_id = expr_right.te_ids_used(ctx)
expr_right_contained = expr_right_te_id <= right_te_id
if (expr_left_contained and expr_right_contained and
expr_left_te_id and expr_right_te_id):
return (expr_left, expr_right, self.operator)
return None
@override
def decompose_into_equi_join(self, ctx, left_te_id, right_te_id):
return (self.operator in ("=", "<=>") and
((self.__decompose_into_equi_join_1(ctx, self.left, self.right,
left_te_id, right_te_id))
or
(self.__decompose_into_equi_join_1(ctx, self.right, self.left,
left_te_id, right_te_id))))
@override
def simplify(self, subst=None):
equiv = self.EQUIV.get(self.operator)
if equiv:
return BinaryOperation(self.left, equiv, self.right).simplify(subst)
# We transform and this way at the AST level so that code breaking
# down where conditions into conjugation lists works on the
# simplest possible such list.
if self.operator == "and":
if isinstance(self.left, TrueLiteral) and \
isinstance(self.right, TrueLiteral):
return TrueLiteral().simplify(subst)
if isinstance(self.left, TrueLiteral):
return self.right.simplify(subst)
if isinstance(self.right, TrueLiteral):
return self.left.simplify(subst)
return super().simplify(subst)
@override
def flatten_conjunctions(self):
if self.operator == "and":
return (self.left.flatten_conjunctions() +
self.right.flatten_conjunctions())
return self,
@override
def dump(self, ctx, out):
out.write("(")
self.left.dump(ctx, out)
out.write(" {} ".format(self.operator))
self.right.dump(ctx, out)
out.write(")")
@final
class UnaryOperation(Expression):
"""AST node for a single-operand expression"""
operator = iattr(str)
argument = iattr(Expression)
EQUIV = {
"not": "!",
}
@override
def to_query(self, ctx, te):
return UnaryOperationQuery(self.operator,
self.argument.to_query(ctx, te))
@override
def simplify(self, subst=None):
equiv = self.EQUIV.get(self.operator)
if equiv:
return UnaryOperation(equiv, self.argument).simplify(subst)
return super().simplify(subst)
@override
def dump(self, ctx, out):
out.write("( {} ".format(self.operator))
self.argument.dump(ctx, out)
out.write(")")
@final
class CollationTag(Expression):
"""AST node for expression tagged with collation"""
expr = iattr(Expression)
collation = iattr(str)
@override
def unwrap_collation(self):
return self.collation, self.expr
@override
def to_query(self, ctx, te):
raise InvalidQueryException("invalid use of COLLATE")
@override
def to_sort_query(self, ctx, te):
return CollateQuery(self.expr.to_query(ctx, te),
self.collation)
@override
def dump(self, ctx, out):
self.expr.dump(ctx, out)
out.write(" COLLATE {}".format(self.collation))
@final
class BetweenAndOperation(Expression):
"""AST node for an X BETWEEN Y AND Z expression"""
value = iattr(Expression)
lower = iattr(Expression)
upper = iattr(Expression)
@override
def to_query(self, ctx, te):
raise AssertionError("should have been syntactically simplified")
@override
def simplify(self, subst=None):
return BinaryOperation(
BinaryOperation(self.lower, "<=", self.value),
"and",
BinaryOperation(self.value, "<=", self.upper)) \
.simplify(subst)
@override
def dump(self, ctx, out):
out.write("(")
self.value.dump(ctx, out)
out.write(" BETWEEN ")
self.lower.dump(ctx, out)
out.write(" AND ")
self.upper.dump(ctx, out)
out.write(")")
@final
class UnitConversion(Expression):
"""AST node for a unit conversion"""
value = iattr(Expression)
unit = iattr()
allow_unitless = iattr(bool, default=False)
@override
def to_query(self, ctx, te):
base_query = self.value.to_query(ctx, te)
if base_query.schema.unit == self.unit:
return base_query
if self.unit:
unit = ureg().parse_units(self.unit)
return base_query.to_unit(unit,
allow_unitless=self.allow_unitless)
return base_query.strip_unit()
@override
def dump(self, ctx, out):
self.value.dump(ctx, out)
out.write(" IN {}".format(
identifier_quote(self.unit) if self.unit else "NULL"))
class Column(AstNode):
"""One column specification in a SELECT's column list"""
@abstract
def get_column_specs(self, ctx, source, is_gen_special, table_schema):
"""Return a sequence of column spec pairs. Each column spec pair is
(NAME, EXPRESSION).
"""
raise NotImplementedError("abstract")
@final
class WildcardColumn(Column):
"""A column specification with a wildcard"""
path = tattr(str, default=())
@override
def get_column_specs(self, ctx, source, is_gen_special, table_schema):
matchless = True
assert (not is_gen_special) or table_schema.kind != TableKind.REGULAR, \
"if we're generating non-regular output, we must have one on input"
for column_name, te in source.gen_wildcard_matches(ctx, self.path):
# If we're generating a span, we always include _ts and
# _duration in higher level code, so we should never match these
# columns via wildcard.
if is_gen_special and column_name in table_schema.meta_columns:
matchless = False
continue
cr = ColumnReference(column_name)
# Pre-resolve this new ColumnReference to wire it up to the TE
# we have here and skip name resolution. Recall that we look up
# ColumnReference instances by object identity, so cr here
# is unique.
ctx.colrefs[cr] = (te, column_name)
yield column_name, cr
matchless = False
if matchless and self.path:
raise InvalidQueryException(
"column specification '{}{}*' matched nothing".format(
".".join(map(identifier_quote, self.path)),
"." if self.path else ""))
@final
class ExpressionColumn(Column):
"""A column specification from an expression"""
expr = iattr(Expression)
name = iattr(str, nullable=True, default=None)
@override
def get_column_specs(self, ctx, _source, _is_gen_special, _table_schema):
return (
(self.name or self.expr.get_auto_name(ctx), self.expr),
)
class ExpressionQueryMaker(ExplicitInheritance):
"""Interface that expressions use for making queries in traversal"""
@abstract
def make_count_query(self, ctx):
"""Make a query yielding the length of this result set"""
raise NotImplementedError("abstract")
@abstract
def make_query(self, ctx, table_reference, column_name):
"""Make a QueryNode through lens of this table expression. The result
may differ from the raw QueryNode in the case of joins.
"""
raise NotImplementedError("abstract")
def make_funcall_query(self, ctx, funcall): # pylint: disable=no-self-use
"""Translate a function call to a QueryNode"""
function = ctx.resolve_funcall(funcall)
# The aggregating and grouping TableExpressions override this
# function and do something useful here for aggregating functions.
if isinstance(function, AggregationFunction):
raise InvalidQueryException("illegal use of aggregate function")
if funcall.distinct:
raise InvalidQueryException(
"DISTINCT/ALL is only for aggregate functions")
if isinstance(function, NormalFunction):
return function.make_query(ctx, self, funcall.arguments)
raise InvalidQueryException(
"illegal call to non-function {!r}".format(function))
class TableExpression(AstNode, ExpressionQueryMaker):
"""A table-valued expression"""
name = iattr(str, nullable=True, default=None, kwonly=True)
"""Explicitly-specified name given using AS"""
def te_walk(self):
"""Enumerate table expressions"""
return self,
@cached_property
def te_ids(self):
"""frozenset of TE IDs used by this TE and its children"""
return frozenset(id(te) for te in self.te_walk())
def rename_columns(self, _name, _renaming):
"""Clone with new name and renamed columns"""
raise InvalidQueryException(
"renaming columns not supported on node of type {}".format(type(self)))
@abstract
def push_conditions(self, ctx, conditions):
"""Apply conditions and configure joins
Return a sequence of "excess" conjunctions we couldn't apply at
levels below.
"""
raise NotImplementedError("abstract")
def self_resolve_te(self, ctx):
"""Note which tables our table references refer to"""
pass
@abstract
def gen_wildcard_matches(self, path):
"""Generate columns matching a column selection wildcard prefix
Each value yielded is a pair (COLUMN_NAME, TABLE_EXPRESSION) where
COLUMN_NAME is a string providing the name of the column and
TABLE_EXPRESSION is the TableExpression object providing that
column.
"""
raise NotImplementedError("abstract")
def match_cr_path(self, path):
"""Match columns against path component of a column reference
Used when generating wildcard columns and resolving column references"""
return not path or (len(path) == 1 and path[0] == self.name)
@abstract
def get_table_schema(self, ctx):
"""The span table schema"""
raise NotImplementedError("abstract")
def get_aggregate_argument_resolver(self):
"""Get the resolver for the arguments of an aggregate function.
Used for "magic grouping" operators to make the _ts in SUM(_ts)
resolve against the underlying table and not the magic grouping
placeholder table source.
"""
return self
def match_column(self, ctx, cr):
"""Yield a sequence of matches for column reference CR.
Each successful match is a (TABLE_EXPRESSION, COLUMN_NAME) pair.
"""
if self.match_cr_path(cr.path) and cr.column in ctx.te_to_qt(self):
return (self, cr.column),
return ()
def gen_table_paths(self, ctx): # pylint: disable=unused-argument
"""Yield a table names that are in scope for column references
Duplicates are allowed. Used for completion.
"""
yield () # Unqualified table references are always legal
if self.name:
yield (self.name,)
@final
def match_column_unique(self, ctx, cr):
"""Solve column reference CR to one column or raise"""
matches = list(self.match_column(ctx, cr))
if not matches:
raise UnboundReferenceException(
"column reference {} is unbound".format(cr))
if len(matches) > 1:
raise InvalidQueryException(
"column reference {} is ambiguous".format(cr))
return matches[0]
def to_schema(self, ctx, *, kind=None, sorting=None):
"""Return a new TE munging this one"""
table_schema = self.get_table_schema(ctx)
if not kind:
kind = table_schema.kind
if not sorting:
sorting = table_schema.sorting
if table_schema.kind == kind and \
table_schema.sorting == sorting:
return self
return ReSortingQueryMaker(self, table_schema, kind, sorting)
@final
class ReSortingQueryMaker(ExpressionQueryMaker):
"""Query maker that transparently restores span invariants"""
__indexer = None
@override
def __init__(self,
base,
base_table_schema,
wanted_kind,
wanted_sorting):
self.__base = the(ExpressionQueryMaker, base)
self.__base_table_schema = the(TableSchema, base_table_schema)
self.__wanted_sorting = the(TableSorting, wanted_sorting)
if base_table_schema.kind != wanted_kind:
raise InvalidQueryException(
"wanted a table kind={} found a {}".format(
TableKind.label_of(wanted_kind),
base_table_schema))
def __make_indexer(self, ctx):
base = self.__base
base_table_schema = self.__base_table_schema
wanted_sorting = self.__wanted_sorting
if wanted_sorting == TableSorting.NONE:
return False # Okay
assert base_table_schema.kind in (TableKind.EVENT, TableKind.SPAN)
partition = base_table_schema.partition
if not partition:
if base_table_schema.sorting != TableSorting.NONE:
# Without a partition, time-major and partition-major sorting
# are the same thing, so we don't need to do any work.
return False
sorts = ("_ts",)
elif wanted_sorting == TableSorting.TIME_MAJOR:
sorts = ("_ts", partition)
else:
assert wanted_sorting == TableSorting.PARTITION_MAJOR
sorts = (partition, "_ts")
return ArgSortQuery([
(ColumnReference(sort)
.self_resolve(ctx, base)
.to_query(ctx, base),
True) for sort in sorts
])
def __get_indexer(self, ctx):
indexer = self.__indexer
if indexer is None:
self.__indexer = indexer = self.__make_indexer(ctx)
return indexer
def __munge(self, ctx, query):
indexer = self.__get_indexer(ctx)
return query.take(indexer) if indexer else query
@override
def make_count_query(self, ctx):
return self.__base.make_count_query(ctx)
@override
def make_query(self, ctx, table_reference, column_name):
return self.__munge(
ctx,
self.__base.make_query(ctx, table_reference, column_name))
@override
def make_funcall_query(self, ctx, funcall):
return self.__munge(
ctx,
self.__base.make_funcall_query(ctx, funcall))
class QueryTableExpression(TableExpression, FunExpr):
"""TableExpression based on a QueryTable"""
__abstract__ = True
@override
def rename_columns(self, new_name, renaming):
return RenamingQueryTableExpression(self, renaming, name=new_name)
@final
def uncached_make_qt(self, ctx):
"""Make a QueryTable"""
return QueryTable.coerce_(self.evaluate_tvf(ctx.tctx))
@final
@override
def make_count_query(self, ctx):
return ctx.te_to_qt(self).countq()
@final
@override
def push_conditions(self, ctx, conjunctions):
# TODO(dancol): push conditions, when safe, down to
# leaf nodes in the join tree
return conjunctions # Will apply at higher level
@final
@override
def make_query(self, ctx, table_reference, column_name):
assert table_reference is self
return ctx.te_to_qt(self)[column_name]
@final
@override
def get_table_schema(self, ctx):
return ctx.te_to_qt(self).table_schema
@final
@override
def gen_wildcard_matches(self, ctx, path):
if self.match_cr_path(path):
for column in ctx.te_to_qt(self).columns:
yield column, self
@final
class RenamingQueryTableExpression(QueryTableExpression):
"""QueryTableExpression that provides column renaming"""
base = iattr(QueryTableExpression)
renaming = iattr()
@override
def evaluate_tvf(self, tctx):
return GenericQueryTable.rename_columns(
QueryTable.coerce_(self.base.evaluate_tvf(tctx)),
self.renaming)
@final
class BindParameter(AstLiteral, QueryTableExpression):
"""Bind parameter"""
ref = iattr((int, str))
def __get_value(self, tctx):
try:
return tctx.get_bind_value(self.ref)
except KeyError:
raise InvalidQueryException(
"no bind parameter given for {}".format(
self.__bind_name)) from None
@override
def evaluate_tvf(self, tctx):
return self.__get_value(tctx)
@property
def __bind_name(self):
if isinstance(self.ref, int):
return "?{}".format(self.ref)
return ":{}".format(identifier_quote(self.ref))
@override
def dump(self, ctx, out):
out.write(self.__bind_name)
@final
class TableValuedFunction(SqlAttributeLookup, ExplicitInheritance):
"""Provides a parameterized table to queries"""
@cached_property
def sql_attributes(self): # pylint: disable=no-self-use
"""Dictionary of SQL attributes"""
return {}
@override
def lookup_sql_attribute(self, key):
if key in self.sql_attributes:
return self.sql_attributes[key]
return super().lookup_sql_attribute(key)
@override
def __init__(self, function):
assert callable(function)
self.__function = function
doc = getattr(function, "__doc__")
if doc is not None:
add_sql_documentation(self, doc)
def __call__(self, *args, **kwargs):
return self.__function(*args, **kwargs)
@staticmethod
def from_select_ast(select_ast, ns=None):
"""Create a TVF from a query with bind parameters"""
assert isinstance(select_ast, Select)
assert isinstance(ns, (NoneType, Namespace))
ns_wr = weakref.ref(ns) if ns else None
def _function(*args, **kwargs):
args = UsageTrackingDictionary(qargs(*args, **kwargs))
if ns_wr:
ns = ns_wr()
assert ns
else:
ns = None
qt = select_ast.make_qt(TvfContext.from_ns(ns, args))
unused_arguments = set(args.keys()) - args.used
if unused_arguments:
raise InvalidQueryException(
"unused arguments: {}".format(unused_arguments))
return qt
return TableValuedFunction(_function)
def _do_tvf_call(tctx, tvf, arguments):
"""Common dispatch code for TvfFunctionCall and TableFunctionCall"""
args = []
kwargs = {}
for arg_node in arguments:
if isinstance(arg_node, FunExprKeywordArgument):
kwargs[arg_node.name] = arg_node.expr.evaluate_tvf(tctx)
else:
if kwargs:
raise InvalidQueryException(
"non-keyword arguments after keyword arguments")
args.append(arg_node.evaluate_tvf(tctx))
return tvf(*args, **kwargs)
@final
class TvfFunctionCall(FunExpr):
"""AST node for a function call in TVF context"""
fn_expr = iattr(FunExpr)
arguments = tattr((FunExpr, FunExprKeywordArgument), default=())
@override
def evaluate_tvf(self, tctx):
return _do_tvf_call(tctx,
self.fn_expr.evaluate_tvf(tctx),
self.arguments)
@final
class TableFunctionCall(QueryTableExpression):
"""AST node for a call to a table-valued function as a query source"""
# We can't just reuse TvfFunctionCall because in the table-source
# case, we need to store a path naming the function, not an
# expression yielding a function argument.
path = tattr(str, nonempty=True)
arguments = tattr((FunExpr, FunExprKeywordArgument), default=())
@override
def evaluate_tvf(self, tctx):
tvf = tctx.lookup_sql_attribute_by_path(self.path)
if not isinstance(tvf, TableValuedFunction):
raise InvalidQueryException(
"{} does not refer to a table-valued function".format(
path2str(self.path)))
return _do_tvf_call(tctx, tvf, self.arguments)
@override
def gen_table_paths(self, ctx):
yield from super().gen_table_paths(ctx)
yield self.path
@final
@override
def match_cr_path(self, path):
return super().match_cr_path(path) or self.path == path
@final
class TableReference(QueryTableExpression):
"""Simple reference to a named value in table namespace"""
path = tattr(str, nonempty=True)
@override
def evaluate_tvf(self, tctx):
try:
return tctx.lookup_sql_attribute_by_path(self.path)
except KeyError:
raise UnboundReferenceException(
"No table-namespace value {}".format(path2str(self.path))) \
from None
@override
def gen_table_paths(self, ctx):
yield from super().gen_table_paths(ctx)
yield self.path
@override
def match_cr_path(self, path):
return super().match_cr_path(path) or self.path == path
@final
class TvfDot(FunExpr):
"""Attribute or sub-namespace access in TVF context"""
expr = iattr(FunExpr)
name = iattr(str)
@override
def evaluate_tvf(self, tctx):
obj = self.expr.evaluate_tvf(tctx)
if not isinstance(obj, SqlAttributeLookup):
raise InvalidQueryException(
"object of type {} does not support attribute access".format(
type(obj)))
return obj.lookup_sql_attribute(self.name)
@final
class TableSubquery(QueryTableExpression, NoWalkAstNode):
"""A subquery used as a table expression"""
subquery = iattr(AstNode)
@override
def evaluate_tvf(self, tctx):
assert isinstance(self.subquery, Select)
return self.subquery.make_qt(tctx) # pylint: disable=no-member
# Joins
LEFT, RIGHT = 0, 1
@final
class ColumnNameList(AstNode):
"""List of comma names used as a join condition"""
columns = tattr(str, default=())
def _drop_duplicates_by_id(sequence):
seen = set()
for item in sequence:
if id(item) not in seen:
seen.add(id(item))
yield item
def _drop_duplicates(sequence):
seen = set()
for item in sequence:
if item not in seen:
seen.add(item)
yield item
class JoinInfo(ExplicitInheritance):
"""Behavioral guts of a join operation"""
__indexers = None
@override
def __init__(self, ctx, join): # pylint: disable=unused-argument
self._join = the(Join, join)
self.__ctx_wr = weakref.ref(the(QueryCompilationContext, ctx))
@property
def _ctx(self):
ctx = self.__ctx_wr()
assert ctx
return ctx
@abstract
def _take_for_join(self, is_right, base_query):
"""Do the take operation for a query side
IS_RIGHT is true iff BASE_QUERY comes from the RHS
of the join.
BASE_QUERY is the query to index into.
"""
raise NotImplementedError("abstract")
@abstract
def get_table_schema(self):
"""The table schema for this join"""
raise NotImplementedError("abstract")
def get_self_columns(self):
"""Sequence of columns provided by the join itself"""
return self.get_table_schema().meta_columns
@abstract
def push_conditions(self, ctx, conjunctions):
"""Push predicates down into join conditions if needed
See TableExpression.push_conditions().
"""
raise NotImplementedError("abstract")
@abstract
def make_count_query(self, ctx):
"""Make QueryNode yielding result set size"""
raise NotImplementedError("abstract")
def _make_query_for_self(self, _ctx, column_name): # pylint: disable=no-self-use
raise InvalidQueryException("invalid join reference "
+ repr(column_name))
def __get_side(self, te):
is_right = id(te) in self._join.right.te_ids
assert is_right or id(te) in self._join.left.te_ids
return is_right
@abstract
def _get_sub_sources(self, ctx):
raise NotImplementedError("abstract")
@final
def make_query(self, ctx, table_reference, column_name):
"""Make a query for a column reference viewed through this join"""
if table_reference is self._join:
return self._make_query_for_self(ctx, column_name)
side = self.__get_side(table_reference)
sub_source = self._get_sub_sources(ctx)[side]
assert isinstance(sub_source, ExpressionQueryMaker)
return self._take_for_join(
side,
sub_source.make_query(ctx, table_reference, column_name))
def match_column(self, ctx, cr):
"""Take over column-matching duties from Join"""
# Special case: a SPAN JOIN "generates" its own span metadata
# columns ex nihilo instead of forwarding them to the
# joined tables.
join = self._join
if len(cr.path) == 1 and cr.path[0] == join.name:
# If we've named a join `foo` and we're matching `foo`.`bar`,
# then pretend we're matching a plain `bar`, but only in the
# context of `foo`. We end up doing the right thing for named
# joins this way, since we still detect ambiguity.
yield from self.match_column(ctx, cr.as_bare_reference())
return
for self_column in self.get_self_columns():
if cr.is_bare_match(self_column):
yield join, self_column
return
yield from join.left.match_column(ctx, cr)
yield from join.right.match_column(ctx, cr)
class JoinInfoClassic(JoinInfo):
"""Join control for a normal SQL join"""
__key_exprs = None
@override
def __init__(self, ctx, join, table_schema=REGULAR_TABLE):
super().__init__(ctx, join)
self.__table_schema = the(TableSchema, table_schema)
@override
def match_column(self, ctx, cr):
# In the semi-join case of a special table with a regular RHS,
# punt any meta columns to the left table directly, since we know
# the right table won't have them.
if not cr.path and cr.column in self.__table_schema.meta_columns:
assert cr.column in \
self._join.left.get_table_schema(ctx).meta_columns
assert cr.column not in \
self._join.right.get_table_schema(ctx).meta_columns
yield from self._join.left.match_column(ctx, cr)
return
yield from super().match_column(ctx, cr)
@override
def get_self_columns(self):
on = self._join.on
if isinstance(on, ColumnNameList):
return on.columns
return ()
@override
def get_table_schema(self):
return self.__table_schema
def __get_join_conjunctions(self, ctx):
on = self._join.on
if isinstance(on, Expression):
return on.flatten_conjunctions()
if isinstance(on, ColumnNameList):
left, right = self._join.left, self._join.right
def _q1(column_name, resolver):
expr = ColumnReference(column_name)
expr.self_resolve(ctx, resolver)
return expr
return [
BinaryOperation.eq(_q1(column_name, left),
_q1(column_name, right))
for column_name in on.columns
]
assert on is None
return ()
@override
def push_conditions(self, ctx, conjunctions):
assert not self.__key_exprs
left_key_exprs = []
right_key_exprs = []
null_allowed = []
conjunctions_left = []
conjunctions_right = []
extra_conjunctions = []
join = self._join
left_te_id = join.left.te_ids
right_te_id = join.right.te_ids
seen = set()
def _handle_conjunction(expr, must_match):
if expr in seen:
return
seen.add(expr)
if must_match or join.op == "inner":
match = expr.decompose_into_equi_join(
ctx, left_te_id, right_te_id)
else:
match = None
if match:
left_expr, right_expr, op = match
left_key_exprs.append(left_expr)
right_key_exprs.append(right_expr)
assert op in ("=", "<=>")
null_allowed.append(op == "<=>")
elif must_match:
raise InvalidQueryException("not an equi-join: {}".format(expr))
elif expr.te_ids_used(ctx) <= left_te_id:
conjunctions_left.append(expr)
elif expr.te_ids_used(ctx) <= right_te_id:
conjunctions_right.append(expr)
else:
extra_conjunctions.append(expr)
for expr in self.__get_join_conjunctions(ctx):
_handle_conjunction(expr, True)
for expr in conjunctions:
_handle_conjunction(expr, False)
assert len(left_key_exprs) == len(right_key_exprs)
if not left_key_exprs and not join.kind:
raise InvalidQueryException("no join condition specified")
extra_conjunctions += \
join.left.push_conditions(ctx, conjunctions_left)
extra_conjunctions += \
join.right.push_conditions(ctx, conjunctions_right)
self.__key_exprs = (left_key_exprs, right_key_exprs, null_allowed)
return extra_conjunctions
@cached_property
def __indexer_array(self):
join = self._join
ctx = self._ctx
key_exprs = self.__key_exprs
assert key_exprs, \
"push_conditions must be called before using indexers"
sub_sources = self._get_sub_sources(ctx)
keys = [[expr.to_query(ctx, sub_sources[side])
for expr in key_exprs[side]]
for side in (LEFT, RIGHT)]
semi_join = any(UNIQUE in q.schema.constraints for q in keys[RIGHT])
if self.__table_schema.kind != TableKind.REGULAR:
assert join.left.get_table_schema(ctx).kind != TableKind.REGULAR
assert join.right.get_table_schema(ctx).kind == TableKind.REGULAR
if not semi_join:
raise InvalidQueryException(
"a table of kind {} can be joined to a table of kind {} for "
"result type {} only if the join condition is unique".format(
join.left.get_table_schema(ctx),
join.right.get_table_schema(ctx),
self.__table_schema))
op = join.op
m = partial(JoinMetaQuery,
keys[LEFT],
keys[RIGHT],
key_exprs[2],
op)
left = (None if (semi_join and op == "left")
else m(JoinMeta.LEFT_INDEXER))
return left, m(JoinMeta.RIGHT_INDEXER)
@override
def _take_for_join(self, is_right, base_query):
indexer = self.__indexer_array[is_right]
if not indexer:
return base_query
if not is_right and self._join.op in ("inner", "left"):
return base_query.take_sequential(indexer)
return base_query.take(indexer)
@override
def _get_sub_sources(self, ctx):
return self._join.left, self._join.right
@override
def make_count_query(self, ctx):
# TODO(dancol): OR-node that lets us use either the left or right
# indexer depending on which we need
indexers = self.__indexer_array
return (indexers[0] or indexers[1]).countq()
@override
def _make_query_for_self(self, ctx, column_name):
join = self._join
if isinstance(join.on, ColumnNameList):
try:
index = join.on.columns.index(column_name)
except IndexError:
index = -1
if index >= 0:
op = join.op
if op == "inner":
# TODO(dancol): OR-node: left or right would work
return self.__key_exprs[0][index].to_query(ctx, self)
if op == "left":
return self.__key_exprs[0][index].to_query(ctx, self)
if op == "right":
return self.__key_exprs[1][index].to_query(ctx, self)
assert op == "outer"
return CoalesceQuery.of(
self.__key_exprs[0][index].to_query(ctx, self),
self.__key_exprs[1][index].to_query(ctx, self))
return super()._make_query_for_self(ctx, column_name)
@final
class JoinInfoSpan(JoinInfo):
"""Join control for a span join"""
# TODO(dancol): implement join support for partition-major span
# tables instead of converting to time-major tables.
@override
def __init__(self, ctx, join):
super().__init__(ctx, join)
if join.on:
raise InvalidQueryException("SPAN JOIN...ON illegal here")
base_sources = (join.left, join.right)
schemas = [bs.get_table_schema(ctx) for bs in base_sources]
for sub_schema in schemas:
if sub_schema.kind != TableKind.SPAN:
raise InvalidQueryException(
("table expression is not a span table: {!r} "
"(only UNIQUE non-span tables may be joined with span tables, "
"and then only in INNER or LEFT join modes)"
).format(base_sources[schemas.index(sub_schema)]))
kind = join.kind
if kind == "span join":
if bool(schemas[LEFT].partition) != bool(schemas[RIGHT].partition):
raise InvalidQueryException(dedent("""\
Attempt to SPAN JOIN a partitioned span table with
a non-partitioned span table. This operation makes no
sense. Either collapse the partition using suitable grouping
operators or use BROADCAST to apply a non-partitioned set of
spans to partitioned data."""))
if schemas[LEFT].partition == schemas[RIGHT].partition:
our_partition = schemas[LEFT].partition
else:
raise InvalidQueryException(
"left and right partition columns differ ({!r} vs {!r}): "
"rename one table's partition".format(
schemas[LEFT].partition, schemas[RIGHT].partition))
elif kind == "span broadcast":
if schemas[LEFT].partition:
raise InvalidQueryException(
"SPAN BROADCAST needs non-partioned LHS")
if not schemas[RIGHT].partition:
raise InvalidQueryException(
"SPAN BROADCAST needs partitioned RHS")
our_partition = schemas[RIGHT].partition
else:
raise AssertionError(
"impossible span join type {!r}".format(join.kind))
# We enforce the time-major ordering thing inside
# indexer construction.
self.__our_table_schema = (
TableSchema(TableKind.SPAN, our_partition, TableSorting.NONE)
if our_partition else SPAN_UNPARTITIONED_TIME_MAJOR)
@override
def get_table_schema(self):
return self.__our_table_schema
@override
def push_conditions(self, ctx, conjunctions):
conjunctions_left = []
conjunctions_right = []
extra_conjunctions = []
join = self._join
left_te_id = join.left.te_ids
right_te_id = join.right.te_ids
for expr in _drop_duplicates(conjunctions):
if expr.te_ids_used(ctx) <= left_te_id:
conjunctions_left.append(expr)
elif expr.te_ids_used(ctx) <= right_te_id:
conjunctions_right.append(expr)
else:
extra_conjunctions.append(expr)
extra_conjunctions += \
join.left.push_conditions(ctx, conjunctions_left)
extra_conjunctions += \
join.right.push_conditions(ctx, conjunctions_right)
return extra_conjunctions
@cached_property
def __sub_sources(self):
ctx = self._ctx
join = self._join
return [
bs.to_schema(ctx,
kind=TableKind.SPAN,
sorting=TableSorting.TIME_MAJOR)
for bs in (join.left, join.right)]
@cached_property
def __span_join(self):
join = self._join
ctx = self._ctx
base_sources = (join.left, join.right)
crs = [[ColumnReference(meta_column_name).self_resolve(ctx, bs)
for meta_column_name in bs.get_table_schema(ctx).meta_columns]
for bs in base_sources]
kind = join.op
required = [kind in ("inner", "left"), kind in ("inner", "right")]
return SpanJoin([
SpanJoin.Source(
SpanTableConfig(
*[meta_cr.to_query(ctx, ss) for meta_cr in cr]),
required=r)
for cr, ss, r in zip(crs, self.__sub_sources, required)
])
@cached_property
def __indexer_array(self):
span_join = self.__span_join
return span_join.indexerq(0), span_join.indexerq(1)
@override
def _take_for_join(self, is_right, base_query):
# TODO(dancol): we can use take_sequential even in the partitioned
# case when we support partition-major ordering for span join.
indexer = self.__indexer_array[is_right]
if not indexer.config.sources[is_right].span.partition:
return base_query.take_sequential(indexer)
return base_query.take(indexer)
@override
def _get_sub_sources(self, ctx):
return self.__sub_sources
def __metaq(self, meta):
return self.__span_join.metaq(meta)
@override
def make_count_query(self, ctx):
return self.__metaq(SpanJoin.Meta.TIMESTAMP).countq()
@override
def _make_query_for_self(self, ctx, column_name):
if column_name == "_ts":
return self.__metaq(SpanJoin.Meta.TIMESTAMP)
if column_name == "_duration":
return self.__metaq(SpanJoin.Meta.DURATION)
if column_name == self.__our_table_schema.partition:
return self.__metaq(SpanJoin.Meta.PARTITION)
return super()._make_query_for_self(ctx, column_name)
@final
class JoinInfoEvent(JoinInfo):
"""Join control for events"""
@override
def __init__(self, ctx, join):
super().__init__(ctx, join)
if join.on:
raise InvalidQueryException("EVENT JOIN ... ON illegal here")
base_sources = (join.left, join.right)
left_table_schema, right_table_schema = [
bs.get_table_schema(ctx) for bs in base_sources]
event_is_required = join.op in ("inner", "left")
if not event_is_required:
raise InvalidQueryException(
"EVENT JOINs must be inner or left joins "
"with respect to events, since it's "
"not possible to condense a span into a point")
span_is_required = join.op in ("inner", "right")
kind = join.kind
if kind == "event join":
if not left_table_schema.partition and right_table_schema.partition:
raise InvalidQueryException(
"cannot join non-partitioned event table and partitioned "
"span table: use BROADCAST instead")
if left_table_schema.partition == right_table_schema.partition:
our_partition = left_table_schema.partition
elif left_table_schema.partition and not right_table_schema.partition:
our_partition = left_table_schema.partition
else:
raise InvalidQueryException(
"left and right partition columns differ ({!r} vs {!r}): "
"rename one input's partition".format(
left_table_schema, right_table_schema))
elif kind == "event broadcast":
if left_table_schema.partition:
raise InvalidQueryException(
"EVENT BROADCAST needs non-partioned LHS")
if not right_table_schema.partition:
raise InvalidQueryException(
"EVENT BROADCAST needs partitioned RHS")
our_partition = right_table_schema.partition
else:
raise AssertionError(
"impossible event join type {!r}".format(join.kind))
# TODO(dancol): We could be unconditionally
# TableSorting.TIME_MAJOR if the event join code would always emit
# partitions in order. Benchmark it both ways.
self.__span_is_required = span_is_required
self.__our_table_schema = TableSchema(
TableKind.EVENT,
our_partition,
TableSorting.NONE if our_partition else TableSorting.TIME_MAJOR)
@override
def get_table_schema(self):
return self.__our_table_schema
@override
def push_conditions(self, ctx, conjunctions):
conjunctions_left = []
conjunctions_right = []
extra_conjunctions = []
join = self._join
left_te_id = join.left.te_ids
right_te_id = join.right.te_ids
for expr in _drop_duplicates(conjunctions):
if expr.te_ids_used(ctx) <= left_te_id:
conjunctions_left.append(expr)
elif expr.te_ids_used(ctx) <= right_te_id:
conjunctions_right.append(expr)
else:
extra_conjunctions.append(expr)
extra_conjunctions += \
join.left.push_conditions(ctx, conjunctions_left)
extra_conjunctions += \
join.right.push_conditions(ctx, conjunctions_right)
return extra_conjunctions
@cached_property
def __indexer_array(self):
return (self.__event_join.metaq(EventJoin.Meta.EVENT_INDEX),
self.__event_join.metaq(EventJoin.Meta.SPAN_INDEX))
@override
def _take_for_join(self, is_right, base_query):
indexer = self.__indexer_array[is_right]
# TODO(dancol): too conservative?
config = indexer.config
if not config.event.partition and not config.span.partition:
return base_query.take_sequential(indexer)
return base_query.take(indexer)
@cached_property
def __event_join(self):
join = self._join
ctx = self._ctx
base_sources = (join.left, join.right)
crs = [[ColumnReference(meta_column_name).self_resolve(ctx, bs)
for meta_column_name in bs.get_table_schema(ctx).meta_columns]
for bs in base_sources]
ss = self.__sub_sources
return EventJoin(
event=EventTableConfig(
*[meta_cr.to_query(ctx, ss[LEFT])
for meta_cr in crs[LEFT]]),
span=SpanTableConfig(
*[meta_cr.to_query(ctx, ss[RIGHT])
for meta_cr in crs[RIGHT]]),
span_is_required=self.__span_is_required)
@cached_property
def __sub_sources(self):
ctx = self._ctx
join = self._join
return [
join.left.to_schema(ctx,
kind=TableKind.EVENT,
sorting=TableSorting.TIME_MAJOR),
join.right.to_schema(ctx,
kind=TableKind.SPAN,
sorting=TableSorting.TIME_MAJOR),
]
@override
def _get_sub_sources(self, ctx):
return self.__sub_sources
@override
def make_count_query(self, ctx):
# TODO(dancol): OR-node
return self.__event_join.metaq(EventJoin.Meta.TIMESTAMP).countq()
@override
def _make_query_for_self(self, ctx, column_name):
if column_name == "_ts":
return self.__event_join.metaq(EventJoin.Meta.TIMESTAMP)
if column_name == self.__our_table_schema.partition:
return self.__event_join.metaq(EventJoin.Meta.PARTITION)
return super()._make_query_for_self(ctx, column_name)
def _span_join_info_constructor(ctx, join):
# We don't allow span-joins of span tables to non-span tables
# because the join might duplicate the span rows, rendering the
# output span table invalid. This consideration doesn't apply,
# however, when we have a unique join, and in this case, we allow
# regular joins.
left_table_schema = join.left.get_table_schema(ctx)
right_table_schema = join.right.get_table_schema(ctx)
left_span = left_table_schema.kind == TableKind.SPAN
right_span = right_table_schema.kind == TableKind.SPAN
if left_span and not right_span and join.op in ("inner", "left"):
return JoinInfoClassic(ctx, join, left_table_schema)
return JoinInfoSpan(ctx, join)
_JOIN_CONSTRUCTORS = {
None: JoinInfoClassic,
"span join": _span_join_info_constructor,
"span broadcast": JoinInfoSpan,
"event join": JoinInfoEvent,
"event broadcast": JoinInfoEvent,
}
@final
class Join(TableExpression):
"""A join of two tables used as a table expression"""
left = iattr(TableExpression)
op = iattr(str)
right = iattr(TableExpression)
on = iattr((Expression, ColumnNameList), nullable=True, default=None)
kind = iattr(str, nullable=True, default=None)
def __get_join_info(self, ctx):
join_info = ctx.join_info.get(self)
if not join_info:
ctx.join_info[self] = \
join_info = \
_JOIN_CONSTRUCTORS[self.kind](ctx, self)
return join_info
@override
def push_conditions(self, ctx, conjunctions):
return self.__get_join_info(ctx).push_conditions(ctx, conjunctions)
@override
def te_walk(self):
yield self
yield from self.left.te_walk()
yield from self.right.te_walk()
@override
def self_resolve_te(self, ctx):
self.left.self_resolve_te(ctx)
self.right.self_resolve_te(ctx)
if isinstance(self.on, Expression):
self.on.self_resolve(ctx, self)
def __get_self_columns(self, ctx):
return self.__get_join_info(ctx).get_self_columns()
@override
def match_column(self, ctx, cr):
yield from self.__get_join_info(ctx).match_column(ctx, cr)
@override
def gen_table_paths(self, ctx):
yield from super().gen_table_paths(ctx)
yield from self.left.gen_table_paths(ctx)
yield from self.right.gen_table_paths(ctx)
@override
def gen_wildcard_matches(self, ctx, path):
self_columns = self.__get_self_columns(ctx)
if not self_columns:
yield from self.left.gen_wildcard_matches(ctx, path)
yield from self.right.gen_wildcard_matches(ctx, path)
return
if self.match_cr_path(path):
for column in self_columns:
yield column, self
if self.name and len(path) == 1 and path[0] == self.name:
citer = self.gen_wildcard_matches(ctx, ())
else:
citer = chain(self.left.gen_wildcard_matches(ctx, path),
self.right.gen_wildcard_matches(ctx, path))
# Omit special columns from wildcard matches
for column, te in citer:
if column not in self_columns:
yield column, te
@override
def make_query(self, ctx, table_reference, column_name):
return self.__get_join_info(ctx).make_query(
ctx, table_reference, column_name)
@override
def make_count_query(self, ctx):
return self.__get_join_info(ctx).make_count_query(ctx)
@override
def get_table_schema(self, ctx):
return self.__get_join_info(ctx).get_table_schema()
class BaseAggregatingQueryMaker(ExpressionQueryMaker):
"""Common functionality for aggregation and grouping"""
@override
def __init__(self, base_qm):
assert isinstance(base_qm, ExpressionQueryMaker)
self._base_qm = base_qm
@abstract
def make_group_sizes_query(self, ctx):
"""Make a query yielding the size of each group"""
raise NotImplementedError("abstract")
@abstract
def _make_aggregation_query(self, base_query, distinct, function):
raise NotImplementedError("abstract")
@override
def make_funcall_query(self, ctx, funcall):
function = ctx.resolve_funcall(funcall)
if not isinstance(function, AggregationFunction):
return super().make_funcall_query(ctx, funcall)
arguments = funcall.arguments
distinct = funcall.distinct
if distinct == "all":
distinct = None
if not arguments and function.name == "count":
# Special case for COUNT(*)
countq = self.make_group_sizes_query(ctx)
assert countq.schema.is_integral, \
"bad COUNT(*) query from {!r}".format(self)
return countq
if len(arguments) > 1:
raise InvalidQueryException("aggregate functions take one argument")
argument = arguments[0]
if not isinstance(argument, Expression):
raise InvalidQueryException(
"aggregate function argument must be simple expression, not {}"
.format(argument))
return self._make_aggregation_query(
argument.to_query(ctx, self._base_qm), distinct, function.name)
@override
def make_query(self, ctx, table_reference, column_name):
# TODO(dancol): support bare columns, SQLite-style
raise InvalidQueryException(
"column reference {}/{} does not have an aggregate".format(
table_reference, column_name))
@final
class GroupingQueryMaker(BaseAggregatingQueryMaker):
"""Aggregation over groups"""
@override
def __init__(self, base_qm, group_by_queries):
super().__init__(base_qm)
assert assert_seq_type(tuple, QueryNode, group_by_queries)
self.__group_by_queries = group_by_queries
@override
def make_query(self, ctx, table_reference, column_name):
base_q = self._base_qm.make_query(ctx, table_reference, column_name)
if base_q in self.__group_by_queries:
return GroupLabelsQuery(self.__group_by_queries, base_q)
return super().make_query(ctx, table_reference, column_name)
@override
def _make_aggregation_query(self, base_query, distinct, aggfunc):
return NativeGroupedAggregationQuery(
group_by=self.__group_by_queries,
data=base_query,
aggregation=aggfunc,
distinct=distinct)
@override
def make_group_sizes_query(self, ctx):
return GroupSizesQuery(self.__group_by_queries)
@override
def make_count_query(self, _ctx):
return GroupCountQuery(self.__group_by_queries)
@final
class AggregatingQueryMaker(BaseAggregatingQueryMaker):
"""Aggregation over entire result sets"""
@override
def _make_aggregation_query(self, base_query, distinct, aggfunc):
return NativeUngroupedAggregationQuery(
aggregation=aggfunc,
data=base_query,
distinct=distinct)
@override
def make_group_sizes_query(self, ctx):
return self._base_qm.make_count_query(ctx)
@override
def make_count_query(self, _ctx):
return QueryNode.scalar(1)
@final
class DepartitioningQueryMaker(BaseAggregatingQueryMaker):
"""Special span-departitioning aggregation voodoo"""
@override
def __init__(self, base_qm, config):
super().__init__(base_qm)
assert isinstance(base_qm, MagicGroupingBaseSource)
self.__config = the(SpanPivot, config)
@override
def _make_aggregation_query(self, base_query, distinct, aggfunc):
return self.__config.dataq(aggfunc, base_query, distinct)
@override
def make_group_sizes_query(self, ctx):
nr_partitions_q = DropDuplicatesQuery.of1(
self.__config.grouped.partition).countq()
return QueryNode.filled(nr_partitions_q,
self.make_count_query(ctx))
@override
def make_count_query(self, ctx):
# TODO(dancol): OR-node, when we get a grown-up optimizer
config = self.__config
return config.metaq(config.Meta.TIMESTAMP).countq()
@override
def make_query(self, ctx, table_reference, column_name):
config = self.__config
if table_reference is self._base_qm:
if column_name == "_ts":
return config.metaq(config.Meta.TIMESTAMP)
if column_name == "_duration":
return config.metaq(config.Meta.DURATION)
if (config.output_partition and
column_name == self._base_qm.get_table_schema(ctx).partition):
return config.metaq(config.Meta.OUTPUT_PARTITION)
return super().make_query(ctx, table_reference, column_name)
class SelectCore(AstNode):
"""One part of a compound select"""
@abstract
def make_qt(self, tctx, ob_asts):
"""Make a QueryTable for this core"""
raise NotImplementedError("abstract")
@final
class RowValue(AstNode):
"""Literal row value"""
items = tattr(Expression)
@final
class TableValues(SelectCore):
"""Table from a list of literal values"""
row_values = tattr(RowValue, nonempty=True)
@override
def make_qt(self, tctx, ob_asts):
# pylint: disable=protected-access
if ob_asts:
raise NotImplementedError( # Why bother with the complexity?
"wrap VALUES in a SELECT to sort them")
return self.simplify().__make_table(tctx), ()
@staticmethod
def __values_to_query(tctx, column_values):
# TODO(dancol): remove this code once we have "query facts"
# Here, we try to coalesce VALUES entries that happen to be
# literals, but the query engine itself should really be doing it
# automatically. The trouble is that we can't express the idea
# that although we have a ConcatenatingQuery, we know its size in
# advance, because we have no general-purpose way of asserting to
# the optimizer that some queries belong to a size equivalence
# class, so absent some special hack like this, we're going to end
# up with a dumb concatenation of individual scalar queries, the
# overall effect of which is to even further stupify the "query
# optimizer".
queries = []
literals = []
def _flush_literals():
if literals:
queries.append(QueryNode.literals(*literals))
literals.clear()
for value in column_values:
assert isinstance(value, Expression)
if isinstance(value, AstLiteral):
literals.append(value.evaluate_tvf(tctx))
else:
_flush_literals()
queries.append(value.to_query_uncoordinated(tctx))
_flush_literals()
return QueryNode.concat(*queries)
def __make_table(self, tctx):
row_values = self.row_values
if not all_same(len(rv.items) for rv in row_values):
raise InvalidQueryException(
"all literal values in VALUES must have the same length")
ncol = len(row_values[0].items)
if not ncol:
raise InvalidQueryException("VALUES with no columns")
column_names = ["col{}".format(i) for i in range(ncol)]
values_by_column = tuple(zip(*[rv.items for rv in row_values]))
return GenericQueryTable([
(name, self.__values_to_query(tctx, column_values))
for name, column_values in zip(column_names, values_by_column)
])
@final
class Compound(SelectCore):
"""Set operation on tables"""
left = iattr(SelectCore)
op = iattr(str)
right = iattr(SelectCore)
@override
def make_qt(self, tctx, ob_asts):
left_qt, left_ob_queries = self.left.make_qt(tctx, ob_asts)
right_qt, right_ob_queries = self.right.make_qt(tctx, ob_asts)
if left_qt.table_schema.kind != TableKind.REGULAR or \
right_qt.table_schema.kind != TableKind.REGULAR:
raise InvalidQueryException(
"compound queries supported only on regular tables")
if len(left_qt.columns) != len(right_qt.columns):
raise InvalidQueryException(
"incompatible sub-selections in compound select")
assert len(left_ob_queries) == len(right_ob_queries)
op = self.op
if op in ("union", "union all"):
columns = [
(left_column, QueryNode.concat(left_query, right_query))
for ((left_column, left_query), (_right_column, right_query))
in zip(left_qt.items(), right_qt.items())
]
ob_queries = list(starmap(
QueryNode.concat, zip(left_ob_queries, right_ob_queries)))
if op == "union":
ob_queries = _distinctify_ob_queries(columns, ob_queries)
columns = _distinctify_column_list(columns)
return GenericQueryTable(columns), ob_queries
if op == "intersect":
# TODO(dancol): be faster. A join is a stupidly expensive way
# to implement the intersect operation; we can just use two hash
# tables, like in drop_duplicates. No, we can't just use a
# Pandas MultIndex: MultiIndex implements set operations by
# Python set object operations on tuples!
left_key = tuple(map(second, left_qt.items()))
right_key = tuple(map(second, right_qt.items()))
# We only need the left indexer: we're taking the intersection,
# so all result values occur on both the left and right sides.
# Here, we choose to grab the values from the left side.
# TODO(dancol): use an OR-node to pick a side
side = JoinMeta.LEFT_INDEXER
indexer = JoinMetaQuery(left_key, right_key,
(True,) * len(left_key),
"inner", side)
# We're using an inner join, which means that the LHS indices
# are sequential.
columns = [(name, query.take_sequential(indexer))
for name, query in left_qt.items()]
ob_queries = [query.take_sequential(indexer)
for query in left_ob_queries]
return (GenericQueryTable(_distinctify_column_list(columns)),
_distinctify_ob_queries(columns, ob_queries))
assert op == "except" # N.B. Anti-symmetric set difference!
# TODO(dancol): Don't implement with join.
left_key = tuple(map(second, left_qt.items()))
right_key = tuple(map(second, right_qt.items()))
side = JoinMeta.LEFT_INDEXER
left_indexer = JoinMetaQuery(left_key, right_key,
(True,) * len(left_key),
"left", side)
side = JoinMeta.RIGHT_INDEXER
right_indexer = JoinMetaQuery(left_key, right_key,
(True,) * len(left_key),
"left", side)
# We use a left join. The right indexer contains the row number
# in the right table where we can find the right-side value we
# joined, or -1 if the right side didn't have the value. Since we
# want anti-symmetric set difference, pick from the left table all
# rows where the right indexer indicates that we didn't find a
# corresponding value.
# TODO(dancol): add optimization for the case where we're doing a
# binary operation and broadcasting one side to the same size as
# the other. Also, an OR-node for counting the left or
# right indexer.
is_right_missing = right_indexer < 0
# pylint: disable=redefined-variable-type
indexer = left_indexer.where(is_right_missing)
columns = [(name, query.take_sequential(indexer))
for name, query in left_qt.items()]
ob_queries = [query.take(indexer) for query in left_ob_queries]
return (GenericQueryTable(_distinctify_column_list(columns)),
_distinctify_ob_queries(columns, ob_queries))
class FilledQueryTable(QueryTable):
"""Generate value-less rows
This function generates COUNT empty rows, which is occasionally
useful when combined with other query operations. For example,
this query produces six rows each containing the value 5:
SELECT 5 AS value FROM dctv.filled(6)
"""
count = iattr_query_node_int(default=QueryNode.scalar(1))
@override
def _make_column_tuple(self):
return ()
@override
def _make_column_query(self, column):
raise KeyError(column)
@override
def countq(self):
return self.count
@final
class DummyTableExpression(QueryTableExpression):
"""Used for source-less queries"""
@override
def evaluate_tvf(self, tctx):
return FilledQueryTable()
@final
class FilteringExpressionQueryMaker(ExpressionQueryMaker):
"""Filters another EQM"""
@override
def __init__(self, unfiltered_qm, q_filter):
assert isinstance(unfiltered_qm, ExpressionQueryMaker)
assert isinstance(q_filter, QueryNode)
self.__unfiltered_qm = unfiltered_qm
filter_schema = q_filter.schema
if filter_schema.dtype != BOOL or not filter_schema.is_normal:
raise InvalidQueryException(
"cannot use query of type as filter: {}".format(filter_schema))
self.__q_filter = q_filter
@override
def make_query(self, ctx, table_reference, column_name):
return self.__unfiltered_qm.make_query(
ctx, table_reference, column_name).where(self.__q_filter)
@override
def make_count_query(self, _ctx):
# The number of true entries is the number of rows as boolean.
return NativeUngroupedAggregationQuery(
aggregation="sum",
data=self.__q_filter)
def _distinctify_column_list(columns):
distinct_group = frozenset(map(second, columns))
return [(name, DropDuplicatesQuery(distinct_group, query))
for name, query in columns]
def _distinctify_ob_queries(columns, ob_queries):
distinct_group = frozenset(map(second, columns))
indexer = DropDuplicatesIndexerQuery(distinct_group)
return [ob_query.take_sequential(indexer) for ob_query in ob_queries]
def _make_conjunction_filter(ctx, source, conjunctions):
assert conjunctions
expr = reduce(BinaryOperation.and_, conjunctions)
return expr.to_query(ctx, source)
class GroupBy(AstNode):
"""Controls grouping in a RegularSelect"""
def simplify_expressions(self, ctx, source, subst_column_references):
"""Simplify any columns to which we refer"""
pass
def override_base_source(self, ctx, source): # pylint: disable=unused-argument,no-self-use
"""Return an override for the base select column source"""
return source
@abstract
def make_query_maker(self, ctx, query_maker, out_table_schema):
"""Make the column data source for this grouping"""
raise NotImplementedError("abstract")
@final
class GroupByExpressions(GroupBy):
"""Group by a list of column expressions"""
expressions = tattr(Expression, nonempty=True)
@override
def simplify(self, subst=None):
return self # Will do our own simplification later
@override
def simplify_expressions(self, ctx, source, subst_column_references):
def _fix_expr(expr):
expr = expr.simplify(subst_column_references)
expr.self_resolve(ctx, source)
return expr
assert not ctx.simplified_gb_expressions
ctx.simplified_gb_expressions = tuple(map(_fix_expr, self.expressions))
@override
def make_query_maker(self, ctx, query_maker, out_table_schema):
if out_table_schema.kind != TableKind.REGULAR:
raise InvalidQueryException("{} incompatible with GROUP BY"
.format(out_table_schema))
return GroupingQueryMaker(
query_maker,
tuple(expr.to_query(ctx, query_maker)
for expr in ctx.simplified_gb_expressions))
@final
class MagicGroupingBaseSource(TableExpression):
"""Base column source used for magic grouping operations
"Magic" grouping operations are those that have the JOIN-like
quality of adding columns to the SELECT result set. GroupBy itself
doesn't participate in column selection, so, early in select
processing, the select core wraps its original event source with
this thing, which exposes the columns the GroupBy wants to provide.
"""
base = iattr(TableExpression)
schema = iattr(TableSchema)
@override
def gen_wildcard_matches(self, ctx, path):
table_schema = self.get_table_schema(ctx)
if self.match_cr_path(path):
for column in table_schema.meta_columns:
yield column, self
if self.name and len(path) == 1 and path[0] == self.name:
citer = self.gen_wildcard_matches(ctx, ())
else:
citer = iter(self.base.gen_wildcard_matches(ctx, path))
for column, te in citer:
if column not in table_schema.meta_columns:
yield column, te
@override
def gen_table_paths(self, ctx):
yield from super().gen_table_paths(ctx)
yield from self.base.gen_table_paths(ctx)
@override
def match_column(self, ctx, cr):
if len(cr.path) == 1 and cr.path[0] == self.name:
yield from self.match_column(ctx, cr.as_bare_reference())
return
for meta_column in self.get_table_schema(ctx).meta_columns:
if cr.is_bare_match(meta_column):
# Will be munged by the higher-level grouping source!
yield self, meta_column
return
yield from self.base.match_column(ctx, cr)
@override
def make_count_query(self, ctx):
return ctx.join_info[self].make_count_query(ctx)
@override
def get_aggregate_argument_resolver(self):
return self.base
@override
def push_conditions(self, ctx, conjunctions):
# pylint: disable=redefined-variable-type
assert self not in ctx.join_info
extra_conjunctions = self.base.push_conditions(ctx, conjunctions)
base_qm = self.base.to_schema(ctx, sorting=TableSorting.TIME_MAJOR)
if extra_conjunctions:
my_id = id(self)
for conjunction in extra_conjunctions:
for te_id in conjunction.te_ids_used(ctx):
if te_id == my_id:
raise InvalidQueryException(
"invalid use of magic group metadata column in WHERE: "
"consider naming the base table explicitly instead")
base_qm = FilteringExpressionQueryMaker(
base_qm,
_make_conjunction_filter(
ctx,
base_qm,
extra_conjunctions))
ctx.join_info[self] = base_qm
return ()
@override
def make_query(self, ctx, table_reference, column_name):
if table_reference is self:
raise AssertionError("invalid placeholder column reference {!r}"
.format(column_name))
base_qm = ctx.join_info[self]
return base_qm.make_query(ctx, table_reference, column_name)
@override
def get_table_schema(self, _ctx):
return self.schema
@final
class GroupBySpanPartition(GroupBy):
"""Group by a internal span-table partitions"""
intersect = iattr(bool, default=True) # TODO(dancol): change default
output_partition = iattr(str, nullable=True, default=None)
@override
def override_base_source(self, ctx, source):
source_table_schema = source.get_table_schema(ctx)
if source_table_schema.kind != TableKind.SPAN:
raise InvalidQueryException(
"GROUP USING PARTITION works only with a span table")
if not source_table_schema.partition:
raise InvalidQueryException(
"GROUP USING PARTITION cannot be applied to non-partitioned "
"span table")
table_schema = (
source_table_schema.evolve(partition=self.output_partition,
sorting=TableSorting.NONE)
if self.output_partition else SPAN_UNPARTITIONED_TIME_MAJOR)
return MagicGroupingBaseSource(source, table_schema)
@override
def make_query_maker(self, ctx, query_maker, out_table_schema):
if not isinstance(query_maker, MagicGroupingBaseSource):
raise InvalidQueryException(
"Unsupported use of GROUP USING PARTITION")
source = query_maker.base
source_table_schema = source.get_table_schema(ctx)
assert source_table_schema.kind == TableKind.SPAN
assert source_table_schema.partition
def _q1(column_name):
expr = ColumnReference(column_name)
expr.self_resolve(ctx, source)
return expr.to_query(ctx, query_maker)
# TODO(dancol): get rid of the union/intersect stuff entirely:
# it's not that useful.
partition_q = _q1(source_table_schema.partition)
if self.intersect:
unique_pvals = DropDuplicatesQuery.of1(partition_q)
min_npartitions_q = unique_pvals.countq()
else:
min_npartitions_q = QueryNode.scalar(1)
config = SpanPivot(
grouped=SpanTableConfig(_q1("_ts"),
_q1("_duration"),
partition_q),
output_partition=(_q1(self.output_partition)
if self.output_partition else None),
min_npartitions=min_npartitions_q,
)
return DepartitioningQueryMaker(query_maker, config)
@final
class GroupBySpanQueryMaker(BaseAggregatingQueryMaker):
"""Generates group-by-span queries"""
@override
def __init__(self, base_qm, config):
super().__init__(base_qm)
self.__config = the(SpanGroup, config)
@override
def make_group_sizes_query(self, ctx):
return self._make_aggregation_query(
self.__config.grouped.ts, False, "count")
@override
def make_count_query(self, ctx):
# TODO(dancol): opportunity for an OR-node
return self.__config.metaq(SpanGroup.Meta.TIMESTAMP).countq()
@override
def _make_aggregation_query(self, base_query, distinct, aggfunc):
return self.__config.dataq(aggfunc, base_query, distinct)
@cached_property
def ts_q(self):
"""Timestamp query for the result set"""
return self.__config.metaq(SpanGroup.Meta.TIMESTAMP)
@cached_property
def partition_q(self):
"""Partition query for the result set"""
return self.__config.metaq(SpanGroup.Meta.PARTITION)
@override
def make_query(self, ctx, table_reference, column_name):
if table_reference is self._base_qm:
if column_name == "_ts":
return self.ts_q
if column_name == "_duration":
return self.__config.metaq(SpanGroup.Meta.DURATION)
if column_name == self._base_qm.get_table_schema(ctx).partition:
return self.partition_q
return super().make_query(ctx, table_reference, column_name)
@final
class GroupBySpanGrouper(GroupBy):
"""Group by an explicit list of spans"""
grouper = iattr(TableExpression)
mode = iattr(str)
intersect = iattr(bool, default=False)
def __get_grouper_qt(self, ctx):
grouper_qt = ctx.join_info.get(self)
if not grouper_qt:
grouper_qt_raw = ctx.te_to_qt(self.grouper)
ctx.join_info[self] = grouper_qt = \
grouper_qt_raw.to_schema(kind=TableKind.SPAN,
sorting=TableSorting.TIME_MAJOR)
return grouper_qt
@override
def override_base_source(self, ctx, source):
grouper_qt = self.__get_grouper_qt(ctx)
grouper_schema = grouper_qt.table_schema
grouped_schema = source.get_table_schema(ctx)
mode = self.mode
if mode == "span":
if grouped_schema.kind != TableKind.SPAN:
raise InvalidQueryException("grouped table must be span table")
elif mode == "event":
if grouped_schema.kind != TableKind.EVENT:
raise InvalidQueryException("grouped table must be event table")
else:
raise AssertionError("invalid mode: " + mode)
assert grouper_schema.kind == TableKind.SPAN
assert grouper_schema.sorting == TableSorting.TIME_MAJOR
if grouper_schema.partition and grouped_schema.partition:
if grouper_schema.partition != grouped_schema.partition:
raise InvalidQueryException(
"partition name mismatch in span group: {!r} vs {!r}".format(
grouper_schema.partition, grouped_schema.partition))
if not grouper_schema.partition and grouped_schema.partition:
table_schema = TableSchema(TableKind.SPAN,
grouped_schema.partition,
TableSorting.NONE)
else:
table_schema = grouper_schema.evolve(sorting=TableSorting.NONE)
if not table_schema.partition:
table_schema = table_schema.evolve(sorting=TableSorting.TIME_MAJOR)
return MagicGroupingBaseSource(source, table_schema)
@override
def make_query_maker(self, ctx, query_maker, _out_table_schema):
if not isinstance(query_maker, MagicGroupingBaseSource):
raise InvalidQueryException(
"Unsupported use of GROUP USING SPANS")
source = query_maker.base
def _q1(column_name):
expr = ColumnReference(column_name)
expr.self_resolve(ctx, source)
return expr.to_query(ctx, query_maker)
grouper_qt = self.__get_grouper_qt(ctx)
grouper_schema = grouper_qt.table_schema
grouped_schema = source.get_table_schema(ctx)
# Checked during override_base_source, so we can just assert here.
assert grouped_schema.kind in (TableKind.SPAN, TableKind.EVENT)
assert grouper_schema.kind == TableKind.SPAN
assert grouper_schema.sorting == TableSorting.TIME_MAJOR
partition_q = (
_q1(grouped_schema.partition)
if grouped_schema.partition else None)
grouper_partition_q = (
grouper_qt[grouper_qt.table_schema.partition]
if grouper_schema.partition else None)
unique_pvals_q = (
DropDuplicatesQuery.of1(partition_q)
if (partition_q and
not grouper_partition_q and
not self.intersect)
else None)
grouped = (SpanTableConfig(_q1("_ts"), _q1("_duration"), partition_q)
if grouped_schema.kind == TableKind.SPAN
else EventTableConfig(_q1("_ts"), partition_q))
config = SpanGroup(grouped=grouped,
grouper=SpanTableConfig.from_qt(grouper_qt),
unique_pvals=unique_pvals_q,
intersect=self.intersect)
assert bool(config.is_output_partitioned) == bool(
query_maker.get_table_schema(ctx).partition), \
"{!r} vs {!r}".format(config.is_output_partitioned,
query_maker.get_table_schema(ctx).partition)
return GroupBySpanQueryMaker(query_maker, config)
@final
class RegularSelect(SelectCore):
"""A normal select as a compound operand"""
distinct = iattr(str, default="all")
columns = tattr(Column, default=())
from_ = iattr(TableExpression, nullable=True, default=None)
where = iattr(Expression, nullable=True, default=None)
gb = iattr(GroupBy, nullable=True, default=None)
having = iattr(Expression, nullable=True, default=None)
kind = iattr(str, nullable=True, default=None)
repartition_by = iattr(str, nullable=True, default=None)
def __make_column_expression_map(self, ctx, source, table_schema):
column_expressions = OrderedDict()
disambig = defaultdict(partial(xcount, 1))
for column in self.columns:
for name, expression in \
column.get_column_specs(ctx, source, self.kind, table_schema):
if name == table_schema.partition:
raise InvalidQueryException(
"column {!r} is the partition and must not "
"be explicitly specified".format(name))
if name in table_schema.meta_columns:
raise InvalidQueryException(
"column {!r} is reserved".format(name))
while name in column_expressions:
name = "{}:{}".format(name, next(disambig[name]))
assert isinstance(expression, Expression), \
"expression {!r} should be expression".format(expression)
column_expressions[name] = expression
return column_expressions
def __make_qt_1(self, ctx, ob_asts):
# pylint: disable=redefined-variable-type
distinct = self.distinct
if distinct == "all":
distinct = None
where = self.where
having = self.having
gb = self.gb
# Resolve columns references to specific concrete source
# TableExpression instances, with lexical scoping.
source = self.from_ or DummyTableExpression()
if gb:
source = gb.override_base_source(ctx, source)
# Enforced by the parser.
assert self.repartition_by is None or self.kind == "event"
source.self_resolve_te(ctx)
source_table_schema = source.get_table_schema(ctx)
if self.kind == "span":
if source_table_schema.kind != TableKind.SPAN:
raise InvalidQueryException(
"SELECT SPAN of a non-span table {}"
.format(source_table_schema))
out_table_schema = source_table_schema
elif self.kind == "event":
if source_table_schema.kind != TableKind.EVENT:
raise InvalidQueryException(
"SELECT EVENT of non-event table {}"
.format(source_table_schema))
out_table_schema = source_table_schema
if self.repartition_by is not None:
# TODO(dancol): check if we're actually already sorted
# according to the new partition and use a less conservative
# output sorting than TableSorting.NONE if so.
out_table_schema = out_table_schema.evolve(
partition=self.repartition_by or None,
sorting=TableSorting.NONE)
else:
assert self.kind is None
out_table_schema = REGULAR_TABLE
column_exprs = self.__make_column_expression_map(
ctx, source, out_table_schema)
if out_table_schema.meta_columns:
assert not set(out_table_schema.meta_columns) & set(column_exprs)
# Prepend the span metadata columns unconditionally the list of
# output columns. Iteration is reversed because we prepend at
# each step.
for meta_column in out_table_schema.meta_columns[::-1]:
column_exprs[meta_column] = ColumnReference(meta_column)
column_exprs.move_to_end(meta_column, last=False)
for expr in column_exprs.values():
expr.self_resolve(ctx, source)
# TODO(dancol): remove this check now that we have reordering for
# special tables.
if ob_asts and out_table_schema.kind != TableKind.REGULAR:
raise InvalidQueryException(
"SELECT {} incompatible with ORDER BY: "
"{} tables are specially ordered"
.format(self.kind.upper(), self.kind.lower()))
if where or having or gb or ob_asts:
# The WHERE and HAVING and GROUP BY and ORDER BY expressions can
# refer to the column expressions, and before we munge things,
# these column expressions appear to be table references in the
# expression's AST. Substitute them with the resolved column
# AST. The substitution isn't recursive, so embedded column
# references with names matching aliases won't blow the stack.
# The new expressions share structure with any expressions in
# the column ASTs, so any column references we resolved above
# are still resolved.
explicit_aliases = {
column.name: column.expr # pylint: disable=no-member
for column in self.columns
if isinstance(column, ExpressionColumn) and column.name # pylint: disable=no-member
}
def _subst_column_references(ast_node):
if isinstance(ast_node, ColumnReference) and not ast_node.path:
ast_node = explicit_aliases.get(ast_node.column, ast_node)
return ast_node
if where:
assert not ctx.prohibit_window_functions
ctx.prohibit_window_functions = "WHERE clause"
where = where.simplify(_subst_column_references)
where.self_resolve(ctx, source)
del ctx.prohibit_window_functions
if having:
having = having.simplify(_subst_column_references)
having.self_resolve(ctx, source)
if gb:
gb.simplify_expressions(ctx, source, _subst_column_references)
if ob_asts:
assert isinstance(ob_asts, tuple)
ob_asts = tuple(ob_ast.simplify(_subst_column_references)
for ob_ast in ob_asts)
for ob_ast in ob_asts:
ob_ast.self_resolve(ctx, source)
# We used to combine HAVING and WHERE here, but the set of
# conductions under which we couldn't do that became unwieldy, so
# now we just always keep HAVING and WHERE separate.
conjunctions = where.flatten_conjunctions() if where else ()
top_level_conjunctions = source.push_conditions(ctx, conjunctions)
query_maker = source
if top_level_conjunctions:
query_maker = FilteringExpressionQueryMaker(
source, _make_conjunction_filter(
ctx,
query_maker,
top_level_conjunctions))
if gb:
query_maker = gb.make_query_maker(ctx, query_maker, out_table_schema)
elif ctx.aggregate_functions_used:
if out_table_schema.kind != TableKind.REGULAR:
raise InvalidQueryException(
"SELECT {} incompatible with aggregation functions".format(
TableKind.label_of(out_table_schema.kind)))
query_maker = AggregatingQueryMaker(query_maker)
# Column source configuration complete: actually generate the
# QueryNode instances for our columns.
columns = [(name, expr.to_query(ctx, query_maker))
for name, expr in column_exprs.items()]
ob_queries = [expr.to_sort_query(ctx, query_maker)
for expr in ob_asts]
# Various kinds of postprocessing.
if having:
q_having_filter = having.to_query(ctx, query_maker)
columns = [(name, query.where(q_having_filter))
for name, query in columns]
ob_queries = [query.where(q_having_filter)
for query in ob_queries]
if distinct:
ob_queries = _distinctify_ob_queries(columns, ob_queries)
columns = _distinctify_column_list(columns)
if ctx.regenerate_column_names:
# If we have column-expression subqueries, we learn the right
# auto-generated name only after we do all the above work, so
# re-run the column name generation algorithm to pick up the
# new, better name.
new_column_names = tuple(self.__make_column_expression_map(
ctx, source, out_table_schema).keys())
for i in range(len(out_table_schema.meta_columns),
len(column_exprs)):
columns[i] = new_column_names[i], columns[i][1]
return (GenericQueryTable(columns, table_schema=out_table_schema),
ob_queries)
@override
def make_qt(self, tctx, ob_asts):
ctx = QueryCompilationContext(tctx)
# pylint: disable=protected-access
return self.simplify().__make_qt_1(ctx, ob_asts)
@final
class OrderingTerm(AstNode):
"""Part of an ORDER BY clause"""
expr = iattr(Expression)
direction = iattr(str, default="asc")
def is_ascending(self):
"""Return whether the ordering is for ascending mode"""
assert self.direction in ("asc", "desc")
return self.direction == "asc"
def _filter_to_limit_offset(tctx, base_qt, limit, offset):
ctx = QueryCompilationContext(tctx)
source = DummyTableExpression()
def _resolve_limit_expr(expr):
if not expr:
return None
return (expr
.simplify()
.self_resolve(ctx, source)
.to_query(ctx, source))
limit_q = _resolve_limit_expr(limit)
offset_q = _resolve_limit_expr(offset)
return base_qt.transform(
lambda q: q.limit_offset(limit=limit_q, offset=offset_q))
class DmlContext(ExplicitInheritance):
"""Environment for running queries"""
@override
def __init__(self, ns=None, qe=None):
"""Make a new DMLContext
NS is the namespace to use as the DML "user" namespace, which
namespace-modifying commands affect. If None, create a new
detached namespace.
If QE is supplied, it should be a QueryEngine that allows DML
commands to execute queries as part of their operation, mostly for
IF conditionalizing.
"""
assert not ns or isinstance(ns, Namespace)
self.user_ns = ns or Namespace()
self.qe = qe
def make_tctx(self, args=NO_ARGS): # pylint: disable=dangerous-default-value
"""Make a compilation environment for the normal user NS"""
return TvfContext.from_ns(self.user_ns, args)
@final
def execute(self, commands):
"""Convenience function to execute a batch of DML queries"""
for command in commands:
assert isinstance(command, DmlAction)
command.execute_dml(self)
def mount_trace(self, mount_path, trace_file_name):
"""Mount a trace file in the user namespace"""
raise NotImplementedError("loading traces not supported "
"in basic DmlContext")
class DmlAction(AstNode):
"""AST node corresponding to an action to perform"""
@abstract
def execute_dml(self, dmlctx):
"""Do a thing"""
raise NotImplementedError("abstract")
@abstract
def eval_for_autocomplete(self, dmlctx):
"""Do enough of a thing to fire autocomplete hooks"""
pass # Call into super!
class Select(DmlAction):
"""Query-producing AST node"""
@final
@override
def execute_dml(self, dmlctx):
# If we want to SELECT for real, we'll need to store the
# QueryNode-level QueryContext in the dmlctx. Here, we just make
# the query table for the select, which has the effect of at least
# checking syntax.
return self.make_qt(dmlctx.make_tctx())
@abstract
def make_qt(self, tctx):
"""Make a QueryTable providing results of this query"""
raise NotImplementedError("abstract")
@final
@override
def eval_for_autocomplete(self, dmlctx):
super().eval_for_autocomplete(dmlctx)
self.execute_dml(dmlctx)
@final
class SelectDirect(Select):
"""A SELECT without any lexical bindings"""
core = iattr(SelectCore)
ob = tattr(OrderingTerm, default=())
limit = iattr(Expression, nullable=True, default=None)
offset = iattr(Expression, nullable=True, default=None)
def __check_no_substructure_sharing(self):
# Check that the AST doesn't share substructure. Doing so would
# be super bad, since the SQL-to-QueryTable code uses object
# identity to index AST node attributes.
seen = IdentityDictionary()
for node in self.walk():
assert node not in seen, "should not share structure"
seen[node] = True
@override
def make_qt(self, tctx):
if __debug__:
self.__check_no_substructure_sharing()
core = self.core
ob_asts = tuple(ob.expr for ob in self.ob)
ob_asc = tuple(ob.is_ascending() for ob in self.ob)
base_qt, ob_queries = core.make_qt(tctx, ob_asts)
assert len(ob_queries) == len(ob_asts)
if ob_queries:
indexer = ArgSortQuery(zip(ob_queries, ob_asc))
columns = [(column, query.take(indexer))
for column, query in base_qt.items()]
base_qt = GenericQueryTable(columns)
if self.limit or self.offset:
return _filter_to_limit_offset(
tctx, base_qt, self.limit, self.offset)
return base_qt
@final
class CteBindingName(AstNode):
"""Name of a CTE binding with possible column renaming"""
name = iattr(str)
renaming = tattr(str, default=())
do_rename = iattr(bool, default=False)
@final
class CteBinding(NoWalkAstNode):
"""One lexical binding in a CTE's binding list"""
name = iattr(CteBindingName)
te = iattr(CteBindingValue)
@final
class SelectWithCte(Select):
"""A SELECT with a locally-bound common table expression"""
bindings = tattr(CteBinding)
body = iattr(Select)
@override
def make_qt(self, tctx):
for binding in self.bindings:
table = binding.te.make_cte_value(tctx, binding.name)
tctx = tctx.let({binding.name.name: table})
return self.body.make_qt(tctx)
class CreateDmlAction(DmlAction):
"""DML action that creates something"""
overwrite = iattr(bool, default=False, kwonly=True)
condition = iattr((FunExpr, NoneType), default=None, kwonly=True)
@abstract
def _do_create(self, dmltctx):
raise NotImplementedError("abstract")
def __should_execute(self, dmlctx):
if not self.condition:
return True
# TODO(dancol): skip the query engine invocation if TVF evaluation
# produces an obviously-true or obviously-false false.
q = QueryNode.coerce_(self.condition.evaluate_tvf(dmlctx.make_tctx()))
qe = dmlctx.qe
if not qe:
raise InvalidQueryException("cannot evaluate view conditional here")
(_res_q, res), = qe.execute_for_columns({q})
return len(res) and bool(res[0])
@override
@final
def execute_dml(self, dmlctx):
if self.__should_execute(dmlctx):
self._do_create(dmlctx)
@override
def eval_for_autocomplete(self, dmlctx):
super().eval_for_autocomplete(dmlctx)
if self.condition:
self.__should_execute(dmlctx) # For side effect
@final
class CreateView(CreateDmlAction):
"""View-creation statement"""
path = tattr(str, nonempty=True)
select = iattr((Select, FunExpr))
as_function = iattr(bool, default=False, kwonly=True)
documentation = iattr(str, nullable=True, kwonly=True, default=None)
@override
def _do_create(self, dmlctx):
if self.as_function:
if isinstance(self.select, FunExpr):
raise InvalidQueryException(
"view values must be table in function mode")
value = TableValuedFunction.from_select_ast(
self.select, dmlctx.user_ns)
else:
ns_wr = weakref.ref(dmlctx.user_ns)
def _make_table():
ns = ns_wr()
assert ns
tctx = TvfContext.from_ns(ns)
return (self.select.make_qt(tctx)
if isinstance(self.select, Select)
else self.select.evaluate_tvf(tctx))
value = LazyNsEntry(_make_table)
if self.documentation is not None:
_documentation[value] = self.documentation
dmlctx.user_ns.assign_by_path(self.path,
value,
overwrite=self.overwrite)
@override
def eval_for_autocomplete(self, dmlctx):
super().eval_for_autocomplete(dmlctx)
if not self.as_function and isinstance(self.select, Select):
# TODO(dancol): support autocomplete in the TVF case
self.select.make_qt(TvfContext.from_ns(dmlctx.user_ns))
@final
class Drop(DmlAction):
"""Drop an entire namespace prefix"""
path = tattr(str)
ignore_absent = iattr(bool, default=False)
@override
def execute_dml(self, dmlctx):
try:
dmlctx.user_ns.delete_by_path(self.path)
except KeyError:
if not self.ignore_absent:
raise UnboundReferenceException(
"could not find " + path2str(self.path)) from None
@override
def eval_for_autocomplete(self, dmlctx):
super().eval_for_autocomplete(dmlctx)
# TODO(dancol): implement autocomplete
raise NotImplementedError
@final
class MountTrace(DmlAction):
"""Mount a trace at a location in the query namespace"""
trace_file_name = iattr(str)
mount_path = tattr(str, nonempty=True)
@override
def execute_dml(self, dmlctx):
dmlctx.mount_trace(self.mount_path, self.trace_file_name)
@override
def eval_for_autocomplete(self, dmlctx):
super().eval_for_autocomplete(dmlctx)
# TODO(dancol): implement autocomplete
raise NotImplementedError
@final
class CreateNamespace(CreateDmlAction):
"""Create a SQL namespace"""
path = tattr(str)
@override
def _do_create(self, dmlctx):
dmlctx.user_ns.assign_by_path(self.path,
Namespace(),
overwrite=self.overwrite)
@override
def eval_for_autocomplete(self, dmlctx):
super().eval_for_autocomplete(dmlctx)
# TODO(dancol): implement autocomplete
raise NotImplementedError
# Non-aggregation functions
@final
class AggregationFunction(Immutable):
"""Aggregation function: implemented in native core"""
name = iattr(str)
@override
def _post_init_assert(self):
super()._post_init_assert()
assert self.name in WELL_KNOWN_AGGREGATIONS
class NormalFunction(Immutable):
"""Non-aggregate SQL function: lives in function namespace"""
def hook_self_resolve(self, _ctx, _funcall, _resolver): # pylint: disable=no-self-use
"""Override column resolution hook for FunctionCall.
Return whether we actually override anything.
"""
return False
@abstract
def make_query(self, ctx, te, arguments):
"""Make a QueryNode for this function call"""
raise NotImplementedError("abstract")
@final
class FnNormalFunction(NormalFunction):
"""NormalFunction implemented with a call to a normal Python function
All arguments are evaluated into QueryNode for the call to the
underlying function.
"""
fn = iattr()
# Hack: keyword arguments are evaluated in non-table context.
# TODO(dancol): provide ability to eaily and precisely specify a
# function evaluation schema.
@override
def hook_self_resolve(self, ctx, funcall, resolver):
assert isinstance(funcall, FunctionCall)
for argument in funcall.arguments:
if isinstance(argument, KeywordArgument):
argument.expr.self_resolve(ctx, DummyTableExpression())
else:
argument.self_resolve(ctx, resolver)
return True
@override
def make_query(self, ctx, te, arguments):
args = []
kwargs = {}
for arg_node in arguments:
if isinstance(arg_node, KeywordArgument):
kwargs[arg_node.name] = arg_node.expr.to_query(
ctx, DummyTableExpression())
else:
if kwargs:
raise InvalidQueryException(
"non-keyword arguments after keyword arguments")
args.append(arg_node.to_query(ctx, te))
# pylint: disable=not-callable
return self.fn(*args, **kwargs)
# TVF utilities
def add_sql_documentation(thing, doc):
"""Add SQL documentation to THING
DOC is a string. THING is any object we can stick in a table
namespace; if it's a lazy-evaluation object, the documentation is
automatically propagated to the inflated object that the lazy
evaluation produces.
"""
_documentation[thing] = the(str, doc)
def get_sql_documentation(thing):
"""Get documentation for THING or None if no documentation exists"""
doc = _documentation.get(thing)
if doc is not None:
doc = dedent(doc.strip("\n"))
return doc
def qargs(*args, **kwargs):
"""Make an args object suitable for resolving bind arguments"""
for i, arg in enumerate(args):
kwargs[i] = arg
return kwargs
def _parse_select(s):
# Avoid recursive import detection in pylint
fn = sys.modules["dctv.sql_parser"].parse_select
global _parse_select # pylint: disable=global-variable-undefined
_parse_select = fn
return fn(s)
@final
class LazyNsEntry(Immutable):
"""Resolve a blob of SQL to a QueryTable lazily"""
resolver = iattr()
@override
def _post_init_check(self):
super()._post_init_check()
assert callable(self.resolver)
doc = getattr(self.resolver, "__doc__", None)
if doc is not None:
add_sql_documentation(self, doc)
# Built-in TVFs
@once()
def _make_generate_sequential_spans_tvf():
ast = _parse_select("""
SELECT
CAST((ROW_NUMBER() * :duration_ns + :start_ns) AS UNIT ns) AS _ts,
CAST(:duration_ns AS UNIT ns) AS _duration,
FROM dctv.filled(((:end_ns - :start_ns) // :duration_ns))
""")
return TableValuedFunction.from_select_ast(ast)
def _generate_sequential_spans_tvf(start, end, span_duration):
"""Generate regular sequential spans
(START, END, SPAN_DURATION)
START is the time (coerced to nanoseconds) of the first span in the
sequence of generated spans. END is the time (coerced to
nanoseconds) of the end of the last span in the sequence.
SPAN_DURATION is the duration of each span in the sequence.
Each generated span in the sequence immediately follows the previous
span. If the last span in the sequence would have a duration
shorter than SPAN_DURATION to end at END, that span is omitted from
the sequence.
"""
def _to_ns(value):
return QueryNode.coerce_(value).to_unit(ureg().ns).strip_unit()
return _internal_cast_as_span_table(
_make_generate_sequential_spans_tvf()(
start_ns=_to_ns(start),
end_ns=_to_ns(end),
duration_ns=_to_ns(span_duration)),
verify=False,
sort=False)
@once()
def _with_all_partitions_tvf_implicit():
ast = _parse_select("""
SELECT SPAN *
FROM
(SELECT SPAN
FROM ?0
GROUP AND INTERSECT SPANS USING PARTITIONS)
SPAN BROADCAST INTO SPAN PARTITIONS
?0
""")
return TableValuedFunction.from_select_ast(ast)
def _with_all_partitions_tvf(qt):
"""Filter a span table so that all partitions are present
QT is a partitioned span table. The output of this function is a
partitioned span table that contains no spans except for those
regions of the timeline during which _all_ partitions in QT are
present. For example, if QT is partitioned by CPU, the output of
this routine a set of spans during which all CPUs are present, e.g.,
when none is hotplugged off.
"""
return _with_all_partitions_tvf_implicit()(qt)
def _span_starts_tvf(span_table):
"""Convert span starts to event table
The output of this function is an event table (partitioned like the
input span table) in which each span start is an event.
This function is useful when reinterpreting one kind of span as
another.
"""
qt = QueryTable.coerce_(span_table)
if qt.table_schema.kind != TableKind.SPAN:
raise InvalidQueryException("not a span table: {}".format(qt))
return GenericQueryTable(
{column: qt[column]
for column in qt.columns if column != "_duration"},
table_schema=qt.table_schema.evolve(kind=TableKind.EVENT))
def _span_ends_tvf(span_table):
"""Convert span ends to event table
The output of this function is an event table (partitioned like the
input span table) in which each span end is an event. This function
is useful when reinterpreting one kind of span as another.
"""
qt = QueryTable.coerce_(span_table)
if qt.table_schema.kind != TableKind.SPAN:
raise InvalidQueryException("not a span table: {}".format(qt))
return GenericQueryTable(
{column: (qt[column] + qt["_duration"]
if column == "_ts" else qt[column])
for column in qt.columns if column != "_duration"},
table_schema=qt.table_schema.evolve(kind=TableKind.EVENT))
@final
class LagLeadFunction(NormalFunction):
"""Evalutes arguments to LAG and LEAD in correct context"""
fn = iattr()
@override
def hook_self_resolve(self, ctx, funcall, resolver):
if ctx.prohibit_window_functions:
raise InvalidQueryException(
"window function not allowed in {}".format(
ctx.prohibit_window_functions))
assert isinstance(funcall, FunctionCall)
arguments = funcall.arguments
if not 1 <= len(arguments) <= 3:
raise InvalidQueryException(
"LAG/LEAD expect between one and three arguments: got {}"
.format(len(arguments)))
for argno, argument in enumerate(arguments):
if not isinstance(argument, Expression):
raise InvalidQueryException("LAG/LEAD expect no keyword arguments")
# Make sure that the first argument of lag/lead is resolved in
# root lexical context, not the query's.
argument.self_resolve(ctx,
resolver if not argno
else DummyTableExpression())
return True
@override
def make_query(self, ctx, te, arguments):
"""Make a QueryNode for this function call"""
# We know self.fn is callable pylint: disable=not-callable
dummy_te = DummyTableExpression()
return self.fn(arguments[0].to_query(ctx, te),
*[argument.to_query(ctx, dummy_te)
for argument in arguments[1:]])
# TODO(dancol): make LAG and LEAD more efficient. The way we do it
# here requires full copies of the inputs. We need to either make
# dedicated LAG/LEAD operators or teach the optimizer about
# count thresholds.
def _lead(query, offset, default=None):
offset_q = QueryNode.coerce_(offset)
return QueryNode.concat(
query.slice(offset_q, None),
QueryNode.filled(QueryNode.coerce_(default),
offset_q +
QueryNode.least(query.countq() - offset_q, 0)))
def _lag(query, offset=1, default=None):
offset_q = QueryNode.coerce_(offset)
return QueryNode.concat(
QueryNode.filled(QueryNode.coerce_(default),
QueryNode.least(query.countq(), offset_q)),
query.slice(0, -offset_q))
def extend_spans_tvf(qt, addl):
"""Grow each span in a span table
(QT, ADDL)
QT is a span table. ADDL is an amount of time (coerced to
nanoseconds) by which to extend each span in the span table. If, as
a result of this extension, spans would illegally overlap, the spans
are instead merged.
The resulting span table has no payload. To get the payload of each
extended span, SPAN GROUP with the original table.
"""
qt = qt.to_schema(kind=TableKind.SPAN,
sorting=TableSorting.TIME_MAJOR)
ts_q = qt["_ts"]
duration_q = qt["_duration"]
if not qt.table_schema.partition:
partition_q = None
else:
partition_q = qt[qt.table_schema.partition]
fixup = SpanFixup(ts_q, duration_q + addl, partition_q)
columns = [("_ts", fixup.metaq(SpanFixup.Meta.TS)),
("_duration", fixup.metaq(SpanFixup.Meta.DURATION))]
if partition_q:
columns.append(
(qt.table_schema.partition,
fixup.metaq(SpanFixup.Meta.PARTITION)))
table_schema = TableSchema(TableKind.SPAN,
qt.table_schema.partition,
TableSorting.NONE)
else:
table_schema = SPAN_UNPARTITIONED_TIME_MAJOR
return GenericQueryTable(columns, table_schema=table_schema)
def _internal_cast_as_span_table(tbl,
*,
partition=None,
sorting="time_major",
sort=True,
verify=True):
"""Treat a regular table as a span table"""
tbl = QueryTable.coerce_(tbl)
# We just throw here because we haven't implemented these features,
# but we at least make callers explicitly opt out of the planned
# safety checks.
if verify:
# TODO(dancol): implement span verification
raise NotImplementedError("Span verification not implemented")
if sort:
# TODO(dancol): implement sort
raise NotImplementedError("Sorting not implemented")
if tbl.table_schema.kind == TableKind.SPAN:
return tbl
if tbl.table_schema.kind != TableKind.REGULAR:
raise InvalidQueryException("not a regular table: {}".format(tbl))
if "_ts" not in tbl:
raise InvalidQueryException("table missing timestamp column")
ts_q = tbl["_ts"]
if ts_q.schema != TS_SCHEMA:
raise InvalidQueryException("invalid timestamp schema: {}"
.format(ts_q.schema))
if "_duration" not in tbl:
raise InvalidQueryException("table missing duration column")
duration_q = tbl["_duration"]
if duration_q.schema != DURATION_SCHEMA:
raise InvalidQueryException("invalid duration schema: {}"
.format(duration_q.schema))
if partition:
if partition not in tbl:
raise InvalidQueryException(
"partition {!r} not in table".format(partition))
table_schema = TableSchema(TableKind.SPAN,
partition,
TableSorting(sorting))
# TODO(dancol): we want to insert some kind of constraint-checking
# (i.e., ensuring that the output actually looks like a span table)
# in debug builds instead of just forwarding the queries blindly!
return GenericQueryTable(
((column, tbl[column]) for column in tbl.columns),
table_schema=table_schema)
def _internal_cast_as_event_table(tbl,
*,
partition,
sorting,
sort,
verify):
"""Treat a regular table as an event table"""
tbl = QueryTable.coerce_(tbl)
# We just throw here because we haven't implemented these features,
# but we at least make callers explicitly opt out of the planned
# safety checks.
if verify:
# TODO(dancol): implement event verification
raise NotImplementedError("verify not implemented")
if sort:
# TODO(dancol): implement sort
raise NotImplementedError("sort not implemented")
if tbl.table_schema.kind == TableKind.EVENT:
return tbl
if tbl.table_schema.kind != TableKind.REGULAR:
raise InvalidQueryException("not a regular table: {}".format(tbl))
if "_ts" not in tbl:
raise InvalidQueryException("table missing timestamp column")
ts_q = tbl["_ts"]
if ts_q.schema != TS_SCHEMA:
raise InvalidQueryException("invalid timestamp column: {}"
.format(ts_q.schema))
if partition:
if partition not in tbl:
raise InvalidQueryException(
"partition {!r} not in table".format(partition))
table_schema = TableSchema(TableKind.EVENT,
partition,
TableSorting(sorting))
return GenericQueryTable(
((column, tbl[column]) for column in tbl.columns),
table_schema=table_schema)
class RowNumberFunction(NormalFunction):
"""Implement SQL's ROW_NUMBER() function"""
@override
def make_query(self, ctx, te, arguments):
if arguments:
raise InvalidQueryException("ROW_NUMBER takes no arguments")
return QueryNode.sequential(te.make_count_query(ctx))
@final
class ReplaceSchemaQuery(SimpleQueryNode):
"""Pass through data while advertising a different schema"""
_query = iattr(QueryNode, name="query")
_schema = iattr(QuerySchema, name="schema")
@override
def countq(self):
return self._query.countq()
@override
def _compute_inputs(self):
return (self._query,)
@override
def _compute_schema(self):
return self._schema
@override
async def run_async(self, qe):
[ic], [oc] = await qe.async_setup([self._query], [self])
await passthrough(qe, ic, oc)
def _assert_unique(query, *, verify=True):
# TODO(dancol): add runtime verification that the constraint holds?
if isinstance(verify, QueryNode):
verify = verify.eager_evaluate_scalar()
if isinstance(verify, QueryNode):
raise InvalidQueryException(
"verify argument to assert_unique must be constant")
if verify:
raise NotImplementedError("ASSERT_UNIQUE VERIFY")
if UNIQUE not in query.schema.constraints:
query = ReplaceSchemaQuery(query, query.schema.constrain(C_UNIQUE))
return query
def backfill_tvf(qt):
"""Auto-fill NULL values in a span table
For each group of contiguous values in a partition, fill all NULL
values in each column with the first non-NULL value in that column
or NULL if no contiguous group of values has a non-NULL value.
This function is useful in cases where data sources may not line up
exactly and we want to "go back in time" and fill in some initial
region of unknown data with whatever we discover the data to be
later.
"""
qt = (QueryTable
.coerce_(qt)
.to_schema(sorting=TableSorting.PARTITION_MAJOR))
meta = SpanTableConfig.from_qt(qt)
return qt.transform(
lambda q: Backfill(meta, q).metaq(Backfill.Meta.VALUE))
def _get_thread_analysis():
"""Thread analysis object"""
from .thread_analysis import ThreadAnalysis
return ThreadAnalysis
def _identity_tvf(arg):
"""Return argument unchanged
This function is occasionally useful for forcing TVF evaluation.
"""
return arg
@once()
def make_standard_lexical_environment():
"""Bootstrap the environment that's always available to users"""
tvf = TableValuedFunction
ns = Namespace()
ns["dctv"] = ns_dctv = Namespace()
ns_dctv["internal"] = ns_internal = Namespace()
ns_internal.disable_autocomplete = True
things = (
(("dict",), dict),
(("list",), lambda *args: list(args)),
(("dctv", "time_series_to_spans"), tvf(TimeSeriesQueryTable)),
(("dctv", "thread_analysis"), LazyNsEntry(_get_thread_analysis)),
(("dctv", "generate_sequential_spans"),
tvf(_generate_sequential_spans_tvf),),
(("dctv", "span_starts"), tvf(_span_starts_tvf)),
(("dctv", "span_ends"), tvf(_span_ends_tvf)),
(("dctv", "filled"), tvf(FilledQueryTable)),
(("dctv", "internal", "cast_as_span_table"),
tvf(_internal_cast_as_span_table)),
(("dctv", "internal", "cast_as_event_table"),
tvf(_internal_cast_as_event_table)),
(("dctv", "with_all_partitions"), tvf(_with_all_partitions_tvf)),
(("dctv", "identity"), tvf(_identity_tvf)),
(("dctv", "extend_spans"), tvf(extend_spans_tvf)),
(("dctv", "backfill"), tvf(backfill_tvf)),
)
for path, value in things:
ns.assign_by_path(path, value)
# _unfn is a hack for implementing non-aggregate SQL functions in
# terms of syntactically-inaccessible unary operator invocations.
def _unfn(op):
return FnNormalFunction(partial(UnaryOperationQuery, op))
all_aggfuncs = [
AggregationFunction(aggfunc)
for aggfunc in WELL_KNOWN_AGGREGATIONS
]
columnwise_things = chain(
[(("dctv", fn.name,), fn) for fn in all_aggfuncs],
[((fn.name,), fn) for fn in all_aggfuncs
if fn.name in ("count", "max", "min", "prod", "sum")],
(
(("floor",), _unfn("floor")),
(("ceil",), _unfn("ceiling")),
(("ceiling",), _unfn("ceiling")),
(("round",), _unfn("round")),
(("trunc",), _unfn("trunc")),
(("coalesce",), FnNormalFunction(CoalesceQuery.of)),
(("lag",), LagLeadFunction(_lag)),
(("lead",), LagLeadFunction(_lead)),
(("greatest",), FnNormalFunction(QueryNode.greatest)),
(("least",), FnNormalFunction(partial(QueryNode.least))),
(("row_number",), RowNumberFunction()),
(("if",), FnNormalFunction(QueryNode.choose)),
(("assert_unique",), FnNormalFunction(_assert_unique)),
(("fls",), FnNormalFunction(FlsQuery.of)),
))
for path, value in columnwise_things:
ns.assign_by_path(_munge_columnwise_path(path), value)
return LexicalEnvironment(None, ns)