blob: 17ce80b68a14f8e42566412f11e707c3e2451eba [file] [log] [blame]
// Copyright 2020 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <assert.h>
#include <xnnpack/gemm.h>
#include <xnnpack/scalar-utils.h>
// This kernel is a scalar model for a kernel using ARMv8.2 dot-product
// instructions.
//
// XNN_DISABLE_TSAN is used because this kernel reads up to 3 bytes past the
// bounds of the `a` matrix region, which may be a race condition with
// another thread. We deem this acceptable because the values that are
// read out of bounds do not affect the result, and the the compiler can't know
// about this undefined behavior.
void xnn_qu8_gemm_minmax_ukernel_${MR}x${NR}c4__scalar(
size_t mr,
size_t nc,
size_t kc,
const uint8_t* restrict a,
size_t a_stride,
const void* restrict w,
uint8_t* restrict c,
size_t cm_stride,
size_t cn_stride,
const union xnn_qu8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN {
assert(mr != 0);
assert(mr <= ${MR});
assert(nc != 0);
assert(kc != 0);
const uint8_t* a0 = a;
uint8_t* c0 = c;
$for M in range(1, MR):
const uint8_t* a${M} = (const uint8_t*) ((uintptr_t) a${M-1} + a_stride);
uint8_t* c${M} = (uint8_t*) ((uintptr_t) c${M-1} + cm_stride);
$if M % 2 == 0:
if XNN_UNPREDICTABLE(mr <= ${M}) {
a${M} = a${M-1};
c${M} = c${M-1};
}
$elif M + 1 == MR:
if XNN_UNPREDICTABLE(mr != ${M+1}) {
a${M} = a${M-1};
c${M} = c${M-1};
}
$else:
if XNN_UNPREDICTABLE(mr < ${M+1}) {
a${M} = a${M-1};
c${M} = c${M-1};
}
const int32_t vb_zero_point = params->scalar.kernel_zero_point;
// Loop over groups of ${NR} columns.
do {
// `vaccMN` is the accumulator at row `M`, column `N`.
// Initialize accumulators with bias. ${NR} bias values are loaded from the
// weight matrix, at the start of the group of ${NR} columns.
$for N in range(NR):
int32_t bias${N} = ((const int32_t*)w)[${N}];
$for M in range(MR):
int32_t vacc${M}${N} = bias${N};
w = (const void*) ((uintptr_t) w + ${NR} * sizeof(int32_t));
// Inner accumulation loop along the ${NR} columns.
// Handle 4 rows at each iteration: this is key to modelling what an
// actual kernel using ARMv8.2 dot-product instructions would look like.
size_t k = 0;
while (k < kc) {
// Load a ${MR}x4 block of activations, and compute sums along rows.
$for M in range(MR):
int16_t vasum${M} = 0;
$for K in range(4):
int32_t va${M}${K} = *a${M}++;
vasum${M} += (int16_t) va${M}${K};
// Load a 4x${NR} block of weights.
$for N in range(NR):
$for K in range(4):
int32_t vb${K}${N} = (int32_t) ((const uint8_t*)w)[${K}];
w = (const void*) ((uintptr_t) w + 4 * sizeof(uint8_t));
// Multiply-accumulate: ${MR}x4 * 4x${NR} --> ${MR}x${NR}. The inner size 4 here means
// we're computing 4D dot-products, which makes this a model for
// a ARMv8.2 dot-product kernel.
$for M in range(MR):
$for N in range(NR):
$for K in range(4):
vacc${M}${N} += va${M}${K} * vb${K}${N};
vacc${M}${N} -= ((int32_t) vasum${M}) * vb_zero_point;
k += 4 * sizeof(uint8_t);
}
// End of accumulation loop. The variable `k` contains the amount by which
// we advanced the `va` pointers, so we rewind by this amount now.
$for M in range(MR):
a${M} = (const uint8_t*)((uintptr_t) a${M} - k);
// Post-accumulation work
const int32_t vmultiplier = params->scalar.multiplier;
const int64_t vq31rounding = INT64_C(0x40000000);
const int32_t vremainder_mask = params->scalar.remainder_mask;
const uint32_t vshift = params->scalar.shift;
const int32_t vremainder_threshold = params->scalar.remainder_threshold;
const int32_t voutput_min = params->scalar.output_min_less_zero_point;
const int32_t voutput_max = params->scalar.output_max_less_zero_point;
const int32_t voutput_zero_point = params->scalar.output_zero_point;
$for M in range(MR):
$for N in range(NR):
const int64_t vproduct${M}${N} = (int64_t)vacc${M}${N} * (int64_t)vmultiplier;
$for M in range(MR):
$for N in range(NR):
const int32_t vq31product${M}${N} = (int32_t)(uint32_t)((uint64_t)(vproduct${M}${N} + vq31rounding) >> 31);
$for M in range(MR):
$for N in range(NR):
const int32_t vremainder${M}${N} = (vq31product${M}${N} & vremainder_mask) - (int32_t)(vq31product${M}${N} < 0);
$for M in range(MR):
$for N in range(NR):
int32_t vout${M}${N} = asr_s32(vq31product${M}${N}, vshift) + (int32_t)(vremainder${M}${N} > vremainder_threshold);
$for M in range(MR):
$for N in range(NR):
vout${M}${N} = vout${M}${N} < voutput_min ? voutput_min : vout${M}${N};
$for M in range(MR):
$for N in range(NR):
vout${M}${N} = vout${M}${N} > voutput_max ? voutput_max : vout${M}${N};
$for M in range(MR):
$for N in range(NR):
vout${M}${N} += voutput_zero_point;
if XNN_LIKELY (nc >= ${NR}) {
// Main case where there the ${NR} columns fit in the destination.
$for M in range(MR):
$for N in range(NR):
c${M}[${N}] = vout${M}${N};
// Advance to the next ${NR} columns.
$for M in range(MR):
c${M} = (uint8_t*)((uintptr_t) c${M} + cn_stride);
nc -= ${NR};
} else {
// Final case where not all of the ${NR} columns fit in the destination.
$for N in range(NR):
if (nc > ${N}) {
$for M in range(MR):
c${M}[${N}] = vout${M}${N};
}
nc = 0;
}
} while (nc != 0);
}