blob: 2acf36e33d86a5fda2ee6ae89cf1f08115773aec [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "kernels.h"
#include <executorch/runtime/kernel/kernel_includes.h>
#include <algorithm>
#include <cmath>
namespace impl {
namespace HiFi {
namespace native {
using Tensor = exec_aten::Tensor;
using RuntimeContext = torch::executor::RuntimeContext;
namespace linear_util {
// This function compute the product of dim[0:dim] where dim is not inclusive
size_t getLeadingDims(const Tensor& tensor, int64_t dim) {
size_t dims = 1;
for (size_t i = 0; i < dim; ++i) {
dims *= tensor.size(i);
}
return dims;
}
} // namespace linear_util
void quantized_linear_pt2_out(
RuntimeContext& ctx,
const Tensor& src,
const Tensor& weight,
const Tensor& bias,
double src_scale,
int64_t src_zero_point,
double weight_scale,
int64_t weight_zero_point,
const Tensor& out_multiplier,
const Tensor& out_shift,
int64_t out_zero_point,
Tensor& out) {
// input comes in shape [leading_dims, in_dim]
// weight comes in shape [out_dim, in_dim]
// output comes in empty with shape [leading_dims, out_dim]
// Perform matrix multiply (M x N) x (N x P)' => M x P
int64_t leading_dims = linear_util::getLeadingDims(src, src.dim() - 1);
int64_t out_dim = weight.size(0); // = out_dim
int64_t in_dim = weight.size(1); // = in_dim
const uint8_t* __restrict__ in_data = src.const_data_ptr<uint8_t>();
const uint8_t* __restrict__ weight_data = weight.const_data_ptr<uint8_t>();
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
uint8_t* __restrict__ out_data = out.mutable_data_ptr<uint8_t>();
// The nnlib kernel to compute quantized linear via matmul.
int32_t ret = impl::HiFi::kernels::matmul_asym8uxasym8u_asym8u(
out_data, // p_out
weight_data, // p_mat1,
in_data, // p_mat2,
bias_data, // p_bias
out_dim, // rows of p_mat1
in_dim, // cols of p_mat1
in_dim, // row_stride of p_mat1
leading_dims, // vec_count, i.e., rows of p_mat2
in_dim, // vec_offset of p_mat2.
out_dim, // out_offset, i.e., offset of next output element written
1, // out_stride, i.e., stride to go to next output row
-weight_zero_point, // mat1_zero_bias
-src_zero_point, // mat2_zero_bias
out_multiplier.const_data_ptr<int32_t>(), // out_multiplier
out_shift.const_data_ptr<int32_t>(), // out_shift
out_zero_point, // out_zero_bias
false); // per channel quantization
ET_DCHECK_MSG(ret == 0, "HiFi quantized::linear failed");
}
}; // namespace native
}; // namespace HiFi
}; // namespace impl