arm,dsp: Add 10bpp Convolve{Compound}Vertical_NEON().

~1.41 to ~2.51 faster than vectorized "C" depending on block size.

PiperOrigin-RevId: 379306710
PiperOrigin-RevId: 378673603
Change-Id: Ic39595c9678f3c7fecaf902be21683a1e977d7dc
diff --git a/src/dsp/arm/common_neon.h b/src/dsp/arm/common_neon.h
index 9bf1a49..157aa18 100644
--- a/src/dsp/arm/common_neon.h
+++ b/src/dsp/arm/common_neon.h
@@ -210,6 +210,14 @@
       vld1_lane_u16(&temp, vreinterpret_u16_u8(val), lane));
 }
 
+template <int lane>
+inline uint16x4_t Load2(const void* const buf, uint16x4_t val) {
+  uint32_t temp;
+  memcpy(&temp, buf, 4);
+  return vreinterpret_u16_u32(
+      vld1_lane_u32(&temp, vreinterpret_u32_u16(val), lane));
+}
+
 // Load 4 uint8_t values into the low half of a uint8x8_t register. Zeros the
 // register before loading the values. Use caution when using this in loops
 // because it will re-zero the register before loading on every iteration.
diff --git a/src/dsp/arm/convolve_10bit_neon.cc b/src/dsp/arm/convolve_10bit_neon.cc
index 0b87e99..3e3e740 100644
--- a/src/dsp/arm/convolve_10bit_neon.cc
+++ b/src/dsp/arm/convolve_10bit_neon.cc
@@ -22,6 +22,7 @@
 #include <cassert>
 #include <cstdint>
 
+#include "src/dsp/arm/common_neon.h"
 #include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
 #include "src/utils/common.h"
