add aom_highbd_lpf_horizontal_6_neon

ported from libgav1 @ v0.17.0-57-g43d9fe1

Bug: b/217462944
Change-Id: I9381b46b68d4a112d4e07f48fc44e68ef116fb2a
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index d39c23c..9c2a734 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -475,7 +475,7 @@
   specialize qw/aom_highbd_lpf_horizontal_14_dual sse2 avx2/;
 
   add_proto qw/void aom_highbd_lpf_horizontal_6/, "uint16_t *s, int pitch, const uint8_t *blimit, const uint8_t *limit, const uint8_t *thresh, int bd";
-  specialize qw/aom_highbd_lpf_horizontal_6 sse2/;
+  specialize qw/aom_highbd_lpf_horizontal_6 neon sse2/;
 
   add_proto qw/void aom_highbd_lpf_horizontal_6_dual/, "uint16_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, const uint8_t *thresh1, int bd";
   specialize qw/aom_highbd_lpf_horizontal_6_dual sse2/;
diff --git a/aom_dsp/arm/highbd_loopfilter_neon.c b/aom_dsp/arm/highbd_loopfilter_neon.c
index fc87c1e..991c3a7 100644
--- a/aom_dsp/arm/highbd_loopfilter_neon.c
+++ b/aom_dsp/arm/highbd_loopfilter_neon.c
@@ -65,6 +65,19 @@
   return vand_u16(inner_mask, outer_mask);
 }
 
+// abs(p2 - p1) <= inner_thresh && abs(p1 - p0) <= inner_thresh &&
+//   abs(q1 - q0) <= inner_thresh && abs(q2 - q1) <= inner_thresh &&
+//   OuterThreshold()
+static INLINE uint16x4_t NeedsFilter6(const uint16x8_t abd_p0p1_q0q1,
+                                      const uint16x8_t abd_p1p2_q1q2,
+                                      const uint16_t inner_thresh,
+                                      const uint16x4_t outer_mask) {
+  const uint16x8_t a = vmaxq_u16(abd_p0p1_q0q1, abd_p1p2_q1q2);
+  const uint16x8_t b = vcleq_u16(a, vdupq_n_u16(inner_thresh));
+  const uint16x4_t inner_mask = vand_u16(vget_low_u16(b), vget_high_u16(b));
+  return vand_u16(inner_mask, outer_mask);
+}
+
 // -----------------------------------------------------------------------------
 // FilterNMasks functions.
 
@@ -85,6 +98,29 @@
   *hev_mask = vand_u16(hev_tmp_mask, *needs_filter4_mask);
 }
 
+// abs(p1 - p0) <= flat_thresh && abs(q1 - q0) <= flat_thresh &&
+//   abs(p2 - p0) <= flat_thresh && abs(q2 - q0) <= flat_thresh
+// |flat_thresh| == 4 for 10 bit decode.
+static INLINE uint16x4_t IsFlat3(const uint16x8_t abd_p0p1_q0q1,
+                                 const uint16x8_t abd_p0p2_q0q2) {
+  const int flat_thresh = 1 << 2;
+  const uint16x8_t a = vmaxq_u16(abd_p0p1_q0q1, abd_p0p2_q0q2);
+  const uint16x8_t b = vcleq_u16(a, vdupq_n_u16(flat_thresh));
+  return vand_u16(vget_low_u16(b), vget_high_u16(b));
+}
+
+static INLINE void Filter6Masks(
+    const uint16x8_t p2q2, const uint16x8_t p1q1, const uint16x8_t p0q0,
+    const uint16_t hev_thresh, const uint16x4_t outer_mask,
+    const uint16_t inner_thresh, uint16x4_t *const needs_filter6_mask,
+    uint16x4_t *const is_flat3_mask, uint16x4_t *const hev_mask) {
+  const uint16x8_t abd_p0p1_q0q1 = vabdq_u16(p0q0, p1q1);
+  *hev_mask = Hev(abd_p0p1_q0q1, hev_thresh);
+  *is_flat3_mask = IsFlat3(abd_p0p1_q0q1, vabdq_u16(p0q0, p2q2));
+  *needs_filter6_mask = NeedsFilter6(abd_p0p1_q0q1, vabdq_u16(p1q1, p2q2),
+                                     inner_thresh, outer_mask);
+}
+
 // -----------------------------------------------------------------------------
 // FilterN functions.
 
@@ -281,3 +317,139 @@
   vst1_u16(dst_q0, output[2]);
   vst1_u16(dst_q1, output[3]);
 }
