arm64: looprestoration: Apply simplifications to align with C code
This applies the same simplifications that were done for the C
code and the x86 assembly in 4613d3a5306e44f3fdc39989e6bb218841e78097,
to the arm64 implementation.
This gives a minor speedup of around a couple percent.
Before: Cortex A53 A55 A72 A73 A76 Apple
M3
sgr_3x3_8bpc_neon: 368583.2 363654.2 279958.1 272065.1 169353.3 354.6
sgr_5x5_8bpc_neon: 258570.7 255018.5 200410.6 199478.3 117968.3 260.9
sgr_mix_8bpc_neon: 603698.1 577383.3 482468.3 436540.4 256632.9 541.8
After:
sgr_3x3_8bpc_neon: 367873.2 357884.1 275462.4 268363.9 165909.8 346.0
sgr_5x5_8bpc_neon: 254988.4 248184.2 190875.1 196939.1 120517.2 252.1
sgr_mix_8bpc_neon: 589204.7 563565.8 414025.6 427702.2 251651.2 533.4
diff --git a/src/arm/64/looprestoration_common.S b/src/arm/64/looprestoration_common.S
index c10a9f3..13c2fd5 100644
--- a/src/arm/64/looprestoration_common.S
+++ b/src/arm/64/looprestoration_common.S
@@ -119,7 +119,6 @@
ldr q23, [x12, #64] //RefLo
dup v6.8h, w9 // -bitdepth_min_8
saddl v7.4s, v6.4h, v6.4h // -2*bitdepth_min_8
- movi v29.8h, #1, lsl #8
dup v30.4s, w13 // one_by_x
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x5], #64
@@ -198,8 +197,6 @@
umull2 v3.4s, v0.8h, v4.8h // x * BB[i]
umull v4.4s, v1.4h, v5.4h // x * BB[i]
umull2 v5.4s, v1.8h, v5.8h // x * BB[i]
- sub v0.8h, v29.8h, v0.8h // 256 - x
- sub v1.8h, v29.8h, v1.8h // 256 - x
mul v2.4s, v2.4s, v30.4s // x * BB[i] * sgr_one_by_x
mul v3.4s, v3.4s, v30.4s // x * BB[i] * sgr_one_by_x
mul v4.4s, v4.4s, v30.4s // x * BB[i] * sgr_one_by_x
@@ -251,7 +248,6 @@
movi v19.8b, #0x7
ldr q18, [x13, #64] // RefLo
saddl v7.4s, v6.4h, v6.4h // -2*bitdepth_min_8
- movi v29.8h, #1, lsl #8
ld1 {v8.4s, v9.4s}, [x5], #32
ld1 {v10.4s, v11.4s}, [x6], #32
@@ -325,7 +321,6 @@
mul v4.4s, v4.4s, v30.4s // x * BB[i] * sgr_one_by_x
srshr v3.4s, v3.4s, #12 // AA[i]
srshr v4.4s, v4.4s, #12 // AA[i]
- sub v5.8h, v29.8h, v5.8h // 256 - x
ld1 {v2.8h}, [x1], #16
st1 {v3.4s, v4.4s}, [x2], #32
diff --git a/src/arm/64/looprestoration_tmpl.S b/src/arm/64/looprestoration_tmpl.S
index 1373f9a..3ee9aaf 100644
--- a/src/arm/64/looprestoration_tmpl.S
+++ b/src/arm/64/looprestoration_tmpl.S
@@ -174,10 +174,10 @@
mla v14.4s, v19.4s, v31.4s // * 3 -> b
mla v15.4s, v20.4s, v31.4s
- umlal v8.4s, v4.4h, v25.4h // b + a * src
- umlal2 v9.4s, v4.8h, v25.8h
- umlal v14.4s, v0.4h, v26.4h // b + a * src
- umlal2 v15.4s, v0.8h, v26.8h
+ umlsl v8.4s, v4.4h, v25.4h // b - a * src
+ umlsl2 v9.4s, v4.8h, v25.8h
+ umlsl v14.4s, v0.4h, v26.4h // b - a * src
+ umlsl2 v15.4s, v0.8h, v26.8h
mov v0.16b, v1.16b
rshrn v8.4h, v8.4s, #9
rshrn2 v8.8h, v9.4s, #9
@@ -292,8 +292,8 @@
uxtl v19.8h, v19.8b // src
.endif
mov v0.16b, v1.16b
- umlal v25.4s, v2.4h, v19.4h // b + a * src
- umlal2 v26.4s, v2.8h, v19.8h
+ umlsl v25.4s, v2.4h, v19.4h // b - a * src
+ umlsl2 v26.4s, v2.8h, v19.8h
mov v2.16b, v3.16b
rshrn v25.4h, v25.4s, #9
rshrn2 v25.8h, v26.4s, #9
@@ -301,30 +301,25 @@
subs w3, w3, #8
// weighted1
- shl v19.8h, v19.8h, #4 // u
mov v4.16b, v5.16b
- sub v25.8h, v25.8h, v19.8h // t1 - u
ld1 {v1.8h}, [x9], #16
- ushll v26.4s, v19.4h, #7 // u << 7
- ushll2 v27.4s, v19.8h, #7 // u << 7
ld1 {v3.8h}, [x10], #16
- smlal v26.4s, v25.4h, v31.4h // v
- smlal2 v27.4s, v25.8h, v31.8h // v
+ smull v26.4s, v25.4h, v31.4h // v = t1 * w1
+ smull2 v27.4s, v25.8h, v31.8h
ld1 {v5.8h}, [x2], #16
-.if \bpc == 8
rshrn v26.4h, v26.4s, #11
rshrn2 v26.8h, v27.4s, #11
+ usqadd v19.8h, v26.8h
+.if \bpc == 8
mov v16.16b, v18.16b
- sqxtun v26.8b, v26.8h
+ sqxtun v26.8b, v19.8h
mov v19.16b, v21.16b
mov v22.16b, v24.16b
st1 {v26.8b}, [x0], #8
.else
- sqrshrun v26.4h, v26.4s, #11
- sqrshrun2 v26.8h, v27.4s, #11
mov v16.16b, v18.16b
- umin v26.8h, v26.8h, v30.8h
+ umin v26.8h, v19.8h, v30.8h
mov v19.16b, v21.16b
mov v22.16b, v24.16b
st1 {v26.8h}, [x0], #16
@@ -424,10 +419,10 @@
uxtl v31.8h, v31.8b
uxtl v30.8h, v30.8b
.endif
- umlal v16.4s, v0.4h, v31.4h // b + a * src
- umlal2 v17.4s, v0.8h, v31.8h
- umlal v9.4s, v8.4h, v30.4h // b + a * src
- umlal2 v10.4s, v8.8h, v30.8h
+ umlsl v16.4s, v0.4h, v31.4h // b - a * src
+ umlsl2 v17.4s, v0.8h, v31.8h
+ umlsl v9.4s, v8.4h, v30.4h // b - a * src
+ umlsl2 v10.4s, v8.8h, v30.8h
mov v0.16b, v1.16b
rshrn v16.4h, v16.4s, #9
rshrn2 v16.8h, v17.4s, #9
@@ -541,10 +536,10 @@
uxtl v31.8h, v31.8b
uxtl v30.8h, v30.8b
.endif
- umlal v16.4s, v0.4h, v31.4h // b + a * src
- umlal2 v17.4s, v0.8h, v31.8h
- umlal v9.4s, v8.4h, v30.4h // b + a * src
- umlal2 v10.4s, v8.8h, v30.8h
+ umlsl v16.4s, v0.4h, v31.4h // b - a * src
+ umlsl2 v17.4s, v0.8h, v31.8h
+ umlsl v9.4s, v8.4h, v30.4h // b - a * src
+ umlsl2 v10.4s, v8.8h, v30.8h
mov v0.16b, v1.16b
rshrn v16.4h, v16.4s, #9
rshrn2 v16.8h, v17.4s, #9
@@ -554,40 +549,30 @@
subs w4, w4, #8
// weighted1
- shl v31.8h, v31.8h, #4 // u
- shl v30.8h, v30.8h, #4
mov v2.16b, v3.16b
- sub v16.8h, v16.8h, v31.8h // t1 - u
- sub v9.8h, v9.8h, v30.8h
ld1 {v1.8h}, [x3], #16
- ushll v22.4s, v31.4h, #7 // u << 7
- ushll2 v23.4s, v31.8h, #7
- ushll v24.4s, v30.4h, #7
- ushll2 v25.4s, v30.8h, #7
ld1 {v3.8h}, [x8], #16
- smlal v22.4s, v16.4h, v14.4h // v
- smlal2 v23.4s, v16.8h, v14.8h
+ smull v22.4s, v16.4h, v14.4h // v
+ smull2 v23.4s, v16.8h, v14.8h
mov v16.16b, v18.16b
- smlal v24.4s, v9.4h, v14.4h
- smlal2 v25.4s, v9.8h, v14.8h
+ smull v24.4s, v9.4h, v14.4h
+ smull2 v25.4s, v9.8h, v14.8h
mov v19.16b, v21.16b
-.if \bpc == 8
rshrn v22.4h, v22.4s, #11
rshrn2 v22.8h, v23.4s, #11
rshrn v23.4h, v24.4s, #11
rshrn2 v23.8h, v25.4s, #11
- sqxtun v22.8b, v22.8h
- sqxtun v23.8b, v23.8h
+ usqadd v31.8h, v22.8h
+ usqadd v30.8h, v23.8h
+.if \bpc == 8
+ sqxtun v22.8b, v31.8h
+ sqxtun v23.8b, v30.8h
st1 {v22.8b}, [x0], #8
st1 {v23.8b}, [x1], #8
.else
- sqrshrun v22.4h, v22.4s, #11
- sqrshrun2 v22.8h, v23.4s, #11
- sqrshrun v23.4h, v24.4s, #11
- sqrshrun2 v23.8h, v25.4s, #11
- umin v22.8h, v22.8h, v15.8h
- umin v23.8h, v23.8h, v15.8h
+ umin v22.8h, v31.8h, v15.8h
+ umin v23.8h, v30.8h, v15.8h
st1 {v22.8h}, [x0], #16
st1 {v23.8h}, [x1], #16
.endif
@@ -653,44 +638,31 @@
ld1 {v18.8h}, [x13], #16
subs w6, w6, #8
.if \bpc == 8
- ushll v0.8h, v0.8b, #4 // u
- ushll v16.8h, v16.8b, #4 // u
-.else
- shl v0.8h, v0.8h, #4 // u
- shl v16.8h, v16.8h, #4 // u
+ uxtl v0.8h, v0.8b
+ uxtl v16.8h, v16.8b
.endif
- sub v1.8h, v1.8h, v0.8h // t1 - u
- sub v2.8h, v2.8h, v0.8h // t2 - u
- sub v17.8h, v17.8h, v16.8h // t1 - u
- sub v18.8h, v18.8h, v16.8h // t2 - u
- ushll v3.4s, v0.4h, #7 // u << 7
- ushll2 v4.4s, v0.8h, #7 // u << 7
- ushll v19.4s, v16.4h, #7 // u << 7
- ushll2 v20.4s, v16.8h, #7 // u << 7
- smlal v3.4s, v1.4h, v30.4h // wt[0] * (t1 - u)
- smlal v3.4s, v2.4h, v31.4h // wt[1] * (t2 - u)
- smlal2 v4.4s, v1.8h, v30.8h // wt[0] * (t1 - u)
- smlal2 v4.4s, v2.8h, v31.8h // wt[1] * (t2 - u)
- smlal v19.4s, v17.4h, v30.4h // wt[0] * (t1 - u)
- smlal v19.4s, v18.4h, v31.4h // wt[1] * (t2 - u)
- smlal2 v20.4s, v17.8h, v30.8h // wt[0] * (t1 - u)
- smlal2 v20.4s, v18.8h, v31.8h // wt[1] * (t2 - u)
-.if \bpc == 8
+ smull v3.4s, v1.4h, v30.4h // wt[0] * t1
+ smlal v3.4s, v2.4h, v31.4h // wt[1] * t2
+ smull2 v4.4s, v1.8h, v30.8h // wt[0] * t1
+ smlal2 v4.4s, v2.8h, v31.8h // wt[1] * t2
+ smull v19.4s, v17.4h, v30.4h // wt[0] * t1
+ smlal v19.4s, v18.4h, v31.4h // wt[1] * t2
+ smull2 v20.4s, v17.8h, v30.8h // wt[0] * t1
+ smlal2 v20.4s, v18.8h, v31.8h // wt[1] * t2
rshrn v3.4h, v3.4s, #11
rshrn2 v3.8h, v4.4s, #11
rshrn v19.4h, v19.4s, #11
rshrn2 v19.8h, v20.4s, #11
- sqxtun v3.8b, v3.8h
- sqxtun v19.8b, v19.8h
+ usqadd v0.8h, v3.8h
+ usqadd v16.8h, v19.8h
+.if \bpc == 8
+ sqxtun v3.8b, v0.8h
+ sqxtun v19.8b, v16.8h
st1 {v3.8b}, [x0], #8
st1 {v19.8b}, [x10], #8
.else
- sqrshrun v3.4h, v3.4s, #11
- sqrshrun2 v3.8h, v4.4s, #11
- sqrshrun v19.4h, v19.4s, #11
- sqrshrun2 v19.8h, v20.4s, #11
- umin v3.8h, v3.8h, v29.8h
- umin v19.8h, v19.8h, v29.8h
+ umin v3.8h, v0.8h, v29.8h
+ umin v19.8h, v16.8h, v29.8h
st1 {v3.8h}, [x0], #16
st1 {v19.8h}, [x10], #16
.endif
@@ -721,27 +693,20 @@
ld1 {v2.8h}, [x5], #16
subs w6, w6, #8
.if \bpc == 8
- ushll v0.8h, v0.8b, #4 // u
-.else
- shl v0.8h, v0.8h, #4 // u
+ uxtl v0.8h, v0.8b
.endif
- sub v1.8h, v1.8h, v0.8h // t1 - u
- sub v2.8h, v2.8h, v0.8h // t2 - u
- ushll v3.4s, v0.4h, #7 // u << 7
- ushll2 v4.4s, v0.8h, #7 // u << 7
- smlal v3.4s, v1.4h, v30.4h // wt[0] * (t1 - u)
- smlal v3.4s, v2.4h, v31.4h // wt[1] * (t2 - u)
- smlal2 v4.4s, v1.8h, v30.8h // wt[0] * (t1 - u)
- smlal2 v4.4s, v2.8h, v31.8h // wt[1] * (t2 - u)
-.if \bpc == 8
+ smull v3.4s, v1.4h, v30.4h // wt[0] * t1
+ smlal v3.4s, v2.4h, v31.4h // wt[1] * t2
+ smull2 v4.4s, v1.8h, v30.8h // wt[0] * t1
+ smlal2 v4.4s, v2.8h, v31.8h // wt[1] * t2
rshrn v3.4h, v3.4s, #11
rshrn2 v3.8h, v4.4s, #11
- sqxtun v3.8b, v3.8h
+ usqadd v0.8h, v3.8h
+.if \bpc == 8
+ sqxtun v3.8b, v0.8h
st1 {v3.8b}, [x0], #8
.else
- sqrshrun v3.4h, v3.4s, #11
- sqrshrun2 v3.8h, v4.4s, #11
- umin v3.8h, v3.8h, v29.8h
+ umin v3.8h, v0.8h, v29.8h
st1 {v3.8h}, [x0], #16
.endif
b.gt 1b