blob: 78e0a123500965f01bdb7828964c7288eeec749a [file] [log] [blame]
// Copyright 2020 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <xnnpack/assembly.h>
# void xnn_qs8_gemm_minmax_ukernel_4x16c4__aarch64_neondot_ld64(
# size_t mr, x0
# size_t nc, x1
# size_t kc, x2 / x0 / x10
# const int8_t* restrict a, x3
# size_t a_stride, x4
# const void* restrict w, x5
# int8_t* restrict c, x6
# size_t cm_stride, x7
# size_t cn_stride, [sp] -> x12
# const union xnn_qs8_gemm_params params) [sp + 8] -> x11
# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
# Register usage
# A0 x3 v0
# A1 x15 v1
# A2 x13 v2
# A3 x4 v3
# B x5 v4 v5 v6 v7
# C0 x6 v16 v20 v24 v28
# C1 x8 v17 v21 v25 v29
# C2 x9 v18 v22 v26 v30
# C3 x7 v19 v23 v27 v31
# unused v8 v9 v10 v11 v12 v13 v14 v15
BEGIN_FUNCTION xnn_qs8_gemm_minmax_ukernel_4x16c4__aarch64_neondot_ld64
# Clamp A and C pointers
CMP x0, 2 // if mr < 2
ADD x15, x3, x4 // a1 = a0 + a_stride
ADD x8, x6, x7 // c1 = c0 + cm_stride
CSEL x15, x3, x15, LO // a1 = a0
CSEL x8, x6, x8, LO // c1 = c0
ADD x13, x15, x4 // a2 = a1 + a_stride
ADD x9, x8, x7 // c2 = c1 + cm_stride
// if mr <= 2
CSEL x13, x15, x13, LS // a2 = a1
CSEL x9, x8, x9, LS // c2 = c1
CMP x0, 4 // if mr < 4
ADD x4, x13, x4 // a3 = a2 + a_stride
ADD x7, x9, x7 // c3 = c2 + cm_stride
CSEL x4, x13, x4, LO // a3 = a2
CSEL x7, x9, x7, LO // c3 = c2
.p2align 3
0:
# Load initial bias from w into accumulators
LDP q16, q20, [x5], 32
SUBS x0, x2, 5 // k = kc - 5
MOV v17.16b, v16.16b
MOV v18.16b, v16.16b
LDP q24, q28, [x5], 32
MOV v19.16b, v16.16b
MOV v21.16b, v20.16b
MOV v22.16b, v20.16b
MOV v23.16b, v20.16b
LDR x11, [sp, 8] // params
MOV v25.16b, v24.16b
MOV v26.16b, v24.16b
ADD x10, x2, 3 // rewind = (kc + 3) & ~7
MOV v27.16b, v24.16b
MOV v29.16b, v28.16b
BIC x10, x10, 7
MOV v30.16b, v28.16b
MOV v31.16b, v28.16b
# Is there at least 5 bytes?
B.LO 3f
# Main loop - 8 bytes of A
.p2align 3
1:
LDR d0, [x3], 8
LDR q4, [x5], 16
LDR d1, [x15], 8
LDR d2, [x13], 8
LDR d3, [x4], 8
SDOT v16.4s, v4.16b, v0.4b[0]
LDR q5, [x5], 16
SDOT v17.4s, v4.16b, v1.4b[0]
SDOT v18.4s, v4.16b, v2.4b[0]
SDOT v19.4s, v4.16b, v3.4b[0]
SDOT v20.4s, v5.16b, v0.4b[0]
LDP q6, q7, [x5], 32
SDOT v21.4s, v5.16b, v1.4b[0]
SDOT v22.4s, v5.16b, v2.4b[0]
SDOT v23.4s, v5.16b, v3.4b[0]
SDOT v24.4s, v6.16b, v0.4b[0]
LDP q4, q5, [x5], 32
SDOT v25.4s, v6.16b, v1.4b[0]
SDOT v26.4s, v6.16b, v2.4b[0]
SDOT v27.4s, v6.16b, v3.4b[0]
SDOT v28.4s, v7.16b, v0.4b[0]
SDOT v29.4s, v7.16b, v1.4b[0]
SDOT v30.4s, v7.16b, v2.4b[0]
SDOT v31.4s, v7.16b, v3.4b[0]
SDOT v16.4s, v4.16b, v0.4b[1]
LDP q6, q7, [x5], 32
SDOT v17.4s, v4.16b, v1.4b[1]
SDOT v18.4s, v4.16b, v2.4b[1]
SDOT v19.4s, v4.16b, v3.4b[1]
SDOT v20.4s, v5.16b, v0.4b[1]
SDOT v21.4s, v5.16b, v1.4b[1]
SDOT v22.4s, v5.16b, v2.4b[1]
SDOT v23.4s, v5.16b, v3.4b[1]
SDOT v24.4s, v6.16b, v0.4b[1]
SDOT v25.4s, v6.16b, v1.4b[1]
SDOT v26.4s, v6.16b, v2.4b[1]
SDOT v27.4s, v6.16b, v3.4b[1]
SDOT v28.4s, v7.16b, v0.4b[1]
SDOT v29.4s, v7.16b, v1.4b[1]
SDOT v30.4s, v7.16b, v2.4b[1]
SUBS x0, x0, 8
SDOT v31.4s, v7.16b, v3.4b[1]
B.HS 1b
# Is there a remainder?- 1 to 4 bytes of A
TBNZ x0, 2, 3f
2:
# Apply params - scale, shift, bias and clamp
LD2R {v0.4s, v1.4s}, [x11], 8
CMEQ v2.4s, v1.4s, 0
BIC v4.16b, v16.16b, v2.16b
BIC v5.16b, v17.16b, v2.16b
BIC v6.16b, v18.16b, v2.16b
BIC v7.16b, v19.16b, v2.16b
SQRDMULH v16.4s, v16.4s, v0.4s
SQRDMULH v17.4s, v17.4s, v0.4s
SQRDMULH v18.4s, v18.4s, v0.4s
SQRDMULH v19.4s, v19.4s, v0.4s
SSRA v16.4s, v4.4s, 31 // signed shift right accumulate
SSRA v17.4s, v5.4s, 31
SSRA v18.4s, v6.4s, 31
SSRA v19.4s, v7.4s, 31
BIC v4.16b, v20.16b, v2.16b
BIC v5.16b, v21.16b, v2.16b
BIC v6.16b, v22.16b, v2.16b
BIC v7.16b, v23.16b, v2.16b
SQRDMULH v20.4s, v20.4s, v0.4s
SQRDMULH v21.4s, v21.4s, v0.4s
SQRDMULH v22.4s, v22.4s, v0.4s
SQRDMULH v23.4s, v23.4s, v0.4s
SSRA v20.4s, v4.4s, 31
SSRA v21.4s, v5.4s, 31
SSRA v22.4s, v6.4s, 31
SSRA v23.4s, v7.4s, 31
BIC v4.16b, v24.16b, v2.16b
BIC v5.16b, v25.16b, v2.16b
BIC v6.16b, v26.16b, v2.16b
BIC v7.16b, v27.16b, v2.16b
SQRDMULH v24.4s, v24.4s, v0.4s
SQRDMULH v25.4s, v25.4s, v0.4s
SQRDMULH v26.4s, v26.4s, v0.4s
SQRDMULH v27.4s, v27.4s, v0.4s
SSRA v24.4s, v4.4s, 31
SSRA v25.4s, v5.4s, 31
SSRA v26.4s, v6.4s, 31
SSRA v27.4s, v7.4s, 31
BIC v4.16b, v28.16b, v2.16b
BIC v5.16b, v29.16b, v2.16b
BIC v6.16b, v30.16b, v2.16b
BIC v7.16b, v31.16b, v2.16b
SQRDMULH v28.4s, v28.4s, v0.4s
SQRDMULH v29.4s, v29.4s, v0.4s
SQRDMULH v30.4s, v30.4s, v0.4s
SQRDMULH v31.4s, v31.4s, v0.4s
SSRA v28.4s, v4.4s, 31
SSRA v29.4s, v5.4s, 31
SSRA v30.4s, v6.4s, 31
SSRA v31.4s, v7.4s, 31
SRSHL v16.4s, v16.4s, v1.4s // signed rounding shift left
SRSHL v17.4s, v17.4s, v1.4s
SRSHL v18.4s, v18.4s, v1.4s
SRSHL v19.4s, v19.4s, v1.4s
SRSHL v20.4s, v20.4s, v1.4s
SRSHL v21.4s, v21.4s, v1.4s
SRSHL v22.4s, v22.4s, v1.4s
SRSHL v23.4s, v23.4s, v1.4s
SRSHL v24.4s, v24.4s, v1.4s
SRSHL v25.4s, v25.4s, v1.4s
SRSHL v26.4s, v26.4s, v1.4s
SRSHL v27.4s, v27.4s, v1.4s
SRSHL v28.4s, v28.4s, v1.4s
SRSHL v29.4s, v29.4s, v1.4s
SRSHL v30.4s, v30.4s, v1.4s
SRSHL v31.4s, v31.4s, v1.4s
SQXTN v16.4h, v16.4s
SQXTN v17.4h, v17.4s
SQXTN v18.4h, v18.4s
SQXTN v19.4h, v19.4s
SQXTN v24.4h, v24.4s
SQXTN v25.4h, v25.4s
SQXTN v26.4h, v26.4s
SQXTN v27.4h, v27.4s
LD1R {v2.8h}, [x11], 2 // add bias
SQXTN2 v16.8h, v20.4s
SQXTN2 v17.8h, v21.4s
SQXTN2 v18.8h, v22.4s
SQXTN2 v19.8h, v23.4s
SQXTN2 v24.8h, v28.4s
SQXTN2 v25.8h, v29.4s
SQXTN2 v26.8h, v30.4s
SQXTN2 v27.8h, v31.4s
SQADD v16.8h, v16.8h, v2.8h
SQADD v17.8h, v17.8h, v2.8h
SQADD v18.8h, v18.8h, v2.8h
SQADD v19.8h, v19.8h, v2.8h
SQADD v24.8h, v24.8h, v2.8h
SQADD v25.8h, v25.8h, v2.8h
SQADD v26.8h, v26.8h, v2.8h
SQADD v27.8h, v27.8h, v2.8h
LD1R {v0.16b}, [x11], 1 // clamp min value
SQXTN v4.8b, v16.8h
SQXTN v5.8b, v17.8h
SQXTN v6.8b, v18.8h
SQXTN v7.8b, v19.8h
LD1R {v1.16b}, [x11] // clamp max value
SQXTN2 v4.16b, v24.8h
SQXTN2 v5.16b, v25.8h
SQXTN2 v6.16b, v26.8h
SQXTN2 v7.16b, v27.8h
LDR x12, [sp] // cn_stride
SMAX v4.16b, v4.16b, v0.16b
SMAX v5.16b, v5.16b, v0.16b
SMAX v6.16b, v6.16b, v0.16b
SMAX v7.16b, v7.16b, v0.16b
SUBS x1, x1, 16
SMIN v4.16b, v4.16b, v1.16b
SMIN v5.16b, v5.16b, v1.16b
SMIN v6.16b, v6.16b, v1.16b
SMIN v7.16b, v7.16b, v1.16b
B.LO 4f
# Store full 4 x 16
ST1 {v4.16b}, [x6], x12
SUB x3, x3, x10 // a0 -= rewind
ST1 {v5.16b}, [x8], x12
SUB x15, x15, x10 // a1 -= rewind
ST1 {v6.16b}, [x9], x12
SUB x13, x13, x10 // a2 -= rewind
ST1 {v7.16b}, [x7], x12
SUB x4, x4, x10 // a3 -= rewind
B.NE 0b
RET
# Remainder- 1 to 4 bytes of A
.p2align 3
3:
LDR s0, [x3]
LDR q4, [x5], 16
LDR s1, [x15]
LDR s2, [x13]
LDR s3, [x4]
SDOT v16.4s, v4.16b, v0.4b[0]
LDR q5, [x5], 16
SDOT v17.4s, v4.16b, v1.4b[0]
SDOT v18.4s, v4.16b, v2.4b[0]
SDOT v19.4s, v4.16b, v3.4b[0]
SDOT v20.4s, v5.16b, v0.4b[0]
LDP q6, q7, [x5], 32
SDOT v21.4s, v5.16b, v1.4b[0]
SDOT v22.4s, v5.16b, v2.4b[0]
SDOT v23.4s, v5.16b, v3.4b[0]
SDOT v24.4s, v6.16b, v0.4b[0]
SDOT v25.4s, v6.16b, v1.4b[0]
SDOT v26.4s, v6.16b, v2.4b[0]
SDOT v27.4s, v6.16b, v3.4b[0]
SDOT v28.4s, v7.16b, v0.4b[0]
SDOT v29.4s, v7.16b, v1.4b[0]
SDOT v30.4s, v7.16b, v2.4b[0]
SDOT v31.4s, v7.16b, v3.4b[0]
B 2b
# Store odd width
.p2align 3
4:
TBZ x1, 3, 5f
STR d4, [x6], 8
DUP d4, v4.d[1]
STR d5, [x8], 8
DUP d5, v5.d[1]
STR d6, [x9], 8
DUP d6, v6.d[1]
STR d7, [x7], 8
DUP d7, v7.d[1]
5:
TBZ x1, 2, 6f
STR s4, [x6], 4
DUP s4, v4.s[1]
STR s5, [x8], 4
DUP s5, v5.s[1]
STR s6, [x9], 4
DUP s6, v6.s[1]
STR s7, [x7], 4
DUP s7, v7.s[1]
6:
TBZ x1, 1, 7f
ST1 {v4.h}[0], [x6], 2
DUP h4, v4.h[1]
ST1 {v5.h}[0], [x8], 2
DUP h5, v5.h[1]
ST1 {v6.h}[0], [x9], 2
DUP h6, v6.h[1]
ST1 {v7.h}[0], [x7], 2
DUP h7, v7.h[1]
7:
TBZ x1, 0, 8f
ST1 {v4.b}[0], [x6]
ST1 {v5.b}[0], [x8]
ST1 {v6.b}[0], [x9]
ST1 {v7.b}[0], [x7]
8:
RET
END_FUNCTION xnn_qs8_gemm_minmax_ukernel_4x16c4__aarch64_neondot_ld64
#ifdef __ELF__
.section ".note.GNU-stack","",%progbits
#endif