blob: a987918744d65ab8675fc8ff44d1e33048161f37 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <cstring>
#include <executorch/runtime/core/array_ref.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/function_ref.h>
#include <executorch/runtime/platform/compiler.h>
#include <executorch/runtime/platform/platform.h>
// Debug switch for operator registry
#if defined(ET_OP_REGISTRY_DEBUG)
#include <ostream>
#endif
namespace torch {
namespace executor {
class KernelRuntimeContext; // Forward declaration
using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove
using OpFunction = FunctionRef<void(KernelRuntimeContext&, EValue**)>;
/**
* Dtype and dim order metadata for a Tensor argument to an operator.
* Used by the Executor to hold the tensor metadata info.
*/
struct TensorMeta {
exec_aten::ScalarType dtype_;
ArrayRef<exec_aten::DimOrderType> dim_order_;
TensorMeta() = default;
TensorMeta(ScalarType dtype, ArrayRef<exec_aten::DimOrderType> order)
: dtype_(dtype), dim_order_(order) {}
bool operator==(const TensorMeta& other) const {
return this->equals(other);
}
bool operator!=(const TensorMeta& other) const {
return !this->equals(other);
}
bool equals(const TensorMeta& other) const {
if (dtype_ != other.dtype_) {
return false;
}
if (dim_order_.size() != other.dim_order_.size()) {
return false;
}
for (int i = 0; i < dim_order_.size(); i++) {
if (dim_order_[i] != other.dim_order_[i]) {
return false;
}
}
return true;
}
#if defined(ET_OP_REGISTRY_DEBUG)
friend std::ostream& operator<<(std::ostream& os, const TensorMeta& meta) {
os << "dtype: " << int(meta.dtype_) << " | dim order: [";
for (int i = 0; i < meta.dim_order_.size(); i++) {
os << static_cast<int32_t>(meta.dim_order_[i]) << ", ";
}
os << "]";
return os;
}
#endif
};
/**
* Describes which dtype & dim order specialized kernel to be bound to an
* operator. If `is_fallback_` is true, it means this kernel can be used as a
* fallback, if false, it means this kernel can only be used if all the
* `TensorMeta` are matched. Fallback means this kernel will be used for
* all input tensor dtypes and dim orders, if the specialized kernel is not
* registered.
*
* The format of a kernel key data is a string:
* "v<version>/<tensor_meta>|<tensor_meta>...\xff"
* Size: Up to 307 1 1 1 (18 +1) * 16
* Assuming max number of tensors is 16 ^
* Version is v0 for now
* Example: v0/0x07;0x00 0x01 0x02 0x03 \xff
* The kernel key has only one tensor: a double tensor with dimension 0, 1, 2, 3
*
* The string is a byte array and contains non-printable characters. It must
* be terminated with a '\xff' so 0xff cannot be a scalar type.
*
* Each tensor_meta has the following format: "<dtype>;<dim_order...>"
* Size: Up to 18 1 1 16
* Assuming that the max number of dims is 16 ^
* Example: 0x07;0x00 0x01 0x02 0x03 for [double; 0, 1, 2, 3]
*
* IMPORTANT:
* Users should not construct a kernel key manually. Instead, it should be
* generated from kernel yaml.
*/
struct KernelKey {
public:
KernelKey() : is_fallback_(true) {}
/* implicit */ KernelKey(const char* kernel_key_data)
: kernel_key_data_(kernel_key_data), is_fallback_(false) {}
constexpr static char TERMINATOR = 0xff;
bool operator==(const KernelKey& other) const {
return this->equals(other);
}
bool operator!=(const KernelKey& other) const {
return !this->equals(other);
}
bool equals(const KernelKey& other) const {
if (is_fallback_ != other.is_fallback_) {
return false;
}
if (is_fallback_) {
return true;
}
size_t i;
for (i = 0; kernel_key_data_[i] != TERMINATOR &&
other.kernel_key_data_[i] != TERMINATOR;
i++) {
if (kernel_key_data_[i] != other.kernel_key_data_[i]) {
return false;
}
}
return kernel_key_data_[i] == TERMINATOR &&
other.kernel_key_data_[i] == TERMINATOR;
}
bool is_fallback() const {
return is_fallback_;
}
#if defined(ET_OP_REGISTRY_DEBUG)
friend std::ostream& operator<<(std::ostream& os, const KernelKey& key) {
os << key.kernel_key_data_ << std::endl;
return os;
}
#endif
private:
const char* kernel_key_data_ = nullptr;
bool is_fallback_;
};
/**
* Struct that bundles a kernel key, a function and an op name together. An
* `Operator` may have more than one `Kernel` (maximum kMaxNumOfKernelPerOp) and
* they should have the same op name and different kernel key. A "fallback"
* kernel may or may not live in an `Operator`.
*/
struct Kernel {
const char* name_;
// String representation of kernel key, with the same format as
// KernelKey.to_string_representation()
// Data is not owned by the Kernel struct.
KernelKey kernel_key_;
OpFunction op_;
/**
* We are doing a copy of the string pointer instead of duplicating the string
* itself, we require the lifetime of the operator name to be at least as long
* as the operator registry.
*/
explicit Kernel(const char* name, OpFunction func) : name_(name), op_(func) {}
explicit Kernel(const char* name, KernelKey key, OpFunction func)
: name_(name), kernel_key_(key), op_(func) {}
Kernel() {}
};
constexpr uint32_t kOperatorTableMaxSize = 250;
constexpr uint32_t kMaxNumOfKernelPerOp = 8;
constexpr uint32_t kMaxNumOfKernels =
kOperatorTableMaxSize * kMaxNumOfKernelPerOp;
/**
* Struct that represents an operator at runtime. This object and the `Operator`
* field in the program should be 1-to-1 mapping. During static initialization,
* all kernels will be registered from the generated C++ code. Then during the
* kernel resolution step in runtime initialization, the target kernel will be
* looked up and stored along with `Chain`.
*/
struct Operator {
public:
const char* name_;
explicit Operator(const char* name) : name_(name), num_kernels_(0) {}
// constructor that takes a kernel with its kernel key.
explicit Operator(const char* name, KernelKey key, OpFunction func)
: name_(name), num_kernels_(1) {
kernels_[0] = Kernel(name, key, func);
}
explicit Operator(const char* name, OpFunction func)
: name_(name), num_kernels_(1) {
kernels_[0] = Kernel(name, {}, func);
}
Operator() {}
// check if this operator contains a kernel with a particular kernel key.
bool contains(KernelKey key) const {
for (auto i = 0; i < num_kernels_; i++) {
if (kernels_[i].kernel_key_ == key) {
return true;
}
}
return false;
}
// returns an `OpFunction` from either a kernel key match, or fallback kernel
// if not matched.
const OpFunction& find_or_fallback(KernelKey key) const {
int32_t fallback_index = -1;
for (auto i = 0; i < num_kernels_; i++) {
if (kernels_[i].kernel_key_ == key) {
return kernels_[i].op_;
}
if (kernels_[i].kernel_key_.is_fallback()) {
fallback_index = i;
}
}
if (fallback_index != -1) {
return kernels_[fallback_index].op_;
}
ET_CHECK_MSG(false, "kernel key not found.");
}
bool has_fallback() const {
return contains({});
}
bool register_kernel(Kernel kernel) {
if (num_kernels_ == kMaxNumOfKernelPerOp) {
return false;
}
kernels_[num_kernels_++] = kernel;
return true;
}
private:
Kernel kernels_[kMaxNumOfKernelPerOp];
uint32_t num_kernels_;
};
/**
* See OperatorRegistry::hasOpsFn()
*/
bool hasOpsFn(const char* name, ArrayRef<TensorMeta> meta_list = {});
/**
* See OperatorRegistry::getOpsFn()
*/
const OpFunction& getOpsFn(
const char* name,
ArrayRef<TensorMeta> meta_list = {});
/**
* See OperatorRegistry::getOpsArray()
*/
ArrayRef<Operator> getOpsArray();
/**
* DEPRECATED: Use register_kernels() instead.
* See OperatorRegistry::register_operators(). Notice that the returned Error
* object should be handled internally and the reason for keep returning is to
* satisfy the requirement to run this in static initialization time.
*/
__ET_NODISCARD Error register_operators(const ArrayRef<Operator>&);
/**
* See OperatorRegistry::register_kernels(). Notice that the returned Error
* object should be handled internally and the reason for keep returning is to
* satisfy the requirement to run this in static initialization time.
*/
__ET_NODISCARD Error register_kernels(const ArrayRef<Kernel>&);
struct OperatorRegistry {
public:
OperatorRegistry() : operatorRegSize_(0) {}
/**
* DEPRECATED: Use register_kernels() instead. TODO: (larryliu) Remove.
* Registers the Operator object (which may contain one or more function
* references) so that it could be called via the name during the runtime.
* WARNING: only use this when we are confident that there are no duplicates
* in Operator name.
* @param[in] operators Operator object
* @retval Error code representing whether registration was successful.
*/
__ET_NODISCARD Error register_operators(const ArrayRef<Operator>&);
/**
* Registers the Kernels object (i.e. string name and function reference
* pair). The kernels will be merged into Operators based on the op name.
*
* @param[in] kernels Kernel object
* @retval Error code representing whether registration was successful.
*/
__ET_NODISCARD Error register_kernels(const ArrayRef<Kernel>&);
/**
* Checks whether an operator with a given name and TensorMeta list.
* When TensorMeta is empty, it means this op does not have specialized
* kernels, so it checks whether it has any fallback kernels.
*/
bool hasOpsFn(const char* name, ArrayRef<TensorMeta> meta_list);
/**
* Get the operator with a given name and TensorMeta list
*/
const OpFunction& getOpsFn(const char* name, ArrayRef<TensorMeta> meta_list);
/**
* Return all registered operators.
*/
ArrayRef<Operator> getOpsArray();
private:
Operator operators_table_[kOperatorTableMaxSize];
uint32_t operatorRegSize_;
};
} // namespace executor
} // namespace torch