blob: 0a03ebda7fd489b4a8dde8cb8b957216076bd60b [file] [log] [blame]
// Copyright (C) 2020 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "dctv.h"
#include <boost/operators.hpp>
#include <functional>
#include <memory>
#include <tuple>
#include "fmt.h"
#include "npy.h"
#include "pyseq.h"
#include "pyutil.h"
namespace dctv {
struct Hunk;
struct InputChannel;
struct IoSpec;
struct OperatorContext;
struct OutputChannel;
struct QueryCache;
struct QueryExecution;
struct QueryKey;
struct StringTable;
using op_ref = obj_pyref<OperatorContext>;
using unique_op_ref = unique_obj_pyref<OperatorContext>;
struct QueryClasses final : NoCopy, NoMove {
unique_obj_pyref<PyTypeObject> cls_tmpdir;
unique_obj_pyref<PyTypeObject> cls_query_node;
unique_obj_pyref<PyTypeObject> cls_query_action;
unique_obj_pyref<PyTypeObject> cls_invalid_query_exception;
unique_pyref np_ma_nomask;
unique_obj_pyref<PyTypeObject> cls_masked_array;
inline static const QueryClasses& get();
inline bool initialized() const;
static void ensure();
private:
static QueryClasses instance;
};
enum class OperatorState { // Priority order: most important last
NOT_RUNNABLE,
RUNNABLE_UNDER_DURESS,
RUNNABLE,
YIELDING,
TERMINATED,
};
// The value we use for sorting actions to be run. At each turn of
// the query crank, the operator with the highest score runs.
struct Score final : boost::totally_ordered<Score> {
OperatorState state = OperatorState::NOT_RUNNABLE;
int negated_ordinal = 0;
inline bool operator<(const Score& other) const noexcept;
inline bool operator==(const Score& other) const noexcept;
auto get_sort_key() const noexcept {
return std::tie(this->state, this->negated_ordinal);
}
bool is_valid() const { return this->negated_ordinal; }
};
struct QueryDb;
struct QeCall : NoCopy {
// Called when a QeCall becomes current for an operator.
virtual void setup(OperatorContext* oc) {}
// Called on demand to recompute the operator's score.
virtual Score compute_score(const OperatorContext* oc) const = 0;
// Called to perform the operation indicated by the QeCall; should
// call oc->communicate() with the value to pass to the coroutine
// and undo any state changes done in setup(). If the return value
// is not nullptr, yield that value to the query runner.
virtual unique_pyref do_it(OperatorContext* oc) = 0;
// Do first-time operator setup. Note that oc->qe is _not_ set up
// at this point and that it shouldn't be, since oc->qe is the
// _link_ owner and at the time we call do_operator_setup, the
// operator isn't on any of qe's queues.
virtual void do_operator_setup(QueryExecution* qe,
OperatorContext* oc,
QueryDb* query_db);
// Special call for consolidating IO.
virtual IoSpec extract_io_spec();
// Implement Python GC traversal.
virtual int py_traverse(visitproc visit, void* arg) const noexcept = 0;
// Refresh a resize-buffers pending call.
virtual void refresh_resize_buffers(
const OperatorContext* oc,
std::unique_ptr<QeCall>* self_ptr);
virtual ~QeCall() = default;
};
unique_pyref make_qe_call(std::unique_ptr<QeCall> qe_call);
unique_pyref make_qe_call_io_single(IoSpec&& io_spec);
unique_pyref make_qe_call_io_multi(pytuple::ref py_io_specs);
unique_pyref make_qe_call_setup(pytuple::ref args);
unique_pyref make_qe_call_yield_to_query_runner(unique_pyref what);
DCTV_NORETURN_ERROR
void throw_invalid_query(String msg);
template<typename... Args>
DCTV_NORETURN_ERROR
void throw_invalid_query_fmt(const char* format, Args&&... args);
// Return whether an object is a masked_array If it is, return a
// non-null reference as an ndarray, otherwise null.
inline pyarray_ref maybe_ma(pyref obj);
// The nomask singleton
inline pyref nomask();
// Return whether the given array has a mask. An array has a mask if
// it's a masked_array and the mask field is not nomask.
bool npy_has_mask(pyarray_ref array);
// Retrieve the data of an array that we know is definitely a
// masked_array. It is a programming error to call these functions on
// arrays that are not masked arrays.
unique_pyarray npy_get_data_of_ma(pyarray_ref ma);
// Retrieve the mask of an array that we know is definitely a
// masked_array. If the mask is nomask, return a null reference.
unique_pyarray npy_get_mask_of_ma(pyarray_ref ma);
// Return the underlying broadcasted singleton array that view ARRAY
// refers to. Works for both regular and masked arrays. Return the
// empty reference if ARRAY isn't actually broadcast.
unique_pyarray npy_get_broadcaster(pyarray_ref array);
// Return whether npy_get_broadcaster would have returned non-null.
bool npy_is_broadcasted(pyarray_ref array) noexcept;
// Lower-level version of npy_get_broadcaster. Works only for
// base-type ndarrays. Not exposed to Python. Does not massage array
// dimensions. Return the empty reference if ARRAY is not a
// broadcasted array.
pyarray_ref npy_get_broadcaster_raw(pyarray_ref array) noexcept;
// Make a masked_array.
unique_pyarray make_masked_array(pyref data, pyref mask);
// Make an uninitialized array backed by a query cache hunk.
unique_pyarray make_uninit_hunk_array(obj_pyref<QueryCache> qc,
unique_dtype type,
npy_intp nelem);
unique_pyarray make_uninit_hunk_array(obj_pyref<QueryExecution> qe,
unique_dtype type,
npy_intp nelem);
// Collapse views that do not changed the viewed data.
unique_pyarray collapse_useless_views(unique_pyarray array);
// Helper functions that avoid having to include whole headers to
// access some common fields of central objects.
StringTable* st_from_qe(QueryExecution* qe);
npy_intp block_size_from_qc(QueryCache* qc);
void maybe_drop_mask(unique_pyarray* mask);
void init_query(pyref m);
} // namespace dctv
#include "query-inl.h"