| // Copyright (C) 2020 The Android Open Source Project |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| #include "io_spec.h" |
| |
| #include <type_traits> |
| |
| #include "block.h" |
| #include "fmt.h" |
| #include "input_channel.h" |
| #include "output_channel.h" |
| #include "pyerrfmt.h" |
| |
| namespace dctv { |
| |
| struct IoSummary { |
| npy_intp nr_pending_writes = 0; |
| npy_intp nr_backpressure_writes = 0; |
| npy_intp nr_unsatisfied_inputs = 0; |
| }; |
| |
| int |
| IoInput::py_traverse(visitproc visit, void* arg) const noexcept |
| { |
| if (int ret = this->IoRequest::py_traverse(visit, arg); ret) |
| return ret; |
| Py_VISIT(this->channel.get()); // NOLINT |
| return 0; |
| } |
| |
| void |
| IoInput::setup(const OperatorContext* oc) const |
| { |
| this->IoRequest::setup(oc); |
| if (this->channel->get_owner() != oc) |
| throw_pyerr_fmt(PyExc_ValueError, "channel mismatch"); |
| if (this->channel->get_read_wanted_hint() >= 0) |
| throw_pyerr_fmt(PyExc_RuntimeError, "duplicate channel read"); |
| assume(this->max_wanted_rows >= 0); |
| this->channel->set_read_wanted_hint(this->max_wanted_rows); |
| } |
| |
| void |
| IoInput::undo_setup() const noexcept |
| { |
| assume(this->channel->get_read_wanted_hint() >= 0); |
| this->channel->set_read_wanted_hint(-1); |
| this->IoRequest::undo_setup(); |
| } |
| |
| void |
| IoInput::summarize(IoSummary* summary) const noexcept |
| { |
| this->IoRequest::summarize(summary); |
| bool satisfied = |
| this->channel->is_eof() || |
| this->min_wanted_rows <= this->channel->get_nr_available_rows(); |
| if (!satisfied) |
| summary->nr_unsatisfied_inputs += 1; |
| } |
| |
| unique_pyref |
| IoInput::do_it() |
| { |
| this->undo_setup(); |
| assume(this->min_wanted_rows <= this->channel->get_nr_available_rows() || |
| this->channel->is_eof()); |
| return this->channel->read( |
| std::min(this->max_wanted_rows, |
| this->channel->get_nr_available_rows())); |
| } |
| |
| unique_pyref |
| IoInputScalar::do_it() |
| { |
| unique_obj_pyref<Block> block = this->IoInput::do_it() |
| .notnull().addref_as_unsafe<Block>(); |
| if (block->get_size() != 1) |
| throw_invalid_query( |
| fmt("wanted scalar result, but found %s rows", block->get_size())); |
| if (block->has_mask()) { |
| // We shouldn't have a mask at all if it's just empty, but let's |
| // cope just in case. Use assert, not assume, here so that we |
| // don't call into Python in optimized builds. (Assume evaluates |
| // its argument, unfortunately.) |
| assert(true_(get_item(block->get_mask(), 0))); |
| return addref(Py_None); |
| } |
| unique_pyarray data = block->get_data(); |
| assume(npy_ndim(data) == 1); |
| const char* dataptr = npy_data<const char>(data); |
| return adopt_check(PyArray_GETITEM(data.get(), dataptr)); |
| } |
| |
| unique_pyref |
| IoInputInt::do_it() |
| { |
| unique_pyref scalar = this->IoInputScalar::do_it(); |
| check_pytype_exact(&PyLong_Type, scalar); |
| return scalar; |
| } |
| |
| int |
| IoOutput::py_traverse(visitproc visit, void* arg) const noexcept |
| { |
| if (int ret = this->IoRequest::py_traverse(visit, arg); ret) |
| return ret; |
| Py_VISIT(this->channel.get()); // NOLINT |
| Py_VISIT(this->data.get()); // NOLINT |
| return 0; |
| } |
| |
| void |
| IoOutput::summarize(IoSummary* summary) const noexcept |
| { |
| this->IoRequest::summarize(summary); |
| summary->nr_pending_writes += 1; |
| if (this->channel->has_backpressure()) |
| summary->nr_backpressure_writes += 1; |
| } |
| |
| bool |
| IoOutput::try_write_1(bool force, bool is_eof) |
| { |
| bool did_work = false; |
| if (force || !this->channel->has_backpressure()) { |
| this->channel->add_data(std::move(this->data), is_eof); |
| did_work = true; |
| } |
| return did_work; |
| } |
| |
| bool |
| IoOutput::try_write(bool force) |
| { |
| return this->try_write_1(force, /*is_eof=*/false); |
| } |
| |
| bool |
| IoOutputEof::try_write(bool force) |
| { |
| return this->try_write_1(force, /*is_eof=*/true); |
| } |
| |
| bool |
| IoResizeBuffer::try_write(bool force) |
| { |
| if (!force && this->channel->has_buffer_resize_backpressure()) |
| return false; |
| this->channel->resize_buffer_now(); |
| return true; |
| } |
| |
| void |
| IoResizeBuffer::summarize(IoSummary* summary) const noexcept |
| { |
| this->IoRequest::summarize(summary); |
| summary->nr_pending_writes += 1; |
| if (this->channel->has_buffer_resize_backpressure()) |
| summary->nr_backpressure_writes += 1; |
| } |
| |
| bool |
| IoTerminalFlush::try_write(bool force) |
| { |
| if (!force && this->channel->has_flush_backpressure()) |
| return false; |
| this->channel->flush(); |
| return true; |
| } |
| |
| void |
| IoTerminalFlush::summarize(IoSummary* summary) const noexcept |
| { |
| this->IoRequest::summarize(summary); |
| summary->nr_pending_writes += 1; |
| if (this->channel->has_flush_backpressure()) |
| summary->nr_backpressure_writes += 1; |
| } |
| |
| Score |
| compute_io_score(Score score, |
| const IoSpec* io_specs, |
| size_t nr_specs) |
| { |
| // TODO(dancol): bump score when we have a _lot_ of queued data |
| // TODO(dancol): backpressure input into read side? |
| |
| IoSummary summary; |
| for (size_t i = 0; i < nr_specs; ++i) |
| io_specs[i].summarize(&summary); |
| |
| assume(summary.nr_backpressure_writes <= summary.nr_pending_writes); |
| assume(static_cast<size_t>(summary.nr_pending_writes + |
| summary.nr_unsatisfied_inputs) <= nr_specs); |
| |
| if (summary.nr_pending_writes) { |
| score.state = summary.nr_backpressure_writes < summary.nr_pending_writes |
| ? OperatorState::RUNNABLE |
| : OperatorState::RUNNABLE_UNDER_DURESS; |
| } else { |
| score.state = summary.nr_unsatisfied_inputs |
| ? OperatorState::NOT_RUNNABLE |
| : OperatorState::RUNNABLE; |
| } |
| return score; |
| } |
| |
| static |
| bool |
| try_writes(IoSpec* io_specs, size_t nr_specs, bool force) |
| { |
| bool did_work = false; // Make sure not to bail early |
| for (size_t i = 0; i < nr_specs; ++i) |
| if (io_specs[i].try_write(force)) { |
| did_work = true; |
| io_specs[i] = IoSpec(IoDummy()); |
| } |
| return did_work; |
| } |
| |
| bool |
| do_writes_incremental(IoSpec* io_specs, size_t nr_specs) |
| { |
| return (try_writes(io_specs, nr_specs, /*force=*/false) || |
| try_writes(io_specs, nr_specs, /*force=*/true)); |
| } |
| |
| void |
| io_setup(OperatorContext* oc, IoSpec* io_specs, size_t nr_specs) |
| { |
| size_t nr_initialized = 0; |
| auto undo_setup = [&] { |
| for (size_t i = 0; i < nr_initialized; ++i) |
| io_specs[i].undo_setup(); |
| }; |
| FINALLY(if (nr_initialized < nr_specs) undo_setup()); |
| for (; nr_initialized < nr_specs; ++nr_initialized) |
| io_specs[nr_initialized].setup(oc); |
| } |
| |
| } // namespace dctv |