Add an SSE implementation of av1_fwht4x4

This is actually used in lossless and the SSE implementation is
more than 3 times faster.

Also add a test for lossless.

Bug: b/191463451

Change-Id: Ia503a3426b51b14323ad94badd9968d047a22be5
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index ea16b43..c4777ea 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -302,7 +302,7 @@
   # fdct functions
 
   add_proto qw/void av1_fwht4x4/, "const int16_t *input, tran_low_t *output, int stride";
-  specialize qw/av1_fwht4x4 neon/;
+  specialize qw/av1_fwht4x4 sse4_1 neon/;
 
   #fwd txfm
   add_proto qw/void av1_lowbd_fwd_txfm/, "const int16_t *src_diff, tran_low_t *coeff, int diff_stride, TxfmParam *txfm_param";
@@ -383,7 +383,7 @@
   }
 
   add_proto qw/void av1_highbd_fwht4x4/, "const int16_t *input, tran_low_t *output, int stride";
-  specialize qw/av1_highbd_fwht4x4 neon/;
+  specialize qw/av1_highbd_fwht4x4 sse4_1 neon/;
 
   # End av1_high encoder functions
 
diff --git a/av1/encoder/x86/highbd_fwd_txfm_sse4.c b/av1/encoder/x86/highbd_fwd_txfm_sse4.c
index 9a0a36c..73f9b44 100644
--- a/av1/encoder/x86/highbd_fwd_txfm_sse4.c
+++ b/av1/encoder/x86/highbd_fwd_txfm_sse4.c
@@ -11,16 +11,70 @@
 #include <assert.h>
 #include <smmintrin.h> /* SSE4.1 */
 
-#include "config/aom_config.h"
-#include "config/av1_rtcd.h"
-
+#include "aom_dsp/txfm_common.h"
+#include "aom_dsp/x86/transpose_sse2.h"
+#include "aom_dsp/x86/txfm_common_sse2.h"
+#include "aom_ports/mem.h"
 #include "av1/common/av1_txfm.h"
 #include "av1/common/x86/highbd_txfm_utility_sse4.h"
 #include "av1/encoder/av1_fwd_txfm1d_cfg.h"
 #include "av1/encoder/x86/av1_txfm1d_sse4.h"
-#include "aom_dsp/txfm_common.h"
-#include "aom_dsp/x86/txfm_common_sse2.h"
-#include "aom_ports/mem.h"
+#include "config/aom_config.h"
+#include "config/av1_rtcd.h"
+
+void av1_fwht4x4_sse4_1(const int16_t *input, tran_low_t *output, int stride) {
+  __m128i in[4];
+  in[0] = _mm_loadl_epi64((const __m128i *)(input + 0 * stride));
+  in[1] = _mm_loadl_epi64((const __m128i *)(input + 1 * stride));
+  in[2] = _mm_loadl_epi64((const __m128i *)(input + 2 * stride));
+  in[3] = _mm_loadl_epi64((const __m128i *)(input + 3 * stride));
+
+  // Convert to int32_t.
+  __m128i op[4];
+  op[0] = _mm_cvtepi16_epi32(in[0]);
+  op[1] = _mm_cvtepi16_epi32(in[1]);
+  op[2] = _mm_cvtepi16_epi32(in[2]);
+  op[3] = _mm_cvtepi16_epi32(in[3]);
+
+  for (int i = 0; i < 2; ++i) {
+    __m128i a1 = op[0];
+    __m128i b1 = op[1];
+    __m128i c1 = op[2];
+    __m128i d1 = op[3];
+    __m128i e1;
+
+    a1 = _mm_add_epi32(a1, b1);  // a1 += b1
+    d1 = _mm_sub_epi32(d1, c1);  // d1 = d1 - c1
+    e1 = _mm_sub_epi32(a1, d1);  // e1 = (a1 - d1) >> 1
+    e1 = _mm_srai_epi32(e1, 1);
+    b1 = _mm_sub_epi32(e1, b1);  // b1 = e1 - b1
+    c1 = _mm_sub_epi32(e1, c1);  // c1 = e1 - c1
+    a1 = _mm_sub_epi32(a1, c1);  // a1 -= c1
+    d1 = _mm_add_epi32(d1, b1);  // d1 += b1
+
+    op[0] = a1;
+    op[1] = c1;
+    op[2] = d1;
+    op[3] = b1;
+
+    transpose_32bit_4x4(op, op);
+  }
+
+  op[0] = _mm_slli_epi32(op[0], UNIT_QUANT_SHIFT);
+  op[1] = _mm_slli_epi32(op[1], UNIT_QUANT_SHIFT);
+  op[2] = _mm_slli_epi32(op[2], UNIT_QUANT_SHIFT);
+  op[3] = _mm_slli_epi32(op[3], UNIT_QUANT_SHIFT);
+
+  _mm_storeu_si128((__m128i *)(output + 0), op[0]);
+  _mm_storeu_si128((__m128i *)(output + 4), op[1]);
+  _mm_storeu_si128((__m128i *)(output + 8), op[2]);
+  _mm_storeu_si128((__m128i *)(output + 12), op[3]);
+}
+
+void av1_highbd_fwht4x4_sse4_1(const int16_t *input, tran_low_t *output,
+                               int stride) {
+  av1_fwht4x4_sse4_1(input, output, stride);
+}
 
 static INLINE void load_buffer_4x4(const int16_t *input, __m128i *in,
                                    int stride, int flipud, int fliplr,
diff --git a/test/fwht4x4_test.cc b/test/fwht4x4_test.cc
index b600d26..23a9fe5 100644
--- a/test/fwht4x4_test.cc
+++ b/test/fwht4x4_test.cc
@@ -182,6 +182,19 @@
                       make_tuple(&av1_highbd_fwht4x4_c, &iwht4x4_12, DCT_DCT,
                                  AOM_BITS_12, 16,
                                  static_cast<FdctFunc>(NULL))));
+#if HAVE_SSE4_1
+
+INSTANTIATE_TEST_SUITE_P(
+    SSE4_1, Trans4x4WHT,
+    ::testing::Values(make_tuple(&av1_highbd_fwht4x4_sse4_1, &iwht4x4_10,
+                                 DCT_DCT, AOM_BITS_10, 16,
+                                 static_cast<FdctFunc>(NULL)),
+                      make_tuple(&av1_highbd_fwht4x4_sse4_1, &iwht4x4_12,
+                                 DCT_DCT, AOM_BITS_12, 16,
+                                 static_cast<FdctFunc>(NULL))));
+
+#endif  // HAVE_SSE4_1
+
 #if HAVE_NEON
 
 INSTANTIATE_TEST_SUITE_P(