| /* |
| * Copyright (c) Meta Platforms, Inc. and affiliates. |
| * All rights reserved. |
| * |
| * This source code is licensed under the BSD-style license found in the |
| * LICENSE file in the root directory of this source tree. |
| */ |
| |
| #include <executorch/backends/cadence/reference/kernels/kernels.h> |
| #include <math.h> |
| #include <algorithm> |
| #include <cstring> |
| #include <limits> |
| #include <numeric> |
| |
| namespace impl { |
| namespace reference { |
| namespace kernels { |
| |
| // Quantize a fp32 value to an int8_t/uint8_t value |
| template <typename T> |
| T quantize(const float x, float scale, int32_t zero_point) { |
| constexpr float min_val = std::numeric_limits<T>::min(); |
| constexpr float max_val = std::numeric_limits<T>::max(); |
| float tmp = roundf(x * scale + zero_point); |
| return std::max(std::min(tmp, max_val), min_val); |
| } |
| |
| // Quantize an fp32 array to an int8_t/uint8_t array |
| template <typename T> |
| void quantize( |
| T* __restrict__ y, |
| const float* __restrict__ x, |
| float inv_scale, |
| int32_t zero_point, |
| size_t size) { |
| for (size_t i = 0; i < size; ++i) { |
| y[i] = quantize<T>(x[i], inv_scale, zero_point); |
| } |
| } |
| |
| // Dequantize an int8_t/uint8_t value to an fp32 value |
| template <typename T> |
| float dequantize(const T x, float scale, int32_t zero_point) { |
| return scale * (x - zero_point); |
| } |
| |
| // Dequantize an int8_t/uint8_t/int16_t array to an fp32 array |
| template <typename T> |
| void dequantize( |
| float* __restrict__ y, |
| const T* __restrict__ x, |
| float scale, |
| int32_t zero_point, |
| size_t size) { |
| for (size_t i = 0; i < size; ++i) { |
| y[i] = dequantize<T>(x[i], scale, zero_point); |
| } |
| } |
| |
| // explicit template instantiation |
| |
| #define typed_quantize_val(dtype) \ |
| template dtype quantize(const float x, float inv_scale, int32_t zero_point); |
| typed_quantize_val(int8_t); |
| typed_quantize_val(uint8_t); |
| typed_quantize_val(int16_t); |
| typed_quantize_val(uint16_t); |
| typed_quantize_val(int32_t); |
| #undef typed_quantize_val |
| |
| #define typed_quantize_vec(dtype) \ |
| template void quantize( \ |
| dtype* __restrict__ y, \ |
| const float* __restrict__ x, \ |
| float inv_scale, \ |
| int32_t zero_point, \ |
| size_t size); |
| typed_quantize_vec(int8_t); |
| typed_quantize_vec(uint8_t); |
| typed_quantize_vec(int16_t); |
| typed_quantize_vec(uint16_t); |
| typed_quantize_vec(int32_t); |
| #undef typed_quantize_vec |
| |
| #define typed_dequantize_val(dtype) \ |
| template float dequantize(const dtype x, float scale, int32_t zero_point); |
| typed_dequantize_val(int8_t); |
| typed_dequantize_val(uint8_t); |
| typed_dequantize_val(int16_t); |
| typed_dequantize_val(uint16_t); |
| typed_dequantize_val(int32_t); |
| #undef typed_dequantize_val |
| |
| #define typed_dequantize_vec(dtype) \ |
| template void dequantize( \ |
| float* __restrict__ y, \ |
| const dtype* __restrict__ x, \ |
| float scale, \ |
| int32_t zero_point, \ |
| size_t size); |
| typed_dequantize_vec(int8_t); |
| typed_dequantize_vec(uint8_t); |
| typed_dequantize_vec(int16_t); |
| typed_dequantize_vec(uint16_t); |
| typed_dequantize_vec(int32_t); |
| #undef typed_dequantize_vec |
| |
| }; // namespace kernels |
| }; // namespace reference |
| }; // namespace impl |