Use 16-bit vectors between h/v filters in highbd_warp_affine_neon

In highbd warp affine we currently perform a widening multiply to
32-bits during the horizontal filter and then maintain this through to
the vertical filter where we do a non-widening 32-bit multiply. This
also means that we must first widen the filter values in the vertical
filter step since these are stored in memory as 16-bit values.

It seems that once the horizontal filter is complete we can instead
narrow back to 16-bits and then perform a second widening multiply in
the vertical step, saving the need to separately widen the filter
values.

This approach should also allow us to make use of the SVE 16-bit dot
product instructions in a future commit.

Benchmarking this on a Neoverse N2 machine with Clang 16 and GCC 12
shows a geomean 5.2% reduction in runtime of the warp affine speed
tests.

Change-Id: I9891d919b6a53edc7e5ca8b6dca1ee23f46ef45a
diff --git a/av1/common/arm/highbd_warp_plane_neon.c b/av1/common/arm/highbd_warp_plane_neon.c
index b83cd4d..9dc76ec 100644
--- a/av1/common/arm/highbd_warp_plane_neon.c
+++ b/av1/common/arm/highbd_warp_plane_neon.c
@@ -65,11 +65,8 @@
   out[7] = vld1q_s16(base + ofs7 * 8);
 }
 
-static INLINE int32x4x2_t warp_affine_horizontal_step_4x1_f4_neon(
+static INLINE int16x8_t warp_affine_horizontal_step_4x1_f4_neon(
     int bd, int sx, int alpha, uint16x8x2_t in) {
-  const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
-  const int offset_bits_horiz = bd + FILTER_BITS - 1;
-
   int16x8_t f[4];
   load_filters_4(f, sx, alpha);
 
@@ -93,15 +90,16 @@
 
   int32x4_t m0123[] = { m0, m1, m2, m3 };
 
-  int32x4x2_t res;
-  res.val[0] = horizontal_add_4d_s32x4(m0123);
-  res.val[0] = vaddq_s32(res.val[0], vdupq_n_s32(1 << offset_bits_horiz));
-  res.val[0] = vrshlq_s32(res.val[0], vdupq_n_s32(-round0));
-  res.val[1] = vdupq_n_s32(0);
-  return res;
+  const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
+  const int offset_bits_horiz = bd + FILTER_BITS - 1;
+
+  int32x4_t res = horizontal_add_4d_s32x4(m0123);
+  res = vaddq_s32(res, vdupq_n_s32(1 << offset_bits_horiz));
+  res = vrshlq_s32(res, vdupq_n_s32(-round0));
+  return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
 }
 
-static INLINE int32x4x2_t warp_affine_horizontal_step_8x1_f8_neon(
+static INLINE int16x8_t warp_affine_horizontal_step_8x1_f8_neon(
     int bd, int sx, int alpha, uint16x8x2_t in) {
   const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
   const int offset_bits_horiz = bd + FILTER_BITS - 1;
@@ -146,22 +144,20 @@
   int32x4_t m0123[] = { m0, m1, m2, m3 };
   int32x4_t m4567[] = { m4, m5, m6, m7 };
 
-  int32x4x2_t res;
-  res.val[0] = horizontal_add_4d_s32x4(m0123);
-  res.val[1] = horizontal_add_4d_s32x4(m4567);
-
-  res.val[0] = vaddq_s32(res.val[0], vdupq_n_s32(1 << offset_bits_horiz));
-  res.val[1] = vaddq_s32(res.val[1], vdupq_n_s32(1 << offset_bits_horiz));
-  res.val[0] = vrshlq_s32(res.val[0], vdupq_n_s32(-round0));
-  res.val[1] = vrshlq_s32(res.val[1], vdupq_n_s32(-round0));
-  return res;
+  int32x4_t res0 = horizontal_add_4d_s32x4(m0123);
+  int32x4_t res1 = horizontal_add_4d_s32x4(m4567);
+  res0 = vaddq_s32(res0, vdupq_n_s32(1 << offset_bits_horiz));
+  res1 = vaddq_s32(res1, vdupq_n_s32(1 << offset_bits_horiz));
+  res0 = vrshlq_s32(res0, vdupq_n_s32(-round0));
+  res1 = vrshlq_s32(res1, vdupq_n_s32(-round0));
+  return vcombine_s16(vmovn_s32(res0), vmovn_s32(res1));
 }
 
 static INLINE void warp_affine_horizontal_neon(const uint16_t *ref, int width,
                                                int height, int stride,
                                                int p_width, int16_t alpha,
                                                int16_t beta, int iy4, int sx4,
-                                               int ix4, int32x4x2_t tmp[],
+                                               int ix4, int16x8_t tmp[],
                                                int bd) {
   const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
 
@@ -170,8 +166,7 @@
       int iy = clamp(iy4 + k - 7, 0, height - 1);
       int32_t dup_val = (1 << (bd + FILTER_BITS - round0 - 1)) +
                         ref[iy * stride] * (1 << (FILTER_BITS - round0));
-      tmp[k].val[0] = vdupq_n_s32(dup_val);
-      tmp[k].val[1] = vdupq_n_s32(dup_val);
+      tmp[k] = vdupq_n_s16(dup_val);
     }
     return;
   } else if (ix4 >= width + 6) {
@@ -180,8 +175,7 @@
       int32_t dup_val =
           (1 << (bd + FILTER_BITS - round0 - 1)) +
           ref[iy * stride + (width - 1)] * (1 << (FILTER_BITS - round0));
-      tmp[k].val[0] = vdupq_n_s32(dup_val);
-      tmp[k].val[1] = vdupq_n_s32(dup_val);
+      tmp[k] = vdupq_n_s16(dup_val);
     }
     return;
   }
