blob: 2163805c43fca33f060ad8853a756db7c7b599bd [file] [log] [blame]
// Copyright (C) 2020 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "native_aggregation.h"
#include <tuple>
#include "automethod.h"
#include "pyseq.h"
#include "vector.h"
namespace dctv {
using std::string_view;
static
unique_dtype
native_aggregation_result_dtype(
std::string_view aggregation,
dtype_ref dtype)
{
return npy_type_dispatch(
dtype,
[&](auto value_dummy) {
return agg_dispatch(
aggregation,
[&](auto agg_dummy) {
using Value = decltype(value_dummy);
using Aggregation = decltype(agg_dummy);
using Kernel = typename Aggregation::template Kernel<Value>;
return type_descr<typename Kernel::Out>();
});
});
}
static
unique_pyref
native_aggregation_info(std::string_view aggregation)
{
auto [cls, has_empty_value] = agg_dispatch(
aggregation,
[&](auto agg_dummy) {
using Aggregation = decltype(agg_dummy);
using SomeKernel = typename Aggregation::template Kernel<int8_t>;
return std::tuple(
Aggregation::cls,
SomeKernel::empty_value.has_value());
});
AggregationClass acls = cls;
return pytuple::of(make_pystr(static_cast<char>(acls)),
make_pybool(has_empty_value));
}
static
unique_pyref
get_supported_native_aggregations()
{
Vector<String> aggfunc_names;
for_each_native_aggregation(
[&](std::string_view name) -> bool {
aggfunc_names.push_back(String(name));
return false;
},
[&](auto aggregation) {
/* Do nothing */
}
);
return pytuple::from(aggfunc_names,
[&](const String& s) {
return make_pystr(s);
});
}
static PyMethodDef functions[] = {
AUTOFUNCTION(native_aggregation_result_dtype,
"Find dtype for a native aggregation",
(string_view, aggregation)
(dtype_ref, dtype)
),
AUTOFUNCTION(native_aggregation_info,
"Get information about a native aggregation",
(string_view, aggregation)
),
AUTOFUNCTION(get_supported_native_aggregations,
"Get a tuple of supported native aggregation names",
),
{ 0 }
};
void
init_native_aggregation(pyref m)
{
register_functions(m, functions);
}
} // namespace dctv