| # Copyright (C) 2020 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http:#www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Utilities for working with SQL""" |
| import logging |
| import re |
| import numpy as np |
| from .util import ( |
| FrozenDict, |
| Interned, |
| cached_property, |
| final, |
| iattr, |
| override, |
| tattr, |
| ) |
| |
| log = logging.getLogger(__name__) |
| |
| SAFE_ID_REGEX = re.compile("[a-zA-Z_][a-zA-Z_0-9]*") |
| def identifier_quote(s): |
| """Quote an identifier name as a SQL expression""" |
| if SAFE_ID_REGEX.fullmatch(s): |
| return s |
| return "`{}`".format(s.replace("`", "``")) |
| |
| NAMED_BIND_RE = re.compile(r":[a-zA-Z_][a-zA-Z0-9_]+") |
| |
| @final |
| class SqlBundle(Interned): |
| """Immutable combination of SQL text and arguments. |
| |
| Can be combined into larger queries with embedded query parameters |
| automatically renamed to avoid conflict. |
| """ |
| parts = tattr(str, kwonly=True) |
| args = iattr(FrozenDict, kwonly=True) |
| |
| @staticmethod |
| def lex(sql): |
| """Split a SQL expression into a parts. |
| |
| If a part begins with ':', it refers to a bind parameter. |
| """ |
| if not sql: |
| return [] |
| matches = tuple(NAMED_BIND_RE.finditer(sql)) |
| if not matches: |
| return [sql] |
| parts = [] |
| prev_end = 0 |
| for match in matches: |
| match_start = match.start() |
| match_end = match.end() |
| if prev_end < match_start: |
| parts.append(sql[prev_end : match_start]) |
| parts.append(sql[match_start : match_end]) |
| prev_end = match_end |
| if match_end < len(sql): |
| parts.append(sql[prev_end:]) |
| return parts |
| |
| @cached_property |
| def sql(self): |
| """SQL content of the bundle""" |
| return "".join(self.parts) |
| |
| @override |
| def __str__(self): |
| return "<SQL {!r}{}>".format( |
| self.sql, |
| "" if not self.args |
| else " " + " ".join("{}={!r}".format(arg, self.args[arg]) |
| for arg in sorted(self.args))) |
| |
| @override |
| def __lt__(self, other): |
| if not isinstance(other, SqlBundle): |
| return NotImplemented |
| return (self.parts, self.args) < (other.parts, other.args) |
| |
| @staticmethod |
| def format(fmt, **kwargs): |
| """Format a query, yielding a SqlBundle. |
| |
| FMT is a format string and part of a SQL query. The substitutions |
| in the format string are spelled using SQL named bind parameter |
| syntax, and each one must have a corresponding keyword argument. |
| |
| Each substitution value is either a SqlBundle itself or a |
| primitive SQL type. |
| |
| All named substitutions must be provided via KWARGS, and all |
| KWARGS must be used for substitutions. |
| """ |
| parts = [] |
| args = {} |
| used = set() |
| |
| def _add_bind(bind_part, value): |
| assert bind_part[0] == ":" |
| bind_name = bind_part[1:] |
| if bind_name in args: # Rename to avoid collision |
| while bind_name in args: |
| bind_name = bind_name + "_" |
| bind_part = ":" + bind_name |
| parts.append(bind_part) |
| args[bind_name] = value |
| |
| def _add_plain_part(part): |
| if part: |
| assert part[0] != ":" |
| if parts and parts[-1][0] != ":": |
| parts[-1] += part |
| else: |
| parts.append(part) |
| |
| for part in SqlBundle.lex(fmt): |
| if part[0] != ":": |
| _add_plain_part(part) |
| continue |
| arg_name = part[1:] |
| used.add(arg_name) |
| value = kwargs[arg_name] |
| if isinstance(value, np.integer): |
| value = int(value) |
| elif isinstance(value, np.floating): |
| value = float(value) |
| elif isinstance(value, str): |
| value = value.encode("UTF-8") |
| if isinstance(value, (int, bytes, bool, float)): |
| _add_bind(part, value) |
| elif isinstance(value, SqlBundle): |
| for sub_bundle_part in value.parts: |
| if sub_bundle_part[0] != ":": |
| _add_plain_part(sub_bundle_part) |
| else: |
| _add_bind(sub_bundle_part, |
| value.args[sub_bundle_part[1:]]) |
| else: |
| raise TypeError( |
| "value of type {} not allowed in SqlBundle: {!r}" |
| .format(type(value), value)) |
| assert used == set(kwargs) |
| return SqlBundle(parts=parts, args=FrozenDict(args)) |
| |
| @staticmethod |
| def coerce_(value): |
| """Make VALUE a SqlBundle""" |
| if isinstance(value, SqlBundle): |
| return value |
| return SqlBundle.format(value) |
| |
| of = coerce_ |
| |
| def paginate(self, page_size, page_number): |
| """Return a new SqlBundle limited to the given page |
| |
| PAGE_SIZE is an integer page size, in rows; PAGE_NUMBER is the |
| number of the page to return. |
| """ |
| return SqlBundle.format( |
| ":base_query LIMIT :qb_page_size OFFSET :qb_offset", |
| base_query=self, |
| qb_page_size=page_size, |
| qb_offset=page_size * page_number) |