@@ -32,7 +33,8 @@
 namespace dsp {
 namespace {
 
-constexpr int kHorizontalOffset = 3;
+// Include the constants and utility functions inside the anonymous namespace.
+#include "src/dsp/convolve.inc"
 
 template <int filter_index>
 int32x4x2_t SumOnePassTaps(const uint16x8_t* const src,
@@ -405,6 +407,410 @@
                                          filter_index);
 }
 
+template <int filter_index, bool is_compound = false>
+void FilterVertical(const uint16_t* LIBGAV1_RESTRICT const src,
+                    const ptrdiff_t src_stride,
+                    void* LIBGAV1_RESTRICT const dst,
+                    const ptrdiff_t dst_stride, const int width,
+                    const int height, const int16x4_t* const taps) {
+  const int num_taps = GetNumTapsInFilter(filter_index);
+  const int next_row = num_taps - 1;
+  const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+  auto* const dst16 = static_cast<uint16_t*>(dst);
+  assert(width >= 8);
+
+  int x = 0;
+  do {
+    const uint16_t* src_x = src + x;
+    uint16x8_t srcs[8];
+    srcs[0] = vld1q_u16(src_x);
+    src_x += src_stride;
+    if (num_taps >= 4) {
+      srcs[1] = vld1q_u16(src_x);
+      src_x += src_stride;
+      srcs[2] = vld1q_u16(src_x);
+      src_x += src_stride;
+      if (num_taps >= 6) {
+        srcs[3] = vld1q_u16(src_x);
+        src_x += src_stride;
+        srcs[4] = vld1q_u16(src_x);
+        src_x += src_stride;
+        if (num_taps == 8) {
+          srcs[5] = vld1q_u16(src_x);
+          src_x += src_stride;
+          srcs[6] = vld1q_u16(src_x);
+          src_x += src_stride;
+        }
+      }
+    }
+
+    // Decreasing the y loop counter produces worse code with clang.
+    // Don't unroll this loop since it generates too much code and the decoder
+    // is even slower.
+    int y = 0;
+    do {
+      srcs[next_row] = vld1q_u16(src_x);
+      src_x += src_stride;
+
+      const int32x4x2_t v_sum = SumOnePassTaps<filter_index>(srcs, taps);
+      if (is_compound) {
+        const int16x4_t v_compound_offset = vdup_n_s16(kCompoundOffset);
+        const int16x4_t d0 =
+            vqrshrn_n_s32(v_sum.val[0], kInterRoundBitsHorizontal - 1);
+        const int16x4_t d1 =
+            vqrshrn_n_s32(v_sum.val[1], kInterRoundBitsHorizontal - 1);
+        vst1_u16(dst16 + x + y * dst_stride,
+                 vreinterpret_u16_s16(vadd_s16(d0, v_compound_offset)));
+        vst1_u16(dst16 + x + 4 + y * dst_stride,
+                 vreinterpret_u16_s16(vadd_s16(d1, v_compound_offset)));
+      } else {
+        const uint16x4_t d0 = vmin_u16(
+            vqrshrun_n_s32(v_sum.val[0], kFilterBits - 1), v_max_bitdepth);
+        const uint16x4_t d1 = vmin_u16(
+            vqrshrun_n_s32(v_sum.val[1], kFilterBits - 1), v_max_bitdepth);
+        vst1_u16(dst16 + x + y * dst_stride, d0);
+        vst1_u16(dst16 + x + 4 + y * dst_stride, d1);
+      }
+
+      srcs[0] = srcs[1];
+      if (num_taps >= 4) {
+        srcs[1] = srcs[2];
+        srcs[2] = srcs[3];
+        if (num_taps >= 6) {
+          srcs[3] = srcs[4];
+          srcs[4] = srcs[5];
+          if (num_taps == 8) {
+            srcs[5] = srcs[6];
+            srcs[6] = srcs[7];
+          }
+        }
+      }
+    } while (++y < height);
+    x += 8;
+  } while (x < width);
+}
+
+template <int filter_index, bool is_compound = false>
+void FilterVertical4xH(const uint16_t* LIBGAV1_RESTRICT src,
+                       const ptrdiff_t src_stride,
+                       void* LIBGAV1_RESTRICT const dst,
+                       const ptrdiff_t dst_stride, const int height,
+                       const int16x4_t* const taps) {
+  const int num_taps = GetNumTapsInFilter(filter_index);
+  const int next_row = num_taps - 1;
+  const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+
+  uint16x4_t srcs[9];
+  srcs[0] = vld1_u16(src);
+  src += src_stride;
+  if (num_taps >= 4) {
+    srcs[1] = vld1_u16(src);
+    src += src_stride;
+    srcs[2] = vld1_u16(src);
+    src += src_stride;
+    if (num_taps >= 6) {
+      srcs[3] = vld1_u16(src);
+      src += src_stride;
+      srcs[4] = vld1_u16(src);
+      src += src_stride;
+      if (num_taps == 8) {
+        srcs[5] = vld1_u16(src);
+        src += src_stride;
+        srcs[6] = vld1_u16(src);
+        src += src_stride;
+      }
+    }
+  }
+
+  int y = height;
+  do {
+    srcs[next_row] = vld1_u16(src);
+    src += src_stride;
+    srcs[num_taps] = vld1_u16(src);
+    src += src_stride;
+
+    const int32x4_t v_sum = SumOnePassTaps<filter_index>(srcs, taps);
+    const int32x4_t v_sum_1 = SumOnePassTaps<filter_index>(srcs + 1, taps);
+    if (is_compound) {
+      const int16x4_t d0 = vqrshrn_n_s32(v_sum, kInterRoundBitsHorizontal - 1);
+      const int16x4_t d1 =
+          vqrshrn_n_s32(v_sum_1, kInterRoundBitsHorizontal - 1);
+      vst1_u16(dst16,
+               vreinterpret_u16_s16(vadd_s16(d0, vdup_n_s16(kCompoundOffset))));
+      dst16 += dst_stride;
+      vst1_u16(dst16,
+               vreinterpret_u16_s16(vadd_s16(d1, vdup_n_s16(kCompoundOffset))));
+      dst16 += dst_stride;
+    } else {
+      const uint16x4_t d0 =
+          vmin_u16(vqrshrun_n_s32(v_sum, kFilterBits - 1), v_max_bitdepth);
+      const uint16x4_t d1 =
+          vmin_u16(vqrshrun_n_s32(v_sum_1, kFilterBits - 1), v_max_bitdepth);
+      vst1_u16(dst16, d0);
+      dst16 += dst_stride;
+      vst1_u16(dst16, d1);
+      dst16 += dst_stride;
+    }
+
+    srcs[0] = srcs[2];
+    if (num_taps >= 4) {
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
+      if (num_taps >= 6) {
+        srcs[3] = srcs[5];
+        srcs[4] = srcs[6];
+        if (num_taps == 8) {
+          srcs[5] = srcs[7];
+          srcs[6] = srcs[8];
+        }
+      }
+    }
+    y -= 2;
+  } while (y != 0);
+}
+
+template <int filter_index>
+void FilterVertical2xH(const uint16_t* LIBGAV1_RESTRICT src,
+                       const ptrdiff_t src_stride,
+                       void* LIBGAV1_RESTRICT const dst,
+                       const ptrdiff_t dst_stride, const int height,
+                       const int16x4_t* const taps) {
+  const int num_taps = GetNumTapsInFilter(filter_index);
+  const int next_row = num_taps - 1;
+  const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+  const uint16x4_t v_zero = vdup_n_u16(0);
+
+  uint16x4_t srcs[9];
+  srcs[0] = Load2<0>(src, v_zero);
+  src += src_stride;
+  if (num_taps >= 4) {
+    srcs[0] = Load2<1>(src, srcs[0]);
+    src += src_stride;
+    srcs[2] = Load2<0>(src, v_zero);
+    src += src_stride;
+    srcs[1] = vext_u16(srcs[0], srcs[2], 2);
+    if (num_taps >= 6) {
+      srcs[2] = Load2<1>(src, srcs[2]);
+      src += src_stride;
+      srcs[4] = Load2<0>(src, v_zero);
+      src += src_stride;
+      srcs[3] = vext_u16(srcs[2], srcs[4], 2);
+      if (num_taps == 8) {
+        srcs[4] = Load2<1>(src, srcs[4]);
+        src += src_stride;
+        srcs[6] = Load2<0>(src, v_zero);
+        src += src_stride;
+        srcs[5] = vext_u16(srcs[4], srcs[6], 2);
+      }
+    }
+  }
+
+  int y = height;
+  do {
+    srcs[next_row - 1] = Load2<1>(src, srcs[next_row - 1]);
+    src += src_stride;
+    srcs[num_taps] = Load2<0>(src, v_zero);
+    src += src_stride;
+    srcs[next_row] = vext_u16(srcs[next_row - 1], srcs[num_taps], 2);
+
+    const int32x4_t v_sum = SumOnePassTaps<filter_index>(srcs, taps);
+    const uint16x4_t d0 =
+        vmin_u16(vqrshrun_n_s32(v_sum, kFilterBits - 1), v_max_bitdepth);
+    Store2<0>(dst16, d0);
+    dst16 += dst_stride;
+    Store2<1>(dst16, d0);
+    dst16 += dst_stride;
+
+    srcs[0] = srcs[2];
+    if (num_taps >= 4) {
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
+      if (num_taps >= 6) {
+        srcs[3] = srcs[5];
+        srcs[4] = srcs[6];
+        if (num_taps == 8) {
+          srcs[5] = srcs[7];
+          srcs[6] = srcs[8];
+        }
+      }
+    }
+    y -= 2;
+  } while (y != 0);
+}
+
+void ConvolveVertical_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int vertical_filter_index, const int /*horizontal_filter_id*/,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
+  const int filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int vertical_taps = GetNumTapsInFilter(filter_index);
+  const ptrdiff_t src_stride = reference_stride >> 1;
+  const auto* src = static_cast<const uint16_t*>(reference) -
+                    (vertical_taps / 2 - 1) * src_stride;
+  auto* const dest = static_cast<uint16_t*>(prediction);
+  const ptrdiff_t dest_stride = pred_stride >> 1;
+  assert(vertical_filter_id != 0);
+
+  int16x4_t taps[8];
+  for (int k = 0; k < kSubPixelTaps; ++k) {
+    taps[k] =
+        vdup_n_s16(kHalfSubPixelFilters[filter_index][vertical_filter_id][k]);
+  }
+
+  if (filter_index == 0) {  // 6 tap.
+    if (width == 2) {
+      FilterVertical2xH<0>(src, src_stride, dest, dest_stride, height,
+                           taps + 1);
+    } else if (width == 4) {
+      FilterVertical4xH<0>(src, src_stride, dest, dest_stride, height,
+                           taps + 1);
+    } else {
+      FilterVertical<0>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 1);
+    }
+  } else if ((filter_index == 1) &
+             ((vertical_filter_id == 1) | (vertical_filter_id == 7) |
+              (vertical_filter_id == 8) | (vertical_filter_id == 9) |
+              (vertical_filter_id == 15))) {  // 6 tap.
+    if (width == 2) {
+      FilterVertical2xH<1>(src, src_stride, dest, dest_stride, height,
+                           taps + 1);
+    } else if (width == 4) {
+      FilterVertical4xH<1>(src, src_stride, dest, dest_stride, height,
+                           taps + 1);
+    } else {
+      FilterVertical<1>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 1);
+    }
+  } else if (filter_index == 2) {  // 8 tap.
+    if (width == 2) {
+      FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps);
+    } else if (width == 4) {
+      FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps);
+    } else {
+      FilterVertical<2>(src, src_stride, dest, dest_stride, width, height,
+                        taps);
+    }
+  } else if (filter_index == 3) {  // 2 tap.
+    if (width == 2) {
+      FilterVertical2xH<3>(src, src_stride, dest, dest_stride, height,
+                           taps + 3);
+    } else if (width == 4) {
+      FilterVertical4xH<3>(src, src_stride, dest, dest_stride, height,
+                           taps + 3);
+    } else {
+      FilterVertical<3>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 3);
+    }
+  } else {
+    // 4 tap. When |filter_index| == 1 the |vertical_filter_id| values listed
+    // below map to 4 tap filters.
+    assert(filter_index == 5 || filter_index == 4 ||
+           (filter_index == 1 &&
+            (vertical_filter_id == 0 || vertical_filter_id == 2 ||
+             vertical_filter_id == 3 || vertical_filter_id == 4 ||
+             vertical_filter_id == 5 || vertical_filter_id == 6 ||
+             vertical_filter_id == 10 || vertical_filter_id == 11 ||
+             vertical_filter_id == 12 || vertical_filter_id == 13 ||
+             vertical_filter_id == 14)));
+    // According to GetNumTapsInFilter() this has 6 taps but here we are
+    // treating it as though it has 4.
+    if (filter_index == 1) src += src_stride;
+    if (width == 2) {
+      FilterVertical2xH<5>(src, src_stride, dest, dest_stride, height,
+                           taps + 2);
+    } else if (width == 4) {
+      FilterVertical4xH<5>(src, src_stride, dest, dest_stride, height,
+                           taps + 2);
+    } else {
+      FilterVertical<5>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 2);
+    }
+  }
+}
+
+void ConvolveCompoundVertical_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int vertical_filter_index, const int /*horizontal_filter_id*/,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
+  const int filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int vertical_taps = GetNumTapsInFilter(filter_index);
+  const ptrdiff_t src_stride = reference_stride >> 1;
+  const auto* src = static_cast<const uint16_t*>(reference) -
+                    (vertical_taps / 2 - 1) * src_stride;
+  auto* const dest = static_cast<uint16_t*>(prediction);
+  assert(vertical_filter_id != 0);
+
+  int16x4_t taps[8];
+  for (int k = 0; k < kSubPixelTaps; ++k) {
+    taps[k] =
+        vdup_n_s16(kHalfSubPixelFilters[filter_index][vertical_filter_id][k]);
+  }
+
+  if (filter_index == 0) {  // 6 tap.
+    if (width == 4) {
+      FilterVertical4xH<0, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 1);
+    } else {
+      FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 1);
+    }
+  } else if ((filter_index == 1) &
+             ((vertical_filter_id == 1) | (vertical_filter_id == 7) |
+              (vertical_filter_id == 8) | (vertical_filter_id == 9) |
+              (vertical_filter_id == 15))) {  // 6 tap.
+    if (width == 4) {
+      FilterVertical4xH<1, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 1);
+    } else {
+      FilterVertical<1, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 1);
+    }
+  } else if (filter_index == 2) {  // 8 tap.
+    if (width == 4) {
+      FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps);
+    } else {
+      FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps);
+    }
+  } else if (filter_index == 3) {  // 2 tap.
+    if (width == 4) {
+      FilterVertical4xH<3, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 3);
+    } else {
+      FilterVertical<3, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 3);
+    }
+  } else {
+    // 4 tap. When |filter_index| == 1 the |filter_id| values listed below map
+    // to 4 tap filters.
+    assert(filter_index == 5 || filter_index == 4 ||
+           (filter_index == 1 &&
+            (vertical_filter_id == 2 || vertical_filter_id == 3 ||
+             vertical_filter_id == 4 || vertical_filter_id == 5 ||
+             vertical_filter_id == 6 || vertical_filter_id == 10 ||
+             vertical_filter_id == 11 || vertical_filter_id == 12 ||
+             vertical_filter_id == 13 || vertical_filter_id == 14)));
+    // According to GetNumTapsInFilter() this has 6 taps but here we are
+    // treating it as though it has 4.
+    if (filter_index == 1) src += src_stride;
+    if (width == 4) {
+      FilterVertical4xH<5, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 2);
+    } else {
+      FilterVertical<5, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 2);
+    }
+  }
+}
+
 void ConvolveCompoundCopy_NEON(
     const void* const reference, const ptrdiff_t reference_stride,
     const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
@@ -476,9 +882,11 @@
   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
   assert(dsp != nullptr);
   dsp->convolve[0][0][0][1] = ConvolveHorizontal_NEON;
+  dsp->convolve[0][0][1][0] = ConvolveVertical_NEON;
 
   dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_NEON;
   dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_NEON;
+  dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_NEON;
 }
 
 }  // namespace
diff --git a/src/dsp/arm/convolve_neon.h b/src/dsp/arm/convolve_neon.h
index 8e2b88d..1cac618 100644
--- a/src/dsp/arm/convolve_neon.h
+++ b/src/dsp/arm/convolve_neon.h
@@ -48,9 +48,11 @@
 #define LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D LIBGAV1_CPU_NEON
 
 #define LIBGAV1_Dsp10bpp_ConvolveHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_ConvolveVertical LIBGAV1_CPU_NEON
 
 #define LIBGAV1_Dsp10bpp_ConvolveCompoundCopy LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_ConvolveCompoundHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_ConvolveCompoundVertical LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_