blob: 7155e22d1059252da3d189b42071dac61f2e269f [file] [log] [blame]
// Copyright (C) 2020 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "span_group_column.h"
#include <string_view>
#include "command_stream.h"
#include "hash_table.h"
#include "native_aggregation.h"
#include "npyiter.h"
#include "optional.h"
#include "pyerr.h"
#include "pyobj.h"
#include "pyparsetuple.h"
#include "pyparsetuplenpy.h"
#include "result_buffer.h"
#include "span_group.h"
namespace dctv {
namespace {
template<typename Kernel>
struct PartitionState final {
using Out = typename Kernel::Out;
inline void reset(bool keep_last);
inline void copy_last_to(PartitionState* other, NaContext* nac);
inline void accumulate(Out value, bool mask, NaContext* nac);
inline optional<Out> sample();
private:
optional<Out> last_value;
optional<Kernel> kernel;
};
template<typename Kernel>
void
PartitionState<Kernel>::reset(bool keep_last)
{
this->kernel.reset();
if (keep_last) {
if (this->last_value)
this->kernel = Kernel(*this->last_value);
} else {
this->last_value.reset();
}
}
template<typename Kernel>
void
PartitionState<Kernel>::copy_last_to(PartitionState* other, NaContext* nac)
{
if (this->last_value)
other->accumulate(*this->last_value, /*mask=*/false, nac);
else
other->accumulate(Out{}, /*mask=*/true, nac);
}
template<typename Kernel>
void
PartitionState<Kernel>::accumulate(Out value, bool mask, NaContext* nac)
{
if (!mask) {
this->last_value = value;
if (this->kernel)
this->kernel->accumulate(value, nac);
else
this->kernel = Kernel(value);
} else {
this->last_value.reset();
}
}
template<typename Kernel>
optional<typename PartitionState<Kernel>::Out>
PartitionState<Kernel>::sample()
{
return this->kernel ? this->kernel->get() : Kernel::empty_value;
}
struct SgGroupNa {
template<typename Value, typename Aggregation>
static void agg_column(CommandIn ci,
pyref py_data,
pyref py_out,
NaContext nac,
QueryExecution* qe);
};
template<typename Value, typename Aggregation>
void
SgGroupNa::agg_column(CommandIn ci,
pyref py_data,
pyref py_out,
NaContext nac,
QueryExecution* qe)
{
using Kernel = typename Aggregation::template Kernel<Value>;
using Ps = PartitionState<Kernel>;
using Out = typename Ps::Out;
Ps global_ps;
HashTable<Partition, Ps> partitions;
npy_uint32 extra_op_flags[2] = { NPY_ITER_NO_BROADCAST, 0 };
NpyIterConfig config;
config.extra_op_flags = &extra_op_flags[0];
PythonChunkIterator<Value, bool> data(py_data.notnull().addref(), config);
ResultBuffer<Out, bool> out(py_out.notnull().addref(), qe);
auto slurp_and_accumulate = [&](Ps* ps) {
if (data.is_at_eof())
throw_pyerr_msg(PyExc_AssertionError, "data underflow");
auto [value, mask] = data.get();
ps->accumulate(value, mask, &nac);
data.advance();
};
auto emit = [&](Ps* ps) {
optional<Out> sample = ps->sample();
if (sample.has_value())
out.add(*sample, /*masked=*/false);
else
out.add(Out(), /*masked=*/true);
};
while (!ci.is_at_eof()) {
switch (ci.next<SgCommand>()) {
case SgCommand::RESET: {
Partition p = ci.next<Partition>();
partitions[p].reset(/*keep_last=*/false);
break;
}
case SgCommand::RESET_GLOBAL: {
global_ps.reset(/*keep_last=*/false);
break;
}
case SgCommand::RESET_KEEP_LAST: {
Partition p = ci.next<Partition>();
partitions[p].reset(/*keep_last=*/true);
break;
}
case SgCommand::RESET_GLOBAL_KEEP_LAST: {
global_ps.reset(/*keep_last=*/true);
break;
}
case SgCommand::COPY_LAST_GLOBAL_TO: {
Partition to = ci.next<Partition>();
global_ps.copy_last_to(&partitions[to], &nac);
break;
}
case SgCommand::SLURP_AND_ACCUMULATE: {
Partition p = ci.next<Partition>();
slurp_and_accumulate(&partitions[p]);
break;
}
case SgCommand::SLURP_AND_ACCUMULATE_GLOBAL: {
slurp_and_accumulate(&global_ps);
break;
}
case SgCommand::EMIT: {
Partition p = ci.next<Partition>();
emit(&partitions[p]);
break;
}
case SgCommand::EMIT_GLOBAL: {
emit(&global_ps);
break;
}
case SgCommand::FORGET_PARTITION: {
Partition p = ci.next<Partition>();
auto it = partitions.find(p);
if (it != partitions.end())
partitions.erase(it);
break;
}
}
}
out.flush();
}
PyMethodDef functions[] = {
make_methoddef("span_group_column",
wraperr<na_agg_impl<SgGroupNa>>(),
METH_VARARGS,
"Compute a column of a span aggregation"),
{ 0 }
};
} // anonymous namespace
void
init_span_group_column(pyref m)
{
register_functions(m, functions);
}
} // namespace dctv