blob: b20f66875711aa4c519f5a6f64e87d94774ab8b7 [file] [log] [blame]
/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "ruy/kernel_arm.h"
#include "ruy/opt_set.h"
#include "ruy/platform.h"
#include "ruy/profiler/instrumentation.h"
namespace ruy {
#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
#define RUY_ASM_LABEL_STORE_UINT8 91
#define RUY_ASM_LABEL_STORE_INT8 92
#define RUY_ASM_LABEL_STORE_INT16 93
#define RUY_ASM_LABEL_STORE_INT32 94
#define RUY_ASM_LABEL_AFTER_STORE 99
#define RUY_OFFSET_LHS_BASE_PTR 0
#define RUY_OFFSET_RHS_BASE_PTR 4
#define RUY_OFFSET_DST_BASE_PTR 8
#define RUY_OFFSET_BIAS 12
#define RUY_OFFSET_START_ROW 16
#define RUY_OFFSET_START_COL 20
#define RUY_OFFSET_LAST_ROW 24
#define RUY_OFFSET_LAST_COL 28
#define RUY_OFFSET_DST_ROWS 32
#define RUY_OFFSET_DST_COLS 36
#define RUY_OFFSET_LHS_STRIDE 40
#define RUY_OFFSET_RHS_STRIDE 44
#define RUY_OFFSET_DST_STRIDE 48
#define RUY_OFFSET_DEPTH 52
#define RUY_OFFSET_CLAMP_MIN 56
#define RUY_OFFSET_CLAMP_MAX 60
#define RUY_OFFSET_FLAGS 64
#define RUY_STACK_OFFSET_SIZE 96
#define RUY_STACK_OFFSET_DST_COL_PTR 0
#define RUY_STACK_OFFSET_DST_PTR 16
#define RUY_STACK_OFFSET_ROW 32
#define RUY_STACK_OFFSET_COL 48
#define RUY_STACK_OFFSET_LHS_COL_PTR 64
#define RUY_STACK_OFFSET_RHS_COL_PTR 80
template <typename Params>
void CheckOffsetsInKernelParamsFloat32(const Params&) {
static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "");
static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "");
static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "");
static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, "");
static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
}
// Float kernel for ARM32 out-of-order cores.
// Just like Float 64 version, except accumulate in to 8x4 block to only
// use 16 128-bit NEON registers. This is a "first pass" kernel and not
// tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9.
void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params) {
CheckOffsetsInKernelParamsFloat32(params);
profiler::ScopeLabel label("Kernel (kNeon)");
const float* lhs_ptr = params.lhs_base_ptr;
const float* rhs_ptr = params.rhs_base_ptr;
// In ARM32 NEON, there are 16 128-bit "q" registers. These registers are
// each composed of two 64-bit "d" registers. The asm kernel below has the
// following NEON register allocation:
// Registers q3 -- q10 are accumulators. During accumulation,
// q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1
// are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block
// of RHS, like this:
// Register layout in "q" registers:
// RHS 1x4 block
// /--------------------------|
// |q2.s[0] ... q2.s[3] |
// \--------------------------/
// LHS 8x1 block
// /---------------------\ /--------------------------|
// | q0.s[0] | | q3.s[0] ... q9.s[0] |
// | ... | | ... ... |
// | q0.s[3] | | q3.s[3] q9.s[3] |
// | q1.s[0] | | q4.s[0] q10.s[0] |
// | ... | | ... ... ... |
// | q1.s[3] | | q4.s[3] .. q10.s[3] |
// \---------------------/ \--------------------------/
// accumulators 8x4 block
// q11, q14, q15 currently unused. q12 and q13 are used to load
// parameters used for the post-accumulation part of the kernel.
// For completeness, here is the register layout in "d" registers:
// RHS 1x4 block
// /--------------------------|
// |d4[0] ... d5[1] |
// \--------------------------/
// LHS 8x1 block
// /---------------------\ /--------------------------|
// | d0[0] | | d6[0] ... d18[0] |
// | ... | | ... ... |
// | d1[1] | | d7[1] d19[1] |
// | d2[0] | | d8[0] d20[0] |
// | ... | | ... ... ... |
// | d3[1] | | d9[1] ... d21[1] |
// \---------------------/ \--------------------------/
// accumulators 8x4 block
asm volatile(
#define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n"
// clang-format off
// Load the first 32 bytes of LHS and RHS data.
// Load q0, q1
"vld1.32 {d0, d1}, [%[lhs_ptr]]!\n"
"vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
// Load q2
"vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
"sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
"str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
"str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
"str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
// Clear accumulators.
RUY_MAKE_ZERO(q3)
RUY_MAKE_ZERO(q4)
RUY_MAKE_ZERO(q5)
RUY_MAKE_ZERO(q6)
RUY_MAKE_ZERO(q7)
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
// r1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 1.
"mov r1, #1\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// Accumulation loop
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
"cmp r1, r2\n"
"beq 79f\n"
"2:\n"
"vmla.f32 q3, q0, d4[0]\n"
"vmla.f32 q5, q0, d4[1]\n"
"vmla.f32 q7, q0, d5[0]\n"
"vmla.f32 q9, q0, d5[1]\n"
"vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS
"vmla.f32 q4, q1, d4[0]\n"
"vmla.f32 q6, q1, d4[1]\n"
"vmla.f32 q8, q1, d5[0]\n"
"vmla.f32 q10, q1, d5[1]\n"
"vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
"vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS
RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
"add r1, r1, #1\n"
"cmp r1, r2\n"
"blt 2b\n"
"79:\n"
// End of the inner loop on depth. Now perform the remaining
// multiply-adds of the last level of depth, for which the LHS
// and RHS data is already loaded.
"vmla.f32 q3, q0, d4[0]\n"
"vmla.f32 q5, q0, d4[1]\n"
"vmla.f32 q7, q0, d5[0]\n"
"vmla.f32 q9, q0, d5[1]\n"
"vmla.f32 q4, q1, d4[0]\n"
"vmla.f32 q6, q1, d4[1]\n"
"vmla.f32 q8, q1, d5[0]\n"
"vmla.f32 q10, q1, d5[1]\n"
// End of accumulation. The registers q3 -- q10 contain the final
// float32 accumulator values of the current 8x8 destination block.
// We now have to compute the final values from these accumulators
// and advance to the next 8x8 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"cmp r1, r3\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
"add r4, r4, r1, lsl #3\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
"b 5f\n"
"4:\n" // Finished last row...
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
// Go back to first row
"str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"cmp r8, r4\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
"add r10, r10, r1, lsl #2\n"
"str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
"mov %[lhs_ptr], r4\n"
"ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
"mov %[rhs_ptr], r5\n"
// Load some parameters needed for the end work on current block.
"ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
// Let r8 be stack offset of the row or column variable, whichever
// is the channel index.
"tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"ite eq\n"
"moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
"movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
// Let r8 be the channel index.
"ldr r8, [sp, r8]\n"
// Compute the bias pointer, by conditionally using the channel index
// (r8) as offset into bias buffer (r1).
"tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"it ne\n"
"addne r1, r1, r8, lsl #2\n"
// Load 4 bias values. When the channel dimension is rows, we will load
// another 4 bias values just before performing the bias addition below,
// as this kernel has a 8x4 rectangular shape.
"vld1.32 {d24, d25}, [r1]!\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore
// in the rest of the work on the current block.
// Load q0, q1
"vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
// Load q2
"vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
// Perform the bias-addition.
// Jump based on channel dimension.
"tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 6f\n"
// Case where channels are rows.
// Load the remaining 4 bias values, since we're on the width-8 side
// of this 8x4 kernel.
"vld1.32 {d26, d27}, [r1]\n"
"vadd.f32 q3, q3, q12\n"
"vadd.f32 q5, q5, q12\n"
"vadd.f32 q7, q7, q12\n"
"vadd.f32 q9, q9, q12\n"
"vadd.f32 q4, q4, q13\n"
"vadd.f32 q6, q6, q13\n"
"vadd.f32 q8, q8, q13\n"
"vadd.f32 q10, q10, q13\n"
"b 7f\n"
"6:\n"
// Case where channels are columns.
"vdup.32 q11, d24[0]\n"
"vdup.32 q13, d24[1]\n"
"vdup.32 q14, d25[0]\n"
"vdup.32 q15, d25[1]\n"
"vadd.f32 q3, q3, q11\n"
"vadd.f32 q4, q4, q11\n"
"vadd.f32 q5, q5, q13\n"
"vadd.f32 q6, q6, q13\n"
"vadd.f32 q7, q7, q14\n"
"vadd.f32 q8, q8, q14\n"
"vadd.f32 q9, q9, q15\n"
"vadd.f32 q10, q10, q15\n"
"7:\n"
// Load the clamp_min, clamp_max bounds
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"vdup.32 q12, r2\n" // clamp_min
"vdup.32 q13, r3\n" // clamp_max
// Apply the clamp_min bound
"vmax.f32 q3, q3, q12\n"
"vmax.f32 q4, q4, q12\n"
"vmax.f32 q5, q5, q12\n"
"vmax.f32 q6, q6, q12\n"
"vmax.f32 q7, q7, q12\n"
"vmax.f32 q8, q8, q12\n"
"vmax.f32 q9, q9, q12\n"
"vmax.f32 q10, q10, q12\n"
// Apply the clamp_max bound
"vmin.f32 q3, q3, q13\n"
"vmin.f32 q4, q4, q13\n"
"vmin.f32 q5, q5, q13\n"
"vmin.f32 q6, q6, q13\n"
"vmin.f32 q7, q7, q13\n"
"vmin.f32 q8, q8, q13\n"
"vmin.f32 q9, q9, q13\n"
"vmin.f32 q10, q10, q13\n"
// Compute how much of the 8x4 block of destination values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x4, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"sub r1, r1, r8\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"sub r2, r2, r4\n"
"mov r3, #8\n"
"mov r5, #4\n"
"cmp r1, #8\n"
// Compute r1 = how many rows of the 8x4 block fit
"it gt\n"
"movgt r1, r3\n"
"cmp r2, #4\n"
// Compute r2 = how many cols of the 8x4 block fit
"it gt\n"
"movgt r2, r5\n"
// Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits.
"cmp r1, r3\n"
"it eq\n"
"cmpeq r2, r5\n"
// Yes, all of the 8x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 8x4 block fits.
// Set (r3 address, r4 stride) to write to dst_tmp_buf
"mov r3, %[dst_tmp_buf]\n"
"mov r4, #32\n"
"b 31f\n"
"30:\n"
// Yes, all of the 8x4 block fits.
// Set (r3 address, r4 stride) to write directly to destination matrix.
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"mov r4, r5\n"
"31:\n"
// Write our float values to the destination described by
// (r3 address, r4 stride)
"vst1.32 {d6, d7, d8, d9}, [r3]\n"
"add r3, r3, r4\n"
RUY_MAKE_ZERO(q3)
RUY_MAKE_ZERO(q4)
"vst1.32 {d10, d11, d12, d13}, [r3]\n"
"add r3, r3, r4\n"
RUY_MAKE_ZERO(q5)
RUY_MAKE_ZERO(q6)
"vst1.32 {d14, d15, d16, d17}, [r3]\n"
"add r3, r3, r4\n"
RUY_MAKE_ZERO(q7)
RUY_MAKE_ZERO(q8)
"vst1.32 {d18, d19, d20, d21}, [r3]\n"
"add r3, r3, r4\n"
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
// If all of the 8x4 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 41f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"mov r3, %[dst_tmp_buf]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"mov r6, #0\n"
"50:\n"
"mov r5, #0\n"
"51:\n"
"ldr r10, [r3, r5, lsl #2]\n"
"str r10, [r4, r5, lsl #2]\n"
"add r5, r5, #1\n"
"cmp r5, r1\n"
"blt 51b\n"
"add r6, r6, #1\n"
"add r3, r3, #32\n"
"add r4, r4, r8\n"
// r2 = how many cols of the 8x4 block fit
"cmp r6, r2\n"
"blt 50b\n"
"41:\n"
// Load dst_ptr, increment, and write back.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"add r4, r4, #32\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
// Reload some params --- we had used r3, r5, r10 for a few other things
// since the last time we had loaded them.
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"cmp r8, r3\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add r8, r8, #8\n"
// Store new value of row
"str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"b 21f\n"
"20:\n"
// Was already at end row.
// Move back to first row.
"str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
// Move to the next column.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"add r4, r4, #4\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
// Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns)
"add r1, r1, r8, lsl #2\n"
// Store dst_col_ptr
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
// Store dst_ptr
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"cmp r8, r4\n"
// r1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 1.
"mov r1, #1\n"
"ble 1b\n"
// Restore stack pointer.
"add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
// clang-format on
: [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
: [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
// Clobber list must specify q registers (and not their constituent
// d registers). There is a (currently unexplained) slowdown if
// d registers are listed in the clobbers list.
: "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
"memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q12", "q13");
}
#undef RUY_MAKE_ZERO
#undef RUY_STACK_OFFSET_SIZE
#undef RUY_STACK_OFFSET_DST_COL_PTR
#undef RUY_STACK_OFFSET_DST_PTR
#undef RUY_STACK_OFFSET_ROW
#undef RUY_STACK_OFFSET_COL
#undef RUY_STACK_OFFSET_LHS_COL_PTR
#undef RUY_STACK_OFFSET_RHS_COL_PTR
#undef RUY_OFFSET_LHS_BASE_PTR
#undef RUY_OFFSET_RHS_BASE_PTR
#undef RUY_OFFSET_DST_BASE_PTR
#undef RUY_OFFSET_BIAS
#undef RUY_OFFSET_START_ROW
#undef RUY_OFFSET_START_COL
#undef RUY_OFFSET_LAST_ROW
#undef RUY_OFFSET_LAST_COL
#undef RUY_OFFSET_DST_ROWS
#undef RUY_OFFSET_DST_COLS
#undef RUY_OFFSET_LHS_STRIDE
#undef RUY_OFFSET_RHS_STRIDE
#undef RUY_OFFSET_DST_STRIDE
#undef RUY_OFFSET_DEPTH
#undef RUY_OFFSET_CLAMP_MIN
#undef RUY_OFFSET_CLAMP_MAX
#undef RUY_OFFSET_FLAGS
#define RUY_OFFSET_BIAS 0
#define RUY_OFFSET_LHS_SUMS 4
#define RUY_OFFSET_RHS_SUMS 8
#define RUY_OFFSET_LHS_BASE_PTR 12
#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16
#define RUY_OFFSET_MULTIPLIER_EXPONENT 20
#define RUY_OFFSET_RHS_BASE_PTR 24
#define RUY_OFFSET_DST_BASE_PTR 28
#define RUY_OFFSET_LHS_ZERO_POINT 32
#define RUY_OFFSET_RHS_ZERO_POINT 36
#define RUY_OFFSET_DST_ZERO_POINT 40
#define RUY_OFFSET_PROD_ZP_DEPTH 44
#define RUY_OFFSET_START_ROW 48
#define RUY_OFFSET_START_COL 52
#define RUY_OFFSET_LAST_ROW 56
#define RUY_OFFSET_LAST_COL 60
#define RUY_OFFSET_DST_ROWS 64
#define RUY_OFFSET_DST_COLS 68
#define RUY_OFFSET_LHS_STRIDE 72
#define RUY_OFFSET_RHS_STRIDE 76
#define RUY_OFFSET_DST_STRIDE 80
#define RUY_OFFSET_DEPTH 84
#define RUY_OFFSET_CLAMP_MIN 88
#define RUY_OFFSET_CLAMP_MAX 92
#define RUY_OFFSET_FLAGS 96
#define RUY_OFFSET_DST_TYPE_ID 97
#define RUY_STACK_OFFSET_SIZE 96
#define RUY_STACK_OFFSET_DST_COL_PTR 0
#define RUY_STACK_OFFSET_DST_PTR 16
#define RUY_STACK_OFFSET_ROW 32
#define RUY_STACK_OFFSET_COL 48
#define RUY_STACK_OFFSET_LHS_COL_PTR 64
#define RUY_STACK_OFFSET_RHS_COL_PTR 80
template <typename Params>
void CheckOffsetsInKernelParams8bit(const Params&) {
static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
"");
static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
"");
static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
"");
static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
"");
static_assert(offsetof(Params, multiplier_fixedpoint) ==
RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
"");
static_assert(
offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
"");
static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
}
// Fast-int8 kernel, ported from ARM 64 version.
// Relevant target CPUs for this kernel include Krait 400 and A9,
// since these are 32-bit, out-of-order CPUs.
void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) {
profiler::ScopeLabel label("Kernel (kNeon)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
// The asm kernel below has the following NEON register allocation:
//
// q6 - q13 are 128-bit (4x32b) accumulators.
// During accumulation, d0 -- d7 are used to load int8 data from LHS and
// d8 -- d11 from RHS:
// int8 RHS 16x2 block
// /-----------------------------|
// |d8.b[0-7] ..... d10.b[0-7]|
// | ... ... |
// |d9.b[0-7] ..... d11.b[0-7]|
// \-----------------------------/
// int8 LHS 4x16 block
// /------------------------\ /-----------------------------|
// |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 |
// |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 |
// (Reload d0, d1, d2, d3)
// |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 |
// |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 |
// \------------------------/ \-----------------------------/
// 128-bit accumulators 4x2 block
//
// No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
// optimization for this kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
// clang-format off
// Load the first 64 bytes of LHS and RHS data.
"vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
// Clear accumulators.
RUY_MAKE_ZERO(q6)
RUY_MAKE_ZERO(q7)
"vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
"vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
RUY_MAKE_ZERO(q11)
"vld1.8 {d10, d11}, [%[rhs_ptr]]!\n"
"sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
RUY_MAKE_ZERO(q12)
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
RUY_MAKE_ZERO(q13)
"str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
RUY_MAKE_ZERO(q14)
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
RUY_MAKE_ZERO(q15)
"str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
"str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
// r1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 16.
"mov r1, #16\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// r1 is how many levels of depth we have already loaded
// data for, r10 is the total depth.
"ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
"cmp r1, r10\n"
"beq 79f\n"
"2:\n"
// Mult, mult-acc in to q14, q15, q2, q3
"vmull.s8 q14, d0, d8\n"
"vmull.s8 q2, d0, d10\n"
"vmull.s8 q15, d2, d8\n"
"vmull.s8 q3, d2, d10\n"
"vmlal.s8 q14, d1, d9\n"
"vmlal.s8 q2, d1, d11\n"
"vmlal.s8 q15, d3, d9\n"
"vmlal.s8 q3, d3, d11\n"
"vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
// Then pairwise accumulate in to q6, q7, q10, q11
"vpadal.s16 q6, q14\n"
"vpadal.s16 q7, q15\n"
"vpadal.s16 q10, q2\n"
"vpadal.s16 q11, q3\n"
// Mult, mult-acc in to q14, q15, q2, q3
"vmull.s8 q14, d0, d8\n"
"vmull.s8 q2, d0, d10\n"
"vmull.s8 q15, d2, d8\n"
"vmull.s8 q3, d2, d10\n"
"vmlal.s8 q14, d1, d9\n"
"vmlal.s8 q2, d1, d11\n"
"vmlal.s8 q15, d3, d9\n"
"vmlal.s8 q3, d3, d11\n"
"vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
// Then pairwise accumulate in to q8, q9, q12, q13
"vpadal.s16 q8, q14\n"
"vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
"vpadal.s16 q9, q15\n"
"vpadal.s16 q12, q2\n"
"vpadal.s16 q13, q3\n"
// Prefetch the next 64 bytes of LHS and RHS data.
RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
// Each iteration of this loop advances by 16 levels of depth.
"add r1, r1, #16\n"
// Loop termination condition
"cmp r1, r10\n"
"blt 2b\n"
"79:\n"
// Mult, mult-acc in to q14, q15, q2, q3
"vmull.s8 q14, d0, d8\n"
"vmull.s8 q2, d0, d10\n"
"vmull.s8 q15, d2, d8\n"
"vmull.s8 q3, d2, d10\n"
"vmlal.s8 q14, d1, d9\n"
"vmlal.s8 q2, d1, d11\n"
"vmlal.s8 q15, d3, d9\n"
"vmlal.s8 q3, d3, d11\n"
"vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
// Then pairwise accumulate in to q6, q7, q10, q11
"vpadal.s16 q6, q14\n"
"vpadal.s16 q7, q15\n"
"vpadal.s16 q10, q2\n"
"vpadal.s16 q11, q3\n"
// Mult, mult-acc in to q14, q15, q2, q3
"vmull.s8 q14, d0, d8\n"
"vmull.s8 q2, d0, d10\n"
"vmull.s8 q15, d2, d8\n"
"vmull.s8 q3, d2, d10\n"
"vmlal.s8 q14, d1, d9\n"
"vmlal.s8 q2, d1, d11\n"
"vmlal.s8 q15, d3, d9\n"
"vmlal.s8 q3, d3, d11\n"
// Then pairwise accumulate in to q8, q9, q12, q13
"vpadal.s16 q8, q14\n"
"vpadal.s16 q9, q15\n"
"vpadal.s16 q12, q2\n"
"vpadal.s16 q13, q3\n"
// All accumulation over depth done. q6 - q13 contain the 4x32b
// accumulators for the 4x2 final matrix.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 4x2 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// q6-q13 now contain 4 x 32b
"vpadd.i32 d0, d12, d13\n"
"vpadd.i32 d1, d14, d15\n"
"vpadd.i32 d2, d16, d17\n"
"vpadd.i32 d3, d18, d19\n"
"vpadd.i32 d4, d20, d21\n"
"vpadd.i32 d5, d22, d23\n"
"vpadd.i32 d6, d24, d25\n"
"vpadd.i32 d7, d26, d27\n"
// d0-d7 each contain 2 x 32b accumulators.
// Need to add pairwise to get 1 x 32b for each of the 4x2 entries
// of destination, (Four 'd' registers total)
"vpadd.i32 d28, d0, d1\n"
"vpadd.i32 d29, d2, d3\n"
"vpadd.i32 d30, d4, d5\n"
"vpadd.i32 d31, d6, d7\n"
//Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"cmp r1, r3\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
"add r4, r4, r1, lsl #2\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
"b 5f\n"
"4:\n" // Finished last row...
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
// Go back to first row
"str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"cmp r8, r4\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
"add r10, r10, r1, lsl #1\n"
"str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
"mov %[lhs_ptr], r4\n"
"ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
"mov %[rhs_ptr], r5\n"
// Now we load: bias data, LHS sums data, RHS sums data.
// First, load the base pointers from the params.
"ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
// Let r8 be stack offset of the row or column variable, whichever
// is the channel index.
"tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"ite eq\n"
"moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
"movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
// Let r8 be the channel index.
"ldr r8, [sp, r8]\n"
// Compute the bias pointer, by conditionally using the channel index
// (r8) as offset into bias buffer (r1).
"tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"it ne\n"
"addne r1, r1, r8, lsl #2\n"
// Load 2 bias values. When the channel dimension is rows, we will load
// another 2 bias values just before performing the bias addition below,
// as this kernel has a 4x2 rectangular shape.
"vld1.32 {d24}, [r1]!\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
"vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
// Add to the bias values the product
// (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in
// https://arxiv.org/pdf/1712.05877.pdf
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"vdup.32 q9, r3\n"
"vadd.i32 d24, d24, d18\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
// Jump based on channel dimension.
"tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 6f\n"
// Case where channels are rows.
// Load the remaining 2 bias values, since we're on the width-4 side
// of this 4x2 kernel.
"vld1.32 {d25}, [r1]\n"
"vadd.i32 d25, d25, d19\n"
"vadd.i32 q14, q14, q12\n"
"vadd.i32 q15, q15, q12\n"
"b 7f\n"
"6:\n"
// Case where channels are columns.
"vdup.32 q10, d24[0]\n"
"vdup.32 q11, d24[1]\n"
"vadd.i32 q14, q14, q10\n"
"vadd.i32 q15, q15, q11\n"
"7:\n"
// LHS/RHS zero points
// Has RHS sums
"ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
// Offset by current col * number of bytes per value
"add r3, r3, r4, lsl #2\n"
"vld1.32 { d12 }, [r3]\n"
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"vdup.32 q10, r5\n" // create lhs_zero_point_vec
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"vmls.i32 q14, q10, d12[0]\n"
"vmls.i32 q15, q10, d12[1]\n"
"401:\n"
// Has LHS sums
"ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
// Offset by current row * number of bytes per value
"add r2, r2, r4, lsl #2\n"
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
// Load 4 lhs_sums values.
"vld1.32 {d22, d23}, [r2]\n"
"vdup.32 d13, r5\n" // rhs_zero_point
// Compute lhs_sums * rhs_zero_point.
"vmul.i32 q11, q11, d13[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"vsub.s32 q14, q14, q11\n"
"vsub.s32 q15, q15, q11\n"
// If the destination is int32, it means the user asks for the raw
// accumulators, no need for us to downquantize the value.
"ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
"cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
// Compute the data pointers for the multiplier data
// r1 = exponent part
// r2 = fixedpoint part
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
// r6 has flags, r8 has channel index
"tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"it ne\n"
"addne r1, r1, r8, lsl #2\n"
"it ne\n"
"addne r2, r2, r8, lsl #2\n"
// Load the first 2 values of multiplier exponent and fixedpoint data
// Since this kernel is rectangular 4x2, we will only conditionally load
// 2 more values below.
"vld1.32 {d20}, [r1]!\n" // 2 values of multiplier_exponent
"vld1.32 {d12}, [r2]!\n" // 2 values of multiplier_fixedpoint
"tst r6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"vmvn.i32 q8, #0\n"
"bne 8f\n"
// Case where channels are rows.
// Load the remaining 2 bias values, since we're on the width-4 side
// of this 4x2 kernel.
"vld1.32 {d21}, [r1]\n" // 2 more values of multiplier_exponent
"vld1.32 {d13}, [r2]\n" // 2 more values of multiplier_fixedpoint
"vmin.s32 q11, q10, q8\n"
"vsub.s32 q10, q10, q11\n"
// Apply the positive exponent part of the multiplier.
"vshl.s32 q14, q14, q10\n"
"vshl.s32 q15, q15, q10\n"
// Apply the fixed-point part of the multiplier.
"vqdmulh.s32 q14, q14, q6\n"
"vqdmulh.s32 q15, q15, q6\n"
// Apply the negative exponent part of the multiplier.
"vrshl.s32 q14, q14, q11\n"
"vrshl.s32 q15, q15, q11\n"
"b 9f\n"
"8:\n"
// Case where channels are columns.
"vmin.s32 d22, d20, d16\n"
"vsub.s32 d20, d20, d22\n"
// Apply the positive exponent part of the multiplier.
"vdup.32 q12, d20[0]\n"
"vdup.32 q13, d20[1]\n"
"vshl.s32 q14, q14, q12\n"
"vshl.s32 q15, q15, q13\n"
// Apply the fixed-point part of the multiplier.
"vqdmulh.s32 q14, q14, d12[0]\n"
"vqdmulh.s32 q15, q15, d12[1]\n"
// Apply the negative exponent part of the multiplier.
"vdup.32 q12, d22[0]\n"
"vdup.32 q13, d22[1]\n"
"vrshl.s32 q14, q14, q12\n"
"vrshl.s32 q15, q15, q13\n"
"9:\n"
"ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
"cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
// Store uint8 values:
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
// Cast-and-saturate from int32 to int16
// After this, all values for output are in q14.
"vqmovn.s32 d28, q14\n"
"vqmovn.s32 d29, q15\n"
// At this point, d12 -- d26, d30, d31 aren't used anymore for the
// current block, so we can start clearing these accumulators for the
// next block (next iteration of the main loop).
RUY_MAKE_ZERO(q6)
RUY_MAKE_ZERO(q7)
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
RUY_MAKE_ZERO(q11)
RUY_MAKE_ZERO(q12)
RUY_MAKE_ZERO(q13)
RUY_MAKE_ZERO(q15)
// Load the destination zero point into each of the 8 16-bit slots
// in a q register.
"ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"vdup.16 q13, r4\n" // dst_zero_point
// Add the destination zero point
"vadd.i16 q14, q14, q13\n"
// Cast-and-saturate from int16 to uint8
// Now all 8 1-byte values are in d30.
"vqmovun.s16 d30, q14\n"
// Load the clamp_min, clamp_max bounds
"ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"vdup.8 d28, r2\n" // clamp_min
"vdup.8 d29, r3\n" // clamp_max
// Apply the clamp_min bound
"vmax.u8 d30, d30, d28\n"
// Apply the clamp_max bound
"vmin.u8 d30, d30, d29\n"
// Compute how much of the 4x2 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x2, there are some 4x2 blocks along the boundaries that do
// not fit entirely.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"sub r1, r1, r8\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"sub r2, r2, r4\n"
"mov r3, #4\n"
"mov r5, #2\n"
"cmp r1, #4\n"
// Compute r1 = how many rows of the 4x2 block fit
"it gt\n"
"movgt r1, r3\n"
"cmp r2, #2\n"
// Compute r2 = how many cols of the 4x2 block fit
"it gt\n"
"movgt r2, r5\n"
// Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
"cmp r1, r3\n"
"it eq\n"
"cmpeq r2, r5\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
// Yes, all of the 4x2 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x2 block fits.
// Store to dst_tmp_buf
// Set r3 address to write to dst_tmp_buf.
"mov r3, %[dst_tmp_buf]\n"
"vst1.8 {d30}, [r3]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov r6, #0\n"
"50:\n"
"mov r8, #0\n"
"51:\n"
"ldrb r10, [r3, r8]\n"
"strb r10, [r4, r8]\n"
"add r8, r8, #1\n"
"cmp r8, r1\n"
"blt 51b\n"
"add r6, r6, #1\n"
"add r3, r3, #4\n"
"add r4, r4, r5\n"
"cmp r6, r2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x2 block fits.
// r3 address, r5 stride
"ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"mov r4, r3\n"
"mov r6, #1\n"
"vst1.32 {d30[0]}, [r3]\n"
"add r4, r4, r5\n"
"mov r3, r4\n"
"vst1.32 {d30[1]}, [r3]\n"
"31:\n"
// Load dst_ptr, increment, and write back.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"add r4, r4, #4\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
RUY_MAKE_ZERO(q13)
RUY_MAKE_ZERO(q14)
RUY_MAKE_ZERO(q15)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
// Store int8 values:
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
// Cast-and-saturate from int32 to int16
// After this, all values for output are in q14.
"vqmovn.s32 d28, q14\n"
"vqmovn.s32 d29, q15\n"
// At this point, d12 -- d26, d30, d31 aren't used anymore for the
// current block, so we can start clearing these accumulators for the
// next block (next iteration of the main loop).
RUY_MAKE_ZERO(q6)
RUY_MAKE_ZERO(q7)
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
RUY_MAKE_ZERO(q11)
RUY_MAKE_ZERO(q12)
RUY_MAKE_ZERO(q13)
RUY_MAKE_ZERO(q15)
// Load the destination zero point into each of the 8 16-bit slots
// in a q register.
"ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"vdup.16 q13, r4\n" // dst_zero_point
// Add the destination zero point
"vadd.i16 q14, q14, q13\n"
// Cast-and-saturate from int16 to int8
// Now all 8 1-byte values are in d30.
"vqmovn.s16 d30, q14\n"
// Load the clamp_min, clamp_max bounds
"ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"vdup.8 d28, r2\n" // clamp_min
"vdup.8 d29, r3\n" // clamp_max
// Apply the clamp_min bound
"vmax.s8 d30, d30, d28\n"
// Apply the clamp_max bound
"vmin.s8 d30, d30, d29\n"
// Compute how much of the 4x2 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x2, there are some 4x2 blocks along the boundaries that do
// not fit entirely.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"sub r1, r1, r8\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"sub r2, r2, r4\n"
"mov r3, #4\n"
"mov r5, #2\n"
"cmp r1, #4\n"
// Compute r1 = how many rows of the 4x2 block fit
"it gt\n"
"movgt r1, r3\n"
"cmp r2, #2\n"
// Compute r2 = how many cols of the 4x2 block fit
"it gt\n"
"movgt r2, r5\n"
// Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
"cmp r1, r3\n"
"it eq\n"
"cmpeq r2, r5\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
// Yes, all of the 4x2 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x2 block fits.
// Store to dst_tmp_buf
// Set r3 address to write to dst_tmp_buf.
"mov r3, %[dst_tmp_buf]\n"
"vst1.8 {d30}, [r3]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov r6, #0\n"
"50:\n"
"mov r8, #0\n"
"51:\n"
"ldrb r10, [r3, r8]\n"
"strb r10, [r4, r8]\n"
"add r8, r8, #1\n"
"cmp r8, r1\n"
"blt 51b\n"
"add r6, r6, #1\n"
"add r3, r3, #4\n"
"add r4, r4, r5\n"
"cmp r6, r2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x2 block fits.
// r3 address, r5 stride
"ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"mov r4, r3\n"
"mov r6, #1\n"
"vst1.32 {d30[0]}, [r3]\n"
"add r4, r4, r5\n"
"mov r3, r4\n"
"vst1.32 {d30[1]}, [r3]\n"
"31:\n"
// Load dst_ptr, increment, and write back.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"add r4, r4, #4\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
RUY_MAKE_ZERO(q13)
RUY_MAKE_ZERO(q14)
RUY_MAKE_ZERO(q15)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Load the destination zero point into each of the 4 32-bit slots
// in a q register.
"ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"vdup.32 q13, r4\n" // dst_zero_point
// Add the destination zero point
"vadd.s32 q14, q14, q13\n"
"vadd.s32 q15, q15, q13\n"
// Cast-and-saturate from int32 to int16
// After this, all values for output are in q14.
"vqmovn.s32 d28, q14\n"
"vqmovn.s32 d29, q15\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(q6)
RUY_MAKE_ZERO(q7)
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
RUY_MAKE_ZERO(q11)
RUY_MAKE_ZERO(q15)
// Load the clamp_min, clamp_max bounds
"ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"vdup.16 q12, r2\n" // clamp_min
"vdup.16 q13, r3\n" // clamp_max
// Apply the clamp_min bound
"vmax.s16 q14, q14, q12\n"
// Apply the clamp_max bound
"vmin.s16 q14, q14, q13\n"
RUY_MAKE_ZERO(q12)
RUY_MAKE_ZERO(q13)
// Compute how much of the 4x2 block of destination 16-bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x2, there are some 4x2 blocks along the boundaries that do
// not fit entirely.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"sub r1, r1, r8\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"sub r2, r2, r4\n"
"mov r3, #4\n"
"mov r5, #2\n"
"cmp r1, #4\n"
// Compute r1 = how many rows of the 4x2 block fit
"it gt\n"
"movgt r1, r3\n"
"cmp r2, #2\n"
// Compute r2 = how many cols of the 4x2 block fit
"it gt\n"
"movgt r2, r5\n"
// Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
"cmp r1, r3\n"
"it eq\n"
"cmpeq r2, r5\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
// Yes, all of the 4x2 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x2 block fits.
// Store to dst_tmp_buf
// Set r3 address to write to dst_tmp_buf.
"mov r3, %[dst_tmp_buf]\n"
"vst1.16 {q14}, [r3]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov r6, #0\n"
"50:\n"
"mov r8, #0\n"
"51:\n"
// Shift of offset register for half-word loads not allowed in A32,
// so we shift, load/store, then shift back r8.
"lsl r8, r8, #1\n"
"ldrh r10, [r3, r8]\n"
"strh r10, [r4, r8]\n"
"lsr r8, r8, #1\n"
"add r8, r8, #1\n"
"cmp r8, r1\n"
"blt 51b\n"
"add r6, r6, #1\n"
"add r3, r3, #8\n"
"add r4, r4, r5\n"
"cmp r6, r2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x2 block fits.
// r3 address, r5 stride
"ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"mov r4, r3\n"
"mov r6, #2\n"
"vst1.16 {d28[0]}, [r3], r6\n"
"add r4, r4, r5\n"
"vst1.16 {d28[1]}, [r3], r6\n"
"vst1.16 {d28[2]}, [r3], r6\n"
"vst1.16 {d28[3]}, [r3], r6\n"
"mov r3, r4\n"
"vst1.16 {d29[0]}, [r3], r6\n"
"vst1.16 {d29[1]}, [r3], r6\n"
"vst1.16 {d29[2]}, [r3], r6\n"
"vst1.16 {d29[3]}, [r3], r6\n"
"31:\n"
// Load dst_ptr, increment, and write back.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"add r4, r4, #8\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
RUY_MAKE_ZERO(q14)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// At this point, v20 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
// Clear accumulators.
RUY_MAKE_ZERO(q6)
RUY_MAKE_ZERO(q7)
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
RUY_MAKE_ZERO(q11)
RUY_MAKE_ZERO(q12)
RUY_MAKE_ZERO(q13)
// Compute how much of the 4x2 block of destination 32 bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x2, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"sub r1, r1, r8\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"sub r2, r2, r4\n"
"mov r3, #4\n"
"mov r5, #2\n"
"cmp r1, #4\n"
// Compute r1 = how many rows of the 4x2 block fit
"it gt\n"
"movgt r1, r3\n"
"cmp r2, #2\n"
// Compute r2 = how many cols of the 4x2 block fit
"it gt\n"
"movgt r2, r5\n"
// Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
"cmp r1, r3\n"
"it eq\n"
"cmpeq r2, r5\n"
// Yes, all of the 4x2 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x2 block fits.
// Set (r3 address, r4 stride) to write to dst_tmp_buf
"mov r3, %[dst_tmp_buf]\n"
"mov r4, #16\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x2 block fits.
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
// r3 address, r4 stride
"ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"mov r4, r5\n"
"31:\n"
"vst1.32 {d28, d29}, [r3]\n"
"add r3, r3, r4\n"
"vst1.32 {d30, d31}, [r3]\n"
// If all of the 4x2 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 41f\n"
// Not all of the 4x2 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"mov r3, %[dst_tmp_buf]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"mov r6, #0\n"
"50:\n"
"mov r5, #0\n"
"51:\n"
"ldr r10, [r3, r5, lsl #2]\n"
"str r10, [r4, r5, lsl #2]\n"
"add r5, r5, #1\n"
"cmp r5, r1\n"
"blt 51b\n"
"add r6, r6, #1\n"
"add r3, r3, #16\n"
"add r4, r4, r8\n"
// r2 = how many cols of the 8x4 block fit
"cmp r6, r2\n"
"blt 50b\n"
"41:\n"
// Load dst_ptr, increment, and write back.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"add r4, r4, #16\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
RUY_MAKE_ZERO(q10)
RUY_MAKE_ZERO(q11)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"cmp r8, r3\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add r8, r8, #4\n"
// Store new value of row
"str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"b 21f\n"
"20:\n"
// Was already at end row.
// Move back to first row.
"str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
// Move to the next column.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"add r4, r4, #2\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
// Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns)
"add r1, r1, r8, lsl #1\n"
// Store dst_col_ptr
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
// Store dst_ptr
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"cmp r8, r4\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 16.
"mov r1, #16\n"
"ble 1b\n"
// Restore stack pointer.
"add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
// clang-format on
: [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
: [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
: "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
// Clobber list must specify q registers (and not their constituent
// d registers). There is a (currently unexplained) slowdown if
// d registers are listed in the clobbers list.
"memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q12", "q13", "q14", "q15");
}
// Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS
// is still packed as if it has two columns
void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) {
profiler::ScopeLabel label("Kernel (kNeon)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL));
// The asm kernel below has the following NEON register allocation:
//
// q6 - q13 are 128-bit (4x32b) accumulators.
// During accumulation, d0 -- d7 are used to load int8 data from LHS and
// d8 -- d11 from RHS:
// int8 RHS 16x1 block
// /------------|
// | d8.b[0] |
// | ... |
// | d8.b[7] |
// | d9.b[0] |
// | ... |
// | d9.b[7] |
// \------------/
// int8 LHS 4x16 block
// /-----------------------------------------\ /------------|
// |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 |
// |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 |
// |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 |
// |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 |
// \-----------------------------------------/ \------------/
// 128-bit accumulators 4x1 block
//
// No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
// optimization for this kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
// clang-format off
// Load the first 64 bytes of LHS and RHS data.
"vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
"vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
"vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
"vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
"vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
// Skip the other column and advance the pointer.
"add %[rhs_ptr], %[rhs_ptr], #16\n"
"sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
"str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
"str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
"str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
// Clear accumulators.
RUY_MAKE_ZERO(q6)
RUY_MAKE_ZERO(q7)
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
RUY_MAKE_ZERO(q11)
RUY_MAKE_ZERO(q12)
RUY_MAKE_ZERO(q13)
RUY_MAKE_ZERO(q14)
RUY_MAKE_ZERO(q15)
// r1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 16.
"mov r1, #16\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// r1 is how many levels of depth we have already loaded
// data for, r10 is the total depth.
"ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
"cmp r1, r10\n"
"beq 79f\n"
"2:\n"
// Mult, mult-acc in to q14, q15
"vmull.s8 q14, d0, d8\n"
"vmull.s8 q15, d2, d8\n"
"vmlal.s8 q14, d1, d9\n"
"vmlal.s8 q15, d3, d9\n"
// Then pairwise accumulate in to q6, q7
"vpadal.s16 q6, q14\n"
"vpadal.s16 q7, q15\n"
// Mult, mult-acc in to q14, q15
"vmull.s8 q14, d4, d8\n"
"vmull.s8 q15, d6, d8\n"
"vmlal.s8 q14, d5, d9\n"
"vmlal.s8 q15, d7, d9\n"
// Then pairwise accumulate in to q8, q9
"vpadal.s16 q8, q14\n"
"vpadal.s16 q9, q15\n"
// Load the next 64 bytes of LHS and RHS data.
"vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
"vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
"vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
"vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
"vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
// Skip the other column and advance the pointer.
"add %[rhs_ptr], %[rhs_ptr], #16\n"
RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
// Each iteration of this loop advances by 16 levels of depth.
"add r1, r1, #16\n"
// Loop termination condition
"cmp r1, r10\n"
"blt 2b\n"
"79:\n"
// Mult, mult-acc in to q14, q15
"vmull.s8 q14, d0, d8\n"
"vmull.s8 q15, d2, d8\n"
"vmlal.s8 q14, d1, d9\n"
"vmlal.s8 q15, d3, d9\n"
// Then pairwise accumulate in to q6, q7
"vpadal.s16 q6, q14\n"
"vpadal.s16 q7, q15\n"
// Mult, mult-acc in to q14, q15
"vmull.s8 q14, d4, d8\n"
"vmull.s8 q15, d6, d8\n"
"vmlal.s8 q14, d5, d9\n"
"vmlal.s8 q15, d7, d9\n"
// Then pairwise accumulate in to q8, q9
"vpadal.s16 q8, q14\n"
"vpadal.s16 q9, q15\n"
// All accumulation over depth done. q6 - q9 contain the 4x32b
// accumulators for the 4x1 final matrix.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 4x2 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// q6-q9 now contain 4 x 32b
"vpadd.i32 d0, d12, d13\n"
"vpadd.i32 d1, d14, d15\n"
"vpadd.i32 d2, d16, d17\n"
"vpadd.i32 d3, d18, d19\n"
// d0-d4 each contain 2 x 32b accumulators.
// Need to add pairwise to get 1 x 32b for each of the 4x1 entries
// of destination, (Four 'd' registers total)
"vpadd.i32 d28, d0, d1\n"
"vpadd.i32 d29, d2, d3\n"
// Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries.
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"cmp r1, r3\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
"add r4, r4, r1, lsl #2\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
"b 5f\n"
"4:\n" // Finished last row...
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
// Go back to first row
"str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"cmp r8, r4\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
"add r10, r10, r1, lsl #1\n"
"str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
"mov %[lhs_ptr], r4\n"
"ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
"mov %[rhs_ptr], r5\n"
// Now we load: bias data, LHS sums data, RHS sums data.
// First, load the base pointers from the params.
"ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
// Offset these base pointers as needed given the current row, col.
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"it ne\n"
"addne r1, r1, r8, lsl #2\n"
// Load 4 bias values.
"vld1.32 {d24, d25}, [r1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
"vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
"vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
"vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
"vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
// Skip the other column and advance the pointer.
"add %[rhs_ptr], %[rhs_ptr], #16\n"
RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
// Add to the bias values the product
// (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in
// https://arxiv.org/pdf/1712.05877.pdf
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"vdup.32 q9, r3\n"
"vadd.i32 q12, q12, q9\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
"vadd.i32 q14, q14, q12\n"
// LHS/RHS zero points
// Has RHS sums
"ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
// Offset by current col * number of bytes per value
"add r3, r3, r4, lsl #2\n"
"vld1.32 { d12 }, [r3]\n"
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"vdup.32 q10, r5\n" // create lhs_zero_point_vec
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"vmls.i32 q14, q10, d12[0]\n"
"401:\n"
// Has LHS sums
"ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
// Offset by current row * number of bytes per value
"add r2, r2, r4, lsl #2\n"
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
// Load 4 lhs_sums values.
"vld1.32 {d22, d23}, [r2]\n"
"vdup.32 d13, r5\n" // rhs_zero_point
// Compute lhs_sums * rhs_zero_point.
"vmul.i32 q11, q11, d13[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"vsub.s32 q14, q14, q11\n"
// If the destination is int32, it means the user asks for the raw
// accumulators, no need for us to downquantize the value.
"ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
"cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
//Load the exponent part of the multiplier.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
"tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"it ne\n"
"addne r1, r1, r4, lsl #2\n"
"vld1.32 {q10}, [r1]\n"
"vmvn.i32 q8, #0\n"
"vmin.s32 q13, q10, q8\n"
"vsub.s32 q12, q10, q13\n"
// Apply the positive exponent part of the multiplier.
"vshl.s32 q14, q14, q12\n"
// Load fixed point part of the multiplier
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
// r6 has flags, r4 has row
"tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"it ne\n"
"addne r1, r1, r4, lsl #2\n"
"vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint
// Apply the fixed-point part of the multiplier.
"vqdmulh.s32 q14, q14, q10\n"
// Apply the negative exponent part of the multiplier.
"vrshl.s32 q14, q14, q13\n"
"ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
"cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
// Store uint8 values:
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
// Cast-and-saturate from int32 to int16
// After this, all values for output are in d28.
"vqmovn.s32 d28, q14\n"
// At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
// current block, so we can start clearing these accumulators for the
// next block (next iteration of the main loop).
RUY_MAKE_ZERO(q6)
RUY_MAKE_ZERO(q7)
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
RUY_MAKE_ZERO(q11)
RUY_MAKE_ZERO(q12)
RUY_MAKE_ZERO(q13)
RUY_MAKE_ZERO(q15)
// Load the destination zero point into each of the 8 16-bit slots
// in a q register.
"ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"vdup.16 q13, r4\n" // dst_zero_point
// Add the destination zero point
"vadd.i16 q14, q14, q13\n"
// Cast-and-saturate from int16 to uint8
"vqmovun.s16 d30, q14\n"
// At this point, we only need 4 8-bit values in the lower half
// of d30.
// Load the clamp_min, clamp_max bounds
"ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"vdup.8 d28, r2\n" // clamp_min
"vdup.8 d29, r3\n" // clamp_max
// Apply the clamp_min bound
"vmax.u8 d30, d30, d28\n"
// Apply the clamp_max bound
"vmin.u8 d30, d30, d29\n"
// Compute how much of the 4x1 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x1, there are some 4x1 blocks along the boundaries that do
// not fit entirely.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"sub r1, r1, r8\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"sub r2, r2, r4\n"
"mov r3, #4\n"
"mov r5, #2\n"
"cmp r1, #4\n"
// Compute r1 = how many rows of the 4x1 block fit
"it gt\n"
"movgt r1, r3\n"
// Test if r1==4, i.e. if all of the 4x1 block fits.
"cmp r1, r3\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
// Yes, all of the 4x1 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x1 block fits.
// Store to dst_tmp_buf
// Set r3 address to write to dst_tmp_buf.
"mov r3, %[dst_tmp_buf]\n"
"vst1.8 {d30}, [r3]\n"
// Slow loop copying from dst_tmp_buf to dst.
"50:\n"
"mov r8, #0\n"
"51:\n"
"ldrb r10, [r3, r8]\n"
"strb r10, [r4, r8]\n"
"add r8, r8, #1\n"
"cmp r8, r1\n"
"blt 51b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x1 block fits.
// r3 address, r5 stride
"ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"mov r4, r3\n"
"mov r6, #1\n"
"vst1.8 {d30[0]}, [r3], r6\n"
"vst1.8 {d30[1]}, [r3], r6\n"
"vst1.8 {d30[2]}, [r3], r6\n"
"vst1.8 {d30[3]}, [r3], r6\n"
"31:\n"
// Load dst_ptr, increment, and write back.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"add r4, r4, #4\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
RUY_MAKE_ZERO(q13)
RUY_MAKE_ZERO(q14)
RUY_MAKE_ZERO(q15)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
// Store int8 values:
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
// Cast-and-saturate from int32 to int16
// After this, all values for output are in d28.
"vqmovn.s32 d28, q14\n"
// At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
// current block, so we can start clearing these accumulators for the
// next block (next iteration of the main loop).
RUY_MAKE_ZERO(q6)
RUY_MAKE_ZERO(q7)
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
RUY_MAKE_ZERO(q11)
RUY_MAKE_ZERO(q12)
RUY_MAKE_ZERO(q13)
RUY_MAKE_ZERO(q15)
// Load the destination zero point into each of the 8 16-bit slots
// in a q register.
"ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"vdup.16 q13, r4\n" // dst_zero_point
// Add the destination zero point
"vadd.i16 q14, q14, q13\n"
// Cast-and-saturate from int16 to int8
"vqmovn.s16 d30, q14\n"
// At this point, we only need 4 8-bit values in the lower half
// of d30.
// Load the clamp_min, clamp_max bounds
"ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"vdup.8 d28, r2\n" // clamp_min
"vdup.8 d29, r3\n" // clamp_max
// Apply the clamp_min bound
"vmax.s8 d30, d30, d28\n"
// Apply the clamp_max bound
"vmin.s8 d30, d30, d29\n"
// Compute how much of the 4x1 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x2, there are some 4x2 blocks along the boundaries that do
// not fit entirely.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"sub r1, r1, r8\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"sub r2, r2, r4\n"
"mov r3, #4\n"
"mov r5, #2\n"
"cmp r1, #4\n"
// Compute r1 = how many rows of the 4x2 block fit
"it gt\n"
"movgt r1, r3\n"
// Test if r1==4 i.e. if all of the 4x1 block fits.
"cmp r1, r3\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
// Yes, all of the 4x2 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x2 block fits.
// Store to dst_tmp_buf
// Set r3 address to write to dst_tmp_buf.
"mov r3, %[dst_tmp_buf]\n"
"vst1.8 {d30}, [r3]\n"
// Slow loop copying from dst_tmp_buf to dst.
"50:\n"
"mov r8, #0\n"
"51:\n"
"ldrb r10, [r3, r8]\n"
"strb r10, [r4, r8]\n"
"add r8, r8, #1\n"
"cmp r8, r1\n"
"blt 51b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x1 block fits.
// r3 address, r5 stride
"ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"mov r4, r3\n"
"mov r6, #1\n"
"vst1.8 {d30[0]}, [r3], r6\n"
"vst1.8 {d30[1]}, [r3], r6\n"
"vst1.8 {d30[2]}, [r3], r6\n"
"vst1.8 {d30[3]}, [r3], r6\n"
"31:\n"
// Load dst_ptr, increment, and write back.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"add r4, r4, #4\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
RUY_MAKE_ZERO(q13)
RUY_MAKE_ZERO(q14)
RUY_MAKE_ZERO(q15)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Load the destination zero point into each of the 4 32-bit slots
// in a q register.
"ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"vdup.32 q13, r4\n" // dst_zero_point
// Add the destination zero point
"vadd.s32 q14, q14, q13\n"
//"vadd.s32 q15, q15, q13\n"
// Cast-and-saturate from int32 to int16
// After this, all values for output are in d28.
"vqmovn.s32 d28, q14\n"
// At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(q6)
RUY_MAKE_ZERO(q7)
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
RUY_MAKE_ZERO(q11)
RUY_MAKE_ZERO(q15)
// Load the clamp_min, clamp_max bounds
"ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"vdup.16 d24, r2\n" // clamp_min
"vdup.16 d26, r3\n" // clamp_max
// Apply the clamp_min bound
"vmax.s16 d28, d28, d24\n"
// Apply the clamp_max bound
"vmin.s16 d28, d28, d26\n"
RUY_MAKE_ZERO(q12)
RUY_MAKE_ZERO(q13)
// Compute how much of the 4x1 block of destination 16-bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x1, there are some 4x1 blocks along the boundaries that do
// not fit entirely.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"sub r1, r1, r8\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"sub r2, r2, r4\n"
"mov r3, #4\n"
"mov r5, #2\n"
"cmp r1, #4\n"
// Compute r1 = how many rows of the 4x1 block fit
"it gt\n"
"movgt r1, r3\n"
// Test if r1==4, i.e. if all of the 4x1 block fits.
"cmp r1, r3\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
// Yes, all of the 4x1 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x1 block fits.
// Store to dst_tmp_buf
// Set r3 address to write to dst_tmp_buf.
"mov r3, %[dst_tmp_buf]\n"
"vst1.16 {d28}, [r3]\n"
// Slow loop copying from dst_tmp_buf to dst.
"50:\n"
"mov r8, #0\n"
"51:\n"
// Shift of offset register for half-word loads not allowed in A32,
// so we shift, load/store, then shift back r8.
"lsl r8, r8, #1\n"
"ldrh r10, [r3, r8]\n"
"strh r10, [r4, r8]\n"
"lsr r8, r8, #1\n"
"add r8, r8, #1\n"
"cmp r8, r1\n"
"blt 51b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x1 block fits.
// r3 address, r5 stride
"ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"mov r4, r3\n"
"mov r6, #2\n"
"vst1.16 {d28[0]}, [r3], r6\n"
"vst1.16 {d28[1]}, [r3], r6\n"
"vst1.16 {d28[2]}, [r3], r6\n"
"vst1.16 {d28[3]}, [r3], r6\n"
"31:\n"
// Load dst_ptr, increment, and write back.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"add r4, r4, #8\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
RUY_MAKE_ZERO(q14)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// At this point, v20 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
// Clear accumulators.
RUY_MAKE_ZERO(q6)
RUY_MAKE_ZERO(q7)
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
RUY_MAKE_ZERO(q11)
RUY_MAKE_ZERO(q12)
RUY_MAKE_ZERO(q13)
// Compute how much of the 4x1 block of destination 32 bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x2, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"sub r1, r1, r8\n"
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"sub r2, r2, r4\n"
"mov r3, #4\n"
"mov r5, #2\n"
"cmp r1, #4\n"
// Compute r1 = how many rows of the 4x2 block fit
"it gt\n"
"movgt r1, r3\n"
// Test if r1==4, i.e. if all of the 4x1 block fits.
"cmp r1, r3\n"
// Yes, all of the 4x1 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x1 block fits.
// Set (r3 address, r4 stride) to write to dst_tmp_buf
"mov r3, %[dst_tmp_buf]\n"
"mov r4, #16\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x1 block fits.
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
// r3 address, r4 stride
"ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"mov r4, r5\n"
"31:\n"
"vst1.32 {d28, d29}, [r3]\n"
// If all of the 4x1 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 41f\n"
// Not all of the 4x1 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"mov r3, %[dst_tmp_buf]\n"
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"50:\n"
"mov r5, #0\n"
"51:\n"
"ldr r10, [r3, r5, lsl #2]\n"
"str r10, [r4, r5, lsl #2]\n"
"add r5, r5, #1\n"
"cmp r5, r1\n"
"blt 51b\n"
"41:\n"
// Load dst_ptr, increment, and write back.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"add r4, r4, #16\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
RUY_MAKE_ZERO(q10)
RUY_MAKE_ZERO(q11)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"cmp r8, r3\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add r8, r8, #4\n"
// Store new value of row
"str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
"b 21f\n"
"20:\n"
// Was already at end row.
// Move back to first row.
"str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
// Move to the next column.
"ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"add r4, r4, #2\n"
"str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
// Increment dst_col_ptr by dst_stride (i.e. 1 column)
"add r1, r1, r8\n"
// Store dst_col_ptr
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
// Store dst_ptr
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
"cmp r8, r4\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 16.
"mov r1, #16\n"
"ble 1b\n"
// Restore stack pointer.
"add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
// clang-format on
: [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
: [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
: "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
// Clobber list must specify q registers (and not their constituent
// d registers). There is a (currently unexplained) slowdown if
// d registers are listed in the clobbers list.
"memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q12", "q13", "q14", "q15");
}
#undef RUY_OFFSET_BIAS
#undef RUY_OFFSET_LHS_SUMS
#undef RUY_OFFSET_RHS_SUMS
#undef RUY_OFFSET_LHS_BASE_PTR
#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT
#undef RUY_OFFSET_MULTIPLIER_EXPONENT
#undef RUY_OFFSET_RHS_BASE_PTR
#undef RUY_OFFSET_DST_BASE_PTR
#undef RUY_OFFSET_LHS_ZERO_POINT
#undef RUY_OFFSET_RHS_ZERO_POINT
#undef RUY_OFFSET_DST_ZERO_POINT
#undef RUY_OFFSET_PROD_ZP_DEPTH
#undef RUY_OFFSET_START_ROW
#undef RUY_OFFSET_START_COL
#undef RUY_OFFSET_LAST_ROW
#undef RUY_OFFSET_LAST_COL
#undef RUY_OFFSET_DST_ROWS
#undef RUY_OFFSET_DST_COLS
#undef RUY_OFFSET_LHS_STRIDE
#undef RUY_OFFSET_RHS_STRIDE
#undef RUY_OFFSET_DST_STRIDE
#undef RUY_OFFSET_DEPTH
#undef RUY_OFFSET_CLAMP_MIN
#undef RUY_OFFSET_CLAMP_MAX
#undef RUY_OFFSET_FLAGS
#undef RUY_OFFSET_DST_TYPE_ID
#undef RUY_STACK_OFFSET_SIZE
#undef RUY_STACK_OFFSET_DST_COL_PTR
#undef RUY_STACK_OFFSET_DST_PTR
#undef RUY_STACK_OFFSET_ROW
#undef RUY_STACK_OFFSET_COL
#undef RUY_STACK_OFFSET_LHS_COL_PTR
#undef RUY_STACK_OFFSET_RHS_COL_PTR
#endif // RUY_PLATFORM_NEON_32 && (RUY_OPT(ASM)
} // namespace ruy