blob: 94a56aa1cf0ff213619525792e5e8c94e75c11e7 [file] [log] [blame]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <arm_neon.h>
#include <xnnpack/common.h>
#include <xnnpack/dwconv.h>
void xnn_q8_dwconv_ukernel_up8x9__neon(
size_t channels,
size_t output_width,
const uint8_t** input,
const void* weights,
uint8_t* output,
size_t input_stride,
size_t output_increment,
const union xnn_q8_gemm_params params[restrict static 1])
{
const uint8x8_t vkernel_zero_point = vld1_dup_u8((const uint8_t*) &params->neon.kernel_zero_point);
const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
const uint8x8_t voutput_min = vld1_dup_u8(&params->neon.output_min);
const uint8x8_t voutput_max = vld1_dup_u8(&params->neon.output_max);
#if XNN_ARCH_ARM64
// Larger number of registers on AArch64 make it possible to process few pixels at a time.
if (input_stride == 3 * sizeof(void*)) {
for (; output_width >= 3; output_width -= 3) {
const uint8_t* i00 = input[ 0];
const uint8_t* i10 = input[ 1];
const uint8_t* i20 = input[ 2];
const uint8_t* i01 = input[ 3];
const uint8_t* i11 = input[ 4];
const uint8_t* i21 = input[ 5];
const uint8_t* i02 = input[ 6];
const uint8_t* i12 = input[ 7];
const uint8_t* i22 = input[ 8];
const uint8_t* i03 = input[ 9];
const uint8_t* i13 = input[10];
const uint8_t* i23 = input[11];
const uint8_t* i04 = input[12];
const uint8_t* i14 = input[13];
const uint8_t* i24 = input[14];
uint8_t* output0 = output;
uint8_t* output1 = output0 + channels + output_increment;
uint8_t* output2 = output1 + channels + output_increment;
input += 9;
size_t c = channels;
const void* w = weights;
for (; c >= 8; c -= 8) {
int32x4_t vacc0_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
int32x4_t vacc0_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
int32x4_t vacc1_lo = vacc0_lo;
int32x4_t vacc2_lo = vacc0_lo;
int32x4_t vacc1_hi = vacc0_hi;
int32x4_t vacc2_hi = vacc0_hi;
const uint8x8_t vk00 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi00 = vld1_u8(i00); i00 += 8;
const uint8x8_t vi01 = vld1_u8(i01); i01 += 8;
const uint8x8_t vi02 = vld1_u8(i02); i02 += 8;
const int16x8_t vxk00 = vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point));
const int16x8_t vxi00 = vreinterpretq_s16_u16(vmovl_u8(vi00));
const int16x8_t vxi01 = vreinterpretq_s16_u16(vmovl_u8(vi01));
const int16x8_t vxi02 = vreinterpretq_s16_u16(vmovl_u8(vi02));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02);
const uint8x8_t vk10 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi10 = vld1_u8(i10); i10 += 8;
const uint8x8_t vi11 = vld1_u8(i11); i11 += 8;
const uint8x8_t vi12 = vld1_u8(i12); i12 += 8;
const int16x8_t vxk10 = vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point));
const int16x8_t vxi10 = vreinterpretq_s16_u16(vmovl_u8(vi10));
const int16x8_t vxi11 = vreinterpretq_s16_u16(vmovl_u8(vi11));
const int16x8_t vxi12 = vreinterpretq_s16_u16(vmovl_u8(vi12));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12);
const uint8x8_t vk20 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi20 = vld1_u8(i20); i20 += 8;
const uint8x8_t vi21 = vld1_u8(i21); i21 += 8;
const uint8x8_t vi22 = vld1_u8(i22); i22 += 8;
const int16x8_t vxk20 = vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point));
const int16x8_t vxi20 = vreinterpretq_s16_u16(vmovl_u8(vi20));
const int16x8_t vxi21 = vreinterpretq_s16_u16(vmovl_u8(vi21));
const int16x8_t vxi22 = vreinterpretq_s16_u16(vmovl_u8(vi22));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22);
const uint8x8_t vk01 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi03 = vld1_u8(i03); i03 += 8;
const int16x8_t vxk01 = vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point));
const int16x8_t vxi03 = vreinterpretq_s16_u16(vmovl_u8(vi03));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03);
const uint8x8_t vk11 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi13 = vld1_u8(i13); i13 += 8;
const int16x8_t vxk11 = vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point));
const int16x8_t vxi13 = vreinterpretq_s16_u16(vmovl_u8(vi13));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13);
const uint8x8_t vk21 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi23 = vld1_u8(i23); i23 += 8;
const int16x8_t vxk21 = vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point));
const int16x8_t vxi23 = vreinterpretq_s16_u16(vmovl_u8(vi23));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23);
const uint8x8_t vk02 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi04 = vld1_u8(i04); i04 += 8;
const int16x8_t vxk02 = vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point));
const int16x8_t vxi04 = vreinterpretq_s16_u16(vmovl_u8(vi04));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04);
const uint8x8_t vk12 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi14 = vld1_u8(i14); i14 += 8;
const int16x8_t vxk12 = vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point));
const int16x8_t vxi14 = vreinterpretq_s16_u16(vmovl_u8(vi14));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14);
const uint8x8_t vk22 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi24 = vld1_u8(i24); i24 += 8;
const int16x8_t vxk22 = vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point));
const int16x8_t vxi24 = vreinterpretq_s16_u16(vmovl_u8(vi24));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
vacc0_lo = vqrdmulhq_s32(vacc0_lo, vmultiplier);
vacc0_hi = vqrdmulhq_s32(vacc0_hi, vmultiplier);
vacc1_lo = vqrdmulhq_s32(vacc1_lo, vmultiplier);
vacc1_hi = vqrdmulhq_s32(vacc1_hi, vmultiplier);
vacc2_lo = vqrdmulhq_s32(vacc2_lo, vmultiplier);
vacc2_hi = vqrdmulhq_s32(vacc2_hi, vmultiplier);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
vacc2_lo = vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31);
vacc2_hi = vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31);
vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift);
vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift);
const int16x8_t vacc0 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), voutput_zero_point);
const int16x8_t vacc1 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), voutput_zero_point);
const int16x8_t vacc2 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), voutput_zero_point);
uint8x8_t vout0 = vqmovun_s16(vacc0);
uint8x8_t vout1 = vqmovun_s16(vacc1);
uint8x8_t vout2 = vqmovun_s16(vacc2);
vout0 = vmax_u8(vout0, voutput_min);
vout1 = vmax_u8(vout1, voutput_min);
vout2 = vmax_u8(vout2, voutput_min);
vout0 = vmin_u8(vout0, voutput_max);
vout1 = vmin_u8(vout1, voutput_max);
vout2 = vmin_u8(vout2, voutput_max);
vst1_u8(output0, vout0); output0 += 8;
vst1_u8(output1, vout1); output1 += 8;
vst1_u8(output2, vout2); output2 += 8;
}
if (c != 0) {
int32x4_t vacc0_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
int32x4_t vacc0_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
int32x4_t vacc1_lo = vacc0_lo;
int32x4_t vacc2_lo = vacc0_lo;
int32x4_t vacc1_hi = vacc0_hi;
int32x4_t vacc2_hi = vacc0_hi;
const uint8x8_t vk00 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi00 = vld1_u8(i00);
const uint8x8_t vi01 = vld1_u8(i01);
const uint8x8_t vi02 = vld1_u8(i02);
const int16x8_t vxk00 = vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point));
const int16x8_t vxi00 = vreinterpretq_s16_u16(vmovl_u8(vi00));
const int16x8_t vxi01 = vreinterpretq_s16_u16(vmovl_u8(vi01));
const int16x8_t vxi02 = vreinterpretq_s16_u16(vmovl_u8(vi02));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02);
const uint8x8_t vk10 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi10 = vld1_u8(i10);
const uint8x8_t vi11 = vld1_u8(i11);
const uint8x8_t vi12 = vld1_u8(i12);
const int16x8_t vxk10 = vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point));
const int16x8_t vxi10 = vreinterpretq_s16_u16(vmovl_u8(vi10));
const int16x8_t vxi11 = vreinterpretq_s16_u16(vmovl_u8(vi11));
const int16x8_t vxi12 = vreinterpretq_s16_u16(vmovl_u8(vi12));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12);
const uint8x8_t vk20 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi20 = vld1_u8(i20);
const uint8x8_t vi21 = vld1_u8(i21);
const uint8x8_t vi22 = vld1_u8(i22);
const int16x8_t vxk20 = vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point));
const int16x8_t vxi20 = vreinterpretq_s16_u16(vmovl_u8(vi20));
const int16x8_t vxi21 = vreinterpretq_s16_u16(vmovl_u8(vi21));
const int16x8_t vxi22 = vreinterpretq_s16_u16(vmovl_u8(vi22));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22);
const uint8x8_t vk01 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi03 = vld1_u8(i03);
const int16x8_t vxk01 = vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point));
const int16x8_t vxi03 = vreinterpretq_s16_u16(vmovl_u8(vi03));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03);
const uint8x8_t vk11 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi13 = vld1_u8(i13);
const int16x8_t vxk11 = vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point));
const int16x8_t vxi13 = vreinterpretq_s16_u16(vmovl_u8(vi13));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13);
const uint8x8_t vk21 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi23 = vld1_u8(i23);
const int16x8_t vxk21 = vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point));
const int16x8_t vxi23 = vreinterpretq_s16_u16(vmovl_u8(vi23));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23);
const uint8x8_t vk02 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi04 = vld1_u8(i04);
const int16x8_t vxk02 = vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point));
const int16x8_t vxi04 = vreinterpretq_s16_u16(vmovl_u8(vi04));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04);
const uint8x8_t vk12 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi14 = vld1_u8(i14);
const int16x8_t vxk12 = vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point));
const int16x8_t vxi14 = vreinterpretq_s16_u16(vmovl_u8(vi14));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14);
const uint8x8_t vk22 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi24 = vld1_u8(i24);
const int16x8_t vxk22 = vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point));
const int16x8_t vxi24 = vreinterpretq_s16_u16(vmovl_u8(vi24));
vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22));
vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22);
vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23));
vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23);
vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
vacc0_lo = vqrdmulhq_s32(vacc0_lo, vmultiplier);
vacc0_hi = vqrdmulhq_s32(vacc0_hi, vmultiplier);
vacc1_lo = vqrdmulhq_s32(vacc1_lo, vmultiplier);
vacc1_hi = vqrdmulhq_s32(vacc1_hi, vmultiplier);
vacc2_lo = vqrdmulhq_s32(vacc2_lo, vmultiplier);
vacc2_hi = vqrdmulhq_s32(vacc2_hi, vmultiplier);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
vacc2_lo = vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31);
vacc2_hi = vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31);
vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift);
vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift);
const int16x8_t vacc0 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), voutput_zero_point);
const int16x8_t vacc1 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), voutput_zero_point);
const int16x8_t vacc2 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), voutput_zero_point);
uint8x8_t vout0 = vqmovun_s16(vacc0);
uint8x8_t vout1 = vqmovun_s16(vacc1);
uint8x8_t vout2 = vqmovun_s16(vacc2);
vout0 = vmax_u8(vout0, voutput_min);
vout1 = vmax_u8(vout1, voutput_min);
vout2 = vmax_u8(vout2, voutput_min);
vout0 = vmin_u8(vout0, voutput_max);
vout1 = vmin_u8(vout1, voutput_max);
vout2 = vmin_u8(vout2, voutput_max);
if (c & 4) {
vst1_lane_u32(__builtin_assume_aligned(output0, 1), vreinterpret_u32_u8(vout0), 0); output0 += 4;
vst1_lane_u32(__builtin_assume_aligned(output1, 1), vreinterpret_u32_u8(vout1), 0); output1 += 4;
vst1_lane_u32(__builtin_assume_aligned(output2, 1), vreinterpret_u32_u8(vout2), 0); output2 += 4;
vout0 = vext_u8(vout0, vout0, 4);
vout1 = vext_u8(vout1, vout1, 4);
vout2 = vext_u8(vout2, vout2, 4);
}
if (c & 2) {
vst1_lane_u16(__builtin_assume_aligned(output0, 1), vreinterpret_u16_u8(vout0), 0); output0 += 2;
vst1_lane_u16(__builtin_assume_aligned(output1, 1), vreinterpret_u16_u8(vout1), 0); output1 += 2;
vst1_lane_u16(__builtin_assume_aligned(output2, 1), vreinterpret_u16_u8(vout2), 0); output2 += 2;
vout0 = vext_u8(vout0, vout0, 2);
vout1 = vext_u8(vout1, vout1, 2);
vout2 = vext_u8(vout2, vout2, 2);
}
if (c & 1) {
vst1_lane_u8(__builtin_assume_aligned(output0, 1), vout0, 0); output0++;
vst1_lane_u8(__builtin_assume_aligned(output1, 1), vout1, 0); output1++;
vst1_lane_u8(__builtin_assume_aligned(output2, 1), vout2, 0); output2++;
}
}
output = (uint8_t*) ((uintptr_t) output2 + output_increment);
}
if (output_width == 0) {
return;
}
}
#endif // XNN_ARCH_ARM64
do {
const uint8_t* i0 = input[0];
const uint8_t* i1 = input[1];
const uint8_t* i2 = input[2];
const uint8_t* i3 = input[3];
const uint8_t* i4 = input[4];
const uint8_t* i5 = input[5];
const uint8_t* i6 = input[6];
const uint8_t* i7 = input[7];
const uint8_t* i8 = input[8];
input = (const uint8_t**) ((uintptr_t) input + input_stride);
size_t c = channels;
const void* w = weights;
for (; c >= 8; c -= 8) {
int32x4_t vaccX1_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
int32x4_t vaccX1_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
const uint8x8_t vk0 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
const int16x8_t vxk0 = vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
const int16x8_t vxi0 = vreinterpretq_s16_u16(vmovl_u8(vi0));
int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
const uint8x8_t vk1 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
const int16x8_t vxk1 = vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
const int16x8_t vxi1 = vreinterpretq_s16_u16(vmovl_u8(vi1));
vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
const uint8x8_t vk2 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
const int16x8_t vxk2 = vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
const int16x8_t vxi2 = vreinterpretq_s16_u16(vmovl_u8(vi2));
vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
const uint8x8_t vk3 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
const int16x8_t vxk3 = vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
const int16x8_t vxi3 = vreinterpretq_s16_u16(vmovl_u8(vi3));
vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
const uint8x8_t vk4 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
const int16x8_t vxk4 = vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
const int16x8_t vxi4 = vreinterpretq_s16_u16(vmovl_u8(vi4));
vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
const uint8x8_t vk5 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
const int16x8_t vxk5 = vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
const int16x8_t vxi5 = vreinterpretq_s16_u16(vmovl_u8(vi5));
vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
const uint8x8_t vk6 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
const int16x8_t vxk6 = vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
const int16x8_t vxi6 = vreinterpretq_s16_u16(vmovl_u8(vi6));
vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
const uint8x8_t vk7 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
const int16x8_t vxk7 = vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
const int16x8_t vxi7 = vreinterpretq_s16_u16(vmovl_u8(vi7));
vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
const uint8x8_t vk8 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi8 = vld1_u8(i8); i8 += 8;
const int16x8_t vxk8 = vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
const int16x8_t vxi8 = vreinterpretq_s16_u16(vmovl_u8(vi8));
vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
vacc_lo = vqrdmulhq_s32(vacc_lo, vmultiplier);
vacc_hi = vqrdmulhq_s32(vacc_hi, vmultiplier);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
vacc_lo = vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
vacc_hi = vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
#if XNN_ARCH_ARM64
const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
#else
const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
#endif
uint8x8_t vout = vqmovun_s16(vacc);
vout = vmax_u8(vout, voutput_min);
vout = vmin_u8(vout, voutput_max);
vst1_u8(output, vout); output += 8;
}
if (c != 0) {
int32x4_t vaccX1_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
int32x4_t vaccX1_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
const uint8x8_t vk0 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi0 = vld1_u8(i0);
const int16x8_t vxk0 = vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
const int16x8_t vxi0 = vreinterpretq_s16_u16(vmovl_u8(vi0));
int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
const uint8x8_t vk1 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi1 = vld1_u8(i1);
const int16x8_t vxk1 = vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
const int16x8_t vxi1 = vreinterpretq_s16_u16(vmovl_u8(vi1));
vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
const uint8x8_t vk2 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi2 = vld1_u8(i2);
const int16x8_t vxk2 = vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
const int16x8_t vxi2 = vreinterpretq_s16_u16(vmovl_u8(vi2));
vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
const uint8x8_t vk3 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi3 = vld1_u8(i3);
const int16x8_t vxk3 = vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
const int16x8_t vxi3 = vreinterpretq_s16_u16(vmovl_u8(vi3));
vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
const uint8x8_t vk4 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi4 = vld1_u8(i4);
const int16x8_t vxk4 = vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
const int16x8_t vxi4 = vreinterpretq_s16_u16(vmovl_u8(vi4));
vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
const uint8x8_t vk5 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi5 = vld1_u8(i5);
const int16x8_t vxk5 = vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
const int16x8_t vxi5 = vreinterpretq_s16_u16(vmovl_u8(vi5));
vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
const uint8x8_t vk6 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi6 = vld1_u8(i6);
const int16x8_t vxk6 = vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
const int16x8_t vxi6 = vreinterpretq_s16_u16(vmovl_u8(vi6));
vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
const uint8x8_t vk7 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
const uint8x8_t vi7 = vld1_u8(i7);
const int16x8_t vxk7 = vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
const int16x8_t vxi7 = vreinterpretq_s16_u16(vmovl_u8(vi7));
vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
const uint8x8_t vk8 = vld1_u8(w);
const uint8x8_t vi8 = vld1_u8(i8);
const int16x8_t vxk8 = vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
const int16x8_t vxi8 = vreinterpretq_s16_u16(vmovl_u8(vi8));
vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
vacc_lo = vqrdmulhq_s32(vacc_lo, vmultiplier);
vacc_hi = vqrdmulhq_s32(vacc_hi, vmultiplier);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
vacc_lo = vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
vacc_hi = vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
#if XNN_ARCH_ARM64
const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
#else
const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
#endif
uint8x8_t vout = vqmovun_s16(vacc);
vout = vmax_u8(vout, voutput_min);
vout = vmin_u8(vout, voutput_max);
if (c & 4) {
vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); output += 4;
vout = vext_u8(vout, vout, 4);
}
if (c & 2) {
vst1_lane_u16(__builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); output += 2;
vout = vext_u8(vout, vout, 2);
}
if (c & 1) {
vst1_lane_u8(__builtin_assume_aligned(output, 1), vout, 0); output++;
}
}
output = (uint8_t*) ((uintptr_t) output + output_increment);
} while (--output_width != 0);
}