+
+static INLINE void Filter6(const uint16x8_t p2q2, const uint16x8_t p1q1,
+                           const uint16x8_t p0q0, uint16x8_t *const p1q1_output,
+                           uint16x8_t *const p0q0_output) {
+  // Sum p1 and q1 output from opposite directions.
+  // The formula is regrouped to allow 3 doubling operations to be combined.
+  //
+  // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
+  //      ^^^^^^^^
+  // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q2)
+  //                                 ^^^^^^^^
+  // p1q1 = p2q2 + 2 * (p2q2 + p1q1 + p0q0) + q0p0
+  //                    ^^^^^^^^^^^
+  uint16x8_t sum = vaddq_u16(p2q2, p1q1);
+
+  // p1q1 = p2q2 + 2 * (p2q2 + p1q1 + p0q0) + q0p0
+  //                                ^^^^^^
+  sum = vaddq_u16(sum, p0q0);
+
+  // p1q1 = p2q2 + 2 * (p2q2 + p1q1 + p0q0) + q0p0
+  //               ^^^^^
+  sum = vshlq_n_u16(sum, 1);
+
+  // p1q1 = p2q2 + 2 * (p2q2 + p1q1 + p0q0) + q0p0
+  //        ^^^^^^                          ^^^^^^
+  // Should dual issue with the left shift.
+  const uint16x8_t q0p0 = Transpose64(p0q0);
+  const uint16x8_t outer_sum = vaddq_u16(p2q2, q0p0);
+  sum = vaddq_u16(sum, outer_sum);
+
+  *p1q1_output = vrshrq_n_u16(sum, 3);
+
+  // Convert to p0 and q0 output:
+  // p0 = p1 - (2 * p2) + q0 + q1
+  // q0 = q1 - (2 * q2) + p0 + p1
+  // p0q0 = p1q1 - (2 * p2q2) + q0p0 + q1p1
+  //                ^^^^^^^^
+  const uint16x8_t p2q2_double = vshlq_n_u16(p2q2, 1);
+  // p0q0 = p1q1 - (2 * p2q2) + q0p0 + q1p1
+  //        ^^^^^^^^
+  sum = vsubq_u16(sum, p2q2_double);
+  const uint16x8_t q1p1 = Transpose64(p1q1);
+  sum = vaddq_u16(sum, vaddq_u16(q0p0, q1p1));
+
+  *p0q0_output = vrshrq_n_u16(sum, 3);
+}
+
+void aom_highbd_lpf_horizontal_6_neon(uint16_t *s, int pitch,
+                                      const uint8_t *blimit,
+                                      const uint8_t *limit,
+                                      const uint8_t *thresh, int bd) {
+  // TODO(b/217462944): add support for 8/12-bit.
+  if (bd != 10) {
+    aom_highbd_lpf_horizontal_6_c(s, pitch, blimit, limit, thresh, bd);
+    return;
+  }
+  uint16_t *const dst_p2 = s - 3 * pitch;
+  uint16_t *const dst_p1 = s - 2 * pitch;
+  uint16_t *const dst_p0 = s - pitch;
+  uint16_t *const dst_q0 = s;
+  uint16_t *const dst_q1 = s + pitch;
+  uint16_t *const dst_q2 = s + 2 * pitch;
+
+  const uint16x4_t src[6] = { vld1_u16(dst_p2), vld1_u16(dst_p1),
+                              vld1_u16(dst_p0), vld1_u16(dst_q0),
+                              vld1_u16(dst_q1), vld1_u16(dst_q2) };
+
+  // Adjust thresholds to bitdepth.
+  const int outer_thresh = *blimit << 2;
+  const int inner_thresh = *limit << 2;
+  const int hev_thresh = *thresh << 2;
+  const uint16x4_t outer_mask =
+      OuterThreshold(src[1], src[2], src[3], src[4], outer_thresh);
+  uint16x4_t hev_mask;
+  uint16x4_t needs_filter_mask;
+  uint16x4_t is_flat3_mask;
+  const uint16x8_t p0q0 = vcombine_u16(src[2], src[3]);
+  const uint16x8_t p1q1 = vcombine_u16(src[1], src[4]);
+  const uint16x8_t p2q2 = vcombine_u16(src[0], src[5]);
+  Filter6Masks(p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh,
+               &needs_filter_mask, &is_flat3_mask, &hev_mask);
+
+#if defined(__aarch64__)
+  if (vaddv_u16(needs_filter_mask) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#else   // !defined(__aarch64__)
+  // This might be faster than vaddv (latency 3) because mov to general register
+  // has latency 2.
+  const uint64x1_t needs_filter_mask64 =
+      vreinterpret_u64_u16(needs_filter_mask);
+  if (vget_lane_u64(needs_filter_mask64, 0) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#endif  // defined(__aarch64__)
+
+  // Copy the masks to the high bits for packed comparisons later.
+  const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask);
+  const uint16x8_t is_flat3_mask_8 = vcombine_u16(is_flat3_mask, is_flat3_mask);
+  const uint16x8_t needs_filter_mask_8 =
+      vcombine_u16(needs_filter_mask, needs_filter_mask);
+
+  uint16x8_t f4_p1q1;
+  uint16x8_t f4_p0q0;
+  // ZIP1 p0q0, p1q1 may perform better here.
+  const uint16x8_t p0q1 = vcombine_u16(src[2], src[4]);
+  Filter4(p0q0, p0q1, p1q1, hev_mask, &f4_p1q1, &f4_p0q0);
+  f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1);
+
+  uint16x8_t p0q0_output, p1q1_output;
+  // Because we did not return after testing |needs_filter_mask| we know it is
+  // nonzero. |is_flat3_mask| controls whether the needed filter is Filter4 or
+  // Filter6. Therefore if it is false when |needs_filter_mask| is true, Filter6
+  // output is not used.
+  uint16x8_t f6_p1q1, f6_p0q0;
+  const uint64x1_t need_filter6 = vreinterpret_u64_u16(is_flat3_mask);
+  if (vget_lane_u64(need_filter6, 0) == 0) {
+    // Filter6() does not apply, but Filter4() applies to one or more values.
+    p0q0_output = p0q0;
+    p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1);
+    p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0);
+  } else {
+    Filter6(p2q2, p1q1, p0q0, &f6_p1q1, &f6_p0q0);
+    p1q1_output = vbslq_u16(is_flat3_mask_8, f6_p1q1, f4_p1q1);
+    p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1);
+    p0q0_output = vbslq_u16(is_flat3_mask_8, f6_p0q0, f4_p0q0);
+    p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0);
+  }
+
+  vst1_u16(dst_p1, vget_low_u16(p1q1_output));
+  vst1_u16(dst_p0, vget_low_u16(p0q0_output));
+  vst1_u16(dst_q0, vget_high_u16(p0q0_output));
+  vst1_u16(dst_q1, vget_high_u16(p1q1_output));
+}
diff --git a/aom_dsp/arm/transpose_neon.h b/aom_dsp/arm/transpose_neon.h
index b697b48..c838fc7 100644
--- a/aom_dsp/arm/transpose_neon.h
+++ b/aom_dsp/arm/transpose_neon.h
@@ -16,6 +16,11 @@
 // TODO(b/217462944): rename functions from libgav1 to libaom style after
 // highbd loop filter code is ported.
 
