blob: 562de283b9338d23d53fa6769678c915f42d8b2f [file] [log] [blame]
// Copyright (C) 2020 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "io_spec.h"
#include <type_traits>
#include "block.h"
#include "fmt.h"
#include "input_channel.h"
#include "output_channel.h"
#include "pyerrfmt.h"
namespace dctv {
struct IoSummary {
npy_intp nr_pending_writes = 0;
npy_intp nr_backpressure_writes = 0;
npy_intp nr_unsatisfied_inputs = 0;
};
int
IoInput::py_traverse(visitproc visit, void* arg) const noexcept
{
if (int ret = this->IoRequest::py_traverse(visit, arg); ret)
return ret;
Py_VISIT(this->channel.get()); // NOLINT
return 0;
}
void
IoInput::setup(const OperatorContext* oc) const
{
this->IoRequest::setup(oc);
if (this->channel->get_owner() != oc)
throw_pyerr_fmt(PyExc_ValueError, "channel mismatch");
if (this->channel->get_read_wanted_hint() >= 0)
throw_pyerr_fmt(PyExc_RuntimeError, "duplicate channel read");
assume(this->max_wanted_rows >= 0);
this->channel->set_read_wanted_hint(this->max_wanted_rows);
}
void
IoInput::undo_setup() const noexcept
{
assume(this->channel->get_read_wanted_hint() >= 0);
this->channel->set_read_wanted_hint(-1);
this->IoRequest::undo_setup();
}
void
IoInput::summarize(IoSummary* summary) const noexcept
{
this->IoRequest::summarize(summary);
bool satisfied =
this->channel->is_eof() ||
this->min_wanted_rows <= this->channel->get_nr_available_rows();
if (!satisfied)
summary->nr_unsatisfied_inputs += 1;
}
unique_pyref
IoInput::do_it()
{
this->undo_setup();
assume(this->min_wanted_rows <= this->channel->get_nr_available_rows() ||
this->channel->is_eof());
return this->channel->read(
std::min(this->max_wanted_rows,
this->channel->get_nr_available_rows()));
}
unique_pyref
IoInputScalar::do_it()
{
unique_obj_pyref<Block> block = this->IoInput::do_it()
.notnull().addref_as_unsafe<Block>();
if (block->get_size() != 1)
throw_invalid_query(
fmt("wanted scalar result, but found %s rows", block->get_size()));
if (block->has_mask()) {
// We shouldn't have a mask at all if it's just empty, but let's
// cope just in case. Use assert, not assume, here so that we
// don't call into Python in optimized builds. (Assume evaluates
// its argument, unfortunately.)
assert(true_(get_item(block->get_mask(), 0)));
return addref(Py_None);
}
unique_pyarray data = block->get_data();
assume(npy_ndim(data) == 1);
const char* dataptr = npy_data<const char>(data);
return adopt_check(PyArray_GETITEM(data.get(), dataptr));
}
unique_pyref
IoInputInt::do_it()
{
unique_pyref scalar = this->IoInputScalar::do_it();
check_pytype_exact(&PyLong_Type, scalar);
return scalar;
}
int
IoOutput::py_traverse(visitproc visit, void* arg) const noexcept
{
if (int ret = this->IoRequest::py_traverse(visit, arg); ret)
return ret;
Py_VISIT(this->channel.get()); // NOLINT
Py_VISIT(this->data.get()); // NOLINT
return 0;
}
void
IoOutput::summarize(IoSummary* summary) const noexcept
{
this->IoRequest::summarize(summary);
summary->nr_pending_writes += 1;
if (this->channel->has_backpressure())
summary->nr_backpressure_writes += 1;
}
bool
IoOutput::try_write_1(bool force, bool is_eof)
{
bool did_work = false;
if (force || !this->channel->has_backpressure()) {
this->channel->add_data(std::move(this->data), is_eof);
did_work = true;
}
return did_work;
}
bool
IoOutput::try_write(bool force)
{
return this->try_write_1(force, /*is_eof=*/false);
}
bool
IoOutputEof::try_write(bool force)
{
return this->try_write_1(force, /*is_eof=*/true);
}
bool
IoResizeBuffer::try_write(bool force)
{
if (!force && this->channel->has_buffer_resize_backpressure())
return false;
this->channel->resize_buffer_now();
return true;
}
void
IoResizeBuffer::summarize(IoSummary* summary) const noexcept
{
this->IoRequest::summarize(summary);
summary->nr_pending_writes += 1;
if (this->channel->has_buffer_resize_backpressure())
summary->nr_backpressure_writes += 1;
}
bool
IoTerminalFlush::try_write(bool force)
{
if (!force && this->channel->has_flush_backpressure())
return false;
this->channel->flush();
return true;
}
void
IoTerminalFlush::summarize(IoSummary* summary) const noexcept
{
this->IoRequest::summarize(summary);
summary->nr_pending_writes += 1;
if (this->channel->has_flush_backpressure())
summary->nr_backpressure_writes += 1;
}
Score
compute_io_score(Score score,
const IoSpec* io_specs,
size_t nr_specs)
{
// TODO(dancol): bump score when we have a _lot_ of queued data
// TODO(dancol): backpressure input into read side?
IoSummary summary;
for (size_t i = 0; i < nr_specs; ++i)
io_specs[i].summarize(&summary);
assume(summary.nr_backpressure_writes <= summary.nr_pending_writes);
assume(static_cast<size_t>(summary.nr_pending_writes +
summary.nr_unsatisfied_inputs) <= nr_specs);
if (summary.nr_pending_writes) {
score.state = summary.nr_backpressure_writes < summary.nr_pending_writes
? OperatorState::RUNNABLE
: OperatorState::RUNNABLE_UNDER_DURESS;
} else {
score.state = summary.nr_unsatisfied_inputs
? OperatorState::NOT_RUNNABLE
: OperatorState::RUNNABLE;
}
return score;
}
static
bool
try_writes(IoSpec* io_specs, size_t nr_specs, bool force)
{
bool did_work = false; // Make sure not to bail early
for (size_t i = 0; i < nr_specs; ++i)
if (io_specs[i].try_write(force)) {
did_work = true;
io_specs[i] = IoSpec(IoDummy());
}
return did_work;
}
bool
do_writes_incremental(IoSpec* io_specs, size_t nr_specs)
{
return (try_writes(io_specs, nr_specs, /*force=*/false) ||
try_writes(io_specs, nr_specs, /*force=*/true));
}
void
io_setup(OperatorContext* oc, IoSpec* io_specs, size_t nr_specs)
{
size_t nr_initialized = 0;
auto undo_setup = [&] {
for (size_t i = 0; i < nr_initialized; ++i)
io_specs[i].undo_setup();
};
FINALLY(if (nr_initialized < nr_specs) undo_setup());
for (; nr_initialized < nr_specs; ++nr_initialized)
io_specs[nr_initialized].setup(oc);
}
} // namespace dctv