blob: 5424107908b41ab1205299a6d34b18fb32a29b40 [file] [log] [blame]
/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include "ruy/asm_helpers.h"
#include "ruy/check_macros.h"
#include "ruy/kernel_arm.h"
#include "ruy/opt_set.h"
#include "ruy/platform.h"
#include "ruy/profiler/instrumentation.h"
namespace ruy {
#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
#define RUY_ASM_LABEL_STORE_UINT8 91
#define RUY_ASM_LABEL_STORE_INT8 92
#define RUY_ASM_LABEL_STORE_INT16 93
#define RUY_ASM_LABEL_STORE_INT32 94
#define RUY_ASM_LABEL_AFTER_STORE 99
#define RUY_OFFSET_BIAS 0
#define RUY_OFFSET_LHS_SUMS 8
#define RUY_OFFSET_RHS_SUMS 16
#define RUY_OFFSET_LHS_BASE_PTR 24
#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32
#define RUY_OFFSET_MULTIPLIER_EXPONENT 40
#define RUY_OFFSET_RHS_BASE_PTR 48
#define RUY_OFFSET_DST_BASE_PTR 56
#define RUY_OFFSET_LHS_ZERO_POINT 64
#define RUY_OFFSET_RHS_ZERO_POINT 68
#define RUY_OFFSET_DST_ZERO_POINT 72
#define RUY_OFFSET_PROD_ZP_DEPTH 76
#define RUY_OFFSET_START_ROW 80
#define RUY_OFFSET_START_COL 84
#define RUY_OFFSET_LAST_ROW 88
#define RUY_OFFSET_LAST_COL 92
#define RUY_OFFSET_DST_ROWS 96
#define RUY_OFFSET_DST_COLS 100
#define RUY_OFFSET_LHS_STRIDE 104
#define RUY_OFFSET_RHS_STRIDE 108
#define RUY_OFFSET_DST_STRIDE 112
#define RUY_OFFSET_DEPTH 116
#define RUY_OFFSET_CLAMP_MIN 120
#define RUY_OFFSET_CLAMP_MAX 124
#define RUY_OFFSET_FLAGS 128
template <typename Params>
void CheckOffsetsInKernelParams8bit(const Params&) {
static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
"");
static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
"");
static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
"");
static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
"");
static_assert(offsetof(Params, multiplier_fixedpoint) ==
RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
"");
static_assert(
offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
"");
static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
}
// Fast-int8-trick kernel, similar to this production gemmlowp kernel:
// NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits
// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296
//
// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75,
// since these are 64-bit, out-of-order and without dotprod support.
void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) {
profiler::ScopeLabel label("Kernel (kNeon)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
void* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are int32 accumulators.
// During accumulation, v0 -- v3 are used to load int8 data from LHS and
// v4 -- v7 from RHS:
//
// int8 RHS 16x4 block
// /-----------------------------------------|
// |v4.b[0] ... v7.b[0] |
// | ... ... |
// |v4.b[15] ... v7.b[15] |
// \-----------------------------------------/
// int8 LHS 4x16 block
// /---------------------\ /-----------------------------------------|
// |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s |
// |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s |
// |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s |
// |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s |
// \---------------------/ \-----------------------------------------/
// int32 accumulators 4x4 block
//
// No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
// optimization for this kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Load the first 64 bytes of LHS and RHS data.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 16.
"mov w1, #16\n"
// Perform the first few multiply-adds on the data that we have already
// loaded.
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
"smull v12.8h, v0.8b, v5.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"smull v14.8h, v2.8b, v5.8b\n"
"smull v15.8h, v3.8b, v5.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// Reminder - w1 is how many levels of depth we have already loaded
// data for, w12 is the total depth.
"cmp w1, w12\n"
"beq 79f\n"
"2:\n"
// Some multiplications and 16-bit accumulation were already done above,
// so we start right away in the middle.
"sadalp v16.4s, v8.8h\n"
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
"smull v8.8h, v0.8b, v6.8b\n"
"sadalp v17.4s, v9.8h\n"
"ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
"smull v9.8h, v1.8b, v6.8b\n"
"sadalp v18.4s, v10.8h\n"
"smull v10.8h, v2.8b, v6.8b\n"
"sadalp v19.4s, v11.8h\n"
"smull v11.8h, v3.8b, v6.8b\n"
"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v0.8b, v7.8b\n"
"sadalp v21.4s, v13.8h\n"
"smull v13.8h, v1.8b, v7.8b\n"
"sadalp v22.4s, v14.8h\n"
"smull v14.8h, v2.8b, v7.8b\n"
"sadalp v23.4s, v15.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v6.16b\n"
"smlal2 v9.8h, v1.16b, v6.16b\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"smlal2 v11.8h, v3.16b, v6.16b\n"
"ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
"smlal2 v12.8h, v0.16b, v7.16b\n"
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"smlal2 v13.8h, v1.16b, v7.16b\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"smlal2 v14.8h, v2.16b, v7.16b\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
"sadalp v24.4s, v8.8h\n"
"smull v8.8h, v0.8b, v4.8b\n"
"sadalp v25.4s, v9.8h\n"
"ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
"smull v9.8h, v1.8b, v4.8b\n"
"sadalp v26.4s, v10.8h\n"
"smull v10.8h, v2.8b, v4.8b\n"
"sadalp v27.4s, v11.8h\n"
"smull v11.8h, v3.8b, v4.8b\n"
"sadalp v28.4s, v12.8h\n"
"smull v12.8h, v0.8b, v5.8b\n"
"sadalp v29.4s, v13.8h\n"
"smull v13.8h, v1.8b, v5.8b\n"
"sadalp v30.4s, v14.8h\n"
"smull v14.8h, v2.8b, v5.8b\n"
"sadalp v31.4s, v15.8h\n"
"smull v15.8h, v3.8b, v5.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
// Each iteration of this loop advances by 16 levels of depth.
"add w1, w1, #16\n"
// Loop termination condition
"cmp w1, w12\n"
"blt 2b\n"
"79:\n"
"sadalp v16.4s, v8.8h\n"
"smull v8.8h, v0.8b, v6.8b\n"
"sadalp v17.4s, v9.8h\n"
"smull v9.8h, v1.8b, v6.8b\n"
"sadalp v18.4s, v10.8h\n"
"smull v10.8h, v2.8b, v6.8b\n"
"sadalp v19.4s, v11.8h\n"
"smull v11.8h, v3.8b, v6.8b\n"
"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v0.8b, v7.8b\n"
"sadalp v21.4s, v13.8h\n"
"smull v13.8h, v1.8b, v7.8b\n"
"sadalp v22.4s, v14.8h\n"
"smull v14.8h, v2.8b, v7.8b\n"
"sadalp v23.4s, v15.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v6.16b\n"
"smlal2 v9.8h, v1.16b, v6.16b\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"smlal2 v11.8h, v3.16b, v6.16b\n"
"smlal2 v12.8h, v0.16b, v7.16b\n"
"smlal2 v13.8h, v1.16b, v7.16b\n"
"smlal2 v14.8h, v2.16b, v7.16b\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"sadalp v24.4s, v8.8h\n"
"sadalp v25.4s, v9.8h\n"
"sadalp v26.4s, v10.8h\n"
"sadalp v27.4s, v11.8h\n"
"sadalp v28.4s, v12.8h\n"
"sadalp v29.4s, v13.8h\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 4x4 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 4x4 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Reduce 32bit accumulators horizontally.
"addp v16.4s, v16.4s, v17.4s\n"
"addp v18.4s, v18.4s, v19.4s\n"
"addp v20.4s, v20.4s, v21.4s\n"
"addp v22.4s, v22.4s, v23.4s\n"
"addp v24.4s, v24.4s, v25.4s\n"
"addp v26.4s, v26.4s, v27.4s\n"
"addp v28.4s, v28.4s, v29.4s\n"
"addp v30.4s, v30.4s, v31.4s\n"
// Reduce 32bit accumulators horizontally, second pass
// (each pass adds pairwise. we need to add 4-wise).
"addp v16.4s, v16.4s, v18.4s\n"
"addp v17.4s, v20.4s, v22.4s\n"
"addp v18.4s, v24.4s, v26.4s\n"
"addp v19.4s, v28.4s, v30.4s\n"
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
"mvni v8.4s, #0\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
// Now we load: bias data, LHS sums data, RHS sums data.
// First, load the base pointers from the params.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
// Determine the channel index.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
// Offset the bias pointer as needed given the current row, col.
"add x5, x1, x3, lsl #2\n"
// If there is no bias, use no offset, just address the passed zero
// data.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 4 bias values.
"ld1 {v14.4s}, [x1]\n"
// Load the multiplier_fixedpoint values.
"add x5, x4, x3, lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"csel x4, x4, x5, eq\n"
"ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
// Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"add v14.4s, v14.4s, v9.4s\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
// Jump based on channel dimension.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 6f\n"
// Case where channels are rows
"add v16.4s, v16.4s, v14.4s\n"
"add v17.4s, v17.4s, v14.4s\n"
"add v18.4s, v18.4s, v14.4s\n"
"add v19.4s, v19.4s, v14.4s\n"
"b 7f\n"
"6:\n"
// Case where channels are columns
"dup v20.4s, v14.s[0]\n"
"dup v21.4s, v14.s[1]\n"
"dup v22.4s, v14.s[2]\n"
"dup v23.4s, v14.s[3]\n"
"add v16.4s, v16.4s, v20.4s\n"
"add v17.4s, v17.4s, v21.4s\n"
"add v18.4s, v18.4s, v22.4s\n"
"add v19.4s, v19.4s, v23.4s\n"
"7:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"add x3, x3, %x[col], lsl #2\n"
"ld1 {v14.4s}, [x3]\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"dup v10.4s, w5\n" // create lhs_zero_point_vec
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"mls v16.4s, v10.4s, v14.s[0]\n"
"mls v17.4s, v10.4s, v14.s[1]\n"
"mls v18.4s, v10.4s, v14.s[2]\n"
"mls v19.4s, v10.4s, v14.s[3]\n"
"401:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"add x2, x2, %x[row], lsl #2\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
// Load 4 lhs_sums values.
"ld1 {v11.4s}, [x2]\n"
"ins v13.s[1], w5\n" // rhs_zero_point
// Compute lhs_sums * rhs_zero_point.
"mul v11.4s, v11.4s, v13.s[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"sub v16.4s, v16.4s, v11.4s\n"
"sub v17.4s, v17.4s, v11.4s\n"
"sub v18.4s, v18.4s, v11.4s\n"
"sub v19.4s, v19.4s, v11.4s\n"
// If the destination is int32, it means the user asks for the raw
// accumulators, no need for us to downquantize the value.
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
//Load the exponent part of the multiplier.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
// Determine the channel index.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"add x5, x1, x3, lsl #2\n"
"csel x1, x1, x5, eq\n"
"ld1 {v14.4s}, [x1]\n"
"smin v11.4s, v8.4s, v14.4s\n"
"sub v12.4s, v14.4s, v11.4s\n"
// Jump based on channel dimension.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 8f\n"
// Case where channels are rows
// Apply the positive exponent part of the multiplier.
"sshl v16.4s, v16.4s, v12.4s\n"
"sshl v17.4s, v17.4s, v12.4s\n"
"sshl v18.4s, v18.4s, v12.4s\n"
"sshl v19.4s, v19.4s, v12.4s\n"
// Apply the fixed-point part of the multiplier.
"sqdmulh v16.4s, v16.4s, v15.4s\n"
"sqdmulh v17.4s, v17.4s, v15.4s\n"
"sqdmulh v18.4s, v18.4s, v15.4s\n"
"sqdmulh v19.4s, v19.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
"srshl v16.4s, v16.4s, v11.4s\n"
"srshl v17.4s, v17.4s, v11.4s\n"
"srshl v18.4s, v18.4s, v11.4s\n"
"srshl v19.4s, v19.4s, v11.4s\n"
"b 9f\n"
"8:\n"
// Case where channels are columns
// Apply the positive exponent part of the multiplier.
"dup v20.4s, v12.s[0]\n"
"dup v21.4s, v12.s[1]\n"
"dup v22.4s, v12.s[2]\n"
"dup v23.4s, v12.s[3]\n"
"sshl v16.4s, v16.4s, v20.4s\n"
"sshl v17.4s, v17.4s, v21.4s\n"
"sshl v18.4s, v18.4s, v22.4s\n"
"sshl v19.4s, v19.4s, v23.4s\n"
// Apply the fixed-point part of the multiplier.
"sqdmulh v16.4s, v16.4s, v15.s[0]\n"
"sqdmulh v17.4s, v17.4s, v15.s[1]\n"
"sqdmulh v18.4s, v18.4s, v15.s[2]\n"
"sqdmulh v19.4s, v19.4s, v15.s[3]\n"
// Apply the negative exponent part of the multiplier.
"dup v20.4s, v11.s[0]\n"
"dup v21.4s, v11.s[1]\n"
"dup v22.4s, v11.s[2]\n"
"dup v23.4s, v11.s[3]\n"
"srshl v16.4s, v16.4s, v20.4s\n"
"srshl v17.4s, v17.4s, v21.4s\n"
"srshl v18.4s, v18.4s, v22.4s\n"
"srshl v19.4s, v19.4s, v23.4s\n"
"9:\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"sqadd v16.8h, v16.8h, v14.8h\n"
"sqadd v17.8h, v17.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtun v16.8b, v16.8h\n"
"sqxtun2 v16.16b, v17.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"umax v16.16b, v16.16b, v14.16b\n"
// Apply the clamp_max bound
"umin v16.16b, v16.16b, v15.16b\n"
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"st1 {v16.16b}, [%[dst_tmp_buf]]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #4\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[0], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[1], [x3], #1\n"
"st1 {v16.b}[2], [x3], #1\n"
"st1 {v16.b}[3], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[4], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[5], [x3], #1\n"
"st1 {v16.b}[6], [x3], #1\n"
"st1 {v16.b}[7], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[8], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[9], [x3], #1\n"
"st1 {v16.b}[10], [x3], #1\n"
"st1 {v16.b}[11], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[12], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[13], [x3], #1\n"
"st1 {v16.b}[14], [x3], #1\n"
"st1 {v16.b}[15], [x3], #1\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #4\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"sqadd v16.8h, v16.8h, v14.8h\n"
"sqadd v17.8h, v17.8h, v14.8h\n"
// Cast-and-saturate from int16 to int8
"sqxtn v16.8b, v16.8h\n"
"sqxtn2 v16.16b, v17.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.16b, v16.16b, v14.16b\n"
// Apply the clamp_max bound
"smin v16.16b, v16.16b, v15.16b\n"
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"st1 {v16.16b}, [%[dst_tmp_buf]]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #4\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[0], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[1], [x3], #1\n"
"st1 {v16.b}[2], [x3], #1\n"
"st1 {v16.b}[3], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[4], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[5], [x3], #1\n"
"st1 {v16.b}[6], [x3], #1\n"
"st1 {v16.b}[7], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[8], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[9], [x3], #1\n"
"st1 {v16.b}[10], [x3], #1\n"
"st1 {v16.b}[11], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[12], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[13], [x3], #1\n"
"st1 {v16.b}[14], [x3], #1\n"
"st1 {v16.b}[15], [x3], #1\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #4\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Add the destination zero point
"dup v14.4h, v13.h[4]\n"
"saddw v16.4s, v16.4s, v14.4h\n"
"saddw v17.4s, v17.4s, v14.4h\n"
"saddw v18.4s, v18.4s, v14.4h\n"
"saddw v19.4s, v19.4s, v14.4h\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Load the clamp_min, clamp_max bounds
"ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.8h, w2\n" // clamp_min
"dup v15.8h, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.8h, v16.8h, v14.8h\n"
"smax v17.8h, v17.8h, v14.8h\n"
// Apply the clamp_max bound
"smin v16.8h, v16.8h, v15.8h\n"
"smin v17.8h, v17.8h, v15.8h\n"
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
"str q17, [%[dst_tmp_buf], #16]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrh w7, [x3, x5, lsl #1]\n"
"strh w7, [x4, x5, lsl #1]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #8\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.h}[0], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v16.h}[1], [x3], #2\n"
"st1 {v16.h}[2], [x3], #2\n"
"st1 {v16.h}[3], [x3], #2\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.h}[4], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v16.h}[5], [x3], #2\n"
"st1 {v16.h}[6], [x3], #2\n"
"st1 {v16.h}[7], [x3], #2\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v17.h}[0], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v17.h}[1], [x3], #2\n"
"st1 {v17.h}[2], [x3], #2\n"
"st1 {v17.h}[3], [x3], #2\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v17.h}[4], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v17.h}[5], [x3], #2\n"
"st1 {v17.h}[6], [x3], #2\n"
"st1 {v17.h}[7], [x3], #2\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// At this point, v20 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
"str q17, [%[dst_tmp_buf], #16]\n"
"str q18, [%[dst_tmp_buf], #32]\n"
"str q19, [%[dst_tmp_buf], #48]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #16\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v16.s}[1], [x3], #4\n"
"st1 {v16.s}[2], [x3], #4\n"
"st1 {v16.s}[3], [x3], #4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v17.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v17.s}[1], [x3], #4\n"
"st1 {v17.s}[2], [x3], #4\n"
"st1 {v17.s}[3], [x3], #4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v18.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v18.s}[1], [x3], #4\n"
"st1 {v18.s}[2], [x3], #4\n"
"st1 {v18.s}[3], [x3], #4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v19.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v19.s}[1], [x3], #4\n"
"st1 {v19.s}[2], [x3], #4\n"
"st1 {v19.s}[3], [x3], #4\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #16\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
"smull v12.8h, v0.8b, v5.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"smull v14.8h, v2.8b, v5.8b\n"
"smull v15.8h, v3.8b, v5.8b\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #4\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #4\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #16\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
[dst_type_id] "r"(params.dst_type_id)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
}
// Similar to existing Kernel8bitNeon but specialized for the case of
// RHS cols == 1.
// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75,
// since these are 64-bit, out-of-order and without dotprod support.
void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) {
profiler::ScopeLabel label("Kernel (kNeon)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
void* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL));
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v19 are int32 accumulators.
// During accumulation, v0 -- v3 are used to load int8 data from LHS and
// v4 from RHS:
//
// int8 RHS 16x1 block
// /-----------|
// |v4.b[0] |
// | ... |
// |v4.b[15] |
// \-----------/
// int8 LHS 4x16 block
// /---------------------\ /-----------|
// |v0.b[0] ... v0.b[15] | |v16.4s |
// |v1.b[0] ... v1.b[15] | |v17.4s |
// |v2.b[0] ... v2.b[15] | |v18.4s |
// |v3.b[0] ... v3.b[15] | |v19.4s |
// \---------------------/ \-----------/
// int32 accumulators 4x1 block
//
// No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
// optimization for this kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Load the first 64 bytes of LHS and RHS data.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
"add %[rhs_ptr], %[rhs_ptr], #48\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 16.
"mov w1, #16\n"
// Perform the first few multiply-adds on the data that we have already
// loaded.
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// Reminder - w1 is how many levels of depth we have already loaded
// data for, w12 is the total depth.
"cmp w1, w12\n"
"beq 79f\n"
"2:\n"
// Some multiplications and 16-bit accumulation were already done above,
// so we start right away in the middle.
"sadalp v16.4s, v8.8h\n"
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
"add %[rhs_ptr], %[rhs_ptr], #48\n"
"sadalp v17.4s, v9.8h\n"
"sadalp v18.4s, v10.8h\n"
"sadalp v19.4s, v11.8h\n"
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
// Each iteration of this loop advances by 16 levels of depth.
"add w1, w1, #16\n"
// Loop termination condition
"cmp w1, w12\n"
"blt 2b\n"
"79:\n"
"sadalp v16.4s, v8.8h\n"
"sadalp v17.4s, v9.8h\n"
"sadalp v18.4s, v10.8h\n"
"sadalp v19.4s, v11.8h\n"
// End of accumulation. The registers v16 -- v19 contain the final
// int32 accumulator values of the current 4x1 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 4x1 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Reduce 32bit accumulators horizontally.
"addp v16.4s, v16.4s, v17.4s\n"
"addp v18.4s, v18.4s, v19.4s\n"
// Reduce 32bit accumulators horizontally, second pass
// (each pass adds pairwise. we need to add 4-wise).
"addp v16.4s, v16.4s, v18.4s\n"
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
// (still multiply column stride by 4 due to packing)
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
"mvni v8.4s, #0\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
"add x5, x4, %x[row], lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"csel x4, x4, x5, eq\n"
"ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
// Now we load: bias data, LHS sums data, RHS sums data.
// First, load the base pointers from the params.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
"add x5, x1, %x[row], lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 4 bias values.
"ld1 {v14.4s}, [x1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
"add %[rhs_ptr], %[rhs_ptr], #48\n"
// Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"add v14.4s, v14.4s, v9.4s\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
// (all four 32-bit accumulators are in v16 at this point)
"add v16.4s, v16.4s, v14.4s\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"add x3, x3, %x[col], lsl #2\n"
"ld1 {v14.4s}, [x3]\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"dup v10.4s, w5\n" // create lhs_zero_point_vec
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"mls v16.4s, v10.4s, v14.s[0]\n"
"401:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"add x2, x2, %x[row], lsl #2\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
// Load 4 lhs_sums values.
"ld1 {v11.4s}, [x2]\n"
"ins v13.s[1], w5\n" // rhs_zero_point
// Compute lhs_sums * rhs_zero_point.
"mul v11.4s, v11.4s, v13.s[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"sub v16.4s, v16.4s, v11.4s\n"
// If the destination is int32, it means the user asks for the raw
// accumulators, no need for us to downquantize the value.
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
//Load the exponent part of the multiplier.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"add x5, x1, %x[row], lsl #2\n"
"csel x1, x1, x5, eq\n"
"ld1 {v14.4s}, [x1]\n"
"smin v11.4s, v8.4s, v14.4s\n"
"sub v12.4s, v14.4s, v11.4s\n"
// Apply the positive exponent part of the multiplier.
"sshl v16.4s, v16.4s, v12.4s\n"
// Apply the fixed-point part of the multiplier.
"sqdmulh v16.4s, v16.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
"srshl v16.4s, v16.4s, v11.4s\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
// Cast-and-saturate from int32 to int16
// After this instruction, all data is in lower half (64-bits) of v16
"sqxtn v16.4h, v16.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"sqadd v16.8h, v16.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
// Now all data is in the first 32-bits of v16
"sqxtun v16.8b, v16.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"umax v16.16b, v16.16b, v14.16b\n"
// Apply the clamp_max bound
"umin v16.16b, v16.16b, v15.16b\n"
// Compute how much of the 4x1 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x1, there are some 4x1 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x1 block fit
"csel w1, w1, w3, le\n"
// Test if w1==4, i.e. if all of the 4x1 block fits.
"cmp w1, w3\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x1 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x1 block fits.
// Store to dst_tmp_buf
"st1 {v16.16b}, [%[dst_tmp_buf]]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x1 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[0], [x3], #1\n"
"st1 {v16.b}[1], [x3], #1\n"
"st1 {v16.b}[2], [x3], #1\n"
"st1 {v16.b}[3], [x3], #1\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #4\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
// Cast-and-saturate from int32 to int16
// After this, all values for output are in the lower half (64 bits) of v16.
"sqxtn v16.4h, v16.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"sqadd v16.8h, v16.8h, v14.8h\n"
// Cast-and-saturate from int16 to int8
"sqxtn v16.8b, v16.8h\n"
// At this point, we only need 4 lowest 8-bit values in v16.
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.16b, v16.16b, v14.16b\n"
// Apply the clamp_max bound
"smin v16.16b, v16.16b, v15.16b\n"
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x1 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Test if w1==4, i.e. if all of the 4x1 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x1 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"st1 {v16.16b}, [%[dst_tmp_buf]]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[0], [x3], #1\n"
"st1 {v16.b}[1], [x3], #1\n"
"st1 {v16.b}[2], [x3], #1\n"
"st1 {v16.b}[3], [x3], #1\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #4\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Add the destination zero point
"dup v14.4h, v13.h[4]\n"
"saddw v16.4s, v16.4s, v14.4h\n"
// Cast-and-saturate from int32 to int16
// After this instruction, all data is in lower half of v16.
"sqxtn v16.4h, v16.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
// Load the clamp_min, clamp_max bounds
"ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.8h, w2\n" // clamp_min
"dup v15.8h, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.8h, v16.8h, v14.8h\n"
// Apply the clamp_max bound
"smin v16.8h, v16.8h, v15.8h\n"
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrh w7, [x3, x5, lsl #1]\n"
"strh w7, [x4, x5, lsl #1]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.h}[0], [x3], #2\n"
"st1 {v16.h}[1], [x3], #2\n"
"st1 {v16.h}[2], [x3], #2\n"
"st1 {v16.h}[3], [x3], #2\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Test if w1==4 i.e. if all of the 4x1 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x1 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.s}[0], [x3], #4\n"
"st1 {v16.s}[1], [x3], #4\n"
"st1 {v16.s}[2], [x3], #4\n"
"st1 {v16.s}[3], [x3], #4\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #16\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #4\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #4\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 16.
"mov w1, #16\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
[dst_type_id] "r"(params.dst_type_id)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19");
}
// Variant of the above Kernel8bitNeon, tuned for A55-ish CPUs.
// Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and
// the original Cortex-A55, since these are 64-bit and do not support dotprod.
//
// While this kernel does not have a direct equivalent in gemmlowp, it was
// developed based on insights that David Mansell at ARM shared with their
// contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful
// comments. Specifically, see this comment about tuning for Cortex-A53:
// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215
void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) {
profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
void* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are int32 accumulators.
// During accumulation, v0 -- v3 are used to load int8 data from LHS and
// v4 -- v7 from RHS:
//
// int8 RHS 16x4 block
// /-----------------------------------------|
// |v4.b[0] ... v7.b[0] |
// | ... ... |
// |v4.b[15] ... v7.b[15] |
// \-----------------------------------------/
// int8 LHS 4x16 block
// /---------------------\ /-----------------------------------------|
// |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s |
// |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s |
// |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s |
// |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s |
// \---------------------/ \-----------------------------------------/
// int32 accumulators 4x4 block
asm volatile(
#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
RUY_MAKE_ZERO(v16)
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
RUY_MAKE_ZERO(v17)
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
RUY_MAKE_ZERO(v18)
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
RUY_MAKE_ZERO(v19)
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
RUY_MAKE_ZERO(v20)
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
RUY_MAKE_ZERO(v21)
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
RUY_MAKE_ZERO(v22)
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
RUY_MAKE_ZERO(v23)
// Load the first 64 bytes of LHS and RHS data.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
RUY_MAKE_ZERO(v24)
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
RUY_MAKE_ZERO(v25)
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
RUY_MAKE_ZERO(v26)
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
RUY_MAKE_ZERO(v27)
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
RUY_MAKE_ZERO(v28)
"ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
RUY_MAKE_ZERO(v29)
"ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
RUY_MAKE_ZERO(v30)
"ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
RUY_MAKE_ZERO(v31)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 16.
"mov w1, #16\n"
// Perform the first few multiply-adds on the data that we have already
// loaded.
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
"smull v12.8h, v0.8b, v5.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"smull v14.8h, v2.8b, v5.8b\n"
"smull v15.8h, v3.8b, v5.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// Reminder - w1 is how many levels of depth we have already loaded
// data for, w12 is the total depth.
"cmp w1, w12\n"
"beq 79f\n"
"2:\n"
// Some multiplications and 16-bit accumulation were already done above,
// so we start right away in the middle.
"sadalp v16.4s, v8.8h\n"
"ldr d4, [%[rhs_ptr], #0]\n"
"smull v8.8h, v0.8b, v6.8b\n"
"ldr x7, [%[rhs_ptr], #8]\n"
"sadalp v17.4s, v9.8h\n"
"ldr d5, [%[rhs_ptr], #16]\n"
"smull v9.8h, v1.8b, v6.8b\n"
"ldr x8, [%[rhs_ptr], #24]\n"
"sadalp v18.4s, v10.8h\n"
"smull v10.8h, v2.8b, v6.8b\n"
"sadalp v19.4s, v11.8h\n"
"add %[lhs_ptr], %[lhs_ptr], #64\n"
"smull v11.8h, v3.8b, v6.8b\n"
"add %[rhs_ptr], %[rhs_ptr], #64\n"
"sadalp v20.4s, v12.8h\n"
// Each iteration of this loop advances by 16 levels of depth.
"add w1, w1, #16\n"
"smull v12.8h, v0.8b, v7.8b\n"
// Loop termination condition
"cmp w1, w12\n"
"sadalp v21.4s, v13.8h\n"
"ldr x3, [%[lhs_ptr], #-56]\n"
"smull v13.8h, v1.8b, v7.8b\n"
"ldr x4, [%[lhs_ptr], #-40]\n"
"sadalp v22.4s, v14.8h\n"
"ldr x5, [%[lhs_ptr], #-24]\n"
"smull v14.8h, v2.8b, v7.8b\n"
"ldr x6, [%[lhs_ptr], #-8]\n"
"sadalp v23.4s, v15.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v6.16b\n"
"smlal2 v9.8h, v1.16b, v6.16b\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"ldr x9, [%[rhs_ptr], #-24]\n"
"smlal2 v11.8h, v3.16b, v6.16b\n"
"ldr d6, [%[rhs_ptr], #-32]\n"
"smlal2 v12.8h, v0.16b, v7.16b\n"
"ldr d0, [%[lhs_ptr], #-64]\n"
"smlal2 v13.8h, v1.16b, v7.16b\n"
"ldr d1, [%[lhs_ptr], #-48]\n"
"smlal2 v14.8h, v2.16b, v7.16b\n"
"ins v4.d[1], x7\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"ins v5.d[1], x8\n"
"ldr d2, [%[lhs_ptr], #-32]\n"
"ins v0.d[1], x3\n"
"sadalp v24.4s, v8.8h\n"
"ldr d3, [%[lhs_ptr], #-16]\n"
"ins v1.d[1], x4\n"
"smull v8.8h, v0.8b, v4.8b\n"
"ins v2.d[1], x5\n"
"sadalp v25.4s, v9.8h\n"
"ins v3.d[1], x6\n"
"smull v9.8h, v1.8b, v4.8b\n"
"ldr d7, [%[rhs_ptr], #-16]\n"
"sadalp v26.4s, v10.8h\n"
"ldr x10, [%[rhs_ptr], #-8]\n"
"smull v10.8h, v2.8b, v4.8b\n"
"sadalp v27.4s, v11.8h\n"
"smull v11.8h, v3.8b, v4.8b\n"
"sadalp v28.4s, v12.8h\n"
"smull v12.8h, v0.8b, v5.8b\n"
"sadalp v29.4s, v13.8h\n"
"smull v13.8h, v1.8b, v5.8b\n"
"sadalp v30.4s, v14.8h\n"
"smull v14.8h, v2.8b, v5.8b\n"
"sadalp v31.4s, v15.8h\n"
"smull v15.8h, v3.8b, v5.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"ins v6.d[1], x9\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"ins v7.d[1], x10\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
"blt 2b\n"
"79:\n"
"sadalp v16.4s, v8.8h\n"
"smull v8.8h, v0.8b, v6.8b\n"
"sadalp v17.4s, v9.8h\n"
"smull v9.8h, v1.8b, v6.8b\n"
"sadalp v18.4s, v10.8h\n"
"smull v10.8h, v2.8b, v6.8b\n"
"sadalp v19.4s, v11.8h\n"
"smull v11.8h, v3.8b, v6.8b\n"
"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v0.8b, v7.8b\n"
"sadalp v21.4s, v13.8h\n"
"smull v13.8h, v1.8b, v7.8b\n"
"sadalp v22.4s, v14.8h\n"
"smull v14.8h, v2.8b, v7.8b\n"
"sadalp v23.4s, v15.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v6.16b\n"
"smlal2 v9.8h, v1.16b, v6.16b\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"smlal2 v11.8h, v3.16b, v6.16b\n"
"smlal2 v12.8h, v0.16b, v7.16b\n"
"smlal2 v13.8h, v1.16b, v7.16b\n"
"smlal2 v14.8h, v2.16b, v7.16b\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"sadalp v24.4s, v8.8h\n"
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"sadalp v25.4s, v9.8h\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"sadalp v26.4s, v10.8h\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"sadalp v27.4s, v11.8h\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"sadalp v28.4s, v12.8h\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"sadalp v29.4s, v13.8h\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 4x4 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 4x4 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Reduce 32bit accumulators horizontally.
"addp v16.4s, v16.4s, v17.4s\n"
"addp v18.4s, v18.4s, v19.4s\n"
"addp v20.4s, v20.4s, v21.4s\n"
"addp v22.4s, v22.4s, v23.4s\n"
"addp v24.4s, v24.4s, v25.4s\n"
"addp v26.4s, v26.4s, v27.4s\n"
"addp v28.4s, v28.4s, v29.4s\n"
"addp v30.4s, v30.4s, v31.4s\n"
// Reduce 32bit accumulators horizontally, second pass
// (each pass adds pairwise. we need to add 4-wise).
"addp v16.4s, v16.4s, v18.4s\n"
"addp v17.4s, v20.4s, v22.4s\n"
"addp v18.4s, v24.4s, v26.4s\n"
"addp v19.4s, v28.4s, v30.4s\n"
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
"mvni v8.4s, #0\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
// Determine the channel index.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
// Offset the bias pointer as needed given the current row, col.
"add x5, x1, x3, lsl #2\n"
// If there is no bias, use no offset, just address the passed zero
// data.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 4 bias values.
"ld1 {v14.4s}, [x1]\n"
// Load the multiplier_fixedpoint values.
"add x5, x4, x3, lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"csel x4, x4, x5, eq\n"
"ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
// Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"add v14.4s, v14.4s, v9.4s\n"
"ldr d0, [%[lhs_ptr], #0]\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
// Jump based on channel dimension.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 6f\n"
// Case where channels are rows
"add v16.4s, v16.4s, v14.4s\n"
"ldr d1, [%[lhs_ptr], #16]\n"
"add v17.4s, v17.4s, v14.4s\n"
"ldr d2, [%[lhs_ptr], #32]\n"
"add v18.4s, v18.4s, v14.4s\n"
"ldr d3, [%[lhs_ptr], #48]\n"
"add v19.4s, v19.4s, v14.4s\n"
"ldr d4, [%[rhs_ptr], #0]\n"
"ldr d5, [%[rhs_ptr], #16]\n"
"ldr d6, [%[rhs_ptr], #32]\n"
"ldr d7, [%[rhs_ptr], #48]\n"
"b 7f\n"
"6:\n"
// Case where channels are columns
"dup v20.4s, v14.s[0]\n"
"ldr d1, [%[lhs_ptr], #16]\n"
"dup v21.4s, v14.s[1]\n"
"ldr d2, [%[lhs_ptr], #32]\n"
"dup v22.4s, v14.s[2]\n"
"ldr d3, [%[lhs_ptr], #48]\n"
"dup v23.4s, v14.s[3]\n"
"ldr d4, [%[rhs_ptr], #0]\n"
"add v16.4s, v16.4s, v20.4s\n"
"ldr d5, [%[rhs_ptr], #16]\n"
"add v17.4s, v17.4s, v21.4s\n"
"ldr d6, [%[rhs_ptr], #32]\n"
"add v18.4s, v18.4s, v22.4s\n"
"ldr d7, [%[rhs_ptr], #48]\n"
"add v19.4s, v19.4s, v23.4s\n"
"7:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"add x3, x3, %x[col], lsl #2\n"
"ld1 {v14.4s}, [x3]\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"dup v10.4s, w5\n" // create lhs_zero_point_vec
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"mls v16.4s, v10.4s, v14.s[0]\n"
"mls v17.4s, v10.4s, v14.s[1]\n"
"mls v18.4s, v10.4s, v14.s[2]\n"
"mls v19.4s, v10.4s, v14.s[3]\n"
"401:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"add x2, x2, %x[row], lsl #2\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
// Load 4 lhs_sums values.
"ld1 {v11.4s}, [x2]\n"
"ins v13.s[1], w5\n" // rhs_zero_point
// Compute lhs_sums * rhs_zero_point.
"mul v11.4s, v11.4s, v13.s[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"sub v16.4s, v16.4s, v11.4s\n"
"sub v17.4s, v17.4s, v11.4s\n"
"sub v18.4s, v18.4s, v11.4s\n"
"sub v19.4s, v19.4s, v11.4s\n"
// If the destination is int32, it means the user asks for the raw
// accumulators, no need for us to downquantize the value.
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
// Determine the channel index.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"add x5, x1, x3, lsl #2\n"
"csel x1, x1, x5, eq\n"
"ld1 {v14.4s}, [x1]\n"
"smin v11.4s, v8.4s, v14.4s\n"
"ldr x1, [%[lhs_ptr], #8]\n"
"sub v12.4s, v14.4s, v11.4s\n"
// Jump based on channel dimension.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 8f\n"
// Case where channels are rows
// Apply the positive exponent part of the multiplier.
"sshl v16.4s, v16.4s, v12.4s\n"
"ldr x2, [%[lhs_ptr], #24]\n"
"sshl v17.4s, v17.4s, v12.4s\n"
"ldr x3, [%[lhs_ptr], #40]\n"
"sshl v18.4s, v18.4s, v12.4s\n"
"ldr x4, [%[lhs_ptr], #56]\n"
"sshl v19.4s, v19.4s, v12.4s\n"
// Apply the fixed-point part of the multiplier.
"ins v0.d[1], x1\n"
"ldr x1, [%[rhs_ptr], #8]\n"
"sqdmulh v16.4s, v16.4s, v15.4s\n"
"ins v1.d[1], x2\n"
"ldr x2, [%[rhs_ptr], #24]\n"
"sqdmulh v17.4s, v17.4s, v15.4s\n"
"ins v2.d[1], x3\n"
"ldr x3, [%[rhs_ptr], #40]\n"
"sqdmulh v18.4s, v18.4s, v15.4s\n"
"ins v3.d[1], x4\n"
"ldr x4, [%[rhs_ptr], #56]\n"
"sqdmulh v19.4s, v19.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
"srshl v16.4s, v16.4s, v11.4s\n"
"srshl v17.4s, v17.4s, v11.4s\n"
"srshl v18.4s, v18.4s, v11.4s\n"
"srshl v19.4s, v19.4s, v11.4s\n"
"b 9f\n"
"8:\n"
// Case where channels are columns
// Apply the positive exponent part of the multiplier.
"dup v20.4s, v12.s[0]\n"
"ldr x2, [%[lhs_ptr], #24]\n"
"ldr x3, [%[lhs_ptr], #40]\n"
"dup v21.4s, v12.s[1]\n"
"ldr x4, [%[lhs_ptr], #56]\n"
"dup v22.4s, v12.s[2]\n"
"ins v0.d[1], x1\n"
"dup v23.4s, v12.s[3]\n"
"ldr x1, [%[rhs_ptr], #8]\n"
"sshl v16.4s, v16.4s, v20.4s\n"
"ins v1.d[1], x2\n"
"sshl v17.4s, v17.4s, v21.4s\n"
"ldr x2, [%[rhs_ptr], #24]\n"
"sshl v18.4s, v18.4s, v22.4s\n"
"ins v2.d[1], x3\n"
"sshl v19.4s, v19.4s, v23.4s\n"
"ldr x3, [%[rhs_ptr], #40]\n"
// Apply the fixed-point part of the multiplier.
"sqdmulh v16.4s, v16.4s, v15.s[0]\n"
"ins v3.d[1], x4\n"
"sqdmulh v17.4s, v17.4s, v15.s[1]\n"
"ldr x4, [%[rhs_ptr], #56]\n"
"sqdmulh v18.4s, v18.4s, v15.s[2]\n"
"dup v20.4s, v11.s[0]\n"
"sqdmulh v19.4s, v19.4s, v15.s[3]\n"
// Apply the negative exponent part of the multiplier.
"dup v21.4s, v11.s[1]\n"
"srshl v16.4s, v16.4s, v20.4s\n"
"dup v22.4s, v11.s[2]\n"
"srshl v17.4s, v17.4s, v21.4s\n"
"dup v23.4s, v11.s[3]\n"
"srshl v18.4s, v18.4s, v22.4s\n"
"srshl v19.4s, v19.4s, v23.4s\n"
"9:\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
"ins v4.d[1], x1\n"
"sqxtn v16.4h, v16.4s\n"
"ins v5.d[1], x2\n"
"sqxtn2 v16.8h, v17.4s\n"
"ins v6.d[1], x3\n"
"sqxtn v17.4h, v18.4s\n"
"ins v7.d[1], x4\n"
RUY_MAKE_ZERO(v18)
"sqxtn2 v17.8h, v19.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v19)
// Add the destination zero point
"add %[lhs_ptr], %[lhs_ptr], #64\n"
"dup v14.8h, v13.h[4]\n"
RUY_MAKE_ZERO(v20)
"add %[rhs_ptr], %[rhs_ptr], #64\n"
"sqadd v16.8h, v16.8h, v14.8h\n"
RUY_MAKE_ZERO(v21)
"sqadd v17.8h, v17.8h, v14.8h\n"
RUY_MAKE_ZERO(v22)
// Cast-and-saturate from int16 to uint8
"sqxtun v16.8b, v16.8h\n"
RUY_MAKE_ZERO(v23)
"sqxtun2 v16.16b, v17.8h\n"
RUY_MAKE_ZERO(v24)
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
RUY_MAKE_ZERO(v25)
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
RUY_MAKE_ZERO(v26)
"dup v14.16b, w2\n" // clamp_min
RUY_MAKE_ZERO(v27)
"dup v15.16b, w3\n" // clamp_max
RUY_MAKE_ZERO(v28)
// Apply the clamp_min bound
"umax v16.16b, v16.16b, v14.16b\n"
RUY_MAKE_ZERO(v29)
// Apply the clamp_max bound
"umin v16.16b, v16.16b, v15.16b\n"
RUY_MAKE_ZERO(v30)
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
RUY_MAKE_ZERO(v31)
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"st1 {v16.16b}, [%[dst_tmp_buf]]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #4\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[0], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[1], [x3], #1\n"
"st1 {v16.b}[2], [x3], #1\n"
"st1 {v16.b}[3], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[4], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[5], [x3], #1\n"
"st1 {v16.b}[6], [x3], #1\n"
"st1 {v16.b}[7], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[8], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[9], [x3], #1\n"
"st1 {v16.b}[10], [x3], #1\n"
"st1 {v16.b}[11], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[12], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[13], [x3], #1\n"
"st1 {v16.b}[14], [x3], #1\n"
"st1 {v16.b}[15], [x3], #1\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #4\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
"ins v4.d[1], x1\n"
"sqxtn v16.4h, v16.4s\n"
"ins v5.d[1], x2\n"
"sqxtn2 v16.8h, v17.4s\n"
"ins v6.d[1], x3\n"
"sqxtn v17.4h, v18.4s\n"
"ins v7.d[1], x4\n"
RUY_MAKE_ZERO(v18)
"sqxtn2 v17.8h, v19.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v19)
// Add the destination zero point
"add %[lhs_ptr], %[lhs_ptr], #64\n"
"dup v14.8h, v13.h[4]\n"
RUY_MAKE_ZERO(v20)
"add %[rhs_ptr], %[rhs_ptr], #64\n"
"sqadd v16.8h, v16.8h, v14.8h\n"
RUY_MAKE_ZERO(v21)
"sqadd v17.8h, v17.8h, v14.8h\n"
RUY_MAKE_ZERO(v22)
// Cast-and-saturate from int16 to uint8
"sqxtn v16.8b, v16.8h\n"
RUY_MAKE_ZERO(v23)
"sqxtn2 v16.16b, v17.8h\n"
RUY_MAKE_ZERO(v24)
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
RUY_MAKE_ZERO(v25)
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
RUY_MAKE_ZERO(v26)
"dup v14.16b, w2\n" // clamp_min
RUY_MAKE_ZERO(v27)
"dup v15.16b, w3\n" // clamp_max
RUY_MAKE_ZERO(v28)
// Apply the clamp_min bound
"smax v16.16b, v16.16b, v14.16b\n"
RUY_MAKE_ZERO(v29)
// Apply the clamp_max bound
"smin v16.16b, v16.16b, v15.16b\n"
RUY_MAKE_ZERO(v30)
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
RUY_MAKE_ZERO(v31)
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"st1 {v16.16b}, [%[dst_tmp_buf]]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #4\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[0], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[1], [x3], #1\n"
"st1 {v16.b}[2], [x3], #1\n"
"st1 {v16.b}[3], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[4], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[5], [x3], #1\n"
"st1 {v16.b}[6], [x3], #1\n"
"st1 {v16.b}[7], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[8], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[9], [x3], #1\n"
"st1 {v16.b}[10], [x3], #1\n"
"st1 {v16.b}[11], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[12], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[13], [x3], #1\n"
"st1 {v16.b}[14], [x3], #1\n"
"st1 {v16.b}[15], [x3], #1\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #4\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Add the destination zero point
"dup v14.4h, v13.h[4]\n"
"saddw v16.4s, v16.4s, v14.4h\n"
"saddw v17.4s, v17.4s, v14.4h\n"
"saddw v18.4s, v18.4s, v14.4h\n"
"saddw v19.4s, v19.4s, v14.4h\n"
// Cast-and-saturate from int32 to int16
"ins v4.d[1], x1\n"
"sqxtn v16.4h, v16.4s\n"
"ins v5.d[1], x2\n"
"sqxtn2 v16.8h, v17.4s\n"
"ins v6.d[1], x3\n"
"sqxtn v17.4h, v18.4s\n"
"ins v7.d[1], x4\n"
RUY_MAKE_ZERO(v18)
"sqxtn2 v17.8h, v19.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v19)
"add %[lhs_ptr], %[lhs_ptr], #64\n"
RUY_MAKE_ZERO(v20)
"add %[rhs_ptr], %[rhs_ptr], #64\n"
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
// Load the clamp_min, clamp_max bounds
"ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
RUY_MAKE_ZERO(v25)
"ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
RUY_MAKE_ZERO(v26)
"dup v14.8h, w2\n" // clamp_min
RUY_MAKE_ZERO(v27)
"dup v15.8h, w3\n" // clamp_max
RUY_MAKE_ZERO(v28)
// Apply the clamp_min bound
"smax v16.8h, v16.8h, v14.8h\n"
"smax v17.8h, v17.8h, v14.8h\n"
RUY_MAKE_ZERO(v29)
// Apply the clamp_max bound
"smin v16.8h, v16.8h, v15.8h\n"
"smin v17.8h, v17.8h, v15.8h\n"
RUY_MAKE_ZERO(v30)
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
RUY_MAKE_ZERO(v31)
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
"str q17, [%[dst_tmp_buf], #16]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrh w7, [x3, x5, lsl #1]\n"
"strh w7, [x4, x5, lsl #1]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #8\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.h}[0], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v16.h}[1], [x3], #2\n"
"st1 {v16.h}[2], [x3], #2\n"
"st1 {v16.h}[3], [x3], #2\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.h}[4], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v16.h}[5], [x3], #2\n"
"st1 {v16.h}[6], [x3], #2\n"
"st1 {v16.h}[7], [x3], #2\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v17.h}[0], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v17.h}[1], [x3], #2\n"
"st1 {v17.h}[2], [x3], #2\n"
"st1 {v17.h}[3], [x3], #2\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v17.h}[4], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v17.h}[5], [x3], #2\n"
"st1 {v17.h}[6], [x3], #2\n"
"st1 {v17.h}[7], [x3], #2\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
"ldr x1, [%[lhs_ptr], #8]\n"
"ldr x2, [%[lhs_ptr], #24]\n"
"ldr x3, [%[lhs_ptr], #40]\n"
"ldr x4, [%[lhs_ptr], #56]\n"
"ins v0.d[1], x1\n"
"ldr x1, [%[rhs_ptr], #8]\n"
"ins v1.d[1], x2\n"
"ldr x2, [%[rhs_ptr], #24]\n"
"ins v2.d[1], x3\n"
"ldr x3, [%[rhs_ptr], #40]\n"
"ins v3.d[1], x4\n"
"ldr x4, [%[rhs_ptr], #56]\n"
"ins v4.d[1], x1\n"
"ins v5.d[1], x2\n"
"ins v6.d[1], x3\n"
"ins v7.d[1], x4\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// At this point, v20 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v20)
"add %[lhs_ptr], %[lhs_ptr], #64\n"
RUY_MAKE_ZERO(v21)
"add %[rhs_ptr], %[rhs_ptr], #64\n"
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
RUY_MAKE_ZERO(v31)
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
"str q17, [%[dst_tmp_buf], #16]\n"
"str q18, [%[dst_tmp_buf], #32]\n"
"str q19, [%[dst_tmp_buf], #48]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #16\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v16.s}[1], [x3], #4\n"
"st1 {v16.s}[2], [x3], #4\n"
"st1 {v16.s}[3], [x3], #4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v17.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v17.s}[1], [x3], #4\n"
"st1 {v17.s}[2], [x3], #4\n"
"st1 {v17.s}[3], [x3], #4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v18.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v18.s}[1], [x3], #4\n"
"st1 {v18.s}[2], [x3], #4\n"
"st1 {v18.s}[3], [x3], #4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v19.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v19.s}[1], [x3], #4\n"
"st1 {v19.s}[2], [x3], #4\n"
"st1 {v19.s}[3], [x3], #4\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #16\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"smull v11.8h, v3.8b, v4.8b\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"smull v12.8h, v0.8b, v5.8b\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"smull v13.8h, v1.8b, v5.8b\n"
"smull v14.8h, v2.8b, v5.8b\n"
"smull v15.8h, v3.8b, v5.8b\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #4\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #4\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #16\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params),[dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
[dst_type_id] "r"(params.dst_type_id)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
}
// Kernel taking advantage of the optional dotprod instruction.
// This is very similar to (and directly inspired by) this gemmlowp kernel
// which was contributed by David Mansell at ARM:
// NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct
// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391
//
// Besides the ruy-ification, the main difference here is that we use a 8x8
// instead of 12x8 width, so as to stick to power-of-two widths. This slightly
// narrower kernel layout is still wide enough to achieve high performance
// although we haven't actually performed a real comparison to know exactly
// how this compares to ARM's aforementioned kernel.
//
// Relevant target CPUs for this kernel include ARM Cortex-A76,
// since these are 64-bit, out-of-order and with dotprod support.
void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) {
profiler::ScopeLabel label("Kernel (kNeonDotprod)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
void* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are int32 accumulators.
// During accumulation, v0 -- v15 are used to load int8 data from LHS and
// RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
// v3 are used to load a 4x8 block of RHS, like this:
//
// int8 RHS 4x8 block
// /-----------------------------------------|
// |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
// | ... ... |
// |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
// \-----------------------------------------/
// int8 LHS 8x4 block
// /---------------------\ /-----------------------------------------|
// |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
// | ... ... | | ... ... |
// |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
// |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
// | ... ... | | ... ... |
// |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
// \---------------------/ \-----------------------------------------/
// int32 accumulators 8x8 block
//
// In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
// is repeated 4 times, using 4x more registers for LHS and RHS, so that
// is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
//
// Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
// unused, and v8 -- v15 are used for loading parameters used for the
// post-accumulation part of the kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Load the first 32 bytes of LHS and RHS data.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #4\n"
// Perform the first few multiply-adds on the data that we have already
// loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// Optional, maximally-streaming, partial-unrolling (4x unrolled)
// optimization of the kernel inner loop (over depth). For more
// comments, see the non-unrolled loop below after the #endif.
#if RUY_OPT(MAX_STREAMING)
"cmp w12, #32\n"
"blt 78f\n"
"ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v8.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v9.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v12.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v13.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v14.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v15.16b}, [%[rhs_ptr]], #16\n"
"mov w1, #16\n"
"and w3, w12, #-16\n"
"81:\n"
"add w1, w1, #16\n"
".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
"ldr q0, [%[lhs_ptr], #0]\n"
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
"ldr q2, [%[rhs_ptr], #0]\n"
".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
"ldr q1, [%[lhs_ptr], #16]\n"
".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
"ldr q3, [%[rhs_ptr], #16]\n"
".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n"
".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n"
".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n"
".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n"
".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n"
".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n"
".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n"
".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n"
".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n"
".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n"
"ldr q5, [%[lhs_ptr], #48]\n"
".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n"
".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n"
"ldr q7, [%[rhs_ptr], #48]\n"
".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n"
".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n"
"ldr q4, [%[lhs_ptr], #32]\n"
".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n"
".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n"
"ldr q6, [%[rhs_ptr], #32]\n"
".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n"
".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n"
".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n"
".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n"
".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n"
".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n"
".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n"
".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n"
".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n"
".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n"
"ldr q9, [%[lhs_ptr], #80]\n"
".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n"
".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n"
"ldr q11, [%[rhs_ptr], #80]\n"
".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n"
".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n"
"ldr q8, [%[lhs_ptr], #64]\n"
".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n"
".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n"
"ldr q10, [%[rhs_ptr], #64]\n"
".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n"
".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n"
"add %[lhs_ptr], %[lhs_ptr], #128\n"
".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n"
".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n"
"add %[rhs_ptr], %[rhs_ptr], #128\n"
".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n"
".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n"
".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n"
".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n"
"cmp w1, w3\n"
".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n"
".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n"
"ldr q13, [%[lhs_ptr], #-16]\n"
".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n"
".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n"
"ldr q15, [%[rhs_ptr], #-16]\n"
".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n"
".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n"
"ldr q12, [%[lhs_ptr], #-32]\n"
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
"ldr q14, [%[rhs_ptr], #-32]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
"blt 81b\n"
".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n"
".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n"
".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n"
".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n"
".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n"
".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n"
".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n"
".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n"
".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n"
".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n"
".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n"
".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n"
".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n"
".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n"
".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n"
".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n"
".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n"
".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n"
".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n"
".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n"
".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n"
".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n"
".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n"
".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n"
".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n"
".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n"
".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n"
".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n"
".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n"
".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n"
".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n"
".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n"
".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n"
".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n"
".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n"
".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n"
".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n"
".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n"
".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n"
".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n"
".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n"
".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n"
".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n"
".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n"
".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n"
".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n"
"78:\n"
#endif // #if RUY_OPT(MAX_STREAMING)
// Ordinary kernel inner loop (over depth), the simpler loop that the
// above was an equivalent 4x-partially-unrolled version of.
// Reminder - w1 is how many levels of depth we have already loaded
// data for, w12 is the total depth.
"cmp w1, w12\n"
"beq 79f\n"
"2:\n"
// Because of the data that we have already loaded, we can start the
// loop body right away with some multiply-adds.
".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
// Each iteration of this loop advances by 4 levels of depth.
"add w1, w1, #4\n"
".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
// Loop termination condition.
"cmp w1, w12\n"
".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
"ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
"ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"blt 2b\n"
"79:\n"
// End of the inner loop on depth. Now perform the remaining
// multiply-adds of the last 4 levels of depth, for which the LHS
// and RHS data is already loaded.
".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 8x8 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 8x8 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
"mvni v8.4s, #0\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
// Determine the channel index.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
// Offset the bias pointer as needed given the current row, col.
"add x5, x1, x3, lsl #2\n"
// If there is no bias, use no offset, just address the passed zero
// data.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 8 bias values.
"ld1 {v14.4s}, [x1], #16\n"
"ld1 {v15.4s}, [x1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
// Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"add v14.4s, v14.4s, v9.4s\n"
"add v15.4s, v15.4s, v9.4s\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
// Jump based on channel dimension.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 6f\n"
// Case where channels are rows
"add v16.4s, v16.4s, v14.4s\n"
"add v17.4s, v17.4s, v15.4s\n"
"add v18.4s, v18.4s, v14.4s\n"
"add v19.4s, v19.4s, v15.4s\n"
"add v20.4s, v20.4s, v14.4s\n"
"add v21.4s, v21.4s, v15.4s\n"
"add v22.4s, v22.4s, v14.4s\n"
"add v23.4s, v23.4s, v15.4s\n"
"add v24.4s, v24.4s, v14.4s\n"
"add v25.4s, v25.4s, v15.4s\n"
"add v26.4s, v26.4s, v14.4s\n"
"add v27.4s, v27.4s, v15.4s\n"
"add v28.4s, v28.4s, v14.4s\n"
"add v29.4s, v29.4s, v15.4s\n"
"add v30.4s, v30.4s, v14.4s\n"
"add v31.4s, v31.4s, v15.4s\n"
"b 7f\n"
"6:\n"
// Case where channels are columns
"dup v10.4s, v14.s[0]\n"
"dup v11.4s, v14.s[1]\n"
"dup v12.4s, v14.s[2]\n"
"dup v13.4s, v14.s[3]\n"
"add v16.4s, v16.4s, v10.4s\n"
"add v17.4s, v17.4s, v10.4s\n"
"add v18.4s, v18.4s, v11.4s\n"
"add v19.4s, v19.4s, v11.4s\n"
"add v20.4s, v20.4s, v12.4s\n"
"add v21.4s, v21.4s, v12.4s\n"
"add v22.4s, v22.4s, v13.4s\n"
"add v23.4s, v23.4s, v13.4s\n"
"dup v10.4s, v15.s[0]\n"
"dup v11.4s, v15.s[1]\n"
"dup v12.4s, v15.s[2]\n"
"dup v13.4s, v15.s[3]\n"
"add v24.4s, v24.4s, v10.4s\n"
"add v25.4s, v25.4s, v10.4s\n"
"add v26.4s, v26.4s, v11.4s\n"
"add v27.4s, v27.4s, v11.4s\n"
"add v28.4s, v28.4s, v12.4s\n"
"add v29.4s, v29.4s, v12.4s\n"
"add v30.4s, v30.4s, v13.4s\n"
"add v31.4s, v31.4s, v13.4s\n"
"7:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"add x3, x3, %x[col], lsl #2\n"
"ld1 {v14.4s}, [x3], #16\n"
"ld1 {v15.4s}, [x3]\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"dup v10.4s, w5\n" // create lhs_zero_point_vec
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"mls v16.4s, v10.4s, v14.s[0]\n"
"mls v17.4s, v10.4s, v14.s[0]\n"
"mls v18.4s, v10.4s, v14.s[1]\n"
"mls v19.4s, v10.4s, v14.s[1]\n"
"mls v20.4s, v10.4s, v14.s[2]\n"
"mls v21.4s, v10.4s, v14.s[2]\n"
"mls v22.4s, v10.4s, v14.s[3]\n"
"mls v23.4s, v10.4s, v14.s[3]\n"
"mls v24.4s, v10.4s, v15.s[0]\n"
"mls v25.4s, v10.4s, v15.s[0]\n"
"mls v26.4s, v10.4s, v15.s[1]\n"
"mls v27.4s, v10.4s, v15.s[1]\n"
"mls v28.4s, v10.4s, v15.s[2]\n"
"mls v29.4s, v10.4s, v15.s[2]\n"
"mls v30.4s, v10.4s, v15.s[3]\n"
"mls v31.4s, v10.4s, v15.s[3]\n"
"401:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"add x2, x2, %x[row], lsl #2\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
// Load 4 lhs_sums values.
"ld1 {v11.4s}, [x2], #16\n"
"ld1 {v12.4s}, [x2]\n"
"ins v13.s[1], w5\n" // rhs_zero_point
// Compute lhs_sums * rhs_zero_point.
"mul v11.4s, v11.4s, v13.s[1]\n"
"mul v12.4s, v12.4s, v13.s[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"sub v16.4s, v16.4s, v11.4s\n"
"sub v17.4s, v17.4s, v12.4s\n"
"sub v18.4s, v18.4s, v11.4s\n"
"sub v19.4s, v19.4s, v12.4s\n"
"sub v20.4s, v20.4s, v11.4s\n"
"sub v21.4s, v21.4s, v12.4s\n"
"sub v22.4s, v22.4s, v11.4s\n"
"sub v23.4s, v23.4s, v12.4s\n"
"sub v24.4s, v24.4s, v11.4s\n"
"sub v25.4s, v25.4s, v12.4s\n"
"sub v26.4s, v26.4s, v11.4s\n"
"sub v27.4s, v27.4s, v12.4s\n"
"sub v28.4s, v28.4s, v11.4s\n"
"sub v29.4s, v29.4s, v12.4s\n"
"sub v30.4s, v30.4s, v11.4s\n"
"sub v31.4s, v31.4s, v12.4s\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
//Load the exponent part of the multiplier.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
// Determine the channel index.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
// Compute the multiplier_exponent pointer
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"add x5, x1, x3, lsl #2\n"
"csel x1, x1, x5, eq\n"
// Load multiplier_exponent
"ldr q9, [x1]\n"
"ldr q10, [x1, #16]\n"
// Separate positive and negative exponents
"smin v11.4s, v8.4s, v9.4s\n"
"smin v12.4s, v8.4s, v10.4s\n"
"sub v9.4s, v9.4s, v11.4s\n"
"sub v10.4s, v10.4s, v12.4s\n"
// Compute the multiplier_fixedpoint pointer
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"add x5, x4, x3, lsl #2\n"
"csel x4, x4, x5, eq\n"
// Load multiplier_fixedpoint
"ldr q14, [x4]\n"
"ldr q15, [x4, #16]\n"
// Jump based on channel dimension.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 8f\n"
// Case where channels are rows
// Apply the positive exponent part of the multiplier.
"sshl v16.4s, v16.4s, v9.4s\n"
"sshl v17.4s, v17.4s, v10.4s\n"
"sshl v18.4s, v18.4s, v9.4s\n"
"sshl v19.4s, v19.4s, v10.4s\n"
"sshl v20.4s, v20.4s, v9.4s\n"
"sshl v21.4s, v21.4s, v10.4s\n"
"sshl v22.4s, v22.4s, v9.4s\n"
"sshl v23.4s, v23.4s, v10.4s\n"
"sshl v24.4s, v24.4s, v9.4s\n"
"sshl v25.4s, v25.4s, v10.4s\n"
"sshl v26.4s, v26.4s, v9.4s\n"
"sshl v27.4s, v27.4s, v10.4s\n"
"sshl v28.4s, v28.4s, v9.4s\n"
"sshl v29.4s, v29.4s, v10.4s\n"
"sshl v30.4s, v30.4s, v9.4s\n"
"sshl v31.4s, v31.4s, v10.4s\n"
"10:\n"
// Apply the fixed-point part of the multiplier.
"sqdmulh v16.4s, v16.4s, v14.4s\n"
"sqdmulh v17.4s, v17.4s, v15.4s\n"
"sqdmulh v18.4s, v18.4s, v14.4s\n"
"sqdmulh v19.4s, v19.4s, v15.4s\n"
"sqdmulh v20.4s, v20.4s, v14.4s\n"
"sqdmulh v21.4s, v21.4s, v15.4s\n"
"sqdmulh v22.4s, v22.4s, v14.4s\n"
"sqdmulh v23.4s, v23.4s, v15.4s\n"
"sqdmulh v24.4s, v24.4s, v14.4s\n"
"sqdmulh v25.4s, v25.4s, v15.4s\n"
"sqdmulh v26.4s, v26.4s, v14.4s\n"
"sqdmulh v27.4s, v27.4s, v15.4s\n"
"sqdmulh v28.4s, v28.4s, v14.4s\n"
"sqdmulh v29.4s, v29.4s, v15.4s\n"
"sqdmulh v30.4s, v30.4s, v14.4s\n"
"sqdmulh v31.4s, v31.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
"srshl v16.4s, v16.4s, v11.4s\n"
"srshl v17.4s, v17.4s, v12.4s\n"
"srshl v18.4s, v18.4s, v11.4s\n"
"srshl v19.4s, v19.4s, v12.4s\n"
"srshl v20.4s, v20.4s, v11.4s\n"
"srshl v21.4s, v21.4s, v12.4s\n"
"srshl v22.4s, v22.4s, v11.4s\n"
"srshl v23.4s, v23.4s, v12.4s\n"
"srshl v24.4s, v24.4s, v11.4s\n"
"srshl v25.4s, v25.4s, v12.4s\n"
"srshl v26.4s, v26.4s, v11.4s\n"
"srshl v27.4s, v27.4s, v12.4s\n"
"srshl v28.4s, v28.4s, v11.4s\n"
"srshl v29.4s, v29.4s, v12.4s\n"
"srshl v30.4s, v30.4s, v11.4s\n"
"srshl v31.4s, v31.4s, v12.4s\n"
"b 9f\n"
"8:\n"
// Case where channels are columns
// Apply the positive exponent part of the multiplier.
"dup v4.4s, v9.s[0]\n"
"dup v5.4s, v9.s[1]\n"
"dup v6.4s, v9.s[2]\n"
"dup v7.4s, v9.s[3]\n"
"sshl v16.4s, v16.4s, v4.4s\n"
"sshl v17.4s, v17.4s, v4.4s\n"
"sshl v18.4s, v18.4s, v5.4s\n"
"sshl v19.4s, v19.4s, v5.4s\n"
"sshl v20.4s, v20.4s, v6.4s\n"
"sshl v21.4s, v21.4s, v6.4s\n"
"sshl v22.4s, v22.4s, v7.4s\n"
"sshl v23.4s, v23.4s, v7.4s\n"
"dup v4.4s, v10.s[0]\n"
"dup v5.4s, v10.s[1]\n"
"dup v6.4s, v10.s[2]\n"
"dup v7.4s, v10.s[3]\n"
"sshl v24.4s, v24.4s, v4.4s\n"
"sshl v25.4s, v25.4s, v4.4s\n"
"sshl v26.4s, v26.4s, v5.4s\n"
"sshl v27.4s, v27.4s, v5.4s\n"
"sshl v28.4s, v28.4s, v6.4s\n"
"sshl v29.4s, v29.4s, v6.4s\n"
"sshl v30.4s, v30.4s, v7.4s\n"
"sshl v31.4s, v31.4s, v7.4s\n"
"11:\n"
// Apply the fixed-point part of the multiplier.
"sqdmulh v16.4s, v16.4s, v14.s[0]\n"
"sqdmulh v17.4s, v17.4s, v14.s[0]\n"
"sqdmulh v18.4s, v18.4s, v14.s[1]\n"
"sqdmulh v19.4s, v19.4s, v14.s[1]\n"
"sqdmulh v20.4s, v20.4s, v14.s[2]\n"
"sqdmulh v21.4s, v21.4s, v14.s[2]\n"
"sqdmulh v22.4s, v22.4s, v14.s[3]\n"
"sqdmulh v23.4s, v23.4s, v14.s[3]\n"
"sqdmulh v24.4s, v24.4s, v15.s[0]\n"
"sqdmulh v25.4s, v25.4s, v15.s[0]\n"
"sqdmulh v26.4s, v26.4s, v15.s[1]\n"
"sqdmulh v27.4s, v27.4s, v15.s[1]\n"
"sqdmulh v28.4s, v28.4s, v15.s[2]\n"
"sqdmulh v29.4s, v29.4s, v15.s[2]\n"
"sqdmulh v30.4s, v30.4s, v15.s[3]\n"
"sqdmulh v31.4s, v31.4s, v15.s[3]\n"
// Apply the negative exponent part of the multiplier.
"dup v4.4s, v11.s[0]\n"
"dup v5.4s, v11.s[1]\n"
"dup v6.4s, v11.s[2]\n"
"dup v7.4s, v11.s[3]\n"
"srshl v16.4s, v16.4s, v4.4s\n"
"srshl v17.4s, v17.4s, v4.4s\n"
"srshl v18.4s, v18.4s, v5.4s\n"
"srshl v19.4s, v19.4s, v5.4s\n"
"srshl v20.4s, v20.4s, v6.4s\n"
"srshl v21.4s, v21.4s, v6.4s\n"
"srshl v22.4s, v22.4s, v7.4s\n"
"srshl v23.4s, v23.4s, v7.4s\n"
"dup v4.4s, v12.s[0]\n"
"dup v5.4s, v12.s[1]\n"
"dup v6.4s, v12.s[2]\n"
"dup v7.4s, v12.s[3]\n"
"srshl v24.4s, v24.4s, v4.4s\n"
"srshl v25.4s, v25.4s, v4.4s\n"
"srshl v26.4s, v26.4s, v5.4s\n"
"srshl v27.4s, v27.4s, v5.4s\n"
"srshl v28.4s, v28.4s, v6.4s\n"
"srshl v29.4s, v29.4s, v6.4s\n"
"srshl v30.4s, v30.4s, v7.4s\n"
"srshl v31.4s, v31.4s, v7.4s\n"
"9:\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
"sqxtn v18.4h, v20.4s\n"
"sqxtn2 v18.8h, v21.4s\n"
"sqxtn v19.4h, v22.4s\n"
"sqxtn2 v19.8h, v23.4s\n"
"sqxtn v20.4h, v24.4s\n"
"sqxtn2 v20.8h, v25.4s\n"
"sqxtn v21.4h, v26.4s\n"
"sqxtn2 v21.8h, v27.4s\n"
"sqxtn v22.4h, v28.4s\n"
"sqxtn2 v22.8h, v29.4s\n"
"sqxtn v23.4h, v30.4s\n"
"sqxtn2 v23.8h, v31.4s\n"
// At this point, v24 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"sqadd v16.8h, v16.8h, v14.8h\n"
"sqadd v17.8h, v17.8h, v14.8h\n"
"sqadd v18.8h, v18.8h, v14.8h\n"
"sqadd v19.8h, v19.8h, v14.8h\n"
"sqadd v20.8h, v20.8h, v14.8h\n"
"sqadd v21.8h, v21.8h, v14.8h\n"
"sqadd v22.8h, v22.8h, v14.8h\n"
"sqadd v23.8h, v23.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtun v16.8b, v16.8h\n"
"sqxtun2 v16.16b, v17.8h\n"
"sqxtun v17.8b, v18.8h\n"
"sqxtun2 v17.16b, v19.8h\n"
"sqxtun v18.8b, v20.8h\n"
"sqxtun2 v18.16b, v21.8h\n"
"sqxtun v19.8b, v22.8h\n"
"sqxtun2 v19.16b, v23.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"umax v16.16b, v16.16b, v14.16b\n"
"umax v17.16b, v17.16b, v14.16b\n"
"umax v18.16b, v18.16b, v14.16b\n"
"umax v19.16b, v19.16b, v14.16b\n"
// Apply the clamp_max bound
"umin v16.16b, v16.16b, v15.16b\n"
"umin v17.16b, v17.16b, v15.16b\n"
"umin v18.16b, v18.16b, v15.16b\n"
"umin v19.16b, v19.16b, v15.16b\n"
// Make it so that all of the final 8bit values are stored in the
// first 64bits of 128bit NEON registers, so they can be stored
// by 64bit st1 store instructions with byte alignment.
"dup d20, v16.d[1]\n"
"dup d21, v17.d[1]\n"
"dup d22, v18.d[1]\n"
"dup d23, v19.d[1]\n"
// Compute how much of the 8x8 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w2 = how many cols of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 30f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #8\n"
"b 31f\n"
"30:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"31:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v16.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v20.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v17.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v21.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v18.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v22.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v19.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v23.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 41f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #8\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"41:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
"sqxtn v18.4h, v20.4s\n"
"sqxtn2 v18.8h, v21.4s\n"
"sqxtn v19.4h, v22.4s\n"
"sqxtn2 v19.8h, v23.4s\n"
"sqxtn v20.4h, v24.4s\n"
"sqxtn2 v20.8h, v25.4s\n"
"sqxtn v21.4h, v26.4s\n"
"sqxtn2 v21.8h, v27.4s\n"
"sqxtn v22.4h, v28.4s\n"
"sqxtn2 v22.8h, v29.4s\n"
"sqxtn v23.4h, v30.4s\n"
"sqxtn2 v23.8h, v31.4s\n"
// At this point, v24 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"sqadd v16.8h, v16.8h, v14.8h\n"
"sqadd v17.8h, v17.8h, v14.8h\n"
"sqadd v18.8h, v18.8h, v14.8h\n"
"sqadd v19.8h, v19.8h, v14.8h\n"
"sqadd v20.8h, v20.8h, v14.8h\n"
"sqadd v21.8h, v21.8h, v14.8h\n"
"sqadd v22.8h, v22.8h, v14.8h\n"
"sqadd v23.8h, v23.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtn v16.8b, v16.8h\n"
"sqxtn2 v16.16b, v17.8h\n"
"sqxtn v17.8b, v18.8h\n"
"sqxtn2 v17.16b, v19.8h\n"
"sqxtn v18.8b, v20.8h\n"
"sqxtn2 v18.16b, v21.8h\n"
"sqxtn v19.8b, v22.8h\n"
"sqxtn2 v19.16b, v23.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.16b, v16.16b, v14.16b\n"
"smax v17.16b, v17.16b, v14.16b\n"
"smax v18.16b, v18.16b, v14.16b\n"
"smax v19.16b, v19.16b, v14.16b\n"
// Apply the clamp_max bound
"smin v16.16b, v16.16b, v15.16b\n"
"smin v17.16b, v17.16b, v15.16b\n"
"smin v18.16b, v18.16b, v15.16b\n"
"smin v19.16b, v19.16b, v15.16b\n"
// Make it so that all of the final 8bit values are stored in the
// first 64bits of 128bit NEON registers, so they can be stored
// by 64bit st1 store instructions with byte alignment.
"dup d20, v16.d[1]\n"
"dup d21, v17.d[1]\n"
"dup d22, v18.d[1]\n"
"dup d23, v19.d[1]\n"
// Compute how much of the 8x8 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w2 = how many cols of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 130f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #8\n"
"b 131f\n"
"130:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"131:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v16.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v20.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v17.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v21.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v18.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v22.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v19.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v23.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 141f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"150:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"151:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 151b\n"
"add w6, w6, #1\n"
"add x3, x3, #8\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 150b\n"
"141:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"saddw v16.4s, v16.4s, v14.4h\n"
"saddw v17.4s, v17.4s, v14.4h\n"
"saddw v18.4s, v18.4s, v14.4h\n"
"saddw v19.4s, v19.4s, v14.4h\n"
"saddw v20.4s, v20.4s, v14.4h\n"
"saddw v21.4s, v21.4s, v14.4h\n"
"saddw v22.4s, v22.4s, v14.4h\n"
"saddw v23.4s, v23.4s, v14.4h\n"
"saddw v24.4s, v24.4s, v14.4h\n"
"saddw v25.4s, v25.4s, v14.4h\n"
"saddw v26.4s, v26.4s, v14.4h\n"
"saddw v27.4s, v27.4s, v14.4h\n"
"saddw v28.4s, v28.4s, v14.4h\n"
"saddw v29.4s, v29.4s, v14.4h\n"
"saddw v30.4s, v30.4s, v14.4h\n"
"saddw v31.4s, v31.4s, v14.4h\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
"sqxtn v18.4h, v20.4s\n"
"sqxtn2 v18.8h, v21.4s\n"
"sqxtn v19.4h, v22.4s\n"
"sqxtn2 v19.8h, v23.4s\n"
"sqxtn v20.4h, v24.4s\n"
"sqxtn2 v20.8h, v25.4s\n"
"sqxtn v21.4h, v26.4s\n"
"sqxtn2 v21.8h, v27.4s\n"
"sqxtn v22.4h, v28.4s\n"
"sqxtn2 v22.8h, v29.4s\n"
"sqxtn v23.4h, v30.4s\n"
"sqxtn2 v23.8h, v31.4s\n"
// At this point, v24 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Load the clamp_min, clamp_max bounds
"ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.8h, w2\n" // clamp_min
"dup v15.8h, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.8h, v16.8h, v14.8h\n"
"smax v17.8h, v17.8h, v14.8h\n"
"smax v18.8h, v18.8h, v14.8h\n"
"smax v19.8h, v19.8h, v14.8h\n"
"smax v20.8h, v20.8h, v14.8h\n"
"smax v21.8h, v21.8h, v14.8h\n"
"smax v22.8h, v22.8h, v14.8h\n"
"smax v23.8h, v23.8h, v14.8h\n"
// Apply the clamp_max bound
"smin v16.8h, v16.8h, v15.8h\n"
"smin v17.8h, v17.8h, v15.8h\n"
"smin v18.8h, v18.8h, v15.8h\n"
"smin v19.8h, v19.8h, v15.8h\n"
"smin v20.8h, v20.8h, v15.8h\n"
"smin v21.8h, v21.8h, v15.8h\n"
"smin v22.8h, v22.8h, v15.8h\n"
"smin v23.8h, v23.8h, v15.8h\n"
// Compute how much of the 8x8 block of destination 16bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 230f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #16\n"
"b 231f\n"
"230:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"231:\n"
// Write our 16bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v16.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v17.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v18.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v19.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v20.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v21.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v22.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v23.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 241f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"250:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"251:\n"
"ldrsh w7, [x3, x5, lsl #1]\n"
"strh w7, [x4, x5, lsl #1]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 251b\n"
"add w6, w6, #1\n"
"add x3, x3, #16\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 250b\n"
"241:\n"
"add %[dst_ptr], %[dst_ptr], #16\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// Compute how much of the 8x8 block of destination 32it values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 330f\n"
// Not all of the 8x8 block fits.
// Write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"st1 {v16.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v16)
"st1 {v17.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v17)
"st1 {v18.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v18)
"st1 {v19.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v19)
"st1 {v20.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v20)
"st1 {v21.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v21)
"st1 {v22.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v22)
"st1 {v23.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v23)
"st1 {v24.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v24)
"st1 {v25.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v25)
"st1 {v26.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v26)
"st1 {v27.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v27)
"st1 {v28.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v28)
"st1 {v29.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v29)
"st1 {v30.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v30)
"st1 {v31.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v31)
"b 331f\n"
"330:\n"
// Yes, all of the 8x8 block fits.
"mov x4, %[dst_ptr]\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.4s, v17.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v18.4s, v19.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v20.4s, v21.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v22.4s, v23.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v24.4s, v25.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v26.4s, v27.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v28.4s, v29.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v30.4s, v31.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
"331:\n"
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 341f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"350:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"351:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 351b\n"
"add w6, w6, #1\n"
"add x3, x3, #32\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 350b\n"
"341:\n"
"add %[dst_ptr], %[dst_ptr], #32\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #8\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #8\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #4\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
[dst_type_id] "r"(params.dst_type_id)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
}
// A fork of the above 8bitNeonDotprod kernel but removes the max streaming
// manual unrolling. Manually unrolling the inner loops benefits some GEMM
// shapes on the Cortex-A76 but destroys performance on the X1 by increasing
// backend stalls. Therefore, we remove the MAX_STREAMING option in this
// kernel. The target CPU for this kernel is currently only the Cortex-X1.
void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params) {
profiler::ScopeLabel label("Kernel (kNeonDotprod)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
void* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are int32 accumulators.
// During accumulation, v0 -- v15 are used to load int8 data from LHS and
// RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
// v3 are used to load a 4x8 block of RHS, like this:
//
// int8 RHS 4x8 block
// /-----------------------------------------|
// |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
// | ... ... |
// |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
// \-----------------------------------------/
// int8 LHS 8x4 block
// /---------------------\ /-----------------------------------------|
// |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
// | ... ... | | ... ... |
// |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
// |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
// | ... ... | | ... ... |
// |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
// \---------------------/ \-----------------------------------------/
// int32 accumulators 8x8 block
//
// In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
// is repeated 4 times, using 4x more registers for LHS and RHS, so that
// is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
//
// Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
// unused, and v8 -- v15 are used for loading parameters used for the
// post-accumulation part of the kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Load the first 32 bytes of LHS and RHS data.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #4\n"
// Perform the first few multiply-adds on the data that we have already
// loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// Kernel inner loop (over depth).
// Reminder - w1 is how many levels of depth we have already loaded
// data for, w12 is the total depth.
"cmp w1, w12\n"
"beq 79f\n"
"2:\n"
// Because of the data that we have already loaded, we can start the
// loop body right away with some multiply-adds.
".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
// Each iteration of this loop advances by 4 levels of depth.
"add w1, w1, #4\n"
".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
// Loop termination condition.
"cmp w1, w12\n"
".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
"ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
"ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"blt 2b\n"
"79:\n"
// End of the inner loop on depth. Now perform the remaining
// multiply-adds of the last 4 levels of depth, for which the LHS
// and RHS data is already loaded.
".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 8x8 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 8x8 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
"mvni v8.4s, #0\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
// Determine the channel index.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
// Offset the bias pointer as needed given the current row, col.
"add x5, x1, x3, lsl #2\n"
// If there is no bias, use no offset, just address the passed zero
// data.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 8 bias values.
"ld1 {v14.4s}, [x1], #16\n"
"ld1 {v15.4s}, [x1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
// Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"add v14.4s, v14.4s, v9.4s\n"
"add v15.4s, v15.4s, v9.4s\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
// Jump based on channel dimension.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 6f\n"
// Case where channels are rows
"add v16.4s, v16.4s, v14.4s\n"
"add v17.4s, v17.4s, v15.4s\n"
"add v18.4s, v18.4s, v14.4s\n"
"add v19.4s, v19.4s, v15.4s\n"
"add v20.4s, v20.4s, v14.4s\n"
"add v21.4s, v21.4s, v15.4s\n"
"add v22.4s, v22.4s, v14.4s\n"
"add v23.4s, v23.4s, v15.4s\n"
"add v24.4s, v24.4s, v14.4s\n"
"add v25.4s, v25.4s, v15.4s\n"
"add v26.4s, v26.4s, v14.4s\n"
"add v27.4s, v27.4s, v15.4s\n"
"add v28.4s, v28.4s, v14.4s\n"
"add v29.4s, v29.4s, v15.4s\n"
"add v30.4s, v30.4s, v14.4s\n"
"add v31.4s, v31.4s, v15.4s\n"
"b 7f\n"
"6:\n"
// Case where channels are columns
"dup v10.4s, v14.s[0]\n"
"dup v11.4s, v14.s[1]\n"
"dup v12.4s, v14.s[2]\n"
"dup v13.4s, v14.s[3]\n"
"add v16.4s, v16.4s, v10.4s\n"
"add v17.4s, v17.4s, v10.4s\n"
"add v18.4s, v18.4s, v11.4s\n"
"add v19.4s, v19.4s, v11.4s\n"
"add v20.4s, v20.4s, v12.4s\n"
"add v21.4s, v21.4s, v12.4s\n"
"add v22.4s, v22.4s, v13.4s\n"
"add v23.4s, v23.4s, v13.4s\n"
"dup v10.4s, v15.s[0]\n"
"dup v11.4s, v15.s[1]\n"
"dup v12.4s, v15.s[2]\n"
"dup v13.4s, v15.s[3]\n"
"add v24.4s, v24.4s, v10.4s\n"
"add v25.4s, v25.4s, v10.4s\n"
"add v26.4s, v26.4s, v11.4s\n"
"add v27.4s, v27.4s, v11.4s\n"
"add v28.4s, v28.4s, v12.4s\n"
"add v29.4s, v29.4s, v12.4s\n"
"add v30.4s, v30.4s, v13.4s\n"
"add v31.4s, v31.4s, v13.4s\n"
"7:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"add x3, x3, %x[col], lsl #2\n"
"ld1 {v14.4s}, [x3], #16\n"
"ld1 {v15.4s}, [x3]\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"dup v10.4s, w5\n" // create lhs_zero_point_vec
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"mls v16.4s, v10.4s, v14.s[0]\n"
"mls v17.4s, v10.4s, v14.s[0]\n"
"mls v18.4s, v10.4s, v14.s[1]\n"
"mls v19.4s, v10.4s, v14.s[1]\n"
"mls v20.4s, v10.4s, v14.s[2]\n"
"mls v21.4s, v10.4s, v14.s[2]\n"
"mls v22.4s, v10.4s, v14.s[3]\n"
"mls v23.4s, v10.4s, v14.s[3]\n"
"mls v24.4s, v10.4s, v15.s[0]\n"
"mls v25.4s, v10.4s, v15.s[0]\n"
"mls v26.4s, v10.4s, v15.s[1]\n"
"mls v27.4s, v10.4s, v15.s[1]\n"
"mls v28.4s, v10.4s, v15.s[2]\n"
"mls v29.4s, v10.4s, v15.s[2]\n"
"mls v30.4s, v10.4s, v15.s[3]\n"
"mls v31.4s, v10.4s, v15.s[3]\n"
"401:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"add x2, x2, %x[row], lsl #2\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
// Load 4 lhs_sums values.
"ld1 {v11.4s}, [x2], #16\n"
"ld1 {v12.4s}, [x2]\n"
"ins v13.s[1], w5\n" // rhs_zero_point
// Compute lhs_sums * rhs_zero_point.
"mul v11.4s, v11.4s, v13.s[1]\n"
"mul v12.4s, v12.4s, v13.s[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"sub v16.4s, v16.4s, v11.4s\n"
"sub v17.4s, v17.4s, v12.4s\n"
"sub v18.4s, v18.4s, v11.4s\n"
"sub v19.4s, v19.4s, v12.4s\n"
"sub v20.4s, v20.4s, v11.4s\n"
"sub v21.4s, v21.4s, v12.4s\n"
"sub v22.4s, v22.4s, v11.4s\n"
"sub v23.4s, v23.4s, v12.4s\n"
"sub v24.4s, v24.4s, v11.4s\n"
"sub v25.4s, v25.4s, v12.4s\n"
"sub v26.4s, v26.4s, v11.4s\n"
"sub v27.4s, v27.4s, v12.4s\n"
"sub v28.4s, v28.4s, v11.4s\n"
"sub v29.4s, v29.4s, v12.4s\n"
"sub v30.4s, v30.4s, v11.4s\n"
"sub v31.4s, v31.4s, v12.4s\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
//Load the exponent part of the multiplier.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
// Determine the channel index.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
// Compute the multiplier_exponent pointer
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"add x5, x1, x3, lsl #2\n"
"csel x1, x1, x5, eq\n"
// Load multiplier_exponent
"ldr q9, [x1]\n"
"ldr q10, [x1, #16]\n"
// Separate positive and negative exponents
"smin v11.4s, v8.4s, v9.4s\n"
"smin v12.4s, v8.4s, v10.4s\n"
"sub v9.4s, v9.4s, v11.4s\n"
"sub v10.4s, v10.4s, v12.4s\n"
// Compute the multiplier_fixedpoint pointer
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"add x5, x4, x3, lsl #2\n"
"csel x4, x4, x5, eq\n"
// Load multiplier_fixedpoint
"ldr q14, [x4]\n"
"ldr q15, [x4, #16]\n"
// Jump based on channel dimension.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 8f\n"
// Case where channels are rows
// Apply the positive exponent part of the multiplier.
"sshl v16.4s, v16.4s, v9.4s\n"
"sshl v17.4s, v17.4s, v10.4s\n"
"sshl v18.4s, v18.4s, v9.4s\n"
"sshl v19.4s, v19.4s, v10.4s\n"
"sshl v20.4s, v20.4s, v9.4s\n"
"sshl v21.4s, v21.4s, v10.4s\n"
"sshl v22.4s, v22.4s, v9.4s\n"
"sshl v23.4s, v23.4s, v10.4s\n"
"sshl v24.4s, v24.4s, v9.4s\n"
"sshl v25.4s, v25.4s, v10.4s\n"
"sshl v26.4s, v26.4s, v9.4s\n"
"sshl v27.4s, v27.4s, v10.4s\n"
"sshl v28.4s, v28.4s, v9.4s\n"
"sshl v29.4s, v29.4s, v10.4s\n"
"sshl v30.4s, v30.4s, v9.4s\n"
"sshl v31.4s, v31.4s, v10.4s\n"
"10:\n"
// Apply the fixed-point part of the multiplier.
"sqdmulh v16.4s, v16.4s, v14.4s\n"
"sqdmulh v17.4s, v17.4s, v15.4s\n"
"sqdmulh v18.4s, v18.4s, v14.4s\n"
"sqdmulh v19.4s, v19.4s, v15.4s\n"
"sqdmulh v20.4s, v20.4s, v14.4s\n"
"sqdmulh v21.4s, v21.4s, v15.4s\n"
"sqdmulh v22.4s, v22.4s, v14.4s\n"
"sqdmulh v23.4s, v23.4s, v15.4s\n"
"sqdmulh v24.4s, v24.4s, v14.4s\n"
"sqdmulh v25.4s, v25.4s, v15.4s\n"
"sqdmulh v26.4s, v26.4s, v14.4s\n"
"sqdmulh v27.4s, v27.4s, v15.4s\n"
"sqdmulh v28.4s, v28.4s, v14.4s\n"
"sqdmulh v29.4s, v29.4s, v15.4s\n"
"sqdmulh v30.4s, v30.4s, v14.4s\n"
"sqdmulh v31.4s, v31.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
"srshl v16.4s, v16.4s, v11.4s\n"
"srshl v17.4s, v17.4s, v12.4s\n"
"srshl v18.4s, v18.4s, v11.4s\n"
"srshl v19.4s, v19.4s, v12.4s\n"
"srshl v20.4s, v20.4s, v11.4s\n"
"srshl v21.4s, v21.4s, v12.4s\n"
"srshl v22.4s, v22.4s, v11.4s\n"
"srshl v23.4s, v23.4s, v12.4s\n"
"srshl v24.4s, v24.4s, v11.4s\n"
"srshl v25.4s, v25.4s, v12.4s\n"
"srshl v26.4s, v26.4s, v11.4s\n"
"srshl v27.4s, v27.4s, v12.4s\n"
"srshl v28.4s, v28.4s, v11.4s\n"
"srshl v29.4s, v29.4s, v12.4s\n"
"srshl v30.4s, v30.4s, v11.4s\n"
"srshl v31.4s, v31.4s, v12.4s\n"
"b 9f\n"
"8:\n"
// Case where channels are columns
// Apply the positive exponent part of the multiplier.
"dup v4.4s, v9.s[0]\n"
"dup v5.4s, v9.s[1]\n"
"dup v6.4s, v9.s[2]\n"
"dup v7.4s, v9.s[3]\n"
"sshl v16.4s, v16.4s, v4.4s\n"
"sshl v17.4s, v17.4s, v4.4s\n"
"sshl v18.4s, v18.4s, v5.4s\n"
"sshl v19.4s, v19.4s, v5.4s\n"
"sshl v20.4s, v20.4s, v6.4s\n"
"sshl v21.4s, v21.4s, v6.4s\n"
"sshl v22.4s, v22.4s, v7.4s\n"
"sshl v23.4s, v23.4s, v7.4s\n"
"dup v4.4s, v10.s[0]\n"
"dup v5.4s, v10.s[1]\n"
"dup v6.4s, v10.s[2]\n"
"dup v7.4s, v10.s[3]\n"
"sshl v24.4s, v24.4s, v4.4s\n"
"sshl v25.4s, v25.4s, v4.4s\n"
"sshl v26.4s, v26.4s, v5.4s\n"
"sshl v27.4s, v27.4s, v5.4s\n"
"sshl v28.4s, v28.4s, v6.4s\n"
"sshl v29.4s, v29.4s, v6.4s\n"
"sshl v30.4s, v30.4s, v7.4s\n"
"sshl v31.4s, v31.4s, v7.4s\n"
"11:\n"
// Apply the fixed-point part of the multiplier.
"sqdmulh v16.4s, v16.4s, v14.s[0]\n"
"sqdmulh v17.4s, v17.4s, v14.s[0]\n"
"sqdmulh v18.4s, v18.4s, v14.s[1]\n"
"sqdmulh v19.4s, v19.4s, v14.s[1]\n"
"sqdmulh v20.4s, v20.4s, v14.s[2]\n"
"sqdmulh v21.4s, v21.4s, v14.s[2]\n"
"sqdmulh v22.4s, v22.4s, v14.s[3]\n"
"sqdmulh v23.4s, v23.4s, v14.s[3]\n"
"sqdmulh v24.4s, v24.4s, v15.s[0]\n"
"sqdmulh v25.4s, v25.4s, v15.s[0]\n"
"sqdmulh v26.4s, v26.4s, v15.s[1]\n"
"sqdmulh v27.4s, v27.4s, v15.s[1]\n"
"sqdmulh v28.4s, v28.4s, v15.s[2]\n"
"sqdmulh v29.4s, v29.4s, v15.s[2]\n"
"sqdmulh v30.4s, v30.4s, v15.s[3]\n"
"sqdmulh v31.4s, v31.4s, v15.s[3]\n"
// Apply the negative exponent part of the multiplier.
"dup v4.4s, v11.s[0]\n"
"dup v5.4s, v11.s[1]\n"
"dup v6.4s, v11.s[2]\n"
"dup v7.4s, v11.s[3]\n"
"srshl v16.4s, v16.4s, v4.4s\n"
"srshl v17.4s, v17.4s, v4.4s\n"
"srshl v18.4s, v18.4s, v5.4s\n"
"srshl v19.4s, v19.4s, v5.4s\n"
"srshl v20.4s, v20.4s, v6.4s\n"
"srshl v21.4s, v21.4s, v6.4s\n"
"srshl v22.4s, v22.4s, v7.4s\n"
"srshl v23.4s, v23.4s, v7.4s\n"
"dup v4.4s, v12.s[0]\n"
"dup v5.4s, v12.s[1]\n"
"dup v6.4s, v12.s[2]\n"
"dup v7.4s, v12.s[3]\n"
"srshl v24.4s, v24.4s, v4.4s\n"
"srshl v25.4s, v25.4s, v4.4s\n"
"srshl v26.4s, v26.4s, v5.4s\n"
"srshl v27.4s, v27.4s, v5.4s\n"
"srshl v28.4s, v28.4s, v6.4s\n"
"srshl v29.4s, v29.4s, v6.4s\n"
"srshl v30.4s, v30.4s, v7.4s\n"
"srshl v31.4s, v31.4s, v7.4s\n"
"9:\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
"sqxtn v18.4h, v20.4s\n"
"sqxtn2 v18.8h, v21.4s\n"
"sqxtn v19.4h, v22.4s\n"
"sqxtn2 v19.8h, v23.4s\n"
"sqxtn v20.4h, v24.4s\n"
"sqxtn2 v20.8h, v25.4s\n"
"sqxtn v21.4h, v26.4s\n"
"sqxtn2 v21.8h, v27.4s\n"
"sqxtn v22.4h, v28.4s\n"
"sqxtn2 v22.8h, v29.4s\n"
"sqxtn v23.4h, v30.4s\n"
"sqxtn2 v23.8h, v31.4s\n"
// At this point, v24 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"sqadd v16.8h, v16.8h, v14.8h\n"
"sqadd v17.8h, v17.8h, v14.8h\n"
"sqadd v18.8h, v18.8h, v14.8h\n"
"sqadd v19.8h, v19.8h, v14.8h\n"
"sqadd v20.8h, v20.8h, v14.8h\n"
"sqadd v21.8h, v21.8h, v14.8h\n"
"sqadd v22.8h, v22.8h, v14.8h\n"
"sqadd v23.8h, v23.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtun v16.8b, v16.8h\n"
"sqxtun2 v16.16b, v17.8h\n"
"sqxtun v17.8b, v18.8h\n"
"sqxtun2 v17.16b, v19.8h\n"
"sqxtun v18.8b, v20.8h\n"
"sqxtun2 v18.16b, v21.8h\n"
"sqxtun v19.8b, v22.8h\n"
"sqxtun2 v19.16b, v23.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"umax v16.16b, v16.16b, v14.16b\n"
"umax v17.16b, v17.16b, v14.16b\n"
"umax v18.16b, v18.16b, v14.16b\n"
"umax v19.16b, v19.16b, v14.16b\n"
// Apply the clamp_max bound
"umin v16.16b, v16.16b, v15.16b\n"
"umin v17.16b, v17.16b, v15.16b\n"
"umin v18.16b, v18.16b, v15.16b\n"
"umin v19.16b, v19.16b, v15.16b\n"
// Make it so that all of the final 8bit values are stored in the
// first 64bits of 128bit NEON registers, so they can be stored
// by 64bit st1 store instructions with byte alignment.
"dup d20, v16.d[1]\n"
"dup d21, v17.d[1]\n"
"dup d22, v18.d[1]\n"
"dup d23, v19.d[1]\n"
// Compute how much of the 8x8 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w2 = how many cols of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 30f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #8\n"
"b 31f\n"
"30:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"31:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v16.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v20.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v17.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v21.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v18.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v22.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v19.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v23.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 41f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #8\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"41:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
"sqxtn v18.4h, v20.4s\n"
"sqxtn2 v18.8h, v21.4s\n"
"sqxtn v19.4h, v22.4s\n"
"sqxtn2 v19.8h, v23.4s\n"
"sqxtn v20.4h, v24.4s\n"
"sqxtn2 v20.8h, v25.4s\n"
"sqxtn v21.4h, v26.4s\n"
"sqxtn2 v21.8h, v27.4s\n"
"sqxtn v22.4h, v28.4s\n"
"sqxtn2 v22.8h, v29.4s\n"
"sqxtn v23.4h, v30.4s\n"
"sqxtn2 v23.8h, v31.4s\n"
// At this point, v24 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"sqadd v16.8h, v16.8h, v14.8h\n"
"sqadd v17.8h, v17.8h, v14.8h\n"
"sqadd v18.8h, v18.8h, v14.8h\n"
"sqadd v19.8h, v19.8h, v14.8h\n"
"sqadd v20.8h, v20.8h, v14.8h\n"
"sqadd v21.8h, v21.8h, v14.8h\n"
"sqadd v22.8h, v22.8h, v14.8h\n"
"sqadd v23.8h, v23.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtn v16.8b, v16.8h\n"
"sqxtn2 v16.16b, v17.8h\n"
"sqxtn v17.8b, v18.8h\n"
"sqxtn2 v17.16b, v19.8h\n"
"sqxtn v18.8b, v20.8h\n"
"sqxtn2 v18.16b, v21.8h\n"
"sqxtn v19.8b, v22.8h\n"
"sqxtn2 v19.16b, v23.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.16b, v16.16b, v14.16b\n"
"smax v17.16b, v17.16b, v14.16b\n"
"smax v18.16b, v18.16b, v14.16b\n"
"smax v19.16b, v19.16b, v14.16b\n"
// Apply the clamp_max bound
"smin v16.16b, v16.16b, v15.16b\n"
"smin v17.16b, v17.16b, v15.16b\n"
"smin v18.16b, v18.16b, v15.16b\n"
"smin v19.16b, v19.16b, v15.16b\n"
// Make it so that all of the final 8bit values are stored in the
// first 64bits of 128bit NEON registers, so they can be stored
// by 64bit st1 store instructions with byte alignment.
"dup d20, v16.d[1]\n"
"dup d21, v17.d[1]\n"
"dup d22, v18.d[1]\n"
"dup d23, v19.d[1]\n"
// Compute how much of the 8x8 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w2 = how many cols of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 130f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #8\n"
"b 131f\n"
"130:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"131:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v16.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v20.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v17.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v21.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v18.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v22.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v19.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v23.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 141f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"150:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"151:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 151b\n"
"add w6, w6, #1\n"
"add x3, x3, #8\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 150b\n"
"141:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"saddw v16.4s, v16.4s, v14.4h\n"
"saddw v17.4s, v17.4s, v14.4h\n"
"saddw v18.4s, v18.4s, v14.4h\n"
"saddw v19.4s, v19.4s, v14.4h\n"
"saddw v20.4s, v20.4s, v14.4h\n"
"saddw v21.4s, v21.4s, v14.4h\n"
"saddw v22.4s, v22.4s, v14.4h\n"
"saddw v23.4s, v23.4s, v14.4h\n"
"saddw v24.4s, v24.4s, v14.4h\n"
"saddw v25.4s, v25.4s, v14.4h\n"
"saddw v26.4s, v26.4s, v14.4h\n"
"saddw v27.4s, v27.4s, v14.4h\n"
"saddw v28.4s, v28.4s, v14.4h\n"
"saddw v29.4s, v29.4s, v14.4h\n"
"saddw v30.4s, v30.4s, v14.4h\n"
"saddw v31.4s, v31.4s, v14.4h\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
"sqxtn v18.4h, v20.4s\n"
"sqxtn2 v18.8h, v21.4s\n"
"sqxtn v19.4h, v22.4s\n"
"sqxtn2 v19.8h, v23.4s\n"
"sqxtn v20.4h, v24.4s\n"
"sqxtn2 v20.8h, v25.4s\n"
"sqxtn v21.4h, v26.4s\n"
"sqxtn2 v21.8h, v27.4s\n"
"sqxtn v22.4h, v28.4s\n"
"sqxtn2 v22.8h, v29.4s\n"
"sqxtn v23.4h, v30.4s\n"
"sqxtn2 v23.8h, v31.4s\n"
// At this point, v24 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Load the clamp_min, clamp_max bounds
"ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.8h, w2\n" // clamp_min
"dup v15.8h, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.8h, v16.8h, v14.8h\n"
"smax v17.8h, v17.8h, v14.8h\n"
"smax v18.8h, v18.8h, v14.8h\n"
"smax v19.8h, v19.8h, v14.8h\n"
"smax v20.8h, v20.8h, v14.8h\n"
"smax v21.8h, v21.8h, v14.8h\n"
"smax v22.8h, v22.8h, v14.8h\n"
"smax v23.8h, v23.8h, v14.8h\n"
// Apply the clamp_max bound
"smin v16.8h, v16.8h, v15.8h\n"
"smin v17.8h, v17.8h, v15.8h\n"
"smin v18.8h, v18.8h, v15.8h\n"
"smin v19.8h, v19.8h, v15.8h\n"
"smin v20.8h, v20.8h, v15.8h\n"
"smin v21.8h, v21.8h, v15.8h\n"
"smin v22.8h, v22.8h, v15.8h\n"
"smin v23.8h, v23.8h, v15.8h\n"
// Compute how much of the 8x8 block of destination 16bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 230f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #16\n"
"b 231f\n"
"230:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"231:\n"
// Write our 16bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v16.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v17.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v18.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v19.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v20.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v21.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v22.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v23.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 241f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"250:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"251:\n"
"ldrsh w7, [x3, x5, lsl #1]\n"
"strh w7, [x4, x5, lsl #1]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 251b\n"
"add w6, w6, #1\n"
"add x3, x3, #16\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 250b\n"
"241:\n"
"add %[dst_ptr], %[dst_ptr], #16\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// Compute how much of the 8x8 block of destination 32it values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 330f\n"
// Not all of the 8x8 block fits.
// Write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"st1 {v16.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v16)
"st1 {v17.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v17)
"st1 {v18.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v18)
"st1 {v19.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v19)
"st1 {v20.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v20)
"st1 {v21.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v21)
"st1 {v22.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v22)
"st1 {v23.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v23)
"st1 {v24.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v24)
"st1 {v25.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v25)
"st1 {v26.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v26)
"st1 {v27.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v27)
"st1 {v28.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v28)
"st1 {v29.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v29)
"st1 {v30.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v30)
"st1 {v31.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v31)
"b 331f\n"
"330:\n"
// Yes, all of the 8x8 block fits.
"mov x4, %[dst_ptr]\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.4s, v17.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v18.4s, v19.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v20.4s, v21.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v22.4s, v23.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v24.4s, v25.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v26.4s, v27.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v28.4s, v29.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
"add x4, x4, x11\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v30.4s, v31.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
"331:\n"
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 341f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"350:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"351:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 351b\n"
"add w6, w6, #1\n"
"add x3, x3, #32\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 350b\n"
"341:\n"
"add %[dst_ptr], %[dst_ptr], #32\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #8\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #8\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #4\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
[dst_type_id] "r"(params.dst_type_id)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
}
// Similar to the above 8-bit dotprod kernel, but specialized for the case of
// RHS cols == 1.
// Relevant target CPUs for this kernel include ARM Cortex-A76,
// since these are 64-bit, out-of-order and with dotprod support.
void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) {
profiler::ScopeLabel label("Kernel (kNeonDotprod)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
void* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL));
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are int32 accumulators.
// During accumulation, v0 -- v15 are used to load int8 data from LHS and
// RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
// v3 are used to load a 4x8 block of RHS, like this:
//
// int8 RHS 4x1 block
// /-------|
// |v2.b[0]|
// | ... |
// |v2.b[3]|
// \-------/
// int8 LHS 8x4 block
// /---------------------\ /--------|
// |v0.b[0] ... v0.b[3] | |v16.s[0]|
// | ... ... | | ... |
// |v0.b[12] ... v0.b[15]| |v16.s[3]|
// |v1.b[0] ... v1.b[3] | |v17.s[0]|
// | ... ... | | ... |
// |v1.b[12] ... v1.b[15]| |v17.s[3]|
// \---------------------/ \--------/
// int32 accumulators 8x1 block
//
// In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
// is repeated 4 times, using 4x more registers for LHS and RHS, so that
// is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
//
// Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
// unused, and v8 -- v15 are used for loading parameters used for the
// post-accumulation part of the kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Load the first 32 bytes of LHS and RHS data.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.8b}, [%[rhs_ptr]]\n"
"add %[rhs_ptr], %[rhs_ptr], #32\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #4\n"
// Perform the first few multiply-adds on the data that we have already
// loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// Ordinary kernel inner loop (over depth), the simpler loop that the
// above was an equivalent 4x-partially-unrolled version of.
// Reminder - w1 is how many levels of depth we have already loaded
// data for, w12 is the total depth.
"cmp w1, w12\n"
"beq 79f\n"
"2:\n"
// Because of the data that we have already loaded, we can start the
// loop body right away with some multiply-adds.
// Each iteration of this loop advances by 4 levels of depth.
"add w1, w1, #4\n"
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
// Loop termination condition.
"cmp w1, w12\n"
"ld1 {v2.8b}, [%[rhs_ptr]]\n"
"add %[rhs_ptr], %[rhs_ptr], #32\n"
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"blt 2b\n"
"79:\n"
// End of the inner loop on depth. Now perform the remaining
// multiply-adds of the last 4 levels of depth, for which the LHS
// and RHS data is already loaded.
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 8x8 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 8x8 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
"mvni v8.4s, #0\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
"add x5, x4, %x[row], lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"csel x4, x4, x5, eq\n"
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
"add x5, x1, %x[row], lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 8 bias values.
"ld1 {v14.4s}, [x1], #16\n"
"ld1 {v15.4s}, [x1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.8b}, [%[rhs_ptr]]\n"
"add %[rhs_ptr], %[rhs_ptr], #32\n"
// Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"add v14.4s, v14.4s, v9.4s\n"
"add v15.4s, v15.4s, v9.4s\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
"add v16.4s, v16.4s, v14.4s\n"
"add v17.4s, v17.4s, v15.4s\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"add x3, x3, %x[col], lsl #2\n"
"ld1 {v14.4s}, [x3], #16\n"
"ld1 {v15.4s}, [x3]\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"dup v10.4s, w5\n" // create lhs_zero_point_vec
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"mls v16.4s, v10.4s, v14.s[0]\n"
"mls v17.4s, v10.4s, v14.s[0]\n"
"401:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"add x2, x2, %x[row], lsl #2\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
// Load 4 lhs_sums values.
"ld1 {v11.4s}, [x2], #16\n"
"ld1 {v12.4s}, [x2]\n"
"ins v13.s[1], w5\n" // rhs_zero_point
// Compute lhs_sums * rhs_zero_point.
"mul v11.4s, v11.4s, v13.s[1]\n"
"mul v12.4s, v12.4s, v13.s[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"sub v16.4s, v16.4s, v11.4s\n"
"sub v17.4s, v17.4s, v12.4s\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
//Load the exponent part of the multiplier.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"add x5, x1, %x[row], lsl #2\n"
"csel x1, x1, x5, eq\n"
"ldr q9, [x1]\n"
"ldr q10, [x1, #16]\n"
"smin v11.4s, v8.4s, v9.4s\n"
"smin v12.4s, v8.4s, v10.4s\n"
"sub v9.4s, v9.4s, v11.4s\n"
"sub v10.4s, v10.4s, v12.4s\n"
// Apply the positive exponent part of the multiplier.
"sshl v16.4s, v16.4s, v9.4s\n"
"sshl v17.4s, v17.4s, v10.4s\n"
"403:\n"
"ldr q14, [x4]\n" // multiplier_fixedpoint
"ldr q15, [x4, #16]\n" // multiplier_fixedpoint
// Apply the fixed-point part of the multiplier.
"sqdmulh v16.4s, v16.4s, v14.4s\n"
"sqdmulh v17.4s, v17.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
"srshl v16.4s, v16.4s, v11.4s\n"
"srshl v17.4s, v17.4s, v12.4s\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
// All data in v16 at this point.
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"sqadd v16.8h, v16.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8, leaving all data in the
// lower half of v16.
"sqxtun v16.8b, v16.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"umax v16.16b, v16.16b, v14.16b\n"
// Apply the clamp_max bound
"umin v16.16b, v16.16b, v15.16b\n"
// Make it so that all of the final 8bit values are stored in the
// first 64bits of 128bit NEON registers, so they can be stored
// by 64bit st1 store instructions with byte alignment.
"dup d20, v16.d[1]\n"
// Compute how much of the 8x1 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x1, there are some 8x1 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x1 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Test if w1==8, i.e. if all of the 8x1 block fits.
"cmp w1, w3\n"
// Yes, all of the 8x1 block fits, go to fast path.
"beq 30f\n"
// Not all of the 8x1 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #8\n"
"b 31f\n"
"30:\n"
// Yes, all of the 8x1 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"31:\n"
// Write our 8bit values to the destination
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"st1 {v16.8b}, [x3]\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 41f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"41:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"sqadd v16.8h, v16.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtn v16.8b, v16.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.16b, v16.16b, v14.16b\n"
// Apply the clamp_max bound
"smin v16.16b, v16.16b, v15.16b\n"
// Make it so that all of the final 8bit values are stored in the
// first 64bits of 128bit NEON registers, so they can be stored
// by 64bit st1 store instructions with byte alignment.
"dup d20, v16.d[1]\n"
// Compute how much of the 8x1 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x1 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Test if w1==8, i.e. if all of the 8x1 block fits.
"cmp w1, w3\n"
// Yes, all of the 8x1 block fits, go to fast path.
"beq 130f\n"
// Not all of the 8x1 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #8\n"
"b 131f\n"
"130:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"131:\n"
// Write our 8bit values to the destination
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"st1 {v16.8b}, [x3]\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 141f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"150:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"151:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 151b\n"
"141:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"saddw v16.4s, v16.4s, v14.4h\n"
"saddw v17.4s, v17.4s, v14.4h\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
// Load the clamp_min, clamp_max bounds
"ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.8h, w2\n" // clamp_min
"dup v15.8h, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.8h, v16.8h, v14.8h\n"
// Apply the clamp_max bound
"smin v16.8h, v16.8h, v15.8h\n"
// Compute how much of the 8x1 block of destination 16bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x1 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x1 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Test if w1==8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
// Yes, all of the 8x1 block fits, go to fast path.
"beq 230f\n"
// Not all of the 8x1 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #16\n"
"b 231f\n"
"230:\n"
// Yes, all of the 8x1 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"231:\n"
// Write our 16bit values to the destination
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"st1 {v16.8h}, [x3]\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
// If all of the 8x1 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 241f\n"
// Not all of the 8x1 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"250:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"251:\n"
"ldrsh w7, [x3, x5, lsl #1]\n"
"strh w7, [x4, x5, lsl #1]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 251b\n"
"241:\n"
"add %[dst_ptr], %[dst_ptr], #16\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// Compute how much of the 8x1 block of destination 32 bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x1, there are some 8x1 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x1 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
// Yes, all of the 8x1 block fits, go to fast path.
"beq 330f\n"
// Not all of the 8x1 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #16\n"
// Write our 32bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v16.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v17.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
"b 331f\n"
"330:\n"
// Yes, all of the 8x1 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x4, %[dst_ptr]\n"
"mov x3, x4\n"
// Write our 32bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v16.4s, v17.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"331:\n"
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 341f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"350:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"mov w5, #0\n"
"351:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 351b\n"
"341:\n"
"add %[dst_ptr], %[dst_ptr], #32\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #8\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #8\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #4\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
[dst_type_id] "r"(params.dst_type_id)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17");
}
// Variant of the above Kernel8bitNeonDotprod, tuned for in-order
// CPUs. Specifically here, the relevant in-order CPUs are ARM Cortex-A55r1,
// since these are 64-bit and support dotprod.
//
// While this kernel does not have a direct equivalent in gemmlowp, it was
// developed based on insights that David Mansell at ARM shared with their
// contribution of gemmlowp kernels tuned for Cortex-A55r1, with very helpful
// comments. Specifically, see this comment about tuning for Cortex-A55r1:
// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412
void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) {
profiler::ScopeLabel label(
"Kernel (kNeonDotprod, optimized for in-order cores)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
void* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are int32 accumulators.
// During accumulation, v0 -- v3 are used to load int8 data from LHS and
// RHS.
//
// int8 RHS 4x8 block
// /-----------------------------------------|
// |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
// | ... ... |
// |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
// \-----------------------------------------/
// int8 LHS 8x4 block
// /---------------------\ /-----------------------------------------|
// |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
// | ... ... | | ... ... |
// |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
// |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
// | ... ... | | ... ... |
// |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
// \---------------------/ \-----------------------------------------/
// int32 accumulators 8x8 block
//
// There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
// we did not observe a benefit of such partial unrolling on in-order CPUs.
//
// v4 -- v7 are unused, and v8 -- v15 are used for loading parameters used for
// the post-accumulation part of the kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
RUY_MAKE_ZERO(v16)
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
RUY_MAKE_ZERO(v17)
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
RUY_MAKE_ZERO(v18)
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
RUY_MAKE_ZERO(v19)
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
RUY_MAKE_ZERO(v20)
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
RUY_MAKE_ZERO(v21)
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
RUY_MAKE_ZERO(v22)
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Load the first 32 bytes of LHS and RHS data.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
// Clear accumulators.
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
// Perform the first few multiply-adds on the data that we have already
// loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
RUY_MAKE_ZERO(v28)
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
RUY_MAKE_ZERO(v29)
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
RUY_MAKE_ZERO(v30)
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
RUY_MAKE_ZERO(v31)
"1:\n"
"add x5, %[lhs_ptr], x12, lsl #3\n"
"sub x5, x5, #32\n"
"cmp %[lhs_ptr], x5\n"
"beq 79f\n"
// Main accumulation loop
"2:\n"
".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
"ldr x1, [%[lhs_ptr], #8]\n"
".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
"ldr x3, [%[rhs_ptr], #8]\n"
".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
"ldr x4, [%[rhs_ptr], #24]\n"
".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
"ldr d0, [%[lhs_ptr], #0]\n"
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
"ins v0.d[1], x1\n"
".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
"ldr x2, [%[lhs_ptr], #24]\n"
".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
"add %[lhs_ptr], %[lhs_ptr], #32\n"
".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
"ldr d2, [%[rhs_ptr], #0]\n"
".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
"ins v2.d[1], x3\n"
".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
"cmp %[lhs_ptr], x5\n"
".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
"add %[rhs_ptr], %[rhs_ptr], #32\n"
".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
"ldr d3, [%[rhs_ptr], #-16]\n"
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
"ldr d1, [%[lhs_ptr], #-16]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
"ins v3.d[1], x4\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
"ins v1.d[1], x2\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
"blt 2b\n"
// Last accumulation steps, nothing left to load.
"79:\n"
".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
"cmp %w[row], w7\n" // Have we finished the last row?
".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 8x8 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 8x8 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
"mvni v8.4s, #0\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
// Determine the channel index.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
// Offset the bias pointer as needed given the current row, col.
"add x5, x1, x3, lsl #2\n"
// If there is no bias, use no offset, just address the passed zero
// data.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 8 bias values.
"ld1 {v14.2s}, [x1], #8\n"
"ldr x5, [x1], #8\n"
"ins v14.d[1], x5\n"
"ld1 {v15.2s}, [x1], #8\n"
"ldr x5, [x1], #8\n"
"ins v15.d[1], x5\n"
// Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"add v14.4s, v14.4s, v9.4s\n"
"add v15.4s, v15.4s, v9.4s\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
// Jump based on channel dimension.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 6f\n"
// Case where channels are rows
"add v16.4s, v16.4s, v14.4s\n"
"add v17.4s, v17.4s, v15.4s\n"
"add v18.4s, v18.4s, v14.4s\n"
"add v19.4s, v19.4s, v15.4s\n"
"add v20.4s, v20.4s, v14.4s\n"
"add v21.4s, v21.4s, v15.4s\n"
"add v22.4s, v22.4s, v14.4s\n"
"add v23.4s, v23.4s, v15.4s\n"
"add v24.4s, v24.4s, v14.4s\n"
"add v25.4s, v25.4s, v15.4s\n"
"add v26.4s, v26.4s, v14.4s\n"
"add v27.4s, v27.4s, v15.4s\n"
"add v28.4s, v28.4s, v14.4s\n"
"add v29.4s, v29.4s, v15.4s\n"
"add v30.4s, v30.4s, v14.4s\n"
"add v31.4s, v31.4s, v15.4s\n"
"b 7f\n"
"6:\n"
// Case where channels are columns
"dup v10.4s, v14.s[0]\n"
"dup v11.4s, v14.s[1]\n"
"add v16.4s, v16.4s, v10.4s\n"
"dup v12.4s, v14.s[2]\n"
"add v17.4s, v17.4s, v10.4s\n"
"dup v13.4s, v14.s[3]\n"
"add v18.4s, v18.4s, v11.4s\n"
"dup v10.4s, v15.s[0]\n"
"add v19.4s, v19.4s, v11.4s\n"
"dup v11.4s, v15.s[1]\n"
"add v20.4s, v20.4s, v12.4s\n"
"add v21.4s, v21.4s, v12.4s\n"
"dup v12.4s, v15.s[2]\n"
"add v22.4s, v22.4s, v13.4s\n"
"add v23.4s, v23.4s, v13.4s\n"
"dup v13.4s, v15.s[3]\n"
"add v24.4s, v24.4s, v10.4s\n"
"add v25.4s, v25.4s, v10.4s\n"
"add v26.4s, v26.4s, v11.4s\n"
"add v27.4s, v27.4s, v11.4s\n"
"add v28.4s, v28.4s, v12.4s\n"
"add v29.4s, v29.4s, v12.4s\n"
"add v30.4s, v30.4s, v13.4s\n"
"add v31.4s, v31.4s, v13.4s\n"
"7:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"dup v10.4s, w5\n" // create lhs_zero_point_vec
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"add x5, x5, %x[col], lsl #2\n"
// Load 8 rhs_sums values.
"ld1 {v14.2s}, [x5], #8\n"
"ldr x7, [x5], #8\n"
"ld1 {v15.2s}, [x5], #8\n"
"ins v14.d[1], x7\n"
"ldr x7, [x5], #8\n"
"ins v15.d[1], x7\n"
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"mls v16.4s, v10.4s, v14.s[0]\n"
"mls v17.4s, v10.4s, v14.s[0]\n"
"mls v18.4s, v10.4s, v14.s[1]\n"
"mls v19.4s, v10.4s, v14.s[1]\n"
"mls v20.4s, v10.4s, v14.s[2]\n"
"mls v21.4s, v10.4s, v14.s[2]\n"
"mls v22.4s, v10.4s, v14.s[3]\n"
"mls v23.4s, v10.4s, v14.s[3]\n"
"mls v24.4s, v10.4s, v15.s[0]\n"
"mls v25.4s, v10.4s, v15.s[0]\n"
"mls v26.4s, v10.4s, v15.s[1]\n"
"mls v27.4s, v10.4s, v15.s[1]\n"
"mls v28.4s, v10.4s, v15.s[2]\n"
"mls v29.4s, v10.4s, v15.s[2]\n"
"mls v30.4s, v10.4s, v15.s[3]\n"
"mls v31.4s, v10.4s, v15.s[3]\n"
"401:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"add x2, x2, %x[row], lsl #2\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
"ins v13.s[1], w5\n" // rhs_zero_point
// Load 8 lhs_sums values.
"ld1 {v11.2s}, [x2], #8\n"
"ldr x4, [x2], #8\n"
"ins v11.d[1], x4\n"
"ld1 {v12.2s}, [x2], #8\n"
"ldr x4, [x2], #8\n"
"ins v12.d[1], x4\n"
// Compute lhs_sums * rhs_zero_point.
"mul v11.4s, v11.4s, v13.s[1]\n"
"mul v12.4s, v12.4s, v13.s[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"sub v16.4s, v16.4s, v11.4s\n"
"sub v17.4s, v17.4s, v12.4s\n"
"sub v18.4s, v18.4s, v11.4s\n"
"sub v19.4s, v19.4s, v12.4s\n"
"sub v20.4s, v20.4s, v11.4s\n"
"sub v21.4s, v21.4s, v12.4s\n"
"sub v22.4s, v22.4s, v11.4s\n"
"sub v23.4s, v23.4s, v12.4s\n"
"sub v24.4s, v24.4s, v11.4s\n"
"sub v25.4s, v25.4s, v12.4s\n"
"sub v26.4s, v26.4s, v11.4s\n"
"sub v27.4s, v27.4s, v12.4s\n"
"sub v28.4s, v28.4s, v11.4s\n"
"sub v29.4s, v29.4s, v12.4s\n"
"sub v30.4s, v30.4s, v11.4s\n"
"sub v31.4s, v31.4s, v12.4s\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
//Load the exponent part of the multiplier.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
// Compute the multiplier_exponent pointer
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"add x5, x1, x3, lsl #2\n"
"csel x1, x1, x5, eq\n"
// Load multiplier_exponent
"ldr q9, [x1]\n"
"ldr q10, [x1, #16]\n"
// Separate positive and negative exponents
"smin v11.4s, v8.4s, v9.4s\n"
"smin v12.4s, v8.4s, v10.4s\n"
"sub v9.4s, v9.4s, v11.4s\n"
"sub v10.4s, v10.4s, v12.4s\n"
// Compute the multiplier_fixedpoint pointer
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"add x5, x4, x3, lsl #2\n"
"csel x4, x4, x5, eq\n"
// Load multiplier_fixedpoint
"ldr q14, [x4]\n"
"ldr q15, [x4, #16]\n"
// Jump based on channel dimension.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 8f\n"
// Case where channels are rows
// Apply the positive exponent part of the multiplier.
"sshl v16.4s, v16.4s, v9.4s\n"
"sshl v17.4s, v17.4s, v10.4s\n"
"sshl v18.4s, v18.4s, v9.4s\n"
"sshl v19.4s, v19.4s, v10.4s\n"
"sshl v20.4s, v20.4s, v9.4s\n"
"sshl v21.4s, v21.4s, v10.4s\n"
"sshl v22.4s, v22.4s, v9.4s\n"
"sshl v23.4s, v23.4s, v10.4s\n"
"sshl v24.4s, v24.4s, v9.4s\n"
"sshl v25.4s, v25.4s, v10.4s\n"
"sshl v26.4s, v26.4s, v9.4s\n"
"sshl v27.4s, v27.4s, v10.4s\n"
"sshl v28.4s, v28.4s, v9.4s\n"
"sshl v29.4s, v29.4s, v10.4s\n"
"sshl v30.4s, v30.4s, v9.4s\n"
"sshl v31.4s, v31.4s, v10.4s\n"
"10:\n"
// Apply the fixed-point part of the multiplier.
//
// ... and, interleaved into that:
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
"sqdmulh v16.4s, v16.4s, v14.4s\n"
"ldr x1, [%[lhs_ptr]], #8\n"
"sqdmulh v17.4s, v17.4s, v15.4s\n"
"ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
"sqdmulh v18.4s, v18.4s, v14.4s\n"
"ldr x2, [%[lhs_ptr]], #8\n"
"sqdmulh v19.4s, v19.4s, v15.4s\n"
"ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
"sqdmulh v20.4s, v20.4s, v14.4s\n"
"ldr x5, [%[rhs_ptr]], #8\n"
"sqdmulh v21.4s, v21.4s, v15.4s\n"
"ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
"sqdmulh v22.4s, v22.4s, v14.4s\n"
"ldr x6, [%[rhs_ptr]], #8\n"
"sqdmulh v23.4s, v23.4s, v15.4s\n"
"sqdmulh v24.4s, v24.4s, v14.4s\n"
"sqdmulh v25.4s, v25.4s, v15.4s\n"
"sqdmulh v26.4s, v26.4s, v14.4s\n"
"sqdmulh v27.4s, v27.4s, v15.4s\n"
"sqdmulh v28.4s, v28.4s, v14.4s\n"
"sqdmulh v29.4s, v29.4s, v15.4s\n"
"sqdmulh v30.4s, v30.4s, v14.4s\n"
"sqdmulh v31.4s, v31.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
"srshl v16.4s, v16.4s, v11.4s\n"
"srshl v17.4s, v17.4s, v12.4s\n"
"srshl v18.4s, v18.4s, v11.4s\n"
"srshl v19.4s, v19.4s, v12.4s\n"
"srshl v20.4s, v20.4s, v11.4s\n"
"srshl v21.4s, v21.4s, v12.4s\n"
"srshl v22.4s, v22.4s, v11.4s\n"
"srshl v23.4s, v23.4s, v12.4s\n"
"srshl v24.4s, v24.4s, v11.4s\n"
"srshl v25.4s, v25.4s, v12.4s\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"srshl v26.4s, v26.4s, v11.4s\n"
"ins v13.h[4], w4\n" // dst_zero_point
"srshl v27.4s, v27.4s, v12.4s\n"
"ins v0.d[1], x1\n"
"srshl v28.4s, v28.4s, v11.4s\n"
"ins v1.d[1], x2\n"
"srshl v29.4s, v29.4s, v12.4s\n"
"ins v2.d[1], x5\n"
"srshl v30.4s, v30.4s, v11.4s\n"
"ins v3.d[1], x6\n"
"srshl v31.4s, v31.4s, v12.4s\n"
"b 9f\n"
"8:\n"
// Case where channels are columns
// Apply the positive exponent part of the multiplier.
"dup v4.4s, v9.s[0]\n"
"dup v5.4s, v9.s[1]\n"
"sshl v16.4s, v16.4s, v4.4s\n"
"dup v6.4s, v9.s[2]\n"
"sshl v17.4s, v17.4s, v4.4s\n"
"dup v7.4s, v9.s[3]\n"
"sshl v18.4s, v18.4s, v5.4s\n"
"dup v4.4s, v10.s[0]\n"
"sshl v19.4s, v19.4s, v5.4s\n"
"dup v5.4s, v10.s[1]\n"
"sshl v20.4s, v20.4s, v6.4s\n"
"sshl v21.4s, v21.4s, v6.4s\n"
"dup v6.4s, v10.s[2]\n"
"sshl v22.4s, v22.4s, v7.4s\n"
"sshl v23.4s, v23.4s, v7.4s\n"
"dup v7.4s, v10.s[3]\n"
"sshl v24.4s, v24.4s, v4.4s\n"
"sshl v25.4s, v25.4s, v4.4s\n"
"sshl v26.4s, v26.4s, v5.4s\n"
"sshl v27.4s, v27.4s, v5.4s\n"
"sshl v28.4s, v28.4s, v6.4s\n"
"sshl v29.4s, v29.4s, v6.4s\n"
"sshl v30.4s, v30.4s, v7.4s\n"
"sshl v31.4s, v31.4s, v7.4s\n"
"11:\n"
// Apply the fixed-point part of the multiplier.
//
// ... and, interleaved into that:
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
"sqdmulh v16.4s, v16.4s, v14.s[0]\n"
"ldr x1, [%[lhs_ptr]], #8\n"
"sqdmulh v17.4s, v17.4s, v14.s[0]\n"
"ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
"sqdmulh v18.4s, v18.4s, v14.s[1]\n"
"ldr x2, [%[lhs_ptr]], #8\n"
"sqdmulh v19.4s, v19.4s, v14.s[1]\n"
"ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
"sqdmulh v20.4s, v20.4s, v14.s[2]\n"
"ldr x5, [%[rhs_ptr]], #8\n"
"sqdmulh v21.4s, v21.4s, v14.s[2]\n"
"ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
"sqdmulh v22.4s, v22.4s, v14.s[3]\n"
"ldr x6, [%[rhs_ptr]], #8\n"
"sqdmulh v23.4s, v23.4s, v14.s[3]\n"
"dup v4.4s, v11.s[0]\n"
"sqdmulh v24.4s, v24.4s, v15.s[0]\n"
"dup v5.4s, v11.s[1]\n"
"sqdmulh v25.4s, v25.4s, v15.s[0]\n"
"dup v6.4s, v11.s[2]\n"
"sqdmulh v26.4s, v26.4s, v15.s[1]\n"
"dup v7.4s, v11.s[3]\n"
"sqdmulh v27.4s, v27.4s, v15.s[1]\n"
"sqdmulh v28.4s, v28.4s, v15.s[2]\n"
"sqdmulh v29.4s, v29.4s, v15.s[2]\n"
"sqdmulh v30.4s, v30.4s, v15.s[3]\n"
"sqdmulh v31.4s, v31.4s, v15.s[3]\n"
// Apply the negative exponent part of the multiplier.
"srshl v16.4s, v16.4s, v4.4s\n"
"srshl v17.4s, v17.4s, v4.4s\n"
"dup v4.4s, v12.s[0]\n"
"srshl v18.4s, v18.4s, v5.4s\n"
"srshl v19.4s, v19.4s, v5.4s\n"
"dup v5.4s, v12.s[1]\n"
"srshl v20.4s, v20.4s, v6.4s\n"
"srshl v21.4s, v21.4s, v6.4s\n"
"dup v6.4s, v12.s[2]\n"
"srshl v22.4s, v22.4s, v7.4s\n"
"srshl v23.4s, v23.4s, v7.4s\n"
"dup v7.4s, v12.s[3]\n"
"srshl v24.4s, v24.4s, v4.4s\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"srshl v25.4s, v25.4s, v4.4s\n"
"ins v13.h[4], w4\n" // dst_zero_point
"srshl v26.4s, v26.4s, v5.4s\n"
"ins v0.d[1], x1\n"
"srshl v27.4s, v27.4s, v5.4s\n"
"ins v1.d[1], x2\n"
"srshl v28.4s, v28.4s, v6.4s\n"
"ins v2.d[1], x5\n"
"srshl v29.4s, v29.4s, v6.4s\n"
"ins v3.d[1], x6\n"
"srshl v30.4s, v30.4s, v7.4s\n"
"srshl v31.4s, v31.4s, v7.4s\n"
"9:\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
"sqxtn v18.4h, v20.4s\n"
"sqxtn2 v18.8h, v21.4s\n"
"sqxtn v19.4h, v22.4s\n"
"sqxtn2 v19.8h, v23.4s\n"
"sqxtn v20.4h, v24.4s\n"
"sqxtn2 v20.8h, v25.4s\n"
"sqxtn v21.4h, v26.4s\n"
"sqxtn2 v21.8h, v27.4s\n"
"sqxtn v22.4h, v28.4s\n"
"sqxtn2 v22.8h, v29.4s\n"
"sqxtn v23.4h, v30.4s\n"
"sqxtn2 v23.8h, v31.4s\n"
// Destination zero_point
"dup v14.8h, v13.h[4]\n"
// At this point, v24 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Add the destination zero point
"sqadd v16.8h, v16.8h, v14.8h\n"
"sqadd v17.8h, v17.8h, v14.8h\n"
"sqadd v18.8h, v18.8h, v14.8h\n"
"sqadd v19.8h, v19.8h, v14.8h\n"
"sqadd v20.8h, v20.8h, v14.8h\n"
"sqadd v21.8h, v21.8h, v14.8h\n"
"sqadd v22.8h, v22.8h, v14.8h\n"
"sqadd v23.8h, v23.8h, v14.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
// Cast-and-saturate from int16 to uint8
"sqxtun v16.8b, v16.8h\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"sqxtun2 v16.16b, v17.8h\n"
"sqxtun v17.8b, v18.8h\n"
"sqxtun2 v17.16b, v19.8h\n"
"sqxtun v18.8b, v20.8h\n"
"sqxtun2 v18.16b, v21.8h\n"
"sqxtun v19.8b, v22.8h\n"
"sqxtun2 v19.16b, v23.8h\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Compute how much of the 8x8 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
// Apply the clamp_min bound
"umax v16.16b, v16.16b, v14.16b\n"
"sub w2, %w[dst_cols], %w[col]\n"
"umax v17.16b, v17.16b, v14.16b\n"
"mov w3, #8\n"
"umax v18.16b, v18.16b, v14.16b\n"
"cmp w1, #8\n"
"umax v19.16b, v19.16b, v14.16b\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
// Apply the clamp_max bound
"umin v16.16b, v16.16b, v15.16b\n"
"cmp w2, #8\n"
"umin v17.16b, v17.16b, v15.16b\n"
// Compute w2 = how many cols of the 8x8 block fit
"csel w2, w2, w3, le\n"
"umin v18.16b, v18.16b, v15.16b\n"
"umin v19.16b, v19.16b, v15.16b\n"
// Make it so that all of the final 8bit values are stored in the
// first 64bits of 128bit NEON registers, so they can be stored
// by 64bit st1 store instructions with byte alignment.
"dup d20, v16.d[1]\n"
"dup d21, v17.d[1]\n"
"dup d22, v18.d[1]\n"
"dup d23, v19.d[1]\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 30f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #8\n"
"b 31f\n"
"30:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"31:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v16.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v20.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v17.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v21.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v18.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v22.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v19.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v23.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 41f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #8\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"41:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
"sqxtn v18.4h, v20.4s\n"
"sqxtn2 v18.8h, v21.4s\n"
"sqxtn v19.4h, v22.4s\n"
"sqxtn2 v19.8h, v23.4s\n"
"sqxtn v20.4h, v24.4s\n"
"sqxtn2 v20.8h, v25.4s\n"
"sqxtn v21.4h, v26.4s\n"
"sqxtn2 v21.8h, v27.4s\n"
"sqxtn v22.4h, v28.4s\n"
"sqxtn2 v22.8h, v29.4s\n"
"sqxtn v23.4h, v30.4s\n"
"sqxtn2 v23.8h, v31.4s\n"
// Destination zero_point
"dup v14.8h, v13.h[4]\n"
// At this point, v24 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Add the destination zero point
"sqadd v16.8h, v16.8h, v14.8h\n"
"sqadd v17.8h, v17.8h, v14.8h\n"
"sqadd v18.8h, v18.8h, v14.8h\n"
"sqadd v19.8h, v19.8h, v14.8h\n"
"sqadd v20.8h, v20.8h, v14.8h\n"
"sqadd v21.8h, v21.8h, v14.8h\n"
"sqadd v22.8h, v22.8h, v14.8h\n"
"sqadd v23.8h, v23.8h, v14.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
// Cast-and-saturate from int16 to uint8
"sqxtn v16.8b, v16.8h\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"sqxtn2 v16.16b, v17.8h\n"
"sqxtn v17.8b, v18.8h\n"
"sqxtn2 v17.16b, v19.8h\n"
"sqxtn v18.8b, v20.8h\n"
"sqxtn2 v18.16b, v21.8h\n"
"sqxtn v19.8b, v22.8h\n"
"sqxtn2 v19.16b, v23.8h\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Compute how much of the 8x8 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
// Apply the clamp_min bound
"smax v16.16b, v16.16b, v14.16b\n"
"sub w2, %w[dst_cols], %w[col]\n"
"smax v17.16b, v17.16b, v14.16b\n"
"mov w3, #8\n"
"smax v18.16b, v18.16b, v14.16b\n"
"cmp w1, #8\n"
"smax v19.16b, v19.16b, v14.16b\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
// Apply the clamp_max bound
"smin v16.16b, v16.16b, v15.16b\n"
"cmp w2, #8\n"
"smin v17.16b, v17.16b, v15.16b\n"
// Compute w2 = how many cols of the 8x8 block fit
"csel w2, w2, w3, le\n"
"smin v18.16b, v18.16b, v15.16b\n"
"smin v19.16b, v19.16b, v15.16b\n"
// Make it so that all of the final 8bit values are stored in the
// first 64bits of 128bit NEON registers, so they can be stored
// by 64bit st1 store instructions with byte alignment.
"dup d20, v16.d[1]\n"
"dup d21, v17.d[1]\n"
"dup d22, v18.d[1]\n"
"dup d23, v19.d[1]\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 130f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #8\n"
"b 131f\n"
"130:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"131:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v16.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v20.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v17.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v21.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v18.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v22.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v19.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v23.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 141f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"150:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"151:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 151b\n"
"add w6, w6, #1\n"
"add x3, x3, #8\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 150b\n"
"141:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"saddw v16.4s, v16.4s, v14.4h\n"
"saddw v17.4s, v17.4s, v14.4h\n"
"saddw v18.4s, v18.4s, v14.4h\n"
"saddw v19.4s, v19.4s, v14.4h\n"
"saddw v20.4s, v20.4s, v14.4h\n"
"saddw v21.4s, v21.4s, v14.4h\n"
"saddw v22.4s, v22.4s, v14.4h\n"
"saddw v23.4s, v23.4s, v14.4h\n"
"saddw v24.4s, v24.4s, v14.4h\n"
"saddw v25.4s, v25.4s, v14.4h\n"
"saddw v26.4s, v26.4s, v14.4h\n"
"saddw v27.4s, v27.4s, v14.4h\n"
"saddw v28.4s, v28.4s, v14.4h\n"
"saddw v29.4s, v29.4s, v14.4h\n"
"saddw v30.4s, v30.4s, v14.4h\n"
"saddw v31.4s, v31.4s, v14.4h\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
"sqxtn v18.4h, v20.4s\n"
"sqxtn2 v18.8h, v21.4s\n"
"sqxtn v19.4h, v22.4s\n"
"sqxtn2 v19.8h, v23.4s\n"
"sqxtn v20.4h, v24.4s\n"
"sqxtn2 v20.8h, v25.4s\n"
"sqxtn v21.4h, v26.4s\n"
"sqxtn2 v21.8h, v27.4s\n"
"sqxtn v22.4h, v28.4s\n"
"sqxtn2 v22.8h, v29.4s\n"
"sqxtn v23.4h, v30.4s\n"
"sqxtn2 v23.8h, v31.4s\n"
// At this point, v24 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Load the clamp_min, clamp_max bounds
"ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.8h, w2\n" // clamp_min
"dup v15.8h, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.8h, v16.8h, v14.8h\n"
"smax v17.8h, v17.8h, v14.8h\n"
"smax v18.8h, v18.8h, v14.8h\n"
"smax v19.8h, v19.8h, v14.8h\n"
"smax v20.8h, v20.8h, v14.8h\n"
"smax v21.8h, v21.8h, v14.8h\n"
"smax v22.8h, v22.8h, v14.8h\n"
"smax v23.8h, v23.8h, v14.8h\n"
// Apply the clamp_max bound
"smin v16.8h, v16.8h, v15.8h\n"
"smin v17.8h, v17.8h, v15.8h\n"
"smin v18.8h, v18.8h, v15.8h\n"
"smin v19.8h, v19.8h, v15.8h\n"
"smin v20.8h, v20.8h, v15.8h\n"
"smin v21.8h, v21.8h, v15.8h\n"
"smin v22.8h, v22.8h, v15.8h\n"
"smin v23.8h, v23.8h, v15.8h\n"
// Compute how much of the 8x8 block of destination 16bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 230f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #16\n"
"b 231f\n"
"230:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"231:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v16.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v17.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v18.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v19.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v20.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v21.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v22.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"st1 {v23.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 241f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"250:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"251:\n"
"ldrsh w7, [x3, x5, lsl #1]\n"
"strh w7, [x4, x5, lsl #1]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 251b\n"
"add w6, w6, #1\n"
"add x3, x3, #16\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 250b\n"
"241:\n"
"add %[dst_ptr], %[dst_ptr], #16\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
"ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
"ldr x1, [%[lhs_ptr]], #8\n"
"ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
"ldr x2, [%[lhs_ptr]], #8\n"
"ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
"ldr x5, [%[rhs_ptr]], #8\n"
"ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
"ldr x6, [%[rhs_ptr]], #8\n"
"ins v0.d[1], x1\n"
"ins v1.d[1], x2\n"
"ins v2.d[1], x5\n"
"ins v3.d[1], x6\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// Compute how much of the 8x8 block of destination 32it values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 330f\n"
// Not all of the 8x8 block fits.
// Write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"st1 {v16.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v16)
"st1 {v17.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v17)
"st1 {v18.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v18)
"st1 {v19.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v19)
"st1 {v20.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v20)
"st1 {v21.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v21)
"st1 {v22.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v22)
"st1 {v23.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v23)
"st1 {v24.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v24)
"st1 {v25.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v25)
"st1 {v26.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v26)
"st1 {v27.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v27)
"st1 {v28.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v28)
"st1 {v29.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v29)
"st1 {v30.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v30)
"st1 {v31.4s}, [x3], #16\n"
RUY_MAKE_ZERO(v31)
"b 331f\n"
"330:\n"
// Yes, all of the 8x8 block fits.
"mov x4, %[dst_ptr]\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"st1 {v16.4s, v17.4s}, [x4], x11\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"st1 {v18.4s, v19.4s}, [x4], x11\n"
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"st1 {v20.4s, v21.4s}, [x4], x11\n"
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"st1 {v22.4s, v23.4s}, [x4], x11\n"
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"st1 {v24.4s, v25.4s}, [x4], x11\n"
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"st1 {v26.4s, v27.4s}, [x4], x11\n"
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"st1 {v28.4s, v29.4s}, [x4], x11\n"
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"st1 {v30.4s, v31.4s}, [x4], x11\n"
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
"331:\n"
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 341f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"350:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"351:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 351b\n"
"add w6, w6, #1\n"
"add x3, x3, #32\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 350b\n"
"341:\n"
"add %[dst_ptr], %[dst_ptr], #32\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #8\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #8\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
[dst_type_id] "r"(params.dst_type_id)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
}
#undef RUY_OFFSET_BIAS
#undef RUY_OFFSET_LHS_SUMS
#undef RUY_OFFSET_RHS_SUMS
#undef RUY_OFFSET_LHS_BASE_PTR
#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT
#undef RUY_OFFSET_MULTIPLIER_EXPONENT
#undef RUY_OFFSET_RHS_BASE_PTR
#undef RUY_OFFSET_DST_BASE_PTR
#undef RUY_OFFSET_LHS_ZERO_POINT
#undef RUY_OFFSET_RHS_ZERO_POINT
#undef RUY_OFFSET_DST_ZERO_POINT
#undef RUY_OFFSET_PROD_ZP_DEPTH
#undef RUY_OFFSET_START_ROW
#undef RUY_OFFSET_START_COL
#undef RUY_OFFSET_LAST_ROW
#undef RUY_OFFSET_LAST_COL
#undef RUY_OFFSET_DST_ROWS
#undef RUY_OFFSET_DST_COLS
#undef RUY_OFFSET_LHS_STRIDE
#undef RUY_OFFSET_RHS_STRIDE
#undef RUY_OFFSET_DST_STRIDE
#undef RUY_OFFSET_DEPTH
#undef RUY_OFFSET_CLAMP_MIN
#undef RUY_OFFSET_CLAMP_MAX
#undef RUY_OFFSET_FLAGS
#define RUY_OFFSET_LHS_BASE_PTR 0
#define RUY_OFFSET_RHS_BASE_PTR 8
#define RUY_OFFSET_DST_BASE_PTR 16
#define RUY_OFFSET_BIAS 24
#define RUY_OFFSET_START_ROW 32
#define RUY_OFFSET_START_COL 36
#define RUY_OFFSET_LAST_ROW 40
#define RUY_OFFSET_LAST_COL 44
#define RUY_OFFSET_LHS_STRIDE 56
#define RUY_OFFSET_RHS_STRIDE 60
#define RUY_OFFSET_DST_STRIDE 64
#define RUY_OFFSET_DEPTH 68
#define RUY_OFFSET_CLAMP_MIN 72
#define RUY_OFFSET_CLAMP_MAX 76
#define RUY_OFFSET_FLAGS 80
template <typename Params>
void CheckOffsetsInKernelParamsFloat(const Params&) {
static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "");
static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "");
static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "");
static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
}
// Just a plain float kernel; good enough for out-of-order cores.
// The closest to it in the gemmlowp collection would be
// NEON_64bit_GEMM_Float32_WithScalar,
// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3925
//
// Besides ruy-ification, the main nuance here is that we stick to a 8x8
// width instead of the wider 12x8 that the register space permits and that
// the aforementioned gemmlowp kernel uses. Ruy likes powers of two for now
// and we don't have evidence that going beyond 8x8 is needed.
void KernelFloatNeon(const KernelParamsFloat<8, 8>& params) {
CheckOffsetsInKernelParamsFloat(params);
profiler::ScopeLabel label("Kernel (kNeon)");
const float* lhs_col_ptr = params.lhs_base_ptr;
const float* rhs_col_ptr = params.rhs_base_ptr;
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr;
float* dst_col_ptr = params.dst_base_ptr;
float* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are accumulators.
// During accumulation, v0 -- v15 are used to load data from LHS and RHS.
// At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and
// v3 are used to load a 1x8 block of RHS, like this:
//
// RHS 1x8 block
// /-----------------------------------------|
// |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
// \-----------------------------------------/
// LHS 8x1 block
// /---------------------\ /-----------------------------------------|
// | v0.s[0] | |v16.s[0] ... v30.s[0]|
// | ... | | ... ... |
// | v0.s[3] | |v16.s[3] ... v30.s[3]|
// | v1.s[0] | |v17.s[0] ... v31.s[0]|
// | ... | | ... ... |
// | v1.s[3] | |v17.s[3] ... v31.s[3]|
// \---------------------/ \-----------------------------------------/
// accumulators 8x8 block
//
// In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
// is repeated 4 times, using 4x more registers for LHS and RHS, so that
// is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
//
// Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
// unused, and v8 -- v15 are used for floading parameters used for the
// post-accumulation part of the kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Load the first 32 bytes of LHS and RHS data.
"ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
"ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 1.
"mov w1, #1\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
"fmla v16.4s, v0.4s, v2.s[0]\n"
"fmla v18.4s, v0.4s, v2.s[1]\n"
"fmla v20.4s, v0.4s, v2.s[2]\n"
"fmla v22.4s, v0.4s, v2.s[3]\n"
#if RUY_OPT(MAX_STREAMING)
"cmp w12, #8\n"
"blt 78f\n"
"and w2, w12, #-4\n"
"ld1 {v4.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v5.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v6.4s}, [%[rhs_ptr]], #16\n"
"ld1 {v7.4s}, [%[rhs_ptr]], #16\n"
"ld1 {v8.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v9.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v10.4s}, [%[rhs_ptr]], #16\n"
"ld1 {v11.4s}, [%[rhs_ptr]], #16\n"
"ld1 {v12.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v13.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v14.4s}, [%[rhs_ptr]], #16\n"
"ld1 {v15.4s}, [%[rhs_ptr]], #16\n"
"mov w1, #4\n"
"80:\n"
"add %[lhs_ptr], %[lhs_ptr], #128\n"
"add %[rhs_ptr], %[rhs_ptr], #128\n"
"fmla v24.4s, v0.4s, v3.s[0]\n"
"fmla v26.4s, v0.4s, v3.s[1]\n"
"fmla v28.4s, v0.4s, v3.s[2]\n"
"fmla v30.4s, v0.4s, v3.s[3]\n"
"ldr q0, [%[lhs_ptr], #-128]\n"
"fmla v25.4s, v1.4s, v3.s[0]\n"
"fmla v27.4s, v1.4s, v3.s[1]\n"
"fmla v29.4s, v1.4s, v3.s[2]\n"
"fmla v31.4s, v1.4s, v3.s[3]\n"
"ldr q3, [%[rhs_ptr], #-112]\n"
"fmla v17.4s, v1.4s, v2.s[0]\n"
"fmla v19.4s, v1.4s, v2.s[1]\n"
"fmla v21.4s, v1.4s, v2.s[2]\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
"ldr q1, [%[lhs_ptr], #-112]\n"
"fmla v16.4s, v4.4s, v6.s[0]\n"
"fmla v18.4s, v4.4s, v6.s[1]\n"
"ldr q2, [%[rhs_ptr], #-128]\n"
"fmla v20.4s, v4.4s, v6.s[2]\n"
"fmla v22.4s, v4.4s, v6.s[3]\n"
"fmla v24.4s, v4.4s, v7.s[0]\n"
"fmla v26.4s, v4.4s, v7.s[1]\n"
"fmla v28.4s, v4.4s, v7.s[2]\n"
"fmla v30.4s, v4.4s, v7.s[3]\n"
"ldr q4, [%[lhs_ptr], #-96]\n"
"fmla v25.4s, v5.4s, v7.s[0]\n"
"fmla v27.4s, v5.4s, v7.s[1]\n"
"fmla v29.4s, v5.4s, v7.s[2]\n"
"fmla v31.4s, v5.4s, v7.s[3]\n"
"ldr q7, [%[rhs_ptr], #-80]\n"
"fmla v17.4s, v5.4s, v6.s[0]\n"
"fmla v19.4s, v5.4s, v6.s[1]\n"
"fmla v21.4s, v5.4s, v6.s[2]\n"
"fmla v23.4s, v5.4s, v6.s[3]\n"
"ldr q5, [%[lhs_ptr], #-80]\n"
"fmla v16.4s, v8.4s, v10.s[0]\n"
"fmla v18.4s, v8.4s, v10.s[1]\n"
"ldr q6, [%[rhs_ptr], #-96]\n"
"fmla v20.4s, v8.4s, v10.s[2]\n"
"fmla v22.4s, v8.4s, v10.s[3]\n"
"fmla v24.4s, v8.4s, v11.s[0]\n"
"fmla v26.4s, v8.4s, v11.s[1]\n"
"fmla v28.4s, v8.4s, v11.s[2]\n"
"fmla v30.4s, v8.4s, v11.s[3]\n"
"ldr q8, [%[lhs_ptr], #-64]\n"
"fmla v25.4s, v9.4s, v11.s[0]\n"
"fmla v27.4s, v9.4s, v11.s[1]\n"
"fmla v29.4s, v9.4s, v11.s[2]\n"
"fmla v31.4s, v9.4s, v11.s[3]\n"
"ldr q11, [%[rhs_ptr], #-48]\n"
"fmla v17.4s, v9.4s, v10.s[0]\n"
"fmla v19.4s, v9.4s, v10.s[1]\n"
"fmla v21.4s, v9.4s, v10.s[2]\n"
"fmla v23.4s, v9.4s, v10.s[3]\n"
"ldr q9, [%[lhs_ptr], #-48]\n"
"fmla v16.4s, v12.4s, v14.s[0]\n"
"fmla v18.4s, v12.4s, v14.s[1]\n"
"ldr q10, [%[rhs_ptr], #-64]\n"
"fmla v20.4s, v12.4s, v14.s[2]\n"
"fmla v22.4s, v12.4s, v14.s[3]\n"
"fmla v24.4s, v12.4s, v15.s[0]\n"
"fmla v26.4s, v12.4s, v15.s[1]\n"
"fmla v28.4s, v12.4s, v15.s[2]\n"
"fmla v30.4s, v12.4s, v15.s[3]\n"
"ldr q12, [%[lhs_ptr], #-32]\n"
"fmla v25.4s, v13.4s, v15.s[0]\n"
"fmla v27.4s, v13.4s, v15.s[1]\n"
"fmla v29.4s, v13.4s, v15.s[2]\n"
"fmla v31.4s, v13.4s, v15.s[3]\n"
"ldr q15, [%[rhs_ptr], #-16]\n"
"fmla v17.4s, v13.4s, v14.s[0]\n"
"fmla v19.4s, v13.4s, v14.s[1]\n"
"fmla v21.4s, v13.4s, v14.s[2]\n"
"fmla v23.4s, v13.4s, v14.s[3]\n"
"ldr q13, [%[lhs_ptr], #-16]\n"
"fmla v16.4s, v0.4s, v2.s[0]\n"
"fmla v18.4s, v0.4s, v2.s[1]\n"
"ldr q14, [%[rhs_ptr], #-32]\n"
"fmla v20.4s, v0.4s, v2.s[2]\n"
"fmla v22.4s, v0.4s, v2.s[3]\n"
"add w1, w1, #4\n"
"cmp w1, w2\n"
"blt 80b\n"
"fmla v16.4s, v4.4s, v6.s[0]\n"
"fmla v18.4s, v4.4s, v6.s[1]\n"
"fmla v20.4s, v4.4s, v6.s[2]\n"
"fmla v22.4s, v4.4s, v6.s[3]\n"
"fmla v24.4s, v4.4s, v7.s[0]\n"
"fmla v26.4s, v4.4s, v7.s[1]\n"
"fmla v28.4s, v4.4s, v7.s[2]\n"
"fmla v30.4s, v4.4s, v7.s[3]\n"
"fmla v25.4s, v5.4s, v7.s[0]\n"
"fmla v27.4s, v5.4s, v7.s[1]\n"
"fmla v29.4s, v5.4s, v7.s[2]\n"
"fmla v31.4s, v5.4s, v7.s[3]\n"
"fmla v17.4s, v5.4s, v6.s[0]\n"
"fmla v19.4s, v5.4s, v6.s[1]\n"
"fmla v21.4s, v5.4s, v6.s[2]\n"
"fmla v23.4s, v5.4s, v6.s[3]\n"
"fmla v16.4s, v8.4s, v10.s[0]\n"
"fmla v18.4s, v8.4s, v10.s[1]\n"
"fmla v20.4s, v8.4s, v10.s[2]\n"
"fmla v22.4s, v8.4s, v10.s[3]\n"
"fmla v24.4s, v8.4s, v11.s[0]\n"
"fmla v26.4s, v8.4s, v11.s[1]\n"
"fmla v28.4s, v8.4s, v11.s[2]\n"
"fmla v30.4s, v8.4s, v11.s[3]\n"
"fmla v25.4s, v9.4s, v11.s[0]\n"
"fmla v27.4s, v9.4s, v11.s[1]\n"
"fmla v29.4s, v9.4s, v11.s[2]\n"
"fmla v31.4s, v9.4s, v11.s[3]\n"
"fmla v17.4s, v9.4s, v10.s[0]\n"
"fmla v19.4s, v9.4s, v10.s[1]\n"
"fmla v21.4s, v9.4s, v10.s[2]\n"
"fmla v23.4s, v9.4s, v10.s[3]\n"
"fmla v16.4s, v12.4s, v14.s[0]\n"
"fmla v18.4s, v12.4s, v14.s[1]\n"
"fmla v20.4s, v12.4s, v14.s[2]\n"
"fmla v22.4s, v12.4s, v14.s[3]\n"
"fmla v24.4s, v12.4s, v15.s[0]\n"
"fmla v26.4s, v12.4s, v15.s[1]\n"
"fmla v28.4s, v12.4s, v15.s[2]\n"
"fmla v30.4s, v12.4s, v15.s[3]\n"
"fmla v25.4s, v13.4s, v15.s[0]\n"
"fmla v27.4s, v13.4s, v15.s[1]\n"
"fmla v29.4s, v13.4s, v15.s[2]\n"
"fmla v31.4s, v13.4s, v15.s[3]\n"
"fmla v17.4s, v13.4s, v14.s[0]\n"
"fmla v19.4s, v13.4s, v14.s[1]\n"
"fmla v21.4s, v13.4s, v14.s[2]\n"
"fmla v23.4s, v13.4s, v14.s[3]\n"
"78:\n"
#endif
// Accumulation loop
"cmp w1, w12\n"
"beq 79f\n"
"2:\n"
"fmla v24.4s, v0.4s, v3.s[0]\n"
"fmla v26.4s, v0.4s, v3.s[1]\n"
"ld1 {v4.4s}, [%[rhs_ptr]], #16\n"
"fmla v28.4s, v0.4s, v3.s[2]\n"
"fmla v30.4s, v0.4s, v3.s[3]\n"
"ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
"fmla v25.4s, v1.4s, v3.s[0]\n"
"fmla v27.4s, v1.4s, v3.s[1]\n"
"add w1, w1, #1\n"
"fmla v29.4s, v1.4s, v3.s[2]\n"
"fmla v31.4s, v1.4s, v3.s[3]\n"
"ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
"fmla v17.4s, v1.4s, v2.s[0]\n"
"fmla v19.4s, v1.4s, v2.s[1]\n"
"cmp w1, w12\n"
"fmla v21.4s, v1.4s, v2.s[2]\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
"ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
"fmla v16.4s, v0.4s, v4.s[0]\n"
"fmla v18.4s, v0.4s, v4.s[1]\n"
"mov v2.16b, v4.16b\n"
"fmla v20.4s, v0.4s, v4.s[2]\n"
"fmla v22.4s, v0.4s, v4.s[3]\n"
"blt 2b\n"
"79:\n"
// End of the inner loop on depth. Now perform the remaining
// multiply-adds of the last level of depth, for which the LHS
// and RHS data is already loaded.
"fmla v24.4s, v0.4s, v3.s[0]\n"
"fmla v26.4s, v0.4s, v3.s[1]\n"
"fmla v28.4s, v0.4s, v3.s[2]\n"
"fmla v30.4s, v0.4s, v3.s[3]\n"
"fmla v25.4s, v1.4s, v3.s[0]\n"
"fmla v27.4s, v1.4s, v3.s[1]\n"
"fmla v29.4s, v1.4s, v3.s[2]\n"
"fmla v31.4s, v1.4s, v3.s[3]\n"
"fmla v17.4s, v1.4s, v2.s[0]\n"
"fmla v19.4s, v1.4s, v2.s[1]\n"
"fmla v21.4s, v1.4s, v2.s[2]\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 8x8 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 8x8 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
"ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
// Determine the channel index.
"tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
// Offset the bias pointer as needed given the current row, col.
"add x5, x1, x3, lsl #2\n"
// If there is no bias, use no offset, just address the passed zero
// data.
"tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 8 bias values.
"ld1 {v14.4s}, [x1], #16\n"
"ld1 {v15.4s}, [x1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
"ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
// Perform the bias-addition.
// Jump based on channel dimension.
"tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 6f\n"
// Case where channels are rows
"fadd v16.4s, v16.4s, v14.4s\n"
"fadd v17.4s, v17.4s, v15.4s\n"
"fadd v18.4s, v18.4s, v14.4s\n"
"fadd v19.4s, v19.4s, v15.4s\n"
"fadd v20.4s, v20.4s, v14.4s\n"
"fadd v21.4s, v21.4s, v15.4s\n"
"fadd v22.4s, v22.4s, v14.4s\n"
"fadd v23.4s, v23.4s, v15.4s\n"
"fadd v24.4s, v24.4s, v14.4s\n"
"fadd v25.4s, v25.4s, v15.4s\n"
"fadd v26.4s, v26.4s, v14.4s\n"
"fadd v27.4s, v27.4s, v15.4s\n"
"fadd v28.4s, v28.4s, v14.4s\n"
"fadd v29.4s, v29.4s, v15.4s\n"
"fadd v30.4s, v30.4s, v14.4s\n"
"fadd v31.4s, v31.4s, v15.4s\n"
"b 7f\n"
"6:\n"
// Case where channels are columns
"dup v8.4s, v14.s[0]\n"
"dup v9.4s, v14.s[1]\n"
"dup v10.4s, v14.s[2]\n"
"dup v11.4s, v14.s[3]\n"
"dup v12.4s, v15.s[0]\n"
"dup v13.4s, v15.s[1]\n"
"dup v14.4s, v15.s[2]\n"
"dup v15.4s, v15.s[3]\n"
"fadd v16.4s, v16.4s, v8.4s\n"
"fadd v17.4s, v17.4s, v8.4s\n"
"fadd v18.4s, v18.4s, v9.4s\n"
"fadd v19.4s, v19.4s, v9.4s\n"
"fadd v20.4s, v20.4s, v10.4s\n"
"fadd v21.4s, v21.4s, v10.4s\n"
"fadd v22.4s, v22.4s, v11.4s\n"
"fadd v23.4s, v23.4s, v11.4s\n"
"fadd v24.4s, v24.4s, v12.4s\n"
"fadd v25.4s, v25.4s, v12.4s\n"
"fadd v26.4s, v26.4s, v13.4s\n"
"fadd v27.4s, v27.4s, v13.4s\n"
"fadd v28.4s, v28.4s, v14.4s\n"
"fadd v29.4s, v29.4s, v14.4s\n"
"fadd v30.4s, v30.4s, v15.4s\n"
"fadd v31.4s, v31.4s, v15.4s\n"
"7:\n"
// Load the clamp_min, clamp_max bounds
"ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.4s, w2\n" // clamp_min
"dup v15.4s, w3\n" // clamp_max
// Apply the clamp_min bound
"fmax v16.4s, v16.4s, v14.4s\n"
"fmax v17.4s, v17.4s, v14.4s\n"
"fmax v18.4s, v18.4s, v14.4s\n"
"fmax v19.4s, v19.4s, v14.4s\n"
"fmax v20.4s, v20.4s, v14.4s\n"
"fmax v21.4s, v21.4s, v14.4s\n"
"fmax v22.4s, v22.4s, v14.4s\n"
"fmax v23.4s, v23.4s, v14.4s\n"
"fmax v24.4s, v24.4s, v14.4s\n"
"fmax v25.4s, v25.4s, v14.4s\n"
"fmax v26.4s, v26.4s, v14.4s\n"
"fmax v27.4s, v27.4s, v14.4s\n"
"fmax v28.4s, v28.4s, v14.4s\n"
"fmax v29.4s, v29.4s, v14.4s\n"
"fmax v30.4s, v30.4s, v14.4s\n"
"fmax v31.4s, v31.4s, v14.4s\n"
// Apply the clamp_max bound
"fmin v16.4s, v16.4s, v15.4s\n"
"fmin v17.4s, v17.4s, v15.4s\n"
"fmin v18.4s, v18.4s, v15.4s\n"
"fmin v19.4s, v19.4s, v15.4s\n"
"fmin v20.4s, v20.4s, v15.4s\n"
"fmin v21.4s, v21.4s, v15.4s\n"
"fmin v22.4s, v22.4s, v15.4s\n"
"fmin v23.4s, v23.4s, v15.4s\n"
"fmin v24.4s, v24.4s, v15.4s\n"
"fmin v25.4s, v25.4s, v15.4s\n"
"fmin v26.4s, v26.4s, v15.4s\n"
"fmin v27.4s, v27.4s, v15.4s\n"
"fmin v28.4s, v28.4s, v15.4s\n"
"fmin v29.4s, v29.4s, v15.4s\n"
"fmin v30.4s, v30.4s, v15.4s\n"
"fmin v31.4s, v31.4s, v15.4s\n"
// Compute how much of the 8x8 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w2 = how many cols of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 30f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #32\n"
"b 31f\n"
"30:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"31:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"str q16, [x3, #0]\n"
"str q17, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"str q18, [x3, #0]\n"
"str q19, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
"str q20, [x3, #0]\n"
"str q21, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
"str q22, [x3, #0]\n"
"str q23, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
"str q24, [x3, #0]\n"
"str q25, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
"str q26, [x3, #0]\n"
"str q27, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
"str q28, [x3, #0]\n"
"str q29, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
"str q30, [x3, #0]\n"
"str q31, [x3, #16]\n"
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 41f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #32\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"41:\n"
"add %[dst_ptr], %[dst_ptr], #32\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #8\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #8\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 1.
"mov w1, #1\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
}
// A fork of the standard float kernel where we omit the manual loop unrolling
// to recover performance on the X1. For now, the X1 core is the only CPU that
// uses this kernel.
void KernelFloatNeonX1(const KernelParamsFloat<8, 8>& params) {
CheckOffsetsInKernelParamsFloat(params);
profiler::ScopeLabel label("Kernel (kNeon) X1");
const float* lhs_col_ptr = params.lhs_base_ptr;
const float* rhs_col_ptr = params.rhs_base_ptr;
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr;
float* dst_col_ptr = params.dst_base_ptr;
float* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are accumulators.
// During accumulation, v0 -- v15 are used to load data from LHS and RHS.
// At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and
// v3 are used to load a 1x8 block of RHS, like this:
//
// RHS 1x8 block
// /-----------------------------------------|
// |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
// \-----------------------------------------/
// LHS 8x1 block
// /---------------------\ /-----------------------------------------|
// | v0.s[0] | |v16.s[0] ... v30.s[0]|
// | ... | | ... ... |
// | v0.s[3] | |v16.s[3] ... v30.s[3]|
// | v1.s[0] | |v17.s[0] ... v31.s[0]|
// | ... | | ... ... |
// | v1.s[3] | |v17.s[3] ... v31.s[3]|
// \---------------------/ \-----------------------------------------/
// accumulators 8x8 block
//
// In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
// is repeated 4 times, using 4x more registers for LHS and RHS, so that
// is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
//
// Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
// unused, and v8 -- v15 are used for floading parameters used for the
// post-accumulation part of the kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Load the first 32 bytes of LHS and RHS data.
"ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
"ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 1.
"mov w1, #1\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
"fmla v16.4s, v0.4s, v2.s[0]\n"
"fmla v18.4s, v0.4s, v2.s[1]\n"
"fmla v20.4s, v0.4s, v2.s[2]\n"
"fmla v22.4s, v0.4s, v2.s[3]\n"
// Accumulation loop
"cmp w1, w12\n"
"beq 79f\n"
"2:\n"
"fmla v24.4s, v0.4s, v3.s[0]\n"
"fmla v26.4s, v0.4s, v3.s[1]\n"
"ld1 {v4.4s}, [%[rhs_ptr]], #16\n"
"fmla v28.4s, v0.4s, v3.s[2]\n"
"fmla v30.4s, v0.4s, v3.s[3]\n"
"ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
"fmla v25.4s, v1.4s, v3.s[0]\n"
"fmla v27.4s, v1.4s, v3.s[1]\n"
"add w1, w1, #1\n"
"fmla v29.4s, v1.4s, v3.s[2]\n"
"fmla v31.4s, v1.4s, v3.s[3]\n"
"ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
"fmla v17.4s, v1.4s, v2.s[0]\n"
"fmla v19.4s, v1.4s, v2.s[1]\n"
"cmp w1, w12\n"
"fmla v21.4s, v1.4s, v2.s[2]\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
"ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
"fmla v16.4s, v0.4s, v4.s[0]\n"
"fmla v18.4s, v0.4s, v4.s[1]\n"
"mov v2.16b, v4.16b\n"
"fmla v20.4s, v0.4s, v4.s[2]\n"
"fmla v22.4s, v0.4s, v4.s[3]\n"
"blt 2b\n"
"79:\n"
// End of the inner loop on depth. Now perform the remaining
// multiply-adds of the last level of depth, for which the LHS
// and RHS data is already loaded.
"fmla v24.4s, v0.4s, v3.s[0]\n"
"fmla v26.4s, v0.4s, v3.s[1]\n"
"fmla v28.4s, v0.4s, v3.s[2]\n"
"fmla v30.4s, v0.4s, v3.s[3]\n"
"fmla v25.4s, v1.4s, v3.s[0]\n"
"fmla v27.4s, v1.4s, v3.s[1]\n"
"fmla v29.4s, v1.4s, v3.s[2]\n"
"fmla v31.4s, v1.4s, v3.s[3]\n"
"fmla v17.4s, v1.4s, v2.s[0]\n"
"fmla v19.4s, v1.4s, v2.s[1]\n"
"fmla v21.4s, v1.4s, v2.s[2]\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 8x8 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 8x8 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
"ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
// Determine the channel index.
"tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
// Offset the bias pointer as needed given the current row, col.
"add x5, x1, x3, lsl #2\n"
// If there is no bias, use no offset, just address the passed zero
// data.
"tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 8 bias values.
"ld1 {v14.4s}, [x1], #16\n"
"ld1 {v15.4s}, [x1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
"ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
// Perform the bias-addition.
// Jump based on channel dimension.
"tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 6f\n"
// Case where channels are rows
"fadd v16.4s, v16.4s, v14.4s\n"
"fadd v17.4s, v17.4s, v15.4s\n"
"fadd v18.4s, v18.4s, v14.4s\n"
"fadd v19.4s, v19.4s, v15.4s\n"
"fadd v20.4s, v20.4s, v14.4s\n"
"fadd v21.4s, v21.4s, v15.4s\n"
"fadd v22.4s, v22.4s, v14.4s\n"
"fadd v23.4s, v23.4s, v15.4s\n"
"fadd v24.4s, v24.4s, v14.4s\n"
"fadd v25.4s, v25.4s, v15.4s\n"
"fadd v26.4s, v26.4s, v14.4s\n"
"fadd v27.4s, v27.4s, v15.4s\n"
"fadd v28.4s, v28.4s, v14.4s\n"
"fadd v29.4s, v29.4s, v15.4s\n"
"fadd v30.4s, v30.4s, v14.4s\n"
"fadd v31.4s, v31.4s, v15.4s\n"
"b 7f\n"
"6:\n"
// Case where channels are columns
"dup v8.4s, v14.s[0]\n"
"dup v9.4s, v14.s[1]\n"
"dup v10.4s, v14.s[2]\n"
"dup v11.4s, v14.s[3]\n"
"dup v12.4s, v15.s[0]\n"
"dup v13.4s, v15.s[1]\n"
"dup v14.4s, v15.s[2]\n"
"dup v15.4s, v15.s[3]\n"
"fadd v16.4s, v16.4s, v8.4s\n"
"fadd v17.4s, v17.4s, v8.4s\n"
"fadd v18.4s, v18.4s, v9.4s\n"
"fadd v19.4s, v19.4s, v9.4s\n"
"fadd v20.4s, v20.4s, v10.4s\n"
"fadd v21.4s, v21.4s, v10.4s\n"
"fadd v22.4s, v22.4s, v11.4s\n"
"fadd v23.4s, v23.4s, v11.4s\n"
"fadd v24.4s, v24.4s, v12.4s\n"
"fadd v25.4s, v25.4s, v12.4s\n"
"fadd v26.4s, v26.4s, v13.4s\n"
"fadd v27.4s, v27.4s, v13.4s\n"
"fadd v28.4s, v28.4s, v14.4s\n"
"fadd v29.4s, v29.4s, v14.4s\n"
"fadd v30.4s, v30.4s, v15.4s\n"
"fadd v31.4s, v31.4s, v15.4s\n"
"7:\n"
// Load the clamp_min, clamp_max bounds
"ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.4s, w2\n" // clamp_min
"dup v15.4s, w3\n" // clamp_max
// Apply the clamp_min bound
"fmax v16.4s, v16.4s, v14.4s\n"
"fmax v17.4s, v17.4s, v14.4s\n"
"fmax v18.4s, v18.4s, v14.4s\n"
"fmax v19.4s, v19.4s, v14.4s\n"
"fmax v20.4s, v20.4s, v14.4s\n"
"fmax v21.4s, v21.4s, v14.4s\n"
"fmax v22.4s, v22.4s, v14.4s\n"
"fmax v23.4s, v23.4s, v14.4s\n"
"fmax v24.4s, v24.4s, v14.4s\n"
"fmax v25.4s, v25.4s, v14.4s\n"
"fmax v26.4s, v26.4s, v14.4s\n"
"fmax v27.4s, v27.4s, v14.4s\n"
"fmax v28.4s, v28.4s, v14.4s\n"
"fmax v29.4s, v29.4s, v14.4s\n"
"fmax v30.4s, v30.4s, v14.4s\n"
"fmax v31.4s, v31.4s, v14.4s\n"
// Apply the clamp_max bound
"fmin v16.4s, v16.4s, v15.4s\n"
"fmin v17.4s, v17.4s, v15.4s\n"
"fmin v18.4s, v18.4s, v15.4s\n"
"fmin v19.4s, v19.4s, v15.4s\n"
"fmin v20.4s, v20.4s, v15.4s\n"
"fmin v21.4s, v21.4s, v15.4s\n"
"fmin v22.4s, v22.4s, v15.4s\n"
"fmin v23.4s, v23.4s, v15.4s\n"
"fmin v24.4s, v24.4s, v15.4s\n"
"fmin v25.4s, v25.4s, v15.4s\n"
"fmin v26.4s, v26.4s, v15.4s\n"
"fmin v27.4s, v27.4s, v15.4s\n"
"fmin v28.4s, v28.4s, v15.4s\n"
"fmin v29.4s, v29.4s, v15.4s\n"
"fmin v30.4s, v30.4s, v15.4s\n"
"fmin v31.4s, v31.4s, v15.4s\n"
// Compute how much of the 8x8 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w2 = how many cols of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 30f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #32\n"
"b 31f\n"
"30:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"31:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"str q16, [x3, #0]\n"
"str q17, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"str q18, [x3, #0]\n"
"str q19, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
"str q20, [x3, #0]\n"
"str q21, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
"str q22, [x3, #0]\n"
"str q23, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
"str q24, [x3, #0]\n"
"str q25, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
"str q26, [x3, #0]\n"
"str q27, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
"str q28, [x3, #0]\n"
"str q29, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
"str q30, [x3, #0]\n"
"str q31, [x3, #16]\n"
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 41f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #32\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"41:\n"
"add %[dst_ptr], %[dst_ptr], #32\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #8\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #8\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 1.
"mov w1, #1\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
}
// Variant of KernelFloatNeon tuned for in-order CPUs that do not
// support dotprod (while dotprod by itself is not relevant to floating-point,
// this additional bit of information that we have about the target happens to
// be useful here).
//
// So a typical target CPU here would be ARM Cortex-A53 or the original
// Cortex-A55.
//
// This kernel is similar to and inspired by gemmlowp's
// NEON_64bit_GEMM_Float32_WithScalar_A53.
// which was contributed by David Mansell with very helpful
// comments. Specifically, see this comment about tuning for Cortex-A53:
// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215
void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params) {
profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)");
CheckOffsetsInKernelParamsFloat(params);
const float* lhs_col_ptr = params.lhs_base_ptr;
const float* rhs_col_ptr = params.rhs_base_ptr;
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr;
float* dst_col_ptr = params.dst_base_ptr;
float* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are accumulators.
// During accumulation, v0 -- v3 are used to load data from LHS and RHS.
//
// RHS 1x8 block
// /-----------------------------------------|
// |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
// \-----------------------------------------/
// LHS 8x1 block
// /---------------------\ /-----------------------------------------|
// | v0.s[0] | |v16.s[0] ... v30.s[0]|
// | ... | | ... ... |
// | v0.s[3] | |v16.s[3] ... v30.s[3]|
// | v1.s[0] | |v17.s[0] ... v31.s[0]|
// | ... | | ... ... |
// | v1.s[3] | |v17.s[3] ... v31.s[3]|
// \---------------------/ \-----------------------------------------/
// accumulators 8x8 block
//
// There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
// we did not observe a benefit of such partial unrolling on in-order CPUs.
//
// v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used
// for the post-accumulation part of the kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
// Load the first 32 bytes of LHS and RHS data.
"ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
RUY_MAKE_ZERO(v17)
"ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
RUY_MAKE_ZERO(v18)
"ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
RUY_MAKE_ZERO(v19)
"ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
RUY_MAKE_ZERO(v20)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n")
RUY_MAKE_ZERO(v21)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n")
RUY_MAKE_ZERO(v22)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n")
RUY_MAKE_ZERO(v23)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n")
RUY_MAKE_ZERO(v24)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n")
RUY_MAKE_ZERO(v25)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n")
RUY_MAKE_ZERO(v26)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
RUY_MAKE_ZERO(v27)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// w1 is the number of levels of depth that remain to load
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently depth - 1.
"sub w1, w12, #1\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
"cmp w1, #0\n"
"fmla v16.4s, v0.4s, v2.s[0]\n"
"fmla v18.4s, v0.4s, v2.s[1]\n"
"fmla v20.4s, v0.4s, v2.s[2]\n"
"fmla v22.4s, v0.4s, v2.s[3]\n"
// Accumulation loop
"beq 79f\n"
"2:\n"
"fmla v24.4s, v0.4s, v3.s[0]\n"
"ldr x2, [%[lhs_ptr], #8]\n"
"fmla v26.4s, v0.4s, v3.s[1]\n"
"ldr x3, [%[lhs_ptr], #24]\n"
"fmla v28.4s, v0.4s, v3.s[2]\n"
"ldr x5, [%[rhs_ptr], #24]\n"
"fmla v30.4s, v0.4s, v3.s[3]\n"
"ldr x4, [%[rhs_ptr], #8]\n"
"fmla v25.4s, v1.4s, v3.s[0]\n"
"subs w1, w1, #1\n"
"ldr d0, [%[lhs_ptr]], #32\n"
"fmla v27.4s, v1.4s, v3.s[1]\n"
"fmla v29.4s, v1.4s, v3.s[2]\n"
"fmla v31.4s, v1.4s, v3.s[3]\n"
"ins v0.d[1], x2\n"
"ldr d3, [%[rhs_ptr], #16]\n"
"fmla v17.4s, v1.4s, v2.s[0]\n"
"fmla v19.4s, v1.4s, v2.s[1]\n"
"ins v3.d[1], x5\n"
"ldr d4, [%[rhs_ptr]], #32\n"
"fmla v21.4s, v1.4s, v2.s[2]\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
"fmla v16.4s, v0.4s, v4.s[0]\n"
"ins v4.d[1], x4\n"
"ldr d1, [%[lhs_ptr], #-16]\n"
"fmla v18.4s, v0.4s, v4.s[1]\n"
"fmla v20.4s, v0.4s, v4.s[2]\n"
"ins v1.d[1], x3\n"
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
"mov v2.16b, v4.16b\n"
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
"fmla v22.4s, v0.4s, v4.s[3]\n"
"bne 2b\n"
"79:\n"
// End of the inner loop on depth. Now perform the remaining
// multiply-adds of the last level of depth, for which the LHS
// and RHS data is already loaded.
"fmla v24.4s, v0.4s, v3.s[0]\n"
"fmla v26.4s, v0.4s, v3.s[1]\n"
"fmla v28.4s, v0.4s, v3.s[2]\n"
"fmla v30.4s, v0.4s, v3.s[3]\n"
"fmla v25.4s, v1.4s, v3.s[0]\n"
"fmla v27.4s, v1.4s, v3.s[1]\n"
"fmla v29.4s, v1.4s, v3.s[2]\n"
"fmla v31.4s, v1.4s, v3.s[3]\n"
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"fmla v17.4s, v1.4s, v2.s[0]\n"
"fmla v19.4s, v1.4s, v2.s[1]\n"
"fmla v21.4s, v1.4s, v2.s[2]\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 8x8 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 8x8 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
"ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
// Determine the channel index.
"tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
// Offset the bias pointer as needed given the current row, col.
"add x5, x1, x3, lsl #2\n"
// If there is no bias, use no offset, just address the passed zero
// data.
"tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 8 bias values.
"ld1 {v14.4s}, [x1], #16\n"
"ld1 {v15.4s}, [x1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
"ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
// Perform the bias-addition.
// Jump based on channel dimension.
"tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 6f\n"
// Case where channels are rows
"fadd v16.4s, v16.4s, v14.4s\n"
"fadd v17.4s, v17.4s, v15.4s\n"
"fadd v18.4s, v18.4s, v14.4s\n"
"fadd v19.4s, v19.4s, v15.4s\n"
"fadd v20.4s, v20.4s, v14.4s\n"
"fadd v21.4s, v21.4s, v15.4s\n"
"fadd v22.4s, v22.4s, v14.4s\n"
"fadd v23.4s, v23.4s, v15.4s\n"
"fadd v24.4s, v24.4s, v14.4s\n"
"fadd v25.4s, v25.4s, v15.4s\n"
"fadd v26.4s, v26.4s, v14.4s\n"
"fadd v27.4s, v27.4s, v15.4s\n"
"fadd v28.4s, v28.4s, v14.4s\n"
"fadd v29.4s, v29.4s, v15.4s\n"
"fadd v30.4s, v30.4s, v14.4s\n"
"fadd v31.4s, v31.4s, v15.4s\n"
"b 7f\n"
"6:\n"
// Case where channels are columns
"dup v8.4s, v14.s[0]\n"
"dup v9.4s, v14.s[1]\n"
"fadd v16.4s, v16.4s, v8.4s\n"
"dup v10.4s, v14.s[2]\n"
"fadd v17.4s, v17.4s, v8.4s\n"
"dup v11.4s, v14.s[3]\n"
"fadd v18.4s, v18.4s, v9.4s\n"
"dup v12.4s, v15.s[0]\n"
"fadd v19.4s, v19.4s, v9.4s\n"
"dup v13.4s, v15.s[1]\n"
"fadd v20.4s, v20.4s, v10.4s\n"
"dup v14.4s, v15.s[2]\n"
"fadd v21.4s, v21.4s, v10.4s\n"
"dup v15.4s, v15.s[3]\n"
"fadd v22.4s, v22.4s, v11.4s\n"
"fadd v23.4s, v23.4s, v11.4s\n"
"fadd v24.4s, v24.4s, v12.4s\n"
"fadd v25.4s, v25.4s, v12.4s\n"
"fadd v26.4s, v26.4s, v13.4s\n"
"fadd v27.4s, v27.4s, v13.4s\n"
"fadd v28.4s, v28.4s, v14.4s\n"
"fadd v29.4s, v29.4s, v14.4s\n"
"fadd v30.4s, v30.4s, v15.4s\n"
"fadd v31.4s, v31.4s, v15.4s\n"
"7:\n"
// Load the clamp_min, clamp_max bounds
"ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.4s, w2\n" // clamp_min
"dup v15.4s, w3\n" // clamp_max
// Apply the clamp_min bound
"fmax v16.4s, v16.4s, v14.4s\n"
"fmax v17.4s, v17.4s, v14.4s\n"
"fmax v18.4s, v18.4s, v14.4s\n"
"fmax v19.4s, v19.4s, v14.4s\n"
"fmax v20.4s, v20.4s, v14.4s\n"
"fmax v21.4s, v21.4s, v14.4s\n"
"fmax v22.4s, v22.4s, v14.4s\n"
"fmax v23.4s, v23.4s, v14.4s\n"
"fmax v24.4s, v24.4s, v14.4s\n"
"fmax v25.4s, v25.4s, v14.4s\n"
"fmax v26.4s, v26.4s, v14.4s\n"
"fmax v27.4s, v27.4s, v14.4s\n"
"fmax v28.4s, v28.4s, v14.4s\n"
"fmax v29.4s, v29.4s, v14.4s\n"
"fmax v30.4s, v30.4s, v14.4s\n"
"fmax v31.4s, v31.4s, v14.4s\n"
// Apply the clamp_max bound
"fmin v16.4s, v16.4s, v15.4s\n"
"fmin v17.4s, v17.4s, v15.4s\n"
"fmin v18.4s, v18.4s, v15.4s\n"
"fmin v19.4s, v19.4s, v15.4s\n"
"fmin v20.4s, v20.4s, v15.4s\n"
"fmin v21.4s, v21.4s, v15.4s\n"
"fmin v22.4s, v22.4s, v15.4s\n"
"fmin v23.4s, v23.4s, v15.4s\n"
"fmin v24.4s, v24.4s, v15.4s\n"
"fmin v25.4s, v25.4s, v15.4s\n"
"fmin v26.4s, v26.4s, v15.4s\n"
"fmin v27.4s, v27.4s, v15.4s\n"
"fmin v28.4s, v28.4s, v15.4s\n"
"fmin v29.4s, v29.4s, v15.4s\n"
"fmin v30.4s, v30.4s, v15.4s\n"
"fmin v31.4s, v31.4s, v15.4s\n"
// Compute how much of the 8x8 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w2 = how many cols of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 30f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #32\n"
"b 31f\n"
"30:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"31:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"str q16, [x3, #0]\n"
"str q17, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"str q18, [x3, #0]\n"
"str q19, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
"str q20, [x3, #0]\n"
"str q21, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
"str q22, [x3, #0]\n"
"str q23, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
"str q24, [x3, #0]\n"
"str q25, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
"str q26, [x3, #0]\n"
"str q27, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
"str q28, [x3, #0]\n"
"str q29, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
"str q30, [x3, #0]\n"
"str q31, [x3, #16]\n"
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 41f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #32\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"41:\n"
"add %[dst_ptr], %[dst_ptr], #32\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #8\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #8\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that remain to load
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently depth - 1.
"sub w1, w12, #1\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
}
// Variant of KernelFloatNeonA55ish tuned for in-order CPUs that do
// support dotprod (while dotprod by itself is not relevant to floating-point,
// this additional bit of information that we have about the target happens to
// be useful here).
//
// So a typical target CPU here would be ARM Cortex-A55r1.
//
// This kernel is similar to and inspired by gemmlowp's
// NEON_64bit_GEMM_Float32_WithScalar_A55r1.
// which was contributed by David Mansell with very helpful
// comments. Specifically, see this comment about tuning for Cortex-A55r1:
// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412
void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params) {
profiler::ScopeLabel label(
"Kernel (kNeonDotprod, optimized for in-order cores)");
CheckOffsetsInKernelParamsFloat(params);
const float* lhs_col_ptr = params.lhs_base_ptr;
const float* rhs_col_ptr = params.rhs_base_ptr;
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr;
float* dst_col_ptr = params.dst_base_ptr;
float* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are accumulators.
// During accumulation, v0 -- v3 are used to load data from LHS and RHS.
//
// RHS 1x8 block
// /-----------------------------------------|
// |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
// \-----------------------------------------/
// LHS 8x1 block
// /---------------------\ /-----------------------------------------|
// | v0.s[0] | |v16.s[0] ... v30.s[0]|
// | ... | | ... ... |
// | v0.s[3] | |v16.s[3] ... v30.s[3]|
// | v1.s[0] | |v17.s[0] ... v31.s[0]|
// | ... | | ... ... |
// | v1.s[3] | |v17.s[3] ... v31.s[3]|
// \---------------------/ \-----------------------------------------/
// accumulators 8x8 block
//
// There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
// we did not observe a benefit of such partial unrolling on in-order CPUs.
//
// v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used
// for the post-accumulation part of the kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
// Load the first 32 bytes of LHS and RHS data.
"ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
RUY_MAKE_ZERO(v17)
"ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
RUY_MAKE_ZERO(v18)
"ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
RUY_MAKE_ZERO(v19)
"ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
RUY_MAKE_ZERO(v20)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n")
RUY_MAKE_ZERO(v21)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n")
RUY_MAKE_ZERO(v22)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n")
RUY_MAKE_ZERO(v23)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n")
RUY_MAKE_ZERO(v24)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n")
RUY_MAKE_ZERO(v25)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n")
RUY_MAKE_ZERO(v26)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
RUY_MAKE_ZERO(v27)
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// w1 is the number of levels of depth that remain to load
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently depth - 1.
"sub w1, w12, #1\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
"cmp w1, #0\n"
"fmla v16.4s, v0.4s, v2.s[0]\n"
"fmla v18.4s, v0.4s, v2.s[1]\n"
"fmla v20.4s, v0.4s, v2.s[2]\n"
"fmla v22.4s, v0.4s, v2.s[3]\n"
// Accumulation loop
"beq 79f\n"
"2:\n"
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
"fmla v24.4s, v0.4s, v3.s[0]\n"
"ldr x2, [%[lhs_ptr], #8]\n"
"fmla v26.4s, v0.4s, v3.s[1]\n"
"ldr x3, [%[lhs_ptr], #24]\n"
"fmla v28.4s, v0.4s, v3.s[2]\n"
"ldr x5, [%[rhs_ptr], #24]\n"
"fmla v30.4s, v0.4s, v3.s[3]\n"
"ldr d0, [%[lhs_ptr]], #32\n"
"fmla v25.4s, v1.4s, v3.s[0]\n"
"ldr x4, [%[rhs_ptr], #8]\n"
"fmla v27.4s, v1.4s, v3.s[1]\n"
"subs w1, w1, #1\n"
"fmla v29.4s, v1.4s, v3.s[2]\n"
"ins v0.d[1], x2\n"
"fmla v31.4s, v1.4s, v3.s[3]\n"
"ldr d3, [%[rhs_ptr], #16]\n"
"fmla v17.4s, v1.4s, v2.s[0]\n"
"ins v3.d[1], x5\n"
"fmla v19.4s, v1.4s, v2.s[1]\n"
"ldr d4, [%[rhs_ptr]], #32\n"
"fmla v21.4s, v1.4s, v2.s[2]\n"
"ins v4.d[1], x4\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
"fmla v16.4s, v0.4s, v4.s[0]\n"
"ldr d1, [%[lhs_ptr], #-16]\n"
"fmla v18.4s, v0.4s, v4.s[1]\n"
"ins v1.d[1], x3\n"
"fmla v20.4s, v0.4s, v4.s[2]\n"
"mov v2.16b, v4.16b\n"
"fmla v22.4s, v0.4s, v4.s[3]\n"
"bne 2b\n"
"79:\n"
// End of the inner loop on depth. Now perform the remaining
// multiply-adds of the last level of depth, for which the LHS
// and RHS data is already loaded.
"fmla v24.4s, v0.4s, v3.s[0]\n"
"fmla v26.4s, v0.4s, v3.s[1]\n"
"fmla v28.4s, v0.4s, v3.s[2]\n"
"fmla v30.4s, v0.4s, v3.s[3]\n"
"fmla v25.4s, v1.4s, v3.s[0]\n"
"fmla v27.4s, v1.4s, v3.s[1]\n"
"fmla v29.4s, v1.4s, v3.s[2]\n"
"fmla v31.4s, v1.4s, v3.s[3]\n"
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"fmla v17.4s, v1.4s, v2.s[0]\n"
"fmla v19.4s, v1.4s, v2.s[1]\n"
"fmla v21.4s, v1.4s, v2.s[2]\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 8x8 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 8x8 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
"ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
// Determine the channel index.
"tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"csel w3, %w[row], %w[col], eq\n"
// Offset the bias pointer as needed given the current row, col.
"add x5, x1, x3, lsl #2\n"
// If there is no bias, use no offset, just address the passed zero
// data.
"tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 8 bias values.
"ld1 {v14.4s}, [x1], #16\n"
"ld1 {v15.4s}, [x1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
"ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
"ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
// Perform the bias-addition.
// Jump based on channel dimension.
"tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
"bne 6f\n"
// Case where channels are rows
"fadd v16.4s, v16.4s, v14.4s\n"
"fadd v17.4s, v17.4s, v15.4s\n"
"fadd v18.4s, v18.4s, v14.4s\n"
"fadd v19.4s, v19.4s, v15.4s\n"
"fadd v20.4s, v20.4s, v14.4s\n"
"fadd v21.4s, v21.4s, v15.4s\n"
"fadd v22.4s, v22.4s, v14.4s\n"
"fadd v23.4s, v23.4s, v15.4s\n"
"fadd v24.4s, v24.4s, v14.4s\n"
"fadd v25.4s, v25.4s, v15.4s\n"
"fadd v26.4s, v26.4s, v14.4s\n"
"fadd v27.4s, v27.4s, v15.4s\n"
"fadd v28.4s, v28.4s, v14.4s\n"
"fadd v29.4s, v29.4s, v15.4s\n"
"fadd v30.4s, v30.4s, v14.4s\n"
"fadd v31.4s, v31.4s, v15.4s\n"
"b 7f\n"
"6:\n"
// Case where channels are columns
"dup v8.4s, v14.s[0]\n"
"dup v9.4s, v14.s[1]\n"
"fadd v16.4s, v16.4s, v8.4s\n"
"dup v10.4s, v14.s[2]\n"
"fadd v17.4s, v17.4s, v8.4s\n"
"dup v11.4s, v14.s[3]\n"
"fadd v18.4s, v18.4s, v9.4s\n"
"dup v12.4s, v15.s[0]\n"
"fadd v19.4s, v19.4s, v9.4s\n"
"dup v13.4s, v15.s[1]\n"
"fadd v20.4s, v20.4s, v10.4s\n"
"dup v14.4s, v15.s[2]\n"
"fadd v21.4s, v21.4s, v10.4s\n"
"dup v15.4s, v15.s[3]\n"
"fadd v22.4s, v22.4s, v11.4s\n"
"fadd v23.4s, v23.4s, v11.4s\n"
"fadd v24.4s, v24.4s, v12.4s\n"
"fadd v25.4s, v25.4s, v12.4s\n"
"fadd v26.4s, v26.4s, v13.4s\n"
"fadd v27.4s, v27.4s, v13.4s\n"
"fadd v28.4s, v28.4s, v14.4s\n"
"fadd v29.4s, v29.4s, v14.4s\n"
"fadd v30.4s, v30.4s, v15.4s\n"
"fadd v31.4s, v31.4s, v15.4s\n"
"7:\n"
// Load the clamp_min, clamp_max bounds
"ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.4s, w2\n" // clamp_min
"dup v15.4s, w3\n" // clamp_max
// Apply the clamp_min bound
"fmax v16.4s, v16.4s, v14.4s\n"
"fmax v17.4s, v17.4s, v14.4s\n"
"fmax v18.4s, v18.4s, v14.4s\n"
"fmax v19.4s, v19.4s, v14.4s\n"
"fmax v20.4s, v20.4s, v14.4s\n"
"fmax v21.4s, v21.4s, v14.4s\n"
"fmax v22.4s, v22.4s, v14.4s\n"
"fmax v23.4s, v23.4s, v14.4s\n"
"fmax v24.4s, v24.4s, v14.4s\n"
"fmax v25.4s, v25.4s, v14.4s\n"
"fmax v26.4s, v26.4s, v14.4s\n"
"fmax v27.4s, v27.4s, v14.4s\n"
"fmax v28.4s, v28.4s, v14.4s\n"
"fmax v29.4s, v29.4s, v14.4s\n"
"fmax v30.4s, v30.4s, v14.4s\n"
"fmax v31.4s, v31.4s, v14.4s\n"
// Apply the clamp_max bound
"fmin v16.4s, v16.4s, v15.4s\n"
"fmin v17.4s, v17.4s, v15.4s\n"
"fmin v18.4s, v18.4s, v15.4s\n"
"fmin v19.4s, v19.4s, v15.4s\n"
"fmin v20.4s, v20.4s, v15.4s\n"
"fmin v21.4s, v21.4s, v15.4s\n"
"fmin v22.4s, v22.4s, v15.4s\n"
"fmin v23.4s, v23.4s, v15.4s\n"
"fmin v24.4s, v24.4s, v15.4s\n"
"fmin v25.4s, v25.4s, v15.4s\n"
"fmin v26.4s, v26.4s, v15.4s\n"
"fmin v27.4s, v27.4s, v15.4s\n"
"fmin v28.4s, v28.4s, v15.4s\n"
"fmin v29.4s, v29.4s, v15.4s\n"
"fmin v30.4s, v30.4s, v15.4s\n"
"fmin v31.4s, v31.4s, v15.4s\n"
// Compute how much of the 8x8 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w2 = how many cols of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 30f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #32\n"
"b 31f\n"
"30:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"31:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
"str q16, [x3, #0]\n"
"str q17, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"str q18, [x3, #0]\n"
"str q19, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
"str q20, [x3, #0]\n"
"str q21, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
"str q22, [x3, #0]\n"
"str q23, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
"str q24, [x3, #0]\n"
"str q25, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
"str q26, [x3, #0]\n"
"str q27, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
"str q28, [x3, #0]\n"
"str q29, [x3, #16]\n"
"add x3, x3, x4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
"str q30, [x3, #0]\n"
"str q31, [x3, #16]\n"
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 41f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #32\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"41:\n"
"add %[dst_ptr], %[dst_ptr], #32\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #8\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #8\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that remain to load
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently depth - 1.
"sub w1, w12, #1\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
}
#undef RUY_OFFSET_BIAS
#undef RUY_OFFSET_FLAGS
#undef RUY_OFFSET_LHS_BASE_PTR
#undef RUY_OFFSET_CLAMP_MIN
#undef RUY_OFFSET_CLAMP_MAX
#undef RUY_OFFSET_START_ROW
#undef RUY_OFFSET_LAST_ROW
#undef RUY_OFFSET_LAST_COL
#undef RUY_OFFSET_LHS_STRIDE
#undef RUY_OFFSET_RHS_STRIDE
#undef RUY_OFFSET_DST_STRIDE
#undef RUY_OFFSET_DEPTH
#undef RUY_OFFSET_START_COL
#undef RUY_OFFSET_RHS_BASE_PTR
#undef RUY_OFFSET_DST_BASE_PTR
#endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
} // namespace ruy