blob: 16626de59ab28a09c7bc3c40f84c49dfc171eb16 [file] [log] [blame]
// Copyright (C) 2020 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "dctv.h"
#include <string_view>
#include "command_stream.h"
#include "hash_table.h"
#include "npy.h"
#include "npyiter.h"
#include "optional.h"
#include "pyerrfmt.h"
#include "pyparsetuple.h"
#include "pyparsetuplenpy.h"
#include "pyutil.h"
#include "string_table.h"
namespace dctv {
enum class AggregationClass : char {
// Output is one of the input values
PASSTHROUGH = '=',
// Output type doesn't depend on input type
INDEPENDENT = 'i',
// Operation is addition-like
ADDITION = '+',
// Operation is multiplication-like
MULTIPLICATION = '*',
};
template<typename T>
struct PromotionHelper final {
using type = T;
};
#define DEFINE_PROMOTION(from, to) \
template<> \
struct PromotionHelper<from> final { \
using type = to; \
}
DEFINE_PROMOTION(bool, int64_t);
DEFINE_PROMOTION(int8_t, int64_t);
DEFINE_PROMOTION(int16_t, int64_t);
DEFINE_PROMOTION(int32_t, int64_t);
DEFINE_PROMOTION(uint8_t, uint64_t);
DEFINE_PROMOTION(uint16_t, uint64_t);
DEFINE_PROMOTION(uint32_t, uint64_t);
DEFINE_PROMOTION(float, double);
#undef DEFINE_PROMOTION
template<typename T>
using PromotionOf = typename PromotionHelper<T>::type;
struct NaContext final {
inline NaContext(StringTable* st, String collation);
inline int compare_strings(StringTable::id_type left,
StringTable::id_type right);
inline bool is_string() const noexcept;
private:
StringTable* st;
String collation;
StringTable::SequenceNumber st_seq = 0;
unique_pyarray rank;
};
enum class MinMaxMode {
MIN,
MAX,
BIGGEST,
};
// Kernels tagged with this base produce output that depends on the
// order of the inputs.
struct NaOrderDependent {};
struct NaNoConstraints {};
struct InvalidAggregation final {
static constexpr auto cls = AggregationClass::INDEPENDENT;
using StringVersion = InvalidAggregation;
template<typename Value>
struct Kernel final {
using Out = int64_t;
static constexpr optional<Out> empty_value = {};
explicit Kernel(Value) { die(); }
void accumulate(Value, NaContext*) {
die();
}
Out get() const {
die();
}
private:
DCTV_NORETURN_ERROR
static void die() {
throw_pyerr_msg(PyExc_AssertionError, "invalid aggregation");
}
};
};
// MinMaxStringAggregation is order dependent because different
// strings can compare equal under some collations (e.g., the
// case-insensitive one) but nevertheless have different values ("foo"
// vs "fOO"). We want to make sure to pick a consistent value.
template<MinMaxMode Mode>
struct MinMaxStringAggregation final {
static constexpr auto cls = AggregationClass::PASSTHROUGH;
using StringVersion = MinMaxStringAggregation;
template<typename Value>
struct Kernel final : NaOrderDependent {
static_assert(std::is_same_v<Value, StringTable::id_type>);
using Out = Value;
static constexpr optional<Out> empty_value = {};
explicit inline Kernel(Value value);
inline void accumulate(Value value, NaContext* nac);
inline Out get() const;
private:
Out state;
};
};
template<MinMaxMode Mode>
struct MinMaxAggregation final {
static constexpr auto cls = AggregationClass::PASSTHROUGH;
using StringVersion = MinMaxStringAggregation<Mode>;
template<typename Value>
struct Kernel final : std::conditional_t<Mode==MinMaxMode::BIGGEST,
NaOrderDependent,
NaNoConstraints>
{
using Out = Value;
static constexpr optional<Out> empty_value = {};
explicit inline Kernel(Value value);
inline void accumulate(Value value, NaContext*);
inline Out get() const;
private:
Out state;
};
};
using MinAggregation = MinMaxAggregation<MinMaxMode::MIN>;
using MaxAggregation = MinMaxAggregation<MinMaxMode::MAX>;
using BiggestAggregation = MinMaxAggregation<MinMaxMode::BIGGEST>;
struct FirstAggregation final {
static constexpr auto cls = AggregationClass::PASSTHROUGH;
using StringVersion = FirstAggregation;
template<typename Value>
struct Kernel final : NaOrderDependent {
using Out = Value;
static constexpr optional<Out> empty_value = {};
explicit Kernel(Value value) : state(value) {}
void accumulate(Value, NaContext*) {}
Out get() const {
return this->state;
}
private:
Out state;
};
};
struct UniqueAggregation final {
static constexpr auto cls = AggregationClass::INDEPENDENT;
using StringVersion = UniqueAggregation;
template<typename Value>
struct Kernel final {
using Out = bool;
static constexpr optional<Out> empty_value = {};
explicit Kernel(Value value) : value(value), is_unique(true) {}
void accumulate(Value value, NaContext*) {
this->is_unique = this->is_unique && value == this->value;
}
Out get() const {
return this->is_unique;
}
private:
Value value;
bool is_unique;
};
};
struct CountAggregation final {
static constexpr auto cls = AggregationClass::INDEPENDENT;
using StringVersion = CountAggregation;
template<typename Value>
struct Kernel final {
using Out = int64_t;
static constexpr optional<Out> empty_value = 0;
explicit Kernel(Value) : count(1) {}
void accumulate(Value, NaContext*) {
this->count += 1;
}
Out get() const {
return this->count;
}
private:
Out count;
};
};
template<typename ArithmeticOperation>
struct ArithmeticAggregation final {
static constexpr auto cls = ArithmeticOperation::cls;
using StringVersion = InvalidAggregation;
template<typename Value>
struct Kernel final {
using Out = PromotionOf<Value>;
static constexpr optional<Out> empty_value = {};
explicit Kernel(Value value) : state(value) {}
void accumulate(Value value, NaContext*) {
this->state = ArithmeticOperation::template perform<Out>(
this->state, value);
}
Out get() const {
return this->state;
}
private:
Out state;
};
};
struct SumOperation final {
static constexpr auto cls = AggregationClass::ADDITION;
template<typename Value>
static Value perform(Value left, Value right) {
// TODO(dancol): detect overflow?
return left + right;
}
};
struct ProdOperation final {
static constexpr auto cls = AggregationClass::MULTIPLICATION;
template<typename Value>
static Value perform(Value left, Value right) {
// TODO(dancol): detect overflow?
return left * right;
}
};
using SumAggregation = ArithmeticAggregation<SumOperation>;
using ProdAggregation = ArithmeticAggregation<ProdOperation>;
template<typename Functor>
auto agg_dispatch(std::string_view aggfunc, Functor&& functor);
template<typename Functor>
auto agg_and_dtype_dispatch(std::string_view aggfunc,
dtype_ref dtype,
bool is_string,
Functor&& functor);
template<typename Matcher, typename Functor>
auto for_each_native_aggregation(Matcher&& matcher,
Functor&& functor);
template<typename NaImpl>
unique_pyref na_agg_impl(PyObject*, pyref args);
void init_native_aggregation(pyref m);
} // namespace dctv
#include "native_aggregation-inl.h"