| // Copyright (C) 2020 The Android Open Source Project |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| #pragma once |
| #include "dctv.h" |
| |
| #include <string_view> |
| |
| #include "command_stream.h" |
| #include "hash_table.h" |
| #include "npy.h" |
| #include "npyiter.h" |
| #include "optional.h" |
| #include "pyerrfmt.h" |
| #include "pyparsetuple.h" |
| #include "pyparsetuplenpy.h" |
| #include "pyutil.h" |
| #include "string_table.h" |
| |
| namespace dctv { |
| |
| enum class AggregationClass : char { |
| // Output is one of the input values |
| PASSTHROUGH = '=', |
| // Output type doesn't depend on input type |
| INDEPENDENT = 'i', |
| // Operation is addition-like |
| ADDITION = '+', |
| // Operation is multiplication-like |
| MULTIPLICATION = '*', |
| }; |
| |
| template<typename T> |
| struct PromotionHelper final { |
| using type = T; |
| }; |
| |
| #define DEFINE_PROMOTION(from, to) \ |
| template<> \ |
| struct PromotionHelper<from> final { \ |
| using type = to; \ |
| } |
| DEFINE_PROMOTION(bool, int64_t); |
| DEFINE_PROMOTION(int8_t, int64_t); |
| DEFINE_PROMOTION(int16_t, int64_t); |
| DEFINE_PROMOTION(int32_t, int64_t); |
| DEFINE_PROMOTION(uint8_t, uint64_t); |
| DEFINE_PROMOTION(uint16_t, uint64_t); |
| DEFINE_PROMOTION(uint32_t, uint64_t); |
| DEFINE_PROMOTION(float, double); |
| #undef DEFINE_PROMOTION |
| |
| template<typename T> |
| using PromotionOf = typename PromotionHelper<T>::type; |
| |
| struct NaContext final { |
| inline NaContext(StringTable* st, String collation); |
| inline int compare_strings(StringTable::id_type left, |
| StringTable::id_type right); |
| inline bool is_string() const noexcept; |
| private: |
| StringTable* st; |
| String collation; |
| StringTable::SequenceNumber st_seq = 0; |
| unique_pyarray rank; |
| }; |
| |
| enum class MinMaxMode { |
| MIN, |
| MAX, |
| BIGGEST, |
| }; |
| |
| // Kernels tagged with this base produce output that depends on the |
| // order of the inputs. |
| struct NaOrderDependent {}; |
| struct NaNoConstraints {}; |
| |
| struct InvalidAggregation final { |
| static constexpr auto cls = AggregationClass::INDEPENDENT; |
| using StringVersion = InvalidAggregation; |
| template<typename Value> |
| struct Kernel final { |
| using Out = int64_t; |
| static constexpr optional<Out> empty_value = {}; |
| explicit Kernel(Value) { die(); } |
| void accumulate(Value, NaContext*) { |
| die(); |
| } |
| Out get() const { |
| die(); |
| } |
| private: |
| DCTV_NORETURN_ERROR |
| static void die() { |
| throw_pyerr_msg(PyExc_AssertionError, "invalid aggregation"); |
| } |
| }; |
| }; |
| |
| // MinMaxStringAggregation is order dependent because different |
| // strings can compare equal under some collations (e.g., the |
| // case-insensitive one) but nevertheless have different values ("foo" |
| // vs "fOO"). We want to make sure to pick a consistent value. |
| template<MinMaxMode Mode> |
| struct MinMaxStringAggregation final { |
| static constexpr auto cls = AggregationClass::PASSTHROUGH; |
| using StringVersion = MinMaxStringAggregation; |
| template<typename Value> |
| struct Kernel final : NaOrderDependent { |
| static_assert(std::is_same_v<Value, StringTable::id_type>); |
| using Out = Value; |
| static constexpr optional<Out> empty_value = {}; |
| explicit inline Kernel(Value value); |
| inline void accumulate(Value value, NaContext* nac); |
| inline Out get() const; |
| private: |
| Out state; |
| }; |
| }; |
| |
| template<MinMaxMode Mode> |
| struct MinMaxAggregation final { |
| static constexpr auto cls = AggregationClass::PASSTHROUGH; |
| using StringVersion = MinMaxStringAggregation<Mode>; |
| |
| template<typename Value> |
| struct Kernel final : std::conditional_t<Mode==MinMaxMode::BIGGEST, |
| NaOrderDependent, |
| NaNoConstraints> |
| { |
| using Out = Value; |
| static constexpr optional<Out> empty_value = {}; |
| explicit inline Kernel(Value value); |
| inline void accumulate(Value value, NaContext*); |
| inline Out get() const; |
| private: |
| Out state; |
| }; |
| }; |
| |
| using MinAggregation = MinMaxAggregation<MinMaxMode::MIN>; |
| using MaxAggregation = MinMaxAggregation<MinMaxMode::MAX>; |
| using BiggestAggregation = MinMaxAggregation<MinMaxMode::BIGGEST>; |
| |
| struct FirstAggregation final { |
| static constexpr auto cls = AggregationClass::PASSTHROUGH; |
| using StringVersion = FirstAggregation; |
| |
| template<typename Value> |
| struct Kernel final : NaOrderDependent { |
| using Out = Value; |
| static constexpr optional<Out> empty_value = {}; |
| explicit Kernel(Value value) : state(value) {} |
| void accumulate(Value, NaContext*) {} |
| Out get() const { |
| return this->state; |
| } |
| private: |
| Out state; |
| }; |
| }; |
| |
| struct UniqueAggregation final { |
| static constexpr auto cls = AggregationClass::INDEPENDENT; |
| using StringVersion = UniqueAggregation; |
| |
| template<typename Value> |
| struct Kernel final { |
| using Out = bool; |
| static constexpr optional<Out> empty_value = {}; |
| explicit Kernel(Value value) : value(value), is_unique(true) {} |
| void accumulate(Value value, NaContext*) { |
| this->is_unique = this->is_unique && value == this->value; |
| } |
| Out get() const { |
| return this->is_unique; |
| } |
| private: |
| Value value; |
| bool is_unique; |
| }; |
| }; |
| |
| struct CountAggregation final { |
| static constexpr auto cls = AggregationClass::INDEPENDENT; |
| using StringVersion = CountAggregation; |
| |
| template<typename Value> |
| struct Kernel final { |
| using Out = int64_t; |
| static constexpr optional<Out> empty_value = 0; |
| explicit Kernel(Value) : count(1) {} |
| void accumulate(Value, NaContext*) { |
| this->count += 1; |
| } |
| Out get() const { |
| return this->count; |
| } |
| private: |
| Out count; |
| }; |
| }; |
| |
| template<typename ArithmeticOperation> |
| struct ArithmeticAggregation final { |
| static constexpr auto cls = ArithmeticOperation::cls; |
| using StringVersion = InvalidAggregation; |
| |
| template<typename Value> |
| struct Kernel final { |
| using Out = PromotionOf<Value>; |
| static constexpr optional<Out> empty_value = {}; |
| explicit Kernel(Value value) : state(value) {} |
| void accumulate(Value value, NaContext*) { |
| this->state = ArithmeticOperation::template perform<Out>( |
| this->state, value); |
| } |
| Out get() const { |
| return this->state; |
| } |
| private: |
| Out state; |
| }; |
| }; |
| |
| struct SumOperation final { |
| static constexpr auto cls = AggregationClass::ADDITION; |
| |
| template<typename Value> |
| static Value perform(Value left, Value right) { |
| // TODO(dancol): detect overflow? |
| return left + right; |
| } |
| }; |
| |
| struct ProdOperation final { |
| static constexpr auto cls = AggregationClass::MULTIPLICATION; |
| |
| template<typename Value> |
| static Value perform(Value left, Value right) { |
| // TODO(dancol): detect overflow? |
| return left * right; |
| } |
| }; |
| |
| using SumAggregation = ArithmeticAggregation<SumOperation>; |
| using ProdAggregation = ArithmeticAggregation<ProdOperation>; |
| |
| template<typename Functor> |
| auto agg_dispatch(std::string_view aggfunc, Functor&& functor); |
| |
| template<typename Functor> |
| auto agg_and_dtype_dispatch(std::string_view aggfunc, |
| dtype_ref dtype, |
| bool is_string, |
| Functor&& functor); |
| |
| template<typename Matcher, typename Functor> |
| auto for_each_native_aggregation(Matcher&& matcher, |
| Functor&& functor); |
| |
| template<typename NaImpl> |
| unique_pyref na_agg_impl(PyObject*, pyref args); |
| |
| void init_native_aggregation(pyref m); |
| |
| } // namespace dctv |
| |
| #include "native_aggregation-inl.h" |