arm,dsp: Add 10bpp Convolve{Compound}Vertical_NEON().
~1.41 to ~2.51 faster than vectorized "C" depending on block size.
PiperOrigin-RevId: 379306710
PiperOrigin-RevId: 378673603
Change-Id: Ic39595c9678f3c7fecaf902be21683a1e977d7dc
diff --git a/src/dsp/arm/common_neon.h b/src/dsp/arm/common_neon.h
index 9bf1a49..157aa18 100644
--- a/src/dsp/arm/common_neon.h
+++ b/src/dsp/arm/common_neon.h
@@ -210,6 +210,14 @@
vld1_lane_u16(&temp, vreinterpret_u16_u8(val), lane));
}
+template <int lane>
+inline uint16x4_t Load2(const void* const buf, uint16x4_t val) {
+ uint32_t temp;
+ memcpy(&temp, buf, 4);
+ return vreinterpret_u16_u32(
+ vld1_lane_u32(&temp, vreinterpret_u32_u16(val), lane));
+}
+
// Load 4 uint8_t values into the low half of a uint8x8_t register. Zeros the
// register before loading the values. Use caution when using this in loops
// because it will re-zero the register before loading on every iteration.
diff --git a/src/dsp/arm/convolve_10bit_neon.cc b/src/dsp/arm/convolve_10bit_neon.cc
index 0b87e99..3e3e740 100644
--- a/src/dsp/arm/convolve_10bit_neon.cc
+++ b/src/dsp/arm/convolve_10bit_neon.cc
@@ -22,6 +22,7 @@
#include <cassert>
#include <cstdint>
+#include "src/dsp/arm/common_neon.h"
#include "src/dsp/constants.h"
#include "src/dsp/dsp.h"
#include "src/utils/common.h"
@@ -32,7 +33,8 @@
namespace dsp {
namespace {
-constexpr int kHorizontalOffset = 3;
+// Include the constants and utility functions inside the anonymous namespace.
+#include "src/dsp/convolve.inc"
template <int filter_index>
int32x4x2_t SumOnePassTaps(const uint16x8_t* const src,
@@ -405,6 +407,410 @@
filter_index);
}
+template <int filter_index, bool is_compound = false>
+void FilterVertical(const uint16_t* LIBGAV1_RESTRICT const src,
+ const ptrdiff_t src_stride,
+ void* LIBGAV1_RESTRICT const dst,
+ const ptrdiff_t dst_stride, const int width,
+ const int height, const int16x4_t* const taps) {
+ const int num_taps = GetNumTapsInFilter(filter_index);
+ const int next_row = num_taps - 1;
+ const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+ auto* const dst16 = static_cast<uint16_t*>(dst);
+ assert(width >= 8);
+
+ int x = 0;
+ do {
+ const uint16_t* src_x = src + x;
+ uint16x8_t srcs[8];
+ srcs[0] = vld1q_u16(src_x);
+ src_x += src_stride;
+ if (num_taps >= 4) {
+ srcs[1] = vld1q_u16(src_x);
+ src_x += src_stride;
+ srcs[2] = vld1q_u16(src_x);
+ src_x += src_stride;
+ if (num_taps >= 6) {
+ srcs[3] = vld1q_u16(src_x);
+ src_x += src_stride;
+ srcs[4] = vld1q_u16(src_x);
+ src_x += src_stride;
+ if (num_taps == 8) {
+ srcs[5] = vld1q_u16(src_x);
+ src_x += src_stride;
+ srcs[6] = vld1q_u16(src_x);
+ src_x += src_stride;
+ }
+ }
+ }
+
+ // Decreasing the y loop counter produces worse code with clang.
+ // Don't unroll this loop since it generates too much code and the decoder
+ // is even slower.
+ int y = 0;
+ do {
+ srcs[next_row] = vld1q_u16(src_x);
+ src_x += src_stride;
+
+ const int32x4x2_t v_sum = SumOnePassTaps<filter_index>(srcs, taps);
+ if (is_compound) {
+ const int16x4_t v_compound_offset = vdup_n_s16(kCompoundOffset);
+ const int16x4_t d0 =
+ vqrshrn_n_s32(v_sum.val[0], kInterRoundBitsHorizontal - 1);
+ const int16x4_t d1 =
+ vqrshrn_n_s32(v_sum.val[1], kInterRoundBitsHorizontal - 1);
+ vst1_u16(dst16 + x + y * dst_stride,
+ vreinterpret_u16_s16(vadd_s16(d0, v_compound_offset)));
+ vst1_u16(dst16 + x + 4 + y * dst_stride,
+ vreinterpret_u16_s16(vadd_s16(d1, v_compound_offset)));
+ } else {
+ const uint16x4_t d0 = vmin_u16(
+ vqrshrun_n_s32(v_sum.val[0], kFilterBits - 1), v_max_bitdepth);
+ const uint16x4_t d1 = vmin_u16(
+ vqrshrun_n_s32(v_sum.val[1], kFilterBits - 1), v_max_bitdepth);
+ vst1_u16(dst16 + x + y * dst_stride, d0);
+ vst1_u16(dst16 + x + 4 + y * dst_stride, d1);
+ }
+
+ srcs[0] = srcs[1];
+ if (num_taps >= 4) {
+ srcs[1] = srcs[2];
+ srcs[2] = srcs[3];
+ if (num_taps >= 6) {
+ srcs[3] = srcs[4];
+ srcs[4] = srcs[5];
+ if (num_taps == 8) {
+ srcs[5] = srcs[6];
+ srcs[6] = srcs[7];
+ }
+ }
+ }
+ } while (++y < height);
+ x += 8;
+ } while (x < width);
+}
+
+template <int filter_index, bool is_compound = false>
+void FilterVertical4xH(const uint16_t* LIBGAV1_RESTRICT src,
+ const ptrdiff_t src_stride,
+ void* LIBGAV1_RESTRICT const dst,
+ const ptrdiff_t dst_stride, const int height,
+ const int16x4_t* const taps) {
+ const int num_taps = GetNumTapsInFilter(filter_index);
+ const int next_row = num_taps - 1;
+ const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+ auto* dst16 = static_cast<uint16_t*>(dst);
+
+ uint16x4_t srcs[9];
+ srcs[0] = vld1_u16(src);
+ src += src_stride;
+ if (num_taps >= 4) {
+ srcs[1] = vld1_u16(src);
+ src += src_stride;
+ srcs[2] = vld1_u16(src);
+ src += src_stride;
+ if (num_taps >= 6) {
+ srcs[3] = vld1_u16(src);
+ src += src_stride;
+ srcs[4] = vld1_u16(src);
+ src += src_stride;
+ if (num_taps == 8) {
+ srcs[5] = vld1_u16(src);
+ src += src_stride;
+ srcs[6] = vld1_u16(src);
+ src += src_stride;
+ }
+ }
+ }
+
+ int y = height;
+ do {
+ srcs[next_row] = vld1_u16(src);
+ src += src_stride;
+ srcs[num_taps] = vld1_u16(src);
+ src += src_stride;
+
+ const int32x4_t v_sum = SumOnePassTaps<filter_index>(srcs, taps);
+ const int32x4_t v_sum_1 = SumOnePassTaps<filter_index>(srcs + 1, taps);
+ if (is_compound) {
+ const int16x4_t d0 = vqrshrn_n_s32(v_sum, kInterRoundBitsHorizontal - 1);
+ const int16x4_t d1 =
+ vqrshrn_n_s32(v_sum_1, kInterRoundBitsHorizontal - 1);
+ vst1_u16(dst16,
+ vreinterpret_u16_s16(vadd_s16(d0, vdup_n_s16(kCompoundOffset))));
+ dst16 += dst_stride;
+ vst1_u16(dst16,
+ vreinterpret_u16_s16(vadd_s16(d1, vdup_n_s16(kCompoundOffset))));
+ dst16 += dst_stride;
+ } else {
+ const uint16x4_t d0 =
+ vmin_u16(vqrshrun_n_s32(v_sum, kFilterBits - 1), v_max_bitdepth);
+ const uint16x4_t d1 =
+ vmin_u16(vqrshrun_n_s32(v_sum_1, kFilterBits - 1), v_max_bitdepth);
+ vst1_u16(dst16, d0);
+ dst16 += dst_stride;
+ vst1_u16(dst16, d1);
+ dst16 += dst_stride;
+ }
+
+ srcs[0] = srcs[2];
+ if (num_taps >= 4) {
+ srcs[1] = srcs[3];
+ srcs[2] = srcs[4];
+ if (num_taps >= 6) {
+ srcs[3] = srcs[5];
+ srcs[4] = srcs[6];
+ if (num_taps == 8) {
+ srcs[5] = srcs[7];
+ srcs[6] = srcs[8];
+ }
+ }
+ }
+ y -= 2;
+ } while (y != 0);
+}
+
+template <int filter_index>
+void FilterVertical2xH(const uint16_t* LIBGAV1_RESTRICT src,
+ const ptrdiff_t src_stride,
+ void* LIBGAV1_RESTRICT const dst,
+ const ptrdiff_t dst_stride, const int height,
+ const int16x4_t* const taps) {
+ const int num_taps = GetNumTapsInFilter(filter_index);
+ const int next_row = num_taps - 1;
+ const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+ auto* dst16 = static_cast<uint16_t*>(dst);
+ const uint16x4_t v_zero = vdup_n_u16(0);
+
+ uint16x4_t srcs[9];
+ srcs[0] = Load2<0>(src, v_zero);
+ src += src_stride;
+ if (num_taps >= 4) {
+ srcs[0] = Load2<1>(src, srcs[0]);
+ src += src_stride;
+ srcs[2] = Load2<0>(src, v_zero);
+ src += src_stride;
+ srcs[1] = vext_u16(srcs[0], srcs[2], 2);
+ if (num_taps >= 6) {
+ srcs[2] = Load2<1>(src, srcs[2]);
+ src += src_stride;
+ srcs[4] = Load2<0>(src, v_zero);
+ src += src_stride;
+ srcs[3] = vext_u16(srcs[2], srcs[4], 2);
+ if (num_taps == 8) {
+ srcs[4] = Load2<1>(src, srcs[4]);
+ src += src_stride;
+ srcs[6] = Load2<0>(src, v_zero);
+ src += src_stride;
+ srcs[5] = vext_u16(srcs[4], srcs[6], 2);
+ }
+ }
+ }
+
+ int y = height;
+ do {
+ srcs[next_row - 1] = Load2<1>(src, srcs[next_row - 1]);
+ src += src_stride;
+ srcs[num_taps] = Load2<0>(src, v_zero);
+ src += src_stride;
+ srcs[next_row] = vext_u16(srcs[next_row - 1], srcs[num_taps], 2);
+
+ const int32x4_t v_sum = SumOnePassTaps<filter_index>(srcs, taps);
+ const uint16x4_t d0 =
+ vmin_u16(vqrshrun_n_s32(v_sum, kFilterBits - 1), v_max_bitdepth);
+ Store2<0>(dst16, d0);
+ dst16 += dst_stride;
+ Store2<1>(dst16, d0);
+ dst16 += dst_stride;
+
+ srcs[0] = srcs[2];
+ if (num_taps >= 4) {
+ srcs[1] = srcs[3];
+ srcs[2] = srcs[4];
+ if (num_taps >= 6) {
+ srcs[3] = srcs[5];
+ srcs[4] = srcs[6];
+ if (num_taps == 8) {
+ srcs[5] = srcs[7];
+ srcs[6] = srcs[8];
+ }
+ }
+ }
+ y -= 2;
+ } while (y != 0);
+}
+
+void ConvolveVertical_NEON(
+ const void* LIBGAV1_RESTRICT const reference,
+ const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+ const int vertical_filter_index, const int /*horizontal_filter_id*/,
+ const int vertical_filter_id, const int width, const int height,
+ void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
+ const int filter_index = GetFilterIndex(vertical_filter_index, height);
+ const int vertical_taps = GetNumTapsInFilter(filter_index);
+ const ptrdiff_t src_stride = reference_stride >> 1;
+ const auto* src = static_cast<const uint16_t*>(reference) -
+ (vertical_taps / 2 - 1) * src_stride;
+ auto* const dest = static_cast<uint16_t*>(prediction);
+ const ptrdiff_t dest_stride = pred_stride >> 1;
+ assert(vertical_filter_id != 0);
+
+ int16x4_t taps[8];
+ for (int k = 0; k < kSubPixelTaps; ++k) {
+ taps[k] =
+ vdup_n_s16(kHalfSubPixelFilters[filter_index][vertical_filter_id][k]);
+ }
+
+ if (filter_index == 0) { // 6 tap.
+ if (width == 2) {
+ FilterVertical2xH<0>(src, src_stride, dest, dest_stride, height,
+ taps + 1);
+ } else if (width == 4) {
+ FilterVertical4xH<0>(src, src_stride, dest, dest_stride, height,
+ taps + 1);
+ } else {
+ FilterVertical<0>(src, src_stride, dest, dest_stride, width, height,
+ taps + 1);
+ }
+ } else if ((filter_index == 1) &
+ ((vertical_filter_id == 1) | (vertical_filter_id == 7) |
+ (vertical_filter_id == 8) | (vertical_filter_id == 9) |
+ (vertical_filter_id == 15))) { // 6 tap.
+ if (width == 2) {
+ FilterVertical2xH<1>(src, src_stride, dest, dest_stride, height,
+ taps + 1);
+ } else if (width == 4) {
+ FilterVertical4xH<1>(src, src_stride, dest, dest_stride, height,
+ taps + 1);
+ } else {
+ FilterVertical<1>(src, src_stride, dest, dest_stride, width, height,
+ taps + 1);
+ }
+ } else if (filter_index == 2) { // 8 tap.
+ if (width == 2) {
+ FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps);
+ } else if (width == 4) {
+ FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps);
+ } else {
+ FilterVertical<2>(src, src_stride, dest, dest_stride, width, height,
+ taps);
+ }
+ } else if (filter_index == 3) { // 2 tap.
+ if (width == 2) {
+ FilterVertical2xH<3>(src, src_stride, dest, dest_stride, height,
+ taps + 3);
+ } else if (width == 4) {
+ FilterVertical4xH<3>(src, src_stride, dest, dest_stride, height,
+ taps + 3);
+ } else {
+ FilterVertical<3>(src, src_stride, dest, dest_stride, width, height,
+ taps + 3);
+ }
+ } else {
+ // 4 tap. When |filter_index| == 1 the |vertical_filter_id| values listed
+ // below map to 4 tap filters.
+ assert(filter_index == 5 || filter_index == 4 ||
+ (filter_index == 1 &&
+ (vertical_filter_id == 0 || vertical_filter_id == 2 ||
+ vertical_filter_id == 3 || vertical_filter_id == 4 ||
+ vertical_filter_id == 5 || vertical_filter_id == 6 ||
+ vertical_filter_id == 10 || vertical_filter_id == 11 ||
+ vertical_filter_id == 12 || vertical_filter_id == 13 ||
+ vertical_filter_id == 14)));
+ // According to GetNumTapsInFilter() this has 6 taps but here we are
+ // treating it as though it has 4.
+ if (filter_index == 1) src += src_stride;
+ if (width == 2) {
+ FilterVertical2xH<5>(src, src_stride, dest, dest_stride, height,
+ taps + 2);
+ } else if (width == 4) {
+ FilterVertical4xH<5>(src, src_stride, dest, dest_stride, height,
+ taps + 2);
+ } else {
+ FilterVertical<5>(src, src_stride, dest, dest_stride, width, height,
+ taps + 2);
+ }
+ }
+}
+
+void ConvolveCompoundVertical_NEON(
+ const void* LIBGAV1_RESTRICT const reference,
+ const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+ const int vertical_filter_index, const int /*horizontal_filter_id*/,
+ const int vertical_filter_id, const int width, const int height,
+ void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
+ const int filter_index = GetFilterIndex(vertical_filter_index, height);
+ const int vertical_taps = GetNumTapsInFilter(filter_index);
+ const ptrdiff_t src_stride = reference_stride >> 1;
+ const auto* src = static_cast<const uint16_t*>(reference) -
+ (vertical_taps / 2 - 1) * src_stride;
+ auto* const dest = static_cast<uint16_t*>(prediction);
+ assert(vertical_filter_id != 0);
+
+ int16x4_t taps[8];
+ for (int k = 0; k < kSubPixelTaps; ++k) {
+ taps[k] =
+ vdup_n_s16(kHalfSubPixelFilters[filter_index][vertical_filter_id][k]);
+ }
+
+ if (filter_index == 0) { // 6 tap.
+ if (width == 4) {
+ FilterVertical4xH<0, /*is_compound=*/true>(src, src_stride, dest, 4,
+ height, taps + 1);
+ } else {
+ FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width,
+ width, height, taps + 1);
+ }
+ } else if ((filter_index == 1) &
+ ((vertical_filter_id == 1) | (vertical_filter_id == 7) |
+ (vertical_filter_id == 8) | (vertical_filter_id == 9) |
+ (vertical_filter_id == 15))) { // 6 tap.
+ if (width == 4) {
+ FilterVertical4xH<1, /*is_compound=*/true>(src, src_stride, dest, 4,
+ height, taps + 1);
+ } else {
+ FilterVertical<1, /*is_compound=*/true>(src, src_stride, dest, width,
+ width, height, taps + 1);
+ }
+ } else if (filter_index == 2) { // 8 tap.
+ if (width == 4) {
+ FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4,
+ height, taps);
+ } else {
+ FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width,
+ width, height, taps);
+ }
+ } else if (filter_index == 3) { // 2 tap.
+ if (width == 4) {
+ FilterVertical4xH<3, /*is_compound=*/true>(src, src_stride, dest, 4,
+ height, taps + 3);
+ } else {
+ FilterVertical<3, /*is_compound=*/true>(src, src_stride, dest, width,
+ width, height, taps + 3);
+ }
+ } else {
+ // 4 tap. When |filter_index| == 1 the |filter_id| values listed below map
+ // to 4 tap filters.
+ assert(filter_index == 5 || filter_index == 4 ||
+ (filter_index == 1 &&
+ (vertical_filter_id == 2 || vertical_filter_id == 3 ||
+ vertical_filter_id == 4 || vertical_filter_id == 5 ||
+ vertical_filter_id == 6 || vertical_filter_id == 10 ||
+ vertical_filter_id == 11 || vertical_filter_id == 12 ||
+ vertical_filter_id == 13 || vertical_filter_id == 14)));
+ // According to GetNumTapsInFilter() this has 6 taps but here we are
+ // treating it as though it has 4.
+ if (filter_index == 1) src += src_stride;
+ if (width == 4) {
+ FilterVertical4xH<5, /*is_compound=*/true>(src, src_stride, dest, 4,
+ height, taps + 2);
+ } else {
+ FilterVertical<5, /*is_compound=*/true>(src, src_stride, dest, width,
+ width, height, taps + 2);
+ }
+ }
+}
+
void ConvolveCompoundCopy_NEON(
const void* const reference, const ptrdiff_t reference_stride,
const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
@@ -476,9 +882,11 @@
Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
assert(dsp != nullptr);
dsp->convolve[0][0][0][1] = ConvolveHorizontal_NEON;
+ dsp->convolve[0][0][1][0] = ConvolveVertical_NEON;
dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_NEON;
dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_NEON;
+ dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_NEON;
}
} // namespace
diff --git a/src/dsp/arm/convolve_neon.h b/src/dsp/arm/convolve_neon.h
index 8e2b88d..1cac618 100644
--- a/src/dsp/arm/convolve_neon.h
+++ b/src/dsp/arm/convolve_neon.h
@@ -48,9 +48,11 @@
#define LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D LIBGAV1_CPU_NEON
#define LIBGAV1_Dsp10bpp_ConvolveHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_ConvolveVertical LIBGAV1_CPU_NEON
#define LIBGAV1_Dsp10bpp_ConvolveCompoundCopy LIBGAV1_CPU_NEON
#define LIBGAV1_Dsp10bpp_ConvolveCompoundHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_ConvolveCompoundVertical LIBGAV1_CPU_NEON
#endif // LIBGAV1_ENABLE_NEON
#endif // LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_