blob: 963f87901752a702ddbfa6378a31f0142c50edbf [file] [log] [blame]
// Copyright (C) 2020 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "query.h"
#include "hunk.h"
#include "io_spec.h"
#include "operator_context.h"
#include "pyerrfmt.h"
#include "pyobj.h"
#include "pyseq.h"
#include "qe_call_resize_buffers.h"
namespace dctv {
static_assert(sizeof (npy_intp) == sizeof (Py_ssize_t));
_Py_IDENTIFIER(_data);
_Py_IDENTIFIER(_mask);
void
QeCall::do_operator_setup(QueryExecution*,
OperatorContext* oc,
QueryDb*)
{
throw_pyerr_fmt(PyExc_RuntimeError,
"operator %s did not call setup",
repr(oc->get_query_action()));
}
IoSpec
QeCall::extract_io_spec()
{
throw_pyerr_fmt(PyExc_RuntimeError, "not an IO operation");
}
void
QeCall::refresh_resize_buffers(const OperatorContext* oc,
std::unique_ptr<QeCall>* self_ptr)
{
assume(self_ptr->get() == this);
// If we fail here, we'll leave *self_ptr empty, which will make
// QueryExecution complain if someone tries to continue execution.
std::unique_ptr<QeCall> new_qe_call =
std::make_unique<QeCallResizeBuffers>(
oc, self_ptr, std::move(*self_ptr));
*self_ptr = std::move(new_qe_call);
}
QueryClasses QueryClasses::instance;
void
QueryClasses::ensure()
{
// TODO(dancol): use less centralized class interning
if (!instance.initialized()) {
instance.cls_tmpdir =
find_pyclass("dctv.util", "TmpDir");
instance.cls_query_node =
find_pyclass("dctv.query", "QueryNode");
instance.cls_query_action =
find_pyclass("dctv.query", "QueryAction");
instance.cls_invalid_query_exception =
find_pyclass("dctv.query", "InvalidQueryException");
instance.np_ma_nomask =
getattr(import_pymodule("numpy.ma"), "nomask");
instance. cls_masked_array =
find_pyclass("numpy.ma", "masked_array");
assert(instance.initialized());
}
}
void
throw_invalid_query(String msg) // NOLINT: min call site size
{
throw_pyerr_msg(QueryClasses::get().cls_invalid_query_exception.get(),
msg.c_str());
}
unique_pyarray
collapse_useless_views(unique_pyarray array)
{
if (npy_ndim(array) != 1)
return array;
for (;;) {
obj_pyref base = npy_base(array);
if (!base)
return array;
if (!is_base_ndarray(base))
return array;
pyarray_ref abase = base.as_unsafe<PyArrayObject>();
if (npy_ndim(abase) != 1)
return array;
if (npy_size1d(abase) != npy_size1d(array))
return array;
if (npy_dtype(abase) != npy_dtype(array))
return array;
if (npy_data_raw(abase) != npy_data_raw(array))
return array;
array = abase.addref();
}
}
unique_pyarray
npy_get_data_of_ma(pyarray_ref ma)
{
assert(maybe_ma(ma));
unique_pyarray data = getattr(ma, &PyId__data)
.addref_as_unsafe<PyArrayObject>();
assert(is_base_ndarray(data));
return collapse_useless_views(std::move(data));
}
unique_pyref
npy_get_mask_or_nomask_of_ma(pyarray_ref ma)
{
assert(maybe_ma(ma));
unique_pyref m = getattr(ma, &PyId__mask);
assert(m == nomask() || is_base_ndarray(m));
if (m == nomask())
return m;
return collapse_useless_views(
std::move(m).addref_as_unsafe<PyArrayObject>());
}
unique_pyarray
npy_get_mask_of_ma(pyarray_ref ma)
{
unique_pyref m = npy_get_mask_or_nomask_of_ma(ma);
if (m == nomask())
return {};
assert(is_base_ndarray(m));
return std::move(m).addref_as_unsafe<PyArrayObject>();
}
bool
npy_has_mask(pyarray_ref array)
{
pyarray_ref ma = maybe_ma(array);
return ma && npy_get_mask_of_ma(ma);
}
static
unique_pyref
py_npy_has_mask(PyObject*, pyref obj)
{
return make_pybool(npy_has_mask(obj.as<PyArrayObject>()));
}
static
unique_pyref
py_npy_get_data(PyObject*, pyref obj)
{
pyarray_ref ma = maybe_ma(obj);
if (ma)
return npy_get_data_of_ma(ma);
return obj.addref();
}
static
unique_pyref
py_npy_get_mask(PyObject*, pyref obj)
{
pyarray_ref ma = maybe_ma(obj);
if (ma)
return npy_get_mask_or_nomask_of_ma(ma);
return nomask().addref();
}
static
unique_pyref
py_npy_explode(PyObject*, pyref obj)
{
pyarray_ref ma = maybe_ma(obj);
if (!ma)
return pytuple::of(obj, nomask());
return pytuple::of(npy_get_data_of_ma(ma),
npy_get_mask_or_nomask_of_ma(ma));
}
template<bool JustCheck>
static std::conditional_t<
JustCheck,
bool,
unique_pyarray> npy_get_broadcaster_1(
pyarray_ref array) noexcept(JustCheck);
template<bool JustCheck>
static
std::conditional_t<JustCheck, bool, unique_pyarray>
npy_get_broadcaster_ma_1(pyarray_ref ma)
{
using Wanted = std::conditional_t<JustCheck, bool, unique_pyarray>;
unique_pyarray data = npy_get_data_of_ma(ma);
Wanted data_broadcaster = npy_get_broadcaster_1<JustCheck>(data);
if (!data_broadcaster)
return {};
unique_pyarray mask = npy_get_mask_of_ma(ma);
if (!mask)
return data_broadcaster;
Wanted mask_broadcaster = npy_get_broadcaster_1<JustCheck>(mask);
if (!mask_broadcaster)
return {};
if constexpr (JustCheck) {
return true;
} else { // NOLINT
return make_masked_array(data_broadcaster, mask_broadcaster);
}
}
pyarray_ref
npy_get_broadcaster_raw(pyarray_ref array) noexcept
{
if (!is_base_ndarray(array))
return {};
if (npy_ndim(array) != 1)
return {};
if (npy_strides(array)[0] != 0)
return {};
pyref base = npy_base(array);
if (!base || !is_base_ndarray(base))
return {};
pyarray_ref abase = base.as_unsafe<PyArrayObject>();
if (npy_dtype(array) != npy_dtype(abase))
return {};
if (npy_data_raw(array) != npy_data_raw(abase))
return {};
if (npy_ndim(abase) == 0 ||
(npy_ndim(abase) == 1 && npy_dims(abase)[0] == 1))
return abase;
return {};
}
template<bool JustCheck>
static
std::conditional_t<JustCheck, bool, unique_pyarray>
npy_get_broadcaster_1(pyarray_ref array) noexcept(JustCheck)
{
pyarray_ref ma = maybe_ma(array);
if (ma)
return npy_get_broadcaster_ma_1<JustCheck>(ma);
pyarray_ref bcaster_raw = npy_get_broadcaster_raw(array);
if (!bcaster_raw)
return {};
if constexpr (JustCheck) {
return true;
} else { // NOLINT
if (npy_ndim(bcaster_raw) == 0)
return npy_ravel(bcaster_raw);
return bcaster_raw.addref();
}
}
unique_pyarray
npy_get_broadcaster(pyarray_ref array)
{
return npy_get_broadcaster_1<false>(array);
}
bool
npy_is_broadcasted(pyarray_ref array) noexcept
{
return npy_get_broadcaster_1<true>(array);
}
static
unique_pyref
py_npy_get_broadcaster(PyObject*, pyref obj)
{
pyarray_ref array = obj.as<PyArrayObject>();
unique_pyarray broadcaster = npy_get_broadcaster(array);
if (!broadcaster)
return addref(Py_None);
return broadcaster;
}
static
unique_pyref
py_npy_is_broadcasted(PyObject*, pyref obj)
{
pyarray_ref array = obj.as<PyArrayObject>();
return make_pybool(npy_is_broadcasted(array));
}
unique_pyarray
make_masked_array(pyref data, pyref mask)
{
return call(QueryClasses::get().cls_masked_array, data, mask)
.addref_as_unsafe<PyArrayObject>();
}
void
maybe_drop_mask(unique_pyarray* mask)
{
if (!*mask)
return;
pyarray_ref bcaster = npy_get_broadcaster_raw(*mask);
if (bcaster) {
assume(npy_ndim(bcaster) <= 1);
if (npy_ndim(bcaster))
assert(npy_size1d(bcaster) == 1);
assume(npy_dtype(bcaster) == type_descr<bool>());
if (!npy_data<bool>(bcaster)[0]) {
mask->reset();
}
} else {
// TODO(dancol): is the full scan worth it?
if (!npy_any(*mask))
mask->reset();
}
}
static PyMethodDef functions[] = {
make_methoddef("npy_has_mask",
wraperr<py_npy_has_mask>(),
METH_O,
"Return whether an object has a mask"),
make_methoddef("npy_get_mask",
wraperr<py_npy_get_mask>(),
METH_O,
"Return an object's mask if present or None"),
make_methoddef("npy_get_data",
wraperr<py_npy_get_data>(),
METH_O,
"Return data of masked array or array"),
make_methoddef("npy_explode",
wraperr<py_npy_explode>(),
METH_O,
"Split an array into value and mask"),
make_methoddef("npy_get_broadcaster",
wraperr<py_npy_get_broadcaster>(),
METH_O,
"Return the underlying broadcasted array "
"or None if array is not broadcasted"),
make_methoddef("npy_is_broadcasted",
wraperr<py_npy_is_broadcasted>(),
METH_O,
"Return the underlying broadcasted array "
"or None if array is not broadcasted"),
{ 0 }
};
void
init_query(pyref m)
{
register_functions(m, functions);
}
} // namespace dctv