x86-64: Add msac_decode_bool_equi asm
diff --git a/src/msac.c b/src/msac.c
index 31e4004..34868ae 100644
--- a/src/msac.c
+++ b/src/msac.c
@@ -27,7 +27,6 @@
 
 #include "config.h"
 
-#include <assert.h>
 #include <limits.h>
 
 #include "common/intops.h"
@@ -68,7 +67,7 @@
         ctx_refill(s);
 }
 
-unsigned dav1d_msac_decode_bool_equi(MsacContext *const s) {
+unsigned dav1d_msac_decode_bool_equi_c(MsacContext *const s) {
     ec_win vw, dif = s->dif;
     unsigned ret, v, r = s->rng;
     assert((dif >> (EC_WIN_SIZE - 16)) < r);
@@ -99,13 +98,6 @@
     return !ret;
 }
 
-unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
-    unsigned v = 0;
-    while (n--)
-        v = (v << 1) | dav1d_msac_decode_bool_equi(s);
-    return v;
-}
-
 int dav1d_msac_decode_subexp(MsacContext *const s, const int ref,
                              const int n, const unsigned k)
 {
@@ -122,15 +114,6 @@
                           n - 1 - inv_recenter(n - 1 - ref, v);
 }
 
-int dav1d_msac_decode_uniform(MsacContext *const s, const unsigned n) {
-    assert(n > 0);
-    const int l = ulog2(n) + 1;
-    assert(l > 1);
-    const unsigned m = (1 << l) - n;
-    const unsigned v = dav1d_msac_decode_bools(s, l - 1);
-    return v < m ? v : (v << 1) - m + dav1d_msac_decode_bool_equi(s);
-}
-
 /* Decodes a symbol given an inverse cumulative distribution function (CDF)
  * table in Q15. */
 static unsigned decode_symbol(MsacContext *const s, const uint16_t *const cdf,
diff --git a/src/msac.h b/src/msac.h
index cd04c30..f1de11d 100644
--- a/src/msac.h
+++ b/src/msac.h
@@ -28,6 +28,7 @@
 #ifndef DAV1D_SRC_MSAC_H
 #define DAV1D_SRC_MSAC_H
 
+#include <assert.h>
 #include <stdint.h>
 #include <stdlib.h>
 
@@ -47,12 +48,10 @@
                      int disable_cdf_update_flag);
 unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *s, uint16_t *cdf,
                                           size_t n_symbols);
-unsigned dav1d_msac_decode_bool_equi(MsacContext *s);
+unsigned dav1d_msac_decode_bool_equi_c(MsacContext *s);
 unsigned dav1d_msac_decode_bool(MsacContext *s, unsigned f);
 unsigned dav1d_msac_decode_bool_adapt(MsacContext *s, uint16_t *cdf);
-unsigned dav1d_msac_decode_bools(MsacContext *s, unsigned n);
 int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
-int dav1d_msac_decode_uniform(MsacContext *s, unsigned n);
 
 /* Supported n_symbols ranges: adapt4: 1-5, adapt8: 1-8, adapt16: 4-16 */
 #if ARCH_AARCH64 && HAVE_ASM
@@ -65,6 +64,7 @@
 #define dav1d_msac_decode_symbol_adapt4  dav1d_msac_decode_symbol_adapt4_neon
 #define dav1d_msac_decode_symbol_adapt8  dav1d_msac_decode_symbol_adapt8_neon
 #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_neon
+#define dav1d_msac_decode_bool_equi      dav1d_msac_decode_bool_equi_c
 #elif ARCH_X86_64 && HAVE_ASM
 unsigned dav1d_msac_decode_symbol_adapt4_sse2(MsacContext *s, uint16_t *cdf,
                                               size_t n_symbols);
@@ -72,13 +72,32 @@
                                               size_t n_symbols);
 unsigned dav1d_msac_decode_symbol_adapt16_sse2(MsacContext *s, uint16_t *cdf,
                                                size_t n_symbols);
+unsigned dav1d_msac_decode_bool_equi_sse2(MsacContext *s);
 #define dav1d_msac_decode_symbol_adapt4  dav1d_msac_decode_symbol_adapt4_sse2
 #define dav1d_msac_decode_symbol_adapt8  dav1d_msac_decode_symbol_adapt8_sse2
 #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
+#define dav1d_msac_decode_bool_equi      dav1d_msac_decode_bool_equi_sse2
 #else
 #define dav1d_msac_decode_symbol_adapt4  dav1d_msac_decode_symbol_adapt_c
 #define dav1d_msac_decode_symbol_adapt8  dav1d_msac_decode_symbol_adapt_c
 #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt_c
+#define dav1d_msac_decode_bool_equi      dav1d_msac_decode_bool_equi_c
 #endif
 
+static inline unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
+    unsigned v = 0;
+    while (n--)
+        v = (v << 1) | dav1d_msac_decode_bool_equi(s);
+    return v;
+}
+
+static inline int dav1d_msac_decode_uniform(MsacContext *const s, const unsigned n) {
+    assert(n > 0);
+    const int l = ulog2(n) + 1;
+    assert(l > 1);
+    const unsigned m = (1 << l) - n;
+    const unsigned v = dav1d_msac_decode_bools(s, l - 1);
+    return v < m ? v : (v << 1) - m + dav1d_msac_decode_bool_equi(s);
+}
+
 #endif /* DAV1D_SRC_MSAC_H */
