| // Copyright (C) 2020 The Android Open Source Project |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| namespace dctv { |
| |
| template<typename Matcher, typename Functor> |
| auto |
| for_each_native_aggregation(Matcher&& matcher, Functor&& functor) |
| { |
| // We use this somewhat awkward function so that the list of |
| // aggregation function names is in exactly one place in DCTV. |
| |
| // N.B. COUNT(*) is special cased in Python |
| if (matcher("count")) |
| return functor(CountAggregation()); |
| if (matcher("first")) |
| return functor(FirstAggregation()); |
| if (matcher("max")) |
| return functor(MaxAggregation()); |
| if (matcher("min")) |
| return functor(MinAggregation()); |
| if (matcher("prod")) |
| return functor(ProdAggregation()); |
| if (matcher("sum")) |
| return functor(SumAggregation()); |
| if (matcher("unique_mask")) |
| return functor(UniqueAggregation()); |
| if (matcher("biggest")) |
| return functor(BiggestAggregation()); |
| return functor(0); |
| } |
| |
| template<typename Functor> |
| auto |
| agg_dispatch(std::string_view aggfunc, Functor&& functor) |
| { |
| return for_each_native_aggregation( |
| [&](std::string_view name) -> bool { |
| return aggfunc == name; |
| }, |
| [&](auto aggregation) { |
| if constexpr (std::is_same_v<decltype(aggregation), int>) { |
| _throw_pyerr_fmt(PyExc_ValueError, |
| "unknown built-in aggregation function %R", |
| make_pystr(aggfunc).get()); |
| // Not reached: just here for the type |
| return AUTOFWD(functor)(SumAggregation()); |
| } else { |
| return AUTOFWD(functor)(AUTOFWD(aggregation)); |
| } |
| }); |
| } |
| |
| template<typename Functor> |
| auto |
| agg_and_dtype_dispatch(std::string_view aggfunc, |
| dtype_ref dtype, |
| bool is_string, |
| Functor&& functor) |
| { |
| return agg_dispatch(aggfunc, [&](auto agg_dummy) { |
| using Aggregation = decltype(agg_dummy); |
| if (is_string) { |
| if (dtype != type_descr<StringTable::id_type>()) |
| throw_pyerr_msg(PyExc_AssertionError, "bad string dtype"); |
| using StringAggregation = typename Aggregation::StringVersion; |
| return AUTOFWD(functor)(StringTable::id_type(), |
| StringAggregation()); |
| } |
| return npy_type_dispatch( |
| dtype, |
| [&](auto value_dummy) { |
| return AUTOFWD(functor)(value_dummy, agg_dummy); |
| }); |
| }); |
| } |
| |
| NaContext::NaContext(StringTable* st, String collation) |
| : st(st), collation(std::move(collation)) |
| {} |
| |
| int |
| NaContext::compare_strings(StringTable::id_type left, |
| StringTable::id_type right) |
| { |
| if (StringTable::SequenceNumber new_st_seq |
| = this->st->get_seqno(); new_st_seq != this->st_seq) { |
| this->rank = this->st->rank(this->collation); |
| this->st_seq = new_st_seq; |
| } |
| |
| assume(left < npy_size1d(this->rank)); |
| assume(right < npy_size1d(this->rank)); |
| StringTable::Rank* rank_data = npy_data<StringTable::Rank>(this->rank); |
| if (rank_data[left] < rank_data[right]) |
| return -1; |
| if (rank_data[left] > rank_data[right]) |
| return 1; |
| return 0; |
| } |
| |
| bool |
| NaContext::is_string() const noexcept |
| { |
| return this->st != nullptr; |
| } |
| |
| template<typename NaImpl> |
| unique_pyref |
| na_agg_impl(PyObject*, pyref py_args) |
| { |
| PARSEPYARGS( |
| (pyref, data) |
| (pyref, command) |
| (pyref, out) |
| (unique_dtype, dtype, no_default, convert_dtype) |
| (std::string_view, aggfunc) |
| (QueryExecution*, qe) |
| (OPTIONAL_ARGS_FOLLOW) |
| (obj_pyref<StringTable>, st) |
| (pyref, collation) |
| )(py_args); |
| |
| String the_collation; |
| if (collation && collation != Py_None) |
| the_collation = str(collation); |
| else |
| the_collation = "binary"; |
| NaContext nac(st.get(), std::move(the_collation)); |
| CommandIn ci(command.notnull().addref()); |
| |
| agg_and_dtype_dispatch( |
| aggfunc, dtype, nac.is_string(), |
| [&](auto value_dummy, auto agg_dummy) { |
| using Value = decltype(value_dummy); |
| using Aggregation = decltype(agg_dummy); |
| NaImpl::template agg_column<Value, Aggregation>( |
| std::move(ci), |
| data, |
| out, |
| std::move(nac), |
| qe); |
| }); |
| |
| return addref(Py_None); |
| } |
| |
| template<MinMaxMode Mode> |
| template<typename Value> |
| MinMaxStringAggregation<Mode>::Kernel<Value>::Kernel(Value value) |
| : state(value) |
| {} |
| |
| template<MinMaxMode Mode> |
| template<typename Value> |
| void |
| MinMaxStringAggregation<Mode>::Kernel<Value>::accumulate( |
| Value value, |
| NaContext* nac) |
| { |
| static_assert(Mode == MinMaxMode::MIN || |
| Mode == MinMaxMode::MAX || |
| Mode == MinMaxMode::BIGGEST); |
| // We don't need special support for MinMaxMode::BIGGEST here |
| // because we always use the length collation, which DTRT, in |
| // this case. |
| int cmp = nac->compare_strings(this->state, value); |
| if (Mode == MinMaxMode::MIN) |
| cmp = -cmp; |
| if (cmp < 0) |
| this->state = value; |
| } |
| |
| template<MinMaxMode Mode> |
| template<typename Value> |
| typename MinMaxStringAggregation<Mode>::template Kernel<Value>::Out |
| MinMaxStringAggregation<Mode>::Kernel<Value>::get() const |
| { |
| return this->state; |
| } |
| |
| template<typename Value> |
| int |
| _integer_log10(Value value) |
| { |
| // TODO(dancol): we can optimize this function in numerous ways |
| // (e.g., lookup by bit count) in the rare event that its |
| // performance actually matters |
| assume(value >= 0); |
| auto xv = static_cast<unsigned long long>(value); |
| unsigned long long limit = 9999999999999999999LLU; |
| int l10 = 20; |
| while (l10 && xv <= limit) { |
| l10 -= 1; |
| limit /= 10; |
| } |
| return l10; |
| } |
| |
| template<MinMaxMode Mode> |
| template<typename Value> |
| MinMaxAggregation<Mode>::Kernel<Value>::Kernel(Value value) |
| : state(value) |
| {} |
| |
| template<MinMaxMode Mode> |
| template<typename Value> |
| void |
| MinMaxAggregation<Mode>::Kernel<Value>::accumulate(Value value, NaContext*) |
| { |
| bool should_update; |
| if constexpr (Mode == MinMaxMode::MIN || Mode == MinMaxMode::MAX) { |
| should_update = Mode == MinMaxMode::MIN |
| ? (value < this->state) |
| : (value > this->state); |
| } else if constexpr (Mode == MinMaxMode::BIGGEST) { |
| if constexpr (std::is_same_v<Value, bool>) { |
| // false is "bigger" than true |
| should_update = this->state && !value; |
| } else if constexpr (std::is_integral_v<Value>) { |
| // TODO(dancol): avoid repeated computation of log10 |
| int log10_state = _integer_log10( |
| this->state < 0 ? -this->state : this->state); |
| int log10_nv = _integer_log10(value < 0 ? -value : value); |
| // Account for minus sign |
| if (this->state < 0) |
| log10_state += 1; |
| if (value < 0 ) |
| log10_nv += 1; |
| should_update = log10_state < log10_nv; |
| } else if constexpr (std::is_floating_point_v<Value>) { |
| // Eh. Just use the bigger number. |
| should_update = this->state < value; |
| } else { |
| return errhack<Value>::cannot_compute_biggest(); |
| } |
| } else { |
| errhack<void, Mode>::bad_min_max_mode(); |
| } |
| if (should_update) |
| this->state = value; |
| } |
| |
| template<MinMaxMode Mode> |
| template<typename Value> |
| typename MinMaxAggregation<Mode>::template Kernel<Value>::Out |
| MinMaxAggregation<Mode>::Kernel<Value>::get() const |
| { |
| return this->state; |
| } |
| |
| } // namespace dctv |