| /* Copyright 2019 Google LLC. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "ruy/kernel_arm.h" |
| #include "ruy/opt_set.h" |
| #include "ruy/platform.h" |
| #include "ruy/profiler/instrumentation.h" |
| |
| namespace ruy { |
| |
| #if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) |
| |
| #define RUY_ASM_LABEL_STORE_UINT8 91 |
| #define RUY_ASM_LABEL_STORE_INT8 92 |
| #define RUY_ASM_LABEL_STORE_INT16 93 |
| #define RUY_ASM_LABEL_STORE_INT32 94 |
| #define RUY_ASM_LABEL_AFTER_STORE 99 |
| |
| #define RUY_OFFSET_LHS_BASE_PTR 0 |
| #define RUY_OFFSET_RHS_BASE_PTR 4 |
| #define RUY_OFFSET_DST_BASE_PTR 8 |
| #define RUY_OFFSET_BIAS 12 |
| #define RUY_OFFSET_START_ROW 16 |
| #define RUY_OFFSET_START_COL 20 |
| #define RUY_OFFSET_LAST_ROW 24 |
| #define RUY_OFFSET_LAST_COL 28 |
| #define RUY_OFFSET_DST_ROWS 32 |
| #define RUY_OFFSET_DST_COLS 36 |
| #define RUY_OFFSET_LHS_STRIDE 40 |
| #define RUY_OFFSET_RHS_STRIDE 44 |
| #define RUY_OFFSET_DST_STRIDE 48 |
| #define RUY_OFFSET_DEPTH 52 |
| #define RUY_OFFSET_CLAMP_MIN 56 |
| #define RUY_OFFSET_CLAMP_MAX 60 |
| #define RUY_OFFSET_FLAGS 64 |
| |
| #define RUY_STACK_OFFSET_SIZE 96 |
| #define RUY_STACK_OFFSET_DST_COL_PTR 0 |
| #define RUY_STACK_OFFSET_DST_PTR 16 |
| #define RUY_STACK_OFFSET_ROW 32 |
| #define RUY_STACK_OFFSET_COL 48 |
| #define RUY_STACK_OFFSET_LHS_COL_PTR 64 |
| #define RUY_STACK_OFFSET_RHS_COL_PTR 80 |
| |
| template <typename Params> |
| void CheckOffsetsInKernelParamsFloat32(const Params&) { |
| static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); |
| static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); |
| static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); |
| static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); |
| static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); |
| static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); |
| static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); |
| static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); |
| static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, ""); |
| static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); |
| static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); |
| static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); |
| static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); |
| static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); |
| static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); |
| static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); |
| } |
| |
| // Float kernel for ARM32 out-of-order cores. |
| // Just like Float 64 version, except accumulate in to 8x4 block to only |
| // use 16 128-bit NEON registers. This is a "first pass" kernel and not |
| // tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9. |
| void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params) { |
| CheckOffsetsInKernelParamsFloat32(params); |
| profiler::ScopeLabel label("Kernel (kNeon)"); |
| |
| const float* lhs_ptr = params.lhs_base_ptr; |
| const float* rhs_ptr = params.rhs_base_ptr; |
| // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are |
| // each composed of two 64-bit "d" registers. The asm kernel below has the |
| // following NEON register allocation: |
| // Registers q3 -- q10 are accumulators. During accumulation, |
| // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1 |
| // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block |
| // of RHS, like this: |
| |
| // Register layout in "q" registers: |
| // RHS 1x4 block |
| // /--------------------------| |
| // |q2.s[0] ... q2.s[3] | |
| // \--------------------------/ |
| // LHS 8x1 block |
| // /---------------------\ /--------------------------| |
| // | q0.s[0] | | q3.s[0] ... q9.s[0] | |
| // | ... | | ... ... | |
| // | q0.s[3] | | q3.s[3] q9.s[3] | |
| // | q1.s[0] | | q4.s[0] q10.s[0] | |
| // | ... | | ... ... ... | |
| // | q1.s[3] | | q4.s[3] .. q10.s[3] | |
| // \---------------------/ \--------------------------/ |
| // accumulators 8x4 block |
| // q11, q14, q15 currently unused. q12 and q13 are used to load |
| // parameters used for the post-accumulation part of the kernel. |
| // For completeness, here is the register layout in "d" registers: |
| // RHS 1x4 block |
| // /--------------------------| |
| // |d4[0] ... d5[1] | |
| // \--------------------------/ |
| // LHS 8x1 block |
| // /---------------------\ /--------------------------| |
| // | d0[0] | | d6[0] ... d18[0] | |
| // | ... | | ... ... | |
| // | d1[1] | | d7[1] d19[1] | |
| // | d2[0] | | d8[0] d20[0] | |
| // | ... | | ... ... ... | |
| // | d3[1] | | d9[1] ... d21[1] | |
| // \---------------------/ \--------------------------/ |
| // accumulators 8x4 block |
| asm volatile( |
| #define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n" |
| |
| // clang-format off |
| |
| // Load the first 32 bytes of LHS and RHS data. |
| // Load q0, q1 |
| "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" |
| "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" |
| RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") |
| // Load q2 |
| "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" |
| RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") |
| |
| "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" |
| "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" |
| "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" |
| "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
| // Clear accumulators. |
| RUY_MAKE_ZERO(q3) |
| RUY_MAKE_ZERO(q4) |
| RUY_MAKE_ZERO(q5) |
| RUY_MAKE_ZERO(q6) |
| RUY_MAKE_ZERO(q7) |
| RUY_MAKE_ZERO(q8) |
| RUY_MAKE_ZERO(q9) |
| RUY_MAKE_ZERO(q10) |
| |
| // r1 is the number of levels of depth that we have already loaded |
| // LHS and RHS data for. Corresponding to the initial ld1 instructions |
| // above, this is currently 1. |
| "mov r1, #1\n" |
| |
| // Main loop of the whole GEMM, over rows and columns of the |
| // destination matrix. |
| "1:\n" |
| |
| // Accumulation loop |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
| "cmp r1, r2\n" |
| "beq 79f\n" |
| |
| "2:\n" |
| |
| "vmla.f32 q3, q0, d4[0]\n" |
| "vmla.f32 q5, q0, d4[1]\n" |
| "vmla.f32 q7, q0, d5[0]\n" |
| "vmla.f32 q9, q0, d5[1]\n" |
| "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS |
| |
| "vmla.f32 q4, q1, d4[0]\n" |
| "vmla.f32 q6, q1, d4[1]\n" |
| "vmla.f32 q8, q1, d5[0]\n" |
| "vmla.f32 q10, q1, d5[1]\n" |
| "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS |
| RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") |
| "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS |
| RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") |
| |
| "add r1, r1, #1\n" |
| "cmp r1, r2\n" |
| |
| "blt 2b\n" |
| |
| "79:\n" |
| |
| // End of the inner loop on depth. Now perform the remaining |
| // multiply-adds of the last level of depth, for which the LHS |
| // and RHS data is already loaded. |
| |
| "vmla.f32 q3, q0, d4[0]\n" |
| "vmla.f32 q5, q0, d4[1]\n" |
| "vmla.f32 q7, q0, d5[0]\n" |
| "vmla.f32 q9, q0, d5[1]\n" |
| |
| "vmla.f32 q4, q1, d4[0]\n" |
| "vmla.f32 q6, q1, d4[1]\n" |
| "vmla.f32 q8, q1, d5[0]\n" |
| "vmla.f32 q10, q1, d5[1]\n" |
| |
| // End of accumulation. The registers q3 -- q10 contain the final |
| // float32 accumulator values of the current 8x8 destination block. |
| // We now have to compute the final values from these accumulators |
| // and advance to the next 8x8 block. We intertwine |
| // these two aspects whenever possible for optimal pipelining, both |
| // at the data flow level (prefetch data for next block as early as |
| // possible) and instruction pipelining level (some of the next-block |
| // work can dual-issue with some of the final work on the current |
| // block). |
| |
| // Logic to advance to the next block in preparation for the next |
| // iteration of the main loop. For now, we only want to compute |
| // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
| // not yet ready to update the values of row and col, as we still need |
| // the current values for the rest of the work on the current block. |
| |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
| "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "cmp r1, r3\n" // Have we finished the last row? |
| |
| "bge 4f\n" // If finished last row, go to 4 |
| // Not finished last row: then advance to next row. |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| "add r4, r4, r1, lsl #3\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| "b 5f\n" |
| "4:\n" // Finished last row... |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
| // Go back to first row |
| "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| // Now we need to advance to the next column. If we already |
| // finished the last column, then in principle we are done, however |
| // we can't just return here, as we need to allow the end work of the |
| // current block to complete. The good news is that at this point it |
| // doesn't matter what data we load for the next column, since |
| // we will exit from the main loop below before actually storing |
| // anything computed from that data. |
| "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "cmp r8, r4\n" // Have we finished the last column? |
| "bge 5f\n" // If yes, just carry on without updating the column pointer. |
| // Not finished last column: then advance to next column. |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
| "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
| "add r10, r10, r1, lsl #2\n" |
| "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
| "5:\n" |
| |
| // Set the LHS and RHS data pointers to the start of the columns just |
| // computed. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| "mov %[lhs_ptr], r4\n" |
| "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
| "mov %[rhs_ptr], r5\n" |
| |
| // Load some parameters needed for the end work on current block. |
| "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
| |
| // Let r8 be stack offset of the row or column variable, whichever |
| // is the channel index. |
| "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
| "ite eq\n" |
| "moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n" |
| "movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n" |
| // Let r8 be the channel index. |
| "ldr r8, [sp, r8]\n" |
| // Compute the bias pointer, by conditionally using the channel index |
| // (r8) as offset into bias buffer (r1). |
| "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
| "it ne\n" |
| "addne r1, r1, r8, lsl #2\n" |
| |
| // Load 4 bias values. When the channel dimension is rows, we will load |
| // another 4 bias values just before performing the bias addition below, |
| // as this kernel has a 8x4 rectangular shape. |
| "vld1.32 {d24, d25}, [r1]!\n" |
| |
| // Now that we know what LHS and RHS data the next iteration of the |
| // main loop will need to load, we start loading the first 32 bytes of |
| // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore |
| // in the rest of the work on the current block. |
| // Load q0, q1 |
| "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" |
| RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") |
| // Load q2 |
| "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" |
| RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") |
| |
| // Perform the bias-addition. |
| // Jump based on channel dimension. |
| "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
| "bne 6f\n" |
| // Case where channels are rows. |
| // Load the remaining 4 bias values, since we're on the width-8 side |
| // of this 8x4 kernel. |
| "vld1.32 {d26, d27}, [r1]\n" |
| "vadd.f32 q3, q3, q12\n" |
| "vadd.f32 q5, q5, q12\n" |
| "vadd.f32 q7, q7, q12\n" |
| "vadd.f32 q9, q9, q12\n" |
| "vadd.f32 q4, q4, q13\n" |
| "vadd.f32 q6, q6, q13\n" |
| "vadd.f32 q8, q8, q13\n" |
| "vadd.f32 q10, q10, q13\n" |
| "b 7f\n" |
| |
| "6:\n" |
| // Case where channels are columns. |
| "vdup.32 q11, d24[0]\n" |
| "vdup.32 q13, d24[1]\n" |
| "vdup.32 q14, d25[0]\n" |
| "vdup.32 q15, d25[1]\n" |
| "vadd.f32 q3, q3, q11\n" |
| "vadd.f32 q4, q4, q11\n" |
| "vadd.f32 q5, q5, q13\n" |
| "vadd.f32 q6, q6, q13\n" |
| "vadd.f32 q7, q7, q14\n" |
| "vadd.f32 q8, q8, q14\n" |
| "vadd.f32 q9, q9, q15\n" |
| "vadd.f32 q10, q10, q15\n" |
| "7:\n" |
| |
| // Load the clamp_min, clamp_max bounds |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
| "vdup.32 q12, r2\n" // clamp_min |
| "vdup.32 q13, r3\n" // clamp_max |
| |
| // Apply the clamp_min bound |
| "vmax.f32 q3, q3, q12\n" |
| "vmax.f32 q4, q4, q12\n" |
| "vmax.f32 q5, q5, q12\n" |
| "vmax.f32 q6, q6, q12\n" |
| "vmax.f32 q7, q7, q12\n" |
| "vmax.f32 q8, q8, q12\n" |
| "vmax.f32 q9, q9, q12\n" |
| "vmax.f32 q10, q10, q12\n" |
| |
| // Apply the clamp_max bound |
| "vmin.f32 q3, q3, q13\n" |
| "vmin.f32 q4, q4, q13\n" |
| "vmin.f32 q5, q5, q13\n" |
| "vmin.f32 q6, q6, q13\n" |
| "vmin.f32 q7, q7, q13\n" |
| "vmin.f32 q8, q8, q13\n" |
| "vmin.f32 q9, q9, q13\n" |
| "vmin.f32 q10, q10, q13\n" |
| |
| // Compute how much of the 8x4 block of destination values that |
| // we have computed, fit in the destination matrix. Typically, all of |
| // it fits, but when the destination matrix shape is not a multiple |
| // of 8x4, there are some 8x8 blocks along the boundaries that do |
| // not fit entirely. |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "sub r1, r1, r8\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "sub r2, r2, r4\n" |
| "mov r3, #8\n" |
| "mov r5, #4\n" |
| "cmp r1, #8\n" |
| // Compute r1 = how many rows of the 8x4 block fit |
| "it gt\n" |
| "movgt r1, r3\n" |
| "cmp r2, #4\n" |
| // Compute r2 = how many cols of the 8x4 block fit |
| "it gt\n" |
| "movgt r2, r5\n" |
| |
| // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits. |
| "cmp r1, r3\n" |
| "it eq\n" |
| "cmpeq r2, r5\n" |
| // Yes, all of the 8x4 block fits, go to fast path. |
| "beq 30f\n" |
| // Not all of the 8x4 block fits. |
| // Set (r3 address, r4 stride) to write to dst_tmp_buf |
| "mov r3, %[dst_tmp_buf]\n" |
| "mov r4, #32\n" |
| "b 31f\n" |
| "30:\n" |
| // Yes, all of the 8x4 block fits. |
| // Set (r3 address, r4 stride) to write directly to destination matrix. |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "mov r4, r5\n" |
| "31:\n" |
| |
| // Write our float values to the destination described by |
| // (r3 address, r4 stride) |
| "vst1.32 {d6, d7, d8, d9}, [r3]\n" |
| "add r3, r3, r4\n" |
| RUY_MAKE_ZERO(q3) |
| RUY_MAKE_ZERO(q4) |
| "vst1.32 {d10, d11, d12, d13}, [r3]\n" |
| "add r3, r3, r4\n" |
| RUY_MAKE_ZERO(q5) |
| RUY_MAKE_ZERO(q6) |
| "vst1.32 {d14, d15, d16, d17}, [r3]\n" |
| "add r3, r3, r4\n" |
| RUY_MAKE_ZERO(q7) |
| RUY_MAKE_ZERO(q8) |
| "vst1.32 {d18, d19, d20, d21}, [r3]\n" |
| "add r3, r3, r4\n" |
| RUY_MAKE_ZERO(q9) |
| RUY_MAKE_ZERO(q10) |
| |
| // If all of the 8x4 block fits, we just finished writing it to the |
| // destination, so we skip the next part. |
| "beq 41f\n" |
| // Not all of the 8x8 block fits in the destination matrix. We just |
| // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
| // it to copy into the destination matrix the part that fits. |
| "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| "mov r3, %[dst_tmp_buf]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "mov r6, #0\n" |
| "50:\n" |
| "mov r5, #0\n" |
| "51:\n" |
| "ldr r10, [r3, r5, lsl #2]\n" |
| "str r10, [r4, r5, lsl #2]\n" |
| "add r5, r5, #1\n" |
| "cmp r5, r1\n" |
| "blt 51b\n" |
| "add r6, r6, #1\n" |
| "add r3, r3, #32\n" |
| "add r4, r4, r8\n" |
| // r2 = how many cols of the 8x4 block fit |
| "cmp r6, r2\n" |
| "blt 50b\n" |
| "41:\n" |
| // Load dst_ptr, increment, and write back. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "add r4, r4, #32\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| // At this point we have completely finished writing values to the |
| // destination matrix for the current block. |
| |
| // Reload some params --- we had used r3, r5, r10 for a few other things |
| // since the last time we had loaded them. |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
| "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
| |
| // Move to the next block of the destination matrix, for the next iter |
| // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
| // been updated earlier. |
| // Have we reached the end row? |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "cmp r8, r3\n" |
| |
| "beq 20f\n" // yes, end row. |
| // Not end row. Move to the next row. |
| "add r8, r8, #8\n" |
| // Store new value of row |
| "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| |
| "b 21f\n" |
| "20:\n" |
| // Was already at end row. |
| // Move back to first row. |
| "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| // Move to the next column. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "add r4, r4, #4\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| |
| "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
| // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns) |
| "add r1, r1, r8, lsl #2\n" |
| // Store dst_col_ptr |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
| // Store dst_ptr |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "21:\n" |
| |
| // Main loop exit condition: have we hit the end column? |
| "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "cmp r8, r4\n" |
| |
| // r1 is the number of levels of depth that we have already loaded |
| // LHS and RHS data for. Corresponding to the initial ld1 instructions |
| // above, this is currently 1. |
| "mov r1, #1\n" |
| |
| "ble 1b\n" |
| |
| // Restore stack pointer. |
| "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" |
| |
| // clang-format on |
| : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) |
| : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) |
| // Clobber list must specify q registers (and not their constituent |
| // d registers). There is a (currently unexplained) slowdown if |
| // d registers are listed in the clobbers list. |
| : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", |
| "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", |
| "q9", "q10", "q12", "q13"); |
| } |
| |
| #undef RUY_MAKE_ZERO |
| #undef RUY_STACK_OFFSET_SIZE |
| #undef RUY_STACK_OFFSET_DST_COL_PTR |
| #undef RUY_STACK_OFFSET_DST_PTR |
| #undef RUY_STACK_OFFSET_ROW |
| #undef RUY_STACK_OFFSET_COL |
| #undef RUY_STACK_OFFSET_LHS_COL_PTR |
| #undef RUY_STACK_OFFSET_RHS_COL_PTR |
| |
| #undef RUY_OFFSET_LHS_BASE_PTR |
| #undef RUY_OFFSET_RHS_BASE_PTR |
| #undef RUY_OFFSET_DST_BASE_PTR |
| #undef RUY_OFFSET_BIAS |
| #undef RUY_OFFSET_START_ROW |
| #undef RUY_OFFSET_START_COL |
| #undef RUY_OFFSET_LAST_ROW |
| #undef RUY_OFFSET_LAST_COL |
| #undef RUY_OFFSET_DST_ROWS |
| #undef RUY_OFFSET_DST_COLS |
| #undef RUY_OFFSET_LHS_STRIDE |
| #undef RUY_OFFSET_RHS_STRIDE |
| #undef RUY_OFFSET_DST_STRIDE |
| #undef RUY_OFFSET_DEPTH |
| #undef RUY_OFFSET_CLAMP_MIN |
| #undef RUY_OFFSET_CLAMP_MAX |
| #undef RUY_OFFSET_FLAGS |
| |
| #define RUY_OFFSET_BIAS 0 |
| #define RUY_OFFSET_LHS_SUMS 4 |
| #define RUY_OFFSET_RHS_SUMS 8 |
| #define RUY_OFFSET_LHS_BASE_PTR 12 |
| #define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16 |
| #define RUY_OFFSET_MULTIPLIER_EXPONENT 20 |
| #define RUY_OFFSET_RHS_BASE_PTR 24 |
| #define RUY_OFFSET_DST_BASE_PTR 28 |
| #define RUY_OFFSET_LHS_ZERO_POINT 32 |
| #define RUY_OFFSET_RHS_ZERO_POINT 36 |
| #define RUY_OFFSET_DST_ZERO_POINT 40 |
| #define RUY_OFFSET_PROD_ZP_DEPTH 44 |
| #define RUY_OFFSET_START_ROW 48 |
| #define RUY_OFFSET_START_COL 52 |
| #define RUY_OFFSET_LAST_ROW 56 |
| #define RUY_OFFSET_LAST_COL 60 |
| #define RUY_OFFSET_DST_ROWS 64 |
| #define RUY_OFFSET_DST_COLS 68 |
| #define RUY_OFFSET_LHS_STRIDE 72 |
| #define RUY_OFFSET_RHS_STRIDE 76 |
| #define RUY_OFFSET_DST_STRIDE 80 |
| #define RUY_OFFSET_DEPTH 84 |
| #define RUY_OFFSET_CLAMP_MIN 88 |
| #define RUY_OFFSET_CLAMP_MAX 92 |
| #define RUY_OFFSET_FLAGS 96 |
| #define RUY_OFFSET_DST_TYPE_ID 97 |
| |
| #define RUY_STACK_OFFSET_SIZE 96 |
| #define RUY_STACK_OFFSET_DST_COL_PTR 0 |
| #define RUY_STACK_OFFSET_DST_PTR 16 |
| #define RUY_STACK_OFFSET_ROW 32 |
| #define RUY_STACK_OFFSET_COL 48 |
| #define RUY_STACK_OFFSET_LHS_COL_PTR 64 |
| #define RUY_STACK_OFFSET_RHS_COL_PTR 80 |
| |
| template <typename Params> |
| void CheckOffsetsInKernelParams8bit(const Params&) { |
| static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, |
| ""); |
| static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, |
| ""); |
| static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, |
| ""); |
| static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, |
| ""); |
| static_assert(offsetof(Params, multiplier_fixedpoint) == |
| RUY_OFFSET_MULTIPLIER_FIXEDPOINT, |
| ""); |
| static_assert( |
| offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, |
| ""); |
| static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); |
| static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); |
| static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); |
| static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); |
| static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); |
| static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); |
| static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); |
| static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); |
| static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); |
| static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); |
| static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); |
| static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); |
| static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); |
| static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); |
| } |
| |
| // Fast-int8 kernel, ported from ARM 64 version. |
| // Relevant target CPUs for this kernel include Krait 400 and A9, |
| // since these are 32-bit, out-of-order CPUs. |
| void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) { |
| profiler::ScopeLabel label("Kernel (kNeon)"); |
| |
| CheckOffsetsInKernelParams8bit(params); |
| |
| const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
| const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; |
| const std::int8_t* lhs_ptr = lhs_col_ptr; |
| const std::int8_t* rhs_ptr = rhs_col_ptr; |
| |
| // The asm kernel below has the following NEON register allocation: |
| // |
| // q6 - q13 are 128-bit (4x32b) accumulators. |
| // During accumulation, d0 -- d7 are used to load int8 data from LHS and |
| // d8 -- d11 from RHS: |
| // int8 RHS 16x2 block |
| // /-----------------------------| |
| // |d8.b[0-7] ..... d10.b[0-7]| |
| // | ... ... | |
| // |d9.b[0-7] ..... d11.b[0-7]| |
| // \-----------------------------/ |
| // int8 LHS 4x16 block |
| // /------------------------\ /-----------------------------| |
| // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 | |
| // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 | |
| // (Reload d0, d1, d2, d3) |
| // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 | |
| // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 | |
| // \------------------------/ \-----------------------------/ |
| // 128-bit accumulators 4x2 block |
| // |
| // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING |
| // optimization for this kernel. |
| asm volatile( |
| #define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" |
| |
| // clang-format off |
| |
| // Load the first 64 bytes of LHS and RHS data. |
| "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" |
| // Clear accumulators. |
| RUY_MAKE_ZERO(q6) |
| RUY_MAKE_ZERO(q7) |
| "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" |
| RUY_MAKE_ZERO(q8) |
| RUY_MAKE_ZERO(q9) |
| RUY_MAKE_ZERO(q10) |
| "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" |
| RUY_MAKE_ZERO(q11) |
| "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n" |
| |
| "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" |
| RUY_MAKE_ZERO(q12) |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" |
| RUY_MAKE_ZERO(q13) |
| "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
| RUY_MAKE_ZERO(q14) |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" |
| RUY_MAKE_ZERO(q15) |
| "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" |
| "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
| |
| |
| // r1 is the number of levels of depth that we have already loaded |
| // LHS and RHS data for. Corresponding to the initial ld1 instructions |
| // above, this is currently 16. |
| "mov r1, #16\n" |
| |
| // Main loop of the whole GEMM, over rows and columns of the |
| // destination matrix. |
| "1:\n" |
| |
| // r1 is how many levels of depth we have already loaded |
| // data for, r10 is the total depth. |
| "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
| "cmp r1, r10\n" |
| "beq 79f\n" |
| |
| "2:\n" |
| |
| // Mult, mult-acc in to q14, q15, q2, q3 |
| "vmull.s8 q14, d0, d8\n" |
| "vmull.s8 q2, d0, d10\n" |
| |
| "vmull.s8 q15, d2, d8\n" |
| "vmull.s8 q3, d2, d10\n" |
| |
| "vmlal.s8 q14, d1, d9\n" |
| "vmlal.s8 q2, d1, d11\n" |
| "vmlal.s8 q15, d3, d9\n" |
| "vmlal.s8 q3, d3, d11\n" |
| "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS |
| |
| // Then pairwise accumulate in to q6, q7, q10, q11 |
| "vpadal.s16 q6, q14\n" |
| "vpadal.s16 q7, q15\n" |
| "vpadal.s16 q10, q2\n" |
| "vpadal.s16 q11, q3\n" |
| |
| // Mult, mult-acc in to q14, q15, q2, q3 |
| "vmull.s8 q14, d0, d8\n" |
| "vmull.s8 q2, d0, d10\n" |
| |
| "vmull.s8 q15, d2, d8\n" |
| "vmull.s8 q3, d2, d10\n" |
| |
| "vmlal.s8 q14, d1, d9\n" |
| "vmlal.s8 q2, d1, d11\n" |
| "vmlal.s8 q15, d3, d9\n" |
| "vmlal.s8 q3, d3, d11\n" |
| "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS |
| |
| // Then pairwise accumulate in to q8, q9, q12, q13 |
| "vpadal.s16 q8, q14\n" |
| "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" |
| "vpadal.s16 q9, q15\n" |
| "vpadal.s16 q12, q2\n" |
| "vpadal.s16 q13, q3\n" |
| |
| // Prefetch the next 64 bytes of LHS and RHS data. |
| RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") |
| RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") |
| |
| // Each iteration of this loop advances by 16 levels of depth. |
| "add r1, r1, #16\n" |
| |
| // Loop termination condition |
| "cmp r1, r10\n" |
| |
| "blt 2b\n" |
| |
| "79:\n" |
| |
| // Mult, mult-acc in to q14, q15, q2, q3 |
| "vmull.s8 q14, d0, d8\n" |
| "vmull.s8 q2, d0, d10\n" |
| |
| "vmull.s8 q15, d2, d8\n" |
| "vmull.s8 q3, d2, d10\n" |
| |
| "vmlal.s8 q14, d1, d9\n" |
| "vmlal.s8 q2, d1, d11\n" |
| "vmlal.s8 q15, d3, d9\n" |
| "vmlal.s8 q3, d3, d11\n" |
| "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS |
| |
| // Then pairwise accumulate in to q6, q7, q10, q11 |
| "vpadal.s16 q6, q14\n" |
| "vpadal.s16 q7, q15\n" |
| "vpadal.s16 q10, q2\n" |
| "vpadal.s16 q11, q3\n" |
| |
| // Mult, mult-acc in to q14, q15, q2, q3 |
| "vmull.s8 q14, d0, d8\n" |
| "vmull.s8 q2, d0, d10\n" |
| |
| "vmull.s8 q15, d2, d8\n" |
| "vmull.s8 q3, d2, d10\n" |
| |
| "vmlal.s8 q14, d1, d9\n" |
| "vmlal.s8 q2, d1, d11\n" |
| "vmlal.s8 q15, d3, d9\n" |
| "vmlal.s8 q3, d3, d11\n" |
| |
| // Then pairwise accumulate in to q8, q9, q12, q13 |
| "vpadal.s16 q8, q14\n" |
| "vpadal.s16 q9, q15\n" |
| "vpadal.s16 q12, q2\n" |
| "vpadal.s16 q13, q3\n" |
| |
| |
| // All accumulation over depth done. q6 - q13 contain the 4x32b |
| // accumulators for the 4x2 final matrix. |
| // We now have to compute the final 8-bit values from these int32 |
| // accumulators, and advance to the next 4x2 block. We intertwine |
| // these two aspects whenever possible for optimal pipelining, both |
| // at the data flow level (prefetch data for next block as early as |
| // possible) and instruction pipelining level (some of the next-block |
| // work can dual-issue with some of the final work on the current |
| // block). |
| |
| // q6-q13 now contain 4 x 32b |
| "vpadd.i32 d0, d12, d13\n" |
| "vpadd.i32 d1, d14, d15\n" |
| "vpadd.i32 d2, d16, d17\n" |
| "vpadd.i32 d3, d18, d19\n" |
| "vpadd.i32 d4, d20, d21\n" |
| "vpadd.i32 d5, d22, d23\n" |
| "vpadd.i32 d6, d24, d25\n" |
| "vpadd.i32 d7, d26, d27\n" |
| |
| // d0-d7 each contain 2 x 32b accumulators. |
| // Need to add pairwise to get 1 x 32b for each of the 4x2 entries |
| // of destination, (Four 'd' registers total) |
| "vpadd.i32 d28, d0, d1\n" |
| "vpadd.i32 d29, d2, d3\n" |
| "vpadd.i32 d30, d4, d5\n" |
| "vpadd.i32 d31, d6, d7\n" |
| |
| //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries |
| |
| // Logic to advance to the next block in preparation for the next |
| // iteration of the main loop. For now, we only want to compute |
| // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
| // not yet ready to update the values of row and col, as we still need |
| // the current values for the rest of the work on the current block. |
| |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
| "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "cmp r1, r3\n" // Have we finished the last row? |
| |
| "bge 4f\n" // If finished last row, go to 4 |
| // Not finished last row: then advance to next row. |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| "add r4, r4, r1, lsl #2\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| "b 5f\n" |
| "4:\n" // Finished last row... |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
| // Go back to first row |
| "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| |
| // Now we need to advance to the next column. If we already |
| // finished the last column, then in principle we are done, however |
| // we can't just return here, as we need to allow the end work of the |
| // current block to complete. The good news is that at this point it |
| // doesn't matter what data we load for the next column, since |
| // we will exit from the main loop below before actually storing |
| // anything computed from that data. |
| |
| "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "cmp r8, r4\n" // Have we finished the last column? |
| "bge 5f\n" // If yes, just carry on without updating the column pointer. |
| // Not finished last column: then advance to next column. |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
| "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
| "add r10, r10, r1, lsl #1\n" |
| "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
| "5:\n" |
| |
| // Set the LHS and RHS data pointers to the start of the columns just |
| // computed. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| "mov %[lhs_ptr], r4\n" |
| "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
| "mov %[rhs_ptr], r5\n" |
| |
| // Now we load: bias data, LHS sums data, RHS sums data. |
| |
| // First, load the base pointers from the params. |
| "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
| |
| // Let r8 be stack offset of the row or column variable, whichever |
| // is the channel index. |
| "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
| "ite eq\n" |
| "moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n" |
| "movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n" |
| // Let r8 be the channel index. |
| "ldr r8, [sp, r8]\n" |
| // Compute the bias pointer, by conditionally using the channel index |
| // (r8) as offset into bias buffer (r1). |
| "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
| "it ne\n" |
| "addne r1, r1, r8, lsl #2\n" |
| |
| // Load 2 bias values. When the channel dimension is rows, we will load |
| // another 2 bias values just before performing the bias addition below, |
| // as this kernel has a 4x2 rectangular shape. |
| "vld1.32 {d24}, [r1]!\n" |
| |
| // Now that we know what LHS and RHS data the next iteration of the |
| // main loop will need to load, we start loading the first 32 bytes of |
| // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
| // in the rest of the work on the current block. |
| "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" |
| RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") |
| "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" |
| RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") |
| |
| // Add to the bias values the product |
| // (depth * lhs_zero_point * rhs_zero_point), |
| // See the term NZ1Z2 in equation (7) in |
| // https://arxiv.org/pdf/1712.05877.pdf |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" |
| "vdup.32 q9, r3\n" |
| "vadd.i32 d24, d24, d18\n" |
| |
| // Perform the bias-addition (per the above, we have just folded into |
| // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) |
| // Jump based on channel dimension. |
| "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
| "bne 6f\n" |
| // Case where channels are rows. |
| // Load the remaining 2 bias values, since we're on the width-4 side |
| // of this 4x2 kernel. |
| "vld1.32 {d25}, [r1]\n" |
| "vadd.i32 d25, d25, d19\n" |
| "vadd.i32 q14, q14, q12\n" |
| "vadd.i32 q15, q15, q12\n" |
| "b 7f\n" |
| |
| "6:\n" |
| // Case where channels are columns. |
| "vdup.32 q10, d24[0]\n" |
| "vdup.32 q11, d24[1]\n" |
| "vadd.i32 q14, q14, q10\n" |
| "vadd.i32 q15, q15, q11\n" |
| "7:\n" |
| |
| // LHS/RHS zero points |
| // Has RHS sums |
| "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
| "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" |
| "beq 401f\n" |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| // Offset by current col * number of bytes per value |
| "add r3, r3, r4, lsl #2\n" |
| "vld1.32 { d12 }, [r3]\n" |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" |
| "vdup.32 q10, r5\n" // create lhs_zero_point_vec |
| // Subtract rhs_sums * lhs_zero_point, per |
| // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
| "vmls.i32 q14, q10, d12[0]\n" |
| "vmls.i32 q15, q10, d12[1]\n" |
| "401:\n" |
| |
| // Has LHS sums |
| "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
| "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" |
| "beq 402f\n" |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| // Offset by current row * number of bytes per value |
| "add r2, r2, r4, lsl #2\n" |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" |
| |
| // Load 4 lhs_sums values. |
| "vld1.32 {d22, d23}, [r2]\n" |
| "vdup.32 d13, r5\n" // rhs_zero_point |
| |
| // Compute lhs_sums * rhs_zero_point. |
| "vmul.i32 q11, q11, d13[1]\n" |
| // Subtract lhs_sums * rhs_zero_point, per |
| // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
| "vsub.s32 q14, q14, q11\n" |
| "vsub.s32 q15, q15, q11\n" |
| |
| // If the destination is int32, it means the user asks for the raw |
| // accumulators, no need for us to downquantize the value. |
| "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" |
| "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" |
| "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" |
| |
| "402:\n" |
| |
| // At this point we have computed the final int32 values. Now we |
| // start down-quantizing them to obtain the final 8bit values from them. |
| |
| // As part of this down-quantization, our int32 values will be |
| // multiplied by a multiplier that has a fixed-point component and an |
| // exponent component. |
| |
| // Compute the data pointers for the multiplier data |
| // r1 = exponent part |
| // r2 = fixedpoint part |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" |
| // r6 has flags, r8 has channel index |
| "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
| "it ne\n" |
| "addne r1, r1, r8, lsl #2\n" |
| "it ne\n" |
| "addne r2, r2, r8, lsl #2\n" |
| |
| // Load the first 2 values of multiplier exponent and fixedpoint data |
| // Since this kernel is rectangular 4x2, we will only conditionally load |
| // 2 more values below. |
| "vld1.32 {d20}, [r1]!\n" // 2 values of multiplier_exponent |
| "vld1.32 {d12}, [r2]!\n" // 2 values of multiplier_fixedpoint |
| |
| "tst r6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
| "vmvn.i32 q8, #0\n" |
| "bne 8f\n" |
| // Case where channels are rows. |
| // Load the remaining 2 bias values, since we're on the width-4 side |
| // of this 4x2 kernel. |
| "vld1.32 {d21}, [r1]\n" // 2 more values of multiplier_exponent |
| "vld1.32 {d13}, [r2]\n" // 2 more values of multiplier_fixedpoint |
| "vmin.s32 q11, q10, q8\n" |
| "vsub.s32 q10, q10, q11\n" |
| |
| // Apply the positive exponent part of the multiplier. |
| "vshl.s32 q14, q14, q10\n" |
| "vshl.s32 q15, q15, q10\n" |
| |
| // Apply the fixed-point part of the multiplier. |
| "vqdmulh.s32 q14, q14, q6\n" |
| "vqdmulh.s32 q15, q15, q6\n" |
| |
| // Apply the negative exponent part of the multiplier. |
| "vrshl.s32 q14, q14, q11\n" |
| "vrshl.s32 q15, q15, q11\n" |
| "b 9f\n" |
| |
| "8:\n" |
| // Case where channels are columns. |
| "vmin.s32 d22, d20, d16\n" |
| "vsub.s32 d20, d20, d22\n" |
| |
| // Apply the positive exponent part of the multiplier. |
| "vdup.32 q12, d20[0]\n" |
| "vdup.32 q13, d20[1]\n" |
| "vshl.s32 q14, q14, q12\n" |
| "vshl.s32 q15, q15, q13\n" |
| |
| // Apply the fixed-point part of the multiplier. |
| "vqdmulh.s32 q14, q14, d12[0]\n" |
| "vqdmulh.s32 q15, q15, d12[1]\n" |
| |
| // Apply the negative exponent part of the multiplier. |
| "vdup.32 q12, d22[0]\n" |
| "vdup.32 q13, d22[1]\n" |
| "vrshl.s32 q14, q14, q12\n" |
| "vrshl.s32 q15, q15, q13\n" |
| |
| "9:\n" |
| |
| "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" |
| "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" |
| "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" |
| "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" |
| "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" |
| |
| // Store uint8 values: |
| RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" |
| |
| // Cast-and-saturate from int32 to int16 |
| // After this, all values for output are in q14. |
| "vqmovn.s32 d28, q14\n" |
| "vqmovn.s32 d29, q15\n" |
| |
| // At this point, d12 -- d26, d30, d31 aren't used anymore for the |
| // current block, so we can start clearing these accumulators for the |
| // next block (next iteration of the main loop). |
| RUY_MAKE_ZERO(q6) |
| RUY_MAKE_ZERO(q7) |
| RUY_MAKE_ZERO(q8) |
| RUY_MAKE_ZERO(q9) |
| RUY_MAKE_ZERO(q10) |
| RUY_MAKE_ZERO(q11) |
| RUY_MAKE_ZERO(q12) |
| RUY_MAKE_ZERO(q13) |
| RUY_MAKE_ZERO(q15) |
| |
| // Load the destination zero point into each of the 8 16-bit slots |
| // in a q register. |
| "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
| "vdup.16 q13, r4\n" // dst_zero_point |
| |
| // Add the destination zero point |
| "vqadd.s16 q14, q14, q13\n" |
| |
| // Cast-and-saturate from int16 to uint8 |
| // Now all 8 1-byte values are in d30. |
| "vqmovun.s16 d30, q14\n" |
| |
| // Load the clamp_min, clamp_max bounds |
| "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
| "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
| "vdup.8 d28, r2\n" // clamp_min |
| "vdup.8 d29, r3\n" // clamp_max |
| |
| // Apply the clamp_min bound |
| "vmax.u8 d30, d30, d28\n" |
| // Apply the clamp_max bound |
| "vmin.u8 d30, d30, d29\n" |
| |
| // Compute how much of the 4x2 block of destination 8bit values that |
| // we have computed, fit in the destination matrix. Typically, all of |
| // it fits, but when the destination matrix shape is not a multiple |
| // of 4x2, there are some 4x2 blocks along the boundaries that do |
| // not fit entirely. |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "sub r1, r1, r8\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "sub r2, r2, r4\n" |
| "mov r3, #4\n" |
| "mov r5, #2\n" |
| "cmp r1, #4\n" |
| // Compute r1 = how many rows of the 4x2 block fit |
| "it gt\n" |
| "movgt r1, r3\n" |
| |
| "cmp r2, #2\n" |
| // Compute r2 = how many cols of the 4x2 block fit |
| "it gt\n" |
| "movgt r2, r5\n" |
| |
| // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. |
| "cmp r1, r3\n" |
| "it eq\n" |
| "cmpeq r2, r5\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| // Yes, all of the 4x2 block fits, go to fast path. |
| "beq 30f\n" |
| // Not all of the 4x2 block fits. |
| // Store to dst_tmp_buf |
| // Set r3 address to write to dst_tmp_buf. |
| "mov r3, %[dst_tmp_buf]\n" |
| "vst1.8 {d30}, [r3]\n" |
| |
| // Slow loop copying from dst_tmp_buf to dst. |
| "mov r6, #0\n" |
| "50:\n" |
| "mov r8, #0\n" |
| "51:\n" |
| "ldrb r10, [r3, r8]\n" |
| "strb r10, [r4, r8]\n" |
| "add r8, r8, #1\n" |
| "cmp r8, r1\n" |
| "blt 51b\n" |
| "add r6, r6, #1\n" |
| "add r3, r3, #4\n" |
| "add r4, r4, r5\n" |
| "cmp r6, r2\n" |
| "blt 50b\n" |
| "b 31f\n" |
| "30:\n" |
| // Yes, all of the 4x2 block fits. |
| // r3 address, r5 stride |
| "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "mov r4, r3\n" |
| "mov r6, #1\n" |
| |
| "vst1.32 {d30[0]}, [r3]\n" |
| "add r4, r4, r5\n" |
| "mov r3, r4\n" |
| "vst1.32 {d30[1]}, [r3]\n" |
| |
| "31:\n" |
| |
| // Load dst_ptr, increment, and write back. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "add r4, r4, #4\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| |
| RUY_MAKE_ZERO(q13) |
| RUY_MAKE_ZERO(q14) |
| RUY_MAKE_ZERO(q15) |
| |
| "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
| |
| // Store int8 values: |
| RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" |
| |
| // Cast-and-saturate from int32 to int16 |
| // After this, all values for output are in q14. |
| "vqmovn.s32 d28, q14\n" |
| "vqmovn.s32 d29, q15\n" |
| |
| // At this point, d12 -- d26, d30, d31 aren't used anymore for the |
| // current block, so we can start clearing these accumulators for the |
| // next block (next iteration of the main loop). |
| RUY_MAKE_ZERO(q6) |
| RUY_MAKE_ZERO(q7) |
| RUY_MAKE_ZERO(q8) |
| RUY_MAKE_ZERO(q9) |
| RUY_MAKE_ZERO(q10) |
| RUY_MAKE_ZERO(q11) |
| RUY_MAKE_ZERO(q12) |
| RUY_MAKE_ZERO(q13) |
| RUY_MAKE_ZERO(q15) |
| |
| // Load the destination zero point into each of the 8 16-bit slots |
| // in a q register. |
| "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
| "vdup.16 q13, r4\n" // dst_zero_point |
| |
| // Add the destination zero point |
| "vqadd.s16 q14, q14, q13\n" |
| |
| // Cast-and-saturate from int16 to int8 |
| // Now all 8 1-byte values are in d30. |
| "vqmovn.s16 d30, q14\n" |
| |
| // Load the clamp_min, clamp_max bounds |
| "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
| "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
| "vdup.8 d28, r2\n" // clamp_min |
| "vdup.8 d29, r3\n" // clamp_max |
| |
| // Apply the clamp_min bound |
| "vmax.s8 d30, d30, d28\n" |
| // Apply the clamp_max bound |
| "vmin.s8 d30, d30, d29\n" |
| |
| // Compute how much of the 4x2 block of destination 8bit values that |
| // we have computed, fit in the destination matrix. Typically, all of |
| // it fits, but when the destination matrix shape is not a multiple |
| // of 4x2, there are some 4x2 blocks along the boundaries that do |
| // not fit entirely. |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "sub r1, r1, r8\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "sub r2, r2, r4\n" |
| "mov r3, #4\n" |
| "mov r5, #2\n" |
| "cmp r1, #4\n" |
| // Compute r1 = how many rows of the 4x2 block fit |
| "it gt\n" |
| "movgt r1, r3\n" |
| |
| "cmp r2, #2\n" |
| // Compute r2 = how many cols of the 4x2 block fit |
| "it gt\n" |
| "movgt r2, r5\n" |
| |
| // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. |
| "cmp r1, r3\n" |
| "it eq\n" |
| "cmpeq r2, r5\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| // Yes, all of the 4x2 block fits, go to fast path. |
| "beq 30f\n" |
| // Not all of the 4x2 block fits. |
| // Store to dst_tmp_buf |
| // Set r3 address to write to dst_tmp_buf. |
| "mov r3, %[dst_tmp_buf]\n" |
| "vst1.8 {d30}, [r3]\n" |
| |
| // Slow loop copying from dst_tmp_buf to dst. |
| "mov r6, #0\n" |
| "50:\n" |
| "mov r8, #0\n" |
| "51:\n" |
| "ldrb r10, [r3, r8]\n" |
| "strb r10, [r4, r8]\n" |
| "add r8, r8, #1\n" |
| "cmp r8, r1\n" |
| "blt 51b\n" |
| "add r6, r6, #1\n" |
| "add r3, r3, #4\n" |
| "add r4, r4, r5\n" |
| "cmp r6, r2\n" |
| "blt 50b\n" |
| "b 31f\n" |
| "30:\n" |
| // Yes, all of the 4x2 block fits. |
| // r3 address, r5 stride |
| "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "mov r4, r3\n" |
| "mov r6, #1\n" |
| |
| "vst1.32 {d30[0]}, [r3]\n" |
| "add r4, r4, r5\n" |
| "mov r3, r4\n" |
| "vst1.32 {d30[1]}, [r3]\n" |
| |
| "31:\n" |
| |
| // Load dst_ptr, increment, and write back. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "add r4, r4, #4\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| |
| RUY_MAKE_ZERO(q13) |
| RUY_MAKE_ZERO(q14) |
| RUY_MAKE_ZERO(q15) |
| |
| "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
| |
| RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" |
| |
| // Load the destination zero point into each of the 4 32-bit slots |
| // in a q register. |
| "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
| "vdup.32 q13, r4\n" // dst_zero_point |
| // Add the destination zero point |
| "vadd.s32 q14, q14, q13\n" |
| "vadd.s32 q15, q15, q13\n" |
| |
| // Cast-and-saturate from int32 to int16 |
| // After this, all values for output are in q14. |
| "vqmovn.s32 d28, q14\n" |
| "vqmovn.s32 d29, q15\n" |
| |
| // At this point, v18 -- v31 aren't used anymore for the current block, |
| // so we can start clearing these accumulators for the next block |
| // (next iteration of the main loop). |
| RUY_MAKE_ZERO(q6) |
| RUY_MAKE_ZERO(q7) |
| RUY_MAKE_ZERO(q8) |
| RUY_MAKE_ZERO(q9) |
| RUY_MAKE_ZERO(q10) |
| RUY_MAKE_ZERO(q11) |
| RUY_MAKE_ZERO(q15) |
| |
| // Load the clamp_min, clamp_max bounds |
| "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
| "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
| "vdup.16 q12, r2\n" // clamp_min |
| "vdup.16 q13, r3\n" // clamp_max |
| |
| // Apply the clamp_min bound |
| "vmax.s16 q14, q14, q12\n" |
| // Apply the clamp_max bound |
| "vmin.s16 q14, q14, q13\n" |
| |
| RUY_MAKE_ZERO(q12) |
| RUY_MAKE_ZERO(q13) |
| |
| // Compute how much of the 4x2 block of destination 16-bit values that |
| // we have computed, fit in the destination matrix. Typically, all of |
| // it fits, but when the destination matrix shape is not a multiple |
| // of 4x2, there are some 4x2 blocks along the boundaries that do |
| // not fit entirely. |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "sub r1, r1, r8\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "sub r2, r2, r4\n" |
| "mov r3, #4\n" |
| "mov r5, #2\n" |
| "cmp r1, #4\n" |
| // Compute r1 = how many rows of the 4x2 block fit |
| "it gt\n" |
| "movgt r1, r3\n" |
| |
| "cmp r2, #2\n" |
| // Compute r2 = how many cols of the 4x2 block fit |
| "it gt\n" |
| "movgt r2, r5\n" |
| |
| // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. |
| "cmp r1, r3\n" |
| "it eq\n" |
| "cmpeq r2, r5\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| // Yes, all of the 4x2 block fits, go to fast path. |
| "beq 30f\n" |
| // Not all of the 4x2 block fits. |
| // Store to dst_tmp_buf |
| // Set r3 address to write to dst_tmp_buf. |
| "mov r3, %[dst_tmp_buf]\n" |
| "vst1.16 {q14}, [r3]\n" |
| |
| // Slow loop copying from dst_tmp_buf to dst. |
| "mov r6, #0\n" |
| "50:\n" |
| "mov r8, #0\n" |
| "51:\n" |
| // Shift of offset register for half-word loads not allowed in A32, |
| // so we shift, load/store, then shift back r8. |
| "lsl r8, r8, #1\n" |
| "ldrh r10, [r3, r8]\n" |
| "strh r10, [r4, r8]\n" |
| "lsr r8, r8, #1\n" |
| "add r8, r8, #1\n" |
| "cmp r8, r1\n" |
| "blt 51b\n" |
| "add r6, r6, #1\n" |
| "add r3, r3, #8\n" |
| "add r4, r4, r5\n" |
| "cmp r6, r2\n" |
| "blt 50b\n" |
| "b 31f\n" |
| "30:\n" |
| // Yes, all of the 4x2 block fits. |
| // r3 address, r5 stride |
| "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "mov r4, r3\n" |
| "mov r6, #2\n" |
| |
| "vst1.16 {d28[0]}, [r3], r6\n" |
| "add r4, r4, r5\n" |
| "vst1.16 {d28[1]}, [r3], r6\n" |
| "vst1.16 {d28[2]}, [r3], r6\n" |
| "vst1.16 {d28[3]}, [r3], r6\n" |
| "mov r3, r4\n" |
| "vst1.16 {d29[0]}, [r3], r6\n" |
| "vst1.16 {d29[1]}, [r3], r6\n" |
| "vst1.16 {d29[2]}, [r3], r6\n" |
| "vst1.16 {d29[3]}, [r3], r6\n" |
| "31:\n" |
| |
| // Load dst_ptr, increment, and write back. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "add r4, r4, #8\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| |
| RUY_MAKE_ZERO(q14) |
| |
| "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
| |
| RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" |
| |
| // Since the store type is the same as the accum type, no need for |
| // downcast. There's also no need for clamp by min/max. |
| |
| // At this point, v20 -- v31 aren't used anymore for the current block, |
| // so we can start clearing these accumulators for the next block |
| // (next iteration of the main loop). |
| // Clear accumulators. |
| RUY_MAKE_ZERO(q6) |
| RUY_MAKE_ZERO(q7) |
| RUY_MAKE_ZERO(q8) |
| RUY_MAKE_ZERO(q9) |
| RUY_MAKE_ZERO(q10) |
| RUY_MAKE_ZERO(q11) |
| RUY_MAKE_ZERO(q12) |
| RUY_MAKE_ZERO(q13) |
| |
| // Compute how much of the 4x2 block of destination 32 bit values that |
| // we have computed, fit in the destination matrix. Typically, all of |
| // it fits, but when the destination matrix shape is not a multiple |
| // of 4x2, there are some 4x4 blocks along the boundaries that do |
| // not fit entirely. |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "sub r1, r1, r8\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "sub r2, r2, r4\n" |
| "mov r3, #4\n" |
| "mov r5, #2\n" |
| "cmp r1, #4\n" |
| // Compute r1 = how many rows of the 4x2 block fit |
| "it gt\n" |
| "movgt r1, r3\n" |
| |
| "cmp r2, #2\n" |
| // Compute r2 = how many cols of the 4x2 block fit |
| "it gt\n" |
| "movgt r2, r5\n" |
| |
| // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. |
| "cmp r1, r3\n" |
| "it eq\n" |
| "cmpeq r2, r5\n" |
| // Yes, all of the 4x2 block fits, go to fast path. |
| "beq 30f\n" |
| // Not all of the 4x2 block fits. |
| // Set (r3 address, r4 stride) to write to dst_tmp_buf |
| "mov r3, %[dst_tmp_buf]\n" |
| "mov r4, #16\n" |
| "b 31f\n" |
| |
| "30:\n" |
| // Yes, all of the 4x2 block fits. |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| // r3 address, r4 stride |
| "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "mov r4, r5\n" |
| |
| "31:\n" |
| |
| "vst1.32 {d28, d29}, [r3]\n" |
| "add r3, r3, r4\n" |
| "vst1.32 {d30, d31}, [r3]\n" |
| |
| // If all of the 4x2 block fits, we just finished writing it to the |
| // destination, so we skip the next part. |
| "beq 41f\n" |
| // Not all of the 4x2 block fits in the destination matrix. We just |
| // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
| // it to copy into the destination matrix the part that fits. |
| "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| "mov r3, %[dst_tmp_buf]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "mov r6, #0\n" |
| "50:\n" |
| "mov r5, #0\n" |
| "51:\n" |
| "ldr r10, [r3, r5, lsl #2]\n" |
| "str r10, [r4, r5, lsl #2]\n" |
| "add r5, r5, #1\n" |
| "cmp r5, r1\n" |
| "blt 51b\n" |
| "add r6, r6, #1\n" |
| "add r3, r3, #16\n" |
| "add r4, r4, r8\n" |
| // r2 = how many cols of the 8x4 block fit |
| "cmp r6, r2\n" |
| "blt 50b\n" |
| |
| "41:\n" |
| // Load dst_ptr, increment, and write back. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "add r4, r4, #16\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| |
| RUY_MAKE_ZERO(q10) |
| RUY_MAKE_ZERO(q11) |
| |
| "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
| |
| RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" |
| |
| // Reload some params --- we had used x5 -- x7 for a few other things |
| // since the last time we had loaded them. |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
| "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
| |
| // Move to the next block of the destination matrix, for the next iter |
| // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
| // been updated earlier. |
| // Have we reached the end row? |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "cmp r8, r3\n" |
| |
| "beq 20f\n" // yes, end row. |
| // Not end row. Move to the next row. |
| "add r8, r8, #4\n" |
| // Store new value of row |
| "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| |
| "b 21f\n" |
| "20:\n" |
| // Was already at end row. |
| // Move back to first row. |
| "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| // Move to the next column. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "add r4, r4, #2\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| |
| "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
| // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns) |
| "add r1, r1, r8, lsl #1\n" |
| // Store dst_col_ptr |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
| // Store dst_ptr |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "21:\n" |
| |
| // Main loop exit condition: have we hit the end column? |
| "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "cmp r8, r4\n" |
| |
| // w1 is the number of levels of depth that we have already loaded |
| // LHS and RHS data for. Corresponding to the initial ld1 instructions |
| // above, this is currently 16. |
| "mov r1, #16\n" |
| |
| "ble 1b\n" |
| |
| // Restore stack pointer. |
| "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" |
| |
| // clang-format on |
| |
| : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) |
| : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) |
| : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", |
| // Clobber list must specify q registers (and not their constituent |
| // d registers). There is a (currently unexplained) slowdown if |
| // d registers are listed in the clobbers list. |
| "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", |
| "q9", "q10", "q12", "q13", "q14", "q15"); |
| } |
| |
| // Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS |
| // is still packed as if it has two columns |
| void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) { |
| profiler::ScopeLabel label("Kernel (kNeon)"); |
| |
| CheckOffsetsInKernelParams8bit(params); |
| |
| const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
| const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; |
| const std::int8_t* lhs_ptr = lhs_col_ptr; |
| const std::int8_t* rhs_ptr = rhs_col_ptr; |
| |
| RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)); |
| |
| // The asm kernel below has the following NEON register allocation: |
| // |
| // q6 - q13 are 128-bit (4x32b) accumulators. |
| // During accumulation, d0 -- d7 are used to load int8 data from LHS and |
| // d8 -- d11 from RHS: |
| // int8 RHS 16x1 block |
| // /------------| |
| // | d8.b[0] | |
| // | ... | |
| // | d8.b[7] | |
| // | d9.b[0] | |
| // | ... | |
| // | d9.b[7] | |
| // \------------/ |
| // int8 LHS 4x16 block |
| // /-----------------------------------------\ /------------| |
| // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 | |
| // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 | |
| // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 | |
| // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 | |
| // \-----------------------------------------/ \------------/ |
| // 128-bit accumulators 4x1 block |
| // |
| // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING |
| // optimization for this kernel. |
| asm volatile( |
| #define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" |
| |
| // clang-format off |
| |
| // Load the first 64 bytes of LHS and RHS data. |
| "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" |
| "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" |
| "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" |
| "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" |
| "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" |
| // Skip the other column and advance the pointer. |
| "add %[rhs_ptr], %[rhs_ptr], #16\n" |
| |
| "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" |
| "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" |
| "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" |
| "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
| |
| // Clear accumulators. |
| RUY_MAKE_ZERO(q6) |
| RUY_MAKE_ZERO(q7) |
| RUY_MAKE_ZERO(q8) |
| RUY_MAKE_ZERO(q9) |
| RUY_MAKE_ZERO(q10) |
| RUY_MAKE_ZERO(q11) |
| RUY_MAKE_ZERO(q12) |
| RUY_MAKE_ZERO(q13) |
| RUY_MAKE_ZERO(q14) |
| RUY_MAKE_ZERO(q15) |
| |
| // r1 is the number of levels of depth that we have already loaded |
| // LHS and RHS data for. Corresponding to the initial ld1 instructions |
| // above, this is currently 16. |
| "mov r1, #16\n" |
| |
| // Main loop of the whole GEMM, over rows and columns of the |
| // destination matrix. |
| "1:\n" |
| |
| // r1 is how many levels of depth we have already loaded |
| // data for, r10 is the total depth. |
| "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
| "cmp r1, r10\n" |
| "beq 79f\n" |
| |
| "2:\n" |
| |
| // Mult, mult-acc in to q14, q15 |
| "vmull.s8 q14, d0, d8\n" |
| "vmull.s8 q15, d2, d8\n" |
| "vmlal.s8 q14, d1, d9\n" |
| "vmlal.s8 q15, d3, d9\n" |
| |
| // Then pairwise accumulate in to q6, q7 |
| "vpadal.s16 q6, q14\n" |
| "vpadal.s16 q7, q15\n" |
| |
| // Mult, mult-acc in to q14, q15 |
| "vmull.s8 q14, d4, d8\n" |
| "vmull.s8 q15, d6, d8\n" |
| "vmlal.s8 q14, d5, d9\n" |
| "vmlal.s8 q15, d7, d9\n" |
| |
| // Then pairwise accumulate in to q8, q9 |
| "vpadal.s16 q8, q14\n" |
| "vpadal.s16 q9, q15\n" |
| |
| |
| // Load the next 64 bytes of LHS and RHS data. |
| "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" |
| "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" |
| "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" |
| "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" |
| RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") |
| "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" |
| // Skip the other column and advance the pointer. |
| "add %[rhs_ptr], %[rhs_ptr], #16\n" |
| RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") |
| |
| // Each iteration of this loop advances by 16 levels of depth. |
| "add r1, r1, #16\n" |
| |
| // Loop termination condition |
| "cmp r1, r10\n" |
| |
| "blt 2b\n" |
| |
| "79:\n" |
| |
| // Mult, mult-acc in to q14, q15 |
| "vmull.s8 q14, d0, d8\n" |
| "vmull.s8 q15, d2, d8\n" |
| "vmlal.s8 q14, d1, d9\n" |
| "vmlal.s8 q15, d3, d9\n" |
| |
| // Then pairwise accumulate in to q6, q7 |
| "vpadal.s16 q6, q14\n" |
| "vpadal.s16 q7, q15\n" |
| |
| // Mult, mult-acc in to q14, q15 |
| "vmull.s8 q14, d4, d8\n" |
| "vmull.s8 q15, d6, d8\n" |
| "vmlal.s8 q14, d5, d9\n" |
| "vmlal.s8 q15, d7, d9\n" |
| |
| // Then pairwise accumulate in to q8, q9 |
| "vpadal.s16 q8, q14\n" |
| "vpadal.s16 q9, q15\n" |
| |
| // All accumulation over depth done. q6 - q9 contain the 4x32b |
| // accumulators for the 4x1 final matrix. |
| // We now have to compute the final 8-bit values from these int32 |
| // accumulators, and advance to the next 4x2 block. We intertwine |
| // these two aspects whenever possible for optimal pipelining, both |
| // at the data flow level (prefetch data for next block as early as |
| // possible) and instruction pipelining level (some of the next-block |
| // work can dual-issue with some of the final work on the current |
| // block). |
| |
| // q6-q9 now contain 4 x 32b |
| "vpadd.i32 d0, d12, d13\n" |
| "vpadd.i32 d1, d14, d15\n" |
| "vpadd.i32 d2, d16, d17\n" |
| "vpadd.i32 d3, d18, d19\n" |
| |
| // d0-d4 each contain 2 x 32b accumulators. |
| // Need to add pairwise to get 1 x 32b for each of the 4x1 entries |
| // of destination, (Four 'd' registers total) |
| "vpadd.i32 d28, d0, d1\n" |
| "vpadd.i32 d29, d2, d3\n" |
| |
| // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries. |
| |
| // Logic to advance to the next block in preparation for the next |
| // iteration of the main loop. For now, we only want to compute |
| // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
| // not yet ready to update the values of row and col, as we still need |
| // the current values for the rest of the work on the current block. |
| |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
| "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "cmp r1, r3\n" // Have we finished the last row? |
| |
| "bge 4f\n" // If finished last row, go to 4 |
| // Not finished last row: then advance to next row. |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| "add r4, r4, r1, lsl #2\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| "b 5f\n" |
| "4:\n" // Finished last row... |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
| // Go back to first row |
| "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| |
| // Now we need to advance to the next column. If we already |
| // finished the last column, then in principle we are done, however |
| // we can't just return here, as we need to allow the end work of the |
| // current block to complete. The good news is that at this point it |
| // doesn't matter what data we load for the next column, since |
| // we will exit from the main loop below before actually storing |
| // anything computed from that data. |
| |
| "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "cmp r8, r4\n" // Have we finished the last column? |
| "bge 5f\n" // If yes, just carry on without updating the column pointer. |
| // Not finished last column: then advance to next column. |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
| "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
| "add r10, r10, r1, lsl #1\n" |
| "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
| "5:\n" |
| |
| // Set the LHS and RHS data pointers to the start of the columns just |
| // computed. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
| "mov %[lhs_ptr], r4\n" |
| "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
| "mov %[rhs_ptr], r5\n" |
| |
| // Now we load: bias data, LHS sums data, RHS sums data. |
| |
| // First, load the base pointers from the params. |
| "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
| |
| // Offset these base pointers as needed given the current row, col. |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| |
| "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
| "it ne\n" |
| "addne r1, r1, r8, lsl #2\n" |
| |
| // Load 4 bias values. |
| "vld1.32 {d24, d25}, [r1]\n" |
| |
| // Now that we know what LHS and RHS data the next iteration of the |
| // main loop will need to load, we start loading the first 32 bytes of |
| // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
| // in the rest of the work on the current block. |
| "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" |
| "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" |
| "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" |
| "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" |
| RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") |
| "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" |
| // Skip the other column and advance the pointer. |
| "add %[rhs_ptr], %[rhs_ptr], #16\n" |
| RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") |
| |
| // Add to the bias values the product |
| // (depth * lhs_zero_point * rhs_zero_point), |
| // See the term NZ1Z2 in equation (7) in |
| // https://arxiv.org/pdf/1712.05877.pdf |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" |
| "vdup.32 q9, r3\n" |
| "vadd.i32 q12, q12, q9\n" |
| |
| // Perform the bias-addition (per the above, we have just folded into |
| // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) |
| "vadd.i32 q14, q14, q12\n" |
| |
| // LHS/RHS zero points |
| // Has RHS sums |
| "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
| "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" |
| "beq 401f\n" |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| // Offset by current col * number of bytes per value |
| "add r3, r3, r4, lsl #2\n" |
| "vld1.32 { d12 }, [r3]\n" |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" |
| "vdup.32 q10, r5\n" // create lhs_zero_point_vec |
| // Subtract rhs_sums * lhs_zero_point, per |
| // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
| "vmls.i32 q14, q10, d12[0]\n" |
| "401:\n" |
| |
| // Has LHS sums |
| "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
| "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" |
| "beq 402f\n" |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| // Offset by current row * number of bytes per value |
| "add r2, r2, r4, lsl #2\n" |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" |
| |
| // Load 4 lhs_sums values. |
| "vld1.32 {d22, d23}, [r2]\n" |
| "vdup.32 d13, r5\n" // rhs_zero_point |
| |
| // Compute lhs_sums * rhs_zero_point. |
| "vmul.i32 q11, q11, d13[1]\n" |
| // Subtract lhs_sums * rhs_zero_point, per |
| // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
| "vsub.s32 q14, q14, q11\n" |
| |
| // If the destination is int32, it means the user asks for the raw |
| // accumulators, no need for us to downquantize the value. |
| "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" |
| "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" |
| "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" |
| |
| "402:\n" |
| |
| // At this point we have computed the final int32 values. Now we |
| // start down-quantizing them to obtain the final 8bit values from them. |
| |
| // As part of this down-quantization, our int32 values will be |
| // multiplied by a multiplier that has a fixed-point component and an |
| // exponent component. |
| |
| //Load the exponent part of the multiplier. |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" |
| "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "it ne\n" |
| "addne r1, r1, r4, lsl #2\n" |
| |
| "vld1.32 {q10}, [r1]\n" |
| |
| "vmvn.i32 q8, #0\n" |
| "vmin.s32 q13, q10, q8\n" |
| "vsub.s32 q12, q10, q13\n" |
| |
| // Apply the positive exponent part of the multiplier. |
| "vshl.s32 q14, q14, q12\n" |
| |
| // Load fixed point part of the multiplier |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" |
| // r6 has flags, r4 has row |
| "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
| "it ne\n" |
| "addne r1, r1, r4, lsl #2\n" |
| "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint |
| |
| // Apply the fixed-point part of the multiplier. |
| "vqdmulh.s32 q14, q14, q10\n" |
| |
| // Apply the negative exponent part of the multiplier. |
| "vrshl.s32 q14, q14, q13\n" |
| |
| "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" |
| "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" |
| "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" |
| "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" |
| "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" |
| |
| // Store uint8 values: |
| RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" |
| |
| // Cast-and-saturate from int32 to int16 |
| // After this, all values for output are in d28. |
| "vqmovn.s32 d28, q14\n" |
| |
| // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the |
| // current block, so we can start clearing these accumulators for the |
| // next block (next iteration of the main loop). |
| RUY_MAKE_ZERO(q6) |
| RUY_MAKE_ZERO(q7) |
| RUY_MAKE_ZERO(q8) |
| RUY_MAKE_ZERO(q9) |
| RUY_MAKE_ZERO(q10) |
| RUY_MAKE_ZERO(q11) |
| RUY_MAKE_ZERO(q12) |
| RUY_MAKE_ZERO(q13) |
| RUY_MAKE_ZERO(q15) |
| |
| // Load the destination zero point into each of the 8 16-bit slots |
| // in a q register. |
| "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
| "vdup.16 q13, r4\n" // dst_zero_point |
| |
| // Add the destination zero point |
| "vqadd.s16 q14, q14, q13\n" |
| |
| // Cast-and-saturate from int16 to uint8 |
| "vqmovun.s16 d30, q14\n" |
| // At this point, we only need 4 8-bit values in the lower half |
| // of d30. |
| |
| |
| // Load the clamp_min, clamp_max bounds |
| "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
| "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
| "vdup.8 d28, r2\n" // clamp_min |
| "vdup.8 d29, r3\n" // clamp_max |
| |
| // Apply the clamp_min bound |
| "vmax.u8 d30, d30, d28\n" |
| // Apply the clamp_max bound |
| "vmin.u8 d30, d30, d29\n" |
| |
| // Compute how much of the 4x1 block of destination 8bit values that |
| // we have computed, fit in the destination matrix. Typically, all of |
| // it fits, but when the destination matrix shape is not a multiple |
| // of 4x1, there are some 4x1 blocks along the boundaries that do |
| // not fit entirely. |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "sub r1, r1, r8\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "sub r2, r2, r4\n" |
| "mov r3, #4\n" |
| "mov r5, #2\n" |
| "cmp r1, #4\n" |
| // Compute r1 = how many rows of the 4x1 block fit |
| "it gt\n" |
| "movgt r1, r3\n" |
| |
| // Test if r1==4, i.e. if all of the 4x1 block fits. |
| "cmp r1, r3\n" |
| |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| // Yes, all of the 4x1 block fits, go to fast path. |
| "beq 30f\n" |
| // Not all of the 4x1 block fits. |
| // Store to dst_tmp_buf |
| // Set r3 address to write to dst_tmp_buf. |
| "mov r3, %[dst_tmp_buf]\n" |
| "vst1.8 {d30}, [r3]\n" |
| |
| // Slow loop copying from dst_tmp_buf to dst. |
| "50:\n" |
| "mov r8, #0\n" |
| "51:\n" |
| "ldrb r10, [r3, r8]\n" |
| "strb r10, [r4, r8]\n" |
| "add r8, r8, #1\n" |
| "cmp r8, r1\n" |
| "blt 51b\n" |
| "b 31f\n" |
| "30:\n" |
| // Yes, all of the 4x1 block fits. |
| // r3 address, r5 stride |
| "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "mov r4, r3\n" |
| "mov r6, #1\n" |
| |
| "vst1.8 {d30[0]}, [r3], r6\n" |
| "vst1.8 {d30[1]}, [r3], r6\n" |
| "vst1.8 {d30[2]}, [r3], r6\n" |
| "vst1.8 {d30[3]}, [r3], r6\n" |
| "31:\n" |
| |
| // Load dst_ptr, increment, and write back. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "add r4, r4, #4\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| |
| RUY_MAKE_ZERO(q13) |
| RUY_MAKE_ZERO(q14) |
| RUY_MAKE_ZERO(q15) |
| |
| "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
| |
| // Store int8 values: |
| RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" |
| |
| // Cast-and-saturate from int32 to int16 |
| // After this, all values for output are in d28. |
| "vqmovn.s32 d28, q14\n" |
| |
| // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the |
| // current block, so we can start clearing these accumulators for the |
| // next block (next iteration of the main loop). |
| RUY_MAKE_ZERO(q6) |
| RUY_MAKE_ZERO(q7) |
| RUY_MAKE_ZERO(q8) |
| RUY_MAKE_ZERO(q9) |
| RUY_MAKE_ZERO(q10) |
| RUY_MAKE_ZERO(q11) |
| RUY_MAKE_ZERO(q12) |
| RUY_MAKE_ZERO(q13) |
| RUY_MAKE_ZERO(q15) |
| |
| // Load the destination zero point into each of the 8 16-bit slots |
| // in a q register. |
| "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
| "vdup.16 q13, r4\n" // dst_zero_point |
| |
| // Add the destination zero point |
| "vqadd.s16 q14, q14, q13\n" |
| |
| // Cast-and-saturate from int16 to int8 |
| "vqmovn.s16 d30, q14\n" |
| // At this point, we only need 4 8-bit values in the lower half |
| // of d30. |
| |
| // Load the clamp_min, clamp_max bounds |
| "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
| "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
| "vdup.8 d28, r2\n" // clamp_min |
| "vdup.8 d29, r3\n" // clamp_max |
| |
| // Apply the clamp_min bound |
| "vmax.s8 d30, d30, d28\n" |
| // Apply the clamp_max bound |
| "vmin.s8 d30, d30, d29\n" |
| |
| // Compute how much of the 4x1 block of destination 8bit values that |
| // we have computed, fit in the destination matrix. Typically, all of |
| // it fits, but when the destination matrix shape is not a multiple |
| // of 4x2, there are some 4x2 blocks along the boundaries that do |
| // not fit entirely. |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "sub r1, r1, r8\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "sub r2, r2, r4\n" |
| "mov r3, #4\n" |
| "mov r5, #2\n" |
| "cmp r1, #4\n" |
| // Compute r1 = how many rows of the 4x2 block fit |
| "it gt\n" |
| "movgt r1, r3\n" |
| |
| // Test if r1==4 i.e. if all of the 4x1 block fits. |
| "cmp r1, r3\n" |
| |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| // Yes, all of the 4x2 block fits, go to fast path. |
| "beq 30f\n" |
| // Not all of the 4x2 block fits. |
| // Store to dst_tmp_buf |
| // Set r3 address to write to dst_tmp_buf. |
| "mov r3, %[dst_tmp_buf]\n" |
| "vst1.8 {d30}, [r3]\n" |
| |
| // Slow loop copying from dst_tmp_buf to dst. |
| "50:\n" |
| "mov r8, #0\n" |
| "51:\n" |
| "ldrb r10, [r3, r8]\n" |
| "strb r10, [r4, r8]\n" |
| "add r8, r8, #1\n" |
| "cmp r8, r1\n" |
| "blt 51b\n" |
| "b 31f\n" |
| "30:\n" |
| // Yes, all of the 4x1 block fits. |
| // r3 address, r5 stride |
| "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "mov r4, r3\n" |
| "mov r6, #1\n" |
| |
| "vst1.8 {d30[0]}, [r3], r6\n" |
| "vst1.8 {d30[1]}, [r3], r6\n" |
| "vst1.8 {d30[2]}, [r3], r6\n" |
| "vst1.8 {d30[3]}, [r3], r6\n" |
| "31:\n" |
| |
| // Load dst_ptr, increment, and write back. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "add r4, r4, #4\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| |
| RUY_MAKE_ZERO(q13) |
| RUY_MAKE_ZERO(q14) |
| RUY_MAKE_ZERO(q15) |
| |
| "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
| |
| RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" |
| |
| // Load the destination zero point into each of the 4 32-bit slots |
| // in a q register. |
| "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
| "vdup.32 q13, r4\n" // dst_zero_point |
| // Add the destination zero point |
| "vadd.s32 q14, q14, q13\n" |
| //"vadd.s32 q15, q15, q13\n" |
| |
| // Cast-and-saturate from int32 to int16 |
| // After this, all values for output are in d28. |
| "vqmovn.s32 d28, q14\n" |
| |
| // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the |
| // so we can start clearing these accumulators for the next block |
| // (next iteration of the main loop). |
| RUY_MAKE_ZERO(q6) |
| RUY_MAKE_ZERO(q7) |
| RUY_MAKE_ZERO(q8) |
| RUY_MAKE_ZERO(q9) |
| RUY_MAKE_ZERO(q10) |
| RUY_MAKE_ZERO(q11) |
| RUY_MAKE_ZERO(q15) |
| |
| // Load the clamp_min, clamp_max bounds |
| "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
| "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
| "vdup.16 d24, r2\n" // clamp_min |
| "vdup.16 d26, r3\n" // clamp_max |
| |
| // Apply the clamp_min bound |
| "vmax.s16 d28, d28, d24\n" |
| // Apply the clamp_max bound |
| "vmin.s16 d28, d28, d26\n" |
| |
| RUY_MAKE_ZERO(q12) |
| RUY_MAKE_ZERO(q13) |
| |
| // Compute how much of the 4x1 block of destination 16-bit values that |
| // we have computed, fit in the destination matrix. Typically, all of |
| // it fits, but when the destination matrix shape is not a multiple |
| // of 4x1, there are some 4x1 blocks along the boundaries that do |
| // not fit entirely. |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "sub r1, r1, r8\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "sub r2, r2, r4\n" |
| "mov r3, #4\n" |
| "mov r5, #2\n" |
| "cmp r1, #4\n" |
| // Compute r1 = how many rows of the 4x1 block fit |
| "it gt\n" |
| "movgt r1, r3\n" |
| |
| // Test if r1==4, i.e. if all of the 4x1 block fits. |
| "cmp r1, r3\n" |
| |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| // Yes, all of the 4x1 block fits, go to fast path. |
| "beq 30f\n" |
| // Not all of the 4x1 block fits. |
| // Store to dst_tmp_buf |
| // Set r3 address to write to dst_tmp_buf. |
| "mov r3, %[dst_tmp_buf]\n" |
| "vst1.16 {d28}, [r3]\n" |
| |
| // Slow loop copying from dst_tmp_buf to dst. |
| "50:\n" |
| "mov r8, #0\n" |
| "51:\n" |
| // Shift of offset register for half-word loads not allowed in A32, |
| // so we shift, load/store, then shift back r8. |
| "lsl r8, r8, #1\n" |
| "ldrh r10, [r3, r8]\n" |
| "strh r10, [r4, r8]\n" |
| "lsr r8, r8, #1\n" |
| "add r8, r8, #1\n" |
| "cmp r8, r1\n" |
| "blt 51b\n" |
| "b 31f\n" |
| "30:\n" |
| // Yes, all of the 4x1 block fits. |
| // r3 address, r5 stride |
| "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "mov r4, r3\n" |
| "mov r6, #2\n" |
| |
| "vst1.16 {d28[0]}, [r3], r6\n" |
| "vst1.16 {d28[1]}, [r3], r6\n" |
| "vst1.16 {d28[2]}, [r3], r6\n" |
| "vst1.16 {d28[3]}, [r3], r6\n" |
| "31:\n" |
| |
| // Load dst_ptr, increment, and write back. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "add r4, r4, #8\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| |
| RUY_MAKE_ZERO(q14) |
| |
| "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
| |
| RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" |
| |
| // Since the store type is the same as the accum type, no need for |
| // downcast. There's also no need for clamp by min/max. |
| |
| // At this point, v20 -- v31 aren't used anymore for the current block, |
| // so we can start clearing these accumulators for the next block |
| // (next iteration of the main loop). |
| // Clear accumulators. |
| RUY_MAKE_ZERO(q6) |
| RUY_MAKE_ZERO(q7) |
| RUY_MAKE_ZERO(q8) |
| RUY_MAKE_ZERO(q9) |
| RUY_MAKE_ZERO(q10) |
| RUY_MAKE_ZERO(q11) |
| RUY_MAKE_ZERO(q12) |
| RUY_MAKE_ZERO(q13) |
| |
| // Compute how much of the 4x1 block of destination 32 bit values that |
| // we have computed, fit in the destination matrix. Typically, all of |
| // it fits, but when the destination matrix shape is not a multiple |
| // of 4x2, there are some 4x4 blocks along the boundaries that do |
| // not fit entirely. |
| |
| "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "sub r1, r1, r8\n" |
| |
| "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "sub r2, r2, r4\n" |
| "mov r3, #4\n" |
| "mov r5, #2\n" |
| "cmp r1, #4\n" |
| // Compute r1 = how many rows of the 4x2 block fit |
| "it gt\n" |
| "movgt r1, r3\n" |
| |
| // Test if r1==4, i.e. if all of the 4x1 block fits. |
| "cmp r1, r3\n" |
| |
| // Yes, all of the 4x1 block fits, go to fast path. |
| "beq 30f\n" |
| // Not all of the 4x1 block fits. |
| // Set (r3 address, r4 stride) to write to dst_tmp_buf |
| "mov r3, %[dst_tmp_buf]\n" |
| "mov r4, #16\n" |
| "b 31f\n" |
| |
| "30:\n" |
| // Yes, all of the 4x1 block fits. |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| // r3 address, r4 stride |
| "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "mov r4, r5\n" |
| |
| "31:\n" |
| |
| "vst1.32 {d28, d29}, [r3]\n" |
| |
| // If all of the 4x1 block fits, we just finished writing it to the |
| // destination, so we skip the next part. |
| "beq 41f\n" |
| // Not all of the 4x1 block fits in the destination matrix. We just |
| // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
| // it to copy into the destination matrix the part that fits. |
| "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| "mov r3, %[dst_tmp_buf]\n" |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "50:\n" |
| "mov r5, #0\n" |
| "51:\n" |
| "ldr r10, [r3, r5, lsl #2]\n" |
| "str r10, [r4, r5, lsl #2]\n" |
| "add r5, r5, #1\n" |
| "cmp r5, r1\n" |
| "blt 51b\n" |
| |
| "41:\n" |
| // Load dst_ptr, increment, and write back. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "add r4, r4, #16\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| |
| RUY_MAKE_ZERO(q10) |
| RUY_MAKE_ZERO(q11) |
| |
| "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
| |
| RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" |
| |
| // Reload some params --- we had used x5 -- x7 for a few other things |
| // since the last time we had loaded them. |
| "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
| "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
| "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
| |
| // Move to the next block of the destination matrix, for the next iter |
| // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
| // been updated earlier. |
| // Have we reached the end row? |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| "cmp r8, r3\n" |
| |
| "beq 20f\n" // yes, end row. |
| // Not end row. Move to the next row. |
| "add r8, r8, #4\n" |
| // Store new value of row |
| "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| |
| "b 21f\n" |
| "20:\n" |
| // Was already at end row. |
| // Move back to first row. |
| "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
| // Move to the next column. |
| "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "add r4, r4, #2\n" |
| "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| |
| "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
| "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
| // Increment dst_col_ptr by dst_stride (i.e. 1 column) |
| "add r1, r1, r8\n" |
| // Store dst_col_ptr |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
| // Store dst_ptr |
| "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
| "21:\n" |
| |
| // Main loop exit condition: have we hit the end column? |
| "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
| "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
| "cmp r8, r4\n" |
| |
| // w1 is the number of levels of depth that we have already loaded |
| // LHS and RHS data for. Corresponding to the initial ld1 instructions |
| // above, this is currently 16. |
| "mov r1, #16\n" |
| |
| "ble 1b\n" |
| |
| // Restore stack pointer. |
| "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" |
| |
| // clang-format on |
| |
| : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) |
| : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) |
| : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", |
| // Clobber list must specify q registers (and not their constituent |
| // d registers). There is a (currently unexplained) slowdown if |
| // d registers are listed in the clobbers list. |
| "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", |
| "q9", "q10", "q12", "q13", "q14", "q15"); |
| } |
| |
| #undef RUY_OFFSET_BIAS |
| #undef RUY_OFFSET_LHS_SUMS |
| #undef RUY_OFFSET_RHS_SUMS |
| #undef RUY_OFFSET_LHS_BASE_PTR |
| #undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT |
| #undef RUY_OFFSET_MULTIPLIER_EXPONENT |
| #undef RUY_OFFSET_RHS_BASE_PTR |
| #undef RUY_OFFSET_DST_BASE_PTR |
| #undef RUY_OFFSET_LHS_ZERO_POINT |
| #undef RUY_OFFSET_RHS_ZERO_POINT |
| #undef RUY_OFFSET_DST_ZERO_POINT |
| #undef RUY_OFFSET_PROD_ZP_DEPTH |
| #undef RUY_OFFSET_START_ROW |
| #undef RUY_OFFSET_START_COL |
| #undef RUY_OFFSET_LAST_ROW |
| #undef RUY_OFFSET_LAST_COL |
| #undef RUY_OFFSET_DST_ROWS |
| #undef RUY_OFFSET_DST_COLS |
| #undef RUY_OFFSET_LHS_STRIDE |
| #undef RUY_OFFSET_RHS_STRIDE |
| #undef RUY_OFFSET_DST_STRIDE |
| #undef RUY_OFFSET_DEPTH |
| #undef RUY_OFFSET_CLAMP_MIN |
| #undef RUY_OFFSET_CLAMP_MAX |
| #undef RUY_OFFSET_FLAGS |
| #undef RUY_OFFSET_DST_TYPE_ID |
| |
| #undef RUY_STACK_OFFSET_SIZE |
| #undef RUY_STACK_OFFSET_DST_COL_PTR |
| #undef RUY_STACK_OFFSET_DST_PTR |
| #undef RUY_STACK_OFFSET_ROW |
| #undef RUY_STACK_OFFSET_COL |
| #undef RUY_STACK_OFFSET_LHS_COL_PTR |
| #undef RUY_STACK_OFFSET_RHS_COL_PTR |
| |
| #endif // RUY_PLATFORM_NEON_32 && (RUY_OPT(ASM) |
| } // namespace ruy |