blob: 6d6c77bd769ea8444cbce9be36670640dd016472 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <cstring>
#include <executorch/runtime/core/array_ref.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/core/span.h>
#include <executorch/runtime/platform/compiler.h>
#include <executorch/runtime/platform/platform.h>
// Debug switch for operator registry
#if defined(ET_OP_REGISTRY_DEBUG)
#include <ostream>
#endif
#define ET_LOG_KERNEL_KEY(k) \
ET_LOG( \
Error, \
"key: %s, is_fallback: %s", \
k.data(), \
k.is_fallback() ? "true" : "false");
#define ET_LOG_TENSOR_META(meta_list) \
for (const auto& meta : meta_list) { \
ET_LOG(Error, "dtype: %d | dim order: [", int(meta.dtype_)); \
for (int i = 0; i < meta.dim_order_.size(); i++) { \
ET_LOG(Error, "%d,", static_cast<int32_t>(meta.dim_order_[i])); \
} \
ET_LOG(Error, "]"); \
}
namespace executorch {
namespace runtime {
class KernelRuntimeContext; // Forward declaration
using OpFunction = void (*)(KernelRuntimeContext&, EValue**);
/**
* Dtype and dim order metadata for a Tensor argument to an operator.
* Used by the Executor to hold the tensor metadata info and retrieve kernel.
*/
struct TensorMeta {
exec_aten::ScalarType dtype_;
Span<exec_aten::DimOrderType> dim_order_;
TensorMeta() = default;
TensorMeta(exec_aten::ScalarType dtype, Span<exec_aten::DimOrderType> order)
: dtype_(dtype), dim_order_(order) {}
bool operator==(const TensorMeta& other) const {
return this->equals(other);
}
bool operator!=(const TensorMeta& other) const {
return !this->equals(other);
}
bool equals(const TensorMeta& other) const {
if (dtype_ != other.dtype_) {
return false;
}
if (dim_order_.size() != other.dim_order_.size()) {
return false;
}
for (int i = 0; i < dim_order_.size(); i++) {
if (dim_order_[i] != other.dim_order_[i]) {
return false;
}
}
return true;
}
#if defined(ET_OP_REGISTRY_DEBUG)
friend std::ostream& operator<<(std::ostream& os, const TensorMeta& meta) {
os << "dtype: " << int(meta.dtype_) << " | dim order: [";
for (int i = 0; i < meta.dim_order_.size(); i++) {
os << static_cast<int32_t>(meta.dim_order_[i]) << ", ";
}
os << "]";
return os;
}
#endif
};
/**
* Describes which dtype & dim order specialized kernel to be bound to an
* operator. If `is_fallback_` is true, it means this kernel can be used as a
* fallback, if false, it means this kernel can only be used if all the
* `TensorMeta` are matched. Fallback means this kernel will be used for
* all input tensor dtypes and dim orders, if the specialized kernel is not
* registered.
*
* The format of a kernel key data is a string:
* "v<version>/<tensor_meta>|<tensor_meta>..."
* Size: Up to 691 1 1 1 (42 +1) * 16
* Assuming max number of tensors is 16 ^
* Kernel key version is v1 for now. If the kernel key format changes,
* update the version to avoid breaking pre-existing kernel keys.
* Example: v1/7;0,1,2,3
* The kernel key has only one tensor: a double tensor with dimension 0, 1, 2, 3
*
* Each tensor_meta has the following format: "<dtype>;<dim_order,...>"
* Size: Up to 42 1-2 1 24 (1 byte for 0-9; 2
* for 10-15) + 15 commas Assuming that the max number of dims is 16 ^ Example:
* 7;0,1,2,3 for [double; 0, 1, 2, 3]
*
* IMPORTANT:
* Users should not construct a kernel key manually. Instead, it should be
* generated from kernel yaml.
*/
struct KernelKey {
public:
KernelKey() : is_fallback_(true) {}
/* implicit */ KernelKey(const char* kernel_key_data)
: kernel_key_data_(kernel_key_data), is_fallback_(false) {}
constexpr static int MAX_SIZE = 691;
bool operator==(const KernelKey& other) const {
return this->equals(other);
}
bool operator!=(const KernelKey& other) const {
return !this->equals(other);
}
bool equals(const KernelKey& other) const {
if (is_fallback_ != other.is_fallback_) {
return false;
}
if (is_fallback_) {
return true;
}
return strncmp(kernel_key_data_, other.kernel_key_data_, MAX_SIZE) == 0;
}
bool is_fallback() const {
return is_fallback_;
}
const char* data() const {
return kernel_key_data_;
}
#if defined(ET_OP_REGISTRY_DEBUG)
friend std::ostream& operator<<(std::ostream& os, const KernelKey& key) {
os << key.kernel_key_data_ << std::endl;
return os;
}
#endif
private:
const char* kernel_key_data_ = nullptr;
bool is_fallback_;
};
/**
* Struct that bundles a kernel key, a function and an op name together. An
* `Operator` may have more than one `Kernel` (maximum kMaxNumOfKernelPerOp) and
* they should have the same op name and different kernel key. A "fallback"
* kernel may or may not live in an `Operator`.
*/
struct Kernel {
const char* name_;
// String representation of kernel key, with the same format as
// KernelKey.to_string_representation()
// Data is not owned by the Kernel struct.
KernelKey kernel_key_;
OpFunction op_;
/**
* We are doing a copy of the string pointer instead of duplicating the string
* itself, we require the lifetime of the operator name to be at least as long
* as the operator registry.
*/
explicit Kernel(const char* name, OpFunction func) : name_(name), op_(func) {}
explicit Kernel(const char* name, KernelKey key, OpFunction func)
: name_(name), kernel_key_(key), op_(func) {}
Kernel() {}
};
namespace internal {
void make_kernel_key_string(Span<const TensorMeta> key, char* buf);
} // namespace internal
/**
* Checks whether an operator exists with a given name and TensorMeta list. When
* TensorMeta is empty, it means this op does not have specialized kernels, so
* it checks whether it has any fallback kernels.
*/
bool registry_has_op_function(
const char* name,
Span<const TensorMeta> meta_list = {});
/**
* Returns the operator with a given name and TensorMeta list, if present.
*/
::executorch::runtime::Result<OpFunction> get_op_function_from_registry(
const char* name,
Span<const TensorMeta> meta_list = {});
/**
* Returns all registered kernels.
*/
Span<const Kernel> get_registered_kernels();
/**
* Registers the provided kernels.
*
* @param[in] kernels Kernel objects to register.
* @retval Error::Ok always. Panics on error. This function needs to return a
* non-void type to run at static initialization time.
*/
ET_NODISCARD Error register_kernels(const Span<const Kernel>);
/**
* Registers a single kernel.
*
* @param[in] kernel Kernel object to register.
* @retval Error::Ok always. Panics on error. This function needs to return a
* non-void type to run at static initialization time.
*/
ET_NODISCARD inline Error register_kernel(const Kernel& kernel) {
return register_kernels({&kernel, 1});
};
} // namespace runtime
} // namespace executorch
namespace torch {
namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::runtime::Kernel;
using ::executorch::runtime::KernelKey;
using ::executorch::runtime::KernelRuntimeContext;
using ::executorch::runtime::OpFunction;
using ::executorch::runtime::TensorMeta;
using KernelRuntimeContext = ::executorch::runtime::KernelRuntimeContext;
inline ::executorch::runtime::Error register_kernels(ArrayRef<Kernel> kernels) {
return ::executorch::runtime::register_kernels(
{kernels.data(), kernels.size()});
}
inline OpFunction getOpsFn(
const char* name,
ArrayRef<TensorMeta> meta_list = {}) {
auto result = ::executorch::runtime::get_op_function_from_registry(
name, {meta_list.data(), meta_list.size()});
ET_CHECK(result.ok()); // get_op_function_from_registry() logs details.
return *result;
}
inline bool hasOpsFn(const char* name, ArrayRef<TensorMeta> meta_list = {}) {
return ::executorch::runtime::registry_has_op_function(
name, {meta_list.data(), meta_list.size()});
}
inline ArrayRef<Kernel> get_kernels() {
Span<const Kernel> kernels = ::executorch::runtime::get_registered_kernels();
return ArrayRef<Kernel>(kernels.data(), kernels.size());
}
} // namespace executor
} // namespace torch