blob: 151024e17b3dbed97445bea434eee05cec93cb18 [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support for running queries in worker threads
Some query operators (especially the span-related ones) are
implemented in C++, maintain significant internal state between input
blocks, and cannot suspend and resume their execution the way Python
coroutines can. To incorporate these operators into the DCTV query
system, we run them on separate threads and communicate with these
threads using special adapter coroutines.
"""
# Terminology: the "query thread" is the one running QueryEngine.
# The "operator thread" is the one running a threaded operator.
# TODO(dancol): once C++2a coroutines become available, we can use a
# single execution strategy for Python and C++ operators.
# TODO(dancol): we ought to let multiple query threads exist for the
# same query execution, allowing multiple operators to run
# simultaneously on different cores. It's a project for another time.
import logging
import numpy as np
from numpy.ma import nomask
from modernmp.util import (
assert_seq_type,
the,
the_seq,
)
from .util import (
ExplicitInheritance,
INT64,
Immutable,
UINT64,
abstract,
cached_property,
final,
iattr,
override,
)
from ._native import (
QueryExecution,
SynchronousCoroutine,
npy_explode,
)
from .query import (
InvalidQueryException,
QueryNode,
check_same_lengths,
)
log = logging.getLogger(__name__)
class SynchronousIoSpec(ExplicitInheritance):
"""Internal base for synchronous inputs and outputs"""
_sync_qe = None
@abstract
def gather(self, input_queries, output_queries):
"""Used internally to prepare for async setup"""
raise NotImplementedError("abstract")
def prepare(self, sync_qe, _ics, _ocs):
"""Extract values from async setup results"""
assert self._sync_qe is None
self._sync_qe = the(SynchronousQueryExecution, sync_qe)
@staticmethod
def _extract_channels(channel_list, queries):
channels = channel_list[:len(queries)]
del channel_list[:len(queries)]
assert len(channels) == len(queries)
assert all(c.query is q for c, q in zip(channels, queries))
return channels
@final
class SynchronousInput(SynchronousIoSpec):
"""Handle for an input to synchronous code
Corresponds to an PythonChunkIterator in C++ code.
"""
__ics = None
@final
class Column(Immutable):
"""A column specification for a synchronous input object"""
query = iattr(nullable=True)
partition = iattr(bool, kwonly=True, default=False)
masked = iattr(bool, kwonly=True, default=False)
def to_arrays(self, arrays, block_it):
"""Build a synchronous input batch"""
if not isinstance(self.query, QueryNode):
arrays.append(self.query)
return
array, mask = npy_explode(next(block_it).as_array())
# The partition column is value-preserving, so if we have a
# uint64, we can round trip it through its signed equivalent.
if self.partition and array.dtype == UINT64:
array = array.view(dtype=INT64)
if self.masked:
arrays.extend((array, mask))
else:
if mask is not nomask:
raise InvalidQueryException(
"mask not supported for {}".format(self))
arrays.append(array)
@override
def __init__(self, *columns):
Column = self.Column
self.__columns = [
c if isinstance(c, Column) else Column(c)
for c in columns]
@cached_property
def __queries(self):
return [column.query
for column in self.__columns
if column.query]
@override
def gather(self, input_queries, _output_queries):
input_queries.extend(self.__queries)
@override
def prepare(self, sync_qe, ics, ocs):
super().prepare(sync_qe, ics, ocs)
self.__ics = self._extract_channels(ics, self.__queries)
def __blocks_to_arrays(self, blocks):
arrays = []
block_it = iter(blocks)
for column in self.__columns:
column.to_arrays(arrays, block_it)
# pylint: disable=compare-to-zero
assert next(block_it, False) is False
return arrays
def __iter__(self):
assert self._sync_qe, "not prepared"
while True:
blocks = self._sync_qe.sync_io(
*[ic.read(self._sync_qe.block_size)
for ic in self.__ics])
check_same_lengths(self.__ics, blocks)
if not len(blocks[0]): # pylint: disable=len-as-condition
break
# TODO(dancol): allow plain list
yield tuple(self.__blocks_to_arrays(blocks))
@final
class SynchronousOutput(SynchronousIoSpec):
"""Handle for a synchronous output.
Matches one ResultBuffer in C++ code.
"""
__ocs = None
__partition_index = None
@override
def __init__(self, *queries, masked=False, partition=None):
if None in queries:
queries = tuple(query for query in queries
if query is not None)
self.__queries = the_seq(tuple, QueryNode, queries)
self.__masked = masked
if partition:
if not np.can_cast(INT64, partition.schema.dtype):
self.__partition_index = self.__queries.index(partition)
@override
def gather(self, _input_queries, output_queries):
output_queries.extend(self.__queries)
@override
def prepare(self, sync_qe, ics, ocs):
super().prepare(sync_qe, ics, ocs)
self.__ocs = self._extract_channels(ocs, self.__queries)
def __call__(self, *buffers):
if self.__masked:
arrays = [np.ma.masked_array(data_buffer, mask_buffer)
for data_buffer, mask_buffer
in zip(buffers[0::2], buffers[1::2])]
else:
arrays = [data_buffer for data_buffer in buffers]
assert len(self.__ocs) == len(arrays), \
("we expected {} arrays but synchronous caller gave us {}: "
"missing={!r} excess={!r}".format(
len(self.__ocs),
len(arrays),
[oc.query for oc in self.__ocs[len(arrays):]],
arrays[len(self.__ocs):]))
pi = self.__partition_index
if pi is not None:
pdt = self.__ocs[pi].dtype
arrays[pi] = arrays[pi].astype(
pdt,
casting="unsafe" if pdt == UINT64 else "same_kind")
ret = self._sync_qe.sync_io(*[
oc.write(array)
for oc, array in zip(self.__ocs, arrays)
])
assert ret == [None] * len(arrays)
@final
class SynchronousQueryExecution(ExplicitInheritance):
"""Adapts synchronous code to the async query engine"""
@override
def __init__(self, qe):
self.__qe = the(QueryExecution, qe)
self.__sync_yield = None
self.block_size = qe.block_size
def sync_io(self, *io_requests):
"""Multi-IO gateway.
Does the same thing as QueryExecution.async_io(), but in a
synchronous fashion. Can be used only inside the call frame of
the function passed to async_setup().
"""
return self.__sync_yield(self.__qe.async_io(*io_requests))
@staticmethod
async def async_setup(qe,
input_queries,
output_queries,
*,
sync_io=()):
"""Set up a synchronous IO query.
Wrap QueryExecution.async_setup(). Returns a triple (SYNC_QE,
ICS, OCS), where ICS and OCS are as for
QueryExecution.async_setup() and SYNC_QE is an object for actually
triggering the synchronous query execution.
"""
inps = []
outs = []
sync_io = tuple(sync_io)
assert assert_seq_type(tuple, SynchronousIoSpec, sync_io)
for spec in sync_io:
spec.gather(inps, outs)
inps.extend(input_queries)
outs.extend(output_queries)
ics, ocs = await qe.async_setup(inps, outs)
ics = list(ics)
ocs = list(ocs)
sync_qe = SynchronousQueryExecution(qe)
for spec in sync_io:
spec.prepare(sync_qe, ics, ocs)
return sync_qe, ics, ocs
async def async_run(self, operator_function):
"""Actually perform the synchronous computation
OPERATOR_FUNCTION is called with no arguments in either a fiber or
a thread and should use pre-configured synchronous input and
output objects or sync_io() directly.
"""
def _do_it(sync_yield):
assert callable(sync_yield)
self.__sync_yield = sync_yield
operator_function()
await SynchronousCoroutine(_do_it, on_stack=True)