blob: 1edd311c0e04d5632d56115a51fc375fabe1c22c [file] [log] [blame]
// Copyright (C) 2020 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace dctv {
template<typename Matcher, typename Functor>
auto
for_each_native_aggregation(Matcher&& matcher, Functor&& functor)
{
// We use this somewhat awkward function so that the list of
// aggregation function names is in exactly one place in DCTV.
// N.B. COUNT(*) is special cased in Python
if (matcher("count"))
return functor(CountAggregation());
if (matcher("first"))
return functor(FirstAggregation());
if (matcher("max"))
return functor(MaxAggregation());
if (matcher("min"))
return functor(MinAggregation());
if (matcher("prod"))
return functor(ProdAggregation());
if (matcher("sum"))
return functor(SumAggregation());
if (matcher("unique_mask"))
return functor(UniqueAggregation());
if (matcher("biggest"))
return functor(BiggestAggregation());
return functor(0);
}
template<typename Functor>
auto
agg_dispatch(std::string_view aggfunc, Functor&& functor)
{
return for_each_native_aggregation(
[&](std::string_view name) -> bool {
return aggfunc == name;
},
[&](auto aggregation) {
if constexpr (std::is_same_v<decltype(aggregation), int>) {
_throw_pyerr_fmt(PyExc_ValueError,
"unknown built-in aggregation function %R",
make_pystr(aggfunc).get());
// Not reached: just here for the type
return AUTOFWD(functor)(SumAggregation());
} else {
return AUTOFWD(functor)(AUTOFWD(aggregation));
}
});
}
template<typename Functor>
auto
agg_and_dtype_dispatch(std::string_view aggfunc,
dtype_ref dtype,
bool is_string,
Functor&& functor)
{
return agg_dispatch(aggfunc, [&](auto agg_dummy) {
using Aggregation = decltype(agg_dummy);
if (is_string) {
if (dtype != type_descr<StringTable::id_type>())
throw_pyerr_msg(PyExc_AssertionError, "bad string dtype");
using StringAggregation = typename Aggregation::StringVersion;
return AUTOFWD(functor)(StringTable::id_type(),
StringAggregation());
}
return npy_type_dispatch(
dtype,
[&](auto value_dummy) {
return AUTOFWD(functor)(value_dummy, agg_dummy);
});
});
}
NaContext::NaContext(StringTable* st, String collation)
: st(st), collation(std::move(collation))
{}
int
NaContext::compare_strings(StringTable::id_type left,
StringTable::id_type right)
{
if (StringTable::SequenceNumber new_st_seq
= this->st->get_seqno(); new_st_seq != this->st_seq) {
this->rank = this->st->rank(this->collation);
this->st_seq = new_st_seq;
}
assume(left < npy_size1d(this->rank));
assume(right < npy_size1d(this->rank));
StringTable::Rank* rank_data = npy_data<StringTable::Rank>(this->rank);
if (rank_data[left] < rank_data[right])
return -1;
if (rank_data[left] > rank_data[right])
return 1;
return 0;
}
bool
NaContext::is_string() const noexcept
{
return this->st != nullptr;
}
template<typename NaImpl>
unique_pyref
na_agg_impl(PyObject*, pyref py_args)
{
PARSEPYARGS(
(pyref, data)
(pyref, command)
(pyref, out)
(unique_dtype, dtype, no_default, convert_dtype)
(std::string_view, aggfunc)
(QueryExecution*, qe)
(OPTIONAL_ARGS_FOLLOW)
(obj_pyref<StringTable>, st)
(pyref, collation)
)(py_args);
String the_collation;
if (collation && collation != Py_None)
the_collation = str(collation);
else
the_collation = "binary";
NaContext nac(st.get(), std::move(the_collation));
CommandIn ci(command.notnull().addref());
agg_and_dtype_dispatch(
aggfunc, dtype, nac.is_string(),
[&](auto value_dummy, auto agg_dummy) {
using Value = decltype(value_dummy);
using Aggregation = decltype(agg_dummy);
NaImpl::template agg_column<Value, Aggregation>(
std::move(ci),
data,
out,
std::move(nac),
qe);
});
return addref(Py_None);
}
template<MinMaxMode Mode>
template<typename Value>
MinMaxStringAggregation<Mode>::Kernel<Value>::Kernel(Value value)
: state(value)
{}
template<MinMaxMode Mode>
template<typename Value>
void
MinMaxStringAggregation<Mode>::Kernel<Value>::accumulate(
Value value,
NaContext* nac)
{
static_assert(Mode == MinMaxMode::MIN ||
Mode == MinMaxMode::MAX ||
Mode == MinMaxMode::BIGGEST);
// We don't need special support for MinMaxMode::BIGGEST here
// because we always use the length collation, which DTRT, in
// this case.
int cmp = nac->compare_strings(this->state, value);
if (Mode == MinMaxMode::MIN)
cmp = -cmp;
if (cmp < 0)
this->state = value;
}
template<MinMaxMode Mode>
template<typename Value>
typename MinMaxStringAggregation<Mode>::template Kernel<Value>::Out
MinMaxStringAggregation<Mode>::Kernel<Value>::get() const
{
return this->state;
}
template<typename Value>
int
_integer_log10(Value value)
{
// TODO(dancol): we can optimize this function in numerous ways
// (e.g., lookup by bit count) in the rare event that its
// performance actually matters
assume(value >= 0);
auto xv = static_cast<unsigned long long>(value);
unsigned long long limit = 9999999999999999999LLU;
int l10 = 20;
while (l10 && xv <= limit) {
l10 -= 1;
limit /= 10;
}
return l10;
}
template<MinMaxMode Mode>
template<typename Value>
MinMaxAggregation<Mode>::Kernel<Value>::Kernel(Value value)
: state(value)
{}
template<MinMaxMode Mode>
template<typename Value>
void
MinMaxAggregation<Mode>::Kernel<Value>::accumulate(Value value, NaContext*)
{
bool should_update;
if constexpr (Mode == MinMaxMode::MIN || Mode == MinMaxMode::MAX) {
should_update = Mode == MinMaxMode::MIN
? (value < this->state)
: (value > this->state);
} else if constexpr (Mode == MinMaxMode::BIGGEST) {
if constexpr (std::is_same_v<Value, bool>) {
// false is "bigger" than true
should_update = this->state && !value;
} else if constexpr (std::is_integral_v<Value>) {
// TODO(dancol): avoid repeated computation of log10
int log10_state = _integer_log10(
this->state < 0 ? -this->state : this->state);
int log10_nv = _integer_log10(value < 0 ? -value : value);
// Account for minus sign
if (this->state < 0)
log10_state += 1;
if (value < 0 )
log10_nv += 1;
should_update = log10_state < log10_nv;
} else if constexpr (std::is_floating_point_v<Value>) {
// Eh. Just use the bigger number.
should_update = this->state < value;
} else {
return errhack<Value>::cannot_compute_biggest();
}
} else {
errhack<void, Mode>::bad_min_max_mode();
}
if (should_update)
this->state = value;
}
template<MinMaxMode Mode>
template<typename Value>
typename MinMaxAggregation<Mode>::template Kernel<Value>::Out
MinMaxAggregation<Mode>::Kernel<Value>::get() const
{
return this->state;
}
} // namespace dctv