blob: 78aa0a517327cd1110257c6274ed88381e673f4d [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/runtime/kernel/operator_registry.h>
#include <cinttypes>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/platform/platform.h>
#include <executorch/runtime/platform/system.h>
namespace executorch {
namespace runtime {
namespace {
// Maximum number of operators and their associated kernels that can be
// registered.
#ifdef MAX_KERNEL_NUM
constexpr uint32_t kMaxRegisteredKernels = MAX_KERNEL_NUM;
#else
constexpr uint32_t kMaxOperators = 250;
constexpr uint32_t kMaxKernelsPerOp = 8;
constexpr uint32_t kMaxRegisteredKernels = kMaxOperators * kMaxKernelsPerOp;
#endif
// Data that backs the kernel table. Since Kernel has a custom default
// constructor (implicitly, because it contains KernelKey, which has a custom
// ctor), some toolchains don't like having a global array of them: it would
// require constructing them at init time. Since we don't care about the values
// until we add each entry to the table, allocate static zeroed memory instead
// and point the table at it.
// @lint-ignore CLANGTIDY facebook-hte-CArray
alignas(sizeof(Kernel)) uint8_t
registered_kernels_data[kMaxRegisteredKernels * sizeof(Kernel)];
/// Global table of registered kernels.
Kernel* registered_kernels = reinterpret_cast<Kernel*>(registered_kernels_data);
/// The number of kernels registered in the table.
size_t num_registered_kernels = 0;
// Registers the kernels, but may return an error.
Error register_kernels_internal(const Span<const Kernel> kernels) {
// Operator registration happens in static initialization time before or after
// PAL init, so call it here. It is safe to call multiple times.
::et_pal_init();
if (kernels.size() + num_registered_kernels > kMaxRegisteredKernels) {
ET_LOG(
Error,
"The total number of kernels to be registered is larger than the limit "
"%" PRIu32 ". %" PRIu32
" kernels are already registered and we're trying to register another "
"%" PRIu32 " kernels.",
kMaxRegisteredKernels,
(uint32_t)num_registered_kernels,
(uint32_t)kernels.size());
ET_LOG(Error, "======== Kernels already in the registry: ========");
for (size_t i = 0; i < num_registered_kernels; i++) {
ET_LOG(Error, "%s", registered_kernels[i].name_);
ET_LOG_KERNEL_KEY(registered_kernels[i].kernel_key_);
}
ET_LOG(Error, "======== Kernels being registered: ========");
for (size_t i = 0; i < kernels.size(); i++) {
ET_LOG(Error, "%s", kernels[i].name_);
ET_LOG_KERNEL_KEY(kernels[i].kernel_key_);
}
return Error::Internal;
}
// for debugging purpose
const char* lib_name = et_pal_get_shared_library_name(kernels.data());
for (const auto& kernel : kernels) {
// Linear search. This is fine if the number of kernels is small.
for (int32_t i = 0; i < num_registered_kernels; i++) {
Kernel k = registered_kernels[i];
if (strcmp(kernel.name_, k.name_) == 0 &&
kernel.kernel_key_ == k.kernel_key_) {
ET_LOG(Error, "Re-registering %s, from %s", k.name_, lib_name);
ET_LOG_KERNEL_KEY(k.kernel_key_);
return Error::InvalidArgument;
}
}
registered_kernels[num_registered_kernels++] = kernel;
}
ET_LOG(
Debug,
"Successfully registered all kernels from shared library: %s",
lib_name);
return Error::Ok;
}
} // namespace
// Registers the kernels, but panics if an error occurs. Always returns Ok.
Error register_kernels(const Span<const Kernel> kernels) {
Error success = register_kernels_internal(kernels);
if (success == Error::InvalidArgument || success == Error::Internal) {
ET_CHECK_MSG(
false,
"Kernel registration failed with error %" PRIu32
", see error log for details.",
static_cast<uint32_t>(success));
}
return success;
}
namespace {
int copy_char_as_number_to_buf(char num, char* buf) {
if ((char)num < 10) {
*buf = '0' + (char)num;
buf += 1;
return 1;
} else {
*buf = '0' + ((char)num) / 10;
buf += 1;
*buf = '0' + ((char)num) % 10;
buf += 1;
return 2;
}
}
} // namespace
namespace internal {
void make_kernel_key_string(Span<const TensorMeta> key, char* buf) {
if (key.empty()) {
// If no tensor is present in an op, kernel key does not apply
return;
}
strncpy(buf, "v1/", 3);
buf += 3;
for (size_t i = 0; i < key.size(); i++) {
auto& meta = key[i];
buf += copy_char_as_number_to_buf((char)meta.dtype_, buf);
*buf = ';';
buf += 1;
for (int j = 0; j < meta.dim_order_.size(); j++) {
buf += copy_char_as_number_to_buf((char)meta.dim_order_[j], buf);
if (j != meta.dim_order_.size() - 1) {
*buf = ',';
buf += 1;
}
}
*buf = (i < (key.size() - 1)) ? '|' : 0x00;
buf += 1;
}
}
} // namespace internal
bool registry_has_op_function(
const char* name,
Span<const TensorMeta> meta_list) {
return get_op_function_from_registry(name, meta_list).ok();
}
Result<OpFunction> get_op_function_from_registry(
const char* name,
Span<const TensorMeta> meta_list) {
// @lint-ignore CLANGTIDY facebook-hte-CArray
char buf[KernelKey::MAX_SIZE] = {0};
internal::make_kernel_key_string(meta_list, buf);
KernelKey kernel_key = KernelKey(buf);
int32_t fallback_idx = -1;
for (size_t idx = 0; idx < num_registered_kernels; idx++) {
if (strcmp(registered_kernels[idx].name_, name) == 0) {
if (registered_kernels[idx].kernel_key_ == kernel_key) {
return registered_kernels[idx].op_;
}
if (registered_kernels[idx].kernel_key_.is_fallback()) {
fallback_idx = idx;
}
}
}
if (fallback_idx != -1) {
return registered_kernels[fallback_idx].op_;
}
ET_LOG(Error, "kernel '%s' not found.", name);
ET_LOG_TENSOR_META(meta_list);
return Error::OperatorMissing;
}
Span<const Kernel> get_registered_kernels() {
return {registered_kernels, num_registered_kernels};
}
} // namespace runtime
} // namespace executorch