diff --git a/src/x86/msac.asm b/src/x86/msac.asm
index 9f3a820..a6b7f33 100644
--- a/src/x86/msac.asm
+++ b/src/x86/msac.asm
@@ -111,6 +111,7 @@
     sub           r2d, r1d ; rng
     shl            r1, 48
     add            r4, r1  ; ~dif
+.renorm3:
     mov           r1d, [sq+msac.cnt]
     movifnidn      t0, sq
     bsr           ecx, r2d
@@ -284,4 +285,21 @@
 %endif
     jmp m(msac_decode_symbol_adapt4).renorm2
 
+cglobal msac_decode_bool_equi, 1, 7, 0, s
+    mov           r1d, [sq+msac.rng]
+    mov            r4, [sq+msac.dif]
+    mov           r2d, r1d
+    mov           r1b, 8
+    mov            r3, r4
+    mov           eax, r1d
+    shr           r1d, 1   ; v
+    shl           rax, 47  ; vw
+    sub           r2d, r1d ; r - v
+    sub            r4, rax ; dif - vw
+    cmovb         r2d, r1d
+    cmovb          r4, r3
+    setb           al ; the upper 32 bits contains garbage but that's OK
+    not            r4
+    jmp m(msac_decode_symbol_adapt4).renorm3
+
 %endif
diff --git a/tests/checkasm/msac.c b/tests/checkasm/msac.c
index 3808e31..ccd3c55 100644
--- a/tests/checkasm/msac.c
+++ b/tests/checkasm/msac.c
@@ -32,14 +32,18 @@
 
 #include <string.h>
 
+#define BUF_SIZE 8192
+
 /* The normal code doesn't use function pointers */
 typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
                                            size_t n_symbols);
+typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
 
 typedef struct {
     decode_symbol_adapt_fn symbol_adapt4;
     decode_symbol_adapt_fn symbol_adapt8;
     decode_symbol_adapt_fn symbol_adapt16;
+    decode_bool_equi_fn    bool_equi;
 } MsacDSPContext;
 
 static void randomize_cdf(uint16_t *const cdf, int n) {
@@ -61,7 +65,7 @@
     if (check_func(c->symbol_adapt##n, "msac_decode_symbol_adapt%d", n)) { \
         for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {          \
             for (int ns = n_min; ns <= n_max; ns++) {                      \
-                dav1d_msac_init(&s_c, buf, sizeof(buf), !cdf_update);      \
+                dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);      \
                 s_a = s_c;                                                 \
                 randomize_cdf(cdf[0], ns);                                 \
                 memcpy(cdf[1], cdf[0], sizeof(*cdf));                      \
@@ -81,14 +85,13 @@
     }                                                                      \
 } while (0)
 
-static void check_decode_symbol_adapt(MsacDSPContext *const c) {
+static void check_decode_symbol_adapt(MsacDSPContext *const c,
+                                      uint8_t *const buf)
+{
     /* Use an aligned CDF buffer for more consistent benchmark
      * results, and a misaligned one for checking correctness. */
     ALIGN_STK_16(uint16_t, cdf, 2, [17]);
     MsacContext s_c, s_a;
-    uint8_t buf[1024];
-    for (int i = 0; i < 1024; i++)
-        buf[i] = rnd();
 
     declare_func(unsigned, MsacContext *s, uint16_t *cdf, size_t n_symbols);
     CHECK_SYMBOL_ADAPT( 4, 1,  5);
@@ -97,11 +100,33 @@
     report("decode_symbol_adapt");
 }
 
+static void check_decode_bool_equi(MsacDSPContext *const c,
+                                   uint8_t *const buf)
+{
+    declare_func(unsigned, MsacContext *s);
+
+    if (check_func(c->bool_equi, "msac_decode_bool_equi")) {
+        MsacContext s_c, s_a;
+        dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
+        s_a = s_c;
+        for (int i = 0; i < 64; i++) {
+            unsigned c_res = call_ref(&s_c);
+            unsigned a_res = call_new(&s_a);
+            if (c_res != a_res || msac_cmp(&s_c, &s_a))
+                fail();
+        }
+        bench_new(&s_a);
+    }
+
+    report("decode_bool_equi");
+}
+
 void checkasm_check_msac(void) {
     MsacDSPContext c;
     c.symbol_adapt4  = dav1d_msac_decode_symbol_adapt_c;
     c.symbol_adapt8  = dav1d_msac_decode_symbol_adapt_c;
     c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
+    c.bool_equi      = dav1d_msac_decode_bool_equi_c;
 
 #if ARCH_AARCH64 && HAVE_ASM
     if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
@@ -114,8 +139,14 @@
         c.symbol_adapt4  = dav1d_msac_decode_symbol_adapt4_sse2;
         c.symbol_adapt8  = dav1d_msac_decode_symbol_adapt8_sse2;
         c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
+        c.bool_equi      = dav1d_msac_decode_bool_equi_sse2;
     }
 #endif
 
-    check_decode_symbol_adapt(&c);
+    uint8_t buf[BUF_SIZE];
+    for (int i = 0; i < BUF_SIZE; i++)
+        buf[i] = rnd();
+
+    check_decode_symbol_adapt(&c, buf);
+    check_decode_bool_equi(&c, buf);
 }