@@ -230,106 +224,103 @@
 }
 
 static INLINE int32x4_t
-warp_affine_vertical_filter_4x1_f1_neon(const int32x4x2_t *tmp, int sy) {
+warp_affine_vertical_filter_4x1_f1_neon(const int16x8_t *tmp, int sy) {
   const int16x8_t f = load_filters_1(sy);
-  const int32x2_t f01 = vget_low_s32(vmovl_s16(vget_low_s16(f)));
-  const int32x2_t f23 = vget_high_s32(vmovl_s16(vget_low_s16(f)));
-  const int32x2_t f45 = vget_low_s32(vmovl_s16(vget_high_s16(f)));
-  const int32x2_t f67 = vget_high_s32(vmovl_s16(vget_high_s16(f)));
+  const int16x4_t f0123 = vget_low_s16(f);
+  const int16x4_t f4567 = vget_high_s16(f);
 
-  int32x4_t m0123 = vmulq_lane_s32(tmp[0].val[0], f01, 0);
-  m0123 = vmlaq_lane_s32(m0123, tmp[1].val[0], f01, 1);
-  m0123 = vmlaq_lane_s32(m0123, tmp[2].val[0], f23, 0);
-  m0123 = vmlaq_lane_s32(m0123, tmp[3].val[0], f23, 1);
-  m0123 = vmlaq_lane_s32(m0123, tmp[4].val[0], f45, 0);
-  m0123 = vmlaq_lane_s32(m0123, tmp[5].val[0], f45, 1);
-  m0123 = vmlaq_lane_s32(m0123, tmp[6].val[0], f67, 0);
-  m0123 = vmlaq_lane_s32(m0123, tmp[7].val[0], f67, 1);
+  int32x4_t m0123 = vmull_lane_s16(vget_low_s16(tmp[0]), f0123, 0);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[1]), f0123, 1);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[2]), f0123, 2);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[3]), f0123, 3);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[4]), f4567, 0);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[5]), f4567, 1);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[6]), f4567, 2);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[7]), f4567, 3);
   return m0123;
 }
 
 static INLINE int32x4x2_t
-warp_affine_vertical_filter_8x1_f1_neon(const int32x4x2_t *tmp, int sy) {
+warp_affine_vertical_filter_8x1_f1_neon(const int16x8_t *tmp, int sy) {
   const int16x8_t f = load_filters_1(sy);
-  const int32x2_t f01 = vget_low_s32(vmovl_s16(vget_low_s16(f)));
-  const int32x2_t f23 = vget_high_s32(vmovl_s16(vget_low_s16(f)));
-  const int32x2_t f45 = vget_low_s32(vmovl_s16(vget_high_s16(f)));
-  const int32x2_t f67 = vget_high_s32(vmovl_s16(vget_high_s16(f)));
+  const int16x4_t f0123 = vget_low_s16(f);
+  const int16x4_t f4567 = vget_high_s16(f);
 
-  int32x4_t m0123 = vmulq_lane_s32(tmp[0].val[0], f01, 0);
-  m0123 = vmlaq_lane_s32(m0123, tmp[1].val[0], f01, 1);
-  m0123 = vmlaq_lane_s32(m0123, tmp[2].val[0], f23, 0);
-  m0123 = vmlaq_lane_s32(m0123, tmp[3].val[0], f23, 1);
-  m0123 = vmlaq_lane_s32(m0123, tmp[4].val[0], f45, 0);
-  m0123 = vmlaq_lane_s32(m0123, tmp[5].val[0], f45, 1);
-  m0123 = vmlaq_lane_s32(m0123, tmp[6].val[0], f67, 0);
-  m0123 = vmlaq_lane_s32(m0123, tmp[7].val[0], f67, 1);
+  int32x4_t m0123 = vmull_lane_s16(vget_low_s16(tmp[0]), f0123, 0);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[1]), f0123, 1);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[2]), f0123, 2);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[3]), f0123, 3);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[4]), f4567, 0);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[5]), f4567, 1);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[6]), f4567, 2);
+  m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[7]), f4567, 3);
 
