| # Copyright (C) 2020 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http:#www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """SQL front-end to query.py""" |
| |
| # TODO(dancol): we should support partitioning by multiple columns; |
| # right now, it's zero or one. |
| |
| import logging |
| import operator as oper |
| import sys |
| import weakref |
| from io import StringIO |
| from collections import OrderedDict, defaultdict |
| from collections.abc import Mapping |
| from functools import partial, reduce |
| from itertools import chain, count as xcount, starmap |
| from textwrap import dedent |
| |
| from cytoolz import ( |
| first, |
| second, |
| ) |
| |
| from modernmp.util import ( |
| assert_seq_type, |
| once, |
| the, |
| ) |
| from .util import ( |
| BOOL, |
| CaseInsensitiveCasePreservingDict, |
| EqImmutable, |
| ExplicitInheritance, |
| Immutable, |
| NoneType, |
| UsageTrackingDictionary, |
| abstract, |
| all_same, |
| cached_property, |
| final, |
| iattr, |
| override, |
| tattr, |
| ureg, |
| ) |
| from .iddict import ( |
| IdentityDictionary, |
| WeakKeyIdentityDictionary, |
| ) |
| |
| from .query import ( |
| ArgSortQuery, |
| BadUnitsException, |
| BinaryOperationQuery, |
| C_UNIQUE, |
| CoalesceQuery, |
| CollateQuery, |
| DURATION_SCHEMA, |
| FlsQuery, |
| GenericQueryTable, |
| InvalidQueryException, |
| QueryNode, |
| QuerySchema, |
| QueryTable, |
| REGULAR_TABLE, |
| SPAN_UNPARTITIONED_TIME_MAJOR, |
| SimpleQueryNode, |
| SqlAttributeLookup, |
| TS_SCHEMA, |
| TableKind, |
| TableSchema, |
| TableSorting, |
| UNIQUE, |
| UnaryOperationQuery, |
| iattr_query_node_int, |
| passthrough, |
| ) |
| |
| from .queryop import ( |
| Backfill, |
| DropDuplicatesIndexerQuery, |
| DropDuplicatesQuery, |
| EventJoin, |
| EventTableConfig, |
| GroupCountQuery, |
| GroupLabelsQuery, |
| GroupSizesQuery, |
| JoinMeta, |
| JoinMetaQuery, |
| NativeGroupedAggregationQuery, |
| NativeUngroupedAggregationQuery, |
| SpanFixup, |
| SpanGroup, |
| SpanJoin, |
| SpanPivot, |
| SpanTableConfig, |
| TimeSeriesQueryTable, |
| WELL_KNOWN_AGGREGATIONS, |
| ) |
| from .sql_util import identifier_quote |
| |
| log = logging.getLogger(__name__) |
| |
| NO_ARGS = {} |
| """Dummy dict for when we have no subst args in compilation""" |
| |
| _documentation = WeakKeyIdentityDictionary() |
| """Holds documentation for various table-namespace things""" |
| |
| |
| |
| def _munge_columnwise_path(path): |
| # SQL columnwise functions live in a namespace distinct from |
| # that of tables, which we simulate by munging columnwise |
| # function names and looking them up in our single namespace. |
| assert not isinstance(path, str) |
| return tuple(path[:-1]) + (("__colfunc_" + path[-1]),) |
| |
| |
| |
| class LazyResolverFailedError(RuntimeError): |
| """Raised when a lazy resolver fails with a special exception |
| |
| We wrap KeyError in this exception type so that we don't confuse |
| resolver failure with a missing attribute. |
| """ |
| |
| class Namespace(CaseInsensitiveCasePreservingDict, SqlAttributeLookup): |
| """SQL object namespace |
| |
| SQL conceptually has at least two namespaces: one for tables and one |
| for aggregate functions. Rather than complicate lookup by actually |
| tracking these various namespaces, we just transform |
| non-table-namespace entry names before lookup. |
| |
| SQL namespaces themselves are case-sensitive and case-preserving. |
| SQL language case insensitivity is on the lookup side. |
| """ |
| |
| disable_autocomplete = False |
| """Set to true to prevent autocomplete resolution""" |
| |
| class KeyExistsError(InvalidQueryException): |
| """Exception raised on assignment if a value already exists""" |
| |
| class NotNamespaceError(InvalidQueryException): |
| """Excepted raised on traversal through non-namespace""" |
| |
| @final |
| def assign_by_path(self, path, item, *, |
| overwrite=False, |
| make_namespaces=False): |
| """Assign a value in the namespace. |
| |
| PATH is a sequence of names. ITEM is the value to assign. |
| If OVERWRITE is true and an item already exists at leaf position, |
| overwrite that item. If the item already exists and OVERWRITE is |
| false, raise KeyExistsError. |
| |
| All non-final entries in PATH must refer to existing namespaces if |
| MAKE_NAMESPACES is false. If this condition is violated, raise |
| KeyError if a name is absent or NotNamespaceError if a value is |
| present but not a namespace. If MAKE_NAMESPACES is true, create |
| empty namespace objects along path as needed, but never replacing |
| existing values. |
| """ |
| assert not isinstance(path, str) |
| last_idx = len(path) - 1 |
| assert last_idx >= 0 |
| ns = self |
| for i, name in enumerate(path): |
| if i == last_idx: |
| if not overwrite and name in ns: |
| raise self.KeyExistsError( |
| "key {} already exists".format(identifier_quote(name))) |
| ns[name] = item |
| else: |
| try: |
| ns = ns[name] |
| except KeyError: |
| if not make_namespaces: |
| raise |
| ns[name] = ns = Namespace() |
| else: |
| if not isinstance(ns, Namespace): |
| raise self.NotNamespaceError |
| |
| @final |
| def delete_by_path(self, path): |
| """Delete the thing at PATH. |
| |
| If PATH refers to a namespace, all values under that namespace are |
| dropped. |
| |
| All non-final entries in PATH must refer to existing namespaces. |
| If this constraint is violated, raise KeyError if a name is absent |
| or NotNamespaceError if a value is present but not a namespace. |
| """ |
| assert not isinstance(path, str) |
| last_idx = len(path) - 1 |
| assert last_idx >= 0 |
| ns = self |
| for i, name in enumerate(path): |
| if i == last_idx: |
| del ns[name] |
| else: |
| ns = ns[name] |
| if not isinstance(ns, Namespace): |
| raise self.NotNamespaceError |
| |
| @final |
| @override |
| def __getitem__(self, key): |
| value = super().__getitem__(key) |
| if isinstance(value, LazyNsEntry): |
| # Make sure to propagate any documentation to the value |
| # being created. |
| doc = _documentation.get(value) |
| try: |
| self[key] = value = value.resolver() |
| except KeyError as ex: |
| # Wrap the exception so a failure in resolver doesn't look |
| # like a missing key. |
| raise LazyResolverFailedError from ex |
| if doc is not None: |
| _documentation[value] = doc |
| return value |
| |
| lookup_sql_attribute = __getitem__ |
| |
| @final |
| @override |
| def enumerate_sql_attributes(self): |
| return tuple(self) |
| |
| @final |
| def walk(self): |
| """Generate left node contents of namespace. |
| |
| Yield (PATH, VALUE) tuples, where PATH is a tuple giving the path |
| to an item and VALUE is the item stored. |
| """ |
| for name in sorted(self): |
| value = self[name] |
| if isinstance(value, Namespace): |
| for sub_name, sub_value in value.walk(): |
| yield name + "." + sub_name, sub_value |
| else: |
| yield name, value |
| |
| def sql_qt(self, sql, args=NO_ARGS): # pylint: disable=dangerous-default-value |
| """Make a QueryTable out of SQL evaluated in this namespace""" |
| return _parse_select(sql).make_qt(TvfContext.from_ns(self, args)) |
| |
| __inherit__ = dict(lookup_sql_attribute=override) |
| |
| EMPTY_NS = Namespace() |
| |
| |
| |
| @final |
| class LexicalEnvironment(SqlAttributeLookup): |
| """Lexical environment providing symbols for query compilation""" |
| |
| @override |
| def __init__(self, parent, ns): |
| assert isinstance(parent, (NoneType, LexicalEnvironment)) |
| assert isinstance(ns, Namespace) |
| self.__parent = parent |
| self.__ns = ns |
| |
| def chain(self, ns): |
| """Bind NS and return a new lexical environment""" |
| return LexicalEnvironment(self, ns) |
| |
| @override |
| def lookup_sql_attribute(self, name): |
| try: |
| return self.__ns.lookup_sql_attribute(name) |
| except KeyError: |
| parent = self.__parent |
| if parent is not None: |
| return parent.lookup_sql_attribute(name) |
| raise |
| |
| @override |
| def enumerate_sql_attributes(self): |
| attr = set(super().enumerate_sql_attributes()) |
| attr.update(self.__ns.enumerate_sql_attributes()) |
| if self.__parent is not None: |
| attr.update(self.__parent.enumerate_sql_attributes()) |
| return attr |
| |
| @final |
| class TvfContext(ExplicitInheritance): |
| """Everything we need to evaluate functions in TVF context""" |
| |
| max_rce_number = 0 |
| """Bookkeeping for recursive CTE transformation""" |
| |
| @override |
| def __init__(self, lexenv, args): |
| assert isinstance(lexenv, LexicalEnvironment) |
| assert isinstance(args, Mapping) |
| self.__lexenv = lexenv |
| self.__args = args |
| self.lookup_sql_attribute_by_path = \ |
| lexenv.lookup_sql_attribute_by_path |
| |
| @staticmethod |
| def from_ns(ns=None, args=NO_ARGS): # pylint: disable=dangerous-default-value |
| """Make a TvfContext from a namespace and arguments. |
| |
| The standard namespace is available automatically. If NS is not |
| none, it must be a SQL namespace that acts as an overlay on top of |
| the standard namespace. |
| """ |
| lexenv = make_standard_lexical_environment() |
| if ns: |
| lexenv = lexenv.chain(ns) |
| return TvfContext(lexenv, args) |
| |
| def get_bind_value(self, ref): |
| """Resolve query compilation environment option""" |
| return self.__args[ref] |
| |
| def let(self, bindings): |
| """Prepend a lexical environment to this TvfContext |
| |
| BINDINGS is a mapping describing the additional bindings that the |
| new context should have. |
| |
| Return a new TvfContext incorporating the new bindings, leaving |
| self unchanged. |
| """ |
| ns = Namespace() |
| ns.update(bindings) |
| tctx = TvfContext(self.__lexenv.chain(ns), self.__args) |
| if self.max_rce_number: |
| tctx.max_rce_number = self.max_rce_number |
| return tctx |
| |
| @property |
| def lexenv(self): |
| """The lexical environment for this TVF""" |
| return self.__lexenv |
| |
| @final |
| class QueryCompilationContext(ExplicitInheritance): |
| """Local data kept during tokens""" |
| |
| regenerate_column_names = False |
| prohibit_window_functions = False |
| |
| @override |
| def __init__(self, tctx): |
| assert isinstance(tctx, TvfContext) |
| self.tctx = tctx |
| self.colrefs = IdentityDictionary() |
| self.join_info = IdentityDictionary() |
| self.__te_to_qt = IdentityDictionary() |
| self.__functions_by_funcall = IdentityDictionary() |
| |
| aggregate_functions_used = False |
| simplified_gb_expressions = None |
| |
| def te_to_qt(self, te): |
| """Generate the QueryTable for the given table expression.""" |
| assert isinstance(te, QueryTableExpression), \ |
| "expected a QTE: got {!r}".format(te) |
| qt = self.__te_to_qt.get(te) |
| if not qt: |
| self.__te_to_qt[te] = qt = te.uncached_make_qt(self) |
| assert isinstance(qt, QueryTable), \ |
| "wanted a QueryTable: got {!r}".format(qt) |
| return qt |
| |
| def resolve_funcall(self, funcall): |
| """Find and cache the function object for AST node FUNCALL""" |
| assert isinstance(funcall, FunctionCall) |
| function = self.__functions_by_funcall.get(funcall) |
| if not function: |
| self.__functions_by_funcall[funcall] = function = \ |
| self.tctx.lookup_sql_attribute_by_path( |
| _munge_columnwise_path(funcall.path)) |
| return function |
| |
| |
| |
| class UnboundReferenceException(InvalidQueryException): |
| """Exception thrown for an unbound reference in a query""" |
| |
| |
| |
| class AstNode(EqImmutable): |
| """Root note of a SQL parse""" |
| |
| def get_children(self): |
| """Return a list of child nodes""" |
| for value in self.__dict__.values(): |
| if isinstance(value, AstNode): |
| yield value |
| if isinstance(value, tuple): |
| for subvalue in value: |
| if isinstance(subvalue, AstNode): |
| yield subvalue |
| |
| def walk(self): |
| """Yield self and children recursively""" |
| yield self |
| for child in self.get_children(): |
| yield from child.walk() |
| |
| def simplify(self, subst=None): |
| """Apply local simplifications |
| |
| If SUBST is non-None, it is a function of one argument, an |
| AstNode. It returns a node to substitute. Any node that subst |
| changes will not be re-simplified by this routine. SUBST should |
| call simplify() explicitly to continue simplification. |
| |
| Return self if unchanged or a new node if something changed. |
| """ |
| if subst: |
| new_self = subst(self) |
| if self is not new_self: |
| return new_self |
| changes = {} |
| for key, value in tuple(self.__dict__.items()): |
| if isinstance(value, AstNode): |
| simplified_value = value.simplify(subst) |
| elif isinstance(value, tuple): |
| simplified_value = tuple((part if isinstance(part, str) |
| else part.simplify(subst)) |
| for part in value) |
| if not any(orig is not new |
| for orig, new in zip(value, simplified_value)): |
| simplified_value = value |
| else: |
| continue |
| if simplified_value is not value: |
| changes[key] = simplified_value |
| if not changes: |
| return self |
| return self.evolve(**changes) |
| |
| class NoWalkAstNode(AstNode): |
| """Mixin for not walking children""" |
| |
| @final |
| @override |
| def get_children(self): |
| return () # Don't walk subqueries |
| |
| @final |
| @override |
| def simplify(self, subst=None): |
| return self # Don't walk subqueries |
| |
| class Expression(AstNode): |
| """AST node for a value computed""" |
| |
| def self_resolve(self, ctx, resolver): |
| """Resolve column references and return self""" |
| for child in self.get_children(): |
| child.self_resolve(ctx, resolver) |
| return self |
| |
| @abstract |
| def to_query(self, ctx, te): |
| """Build the QueryNode this expression represents""" |
| raise NotImplementedError("abstract") |
| |
| @final |
| def to_query_uncoordinated(self, tctx): |
| """Build a QueryNode that can't reference query lexical context |
| |
| We use this method in contexts like VALUES in which we want to |
| generate expressions, but can't let expressions refer to |
| surrounding query context. |
| """ |
| assert isinstance(tctx, TvfContext) |
| return self.to_query(QueryCompilationContext(tctx), |
| DummyTableExpression()) |
| |
| def flatten_conjunctions(self): |
| """Transform expression into tuple of AND-ed expressions""" |
| return self, |
| |
| def decompose_into_equi_join(self, _ctx, _left_te_id, _right_te_id): # pylint: disable=no-self-use |
| """Extract equi-join conditions from expression. |
| |
| If this expression doesn't represent an equi-join, return a false |
| value. Otherwise, return a tuple (LEFT_EXPR, RIGHT_EXPR, OP) of |
| expressions to match on each side of the join. LEFT_TE_ID and |
| RIGHT_TE_ID are sets of table expression IDs on the left and right |
| sides of the join. OP is the equi-join operation, which is either |
| "=" or "<=>" for no-NULL and NULL-allowed matches, respectively. |
| """ |
| return None |
| |
| def te_ids_used(self, ctx): |
| """All TableExpressions in this expression""" |
| return frozenset(id(ctx.colrefs[tr][0]) |
| for tr in self.walk() |
| if isinstance(tr, ColumnReference)) |
| |
| def unwrap_collation(self): |
| """Return COLLATION, AST_NODE decoding CollationTag""" |
| return None, self |
| |
| def to_sort_query(self, ctx, te): |
| """Make a QueryNode suitable for use in sorting operations""" |
| query = self.to_query(ctx, te) |
| if query.schema.is_string: |
| query = CollateQuery(query, "binary") |
| return query |
| |
| @abstract |
| def dump(self, ctx, out): |
| """Write to OUT a simple text description of this operation. |
| |
| This method is used for automatically forming column names. The resulting |
| string should be adequate for human inspection, but doesn't have to completely |
| describe the expression.""" |
| raise NotImplementedError("abstract") |
| |
| def get_auto_name(self, ctx): |
| """Automatic column name for this expression |
| |
| If we set ctx.regenerate_column_names, we re-run the column naming |
| pass after building the query, at which time we might have more |
| information and generate a better name than we did the first time. |
| """ |
| out = StringIO() |
| self.dump(ctx, out) |
| return out.getvalue() |
| |
| @final |
| class Cast(Expression): |
| """Data type cast""" |
| expr = iattr(Expression) |
| dtype = iattr(str) |
| safe = iattr(bool) |
| |
| @override |
| def to_query(self, ctx, te): |
| return self.expr.to_query(ctx, te).to_dtype(self.dtype, self.safe) |
| |
| @override |
| def dump(self, ctx, out): |
| out.write("CAST(") |
| self.expr.dump(ctx, out) |
| out.write(", ") |
| out.write(identifier_quote(self.dtype)) |
| out.write(")") |
| |
| @abstract |
| class CteBindingValue(AstNode): |
| """AST node that provides a value for a CTE binding""" |
| |
| @abstract |
| def make_cte_value(self, tctx, cte_name): |
| """Make the value to which we should bind a CTE. |
| |
| TCTX is a TvfContext. |
| |
| CTE_NAME is the common table expression binding name (a |
| CteBindingName) for which we're making this table. |
| Having CTE_NAME available here is important for recursive CTEs, |
| which need to lexically bind references to CTE_NAME to the table |
| under construction. We generally ignore CTE_NAME otherwise. |
| |
| This function returns a value that we embed into the SQL table |
| namespace. This value is usually a QueryTable, but it can be any |
| object. The ability to embed arbitrary objects is sometimes |
| useful. |
| """ |
| raise NotImplementedError("abstract") |
| |
| @abstract |
| class FunExpr(CteBindingValue): |
| """AST node for an expression in a table-valued funcall""" |
| |
| @abstract |
| def evaluate_tvf(self, tctx): |
| """Parse-time evaluation of a table-valued function |
| |
| TCTX is a TvfContext for the evaluation. |
| """ |
| raise NotImplementedError("abstract") |
| |
| @final |
| @override |
| def make_cte_value(self, tctx, cte_name): |
| assert isinstance(cte_name, CteBindingName) |
| value = self.evaluate_tvf(tctx) |
| if cte_name.do_rename: |
| if not isinstance(value, QueryTable): |
| raise InvalidQueryException( |
| "can rename columns in a CTE only if " |
| "the CTE evaluates to a table!") |
| value = GenericQueryTable.rename_columns(value, cte_name.renaming) |
| return value |
| |
| @final |
| class UnaryFunExprOperation(FunExpr): |
| """Unary operation in table-valued funcall""" |
| operator = iattr(str) |
| argument = iattr(FunExpr) |
| |
| OPS = { |
| "-": oper.neg, |
| "+": oper.pos, |
| "~": oper.inv, |
| "!": oper.not_, |
| } |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| return self.OPS[self.operator](self.argument.evaluate_tvf(tctx)) |
| |
| @final |
| class BinaryFunExprOperation(FunExpr): |
| """Binary operation in table-valued funcall""" |
| left = iattr(FunExpr) |
| operator = iattr(str) |
| right = iattr(FunExpr) |
| |
| OPS = { |
| "*": oper.mul, |
| "/": oper.truediv, |
| "%": oper.mod, |
| "//": oper.floordiv, |
| "+": oper.add, |
| "-": oper.sub, |
| "<<": oper.lshift, |
| ">>": oper.rshift, |
| "&": oper.and_, |
| "|": oper.or_, |
| "=": oper.eq, |
| "<=>": oper.eq, |
| "<!=>": oper.ne, |
| ">=": oper.ge, |
| ">": oper.gt, |
| "<=": oper.le, |
| "<": oper.lt, |
| "<>": oper.ne, |
| "!=": oper.ne, |
| "==": oper.eq, |
| "is": oper.is_, |
| "is not": oper.is_not, |
| # and/or special-cased |
| } |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| operator = self.operator |
| if operator == "and": |
| return self.left.evaluate_tvf(tctx) and self.right.evaluate_tvf(tctx) |
| if operator == "or": |
| return self.left.evaluate_tvf(tctx) or self.right.evaluate_tvf(tctx) |
| left = self.left.evaluate_tvf(tctx) |
| right = self.right.evaluate_tvf(tctx) |
| return self.OPS[self.operator](left, right) |
| |
| class AstLiteral(Expression, FunExpr): |
| """Parent of literal values""" |
| __abstract__ = True |
| |
| @final |
| @override |
| def to_query(self, ctx, te): |
| # TODO(dancol): use broadcasting instead of count |
| value = QueryNode.coerce_(self.evaluate_tvf(ctx.tctx)) |
| return QueryNode.filled(value, te.make_count_query(ctx)) |
| |
| @final |
| class IntegerLiteral(AstLiteral): |
| """AST node for a literal integer""" |
| value = iattr(int, converter=int) |
| unit = iattr(default=None) |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| if self.unit: |
| try: |
| unit = ureg().parse_units(self.unit) |
| except: |
| raise BadUnitsException("Invalid unit {!r}".format(self.unit)) |
| return self.value * unit |
| return self.value |
| |
| @override |
| def dump(self, ctx, out): |
| # We leave out the unit because it's frequently repeated in column |
| # headings |
| out.write("{}".format(self.value)) |
| |
| @final |
| class FloatLiteral(AstLiteral): |
| """AST node for a literal float""" |
| value = iattr(float, converter=float) |
| unit = iattr(default=None) |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| if self.unit: |
| return self.value * ureg().parse_units(self.unit) |
| return self.value |
| |
| @override |
| def dump(self, ctx, out): |
| # We leave out the unit because it's frequently repeated in column |
| # headings |
| out.write("{}".format(self.value)) |
| |
| def quote(s): |
| """Quote a string as a SQL string""" |
| return "'{}'".format(s.replace("'", "''")) |
| |
| def path2str(path): |
| """Convert a path sequence to a path string""" |
| return ".".join(map(identifier_quote, path)) |
| |
| @final |
| class StringLiteral(AstLiteral): |
| """AST node for a literal thing""" |
| value = iattr(str, converter=str) |
| |
| @staticmethod |
| def decode(unescaped): |
| """Parse a quoted SQL string into its unquoted payload""" |
| delim = unescaped[0] |
| assert unescaped.endswith(delim), \ |
| "invalid SQL string {!r}".format(unescaped) |
| return unescaped[1:-1].replace(delim*2, delim) |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| return self.value |
| |
| @override |
| def dump(self, ctx, out): |
| out.write(quote(self.value)) |
| |
| @final |
| class TrueLiteral(AstLiteral): |
| """AST node for true""" |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| return True |
| |
| @override |
| def dump(self, ctx, out): |
| out.write("true") |
| |
| @final |
| class FalseLiteral(AstLiteral): |
| """AST node for false""" |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| return False |
| |
| @override |
| def dump(self, ctx, out): |
| out.write("true") |
| |
| @final |
| class NullLiteral(AstLiteral): |
| """AST node for NULL""" |
| @override |
| def evaluate_tvf(self, tctx): |
| return None |
| |
| @override |
| def dump(self, ctx, out): |
| out.write("true") |
| |
| @final |
| class KeywordArgument(AstNode): |
| """AST node for a function keyword argument""" |
| name = iattr(str) |
| expr = iattr(Expression) |
| |
| def dump(self, ctx, out): |
| """Dump for syntax""" |
| out.write(identifier_quote(self.name)) |
| out.write("=>") |
| self.expr.dump(ctx, out) |
| |
| @final |
| class FunExprKeywordArgument(AstNode): |
| """AST node for function keyword argument in TVF context""" |
| name = iattr(str) |
| expr = iattr(FunExpr) |
| |
| @final |
| class FunExprList(FunExpr): |
| """AST node for a list literal in TBF context""" |
| parts = tattr(FunExpr) |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| return [part.evaluate_tvf(tctx) for part in self.parts] |
| |
| @final |
| class FunExprDictItem(AstNode): |
| """Root for dict parts of the TVF AST""" |
| kw = iattr((FunExpr, str)) |
| expr = iattr(FunExpr) |
| |
| def evaluate(self, tctx): |
| """Create a dict pair from this AST node""" |
| kw = self.kw |
| if isinstance(kw, FunExpr): |
| kw = kw.evaluate_tvf(tctx) |
| return kw, self.expr.evaluate_tvf(tctx) |
| |
| @final |
| class FunExprDict(FunExpr): |
| """Dictionary literal""" |
| items = tattr(FunExprDictItem) |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| return dict(item.evaluate(tctx) for item in self.items) |
| |
| @final |
| class FunctionCall(Expression): |
| """AST node for a function call""" |
| path = tattr(str, nonempty=True) |
| arguments = tattr((Expression, KeywordArgument), default=()) |
| distinct = iattr(bool, default=False) |
| |
| @override |
| def self_resolve(self, ctx, resolver): |
| function = ctx.resolve_funcall(self) |
| if isinstance(function, AggregationFunction): |
| ctx.aggregate_functions_used = True |
| resolver = resolver.get_aggregate_argument_resolver() |
| if (isinstance(function, NormalFunction) and |
| function.hook_self_resolve(ctx, self, resolver)): |
| return self |
| return super().self_resolve(ctx, resolver) |
| |
| @override |
| def to_query(self, ctx, te): |
| return te.make_funcall_query(ctx, self) |
| |
| @override |
| def dump(self, ctx, out): |
| out.write(".".join(map(identifier_quote, self.path))) |
| out.write("(") |
| if self.distinct: |
| out.write("DISTINCT") |
| if self.arguments: |
| out.write(" ") |
| first_argument = True |
| for argument in self.arguments: |
| if first_argument: |
| first_argument = False |
| else: |
| out.write(", ") |
| argument.dump(ctx, out) |
| out.write(")") |
| |
| @final |
| class ColumnReference(Expression): |
| """AST node for a "variable" reference""" |
| column = iattr(str) |
| path = tattr(str, default=()) |
| |
| @override |
| def __str__(self): |
| return ".".join(map(identifier_quote, |
| self.path + (self.column,))) |
| |
| @override |
| def self_resolve(self, ctx, resolver): |
| if self not in ctx.colrefs: |
| ctx.colrefs[self] = resolver.match_column_unique(ctx, self) |
| return self |
| |
| @override |
| def to_query(self, ctx, te): |
| table_reference, column_name = ctx.colrefs[self] |
| return te.make_query(ctx, table_reference, column_name) |
| |
| def is_bare_match(self, column): |
| """Return whether we're a bare match for column-name COLUMN""" |
| return not self.path and self.column == column |
| |
| def as_bare_reference(self): |
| """Strip everything but the leaf name of this ColumnReference |
| |
| Return a new ColumnReference""" |
| return ColumnReference(self.column) |
| |
| @override |
| def dump(self, ctx, out): |
| out.write(str(self)) |
| |
| @final |
| class SubqueryExpression(Expression, NoWalkAstNode): |
| """AST node for a subquery in column-expression context |
| |
| We do not yet support coordinated subqueries, so the subquery is |
| just evaluated on its own and glommed onto the result set. |
| The query must have a single column and length of one or the table |
| length, of course. |
| """ |
| |
| # We can't specify the type we really want: it'd make a |
| # circular reference. |
| subquery = iattr(AstNode) |
| |
| @override |
| def to_query(self, ctx, te): |
| assert isinstance(self.subquery, Select) |
| qt = ctx.join_info.get(self) |
| if not qt: |
| # pylint: disable=no-member |
| ctx.join_info[self] = qt = self.subquery.make_qt(ctx.tctx) |
| return QueryNode.filled(QueryNode.coerce_(qt), te.make_count_query(ctx)) |
| |
| @override |
| def dump(self, ctx, out): |
| qt = ctx.join_info.get(self) |
| if qt and len(qt.columns) == 1: |
| out.write(first(qt.columns)) |
| else: |
| ctx.regenerate_column_names = True |
| out.write("{subquery}") |
| |
| @final |
| class BinaryOperation(Expression): |
| """AST node for a two-operand operation""" |
| left = iattr(Expression) |
| operator = iattr(str) |
| right = iattr(Expression) |
| |
| EQUIV = { |
| "<>": "!=", |
| "==": "=", |
| "is": "<=>", |
| "is not": "<!=>", |
| "is distinct from": "<!=>", |
| "is not distinct from": "<=>", |
| } |
| |
| @staticmethod |
| def and_(left, right): |
| """Convenience function for generating a conjunction""" |
| return BinaryOperation(left, "and", right) |
| |
| @staticmethod |
| def eq(left, right): |
| """Convenience function for equality""" |
| return BinaryOperation(left, "=", right) |
| |
| @override |
| def to_query(self, ctx, te): |
| left_collation, left = self.left.unwrap_collation() |
| right_collation, right = self.right.unwrap_collation() |
| return BinaryOperationQuery( |
| left.to_query(ctx, te), |
| self.operator, |
| right.to_query(ctx, te), |
| left_collation or right_collation or "binary") |
| |
| def __decompose_into_equi_join_1(self, |
| ctx, |
| expr_left, |
| expr_right, |
| left_te_id, |
| right_te_id): |
| expr_left_te_id = expr_left.te_ids_used(ctx) |
| expr_left_contained = expr_left_te_id <= left_te_id |
| expr_right_te_id = expr_right.te_ids_used(ctx) |
| expr_right_contained = expr_right_te_id <= right_te_id |
| if (expr_left_contained and expr_right_contained and |
| expr_left_te_id and expr_right_te_id): |
| return (expr_left, expr_right, self.operator) |
| return None |
| |
| @override |
| def decompose_into_equi_join(self, ctx, left_te_id, right_te_id): |
| return (self.operator in ("=", "<=>") and |
| ((self.__decompose_into_equi_join_1(ctx, self.left, self.right, |
| left_te_id, right_te_id)) |
| or |
| (self.__decompose_into_equi_join_1(ctx, self.right, self.left, |
| left_te_id, right_te_id)))) |
| |
| @override |
| def simplify(self, subst=None): |
| equiv = self.EQUIV.get(self.operator) |
| if equiv: |
| return BinaryOperation(self.left, equiv, self.right).simplify(subst) |
| # We transform and this way at the AST level so that code breaking |
| # down where conditions into conjugation lists works on the |
| # simplest possible such list. |
| if self.operator == "and": |
| if isinstance(self.left, TrueLiteral) and \ |
| isinstance(self.right, TrueLiteral): |
| return TrueLiteral().simplify(subst) |
| if isinstance(self.left, TrueLiteral): |
| return self.right.simplify(subst) |
| if isinstance(self.right, TrueLiteral): |
| return self.left.simplify(subst) |
| return super().simplify(subst) |
| |
| @override |
| def flatten_conjunctions(self): |
| if self.operator == "and": |
| return (self.left.flatten_conjunctions() + |
| self.right.flatten_conjunctions()) |
| return self, |
| |
| @override |
| def dump(self, ctx, out): |
| out.write("(") |
| self.left.dump(ctx, out) |
| out.write(" {} ".format(self.operator)) |
| self.right.dump(ctx, out) |
| out.write(")") |
| |
| @final |
| class UnaryOperation(Expression): |
| """AST node for a single-operand expression""" |
| operator = iattr(str) |
| argument = iattr(Expression) |
| |
| EQUIV = { |
| "not": "!", |
| } |
| |
| @override |
| def to_query(self, ctx, te): |
| return UnaryOperationQuery(self.operator, |
| self.argument.to_query(ctx, te)) |
| |
| @override |
| def simplify(self, subst=None): |
| equiv = self.EQUIV.get(self.operator) |
| if equiv: |
| return UnaryOperation(equiv, self.argument).simplify(subst) |
| return super().simplify(subst) |
| |
| @override |
| def dump(self, ctx, out): |
| out.write("( {} ".format(self.operator)) |
| self.argument.dump(ctx, out) |
| out.write(")") |
| |
| @final |
| class CollationTag(Expression): |
| """AST node for expression tagged with collation""" |
| expr = iattr(Expression) |
| collation = iattr(str) |
| |
| @override |
| def unwrap_collation(self): |
| return self.collation, self.expr |
| |
| @override |
| def to_query(self, ctx, te): |
| raise InvalidQueryException("invalid use of COLLATE") |
| |
| @override |
| def to_sort_query(self, ctx, te): |
| return CollateQuery(self.expr.to_query(ctx, te), |
| self.collation) |
| |
| @override |
| def dump(self, ctx, out): |
| self.expr.dump(ctx, out) |
| out.write(" COLLATE {}".format(self.collation)) |
| |
| @final |
| class BetweenAndOperation(Expression): |
| """AST node for an X BETWEEN Y AND Z expression""" |
| value = iattr(Expression) |
| lower = iattr(Expression) |
| upper = iattr(Expression) |
| |
| @override |
| def to_query(self, ctx, te): |
| raise AssertionError("should have been syntactically simplified") |
| |
| @override |
| def simplify(self, subst=None): |
| return BinaryOperation( |
| BinaryOperation(self.lower, "<=", self.value), |
| "and", |
| BinaryOperation(self.value, "<=", self.upper)) \ |
| .simplify(subst) |
| |
| @override |
| def dump(self, ctx, out): |
| out.write("(") |
| self.value.dump(ctx, out) |
| out.write(" BETWEEN ") |
| self.lower.dump(ctx, out) |
| out.write(" AND ") |
| self.upper.dump(ctx, out) |
| out.write(")") |
| |
| @final |
| class UnitConversion(Expression): |
| """AST node for a unit conversion""" |
| value = iattr(Expression) |
| unit = iattr() |
| allow_unitless = iattr(bool, default=False) |
| |
| @override |
| def to_query(self, ctx, te): |
| base_query = self.value.to_query(ctx, te) |
| if base_query.schema.unit == self.unit: |
| return base_query |
| if self.unit: |
| unit = ureg().parse_units(self.unit) |
| return base_query.to_unit(unit, |
| allow_unitless=self.allow_unitless) |
| return base_query.strip_unit() |
| |
| @override |
| def dump(self, ctx, out): |
| self.value.dump(ctx, out) |
| out.write(" IN {}".format( |
| identifier_quote(self.unit) if self.unit else "NULL")) |
| |
| |
| |
| class Column(AstNode): |
| """One column specification in a SELECT's column list""" |
| @abstract |
| def get_column_specs(self, ctx, source, is_gen_special, table_schema): |
| """Return a sequence of column spec pairs. Each column spec pair is |
| (NAME, EXPRESSION). |
| """ |
| raise NotImplementedError("abstract") |
| |
| @final |
| class WildcardColumn(Column): |
| """A column specification with a wildcard""" |
| path = tattr(str, default=()) |
| @override |
| def get_column_specs(self, ctx, source, is_gen_special, table_schema): |
| matchless = True |
| assert (not is_gen_special) or table_schema.kind != TableKind.REGULAR, \ |
| "if we're generating non-regular output, we must have one on input" |
| for column_name, te in source.gen_wildcard_matches(ctx, self.path): |
| # If we're generating a span, we always include _ts and |
| # _duration in higher level code, so we should never match these |
| # columns via wildcard. |
| if is_gen_special and column_name in table_schema.meta_columns: |
| matchless = False |
| continue |
| cr = ColumnReference(column_name) |
| # Pre-resolve this new ColumnReference to wire it up to the TE |
| # we have here and skip name resolution. Recall that we look up |
| # ColumnReference instances by object identity, so cr here |
| # is unique. |
| ctx.colrefs[cr] = (te, column_name) |
| yield column_name, cr |
| matchless = False |
| if matchless and self.path: |
| raise InvalidQueryException( |
| "column specification '{}{}*' matched nothing".format( |
| ".".join(map(identifier_quote, self.path)), |
| "." if self.path else "")) |
| |
| @final |
| class ExpressionColumn(Column): |
| """A column specification from an expression""" |
| expr = iattr(Expression) |
| name = iattr(str, nullable=True, default=None) |
| |
| @override |
| def get_column_specs(self, ctx, _source, _is_gen_special, _table_schema): |
| return ( |
| (self.name or self.expr.get_auto_name(ctx), self.expr), |
| ) |
| |
| class ExpressionQueryMaker(ExplicitInheritance): |
| """Interface that expressions use for making queries in traversal""" |
| |
| @abstract |
| def make_count_query(self, ctx): |
| """Make a query yielding the length of this result set""" |
| raise NotImplementedError("abstract") |
| |
| @abstract |
| def make_query(self, ctx, table_reference, column_name): |
| """Make a QueryNode through lens of this table expression. The result |
| may differ from the raw QueryNode in the case of joins. |
| """ |
| raise NotImplementedError("abstract") |
| |
| def make_funcall_query(self, ctx, funcall): # pylint: disable=no-self-use |
| """Translate a function call to a QueryNode""" |
| function = ctx.resolve_funcall(funcall) |
| # The aggregating and grouping TableExpressions override this |
| # function and do something useful here for aggregating functions. |
| if isinstance(function, AggregationFunction): |
| raise InvalidQueryException("illegal use of aggregate function") |
| if funcall.distinct: |
| raise InvalidQueryException( |
| "DISTINCT/ALL is only for aggregate functions") |
| if isinstance(function, NormalFunction): |
| return function.make_query(ctx, self, funcall.arguments) |
| raise InvalidQueryException( |
| "illegal call to non-function {!r}".format(function)) |
| |
| class TableExpression(AstNode, ExpressionQueryMaker): |
| """A table-valued expression""" |
| |
| name = iattr(str, nullable=True, default=None, kwonly=True) |
| """Explicitly-specified name given using AS""" |
| |
| def te_walk(self): |
| """Enumerate table expressions""" |
| return self, |
| |
| @cached_property |
| def te_ids(self): |
| """frozenset of TE IDs used by this TE and its children""" |
| return frozenset(id(te) for te in self.te_walk()) |
| |
| def rename_columns(self, _name, _renaming): |
| """Clone with new name and renamed columns""" |
| raise InvalidQueryException( |
| "renaming columns not supported on node of type {}".format(type(self))) |
| |
| @abstract |
| def push_conditions(self, ctx, conditions): |
| """Apply conditions and configure joins |
| |
| Return a sequence of "excess" conjunctions we couldn't apply at |
| levels below. |
| """ |
| raise NotImplementedError("abstract") |
| |
| def self_resolve_te(self, ctx): |
| """Note which tables our table references refer to""" |
| pass |
| |
| @abstract |
| def gen_wildcard_matches(self, path): |
| """Generate columns matching a column selection wildcard prefix |
| |
| Each value yielded is a pair (COLUMN_NAME, TABLE_EXPRESSION) where |
| COLUMN_NAME is a string providing the name of the column and |
| TABLE_EXPRESSION is the TableExpression object providing that |
| column. |
| """ |
| raise NotImplementedError("abstract") |
| |
| def match_cr_path(self, path): |
| """Match columns against path component of a column reference |
| |
| Used when generating wildcard columns and resolving column references""" |
| return not path or (len(path) == 1 and path[0] == self.name) |
| |
| @abstract |
| def get_table_schema(self, ctx): |
| """The span table schema""" |
| raise NotImplementedError("abstract") |
| |
| def get_aggregate_argument_resolver(self): |
| """Get the resolver for the arguments of an aggregate function. |
| |
| Used for "magic grouping" operators to make the _ts in SUM(_ts) |
| resolve against the underlying table and not the magic grouping |
| placeholder table source. |
| """ |
| return self |
| |
| def match_column(self, ctx, cr): |
| """Yield a sequence of matches for column reference CR. |
| |
| Each successful match is a (TABLE_EXPRESSION, COLUMN_NAME) pair. |
| """ |
| if self.match_cr_path(cr.path) and cr.column in ctx.te_to_qt(self): |
| return (self, cr.column), |
| return () |
| |
| def gen_table_paths(self, ctx): # pylint: disable=unused-argument |
| """Yield a table names that are in scope for column references |
| |
| Duplicates are allowed. Used for completion. |
| """ |
| yield () # Unqualified table references are always legal |
| if self.name: |
| yield (self.name,) |
| |
| @final |
| def match_column_unique(self, ctx, cr): |
| """Solve column reference CR to one column or raise""" |
| matches = list(self.match_column(ctx, cr)) |
| if not matches: |
| raise UnboundReferenceException( |
| "column reference {} is unbound".format(cr)) |
| if len(matches) > 1: |
| raise InvalidQueryException( |
| "column reference {} is ambiguous".format(cr)) |
| return matches[0] |
| |
| def to_schema(self, ctx, *, kind=None, sorting=None): |
| """Return a new TE munging this one""" |
| table_schema = self.get_table_schema(ctx) |
| if not kind: |
| kind = table_schema.kind |
| if not sorting: |
| sorting = table_schema.sorting |
| if table_schema.kind == kind and \ |
| table_schema.sorting == sorting: |
| return self |
| return ReSortingQueryMaker(self, table_schema, kind, sorting) |
| |
| @final |
| class ReSortingQueryMaker(ExpressionQueryMaker): |
| """Query maker that transparently restores span invariants""" |
| |
| __indexer = None |
| |
| @override |
| def __init__(self, |
| base, |
| base_table_schema, |
| wanted_kind, |
| wanted_sorting): |
| self.__base = the(ExpressionQueryMaker, base) |
| self.__base_table_schema = the(TableSchema, base_table_schema) |
| self.__wanted_sorting = the(TableSorting, wanted_sorting) |
| if base_table_schema.kind != wanted_kind: |
| raise InvalidQueryException( |
| "wanted a table kind={} found a {}".format( |
| TableKind.label_of(wanted_kind), |
| base_table_schema)) |
| |
| def __make_indexer(self, ctx): |
| base = self.__base |
| base_table_schema = self.__base_table_schema |
| wanted_sorting = self.__wanted_sorting |
| if wanted_sorting == TableSorting.NONE: |
| return False # Okay |
| assert base_table_schema.kind in (TableKind.EVENT, TableKind.SPAN) |
| partition = base_table_schema.partition |
| if not partition: |
| if base_table_schema.sorting != TableSorting.NONE: |
| # Without a partition, time-major and partition-major sorting |
| # are the same thing, so we don't need to do any work. |
| return False |
| sorts = ("_ts",) |
| elif wanted_sorting == TableSorting.TIME_MAJOR: |
| sorts = ("_ts", partition) |
| else: |
| assert wanted_sorting == TableSorting.PARTITION_MAJOR |
| sorts = (partition, "_ts") |
| return ArgSortQuery([ |
| (ColumnReference(sort) |
| .self_resolve(ctx, base) |
| .to_query(ctx, base), |
| True) for sort in sorts |
| ]) |
| |
| def __get_indexer(self, ctx): |
| indexer = self.__indexer |
| if indexer is None: |
| self.__indexer = indexer = self.__make_indexer(ctx) |
| return indexer |
| |
| def __munge(self, ctx, query): |
| indexer = self.__get_indexer(ctx) |
| return query.take(indexer) if indexer else query |
| |
| @override |
| def make_count_query(self, ctx): |
| return self.__base.make_count_query(ctx) |
| |
| @override |
| def make_query(self, ctx, table_reference, column_name): |
| return self.__munge( |
| ctx, |
| self.__base.make_query(ctx, table_reference, column_name)) |
| |
| @override |
| def make_funcall_query(self, ctx, funcall): |
| return self.__munge( |
| ctx, |
| self.__base.make_funcall_query(ctx, funcall)) |
| |
| class QueryTableExpression(TableExpression, FunExpr): |
| """TableExpression based on a QueryTable""" |
| |
| __abstract__ = True |
| |
| @override |
| def rename_columns(self, new_name, renaming): |
| return RenamingQueryTableExpression(self, renaming, name=new_name) |
| |
| @final |
| def uncached_make_qt(self, ctx): |
| """Make a QueryTable""" |
| return QueryTable.coerce_(self.evaluate_tvf(ctx.tctx)) |
| |
| @final |
| @override |
| def make_count_query(self, ctx): |
| return ctx.te_to_qt(self).countq() |
| |
| @final |
| @override |
| def push_conditions(self, ctx, conjunctions): |
| # TODO(dancol): push conditions, when safe, down to |
| # leaf nodes in the join tree |
| return conjunctions # Will apply at higher level |
| |
| @final |
| @override |
| def make_query(self, ctx, table_reference, column_name): |
| assert table_reference is self |
| return ctx.te_to_qt(self)[column_name] |
| |
| @final |
| @override |
| def get_table_schema(self, ctx): |
| return ctx.te_to_qt(self).table_schema |
| |
| @final |
| @override |
| def gen_wildcard_matches(self, ctx, path): |
| if self.match_cr_path(path): |
| for column in ctx.te_to_qt(self).columns: |
| yield column, self |
| |
| @final |
| class RenamingQueryTableExpression(QueryTableExpression): |
| """QueryTableExpression that provides column renaming""" |
| base = iattr(QueryTableExpression) |
| renaming = iattr() |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| return GenericQueryTable.rename_columns( |
| QueryTable.coerce_(self.base.evaluate_tvf(tctx)), |
| self.renaming) |
| |
| @final |
| class BindParameter(AstLiteral, QueryTableExpression): |
| """Bind parameter""" |
| ref = iattr((int, str)) |
| |
| def __get_value(self, tctx): |
| try: |
| return tctx.get_bind_value(self.ref) |
| except KeyError: |
| raise InvalidQueryException( |
| "no bind parameter given for {}".format( |
| self.__bind_name)) from None |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| return self.__get_value(tctx) |
| |
| @property |
| def __bind_name(self): |
| if isinstance(self.ref, int): |
| return "?{}".format(self.ref) |
| return ":{}".format(identifier_quote(self.ref)) |
| |
| @override |
| def dump(self, ctx, out): |
| out.write(self.__bind_name) |
| |
| @final |
| class TableValuedFunction(SqlAttributeLookup, ExplicitInheritance): |
| """Provides a parameterized table to queries""" |
| |
| @cached_property |
| def sql_attributes(self): # pylint: disable=no-self-use |
| """Dictionary of SQL attributes""" |
| return {} |
| |
| @override |
| def lookup_sql_attribute(self, key): |
| if key in self.sql_attributes: |
| return self.sql_attributes[key] |
| return super().lookup_sql_attribute(key) |
| |
| @override |
| def __init__(self, function): |
| assert callable(function) |
| self.__function = function |
| doc = getattr(function, "__doc__") |
| if doc is not None: |
| add_sql_documentation(self, doc) |
| |
| def __call__(self, *args, **kwargs): |
| return self.__function(*args, **kwargs) |
| |
| @staticmethod |
| def from_select_ast(select_ast, ns=None): |
| """Create a TVF from a query with bind parameters""" |
| assert isinstance(select_ast, Select) |
| assert isinstance(ns, (NoneType, Namespace)) |
| ns_wr = weakref.ref(ns) if ns else None |
| def _function(*args, **kwargs): |
| args = UsageTrackingDictionary(qargs(*args, **kwargs)) |
| if ns_wr: |
| ns = ns_wr() |
| assert ns |
| else: |
| ns = None |
| qt = select_ast.make_qt(TvfContext.from_ns(ns, args)) |
| unused_arguments = set(args.keys()) - args.used |
| if unused_arguments: |
| raise InvalidQueryException( |
| "unused arguments: {}".format(unused_arguments)) |
| return qt |
| return TableValuedFunction(_function) |
| |
| def _do_tvf_call(tctx, tvf, arguments): |
| """Common dispatch code for TvfFunctionCall and TableFunctionCall""" |
| args = [] |
| kwargs = {} |
| for arg_node in arguments: |
| if isinstance(arg_node, FunExprKeywordArgument): |
| kwargs[arg_node.name] = arg_node.expr.evaluate_tvf(tctx) |
| else: |
| if kwargs: |
| raise InvalidQueryException( |
| "non-keyword arguments after keyword arguments") |
| args.append(arg_node.evaluate_tvf(tctx)) |
| return tvf(*args, **kwargs) |
| |
| @final |
| class TvfFunctionCall(FunExpr): |
| """AST node for a function call in TVF context""" |
| fn_expr = iattr(FunExpr) |
| arguments = tattr((FunExpr, FunExprKeywordArgument), default=()) |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| return _do_tvf_call(tctx, |
| self.fn_expr.evaluate_tvf(tctx), |
| self.arguments) |
| |
| @final |
| class TableFunctionCall(QueryTableExpression): |
| """AST node for a call to a table-valued function as a query source""" |
| # We can't just reuse TvfFunctionCall because in the table-source |
| # case, we need to store a path naming the function, not an |
| # expression yielding a function argument. |
| path = tattr(str, nonempty=True) |
| arguments = tattr((FunExpr, FunExprKeywordArgument), default=()) |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| tvf = tctx.lookup_sql_attribute_by_path(self.path) |
| if not isinstance(tvf, TableValuedFunction): |
| raise InvalidQueryException( |
| "{} does not refer to a table-valued function".format( |
| path2str(self.path))) |
| return _do_tvf_call(tctx, tvf, self.arguments) |
| |
| @override |
| def gen_table_paths(self, ctx): |
| yield from super().gen_table_paths(ctx) |
| yield self.path |
| |
| @final |
| @override |
| def match_cr_path(self, path): |
| return super().match_cr_path(path) or self.path == path |
| |
| @final |
| class TableReference(QueryTableExpression): |
| """Simple reference to a named value in table namespace""" |
| path = tattr(str, nonempty=True) |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| try: |
| return tctx.lookup_sql_attribute_by_path(self.path) |
| except KeyError: |
| raise UnboundReferenceException( |
| "No table-namespace value {}".format(path2str(self.path))) \ |
| from None |
| |
| @override |
| def gen_table_paths(self, ctx): |
| yield from super().gen_table_paths(ctx) |
| yield self.path |
| |
| @override |
| def match_cr_path(self, path): |
| return super().match_cr_path(path) or self.path == path |
| |
| @final |
| class TvfDot(FunExpr): |
| """Attribute or sub-namespace access in TVF context""" |
| expr = iattr(FunExpr) |
| name = iattr(str) |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| obj = self.expr.evaluate_tvf(tctx) |
| if not isinstance(obj, SqlAttributeLookup): |
| raise InvalidQueryException( |
| "object of type {} does not support attribute access".format( |
| type(obj))) |
| return obj.lookup_sql_attribute(self.name) |
| |
| @final |
| class TableSubquery(QueryTableExpression, NoWalkAstNode): |
| """A subquery used as a table expression""" |
| subquery = iattr(AstNode) |
| |
| @override |
| def evaluate_tvf(self, tctx): |
| assert isinstance(self.subquery, Select) |
| return self.subquery.make_qt(tctx) # pylint: disable=no-member |
| |
| |
| # Joins |
| |
| LEFT, RIGHT = 0, 1 |
| |
| @final |
| class ColumnNameList(AstNode): |
| """List of comma names used as a join condition""" |
| columns = tattr(str, default=()) |
| |
| def _drop_duplicates_by_id(sequence): |
| seen = set() |
| for item in sequence: |
| if id(item) not in seen: |
| seen.add(id(item)) |
| yield item |
| |
| def _drop_duplicates(sequence): |
| seen = set() |
| for item in sequence: |
| if item not in seen: |
| seen.add(item) |
| yield item |
| |
| class JoinInfo(ExplicitInheritance): |
| """Behavioral guts of a join operation""" |
| |
| __indexers = None |
| |
| @override |
| def __init__(self, ctx, join): # pylint: disable=unused-argument |
| self._join = the(Join, join) |
| self.__ctx_wr = weakref.ref(the(QueryCompilationContext, ctx)) |
| |
| @property |
| def _ctx(self): |
| ctx = self.__ctx_wr() |
| assert ctx |
| return ctx |
| |
| @abstract |
| def _take_for_join(self, is_right, base_query): |
| """Do the take operation for a query side |
| |
| IS_RIGHT is true iff BASE_QUERY comes from the RHS |
| of the join. |
| |
| BASE_QUERY is the query to index into. |
| """ |
| raise NotImplementedError("abstract") |
| |
| @abstract |
| def get_table_schema(self): |
| """The table schema for this join""" |
| raise NotImplementedError("abstract") |
| |
| def get_self_columns(self): |
| """Sequence of columns provided by the join itself""" |
| return self.get_table_schema().meta_columns |
| |
| @abstract |
| def push_conditions(self, ctx, conjunctions): |
| """Push predicates down into join conditions if needed |
| |
| See TableExpression.push_conditions(). |
| """ |
| raise NotImplementedError("abstract") |
| |
| @abstract |
| def make_count_query(self, ctx): |
| """Make QueryNode yielding result set size""" |
| raise NotImplementedError("abstract") |
| |
| def _make_query_for_self(self, _ctx, column_name): # pylint: disable=no-self-use |
| raise InvalidQueryException("invalid join reference " |
| + repr(column_name)) |
| |
| def __get_side(self, te): |
| is_right = id(te) in self._join.right.te_ids |
| assert is_right or id(te) in self._join.left.te_ids |
| return is_right |
| |
| @abstract |
| def _get_sub_sources(self, ctx): |
| raise NotImplementedError("abstract") |
| |
| @final |
| def make_query(self, ctx, table_reference, column_name): |
| """Make a query for a column reference viewed through this join""" |
| if table_reference is self._join: |
| return self._make_query_for_self(ctx, column_name) |
| side = self.__get_side(table_reference) |
| sub_source = self._get_sub_sources(ctx)[side] |
| assert isinstance(sub_source, ExpressionQueryMaker) |
| return self._take_for_join( |
| side, |
| sub_source.make_query(ctx, table_reference, column_name)) |
| |
| def match_column(self, ctx, cr): |
| """Take over column-matching duties from Join""" |
| # Special case: a SPAN JOIN "generates" its own span metadata |
| # columns ex nihilo instead of forwarding them to the |
| # joined tables. |
| join = self._join |
| if len(cr.path) == 1 and cr.path[0] == join.name: |
| # If we've named a join `foo` and we're matching `foo`.`bar`, |
| # then pretend we're matching a plain `bar`, but only in the |
| # context of `foo`. We end up doing the right thing for named |
| # joins this way, since we still detect ambiguity. |
| yield from self.match_column(ctx, cr.as_bare_reference()) |
| return |
| for self_column in self.get_self_columns(): |
| if cr.is_bare_match(self_column): |
| yield join, self_column |
| return |
| yield from join.left.match_column(ctx, cr) |
| yield from join.right.match_column(ctx, cr) |
| |
| class JoinInfoClassic(JoinInfo): |
| """Join control for a normal SQL join""" |
| |
| __key_exprs = None |
| |
| @override |
| def __init__(self, ctx, join, table_schema=REGULAR_TABLE): |
| super().__init__(ctx, join) |
| self.__table_schema = the(TableSchema, table_schema) |
| |
| @override |
| def match_column(self, ctx, cr): |
| # In the semi-join case of a special table with a regular RHS, |
| # punt any meta columns to the left table directly, since we know |
| # the right table won't have them. |
| if not cr.path and cr.column in self.__table_schema.meta_columns: |
| assert cr.column in \ |
| self._join.left.get_table_schema(ctx).meta_columns |
| assert cr.column not in \ |
| self._join.right.get_table_schema(ctx).meta_columns |
| yield from self._join.left.match_column(ctx, cr) |
| return |
| yield from super().match_column(ctx, cr) |
| |
| @override |
| def get_self_columns(self): |
| on = self._join.on |
| if isinstance(on, ColumnNameList): |
| return on.columns |
| return () |
| |
| @override |
| def get_table_schema(self): |
| return self.__table_schema |
| |
| def __get_join_conjunctions(self, ctx): |
| on = self._join.on |
| if isinstance(on, Expression): |
| return on.flatten_conjunctions() |
| if isinstance(on, ColumnNameList): |
| left, right = self._join.left, self._join.right |
| def _q1(column_name, resolver): |
| expr = ColumnReference(column_name) |
| expr.self_resolve(ctx, resolver) |
| return expr |
| return [ |
| BinaryOperation.eq(_q1(column_name, left), |
| _q1(column_name, right)) |
| for column_name in on.columns |
| ] |
| assert on is None |
| return () |
| |
| @override |
| def push_conditions(self, ctx, conjunctions): |
| assert not self.__key_exprs |
| left_key_exprs = [] |
| right_key_exprs = [] |
| null_allowed = [] |
| conjunctions_left = [] |
| conjunctions_right = [] |
| extra_conjunctions = [] |
| join = self._join |
| left_te_id = join.left.te_ids |
| right_te_id = join.right.te_ids |
| |
| seen = set() |
| |
| def _handle_conjunction(expr, must_match): |
| if expr in seen: |
| return |
| seen.add(expr) |
| if must_match or join.op == "inner": |
| match = expr.decompose_into_equi_join( |
| ctx, left_te_id, right_te_id) |
| else: |
| match = None |
| if match: |
| left_expr, right_expr, op = match |
| left_key_exprs.append(left_expr) |
| right_key_exprs.append(right_expr) |
| assert op in ("=", "<=>") |
| null_allowed.append(op == "<=>") |
| elif must_match: |
| raise InvalidQueryException("not an equi-join: {}".format(expr)) |
| elif expr.te_ids_used(ctx) <= left_te_id: |
| conjunctions_left.append(expr) |
| elif expr.te_ids_used(ctx) <= right_te_id: |
| conjunctions_right.append(expr) |
| else: |
| extra_conjunctions.append(expr) |
| |
| for expr in self.__get_join_conjunctions(ctx): |
| _handle_conjunction(expr, True) |
| for expr in conjunctions: |
| _handle_conjunction(expr, False) |
| |
| assert len(left_key_exprs) == len(right_key_exprs) |
| if not left_key_exprs and not join.kind: |
| raise InvalidQueryException("no join condition specified") |
| |
| extra_conjunctions += \ |
| join.left.push_conditions(ctx, conjunctions_left) |
| extra_conjunctions += \ |
| join.right.push_conditions(ctx, conjunctions_right) |
| |
| self.__key_exprs = (left_key_exprs, right_key_exprs, null_allowed) |
| return extra_conjunctions |
| |
| @cached_property |
| def __indexer_array(self): |
| join = self._join |
| ctx = self._ctx |
| key_exprs = self.__key_exprs |
| assert key_exprs, \ |
| "push_conditions must be called before using indexers" |
| sub_sources = self._get_sub_sources(ctx) |
| keys = [[expr.to_query(ctx, sub_sources[side]) |
| for expr in key_exprs[side]] |
| for side in (LEFT, RIGHT)] |
| semi_join = any(UNIQUE in q.schema.constraints for q in keys[RIGHT]) |
| if self.__table_schema.kind != TableKind.REGULAR: |
| assert join.left.get_table_schema(ctx).kind != TableKind.REGULAR |
| assert join.right.get_table_schema(ctx).kind == TableKind.REGULAR |
| if not semi_join: |
| raise InvalidQueryException( |
| "a table of kind {} can be joined to a table of kind {} for " |
| "result type {} only if the join condition is unique".format( |
| join.left.get_table_schema(ctx), |
| join.right.get_table_schema(ctx), |
| self.__table_schema)) |
| op = join.op |
| m = partial(JoinMetaQuery, |
| keys[LEFT], |
| keys[RIGHT], |
| key_exprs[2], |
| op) |
| left = (None if (semi_join and op == "left") |
| else m(JoinMeta.LEFT_INDEXER)) |
| return left, m(JoinMeta.RIGHT_INDEXER) |
| |
| @override |
| def _take_for_join(self, is_right, base_query): |
| indexer = self.__indexer_array[is_right] |
| if not indexer: |
| return base_query |
| if not is_right and self._join.op in ("inner", "left"): |
| return base_query.take_sequential(indexer) |
| return base_query.take(indexer) |
| |
| @override |
| def _get_sub_sources(self, ctx): |
| return self._join.left, self._join.right |
| |
| @override |
| def make_count_query(self, ctx): |
| # TODO(dancol): OR-node that lets us use either the left or right |
| # indexer depending on which we need |
| indexers = self.__indexer_array |
| return (indexers[0] or indexers[1]).countq() |
| |
| @override |
| def _make_query_for_self(self, ctx, column_name): |
| join = self._join |
| if isinstance(join.on, ColumnNameList): |
| try: |
| index = join.on.columns.index(column_name) |
| except IndexError: |
| index = -1 |
| if index >= 0: |
| op = join.op |
| if op == "inner": |
| # TODO(dancol): OR-node: left or right would work |
| return self.__key_exprs[0][index].to_query(ctx, self) |
| if op == "left": |
| return self.__key_exprs[0][index].to_query(ctx, self) |
| if op == "right": |
| return self.__key_exprs[1][index].to_query(ctx, self) |
| assert op == "outer" |
| return CoalesceQuery.of( |
| self.__key_exprs[0][index].to_query(ctx, self), |
| self.__key_exprs[1][index].to_query(ctx, self)) |
| return super()._make_query_for_self(ctx, column_name) |
| |
| @final |
| class JoinInfoSpan(JoinInfo): |
| """Join control for a span join""" |
| |
| # TODO(dancol): implement join support for partition-major span |
| # tables instead of converting to time-major tables. |
| |
| @override |
| def __init__(self, ctx, join): |
| super().__init__(ctx, join) |
| if join.on: |
| raise InvalidQueryException("SPAN JOIN...ON illegal here") |
| base_sources = (join.left, join.right) |
| schemas = [bs.get_table_schema(ctx) for bs in base_sources] |
| for sub_schema in schemas: |
| if sub_schema.kind != TableKind.SPAN: |
| raise InvalidQueryException( |
| ("table expression is not a span table: {!r} " |
| "(only UNIQUE non-span tables may be joined with span tables, " |
| "and then only in INNER or LEFT join modes)" |
| ).format(base_sources[schemas.index(sub_schema)])) |
| kind = join.kind |
| if kind == "span join": |
| if bool(schemas[LEFT].partition) != bool(schemas[RIGHT].partition): |
| raise InvalidQueryException(dedent("""\ |
| Attempt to SPAN JOIN a partitioned span table with |
| a non-partitioned span table. This operation makes no |
| sense. Either collapse the partition using suitable grouping |
| operators or use BROADCAST to apply a non-partitioned set of |
| spans to partitioned data.""")) |
| if schemas[LEFT].partition == schemas[RIGHT].partition: |
| our_partition = schemas[LEFT].partition |
| else: |
| raise InvalidQueryException( |
| "left and right partition columns differ ({!r} vs {!r}): " |
| "rename one table's partition".format( |
| schemas[LEFT].partition, schemas[RIGHT].partition)) |
| elif kind == "span broadcast": |
| if schemas[LEFT].partition: |
| raise InvalidQueryException( |
| "SPAN BROADCAST needs non-partioned LHS") |
| if not schemas[RIGHT].partition: |
| raise InvalidQueryException( |
| "SPAN BROADCAST needs partitioned RHS") |
| our_partition = schemas[RIGHT].partition |
| else: |
| raise AssertionError( |
| "impossible span join type {!r}".format(join.kind)) |
| # We enforce the time-major ordering thing inside |
| # indexer construction. |
| self.__our_table_schema = ( |
| TableSchema(TableKind.SPAN, our_partition, TableSorting.NONE) |
| if our_partition else SPAN_UNPARTITIONED_TIME_MAJOR) |
| |
| @override |
| def get_table_schema(self): |
| return self.__our_table_schema |
| |
| @override |
| def push_conditions(self, ctx, conjunctions): |
| conjunctions_left = [] |
| conjunctions_right = [] |
| extra_conjunctions = [] |
| join = self._join |
| left_te_id = join.left.te_ids |
| right_te_id = join.right.te_ids |
| for expr in _drop_duplicates(conjunctions): |
| if expr.te_ids_used(ctx) <= left_te_id: |
| conjunctions_left.append(expr) |
| elif expr.te_ids_used(ctx) <= right_te_id: |
| conjunctions_right.append(expr) |
| else: |
| extra_conjunctions.append(expr) |
| extra_conjunctions += \ |
| join.left.push_conditions(ctx, conjunctions_left) |
| extra_conjunctions += \ |
| join.right.push_conditions(ctx, conjunctions_right) |
| return extra_conjunctions |
| |
| @cached_property |
| def __sub_sources(self): |
| ctx = self._ctx |
| join = self._join |
| return [ |
| bs.to_schema(ctx, |
| kind=TableKind.SPAN, |
| sorting=TableSorting.TIME_MAJOR) |
| for bs in (join.left, join.right)] |
| |
| @cached_property |
| def __span_join(self): |
| join = self._join |
| ctx = self._ctx |
| base_sources = (join.left, join.right) |
| crs = [[ColumnReference(meta_column_name).self_resolve(ctx, bs) |
| for meta_column_name in bs.get_table_schema(ctx).meta_columns] |
| for bs in base_sources] |
| kind = join.op |
| required = [kind in ("inner", "left"), kind in ("inner", "right")] |
| return SpanJoin([ |
| SpanJoin.Source( |
| SpanTableConfig( |
| *[meta_cr.to_query(ctx, ss) for meta_cr in cr]), |
| required=r) |
| for cr, ss, r in zip(crs, self.__sub_sources, required) |
| ]) |
| |
| @cached_property |
| def __indexer_array(self): |
| span_join = self.__span_join |
| return span_join.indexerq(0), span_join.indexerq(1) |
| |
| @override |
| def _take_for_join(self, is_right, base_query): |
| # TODO(dancol): we can use take_sequential even in the partitioned |
| # case when we support partition-major ordering for span join. |
| indexer = self.__indexer_array[is_right] |
| if not indexer.config.sources[is_right].span.partition: |
| return base_query.take_sequential(indexer) |
| return base_query.take(indexer) |
| |
| @override |
| def _get_sub_sources(self, ctx): |
| return self.__sub_sources |
| |
| def __metaq(self, meta): |
| return self.__span_join.metaq(meta) |
| |
| @override |
| def make_count_query(self, ctx): |
| return self.__metaq(SpanJoin.Meta.TIMESTAMP).countq() |
| |
| @override |
| def _make_query_for_self(self, ctx, column_name): |
| if column_name == "_ts": |
| return self.__metaq(SpanJoin.Meta.TIMESTAMP) |
| if column_name == "_duration": |
| return self.__metaq(SpanJoin.Meta.DURATION) |
| if column_name == self.__our_table_schema.partition: |
| return self.__metaq(SpanJoin.Meta.PARTITION) |
| return super()._make_query_for_self(ctx, column_name) |
| |
| @final |
| class JoinInfoEvent(JoinInfo): |
| """Join control for events""" |
| |
| @override |
| def __init__(self, ctx, join): |
| super().__init__(ctx, join) |
| if join.on: |
| raise InvalidQueryException("EVENT JOIN ... ON illegal here") |
| base_sources = (join.left, join.right) |
| left_table_schema, right_table_schema = [ |
| bs.get_table_schema(ctx) for bs in base_sources] |
| event_is_required = join.op in ("inner", "left") |
| if not event_is_required: |
| raise InvalidQueryException( |
| "EVENT JOINs must be inner or left joins " |
| "with respect to events, since it's " |
| "not possible to condense a span into a point") |
| span_is_required = join.op in ("inner", "right") |
| kind = join.kind |
| if kind == "event join": |
| if not left_table_schema.partition and right_table_schema.partition: |
| raise InvalidQueryException( |
| "cannot join non-partitioned event table and partitioned " |
| "span table: use BROADCAST instead") |
| if left_table_schema.partition == right_table_schema.partition: |
| our_partition = left_table_schema.partition |
| elif left_table_schema.partition and not right_table_schema.partition: |
| our_partition = left_table_schema.partition |
| else: |
| raise InvalidQueryException( |
| "left and right partition columns differ ({!r} vs {!r}): " |
| "rename one input's partition".format( |
| left_table_schema, right_table_schema)) |
| elif kind == "event broadcast": |
| if left_table_schema.partition: |
| raise InvalidQueryException( |
| "EVENT BROADCAST needs non-partioned LHS") |
| if not right_table_schema.partition: |
| raise InvalidQueryException( |
| "EVENT BROADCAST needs partitioned RHS") |
| our_partition = right_table_schema.partition |
| else: |
| raise AssertionError( |
| "impossible event join type {!r}".format(join.kind)) |
| |
| # TODO(dancol): We could be unconditionally |
| # TableSorting.TIME_MAJOR if the event join code would always emit |
| # partitions in order. Benchmark it both ways. |
| |
| self.__span_is_required = span_is_required |
| self.__our_table_schema = TableSchema( |
| TableKind.EVENT, |
| our_partition, |
| TableSorting.NONE if our_partition else TableSorting.TIME_MAJOR) |
| |
| @override |
| def get_table_schema(self): |
| return self.__our_table_schema |
| |
| @override |
| def push_conditions(self, ctx, conjunctions): |
| conjunctions_left = [] |
| conjunctions_right = [] |
| extra_conjunctions = [] |
| join = self._join |
| left_te_id = join.left.te_ids |
| right_te_id = join.right.te_ids |
| for expr in _drop_duplicates(conjunctions): |
| if expr.te_ids_used(ctx) <= left_te_id: |
| conjunctions_left.append(expr) |
| elif expr.te_ids_used(ctx) <= right_te_id: |
| conjunctions_right.append(expr) |
| else: |
| extra_conjunctions.append(expr) |
| extra_conjunctions += \ |
| join.left.push_conditions(ctx, conjunctions_left) |
| extra_conjunctions += \ |
| join.right.push_conditions(ctx, conjunctions_right) |
| return extra_conjunctions |
| |
| @cached_property |
| def __indexer_array(self): |
| return (self.__event_join.metaq(EventJoin.Meta.EVENT_INDEX), |
| self.__event_join.metaq(EventJoin.Meta.SPAN_INDEX)) |
| |
| @override |
| def _take_for_join(self, is_right, base_query): |
| indexer = self.__indexer_array[is_right] |
| # TODO(dancol): too conservative? |
| config = indexer.config |
| if not config.event.partition and not config.span.partition: |
| return base_query.take_sequential(indexer) |
| return base_query.take(indexer) |
| |
| @cached_property |
| def __event_join(self): |
| join = self._join |
| ctx = self._ctx |
| base_sources = (join.left, join.right) |
| crs = [[ColumnReference(meta_column_name).self_resolve(ctx, bs) |
| for meta_column_name in bs.get_table_schema(ctx).meta_columns] |
| for bs in base_sources] |
| ss = self.__sub_sources |
| return EventJoin( |
| event=EventTableConfig( |
| *[meta_cr.to_query(ctx, ss[LEFT]) |
| for meta_cr in crs[LEFT]]), |
| span=SpanTableConfig( |
| *[meta_cr.to_query(ctx, ss[RIGHT]) |
| for meta_cr in crs[RIGHT]]), |
| span_is_required=self.__span_is_required) |
| |
| @cached_property |
| def __sub_sources(self): |
| ctx = self._ctx |
| join = self._join |
| return [ |
| join.left.to_schema(ctx, |
| kind=TableKind.EVENT, |
| sorting=TableSorting.TIME_MAJOR), |
| join.right.to_schema(ctx, |
| kind=TableKind.SPAN, |
| sorting=TableSorting.TIME_MAJOR), |
| ] |
| |
| @override |
| def _get_sub_sources(self, ctx): |
| return self.__sub_sources |
| |
| @override |
| def make_count_query(self, ctx): |
| # TODO(dancol): OR-node |
| return self.__event_join.metaq(EventJoin.Meta.TIMESTAMP).countq() |
| |
| @override |
| def _make_query_for_self(self, ctx, column_name): |
| if column_name == "_ts": |
| return self.__event_join.metaq(EventJoin.Meta.TIMESTAMP) |
| if column_name == self.__our_table_schema.partition: |
| return self.__event_join.metaq(EventJoin.Meta.PARTITION) |
| return super()._make_query_for_self(ctx, column_name) |
| |
| def _span_join_info_constructor(ctx, join): |
| # We don't allow span-joins of span tables to non-span tables |
| # because the join might duplicate the span rows, rendering the |
| # output span table invalid. This consideration doesn't apply, |
| # however, when we have a unique join, and in this case, we allow |
| # regular joins. |
| left_table_schema = join.left.get_table_schema(ctx) |
| right_table_schema = join.right.get_table_schema(ctx) |
| left_span = left_table_schema.kind == TableKind.SPAN |
| right_span = right_table_schema.kind == TableKind.SPAN |
| if left_span and not right_span and join.op in ("inner", "left"): |
| return JoinInfoClassic(ctx, join, left_table_schema) |
| return JoinInfoSpan(ctx, join) |
| |
| _JOIN_CONSTRUCTORS = { |
| None: JoinInfoClassic, |
| "span join": _span_join_info_constructor, |
| "span broadcast": JoinInfoSpan, |
| "event join": JoinInfoEvent, |
| "event broadcast": JoinInfoEvent, |
| } |
| |
| @final |
| class Join(TableExpression): |
| """A join of two tables used as a table expression""" |
| left = iattr(TableExpression) |
| op = iattr(str) |
| right = iattr(TableExpression) |
| on = iattr((Expression, ColumnNameList), nullable=True, default=None) |
| kind = iattr(str, nullable=True, default=None) |
| |
| def __get_join_info(self, ctx): |
| join_info = ctx.join_info.get(self) |
| if not join_info: |
| ctx.join_info[self] = \ |
| join_info = \ |
| _JOIN_CONSTRUCTORS[self.kind](ctx, self) |
| return join_info |
| |
| @override |
| def push_conditions(self, ctx, conjunctions): |
| return self.__get_join_info(ctx).push_conditions(ctx, conjunctions) |
| |
| @override |
| def te_walk(self): |
| yield self |
| yield from self.left.te_walk() |
| yield from self.right.te_walk() |
| |
| @override |
| def self_resolve_te(self, ctx): |
| self.left.self_resolve_te(ctx) |
| self.right.self_resolve_te(ctx) |
| if isinstance(self.on, Expression): |
| self.on.self_resolve(ctx, self) |
| |
| def __get_self_columns(self, ctx): |
| return self.__get_join_info(ctx).get_self_columns() |
| |
| @override |
| def match_column(self, ctx, cr): |
| yield from self.__get_join_info(ctx).match_column(ctx, cr) |
| |
| @override |
| def gen_table_paths(self, ctx): |
| yield from super().gen_table_paths(ctx) |
| yield from self.left.gen_table_paths(ctx) |
| yield from self.right.gen_table_paths(ctx) |
| |
| @override |
| def gen_wildcard_matches(self, ctx, path): |
| self_columns = self.__get_self_columns(ctx) |
| if not self_columns: |
| yield from self.left.gen_wildcard_matches(ctx, path) |
| yield from self.right.gen_wildcard_matches(ctx, path) |
| return |
| if self.match_cr_path(path): |
| for column in self_columns: |
| yield column, self |
| if self.name and len(path) == 1 and path[0] == self.name: |
| citer = self.gen_wildcard_matches(ctx, ()) |
| else: |
| citer = chain(self.left.gen_wildcard_matches(ctx, path), |
| self.right.gen_wildcard_matches(ctx, path)) |
| # Omit special columns from wildcard matches |
| for column, te in citer: |
| if column not in self_columns: |
| yield column, te |
| |
| @override |
| def make_query(self, ctx, table_reference, column_name): |
| return self.__get_join_info(ctx).make_query( |
| ctx, table_reference, column_name) |
| |
| @override |
| def make_count_query(self, ctx): |
| return self.__get_join_info(ctx).make_count_query(ctx) |
| |
| @override |
| def get_table_schema(self, ctx): |
| return self.__get_join_info(ctx).get_table_schema() |
| |
| |
| |
| class BaseAggregatingQueryMaker(ExpressionQueryMaker): |
| """Common functionality for aggregation and grouping""" |
| |
| @override |
| def __init__(self, base_qm): |
| assert isinstance(base_qm, ExpressionQueryMaker) |
| self._base_qm = base_qm |
| |
| @abstract |
| def make_group_sizes_query(self, ctx): |
| """Make a query yielding the size of each group""" |
| raise NotImplementedError("abstract") |
| |
| @abstract |
| def _make_aggregation_query(self, base_query, distinct, function): |
| raise NotImplementedError("abstract") |
| |
| @override |
| def make_funcall_query(self, ctx, funcall): |
| function = ctx.resolve_funcall(funcall) |
| if not isinstance(function, AggregationFunction): |
| return super().make_funcall_query(ctx, funcall) |
| arguments = funcall.arguments |
| distinct = funcall.distinct |
| if distinct == "all": |
| distinct = None |
| if not arguments and function.name == "count": |
| # Special case for COUNT(*) |
| countq = self.make_group_sizes_query(ctx) |
| assert countq.schema.is_integral, \ |
| "bad COUNT(*) query from {!r}".format(self) |
| return countq |
| if len(arguments) > 1: |
| raise InvalidQueryException("aggregate functions take one argument") |
| argument = arguments[0] |
| if not isinstance(argument, Expression): |
| raise InvalidQueryException( |
| "aggregate function argument must be simple expression, not {}" |
| .format(argument)) |
| return self._make_aggregation_query( |
| argument.to_query(ctx, self._base_qm), distinct, function.name) |
| |
| @override |
| def make_query(self, ctx, table_reference, column_name): |
| # TODO(dancol): support bare columns, SQLite-style |
| raise InvalidQueryException( |
| "column reference {}/{} does not have an aggregate".format( |
| table_reference, column_name)) |
| |
| @final |
| class GroupingQueryMaker(BaseAggregatingQueryMaker): |
| """Aggregation over groups""" |
| |
| @override |
| def __init__(self, base_qm, group_by_queries): |
| super().__init__(base_qm) |
| assert assert_seq_type(tuple, QueryNode, group_by_queries) |
| self.__group_by_queries = group_by_queries |
| |
| @override |
| def make_query(self, ctx, table_reference, column_name): |
| base_q = self._base_qm.make_query(ctx, table_reference, column_name) |
| if base_q in self.__group_by_queries: |
| return GroupLabelsQuery(self.__group_by_queries, base_q) |
| return super().make_query(ctx, table_reference, column_name) |
| |
| @override |
| def _make_aggregation_query(self, base_query, distinct, aggfunc): |
| return NativeGroupedAggregationQuery( |
| group_by=self.__group_by_queries, |
| data=base_query, |
| aggregation=aggfunc, |
| distinct=distinct) |
| |
| @override |
| def make_group_sizes_query(self, ctx): |
| return GroupSizesQuery(self.__group_by_queries) |
| |
| @override |
| def make_count_query(self, _ctx): |
| return GroupCountQuery(self.__group_by_queries) |
| |
| @final |
| class AggregatingQueryMaker(BaseAggregatingQueryMaker): |
| """Aggregation over entire result sets""" |
| |
| @override |
| def _make_aggregation_query(self, base_query, distinct, aggfunc): |
| return NativeUngroupedAggregationQuery( |
| aggregation=aggfunc, |
| data=base_query, |
| distinct=distinct) |
| |
| @override |
| def make_group_sizes_query(self, ctx): |
| return self._base_qm.make_count_query(ctx) |
| |
| @override |
| def make_count_query(self, _ctx): |
| return QueryNode.scalar(1) |
| |
| @final |
| class DepartitioningQueryMaker(BaseAggregatingQueryMaker): |
| """Special span-departitioning aggregation voodoo""" |
| |
| @override |
| def __init__(self, base_qm, config): |
| super().__init__(base_qm) |
| assert isinstance(base_qm, MagicGroupingBaseSource) |
| self.__config = the(SpanPivot, config) |
| |
| @override |
| def _make_aggregation_query(self, base_query, distinct, aggfunc): |
| return self.__config.dataq(aggfunc, base_query, distinct) |
| |
| @override |
| def make_group_sizes_query(self, ctx): |
| nr_partitions_q = DropDuplicatesQuery.of1( |
| self.__config.grouped.partition).countq() |
| return QueryNode.filled(nr_partitions_q, |
| self.make_count_query(ctx)) |
| |
| @override |
| def make_count_query(self, ctx): |
| # TODO(dancol): OR-node, when we get a grown-up optimizer |
| config = self.__config |
| return config.metaq(config.Meta.TIMESTAMP).countq() |
| |
| @override |
| def make_query(self, ctx, table_reference, column_name): |
| config = self.__config |
| if table_reference is self._base_qm: |
| if column_name == "_ts": |
| return config.metaq(config.Meta.TIMESTAMP) |
| if column_name == "_duration": |
| return config.metaq(config.Meta.DURATION) |
| if (config.output_partition and |
| column_name == self._base_qm.get_table_schema(ctx).partition): |
| return config.metaq(config.Meta.OUTPUT_PARTITION) |
| return super().make_query(ctx, table_reference, column_name) |
| |
| class SelectCore(AstNode): |
| """One part of a compound select""" |
| |
| @abstract |
| def make_qt(self, tctx, ob_asts): |
| """Make a QueryTable for this core""" |
| raise NotImplementedError("abstract") |
| |
| @final |
| class RowValue(AstNode): |
| """Literal row value""" |
| items = tattr(Expression) |
| |
| @final |
| class TableValues(SelectCore): |
| """Table from a list of literal values""" |
| row_values = tattr(RowValue, nonempty=True) |
| |
| @override |
| def make_qt(self, tctx, ob_asts): |
| # pylint: disable=protected-access |
| if ob_asts: |
| raise NotImplementedError( # Why bother with the complexity? |
| "wrap VALUES in a SELECT to sort them") |
| return self.simplify().__make_table(tctx), () |
| |
| @staticmethod |
| def __values_to_query(tctx, column_values): |
| # TODO(dancol): remove this code once we have "query facts" |
| # Here, we try to coalesce VALUES entries that happen to be |
| # literals, but the query engine itself should really be doing it |
| # automatically. The trouble is that we can't express the idea |
| # that although we have a ConcatenatingQuery, we know its size in |
| # advance, because we have no general-purpose way of asserting to |
| # the optimizer that some queries belong to a size equivalence |
| # class, so absent some special hack like this, we're going to end |
| # up with a dumb concatenation of individual scalar queries, the |
| # overall effect of which is to even further stupify the "query |
| # optimizer". |
| queries = [] |
| literals = [] |
| def _flush_literals(): |
| if literals: |
| queries.append(QueryNode.literals(*literals)) |
| literals.clear() |
| for value in column_values: |
| assert isinstance(value, Expression) |
| if isinstance(value, AstLiteral): |
| literals.append(value.evaluate_tvf(tctx)) |
| else: |
| _flush_literals() |
| queries.append(value.to_query_uncoordinated(tctx)) |
| _flush_literals() |
| return QueryNode.concat(*queries) |
| |
| def __make_table(self, tctx): |
| row_values = self.row_values |
| if not all_same(len(rv.items) for rv in row_values): |
| raise InvalidQueryException( |
| "all literal values in VALUES must have the same length") |
| ncol = len(row_values[0].items) |
| if not ncol: |
| raise InvalidQueryException("VALUES with no columns") |
| column_names = ["col{}".format(i) for i in range(ncol)] |
| values_by_column = tuple(zip(*[rv.items for rv in row_values])) |
| return GenericQueryTable([ |
| (name, self.__values_to_query(tctx, column_values)) |
| for name, column_values in zip(column_names, values_by_column) |
| ]) |
| |
| @final |
| class Compound(SelectCore): |
| """Set operation on tables""" |
| left = iattr(SelectCore) |
| op = iattr(str) |
| right = iattr(SelectCore) |
| |
| @override |
| def make_qt(self, tctx, ob_asts): |
| left_qt, left_ob_queries = self.left.make_qt(tctx, ob_asts) |
| right_qt, right_ob_queries = self.right.make_qt(tctx, ob_asts) |
| if left_qt.table_schema.kind != TableKind.REGULAR or \ |
| right_qt.table_schema.kind != TableKind.REGULAR: |
| raise InvalidQueryException( |
| "compound queries supported only on regular tables") |
| |
| if len(left_qt.columns) != len(right_qt.columns): |
| raise InvalidQueryException( |
| "incompatible sub-selections in compound select") |
| assert len(left_ob_queries) == len(right_ob_queries) |
| op = self.op |
| if op in ("union", "union all"): |
| columns = [ |
| (left_column, QueryNode.concat(left_query, right_query)) |
| for ((left_column, left_query), (_right_column, right_query)) |
| in zip(left_qt.items(), right_qt.items()) |
| ] |
| ob_queries = list(starmap( |
| QueryNode.concat, zip(left_ob_queries, right_ob_queries))) |
| if op == "union": |
| ob_queries = _distinctify_ob_queries(columns, ob_queries) |
| columns = _distinctify_column_list(columns) |
| return GenericQueryTable(columns), ob_queries |
| |
| if op == "intersect": |
| # TODO(dancol): be faster. A join is a stupidly expensive way |
| # to implement the intersect operation; we can just use two hash |
| # tables, like in drop_duplicates. No, we can't just use a |
| # Pandas MultIndex: MultiIndex implements set operations by |
| # Python set object operations on tuples! |
| left_key = tuple(map(second, left_qt.items())) |
| right_key = tuple(map(second, right_qt.items())) |
| # We only need the left indexer: we're taking the intersection, |
| # so all result values occur on both the left and right sides. |
| # Here, we choose to grab the values from the left side. |
| # TODO(dancol): use an OR-node to pick a side |
| side = JoinMeta.LEFT_INDEXER |
| indexer = JoinMetaQuery(left_key, right_key, |
| (True,) * len(left_key), |
| "inner", side) |
| # We're using an inner join, which means that the LHS indices |
| # are sequential. |
| columns = [(name, query.take_sequential(indexer)) |
| for name, query in left_qt.items()] |
| ob_queries = [query.take_sequential(indexer) |
| for query in left_ob_queries] |
| return (GenericQueryTable(_distinctify_column_list(columns)), |
| _distinctify_ob_queries(columns, ob_queries)) |
| |
| assert op == "except" # N.B. Anti-symmetric set difference! |
| # TODO(dancol): Don't implement with join. |
| left_key = tuple(map(second, left_qt.items())) |
| right_key = tuple(map(second, right_qt.items())) |
| side = JoinMeta.LEFT_INDEXER |
| left_indexer = JoinMetaQuery(left_key, right_key, |
| (True,) * len(left_key), |
| "left", side) |
| side = JoinMeta.RIGHT_INDEXER |
| right_indexer = JoinMetaQuery(left_key, right_key, |
| (True,) * len(left_key), |
| "left", side) |
| |
| # We use a left join. The right indexer contains the row number |
| # in the right table where we can find the right-side value we |
| # joined, or -1 if the right side didn't have the value. Since we |
| # want anti-symmetric set difference, pick from the left table all |
| # rows where the right indexer indicates that we didn't find a |
| # corresponding value. |
| |
| # TODO(dancol): add optimization for the case where we're doing a |
| # binary operation and broadcasting one side to the same size as |
| # the other. Also, an OR-node for counting the left or |
| # right indexer. |
| is_right_missing = right_indexer < 0 |
| # pylint: disable=redefined-variable-type |
| indexer = left_indexer.where(is_right_missing) |
| columns = [(name, query.take_sequential(indexer)) |
| for name, query in left_qt.items()] |
| ob_queries = [query.take(indexer) for query in left_ob_queries] |
| return (GenericQueryTable(_distinctify_column_list(columns)), |
| _distinctify_ob_queries(columns, ob_queries)) |
| |
| class FilledQueryTable(QueryTable): |
| """Generate value-less rows |
| |
| This function generates COUNT empty rows, which is occasionally |
| useful when combined with other query operations. For example, |
| this query produces six rows each containing the value 5: |
| |
| SELECT 5 AS value FROM dctv.filled(6) |
| """ |
| |
| count = iattr_query_node_int(default=QueryNode.scalar(1)) |
| |
| @override |
| def _make_column_tuple(self): |
| return () |
| |
| @override |
| def _make_column_query(self, column): |
| raise KeyError(column) |
| |
| @override |
| def countq(self): |
| return self.count |
| |
| @final |
| class DummyTableExpression(QueryTableExpression): |
| """Used for source-less queries""" |
| @override |
| def evaluate_tvf(self, tctx): |
| return FilledQueryTable() |
| |
| @final |
| class FilteringExpressionQueryMaker(ExpressionQueryMaker): |
| """Filters another EQM""" |
| |
| @override |
| def __init__(self, unfiltered_qm, q_filter): |
| assert isinstance(unfiltered_qm, ExpressionQueryMaker) |
| assert isinstance(q_filter, QueryNode) |
| self.__unfiltered_qm = unfiltered_qm |
| filter_schema = q_filter.schema |
| if filter_schema.dtype != BOOL or not filter_schema.is_normal: |
| raise InvalidQueryException( |
| "cannot use query of type as filter: {}".format(filter_schema)) |
| self.__q_filter = q_filter |
| |
| @override |
| def make_query(self, ctx, table_reference, column_name): |
| return self.__unfiltered_qm.make_query( |
| ctx, table_reference, column_name).where(self.__q_filter) |
| |
| @override |
| def make_count_query(self, _ctx): |
| # The number of true entries is the number of rows as boolean. |
| return NativeUngroupedAggregationQuery( |
| aggregation="sum", |
| data=self.__q_filter) |
| |
| def _distinctify_column_list(columns): |
| distinct_group = frozenset(map(second, columns)) |
| return [(name, DropDuplicatesQuery(distinct_group, query)) |
| for name, query in columns] |
| |
| def _distinctify_ob_queries(columns, ob_queries): |
| distinct_group = frozenset(map(second, columns)) |
| indexer = DropDuplicatesIndexerQuery(distinct_group) |
| return [ob_query.take_sequential(indexer) for ob_query in ob_queries] |
| |
| def _make_conjunction_filter(ctx, source, conjunctions): |
| assert conjunctions |
| expr = reduce(BinaryOperation.and_, conjunctions) |
| return expr.to_query(ctx, source) |
| |
| class GroupBy(AstNode): |
| """Controls grouping in a RegularSelect""" |
| |
| def simplify_expressions(self, ctx, source, subst_column_references): |
| """Simplify any columns to which we refer""" |
| pass |
| |
| def override_base_source(self, ctx, source): # pylint: disable=unused-argument,no-self-use |
| """Return an override for the base select column source""" |
| return source |
| |
| @abstract |
| def make_query_maker(self, ctx, query_maker, out_table_schema): |
| """Make the column data source for this grouping""" |
| raise NotImplementedError("abstract") |
| |
| @final |
| class GroupByExpressions(GroupBy): |
| """Group by a list of column expressions""" |
| expressions = tattr(Expression, nonempty=True) |
| |
| @override |
| def simplify(self, subst=None): |
| return self # Will do our own simplification later |
| |
| @override |
| def simplify_expressions(self, ctx, source, subst_column_references): |
| def _fix_expr(expr): |
| expr = expr.simplify(subst_column_references) |
| expr.self_resolve(ctx, source) |
| return expr |
| assert not ctx.simplified_gb_expressions |
| ctx.simplified_gb_expressions = tuple(map(_fix_expr, self.expressions)) |
| |
| @override |
| def make_query_maker(self, ctx, query_maker, out_table_schema): |
| if out_table_schema.kind != TableKind.REGULAR: |
| raise InvalidQueryException("{} incompatible with GROUP BY" |
| .format(out_table_schema)) |
| return GroupingQueryMaker( |
| query_maker, |
| tuple(expr.to_query(ctx, query_maker) |
| for expr in ctx.simplified_gb_expressions)) |
| |
| @final |
| class MagicGroupingBaseSource(TableExpression): |
| """Base column source used for magic grouping operations |
| |
| "Magic" grouping operations are those that have the JOIN-like |
| quality of adding columns to the SELECT result set. GroupBy itself |
| doesn't participate in column selection, so, early in select |
| processing, the select core wraps its original event source with |
| this thing, which exposes the columns the GroupBy wants to provide. |
| """ |
| base = iattr(TableExpression) |
| schema = iattr(TableSchema) |
| |
| @override |
| def gen_wildcard_matches(self, ctx, path): |
| table_schema = self.get_table_schema(ctx) |
| if self.match_cr_path(path): |
| for column in table_schema.meta_columns: |
| yield column, self |
| if self.name and len(path) == 1 and path[0] == self.name: |
| citer = self.gen_wildcard_matches(ctx, ()) |
| else: |
| citer = iter(self.base.gen_wildcard_matches(ctx, path)) |
| for column, te in citer: |
| if column not in table_schema.meta_columns: |
| yield column, te |
| |
| @override |
| def gen_table_paths(self, ctx): |
| yield from super().gen_table_paths(ctx) |
| yield from self.base.gen_table_paths(ctx) |
| |
| @override |
| def match_column(self, ctx, cr): |
| if len(cr.path) == 1 and cr.path[0] == self.name: |
| yield from self.match_column(ctx, cr.as_bare_reference()) |
| return |
| for meta_column in self.get_table_schema(ctx).meta_columns: |
| if cr.is_bare_match(meta_column): |
| # Will be munged by the higher-level grouping source! |
| yield self, meta_column |
| return |
| yield from self.base.match_column(ctx, cr) |
| |
| @override |
| def make_count_query(self, ctx): |
| return ctx.join_info[self].make_count_query(ctx) |
| |
| @override |
| def get_aggregate_argument_resolver(self): |
| return self.base |
| |
| @override |
| def push_conditions(self, ctx, conjunctions): |
| # pylint: disable=redefined-variable-type |
| assert self not in ctx.join_info |
| extra_conjunctions = self.base.push_conditions(ctx, conjunctions) |
| base_qm = self.base.to_schema(ctx, sorting=TableSorting.TIME_MAJOR) |
| if extra_conjunctions: |
| my_id = id(self) |
| for conjunction in extra_conjunctions: |
| for te_id in conjunction.te_ids_used(ctx): |
| if te_id == my_id: |
| raise InvalidQueryException( |
| "invalid use of magic group metadata column in WHERE: " |
| "consider naming the base table explicitly instead") |
| base_qm = FilteringExpressionQueryMaker( |
| base_qm, |
| _make_conjunction_filter( |
| ctx, |
| base_qm, |
| extra_conjunctions)) |
| ctx.join_info[self] = base_qm |
| return () |
| |
| @override |
| def make_query(self, ctx, table_reference, column_name): |
| if table_reference is self: |
| raise AssertionError("invalid placeholder column reference {!r}" |
| .format(column_name)) |
| base_qm = ctx.join_info[self] |
| return base_qm.make_query(ctx, table_reference, column_name) |
| |
| @override |
| def get_table_schema(self, _ctx): |
| return self.schema |
| |
| @final |
| class GroupBySpanPartition(GroupBy): |
| """Group by a internal span-table partitions""" |
| intersect = iattr(bool, default=True) # TODO(dancol): change default |
| output_partition = iattr(str, nullable=True, default=None) |
| |
| @override |
| def override_base_source(self, ctx, source): |
| source_table_schema = source.get_table_schema(ctx) |
| if source_table_schema.kind != TableKind.SPAN: |
| raise InvalidQueryException( |
| "GROUP USING PARTITION works only with a span table") |
| if not source_table_schema.partition: |
| raise InvalidQueryException( |
| "GROUP USING PARTITION cannot be applied to non-partitioned " |
| "span table") |
| table_schema = ( |
| source_table_schema.evolve(partition=self.output_partition, |
| sorting=TableSorting.NONE) |
| if self.output_partition else SPAN_UNPARTITIONED_TIME_MAJOR) |
| return MagicGroupingBaseSource(source, table_schema) |
| |
| @override |
| def make_query_maker(self, ctx, query_maker, out_table_schema): |
| if not isinstance(query_maker, MagicGroupingBaseSource): |
| raise InvalidQueryException( |
| "Unsupported use of GROUP USING PARTITION") |
| source = query_maker.base |
| source_table_schema = source.get_table_schema(ctx) |
| assert source_table_schema.kind == TableKind.SPAN |
| assert source_table_schema.partition |
| |
| def _q1(column_name): |
| expr = ColumnReference(column_name) |
| expr.self_resolve(ctx, source) |
| return expr.to_query(ctx, query_maker) |
| |
| # TODO(dancol): get rid of the union/intersect stuff entirely: |
| # it's not that useful. |
| |
| partition_q = _q1(source_table_schema.partition) |
| if self.intersect: |
| unique_pvals = DropDuplicatesQuery.of1(partition_q) |
| min_npartitions_q = unique_pvals.countq() |
| else: |
| min_npartitions_q = QueryNode.scalar(1) |
| |
| config = SpanPivot( |
| grouped=SpanTableConfig(_q1("_ts"), |
| _q1("_duration"), |
| partition_q), |
| output_partition=(_q1(self.output_partition) |
| if self.output_partition else None), |
| min_npartitions=min_npartitions_q, |
| ) |
| |
| return DepartitioningQueryMaker(query_maker, config) |
| |
| @final |
| class GroupBySpanQueryMaker(BaseAggregatingQueryMaker): |
| """Generates group-by-span queries""" |
| |
| @override |
| def __init__(self, base_qm, config): |
| super().__init__(base_qm) |
| self.__config = the(SpanGroup, config) |
| |
| @override |
| def make_group_sizes_query(self, ctx): |
| return self._make_aggregation_query( |
| self.__config.grouped.ts, False, "count") |
| |
| @override |
| def make_count_query(self, ctx): |
| # TODO(dancol): opportunity for an OR-node |
| return self.__config.metaq(SpanGroup.Meta.TIMESTAMP).countq() |
| |
| @override |
| def _make_aggregation_query(self, base_query, distinct, aggfunc): |
| return self.__config.dataq(aggfunc, base_query, distinct) |
| |
| @cached_property |
| def ts_q(self): |
| """Timestamp query for the result set""" |
| return self.__config.metaq(SpanGroup.Meta.TIMESTAMP) |
| |
| @cached_property |
| def partition_q(self): |
| """Partition query for the result set""" |
| return self.__config.metaq(SpanGroup.Meta.PARTITION) |
| |
| @override |
| def make_query(self, ctx, table_reference, column_name): |
| if table_reference is self._base_qm: |
| if column_name == "_ts": |
| return self.ts_q |
| if column_name == "_duration": |
| return self.__config.metaq(SpanGroup.Meta.DURATION) |
| if column_name == self._base_qm.get_table_schema(ctx).partition: |
| return self.partition_q |
| return super().make_query(ctx, table_reference, column_name) |
| |
| @final |
| class GroupBySpanGrouper(GroupBy): |
| """Group by an explicit list of spans""" |
| grouper = iattr(TableExpression) |
| mode = iattr(str) |
| intersect = iattr(bool, default=False) |
| |
| def __get_grouper_qt(self, ctx): |
| grouper_qt = ctx.join_info.get(self) |
| if not grouper_qt: |
| grouper_qt_raw = ctx.te_to_qt(self.grouper) |
| ctx.join_info[self] = grouper_qt = \ |
| grouper_qt_raw.to_schema(kind=TableKind.SPAN, |
| sorting=TableSorting.TIME_MAJOR) |
| return grouper_qt |
| |
| @override |
| def override_base_source(self, ctx, source): |
| grouper_qt = self.__get_grouper_qt(ctx) |
| grouper_schema = grouper_qt.table_schema |
| grouped_schema = source.get_table_schema(ctx) |
| mode = self.mode |
| if mode == "span": |
| if grouped_schema.kind != TableKind.SPAN: |
| raise InvalidQueryException("grouped table must be span table") |
| elif mode == "event": |
| if grouped_schema.kind != TableKind.EVENT: |
| raise InvalidQueryException("grouped table must be event table") |
| else: |
| raise AssertionError("invalid mode: " + mode) |
| assert grouper_schema.kind == TableKind.SPAN |
| assert grouper_schema.sorting == TableSorting.TIME_MAJOR |
| if grouper_schema.partition and grouped_schema.partition: |
| if grouper_schema.partition != grouped_schema.partition: |
| raise InvalidQueryException( |
| "partition name mismatch in span group: {!r} vs {!r}".format( |
| grouper_schema.partition, grouped_schema.partition)) |
| if not grouper_schema.partition and grouped_schema.partition: |
| table_schema = TableSchema(TableKind.SPAN, |
| grouped_schema.partition, |
| TableSorting.NONE) |
| else: |
| table_schema = grouper_schema.evolve(sorting=TableSorting.NONE) |
| if not table_schema.partition: |
| table_schema = table_schema.evolve(sorting=TableSorting.TIME_MAJOR) |
| return MagicGroupingBaseSource(source, table_schema) |
| |
| @override |
| def make_query_maker(self, ctx, query_maker, _out_table_schema): |
| if not isinstance(query_maker, MagicGroupingBaseSource): |
| raise InvalidQueryException( |
| "Unsupported use of GROUP USING SPANS") |
| source = query_maker.base |
| |
| def _q1(column_name): |
| expr = ColumnReference(column_name) |
| expr.self_resolve(ctx, source) |
| return expr.to_query(ctx, query_maker) |
| |
| grouper_qt = self.__get_grouper_qt(ctx) |
| grouper_schema = grouper_qt.table_schema |
| grouped_schema = source.get_table_schema(ctx) |
| |
| # Checked during override_base_source, so we can just assert here. |
| assert grouped_schema.kind in (TableKind.SPAN, TableKind.EVENT) |
| assert grouper_schema.kind == TableKind.SPAN |
| assert grouper_schema.sorting == TableSorting.TIME_MAJOR |
| |
| partition_q = ( |
| _q1(grouped_schema.partition) |
| if grouped_schema.partition else None) |
| grouper_partition_q = ( |
| grouper_qt[grouper_qt.table_schema.partition] |
| if grouper_schema.partition else None) |
| unique_pvals_q = ( |
| DropDuplicatesQuery.of1(partition_q) |
| if (partition_q and |
| not grouper_partition_q and |
| not self.intersect) |
| else None) |
| |
| grouped = (SpanTableConfig(_q1("_ts"), _q1("_duration"), partition_q) |
| if grouped_schema.kind == TableKind.SPAN |
| else EventTableConfig(_q1("_ts"), partition_q)) |
| config = SpanGroup(grouped=grouped, |
| grouper=SpanTableConfig.from_qt(grouper_qt), |
| unique_pvals=unique_pvals_q, |
| intersect=self.intersect) |
| assert bool(config.is_output_partitioned) == bool( |
| query_maker.get_table_schema(ctx).partition), \ |
| "{!r} vs {!r}".format(config.is_output_partitioned, |
| query_maker.get_table_schema(ctx).partition) |
| return GroupBySpanQueryMaker(query_maker, config) |
| |
| @final |
| class RegularSelect(SelectCore): |
| """A normal select as a compound operand""" |
| distinct = iattr(str, default="all") |
| columns = tattr(Column, default=()) |
| from_ = iattr(TableExpression, nullable=True, default=None) |
| where = iattr(Expression, nullable=True, default=None) |
| gb = iattr(GroupBy, nullable=True, default=None) |
| having = iattr(Expression, nullable=True, default=None) |
| kind = iattr(str, nullable=True, default=None) |
| repartition_by = iattr(str, nullable=True, default=None) |
| |
| def __make_column_expression_map(self, ctx, source, table_schema): |
| column_expressions = OrderedDict() |
| disambig = defaultdict(partial(xcount, 1)) |
| for column in self.columns: |
| for name, expression in \ |
| column.get_column_specs(ctx, source, self.kind, table_schema): |
| if name == table_schema.partition: |
| raise InvalidQueryException( |
| "column {!r} is the partition and must not " |
| "be explicitly specified".format(name)) |
| if name in table_schema.meta_columns: |
| raise InvalidQueryException( |
| "column {!r} is reserved".format(name)) |
| while name in column_expressions: |
| name = "{}:{}".format(name, next(disambig[name])) |
| assert isinstance(expression, Expression), \ |
| "expression {!r} should be expression".format(expression) |
| column_expressions[name] = expression |
| return column_expressions |
| |
| def __make_qt_1(self, ctx, ob_asts): |
| # pylint: disable=redefined-variable-type |
| |
| distinct = self.distinct |
| if distinct == "all": |
| distinct = None |
| |
| where = self.where |
| having = self.having |
| gb = self.gb |
| |
| # Resolve columns references to specific concrete source |
| # TableExpression instances, with lexical scoping. |
| source = self.from_ or DummyTableExpression() |
| if gb: |
| source = gb.override_base_source(ctx, source) |
| |
| # Enforced by the parser. |
| assert self.repartition_by is None or self.kind == "event" |
| |
| source.self_resolve_te(ctx) |
| source_table_schema = source.get_table_schema(ctx) |
| if self.kind == "span": |
| if source_table_schema.kind != TableKind.SPAN: |
| raise InvalidQueryException( |
| "SELECT SPAN of a non-span table {}" |
| .format(source_table_schema)) |
| out_table_schema = source_table_schema |
| elif self.kind == "event": |
| if source_table_schema.kind != TableKind.EVENT: |
| raise InvalidQueryException( |
| "SELECT EVENT of non-event table {}" |
| .format(source_table_schema)) |
| out_table_schema = source_table_schema |
| if self.repartition_by is not None: |
| # TODO(dancol): check if we're actually already sorted |
| # according to the new partition and use a less conservative |
| # output sorting than TableSorting.NONE if so. |
| out_table_schema = out_table_schema.evolve( |
| partition=self.repartition_by or None, |
| sorting=TableSorting.NONE) |
| else: |
| assert self.kind is None |
| out_table_schema = REGULAR_TABLE |
| |
| column_exprs = self.__make_column_expression_map( |
| ctx, source, out_table_schema) |
| |
| if out_table_schema.meta_columns: |
| assert not set(out_table_schema.meta_columns) & set(column_exprs) |
| # Prepend the span metadata columns unconditionally the list of |
| # output columns. Iteration is reversed because we prepend at |
| # each step. |
| for meta_column in out_table_schema.meta_columns[::-1]: |
| column_exprs[meta_column] = ColumnReference(meta_column) |
| column_exprs.move_to_end(meta_column, last=False) |
| |
| for expr in column_exprs.values(): |
| expr.self_resolve(ctx, source) |
| |
| # TODO(dancol): remove this check now that we have reordering for |
| # special tables. |
| if ob_asts and out_table_schema.kind != TableKind.REGULAR: |
| raise InvalidQueryException( |
| "SELECT {} incompatible with ORDER BY: " |
| "{} tables are specially ordered" |
| .format(self.kind.upper(), self.kind.lower())) |
| |
| if where or having or gb or ob_asts: |
| # The WHERE and HAVING and GROUP BY and ORDER BY expressions can |
| # refer to the column expressions, and before we munge things, |
| # these column expressions appear to be table references in the |
| # expression's AST. Substitute them with the resolved column |
| # AST. The substitution isn't recursive, so embedded column |
| # references with names matching aliases won't blow the stack. |
| # The new expressions share structure with any expressions in |
| # the column ASTs, so any column references we resolved above |
| # are still resolved. |
| explicit_aliases = { |
| column.name: column.expr # pylint: disable=no-member |
| for column in self.columns |
| if isinstance(column, ExpressionColumn) and column.name # pylint: disable=no-member |
| } |
| def _subst_column_references(ast_node): |
| if isinstance(ast_node, ColumnReference) and not ast_node.path: |
| ast_node = explicit_aliases.get(ast_node.column, ast_node) |
| return ast_node |
| if where: |
| assert not ctx.prohibit_window_functions |
| ctx.prohibit_window_functions = "WHERE clause" |
| where = where.simplify(_subst_column_references) |
| where.self_resolve(ctx, source) |
| del ctx.prohibit_window_functions |
| if having: |
| having = having.simplify(_subst_column_references) |
| having.self_resolve(ctx, source) |
| if gb: |
| gb.simplify_expressions(ctx, source, _subst_column_references) |
| if ob_asts: |
| assert isinstance(ob_asts, tuple) |
| ob_asts = tuple(ob_ast.simplify(_subst_column_references) |
| for ob_ast in ob_asts) |
| for ob_ast in ob_asts: |
| ob_ast.self_resolve(ctx, source) |
| |
| # We used to combine HAVING and WHERE here, but the set of |
| # conductions under which we couldn't do that became unwieldy, so |
| # now we just always keep HAVING and WHERE separate. |
| |
| conjunctions = where.flatten_conjunctions() if where else () |
| top_level_conjunctions = source.push_conditions(ctx, conjunctions) |
| query_maker = source |
| if top_level_conjunctions: |
| query_maker = FilteringExpressionQueryMaker( |
| source, _make_conjunction_filter( |
| ctx, |
| query_maker, |
| top_level_conjunctions)) |
| |
| if gb: |
| query_maker = gb.make_query_maker(ctx, query_maker, out_table_schema) |
| elif ctx.aggregate_functions_used: |
| if out_table_schema.kind != TableKind.REGULAR: |
| raise InvalidQueryException( |
| "SELECT {} incompatible with aggregation functions".format( |
| TableKind.label_of(out_table_schema.kind))) |
| query_maker = AggregatingQueryMaker(query_maker) |
| |
| # Column source configuration complete: actually generate the |
| # QueryNode instances for our columns. |
| columns = [(name, expr.to_query(ctx, query_maker)) |
| for name, expr in column_exprs.items()] |
| |
| ob_queries = [expr.to_sort_query(ctx, query_maker) |
| for expr in ob_asts] |
| |
| # Various kinds of postprocessing. |
| if having: |
| q_having_filter = having.to_query(ctx, query_maker) |
| columns = [(name, query.where(q_having_filter)) |
| for name, query in columns] |
| ob_queries = [query.where(q_having_filter) |
| for query in ob_queries] |
| if distinct: |
| ob_queries = _distinctify_ob_queries(columns, ob_queries) |
| columns = _distinctify_column_list(columns) |
| |
| if ctx.regenerate_column_names: |
| # If we have column-expression subqueries, we learn the right |
| # auto-generated name only after we do all the above work, so |
| # re-run the column name generation algorithm to pick up the |
| # new, better name. |
| new_column_names = tuple(self.__make_column_expression_map( |
| ctx, source, out_table_schema).keys()) |
| for i in range(len(out_table_schema.meta_columns), |
| len(column_exprs)): |
| columns[i] = new_column_names[i], columns[i][1] |
| |
| return (GenericQueryTable(columns, table_schema=out_table_schema), |
| ob_queries) |
| |
| @override |
| def make_qt(self, tctx, ob_asts): |
| ctx = QueryCompilationContext(tctx) |
| # pylint: disable=protected-access |
| return self.simplify().__make_qt_1(ctx, ob_asts) |
| |
| @final |
| class OrderingTerm(AstNode): |
| """Part of an ORDER BY clause""" |
| expr = iattr(Expression) |
| direction = iattr(str, default="asc") |
| |
| def is_ascending(self): |
| """Return whether the ordering is for ascending mode""" |
| assert self.direction in ("asc", "desc") |
| return self.direction == "asc" |
| |
| def _filter_to_limit_offset(tctx, base_qt, limit, offset): |
| ctx = QueryCompilationContext(tctx) |
| source = DummyTableExpression() |
| def _resolve_limit_expr(expr): |
| if not expr: |
| return None |
| return (expr |
| .simplify() |
| .self_resolve(ctx, source) |
| .to_query(ctx, source)) |
| limit_q = _resolve_limit_expr(limit) |
| offset_q = _resolve_limit_expr(offset) |
| return base_qt.transform( |
| lambda q: q.limit_offset(limit=limit_q, offset=offset_q)) |
| |
| class DmlContext(ExplicitInheritance): |
| """Environment for running queries""" |
| |
| @override |
| def __init__(self, ns=None, qe=None): |
| """Make a new DMLContext |
| |
| NS is the namespace to use as the DML "user" namespace, which |
| namespace-modifying commands affect. If None, create a new |
| detached namespace. |
| |
| If QE is supplied, it should be a QueryEngine that allows DML |
| commands to execute queries as part of their operation, mostly for |
| IF conditionalizing. |
| """ |
| assert not ns or isinstance(ns, Namespace) |
| self.user_ns = ns or Namespace() |
| self.qe = qe |
| |
| def make_tctx(self, args=NO_ARGS): # pylint: disable=dangerous-default-value |
| """Make a compilation environment for the normal user NS""" |
| return TvfContext.from_ns(self.user_ns, args) |
| |
| @final |
| def execute(self, commands): |
| """Convenience function to execute a batch of DML queries""" |
| for command in commands: |
| assert isinstance(command, DmlAction) |
| command.execute_dml(self) |
| |
| def mount_trace(self, mount_path, trace_file_name): |
| """Mount a trace file in the user namespace""" |
| raise NotImplementedError("loading traces not supported " |
| "in basic DmlContext") |
| |
| class DmlAction(AstNode): |
| """AST node corresponding to an action to perform""" |
| @abstract |
| def execute_dml(self, dmlctx): |
| """Do a thing""" |
| raise NotImplementedError("abstract") |
| |
| @abstract |
| def eval_for_autocomplete(self, dmlctx): |
| """Do enough of a thing to fire autocomplete hooks""" |
| pass # Call into super! |
| |
| class Select(DmlAction): |
| """Query-producing AST node""" |
| |
| @final |
| @override |
| def execute_dml(self, dmlctx): |
| # If we want to SELECT for real, we'll need to store the |
| # QueryNode-level QueryContext in the dmlctx. Here, we just make |
| # the query table for the select, which has the effect of at least |
| # checking syntax. |
| return self.make_qt(dmlctx.make_tctx()) |
| |
| @abstract |
| def make_qt(self, tctx): |
| """Make a QueryTable providing results of this query""" |
| raise NotImplementedError("abstract") |
| |
| @final |
| @override |
| def eval_for_autocomplete(self, dmlctx): |
| super().eval_for_autocomplete(dmlctx) |
| self.execute_dml(dmlctx) |
| |
| @final |
| class SelectDirect(Select): |
| """A SELECT without any lexical bindings""" |
| core = iattr(SelectCore) |
| ob = tattr(OrderingTerm, default=()) |
| limit = iattr(Expression, nullable=True, default=None) |
| offset = iattr(Expression, nullable=True, default=None) |
| |
| def __check_no_substructure_sharing(self): |
| # Check that the AST doesn't share substructure. Doing so would |
| # be super bad, since the SQL-to-QueryTable code uses object |
| # identity to index AST node attributes. |
| seen = IdentityDictionary() |
| for node in self.walk(): |
| assert node not in seen, "should not share structure" |
| seen[node] = True |
| |
| @override |
| def make_qt(self, tctx): |
| if __debug__: |
| self.__check_no_substructure_sharing() |
| core = self.core |
| ob_asts = tuple(ob.expr for ob in self.ob) |
| ob_asc = tuple(ob.is_ascending() for ob in self.ob) |
| base_qt, ob_queries = core.make_qt(tctx, ob_asts) |
| assert len(ob_queries) == len(ob_asts) |
| if ob_queries: |
| indexer = ArgSortQuery(zip(ob_queries, ob_asc)) |
| columns = [(column, query.take(indexer)) |
| for column, query in base_qt.items()] |
| base_qt = GenericQueryTable(columns) |
| if self.limit or self.offset: |
| return _filter_to_limit_offset( |
| tctx, base_qt, self.limit, self.offset) |
| return base_qt |
| |
| @final |
| class CteBindingName(AstNode): |
| """Name of a CTE binding with possible column renaming""" |
| name = iattr(str) |
| renaming = tattr(str, default=()) |
| do_rename = iattr(bool, default=False) |
| |
| @final |
| class CteBinding(NoWalkAstNode): |
| """One lexical binding in a CTE's binding list""" |
| name = iattr(CteBindingName) |
| te = iattr(CteBindingValue) |
| |
| @final |
| class SelectWithCte(Select): |
| """A SELECT with a locally-bound common table expression""" |
| bindings = tattr(CteBinding) |
| body = iattr(Select) |
| |
| @override |
| def make_qt(self, tctx): |
| for binding in self.bindings: |
| table = binding.te.make_cte_value(tctx, binding.name) |
| tctx = tctx.let({binding.name.name: table}) |
| return self.body.make_qt(tctx) |
| |
| class CreateDmlAction(DmlAction): |
| """DML action that creates something""" |
| |
| overwrite = iattr(bool, default=False, kwonly=True) |
| condition = iattr((FunExpr, NoneType), default=None, kwonly=True) |
| |
| @abstract |
| def _do_create(self, dmltctx): |
| raise NotImplementedError("abstract") |
| |
| def __should_execute(self, dmlctx): |
| if not self.condition: |
| return True |
| # TODO(dancol): skip the query engine invocation if TVF evaluation |
| # produces an obviously-true or obviously-false false. |
| q = QueryNode.coerce_(self.condition.evaluate_tvf(dmlctx.make_tctx())) |
| qe = dmlctx.qe |
| if not qe: |
| raise InvalidQueryException("cannot evaluate view conditional here") |
| (_res_q, res), = qe.execute_for_columns({q}) |
| return len(res) and bool(res[0]) |
| |
| @override |
| @final |
| def execute_dml(self, dmlctx): |
| if self.__should_execute(dmlctx): |
| self._do_create(dmlctx) |
| |
| @override |
| def eval_for_autocomplete(self, dmlctx): |
| super().eval_for_autocomplete(dmlctx) |
| if self.condition: |
| self.__should_execute(dmlctx) # For side effect |
| |
| @final |
| class CreateView(CreateDmlAction): |
| """View-creation statement""" |
| path = tattr(str, nonempty=True) |
| select = iattr((Select, FunExpr)) |
| as_function = iattr(bool, default=False, kwonly=True) |
| documentation = iattr(str, nullable=True, kwonly=True, default=None) |
| |
| @override |
| def _do_create(self, dmlctx): |
| if self.as_function: |
| if isinstance(self.select, FunExpr): |
| raise InvalidQueryException( |
| "view values must be table in function mode") |
| value = TableValuedFunction.from_select_ast( |
| self.select, dmlctx.user_ns) |
| else: |
| ns_wr = weakref.ref(dmlctx.user_ns) |
| def _make_table(): |
| ns = ns_wr() |
| assert ns |
| tctx = TvfContext.from_ns(ns) |
| return (self.select.make_qt(tctx) |
| if isinstance(self.select, Select) |
| else self.select.evaluate_tvf(tctx)) |
| value = LazyNsEntry(_make_table) |
| if self.documentation is not None: |
| _documentation[value] = self.documentation |
| dmlctx.user_ns.assign_by_path(self.path, |
| value, |
| overwrite=self.overwrite) |
| |
| @override |
| def eval_for_autocomplete(self, dmlctx): |
| super().eval_for_autocomplete(dmlctx) |
| if not self.as_function and isinstance(self.select, Select): |
| # TODO(dancol): support autocomplete in the TVF case |
| self.select.make_qt(TvfContext.from_ns(dmlctx.user_ns)) |
| |
| @final |
| class Drop(DmlAction): |
| """Drop an entire namespace prefix""" |
| path = tattr(str) |
| ignore_absent = iattr(bool, default=False) |
| |
| @override |
| def execute_dml(self, dmlctx): |
| try: |
| dmlctx.user_ns.delete_by_path(self.path) |
| except KeyError: |
| if not self.ignore_absent: |
| raise UnboundReferenceException( |
| "could not find " + path2str(self.path)) from None |
| |
| @override |
| def eval_for_autocomplete(self, dmlctx): |
| super().eval_for_autocomplete(dmlctx) |
| # TODO(dancol): implement autocomplete |
| raise NotImplementedError |
| |
| @final |
| class MountTrace(DmlAction): |
| """Mount a trace at a location in the query namespace""" |
| trace_file_name = iattr(str) |
| mount_path = tattr(str, nonempty=True) |
| |
| @override |
| def execute_dml(self, dmlctx): |
| dmlctx.mount_trace(self.mount_path, self.trace_file_name) |
| |
| @override |
| def eval_for_autocomplete(self, dmlctx): |
| super().eval_for_autocomplete(dmlctx) |
| # TODO(dancol): implement autocomplete |
| raise NotImplementedError |
| |
| @final |
| class CreateNamespace(CreateDmlAction): |
| """Create a SQL namespace""" |
| path = tattr(str) |
| |
| @override |
| def _do_create(self, dmlctx): |
| dmlctx.user_ns.assign_by_path(self.path, |
| Namespace(), |
| overwrite=self.overwrite) |
| |
| @override |
| def eval_for_autocomplete(self, dmlctx): |
| super().eval_for_autocomplete(dmlctx) |
| # TODO(dancol): implement autocomplete |
| raise NotImplementedError |
| |
| |
| # Non-aggregation functions |
| |
| @final |
| class AggregationFunction(Immutable): |
| """Aggregation function: implemented in native core""" |
| name = iattr(str) |
| |
| @override |
| def _post_init_assert(self): |
| super()._post_init_assert() |
| assert self.name in WELL_KNOWN_AGGREGATIONS |
| |
| class NormalFunction(Immutable): |
| """Non-aggregate SQL function: lives in function namespace""" |
| |
| def hook_self_resolve(self, _ctx, _funcall, _resolver): # pylint: disable=no-self-use |
| """Override column resolution hook for FunctionCall. |
| |
| Return whether we actually override anything. |
| """ |
| return False |
| |
| @abstract |
| def make_query(self, ctx, te, arguments): |
| """Make a QueryNode for this function call""" |
| raise NotImplementedError("abstract") |
| |
| @final |
| class FnNormalFunction(NormalFunction): |
| """NormalFunction implemented with a call to a normal Python function |
| |
| All arguments are evaluated into QueryNode for the call to the |
| underlying function. |
| """ |
| |
| fn = iattr() |
| |
| # Hack: keyword arguments are evaluated in non-table context. |
| # TODO(dancol): provide ability to eaily and precisely specify a |
| # function evaluation schema. |
| |
| @override |
| def hook_self_resolve(self, ctx, funcall, resolver): |
| assert isinstance(funcall, FunctionCall) |
| for argument in funcall.arguments: |
| if isinstance(argument, KeywordArgument): |
| argument.expr.self_resolve(ctx, DummyTableExpression()) |
| else: |
| argument.self_resolve(ctx, resolver) |
| return True |
| |
| @override |
| def make_query(self, ctx, te, arguments): |
| args = [] |
| kwargs = {} |
| for arg_node in arguments: |
| if isinstance(arg_node, KeywordArgument): |
| kwargs[arg_node.name] = arg_node.expr.to_query( |
| ctx, DummyTableExpression()) |
| else: |
| if kwargs: |
| raise InvalidQueryException( |
| "non-keyword arguments after keyword arguments") |
| args.append(arg_node.to_query(ctx, te)) |
| # pylint: disable=not-callable |
| return self.fn(*args, **kwargs) |
| |
| |
| # TVF utilities |
| |
| def add_sql_documentation(thing, doc): |
| """Add SQL documentation to THING |
| |
| DOC is a string. THING is any object we can stick in a table |
| namespace; if it's a lazy-evaluation object, the documentation is |
| automatically propagated to the inflated object that the lazy |
| evaluation produces. |
| """ |
| _documentation[thing] = the(str, doc) |
| |
| def get_sql_documentation(thing): |
| """Get documentation for THING or None if no documentation exists""" |
| doc = _documentation.get(thing) |
| if doc is not None: |
| doc = dedent(doc.strip("\n")) |
| return doc |
| |
| def qargs(*args, **kwargs): |
| """Make an args object suitable for resolving bind arguments""" |
| for i, arg in enumerate(args): |
| kwargs[i] = arg |
| return kwargs |
| |
| def _parse_select(s): |
| # Avoid recursive import detection in pylint |
| fn = sys.modules["dctv.sql_parser"].parse_select |
| global _parse_select # pylint: disable=global-variable-undefined |
| _parse_select = fn |
| return fn(s) |
| |
| @final |
| class LazyNsEntry(Immutable): |
| """Resolve a blob of SQL to a QueryTable lazily""" |
| resolver = iattr() |
| @override |
| def _post_init_check(self): |
| super()._post_init_check() |
| assert callable(self.resolver) |
| doc = getattr(self.resolver, "__doc__", None) |
| if doc is not None: |
| add_sql_documentation(self, doc) |
| |
| |
| # Built-in TVFs |
| |
| @once() |
| def _make_generate_sequential_spans_tvf(): |
| ast = _parse_select(""" |
| SELECT |
| CAST((ROW_NUMBER() * :duration_ns + :start_ns) AS UNIT ns) AS _ts, |
| CAST(:duration_ns AS UNIT ns) AS _duration, |
| FROM dctv.filled(((:end_ns - :start_ns) // :duration_ns)) |
| """) |
| return TableValuedFunction.from_select_ast(ast) |
| |
| def _generate_sequential_spans_tvf(start, end, span_duration): |
| """Generate regular sequential spans |
| |
| (START, END, SPAN_DURATION) |
| |
| START is the time (coerced to nanoseconds) of the first span in the |
| sequence of generated spans. END is the time (coerced to |
| nanoseconds) of the end of the last span in the sequence. |
| SPAN_DURATION is the duration of each span in the sequence. |
| Each generated span in the sequence immediately follows the previous |
| span. If the last span in the sequence would have a duration |
| shorter than SPAN_DURATION to end at END, that span is omitted from |
| the sequence. |
| """ |
| def _to_ns(value): |
| return QueryNode.coerce_(value).to_unit(ureg().ns).strip_unit() |
| return _internal_cast_as_span_table( |
| _make_generate_sequential_spans_tvf()( |
| start_ns=_to_ns(start), |
| end_ns=_to_ns(end), |
| duration_ns=_to_ns(span_duration)), |
| verify=False, |
| sort=False) |
| |
| @once() |
| def _with_all_partitions_tvf_implicit(): |
| ast = _parse_select(""" |
| SELECT SPAN * |
| FROM |
| (SELECT SPAN |
| FROM ?0 |
| GROUP AND INTERSECT SPANS USING PARTITIONS) |
| SPAN BROADCAST INTO SPAN PARTITIONS |
| ?0 |
| """) |
| return TableValuedFunction.from_select_ast(ast) |
| |
| |
| def _with_all_partitions_tvf(qt): |
| """Filter a span table so that all partitions are present |
| |
| QT is a partitioned span table. The output of this function is a |
| partitioned span table that contains no spans except for those |
| regions of the timeline during which _all_ partitions in QT are |
| present. For example, if QT is partitioned by CPU, the output of |
| this routine a set of spans during which all CPUs are present, e.g., |
| when none is hotplugged off. |
| """ |
| |
| return _with_all_partitions_tvf_implicit()(qt) |
| |
| def _span_starts_tvf(span_table): |
| """Convert span starts to event table |
| |
| The output of this function is an event table (partitioned like the |
| input span table) in which each span start is an event. |
| This function is useful when reinterpreting one kind of span as |
| another. |
| """ |
| |
| qt = QueryTable.coerce_(span_table) |
| if qt.table_schema.kind != TableKind.SPAN: |
| raise InvalidQueryException("not a span table: {}".format(qt)) |
| return GenericQueryTable( |
| {column: qt[column] |
| for column in qt.columns if column != "_duration"}, |
| table_schema=qt.table_schema.evolve(kind=TableKind.EVENT)) |
| |
| def _span_ends_tvf(span_table): |
| """Convert span ends to event table |
| |
| The output of this function is an event table (partitioned like the |
| input span table) in which each span end is an event. This function |
| is useful when reinterpreting one kind of span as another. |
| """ |
| qt = QueryTable.coerce_(span_table) |
| if qt.table_schema.kind != TableKind.SPAN: |
| raise InvalidQueryException("not a span table: {}".format(qt)) |
| return GenericQueryTable( |
| {column: (qt[column] + qt["_duration"] |
| if column == "_ts" else qt[column]) |
| for column in qt.columns if column != "_duration"}, |
| table_schema=qt.table_schema.evolve(kind=TableKind.EVENT)) |
| |
| @final |
| class LagLeadFunction(NormalFunction): |
| """Evalutes arguments to LAG and LEAD in correct context""" |
| |
| fn = iattr() |
| |
| @override |
| def hook_self_resolve(self, ctx, funcall, resolver): |
| if ctx.prohibit_window_functions: |
| raise InvalidQueryException( |
| "window function not allowed in {}".format( |
| ctx.prohibit_window_functions)) |
| |
| assert isinstance(funcall, FunctionCall) |
| arguments = funcall.arguments |
| if not 1 <= len(arguments) <= 3: |
| raise InvalidQueryException( |
| "LAG/LEAD expect between one and three arguments: got {}" |
| .format(len(arguments))) |
| for argno, argument in enumerate(arguments): |
| if not isinstance(argument, Expression): |
| raise InvalidQueryException("LAG/LEAD expect no keyword arguments") |
| # Make sure that the first argument of lag/lead is resolved in |
| # root lexical context, not the query's. |
| argument.self_resolve(ctx, |
| resolver if not argno |
| else DummyTableExpression()) |
| return True |
| |
| @override |
| def make_query(self, ctx, te, arguments): |
| """Make a QueryNode for this function call""" |
| # We know self.fn is callable pylint: disable=not-callable |
| dummy_te = DummyTableExpression() |
| return self.fn(arguments[0].to_query(ctx, te), |
| *[argument.to_query(ctx, dummy_te) |
| for argument in arguments[1:]]) |
| |
| # TODO(dancol): make LAG and LEAD more efficient. The way we do it |
| # here requires full copies of the inputs. We need to either make |
| # dedicated LAG/LEAD operators or teach the optimizer about |
| # count thresholds. |
| |
| def _lead(query, offset, default=None): |
| offset_q = QueryNode.coerce_(offset) |
| return QueryNode.concat( |
| query.slice(offset_q, None), |
| QueryNode.filled(QueryNode.coerce_(default), |
| offset_q + |
| QueryNode.least(query.countq() - offset_q, 0))) |
| |
| def _lag(query, offset=1, default=None): |
| offset_q = QueryNode.coerce_(offset) |
| return QueryNode.concat( |
| QueryNode.filled(QueryNode.coerce_(default), |
| QueryNode.least(query.countq(), offset_q)), |
| query.slice(0, -offset_q)) |
| |
| def extend_spans_tvf(qt, addl): |
| """Grow each span in a span table |
| |
| (QT, ADDL) |
| |
| QT is a span table. ADDL is an amount of time (coerced to |
| nanoseconds) by which to extend each span in the span table. If, as |
| a result of this extension, spans would illegally overlap, the spans |
| are instead merged. |
| |
| |
| The resulting span table has no payload. To get the payload of each |
| extended span, SPAN GROUP with the original table. |
| """ |
| qt = qt.to_schema(kind=TableKind.SPAN, |
| sorting=TableSorting.TIME_MAJOR) |
| ts_q = qt["_ts"] |
| duration_q = qt["_duration"] |
| if not qt.table_schema.partition: |
| partition_q = None |
| else: |
| partition_q = qt[qt.table_schema.partition] |
| fixup = SpanFixup(ts_q, duration_q + addl, partition_q) |
| columns = [("_ts", fixup.metaq(SpanFixup.Meta.TS)), |
| ("_duration", fixup.metaq(SpanFixup.Meta.DURATION))] |
| if partition_q: |
| columns.append( |
| (qt.table_schema.partition, |
| fixup.metaq(SpanFixup.Meta.PARTITION))) |
| table_schema = TableSchema(TableKind.SPAN, |
| qt.table_schema.partition, |
| TableSorting.NONE) |
| else: |
| table_schema = SPAN_UNPARTITIONED_TIME_MAJOR |
| return GenericQueryTable(columns, table_schema=table_schema) |
| |
| def _internal_cast_as_span_table(tbl, |
| *, |
| partition=None, |
| sorting="time_major", |
| sort=True, |
| verify=True): |
| """Treat a regular table as a span table""" |
| tbl = QueryTable.coerce_(tbl) |
| # We just throw here because we haven't implemented these features, |
| # but we at least make callers explicitly opt out of the planned |
| # safety checks. |
| if verify: |
| # TODO(dancol): implement span verification |
| raise NotImplementedError("Span verification not implemented") |
| if sort: |
| # TODO(dancol): implement sort |
| raise NotImplementedError("Sorting not implemented") |
| if tbl.table_schema.kind == TableKind.SPAN: |
| return tbl |
| if tbl.table_schema.kind != TableKind.REGULAR: |
| raise InvalidQueryException("not a regular table: {}".format(tbl)) |
| |
| if "_ts" not in tbl: |
| raise InvalidQueryException("table missing timestamp column") |
| ts_q = tbl["_ts"] |
| if ts_q.schema != TS_SCHEMA: |
| raise InvalidQueryException("invalid timestamp schema: {}" |
| .format(ts_q.schema)) |
| if "_duration" not in tbl: |
| raise InvalidQueryException("table missing duration column") |
| duration_q = tbl["_duration"] |
| if duration_q.schema != DURATION_SCHEMA: |
| raise InvalidQueryException("invalid duration schema: {}" |
| .format(duration_q.schema)) |
| if partition: |
| if partition not in tbl: |
| raise InvalidQueryException( |
| "partition {!r} not in table".format(partition)) |
| table_schema = TableSchema(TableKind.SPAN, |
| partition, |
| TableSorting(sorting)) |
| # TODO(dancol): we want to insert some kind of constraint-checking |
| # (i.e., ensuring that the output actually looks like a span table) |
| # in debug builds instead of just forwarding the queries blindly! |
| return GenericQueryTable( |
| ((column, tbl[column]) for column in tbl.columns), |
| table_schema=table_schema) |
| |
| def _internal_cast_as_event_table(tbl, |
| *, |
| partition, |
| sorting, |
| sort, |
| verify): |
| """Treat a regular table as an event table""" |
| tbl = QueryTable.coerce_(tbl) |
| # We just throw here because we haven't implemented these features, |
| # but we at least make callers explicitly opt out of the planned |
| # safety checks. |
| if verify: |
| # TODO(dancol): implement event verification |
| raise NotImplementedError("verify not implemented") |
| if sort: |
| # TODO(dancol): implement sort |
| raise NotImplementedError("sort not implemented") |
| if tbl.table_schema.kind == TableKind.EVENT: |
| return tbl |
| if tbl.table_schema.kind != TableKind.REGULAR: |
| raise InvalidQueryException("not a regular table: {}".format(tbl)) |
| if "_ts" not in tbl: |
| raise InvalidQueryException("table missing timestamp column") |
| ts_q = tbl["_ts"] |
| if ts_q.schema != TS_SCHEMA: |
| raise InvalidQueryException("invalid timestamp column: {}" |
| .format(ts_q.schema)) |
| if partition: |
| if partition not in tbl: |
| raise InvalidQueryException( |
| "partition {!r} not in table".format(partition)) |
| table_schema = TableSchema(TableKind.EVENT, |
| partition, |
| TableSorting(sorting)) |
| return GenericQueryTable( |
| ((column, tbl[column]) for column in tbl.columns), |
| table_schema=table_schema) |
| |
| class RowNumberFunction(NormalFunction): |
| """Implement SQL's ROW_NUMBER() function""" |
| @override |
| def make_query(self, ctx, te, arguments): |
| if arguments: |
| raise InvalidQueryException("ROW_NUMBER takes no arguments") |
| return QueryNode.sequential(te.make_count_query(ctx)) |
| |
| @final |
| class ReplaceSchemaQuery(SimpleQueryNode): |
| """Pass through data while advertising a different schema""" |
| _query = iattr(QueryNode, name="query") |
| _schema = iattr(QuerySchema, name="schema") |
| |
| @override |
| def countq(self): |
| return self._query.countq() |
| |
| @override |
| def _compute_inputs(self): |
| return (self._query,) |
| |
| @override |
| def _compute_schema(self): |
| return self._schema |
| |
| @override |
| async def run_async(self, qe): |
| [ic], [oc] = await qe.async_setup([self._query], [self]) |
| await passthrough(qe, ic, oc) |
| |
| def _assert_unique(query, *, verify=True): |
| # TODO(dancol): add runtime verification that the constraint holds? |
| if isinstance(verify, QueryNode): |
| verify = verify.eager_evaluate_scalar() |
| if isinstance(verify, QueryNode): |
| raise InvalidQueryException( |
| "verify argument to assert_unique must be constant") |
| if verify: |
| raise NotImplementedError("ASSERT_UNIQUE VERIFY") |
| if UNIQUE not in query.schema.constraints: |
| query = ReplaceSchemaQuery(query, query.schema.constrain(C_UNIQUE)) |
| return query |
| |
| def backfill_tvf(qt): |
| """Auto-fill NULL values in a span table |
| |
| For each group of contiguous values in a partition, fill all NULL |
| values in each column with the first non-NULL value in that column |
| or NULL if no contiguous group of values has a non-NULL value. |
| |
| This function is useful in cases where data sources may not line up |
| exactly and we want to "go back in time" and fill in some initial |
| region of unknown data with whatever we discover the data to be |
| later. |
| """ |
| qt = (QueryTable |
| .coerce_(qt) |
| .to_schema(sorting=TableSorting.PARTITION_MAJOR)) |
| meta = SpanTableConfig.from_qt(qt) |
| return qt.transform( |
| lambda q: Backfill(meta, q).metaq(Backfill.Meta.VALUE)) |
| |
| def _get_thread_analysis(): |
| """Thread analysis object""" |
| from .thread_analysis import ThreadAnalysis |
| return ThreadAnalysis |
| |
| def _identity_tvf(arg): |
| """Return argument unchanged |
| |
| This function is occasionally useful for forcing TVF evaluation. |
| """ |
| return arg |
| |
| @once() |
| def make_standard_lexical_environment(): |
| """Bootstrap the environment that's always available to users""" |
| tvf = TableValuedFunction |
| ns = Namespace() |
| ns["dctv"] = ns_dctv = Namespace() |
| ns_dctv["internal"] = ns_internal = Namespace() |
| ns_internal.disable_autocomplete = True |
| |
| things = ( |
| (("dict",), dict), |
| (("list",), lambda *args: list(args)), |
| (("dctv", "time_series_to_spans"), tvf(TimeSeriesQueryTable)), |
| (("dctv", "thread_analysis"), LazyNsEntry(_get_thread_analysis)), |
| (("dctv", "generate_sequential_spans"), |
| tvf(_generate_sequential_spans_tvf),), |
| (("dctv", "span_starts"), tvf(_span_starts_tvf)), |
| (("dctv", "span_ends"), tvf(_span_ends_tvf)), |
| (("dctv", "filled"), tvf(FilledQueryTable)), |
| (("dctv", "internal", "cast_as_span_table"), |
| tvf(_internal_cast_as_span_table)), |
| (("dctv", "internal", "cast_as_event_table"), |
| tvf(_internal_cast_as_event_table)), |
| (("dctv", "with_all_partitions"), tvf(_with_all_partitions_tvf)), |
| (("dctv", "identity"), tvf(_identity_tvf)), |
| (("dctv", "extend_spans"), tvf(extend_spans_tvf)), |
| (("dctv", "backfill"), tvf(backfill_tvf)), |
| ) |
| for path, value in things: |
| ns.assign_by_path(path, value) |
| # _unfn is a hack for implementing non-aggregate SQL functions in |
| # terms of syntactically-inaccessible unary operator invocations. |
| def _unfn(op): |
| return FnNormalFunction(partial(UnaryOperationQuery, op)) |
| |
| all_aggfuncs = [ |
| AggregationFunction(aggfunc) |
| for aggfunc in WELL_KNOWN_AGGREGATIONS |
| ] |
| |
| columnwise_things = chain( |
| [(("dctv", fn.name,), fn) for fn in all_aggfuncs], |
| [((fn.name,), fn) for fn in all_aggfuncs |
| if fn.name in ("count", "max", "min", "prod", "sum")], |
| ( |
| (("floor",), _unfn("floor")), |
| (("ceil",), _unfn("ceiling")), |
| (("ceiling",), _unfn("ceiling")), |
| (("round",), _unfn("round")), |
| (("trunc",), _unfn("trunc")), |
| (("coalesce",), FnNormalFunction(CoalesceQuery.of)), |
| (("lag",), LagLeadFunction(_lag)), |
| (("lead",), LagLeadFunction(_lead)), |
| (("greatest",), FnNormalFunction(QueryNode.greatest)), |
| (("least",), FnNormalFunction(partial(QueryNode.least))), |
| (("row_number",), RowNumberFunction()), |
| (("if",), FnNormalFunction(QueryNode.choose)), |
| (("assert_unique",), FnNormalFunction(_assert_unique)), |
| (("fls",), FnNormalFunction(FlsQuery.of)), |
| )) |
| |
| for path, value in columnwise_things: |
| ns.assign_by_path(_munge_columnwise_path(path), value) |
| return LexicalEnvironment(None, ns) |