Use 16-bit vectors between h/v filters in highbd_warp_affine_neon
In highbd warp affine we currently perform a widening multiply to
32-bits during the horizontal filter and then maintain this through to
the vertical filter where we do a non-widening 32-bit multiply. This
also means that we must first widen the filter values in the vertical
filter step since these are stored in memory as 16-bit values.
It seems that once the horizontal filter is complete we can instead
narrow back to 16-bits and then perform a second widening multiply in
the vertical step, saving the need to separately widen the filter
values.
This approach should also allow us to make use of the SVE 16-bit dot
product instructions in a future commit.
Benchmarking this on a Neoverse N2 machine with Clang 16 and GCC 12
shows a geomean 5.2% reduction in runtime of the warp affine speed
tests.
Change-Id: I9891d919b6a53edc7e5ca8b6dca1ee23f46ef45a
diff --git a/av1/common/arm/highbd_warp_plane_neon.c b/av1/common/arm/highbd_warp_plane_neon.c
index b83cd4d..9dc76ec 100644
--- a/av1/common/arm/highbd_warp_plane_neon.c
+++ b/av1/common/arm/highbd_warp_plane_neon.c
@@ -65,11 +65,8 @@
out[7] = vld1q_s16(base + ofs7 * 8);
}
-static INLINE int32x4x2_t warp_affine_horizontal_step_4x1_f4_neon(
+static INLINE int16x8_t warp_affine_horizontal_step_4x1_f4_neon(
int bd, int sx, int alpha, uint16x8x2_t in) {
- const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
- const int offset_bits_horiz = bd + FILTER_BITS - 1;
-
int16x8_t f[4];
load_filters_4(f, sx, alpha);
@@ -93,15 +90,16 @@
int32x4_t m0123[] = { m0, m1, m2, m3 };
- int32x4x2_t res;
- res.val[0] = horizontal_add_4d_s32x4(m0123);
- res.val[0] = vaddq_s32(res.val[0], vdupq_n_s32(1 << offset_bits_horiz));
- res.val[0] = vrshlq_s32(res.val[0], vdupq_n_s32(-round0));
- res.val[1] = vdupq_n_s32(0);
- return res;
+ const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
+ const int offset_bits_horiz = bd + FILTER_BITS - 1;
+
+ int32x4_t res = horizontal_add_4d_s32x4(m0123);
+ res = vaddq_s32(res, vdupq_n_s32(1 << offset_bits_horiz));
+ res = vrshlq_s32(res, vdupq_n_s32(-round0));
+ return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
}
-static INLINE int32x4x2_t warp_affine_horizontal_step_8x1_f8_neon(
+static INLINE int16x8_t warp_affine_horizontal_step_8x1_f8_neon(
int bd, int sx, int alpha, uint16x8x2_t in) {
const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
const int offset_bits_horiz = bd + FILTER_BITS - 1;
@@ -146,22 +144,20 @@
int32x4_t m0123[] = { m0, m1, m2, m3 };
int32x4_t m4567[] = { m4, m5, m6, m7 };
- int32x4x2_t res;
- res.val[0] = horizontal_add_4d_s32x4(m0123);
- res.val[1] = horizontal_add_4d_s32x4(m4567);
-
- res.val[0] = vaddq_s32(res.val[0], vdupq_n_s32(1 << offset_bits_horiz));
- res.val[1] = vaddq_s32(res.val[1], vdupq_n_s32(1 << offset_bits_horiz));
- res.val[0] = vrshlq_s32(res.val[0], vdupq_n_s32(-round0));
- res.val[1] = vrshlq_s32(res.val[1], vdupq_n_s32(-round0));
- return res;
+ int32x4_t res0 = horizontal_add_4d_s32x4(m0123);
+ int32x4_t res1 = horizontal_add_4d_s32x4(m4567);
+ res0 = vaddq_s32(res0, vdupq_n_s32(1 << offset_bits_horiz));
+ res1 = vaddq_s32(res1, vdupq_n_s32(1 << offset_bits_horiz));
+ res0 = vrshlq_s32(res0, vdupq_n_s32(-round0));
+ res1 = vrshlq_s32(res1, vdupq_n_s32(-round0));
+ return vcombine_s16(vmovn_s32(res0), vmovn_s32(res1));
}
static INLINE void warp_affine_horizontal_neon(const uint16_t *ref, int width,
int height, int stride,
int p_width, int16_t alpha,
int16_t beta, int iy4, int sx4,
- int ix4, int32x4x2_t tmp[],
+ int ix4, int16x8_t tmp[],
int bd) {
const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
@@ -170,8 +166,7 @@
int iy = clamp(iy4 + k - 7, 0, height - 1);
int32_t dup_val = (1 << (bd + FILTER_BITS - round0 - 1)) +
ref[iy * stride] * (1 << (FILTER_BITS - round0));
- tmp[k].val[0] = vdupq_n_s32(dup_val);
- tmp[k].val[1] = vdupq_n_s32(dup_val);
+ tmp[k] = vdupq_n_s16(dup_val);
}
return;
} else if (ix4 >= width + 6) {
@@ -180,8 +175,7 @@
int32_t dup_val =
(1 << (bd + FILTER_BITS - round0 - 1)) +
ref[iy * stride + (width - 1)] * (1 << (FILTER_BITS - round0));
- tmp[k].val[0] = vdupq_n_s32(dup_val);
- tmp[k].val[1] = vdupq_n_s32(dup_val);
+ tmp[k] = vdupq_n_s16(dup_val);
}
return;
}
@@ -230,106 +224,103 @@
}
static INLINE int32x4_t
-warp_affine_vertical_filter_4x1_f1_neon(const int32x4x2_t *tmp, int sy) {
+warp_affine_vertical_filter_4x1_f1_neon(const int16x8_t *tmp, int sy) {
const int16x8_t f = load_filters_1(sy);
- const int32x2_t f01 = vget_low_s32(vmovl_s16(vget_low_s16(f)));
- const int32x2_t f23 = vget_high_s32(vmovl_s16(vget_low_s16(f)));
- const int32x2_t f45 = vget_low_s32(vmovl_s16(vget_high_s16(f)));
- const int32x2_t f67 = vget_high_s32(vmovl_s16(vget_high_s16(f)));
+ const int16x4_t f0123 = vget_low_s16(f);
+ const int16x4_t f4567 = vget_high_s16(f);
- int32x4_t m0123 = vmulq_lane_s32(tmp[0].val[0], f01, 0);
- m0123 = vmlaq_lane_s32(m0123, tmp[1].val[0], f01, 1);
- m0123 = vmlaq_lane_s32(m0123, tmp[2].val[0], f23, 0);
- m0123 = vmlaq_lane_s32(m0123, tmp[3].val[0], f23, 1);
- m0123 = vmlaq_lane_s32(m0123, tmp[4].val[0], f45, 0);
- m0123 = vmlaq_lane_s32(m0123, tmp[5].val[0], f45, 1);
- m0123 = vmlaq_lane_s32(m0123, tmp[6].val[0], f67, 0);
- m0123 = vmlaq_lane_s32(m0123, tmp[7].val[0], f67, 1);
+ int32x4_t m0123 = vmull_lane_s16(vget_low_s16(tmp[0]), f0123, 0);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[1]), f0123, 1);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[2]), f0123, 2);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[3]), f0123, 3);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[4]), f4567, 0);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[5]), f4567, 1);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[6]), f4567, 2);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[7]), f4567, 3);
return m0123;
}
static INLINE int32x4x2_t
-warp_affine_vertical_filter_8x1_f1_neon(const int32x4x2_t *tmp, int sy) {
+warp_affine_vertical_filter_8x1_f1_neon(const int16x8_t *tmp, int sy) {
const int16x8_t f = load_filters_1(sy);
- const int32x2_t f01 = vget_low_s32(vmovl_s16(vget_low_s16(f)));
- const int32x2_t f23 = vget_high_s32(vmovl_s16(vget_low_s16(f)));
- const int32x2_t f45 = vget_low_s32(vmovl_s16(vget_high_s16(f)));
- const int32x2_t f67 = vget_high_s32(vmovl_s16(vget_high_s16(f)));
+ const int16x4_t f0123 = vget_low_s16(f);
+ const int16x4_t f4567 = vget_high_s16(f);
- int32x4_t m0123 = vmulq_lane_s32(tmp[0].val[0], f01, 0);
- m0123 = vmlaq_lane_s32(m0123, tmp[1].val[0], f01, 1);
- m0123 = vmlaq_lane_s32(m0123, tmp[2].val[0], f23, 0);
- m0123 = vmlaq_lane_s32(m0123, tmp[3].val[0], f23, 1);
- m0123 = vmlaq_lane_s32(m0123, tmp[4].val[0], f45, 0);
- m0123 = vmlaq_lane_s32(m0123, tmp[5].val[0], f45, 1);
- m0123 = vmlaq_lane_s32(m0123, tmp[6].val[0], f67, 0);
- m0123 = vmlaq_lane_s32(m0123, tmp[7].val[0], f67, 1);
+ int32x4_t m0123 = vmull_lane_s16(vget_low_s16(tmp[0]), f0123, 0);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[1]), f0123, 1);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[2]), f0123, 2);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[3]), f0123, 3);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[4]), f4567, 0);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[5]), f4567, 1);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[6]), f4567, 2);
+ m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[7]), f4567, 3);
- int32x4_t m4567 = vmulq_lane_s32(tmp[0].val[1], f01, 0);
- m4567 = vmlaq_lane_s32(m4567, tmp[1].val[1], f01, 1);
- m4567 = vmlaq_lane_s32(m4567, tmp[2].val[1], f23, 0);
- m4567 = vmlaq_lane_s32(m4567, tmp[3].val[1], f23, 1);
- m4567 = vmlaq_lane_s32(m4567, tmp[4].val[1], f45, 0);
- m4567 = vmlaq_lane_s32(m4567, tmp[5].val[1], f45, 1);
- m4567 = vmlaq_lane_s32(m4567, tmp[6].val[1], f67, 0);
- m4567 = vmlaq_lane_s32(m4567, tmp[7].val[1], f67, 1);
+ int32x4_t m4567 = vmull_lane_s16(vget_high_s16(tmp[0]), f0123, 0);
+ m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[1]), f0123, 1);
+ m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[2]), f0123, 2);
+ m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[3]), f0123, 3);
+ m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[4]), f4567, 0);
+ m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[5]), f4567, 1);
+ m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[6]), f4567, 2);
+ m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[7]), f4567, 3);
return (int32x4x2_t){ { m0123, m4567 } };
}
static INLINE int32x4_t warp_affine_vertical_filter_4x1_f4_neon(
- const int32x4x2_t *tmp, int sy, int gamma) {
- int32x4x2_t s0, s1, s2, s3;
- transpose_s32_4x8(tmp[0].val[0], tmp[1].val[0], tmp[2].val[0], tmp[3].val[0],
- tmp[4].val[0], tmp[5].val[0], tmp[6].val[0], tmp[7].val[0],
- &s0, &s1, &s2, &s3);
+ const int16x8_t *tmp, int sy, int gamma) {
+ int16x8_t s0, s1, s2, s3;
+ transpose_s16_4x8(
+ vget_low_s16(tmp[0]), vget_low_s16(tmp[1]), vget_low_s16(tmp[2]),
+ vget_low_s16(tmp[3]), vget_low_s16(tmp[4]), vget_low_s16(tmp[5]),
+ vget_low_s16(tmp[6]), vget_low_s16(tmp[7]), &s0, &s1, &s2, &s3);
int16x8_t f[4];
load_filters_4(f, sy, gamma);
- int32x4_t m0 = vmulq_s32(s0.val[0], vmovl_s16(vget_low_s16(f[0])));
- m0 = vmlaq_s32(m0, s0.val[1], vmovl_s16(vget_high_s16(f[0])));
- int32x4_t m1 = vmulq_s32(s1.val[0], vmovl_s16(vget_low_s16(f[1])));
- m1 = vmlaq_s32(m1, s1.val[1], vmovl_s16(vget_high_s16(f[1])));
- int32x4_t m2 = vmulq_s32(s2.val[0], vmovl_s16(vget_low_s16(f[2])));
- m2 = vmlaq_s32(m2, s2.val[1], vmovl_s16(vget_high_s16(f[2])));
- int32x4_t m3 = vmulq_s32(s3.val[0], vmovl_s16(vget_low_s16(f[3])));
- m3 = vmlaq_s32(m3, s3.val[1], vmovl_s16(vget_high_s16(f[3])));
+ int32x4_t m0 = vmull_s16(vget_low_s16(s0), vget_low_s16(f[0]));
+ m0 = vmlal_s16(m0, vget_high_s16(s0), vget_high_s16(f[0]));
+ int32x4_t m1 = vmull_s16(vget_low_s16(s1), vget_low_s16(f[1]));
+ m1 = vmlal_s16(m1, vget_high_s16(s1), vget_high_s16(f[1]));
+ int32x4_t m2 = vmull_s16(vget_low_s16(s2), vget_low_s16(f[2]));
+ m2 = vmlal_s16(m2, vget_high_s16(s2), vget_high_s16(f[2]));
+ int32x4_t m3 = vmull_s16(vget_low_s16(s3), vget_low_s16(f[3]));
+ m3 = vmlal_s16(m3, vget_high_s16(s3), vget_high_s16(f[3]));
int32x4_t m0123[] = { m0, m1, m2, m3 };
return horizontal_add_4d_s32x4(m0123);
}
static INLINE int32x4x2_t warp_affine_vertical_filter_8x1_f8_neon(
- const int32x4x2_t *tmp, int sy, int gamma) {
- int32x4x2_t s0 = tmp[0];
- int32x4x2_t s1 = tmp[1];
- int32x4x2_t s2 = tmp[2];
- int32x4x2_t s3 = tmp[3];
- int32x4x2_t s4 = tmp[4];
- int32x4x2_t s5 = tmp[5];
- int32x4x2_t s6 = tmp[6];
- int32x4x2_t s7 = tmp[7];
- transpose_s32_8x8(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
+ const int16x8_t *tmp, int sy, int gamma) {
+ int16x8_t s0 = tmp[0];
+ int16x8_t s1 = tmp[1];
+ int16x8_t s2 = tmp[2];
+ int16x8_t s3 = tmp[3];
+ int16x8_t s4 = tmp[4];
+ int16x8_t s5 = tmp[5];
+ int16x8_t s6 = tmp[6];
+ int16x8_t s7 = tmp[7];
+ transpose_s16_8x8(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
int16x8_t f[8];
load_filters_8(f, sy, gamma);
- int32x4_t m0 = vmulq_s32(s0.val[0], vmovl_s16(vget_low_s16(f[0])));
- m0 = vmlaq_s32(m0, s0.val[1], vmovl_s16(vget_high_s16(f[0])));
- int32x4_t m1 = vmulq_s32(s1.val[0], vmovl_s16(vget_low_s16(f[1])));
- m1 = vmlaq_s32(m1, s1.val[1], vmovl_s16(vget_high_s16(f[1])));
- int32x4_t m2 = vmulq_s32(s2.val[0], vmovl_s16(vget_low_s16(f[2])));
- m2 = vmlaq_s32(m2, s2.val[1], vmovl_s16(vget_high_s16(f[2])));
- int32x4_t m3 = vmulq_s32(s3.val[0], vmovl_s16(vget_low_s16(f[3])));
- m3 = vmlaq_s32(m3, s3.val[1], vmovl_s16(vget_high_s16(f[3])));
- int32x4_t m4 = vmulq_s32(s4.val[0], vmovl_s16(vget_low_s16(f[4])));
- m4 = vmlaq_s32(m4, s4.val[1], vmovl_s16(vget_high_s16(f[4])));
- int32x4_t m5 = vmulq_s32(s5.val[0], vmovl_s16(vget_low_s16(f[5])));
- m5 = vmlaq_s32(m5, s5.val[1], vmovl_s16(vget_high_s16(f[5])));
- int32x4_t m6 = vmulq_s32(s6.val[0], vmovl_s16(vget_low_s16(f[6])));
- m6 = vmlaq_s32(m6, s6.val[1], vmovl_s16(vget_high_s16(f[6])));
- int32x4_t m7 = vmulq_s32(s7.val[0], vmovl_s16(vget_low_s16(f[7])));
- m7 = vmlaq_s32(m7, s7.val[1], vmovl_s16(vget_high_s16(f[7])));
+ int32x4_t m0 = vmull_s16(vget_low_s16(s0), vget_low_s16(f[0]));
+ m0 = vmlal_s16(m0, vget_high_s16(s0), vget_high_s16(f[0]));
+ int32x4_t m1 = vmull_s16(vget_low_s16(s1), vget_low_s16(f[1]));
+ m1 = vmlal_s16(m1, vget_high_s16(s1), vget_high_s16(f[1]));
+ int32x4_t m2 = vmull_s16(vget_low_s16(s2), vget_low_s16(f[2]));
+ m2 = vmlal_s16(m2, vget_high_s16(s2), vget_high_s16(f[2]));
+ int32x4_t m3 = vmull_s16(vget_low_s16(s3), vget_low_s16(f[3]));
+ m3 = vmlal_s16(m3, vget_high_s16(s3), vget_high_s16(f[3]));
+ int32x4_t m4 = vmull_s16(vget_low_s16(s4), vget_low_s16(f[4]));
+ m4 = vmlal_s16(m4, vget_high_s16(s4), vget_high_s16(f[4]));
+ int32x4_t m5 = vmull_s16(vget_low_s16(s5), vget_low_s16(f[5]));
+ m5 = vmlal_s16(m5, vget_high_s16(s5), vget_high_s16(f[5]));
+ int32x4_t m6 = vmull_s16(vget_low_s16(s6), vget_low_s16(f[6]));
+ m6 = vmlal_s16(m6, vget_high_s16(s6), vget_high_s16(f[6]));
+ int32x4_t m7 = vmull_s16(vget_low_s16(s7), vget_low_s16(f[7]));
+ m7 = vmlal_s16(m7, vget_high_s16(s7), vget_high_s16(f[7]));
int32x4_t m0123[] = { m0, m1, m2, m3 };
int32x4_t m4567[] = { m4, m5, m6, m7 };
@@ -343,7 +334,7 @@
static INLINE void warp_affine_vertical_step_4x1_f4_neon(
uint16_t *pred, int p_stride, int bd, uint16_t *dst, int dst_stride,
int is_compound, int do_average, int use_dist_wtd_comp_avg, int fwd,
- int bwd, int16_t gamma, const int32x4x2_t *tmp, int i, int sy, int j) {
+ int bwd, int16_t gamma, const int16x8_t *tmp, int i, int sy, int j) {
int32x4_t sum0 =
gamma == 0 ? warp_affine_vertical_filter_4x1_f1_neon(tmp, sy)
: warp_affine_vertical_filter_4x1_f4_neon(tmp, sy, gamma);
@@ -399,7 +390,7 @@
static INLINE void warp_affine_vertical_step_8x1_f8_neon(
uint16_t *pred, int p_stride, int bd, uint16_t *dst, int dst_stride,
int is_compound, int do_average, int use_dist_wtd_comp_avg, int fwd,
- int bwd, int16_t gamma, const int32x4x2_t *tmp, int i, int sy, int j) {
+ int bwd, int16_t gamma, const int16x8_t *tmp, int i, int sy, int j) {
int32x4x2_t sums =
gamma == 0 ? warp_affine_vertical_filter_8x1_f1_neon(tmp, sy)
: warp_affine_vertical_filter_8x1_f8_neon(tmp, sy, gamma);
@@ -475,8 +466,9 @@
uint16_t *pred, int p_width, int p_height, int p_stride, int bd,
uint16_t *dst, int dst_stride, int is_compound, int do_average,
int use_dist_wtd_comp_avg, int fwd, int bwd, int16_t gamma, int16_t delta,
- const int32x4x2_t *tmp, int i, int sy4, int j) {
+ const int16x8_t *tmp, int i, int sy4, int j) {
int limit_height = p_height > 4 ? 8 : 4;
+
if (p_width > 4) {
// p_width == 8
for (int k = 0; k < limit_height; ++k) {
@@ -539,7 +531,21 @@
sx4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1);
sy4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1);
- int32x4x2_t tmp[15];
+ // Each horizontal filter result is formed by the sum of up to eight
+ // multiplications by filter values and then a shift. Although both the
+ // inputs and filters are loaded as int16, the input data is at most bd
+ // bits and the filters are at most 8 bits each. Additionally since we
+ // know all possible filter values we know that the sum of absolute
+ // filter values will fit in at most 9 bits. With this in mind we can
+ // conclude that the sum of each filter application will fit in bd + 9
+ // bits. The shift following the summation is ROUND0_BITS (which is 3),
+ // +2 for 12-bit, which gives us a final storage of:
+ // bd == 8: ( 8 + 9) - 3 => 14 bits
+ // bd == 10: (10 + 9) - 3 => 16 bits
+ // bd == 12: (12 + 9) - 5 => 16 bits
+ // So it is safe to use int16x8_t as the intermediate storage type here.
+ int16x8_t tmp[15];
+
warp_affine_horizontal_neon(ref, width, height, stride, p_width, alpha,
beta, iy4, sx4, ix4, tmp, bd);
warp_affine_vertical_neon(pred, p_width, p_height, p_stride, bd, dst,