blob: a8fd50d7b91843e709701cbab6ebf84f24eef71e [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/runtime/kernel/operator_registry.h>
#include <executorch/runtime/platform/runtime.h>
#include <executorch/runtime/platform/system.h>
#include <cinttypes>
#include <executorch/runtime/platform/assert.h>
namespace executorch {
namespace runtime {
OperatorRegistry& getOperatorRegistry();
OperatorRegistry& getOperatorRegistry() {
static OperatorRegistry operator_registry;
return operator_registry;
}
Error register_kernels(const ArrayRef<Kernel>& kernels) {
Error success = getOperatorRegistry().register_kernels(kernels);
if (success == Error::InvalidArgument || success == Error::Internal) {
ET_CHECK_MSG(
false,
"Kernel registration failed with error %" PRIu32
", see error log for details.",
static_cast<uint32_t>(success));
}
return success;
}
Error OperatorRegistry::register_kernels(const ArrayRef<Kernel>& kernels) {
// Operator registration happens in static initialization time when PAL init
// may or may not happen already. Here we are assuming et_pal_init() doesn't
// have any side effect even if falled multiple times.
::et_pal_init();
if (kernels.size() + this->num_kernels_ > kMaxNumOfKernels) {
ET_LOG(
Error,
"The total number of kernels to be registered is larger than the limit %" PRIu32
". %" PRIu32
" kernels are already registered and we're trying to register another %" PRIu32
" kernels.",
kMaxNumOfKernels,
(uint32_t)this->num_kernels_,
(uint32_t)kernels.size());
ET_LOG(Error, "======== Kernels already in the registry: ========");
for (size_t i = 0; i < this->num_kernels_; i++) {
ET_LOG(Error, "%s", this->kernels_[i].name_);
ET_LOG_KERNEL_KEY(this->kernels_[i].kernel_key_);
}
ET_LOG(Error, "======== Kernels being registered: ========");
for (size_t i = 0; i < kernels.size(); i++) {
ET_LOG(Error, "%s", kernels[i].name_);
ET_LOG_KERNEL_KEY(kernels[i].kernel_key_);
}
return Error::Internal;
}
// for debugging purpose
const char* lib_name = et_pal_get_shared_library_name(kernels.data());
for (const auto& kernel : kernels) {
// linear search. This is fine if the number of kernels are small.
for (int32_t i = 0; i < this->num_kernels_; i++) {
Kernel k = this->kernels_[i];
if (strcmp(kernel.name_, k.name_) == 0 &&
kernel.kernel_key_ == k.kernel_key_) {
ET_LOG(Error, "Re-registering %s, from %s", k.name_, lib_name);
ET_LOG_KERNEL_KEY(k.kernel_key_);
return Error::InvalidArgument;
}
}
this->kernels_[this->num_kernels_++] = kernel;
}
ET_LOG(
Debug,
"Successfully registered all kernels from shared library: %s",
lib_name);
return Error::Ok;
}
bool hasOpsFn(const char* name, ArrayRef<TensorMeta> kernel_key) {
return getOperatorRegistry().hasOpsFn(name, kernel_key);
}
static int copy_char_as_number_to_buf(char num, char* buf) {
if ((char)num < 10) {
*buf = '0' + (char)num;
buf += 1;
return 1;
} else {
*buf = '0' + ((char)num) / 10;
buf += 1;
*buf = '0' + ((char)num) % 10;
buf += 1;
return 2;
}
}
void make_kernel_key_string(ArrayRef<TensorMeta> key, char* buf);
void make_kernel_key_string(ArrayRef<TensorMeta> key, char* buf) {
if (key.empty()) {
// If no tensor is present in an op, kernel key does not apply
return;
}
strncpy(buf, "v1/", 3);
buf += 3;
for (size_t i = 0; i < key.size(); i++) {
auto& meta = key[i];
buf += copy_char_as_number_to_buf((char)meta.dtype_, buf);
*buf = ';';
buf += 1;
for (int j = 0; j < meta.dim_order_.size(); j++) {
buf += copy_char_as_number_to_buf((char)meta.dim_order_[j], buf);
if (j != meta.dim_order_.size() - 1) {
*buf = ',';
buf += 1;
}
}
*buf = (i < (key.size() - 1)) ? '|' : 0x00;
buf += 1;
}
}
bool OperatorRegistry::hasOpsFn(
const char* name,
ArrayRef<TensorMeta> meta_list) {
char buf[KernelKey::MAX_SIZE] = {0};
make_kernel_key_string(meta_list, buf);
KernelKey kernel_key = KernelKey(buf);
for (size_t idx = 0; idx < this->num_kernels_; idx++) {
if (strcmp(this->kernels_[idx].name_, name) == 0) {
if (this->kernels_[idx].kernel_key_.is_fallback() ||
this->kernels_[idx].kernel_key_ == kernel_key) {
return true;
}
}
}
return false;
}
const OpFunction& getOpsFn(const char* name, ArrayRef<TensorMeta> kernel_key) {
return getOperatorRegistry().getOpsFn(name, kernel_key);
}
const OpFunction& OperatorRegistry::getOpsFn(
const char* name,
ArrayRef<TensorMeta> meta_list) {
char buf[KernelKey::MAX_SIZE] = {0};
make_kernel_key_string(meta_list, buf);
KernelKey kernel_key = KernelKey(buf);
int32_t fallback_idx = -1;
for (size_t idx = 0; idx < this->num_kernels_; idx++) {
if (strcmp(this->kernels_[idx].name_, name) == 0) {
if (this->kernels_[idx].kernel_key_ == kernel_key) {
return this->kernels_[idx].op_;
}
if (this->kernels_[idx].kernel_key_.is_fallback()) {
fallback_idx = idx;
}
}
}
if (fallback_idx != -1) {
return this->kernels_[fallback_idx].op_;
}
ET_CHECK_MSG(false, "kernel '%s' not found.", name);
ET_LOG_TENSOR_META(meta_list);
}
ArrayRef<Kernel> get_kernels() {
return getOperatorRegistry().get_kernels();
}
ArrayRef<Kernel> OperatorRegistry::get_kernels() {
return ArrayRef<Kernel>(this->kernels_, this->num_kernels_);
}
} // namespace runtime
} // namespace executorch