| // Copyright (C) 2020 The Android Open Source Project |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| #include "span_group_column.h" |
| |
| #include <string_view> |
| |
| #include "command_stream.h" |
| #include "hash_table.h" |
| #include "native_aggregation.h" |
| #include "npyiter.h" |
| #include "optional.h" |
| #include "pyerr.h" |
| #include "pyobj.h" |
| #include "pyparsetuple.h" |
| #include "pyparsetuplenpy.h" |
| #include "result_buffer.h" |
| #include "span_group.h" |
| |
| namespace dctv { |
| namespace { |
| |
| template<typename Kernel> |
| struct PartitionState final { |
| using Out = typename Kernel::Out; |
| inline void reset(bool keep_last); |
| inline void copy_last_to(PartitionState* other, NaContext* nac); |
| inline void accumulate(Out value, bool mask, NaContext* nac); |
| inline optional<Out> sample(); |
| |
| private: |
| optional<Out> last_value; |
| optional<Kernel> kernel; |
| }; |
| |
| template<typename Kernel> |
| void |
| PartitionState<Kernel>::reset(bool keep_last) |
| { |
| this->kernel.reset(); |
| if (keep_last) { |
| if (this->last_value) |
| this->kernel = Kernel(*this->last_value); |
| } else { |
| this->last_value.reset(); |
| } |
| } |
| |
| template<typename Kernel> |
| void |
| PartitionState<Kernel>::copy_last_to(PartitionState* other, NaContext* nac) |
| { |
| if (this->last_value) |
| other->accumulate(*this->last_value, /*mask=*/false, nac); |
| else |
| other->accumulate(Out{}, /*mask=*/true, nac); |
| } |
| |
| template<typename Kernel> |
| void |
| PartitionState<Kernel>::accumulate(Out value, bool mask, NaContext* nac) |
| { |
| if (!mask) { |
| this->last_value = value; |
| if (this->kernel) |
| this->kernel->accumulate(value, nac); |
| else |
| this->kernel = Kernel(value); |
| } else { |
| this->last_value.reset(); |
| } |
| } |
| |
| template<typename Kernel> |
| optional<typename PartitionState<Kernel>::Out> |
| PartitionState<Kernel>::sample() |
| { |
| return this->kernel ? this->kernel->get() : Kernel::empty_value; |
| } |
| |
| struct SgGroupNa { |
| template<typename Value, typename Aggregation> |
| static void agg_column(CommandIn ci, |
| pyref py_data, |
| pyref py_out, |
| NaContext nac, |
| QueryExecution* qe); |
| |
| }; |
| |
| template<typename Value, typename Aggregation> |
| void |
| SgGroupNa::agg_column(CommandIn ci, |
| pyref py_data, |
| pyref py_out, |
| NaContext nac, |
| QueryExecution* qe) |
| { |
| using Kernel = typename Aggregation::template Kernel<Value>; |
| using Ps = PartitionState<Kernel>; |
| using Out = typename Ps::Out; |
| |
| Ps global_ps; |
| HashTable<Partition, Ps> partitions; |
| npy_uint32 extra_op_flags[2] = { NPY_ITER_NO_BROADCAST, 0 }; |
| NpyIterConfig config; |
| config.extra_op_flags = &extra_op_flags[0]; |
| PythonChunkIterator<Value, bool> data(py_data.notnull().addref(), config); |
| ResultBuffer<Out, bool> out(py_out.notnull().addref(), qe); |
| |
| auto slurp_and_accumulate = [&](Ps* ps) { |
| if (data.is_at_eof()) |
| throw_pyerr_msg(PyExc_AssertionError, "data underflow"); |
| auto [value, mask] = data.get(); |
| ps->accumulate(value, mask, &nac); |
| data.advance(); |
| }; |
| |
| auto emit = [&](Ps* ps) { |
| optional<Out> sample = ps->sample(); |
| if (sample.has_value()) |
| out.add(*sample, /*masked=*/false); |
| else |
| out.add(Out(), /*masked=*/true); |
| }; |
| |
| while (!ci.is_at_eof()) { |
| switch (ci.next<SgCommand>()) { |
| case SgCommand::RESET: { |
| Partition p = ci.next<Partition>(); |
| partitions[p].reset(/*keep_last=*/false); |
| break; |
| } |
| case SgCommand::RESET_GLOBAL: { |
| global_ps.reset(/*keep_last=*/false); |
| break; |
| } |
| case SgCommand::RESET_KEEP_LAST: { |
| Partition p = ci.next<Partition>(); |
| partitions[p].reset(/*keep_last=*/true); |
| break; |
| } |
| case SgCommand::RESET_GLOBAL_KEEP_LAST: { |
| global_ps.reset(/*keep_last=*/true); |
| break; |
| } |
| case SgCommand::COPY_LAST_GLOBAL_TO: { |
| Partition to = ci.next<Partition>(); |
| global_ps.copy_last_to(&partitions[to], &nac); |
| break; |
| } |
| case SgCommand::SLURP_AND_ACCUMULATE: { |
| Partition p = ci.next<Partition>(); |
| slurp_and_accumulate(&partitions[p]); |
| break; |
| } |
| case SgCommand::SLURP_AND_ACCUMULATE_GLOBAL: { |
| slurp_and_accumulate(&global_ps); |
| break; |
| } |
| case SgCommand::EMIT: { |
| Partition p = ci.next<Partition>(); |
| emit(&partitions[p]); |
| break; |
| } |
| case SgCommand::EMIT_GLOBAL: { |
| emit(&global_ps); |
| break; |
| } |
| case SgCommand::FORGET_PARTITION: { |
| Partition p = ci.next<Partition>(); |
| auto it = partitions.find(p); |
| if (it != partitions.end()) |
| partitions.erase(it); |
| break; |
| } |
| } |
| } |
| |
| out.flush(); |
| } |
| |
| PyMethodDef functions[] = { |
| make_methoddef("span_group_column", |
| wraperr<na_agg_impl<SgGroupNa>>(), |
| METH_VARARGS, |
| "Compute a column of a span aggregation"), |
| { 0 } |
| }; |
| |
| } // anonymous namespace |
| |
| void |
| init_span_group_column(pyref m) |
| { |
| register_functions(m, functions); |
| } |
| |
| } // namespace dctv |