blob: 2ae1e8e7a1b2ab358e55ac66877bf648c67be120 [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for working with SQL"""
import logging
import re
import numpy as np
from .util import (
FrozenDict,
Interned,
cached_property,
final,
iattr,
override,
tattr,
)
log = logging.getLogger(__name__)
SAFE_ID_REGEX = re.compile("[a-zA-Z_][a-zA-Z_0-9]*")
def identifier_quote(s):
"""Quote an identifier name as a SQL expression"""
if SAFE_ID_REGEX.fullmatch(s):
return s
return "`{}`".format(s.replace("`", "``"))
NAMED_BIND_RE = re.compile(r":[a-zA-Z_][a-zA-Z0-9_]+")
@final
class SqlBundle(Interned):
"""Immutable combination of SQL text and arguments.
Can be combined into larger queries with embedded query parameters
automatically renamed to avoid conflict.
"""
parts = tattr(str, kwonly=True)
args = iattr(FrozenDict, kwonly=True)
@staticmethod
def lex(sql):
"""Split a SQL expression into a parts.
If a part begins with ':', it refers to a bind parameter.
"""
if not sql:
return []
matches = tuple(NAMED_BIND_RE.finditer(sql))
if not matches:
return [sql]
parts = []
prev_end = 0
for match in matches:
match_start = match.start()
match_end = match.end()
if prev_end < match_start:
parts.append(sql[prev_end : match_start])
parts.append(sql[match_start : match_end])
prev_end = match_end
if match_end < len(sql):
parts.append(sql[prev_end:])
return parts
@cached_property
def sql(self):
"""SQL content of the bundle"""
return "".join(self.parts)
@override
def __str__(self):
return "<SQL {!r}{}>".format(
self.sql,
"" if not self.args
else " " + " ".join("{}={!r}".format(arg, self.args[arg])
for arg in sorted(self.args)))
@override
def __lt__(self, other):
if not isinstance(other, SqlBundle):
return NotImplemented
return (self.parts, self.args) < (other.parts, other.args)
@staticmethod
def format(fmt, **kwargs):
"""Format a query, yielding a SqlBundle.
FMT is a format string and part of a SQL query. The substitutions
in the format string are spelled using SQL named bind parameter
syntax, and each one must have a corresponding keyword argument.
Each substitution value is either a SqlBundle itself or a
primitive SQL type.
All named substitutions must be provided via KWARGS, and all
KWARGS must be used for substitutions.
"""
parts = []
args = {}
used = set()
def _add_bind(bind_part, value):
assert bind_part[0] == ":"
bind_name = bind_part[1:]
if bind_name in args: # Rename to avoid collision
while bind_name in args:
bind_name = bind_name + "_"
bind_part = ":" + bind_name
parts.append(bind_part)
args[bind_name] = value
def _add_plain_part(part):
if part:
assert part[0] != ":"
if parts and parts[-1][0] != ":":
parts[-1] += part
else:
parts.append(part)
for part in SqlBundle.lex(fmt):
if part[0] != ":":
_add_plain_part(part)
continue
arg_name = part[1:]
used.add(arg_name)
value = kwargs[arg_name]
if isinstance(value, np.integer):
value = int(value)
elif isinstance(value, np.floating):
value = float(value)
elif isinstance(value, str):
value = value.encode("UTF-8")
if isinstance(value, (int, bytes, bool, float)):
_add_bind(part, value)
elif isinstance(value, SqlBundle):
for sub_bundle_part in value.parts:
if sub_bundle_part[0] != ":":
_add_plain_part(sub_bundle_part)
else:
_add_bind(sub_bundle_part,
value.args[sub_bundle_part[1:]])
else:
raise TypeError(
"value of type {} not allowed in SqlBundle: {!r}"
.format(type(value), value))
assert used == set(kwargs)
return SqlBundle(parts=parts, args=FrozenDict(args))
@staticmethod
def coerce_(value):
"""Make VALUE a SqlBundle"""
if isinstance(value, SqlBundle):
return value
return SqlBundle.format(value)
of = coerce_
def paginate(self, page_size, page_number):
"""Return a new SqlBundle limited to the given page
PAGE_SIZE is an integer page size, in rows; PAGE_NUMBER is the
number of the page to return.
"""
return SqlBundle.format(
":base_query LIMIT :qb_page_size OFFSET :qb_offset",
base_query=self,
qb_page_size=page_size,
qb_offset=page_size * page_number)