-  int32x4_t m4567 = vmulq_lane_s32(tmp[0].val[1], f01, 0);
-  m4567 = vmlaq_lane_s32(m4567, tmp[1].val[1], f01, 1);
-  m4567 = vmlaq_lane_s32(m4567, tmp[2].val[1], f23, 0);
-  m4567 = vmlaq_lane_s32(m4567, tmp[3].val[1], f23, 1);
-  m4567 = vmlaq_lane_s32(m4567, tmp[4].val[1], f45, 0);
-  m4567 = vmlaq_lane_s32(m4567, tmp[5].val[1], f45, 1);
-  m4567 = vmlaq_lane_s32(m4567, tmp[6].val[1], f67, 0);
-  m4567 = vmlaq_lane_s32(m4567, tmp[7].val[1], f67, 1);
+  int32x4_t m4567 = vmull_lane_s16(vget_high_s16(tmp[0]), f0123, 0);
+  m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[1]), f0123, 1);
+  m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[2]), f0123, 2);
+  m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[3]), f0123, 3);
+  m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[4]), f4567, 0);
+  m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[5]), f4567, 1);
+  m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[6]), f4567, 2);
+  m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[7]), f4567, 3);
   return (int32x4x2_t){ { m0123, m4567 } };
 }
 
 static INLINE int32x4_t warp_affine_vertical_filter_4x1_f4_neon(
-    const int32x4x2_t *tmp, int sy, int gamma) {
-  int32x4x2_t s0, s1, s2, s3;
-  transpose_s32_4x8(tmp[0].val[0], tmp[1].val[0], tmp[2].val[0], tmp[3].val[0],
-                    tmp[4].val[0], tmp[5].val[0], tmp[6].val[0], tmp[7].val[0],
-                    &s0, &s1, &s2, &s3);
+    const int16x8_t *tmp, int sy, int gamma) {
+  int16x8_t s0, s1, s2, s3;
+  transpose_s16_4x8(
+      vget_low_s16(tmp[0]), vget_low_s16(tmp[1]), vget_low_s16(tmp[2]),
+      vget_low_s16(tmp[3]), vget_low_s16(tmp[4]), vget_low_s16(tmp[5]),
+      vget_low_s16(tmp[6]), vget_low_s16(tmp[7]), &s0, &s1, &s2, &s3);
 
   int16x8_t f[4];
   load_filters_4(f, sy, gamma);
 
-  int32x4_t m0 = vmulq_s32(s0.val[0], vmovl_s16(vget_low_s16(f[0])));
-  m0 = vmlaq_s32(m0, s0.val[1], vmovl_s16(vget_high_s16(f[0])));
-  int32x4_t m1 = vmulq_s32(s1.val[0], vmovl_s16(vget_low_s16(f[1])));
-  m1 = vmlaq_s32(m1, s1.val[1], vmovl_s16(vget_high_s16(f[1])));
-  int32x4_t m2 = vmulq_s32(s2.val[0], vmovl_s16(vget_low_s16(f[2])));
-  m2 = vmlaq_s32(m2, s2.val[1], vmovl_s16(vget_high_s16(f[2])));
-  int32x4_t m3 = vmulq_s32(s3.val[0], vmovl_s16(vget_low_s16(f[3])));
-  m3 = vmlaq_s32(m3, s3.val[1], vmovl_s16(vget_high_s16(f[3])));
+  int32x4_t m0 = vmull_s16(vget_low_s16(s0), vget_low_s16(f[0]));
+  m0 = vmlal_s16(m0, vget_high_s16(s0), vget_high_s16(f[0]));
+  int32x4_t m1 = vmull_s16(vget_low_s16(s1), vget_low_s16(f[1]));
+  m1 = vmlal_s16(m1, vget_high_s16(s1), vget_high_s16(f[1]));
+  int32x4_t m2 = vmull_s16(vget_low_s16(s2), vget_low_s16(f[2]));
+  m2 = vmlal_s16(m2, vget_high_s16(s2), vget_high_s16(f[2]));
+  int32x4_t m3 = vmull_s16(vget_low_s16(s3), vget_low_s16(f[3]));
+  m3 = vmlal_s16(m3, vget_high_s16(s3), vget_high_s16(f[3]));
 
   int32x4_t m0123[] = { m0, m1, m2, m3 };
   return horizontal_add_4d_s32x4(m0123);
 }
 
 static INLINE int32x4x2_t warp_affine_vertical_filter_8x1_f8_neon(
-    const int32x4x2_t *tmp, int sy, int gamma) {
-  int32x4x2_t s0 = tmp[0];
-  int32x4x2_t s1 = tmp[1];
-  int32x4x2_t s2 = tmp[2];
-  int32x4x2_t s3 = tmp[3];
-  int32x4x2_t s4 = tmp[4];
-  int32x4x2_t s5 = tmp[5];
-  int32x4x2_t s6 = tmp[6];
-  int32x4x2_t s7 = tmp[7];
-  transpose_s32_8x8(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
+    const int16x8_t *tmp, int sy, int gamma) {
+  int16x8_t s0 = tmp[0];
+  int16x8_t s1 = tmp[1];
+  int16x8_t s2 = tmp[2];
+  int16x8_t s3 = tmp[3];
+  int16x8_t s4 = tmp[4];
+  int16x8_t s5 = tmp[5];
+  int16x8_t s6 = tmp[6];
+  int16x8_t s7 = tmp[7];
+  transpose_s16_8x8(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
 
   int16x8_t f[8];
   load_filters_8(f, sy, gamma);
 
-  int32x4_t m0 = vmulq_s32(s0.val[0], vmovl_s16(vget_low_s16(f[0])));
-  m0 = vmlaq_s32(m0, s0.val[1], vmovl_s16(vget_high_s16(f[0])));
-  int32x4_t m1 = vmulq_s32(s1.val[0], vmovl_s16(vget_low_s16(f[1])));
-  m1 = vmlaq_s32(m1, s1.val[1], vmovl_s16(vget_high_s16(f[1])));
-  int32x4_t m2 = vmulq_s32(s2.val[0], vmovl_s16(vget_low_s16(f[2])));
-  m2 = vmlaq_s32(m2, s2.val[1], vmovl_s16(vget_high_s16(f[2])));
-  int32x4_t m3 = vmulq_s32(s3.val[0], vmovl_s16(vget_low_s16(f[3])));
-  m3 = vmlaq_s32(m3, s3.val[1], vmovl_s16(vget_high_s16(f[3])));
-  int32x4_t m4 = vmulq_s32(s4.val[0], vmovl_s16(vget_low_s16(f[4])));
-  m4 = vmlaq_s32(m4, s4.val[1], vmovl_s16(vget_high_s16(f[4])));
-  int32x4_t m5 = vmulq_s32(s5.val[0], vmovl_s16(vget_low_s16(f[5])));
-  m5 = vmlaq_s32(m5, s5.val[1], vmovl_s16(vget_high_s16(f[5])));
-  int32x4_t m6 = vmulq_s32(s6.val[0], vmovl_s16(vget_low_s16(f[6])));
-  m6 = vmlaq_s32(m6, s6.val[1], vmovl_s16(vget_high_s16(f[6])));
-  int32x4_t m7 = vmulq_s32(s7.val[0], vmovl_s16(vget_low_s16(f[7])));
-  m7 = vmlaq_s32(m7, s7.val[1], vmovl_s16(vget_high_s16(f[7])));
+  int32x4_t m0 = vmull_s16(vget_low_s16(s0), vget_low_s16(f[0]));
+  m0 = vmlal_s16(m0, vget_high_s16(s0), vget_high_s16(f[0]));
+  int32x4_t m1 = vmull_s16(vget_low_s16(s1), vget_low_s16(f[1]));
+  m1 = vmlal_s16(m1, vget_high_s16(s1), vget_high_s16(f[1]));
+  int32x4_t m2 = vmull_s16(vget_low_s16(s2), vget_low_s16(f[2]));
+  m2 = vmlal_s16(m2, vget_high_s16(s2), vget_high_s16(f[2]));
+  int32x4_t m3 = vmull_s16(vget_low_s16(s3), vget_low_s16(f[3]));
+  m3 = vmlal_s16(m3, vget_high_s16(s3), vget_high_s16(f[3]));
+  int32x4_t m4 = vmull_s16(vget_low_s16(s4), vget_low_s16(f[4]));
+  m4 = vmlal_s16(m4, vget_high_s16(s4), vget_high_s16(f[4]));
+  int32x4_t m5 = vmull_s16(vget_low_s16(s5), vget_low_s16(f[5]));
+  m5 = vmlal_s16(m5, vget_high_s16(s5), vget_high_s16(f[5]));
+  int32x4_t m6 = vmull_s16(vget_low_s16(s6), vget_low_s16(f[6]));
+  m6 = vmlal_s16(m6, vget_high_s16(s6), vget_high_s16(f[6]));
+  int32x4_t m7 = vmull_s16(vget_low_s16(s7), vget_low_s16(f[7]));
+  m7 = vmlal_s16(m7, vget_high_s16(s7), vget_high_s16(f[7]));
 
   int32x4_t m0123[] = { m0, m1, m2, m3 };
   int32x4_t m4567[] = { m4, m5, m6, m7 };
@@ -343,7 +334,7 @@
 static INLINE void warp_affine_vertical_step_4x1_f4_neon(
     uint16_t *pred, int p_stride, int bd, uint16_t *dst, int dst_stride,
     int is_compound, int do_average, int use_dist_wtd_comp_avg, int fwd,
-    int bwd, int16_t gamma, const int32x4x2_t *tmp, int i, int sy, int j) {
+    int bwd, int16_t gamma, const int16x8_t *tmp, int i, int sy, int j) {
   int32x4_t sum0 =
       gamma == 0 ? warp_affine_vertical_filter_4x1_f1_neon(tmp, sy)
                  : warp_affine_vertical_filter_4x1_f4_neon(tmp, sy, gamma);
@@ -399,7 +390,7 @@
 static INLINE void warp_affine_vertical_step_8x1_f8_neon(
     uint16_t *pred, int p_stride, int bd, uint16_t *dst, int dst_stride,
     int is_compound, int do_average, int use_dist_wtd_comp_avg, int fwd,
-    int bwd, int16_t gamma, const int32x4x2_t *tmp, int i, int sy, int j) {
+    int bwd, int16_t gamma, const int16x8_t *tmp, int i, int sy, int j) {
   int32x4x2_t sums =
       gamma == 0 ? warp_affine_vertical_filter_8x1_f1_neon(tmp, sy)
                  : warp_affine_vertical_filter_8x1_f8_neon(tmp, sy, gamma);
@@ -475,8 +466,9 @@
     uint16_t *pred, int p_width, int p_height, int p_stride, int bd,
     uint16_t *dst, int dst_stride, int is_compound, int do_average,
     int use_dist_wtd_comp_avg, int fwd, int bwd, int16_t gamma, int16_t delta,
-    const int32x4x2_t *tmp, int i, int sy4, int j) {
+    const int16x8_t *tmp, int i, int sy4, int j) {
   int limit_height = p_height > 4 ? 8 : 4;
+
   if (p_width > 4) {
     // p_width == 8
     for (int k = 0; k < limit_height; ++k) {
@@ -539,7 +531,21 @@
       sx4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1);
       sy4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1);
 
-      int32x4x2_t tmp[15];
+      // Each horizontal filter result is formed by the sum of up to eight
+      // multiplications by filter values and then a shift. Although both the
+      // inputs and filters are loaded as int16, the input data is at most bd
+      // bits and the filters are at most 8 bits each. Additionally since we
+      // know all possible filter values we know that the sum of absolute
+      // filter values will fit in at most 9 bits. With this in mind we can
+      // conclude that the sum of each filter application will fit in bd + 9
+      // bits. The shift following the summation is ROUND0_BITS (which is 3),
+      // +2 for 12-bit, which gives us a final storage of:
+      // bd ==  8: ( 8 + 9) - 3 => 14 bits
+      // bd == 10: (10 + 9) - 3 => 16 bits
+      // bd == 12: (12 + 9) - 5 => 16 bits
+      // So it is safe to use int16x8_t as the intermediate storage type here.
+      int16x8_t tmp[15];
+
       warp_affine_horizontal_neon(ref, width, height, stride, p_width, alpha,
                                   beta, iy4, sx4, ix4, tmp, bd);
       warp_affine_vertical_neon(pred, p_width, p_height, p_stride, bd, dst,