Refactor DotProd convolve_x_sr_12tap_neon path

Remove unnecessary variable forward declarations and tidy up naming
of the convolve helper functions.

Change-Id: I6b7c44430d56b28fbfc2d82b234ca6b96183baf5
diff --git a/av1/common/arm/convolve_neon.c b/av1/common/arm/convolve_neon.c
index e5908e9..a46cb16 100644
--- a/av1/common/arm/convolve_neon.c
+++ b/av1/common/arm/convolve_neon.c
@@ -81,78 +81,6 @@
                       vqrshrn_n_s32(sum[1], FILTER_BITS));
 }
 
-#elif AOM_ARCH_AARCH64 && defined(__ARM_FEATURE_DOTPROD)
-
-static INLINE int32x4_t convolve12_4_sdot(uint8x16_t samples,
-                                          const int8x16_t filters,
-                                          const int32x4_t correction,
-                                          const uint8x16_t range_limit,
-                                          const uint8x16x3_t permute_tbl) {
-  int8x16_t clamped_samples, permuted_samples[3];
-  int32x4_t sum;
-
-  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
-  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
-
-  /* Permute samples ready for dot product. */
-  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
-  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
-  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
-  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
-  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
-  permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
-
-  /* Accumulate dot product into 'correction' to account for range clamp. */
-  /* First 4 output values. */
-  sum = vdotq_laneq_s32(correction, permuted_samples[0], filters, 0);
-  sum = vdotq_laneq_s32(sum, permuted_samples[1], filters, 1);
-  sum = vdotq_laneq_s32(sum, permuted_samples[2], filters, 2);
-
-  return sum;
-}
-
-static INLINE int16x8_t convolve12_8_sdot(uint8x16_t samples0,
-                                          uint8x16_t samples1,
-                                          const int8x16_t filters,
-                                          const int32x4_t correction,
-                                          const uint8x16_t range_limit,
-                                          const uint8x16x3_t permute_tbl) {
-  int8x16_t clamped_samples[2], permuted_samples[4];
-  int32x4_t sum[2];
-
-  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
-  clamped_samples[0] = vreinterpretq_s8_u8(vsubq_u8(samples0, range_limit));
-  clamped_samples[1] = vreinterpretq_s8_u8(vsubq_u8(samples1, range_limit));
-
-  /* Permute samples ready for dot product. */
-  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
-  permuted_samples[0] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[0]);
-  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
-  permuted_samples[1] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[1]);
-  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
-  permuted_samples[2] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[2]);
-  /* {12, 13, 14, 15, 13, 14, 15, 16, 14, 15, 16, 17, 15, 16, 17, 18 } */
-  permuted_samples[3] = vqtbl1q_s8(clamped_samples[1], permute_tbl.val[2]);
-
-  /* Accumulate dot product into 'correction' to account for range clamp. */
-  /* First 4 output values. */
-  sum[0] = vdotq_laneq_s32(correction, permuted_samples[0], filters, 0);
-  sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[1], filters, 1);
-  sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[2], filters, 2);
-  /* Second 4 output values. */
-  sum[1] = vdotq_laneq_s32(correction, permuted_samples[1], filters, 0);
-  sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[2], filters, 1);
-  sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[3], filters, 2);
-
-  /* Narrow and re-pack. */
-  return vcombine_s16(vqrshrn_n_s32(sum[0], FILTER_BITS),
-                      vqrshrn_n_s32(sum[1], FILTER_BITS));
-}
-
-#endif  // AOM_ARCH_AARCH64 && defined(__ARM_FEATURE_MATMUL_INT8)
-
-#if AOM_ARCH_AARCH64 && defined(__ARM_FEATURE_MATMUL_INT8)
-
 void convolve_x_sr_12tap_neon(const uint8_t *src, int src_stride, uint8_t *dst,
                               int dst_stride, int w, int h,
                               const int16_t *x_filter_ptr) {
@@ -391,6 +319,72 @@
 
 #elif AOM_ARCH_AARCH64 && defined(__ARM_FEATURE_DOTPROD)
 
+static INLINE int16x4_t convolve12_4_x(uint8x16_t samples,
+                                       const int8x16_t filter,
+                                       const int32x4_t correction,
+                                       const uint8x16_t range_limit,
+                                       const uint8x16x3_t permute_tbl) {
+  int8x16_t clamped_samples, permuted_samples[3];
+  int32x4_t sum;
+
+  // Clamp sample range to [-128, 127] for 8-bit signed dot product.
+  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
+
+  // Permute samples ready for dot product.
+  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
+  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
+  // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
+  permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
+
+  // Accumulate dot product into 'correction' to account for range clamp.
+  // First 4 output values.
+  sum = vdotq_laneq_s32(correction, permuted_samples[0], filter, 0);
+  sum = vdotq_laneq_s32(sum, permuted_samples[1], filter, 1);
+  sum = vdotq_laneq_s32(sum, permuted_samples[2], filter, 2);
+
+  return vqrshrn_n_s32(sum, FILTER_BITS);
+}
+
+static INLINE uint8x8_t convolve12_8_x(uint8x16_t samples[2],
+                                       const int8x16_t filter,
+                                       const int32x4_t correction,
+                                       const uint8x16_t range_limit,
+                                       const uint8x16x3_t permute_tbl) {
+  int8x16_t clamped_samples[2], permuted_samples[4];
+  int32x4_t sum[2];
+
+  // Clamp sample range to [-128, 127] for 8-bit signed dot product.
+  clamped_samples[0] = vreinterpretq_s8_u8(vsubq_u8(samples[0], range_limit));
+  clamped_samples[1] = vreinterpretq_s8_u8(vsubq_u8(samples[1], range_limit));
+
+  // Permute samples ready for dot product.
+  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[0]);
+  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[1]);
+  // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
+  permuted_samples[2] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[2]);
+  // {12, 13, 14, 15, 13, 14, 15, 16, 14, 15, 16, 17, 15, 16, 17, 18 }
+  permuted_samples[3] = vqtbl1q_s8(clamped_samples[1], permute_tbl.val[2]);
+
+  // Accumulate dot product into 'correction' to account for range clamp.
+  // First 4 output values.
+  sum[0] = vdotq_laneq_s32(correction, permuted_samples[0], filter, 0);
+  sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[1], filter, 1);
+  sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[2], filter, 2);
+  // Second 4 output values.
+  sum[1] = vdotq_laneq_s32(correction, permuted_samples[1], filter, 0);
+  sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[2], filter, 1);
+  sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[3], filter, 2);
+
+  // Narrow and re-pack.
+  int16x8_t sum_s16 = vcombine_s16(vqrshrn_n_s32(sum[0], FILTER_BITS),
+                                   vqrshrn_n_s32(sum[1], FILTER_BITS));
+  return vqmovun_s16(sum_s16);
+}
+
 void convolve_x_sr_12tap_neon(const uint8_t *src, int src_stride, uint8_t *dst,
                               int dst_stride, int w, int h,
                               const int16_t *x_filter_ptr) {
@@ -400,15 +394,14 @@
   const int8x16_t filter =
       vcombine_s8(vmovn_s16(filter_0_7), vmovn_s16(filter_8_15));
 
-  const int32x4_t correct_tmp =
-      vaddq_s32(vpaddlq_s16(vshlq_n_s16(filter_0_7, 7)),
-                vpaddlq_s16(vshlq_n_s16(filter_8_15, 7)));
-  // This shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding
-  // right shift by FILTER_BITS - instead of a first rounding right shift by
+  const int32_t correction_s32 =
+      vaddvq_s32(vaddq_s32(vpaddlq_s16(vshlq_n_s16(filter_0_7, FILTER_BITS)),
+                           vpaddlq_s16(vshlq_n_s16(filter_8_15, FILTER_BITS))));
+  // A shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding right
+  // shift by FILTER_BITS - instead of a first rounding right shift by
   // ROUND0_BITS, followed by second rounding right shift by FILTER_BITS -
   // ROUND0_BITS.
-  int32x4_t correction =
-      vdupq_n_s32(vaddvq_s32(correct_tmp) + (1 << (ROUND0_BITS - 1)));
+  int32x4_t correction = vdupq_n_s32(correction_s32 + (1 << (ROUND0_BITS - 1)));
   const uint8x16_t range_limit = vdupq_n_u8(128);
   const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
 
@@ -416,49 +409,48 @@
   // 8-bit signed dot-product instruction:
   // { 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0 }
   if (vgetq_lane_s16(filter_0_7, 5) == 128) {
-    uint8x8_t d0;
-
     // Undo the horizontal offset in the calling function.
     src += 5;
 
-    for (int i = 0; i < h; i++) {
-      for (int j = 0; j < w; j += 8) {
-        d0 = vld1_u8(src + i * src_stride + j);
-        if (w == 2) {
-          store_u8_2x1(dst + i * dst_stride, d0, 0);
-        } else if (w == 4) {
-          store_u8_4x1(dst + i * dst_stride, d0, 0);
-        } else {
-          vst1_u8(dst + i * dst_stride + j, d0);
-        }
-      }
-    }
-  } else {
-    if (w <= 4) {
-      uint8x16_t s0, s1, s2, s3;
-      int32x4_t d0, d1, d2, d3;
-      int16x8_t t01, t23;
-      uint8x8_t d01, d23;
+    do {
+      const uint8_t *s = src;
+      uint8_t *d = dst;
+      int width = w;
 
       do {
+        uint8x8_t d0 = vld1_u8(s);
+        if (w == 2) {
+          store_u8_2x1(d, d0, 0);
+        } else if (w == 4) {
+          store_u8_4x1(d, d0, 0);
+        } else {
+          vst1_u8(d, d0);
+        }
+
+        s += 8;
+        d += 8;
+        width -= 8;
+      } while (width > 0);
+      src += src_stride;
+      dst += dst_stride;
+    } while (--h != 0);
+  } else {
+    if (w <= 4) {
+      do {
+        uint8x16_t s0, s1, s2, s3;
         load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
 
-        d0 =
-            convolve12_4_sdot(s0, filter, correction, range_limit, permute_tbl);
-        d1 =
-            convolve12_4_sdot(s1, filter, correction, range_limit, permute_tbl);
-        d2 =
-            convolve12_4_sdot(s2, filter, correction, range_limit, permute_tbl);
-        d3 =
-            convolve12_4_sdot(s3, filter, correction, range_limit, permute_tbl);
+        int16x4_t d0 =
+            convolve12_4_x(s0, filter, correction, range_limit, permute_tbl);
+        int16x4_t d1 =
+            convolve12_4_x(s1, filter, correction, range_limit, permute_tbl);
+        int16x4_t d2 =
+            convolve12_4_x(s2, filter, correction, range_limit, permute_tbl);
+        int16x4_t d3 =
+            convolve12_4_x(s3, filter, correction, range_limit, permute_tbl);
 
-        t01 = vcombine_s16(vqrshrn_n_s32(d0, FILTER_BITS),
-                           vqrshrn_n_s32(d1, FILTER_BITS));
-        t23 = vcombine_s16(vqrshrn_n_s32(d2, FILTER_BITS),
-                           vqrshrn_n_s32(d3, FILTER_BITS));
-
-        d01 = vqmovun_s16(t01);
-        d23 = vqmovun_s16(t23);
+        uint8x8_t d01 = vqmovun_s16(vcombine_s16(d0, d1));
+        uint8x8_t d23 = vqmovun_s16(vcombine_s16(d2, d3));
 
         if (w == 2) {
           store_u8_2x1(dst + 0 * dst_stride, d01, 0);
@@ -481,36 +473,28 @@
         h -= 4;
       } while (h > 0);
     } else {
-      uint8x16_t s0, s1, s2, s3, s4, s5, s6, s7;
-      int16x8_t d0, d1, d2, d3;
-      uint8x8_t dd0, dd1, dd2, dd3;
-
       do {
         const uint8_t *s = src;
         uint8_t *d = dst;
         int width = w;
 
         do {
-          load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
-          load_u8_16x4(s + 4, src_stride, &s4, &s5, &s6, &s7);
+          uint8x16_t s0[2], s1[2], s2[2], s3[2];
+          load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
+          load_u8_16x4(s + 4, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
 
-          d0 = convolve12_8_sdot(s0, s4, filter, correction, range_limit,
-                                 permute_tbl);
-          d1 = convolve12_8_sdot(s1, s5, filter, correction, range_limit,
-                                 permute_tbl);
-          d2 = convolve12_8_sdot(s2, s6, filter, correction, range_limit,
-                                 permute_tbl);
-          d3 = convolve12_8_sdot(s3, s7, filter, correction, range_limit,
-                                 permute_tbl);
+          uint8x8_t d0 =
+              convolve12_8_x(s0, filter, correction, range_limit, permute_tbl);
+          uint8x8_t d1 =
+              convolve12_8_x(s1, filter, correction, range_limit, permute_tbl);
+          uint8x8_t d2 =
+              convolve12_8_x(s2, filter, correction, range_limit, permute_tbl);
+          uint8x8_t d3 =
+              convolve12_8_x(s3, filter, correction, range_limit, permute_tbl);
 
-          dd0 = vqmovun_s16(d0);
-          dd1 = vqmovun_s16(d1);
-          dd2 = vqmovun_s16(d2);
-          dd3 = vqmovun_s16(d3);
-
-          store_u8_8x2(d + 0 * dst_stride, dst_stride, dd0, dd1);
+          store_u8_8x2(d + 0 * dst_stride, dst_stride, d0, d1);
           if (h != 2) {
-            store_u8_8x2(d + 2 * dst_stride, dst_stride, dd2, dd3);
+            store_u8_8x2(d + 2 * dst_stride, dst_stride, d2, d3);
           }
 
           s += 8;