+// Swap high and low halves.
+static INLINE uint16x8_t Transpose64(const uint16x8_t a) {
+  return vextq_u16(a, a, 4);
+}
+
 static INLINE void transpose_u8_8x8(uint8x8_t *a0, uint8x8_t *a1, uint8x8_t *a2,
                                     uint8x8_t *a3, uint8x8_t *a4, uint8x8_t *a5,
                                     uint8x8_t *a6, uint8x8_t *a7) {
diff --git a/test/lpf_test.cc b/test/lpf_test.cc
index 20a77b1..3bfacc3 100644
--- a/test/lpf_test.cc
+++ b/test/lpf_test.cc
@@ -666,6 +666,12 @@
              10),
   make_tuple(&aom_highbd_lpf_horizontal_4_neon, &aom_highbd_lpf_horizontal_4_c,
              12),
+  make_tuple(&aom_highbd_lpf_horizontal_6_neon, &aom_highbd_lpf_horizontal_6_c,
+             8),
+  make_tuple(&aom_highbd_lpf_horizontal_6_neon, &aom_highbd_lpf_horizontal_6_c,
+             10),
+  make_tuple(&aom_highbd_lpf_horizontal_6_neon, &aom_highbd_lpf_horizontal_6_c,
+             12),
   make_tuple(&aom_highbd_lpf_vertical_4_neon, &aom_highbd_lpf_vertical_4_c, 8),
   make_tuple(&aom_highbd_lpf_vertical_4_neon, &aom_highbd_lpf_vertical_4_c, 10),
   make_tuple(&aom_highbd_lpf_vertical_4_neon, &aom_highbd_lpf_vertical_4_c, 12),