blob: ea41529d8186d3047f6e6f2750bc4293fee58ff1 [file] [log] [blame]
/*
* Copyright (c) 2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#pragma once
#include "src/core/NEON/kernels/assembly/depthwise.hpp"
#include <cstddef>
#include <functional>
using arm_gemm::Nothing;
namespace arm_conv {
namespace depthwise {
template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing>
struct DepthwiseImplementation
{
const DepthwiseMethod method;
const char *name;
std::function<bool(const DepthwiseArgs &, const OutputStage &)> is_supported;
std::function<uint64_t(const DepthwiseArgs &, const OutputStage &)> cycle_estimate;
std::function<DepthwiseCommon<TInput, TWeight, TOutput> *(const DepthwiseArgs &, const OutputStage &)> initialise;
bool get_is_supported(const DepthwiseArgs &args, const OutputStage &os) const
{
return (is_supported == nullptr) ? true : is_supported(args, os);
}
uint64_t get_cycle_estimate(const DepthwiseArgs &args, const OutputStage &os) const
{
return (cycle_estimate == nullptr) ? 0 : cycle_estimate(args, os);
}
DepthwiseCommon<TInput, TWeight, TOutput> *get_instance(const DepthwiseArgs &args, const OutputStage &os) const
{
return initialise(args, os);
}
};
template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing>
const DepthwiseImplementation<TInput, TWeight, TOutput, OutputStage> *depthwise_implementation_list();
template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing>
bool find_implementation(
const DepthwiseArgs &args,
const OutputStage &os,
const DepthwiseImplementation<TInput, TWeight, TOutput, OutputStage> * &selected
)
{
selected = nullptr;
uint64_t best_cycle_estimate = UINT64_MAX;
const auto *impl = depthwise_implementation_list<TInput, TWeight, TOutput, OutputStage>();
for (; impl->method != DepthwiseMethod::DEFAULT; impl++)
{
const bool has_cfg = (args.config != nullptr);
const auto &cfg = args.config;
if (
!impl->get_is_supported(args, os) || // Problem is unsupported
(has_cfg && cfg->method != DepthwiseMethod::DEFAULT && cfg->method != impl->method) ||
(has_cfg && cfg->filter != "" && !std::strstr(impl->name, cfg->filter.c_str()))
)
{
continue;
}
const auto cycle_estimate = impl->get_cycle_estimate(args, os);
if (cycle_estimate == 0)
{
selected = impl;
break;
}
if (selected == nullptr || cycle_estimate < best_cycle_estimate)
{
selected = impl;
best_cycle_estimate = cycle_estimate;
}
}
return (selected != nullptr);
}
template <typename TInput, typename TWeight, typename TOutput, class OutputStage>
std::vector<KernelDescription> get_compatible_kernels(const DepthwiseArgs &args, const OutputStage &os)
{
std::vector<KernelDescription> kerns;
// Find the default implementation so we can flag it accordingly
const DepthwiseImplementation<TInput, TWeight, TOutput, OutputStage> *default_impl;
find_implementation<TInput, TWeight, TOutput, OutputStage>(args, os, default_impl);
for (auto impl = depthwise_implementation_list<TInput, TWeight, TOutput, OutputStage>();
impl->method != DepthwiseMethod::DEFAULT; impl++)
{
if (!impl->get_is_supported(args, os))
{
continue;
}
kerns.emplace_back(
impl->method, impl->name, impl == default_impl,
impl->get_cycle_estimate(args, os)
);
}
return kerns;
}
template <typename TInput, typename TWeight, typename TOutput, class OutputStage>
UniqueDepthwiseCommon<TInput, TWeight, TOutput> depthwise(const DepthwiseArgs &args, const OutputStage &os)
{
const DepthwiseImplementation<TInput, TWeight, TOutput, OutputStage> *impl = nullptr;
const bool success = find_implementation<TInput, TWeight, TOutput, OutputStage>(args, os, impl);
if(success)
{
auto i = impl->get_instance(args, os);
i->set_name(impl->name);
return UniqueDepthwiseCommon<TInput, TWeight, TOutput>(i);
}
return nullptr;
}
} // namespace depthwise
} // namespace arm_conv