Split MC blend

The mstride == 0, mstride == 1, and mstride == w cases are very different
from each other, and splitting them into separate functions makes it easier
top optimize them.

Also add some further optimizations to the AVX2 asm that became possible
after this change.
diff --git a/src/mc.h b/src/mc.h
index abc978f..727cc5a 100644
--- a/src/mc.h
+++ b/src/mc.h
@@ -81,11 +81,14 @@
 typedef decl_w_mask_fn(*w_mask_fn);
 
 #define decl_blend_fn(name) \
-void (name)(pixel *dst, ptrdiff_t dst_stride, \
-            const pixel *tmp, int w, int h, \
-            const uint8_t *mask, ptrdiff_t mstride)
+void (name)(pixel *dst, ptrdiff_t dst_stride, const pixel *tmp, \
+            int w, int h, const uint8_t *mask)
 typedef decl_blend_fn(*blend_fn);
 
+#define decl_blend_dir_fn(name) \
+void (name)(pixel *dst, ptrdiff_t dst_stride, const pixel *tmp, int w, int h)
+typedef decl_blend_dir_fn(*blend_dir_fn);
+
 #define decl_emu_edge_fn(name) \
 void (name)(intptr_t bw, intptr_t bh, intptr_t iw, intptr_t ih, intptr_t x, intptr_t y, \
             pixel *dst, ptrdiff_t dst_stride, const pixel *src, ptrdiff_t src_stride)
@@ -99,6 +102,8 @@
     mask_fn mask;
     w_mask_fn w_mask[3 /* 444, 422, 420 */];
     blend_fn blend;
+    blend_dir_fn blend_v;
+    blend_dir_fn blend_h;
     warp8x8_fn warp8x8;
     warp8x8t_fn warp8x8t;
     emu_edge_fn emu_edge;
diff --git a/src/mc_tmpl.c b/src/mc_tmpl.c
index c43745e..cef6972 100644
--- a/src/mc_tmpl.c
+++ b/src/mc_tmpl.c
@@ -373,19 +373,46 @@
     } while (--h);
 }
 
-static void blend_c(pixel *dst, const ptrdiff_t dst_stride,
-                    const pixel *tmp, const int w, const int h,
-                    const uint8_t *mask, const ptrdiff_t m_stride)
-{
-    for (int y = 0; y < h; y++) {
-        for (int x = 0; x < w; x++) {
 #define blend_px(a, b, m) (((a * (64 - m) + b * m) + 32) >> 6)
-            dst[x] = blend_px(dst[x], tmp[x], mask[m_stride == 1 ? 0 : x]);
+static NOINLINE void
+blend_internal_c(pixel *dst, const ptrdiff_t dst_stride, const pixel *tmp,
+                 const int w, int h, const uint8_t *mask,
+                 const ptrdiff_t mask_stride)
+{
+    do {
+        for (int x = 0; x < w; x++) {
+            dst[x] = blend_px(dst[x], tmp[x], mask[x]);
         }
         dst += PXSTRIDE(dst_stride);
         tmp += w;
-        mask += m_stride;
-    }
+        mask += mask_stride;
+    } while (--h);
+}
+
+static void blend_c(pixel *dst, const ptrdiff_t dst_stride, const pixel *tmp,
+                    const int w, const int h, const uint8_t *mask)
+{
+    blend_internal_c(dst, dst_stride, tmp, w, h, mask, w);
+}
+
+static void blend_v_c(pixel *dst, const ptrdiff_t dst_stride, const pixel *tmp,
+                      const int w, const int h)
+{
+    blend_internal_c(dst, dst_stride, tmp, w, h, &dav1d_obmc_masks[w], 0);
+}
+
+static void blend_h_c(pixel *dst, const ptrdiff_t dst_stride, const pixel *tmp,
+                      const int w, int h)
+{
+    const uint8_t *mask = &dav1d_obmc_masks[h];
+    do {
+        const int m = *mask++;
+        for (int x = 0; x < w; x++) {
+            dst[x] = blend_px(dst[x], tmp[x], m);
+        }
+        dst += PXSTRIDE(dst_stride);
+        tmp += w;
+    } while (--h);
 }
 
 static void w_mask_c(pixel *dst, const ptrdiff_t dst_stride,
@@ -591,6 +618,8 @@
     c->w_avg    = w_avg_c;
     c->mask     = mask_c;
     c->blend    = blend_c;
+    c->blend_v  = blend_v_c;
+    c->blend_h  = blend_h_c;
     c->w_mask[0] = w_mask_444_c;
     c->w_mask[1] = w_mask_422_c;
     c->w_mask[2] = w_mask_420_c;
diff --git a/src/recon_tmpl.c b/src/recon_tmpl.c
index 46c5178..e4ea731 100644
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -579,9 +579,8 @@
                          &f->refp[a_r->ref[0] - 1],
                          dav1d_filter_2d[t->a->filter[1][bx4 + x + 1]][t->a->filter[0][bx4 + x + 1]]);
                 if (res) return res;
-                f->dsp->mc.blend(&dst[x * h_mul], dst_stride, lap,
-                                 h_mul * ow4, v_mul * oh4,
-                                 &dav1d_obmc_masks[v_mul * oh4], 1);
+                f->dsp->mc.blend_h(&dst[x * h_mul], dst_stride, lap,
+                                   h_mul * ow4, v_mul * oh4);
                 i++;
             }
             x += imax(a_b_dim[0], 2);
@@ -603,9 +602,8 @@
                          &f->refp[l_r->ref[0] - 1],
                          dav1d_filter_2d[t->l.filter[1][by4 + y + 1]][t->l.filter[0][by4 + y + 1]]);
                 if (res) return res;
-                f->dsp->mc.blend(&dst[y * v_mul * PXSTRIDE(dst_stride)],
-                                 dst_stride, lap, h_mul * ow4, v_mul * oh4,
-                                 &dav1d_obmc_masks[h_mul * ow4], 0);
+                f->dsp->mc.blend_v(&dst[y * v_mul * PXSTRIDE(dst_stride)],
+                                   dst_stride, lap, h_mul * ow4, v_mul * oh4);
                 i++;
             }
             y += imax(l_b_dim[1], 2);
@@ -1144,7 +1142,7 @@
                      dav1d_ii_masks[bs][0][b->interintra_mode] :
                      dav1d_wedge_masks[bs][0][0][b->wedge_idx];
             dsp->mc.blend(dst, f->cur.p.stride[0], tmp,
-                          bw4 * 4, bh4 * 4, ii_mask, bw4 * 4);
+                          bw4 * 4, bh4 * 4, ii_mask);
         }
 
         if (!has_chroma) goto skip_inter_chroma_pred;
@@ -1277,7 +1275,7 @@
                     dsp->ipred.intra_pred[m](tmp, cbw4 * 4 * sizeof(pixel),
                                              tl_edge, cbw4 * 4, cbh4 * 4, 0);
                     dsp->mc.blend(uvdst, f->cur.p.stride[1], tmp,
-                                  cbw4 * 4, cbh4 * 4, ii_mask, cbw4 * 4);
+                                  cbw4 * 4, cbh4 * 4, ii_mask);
                 }
             }
         }
diff --git a/src/x86/mc.asm b/src/x86/mc.asm
index 860e935..b63e8a5 100644
--- a/src/x86/mc.asm
+++ b/src/x86/mc.asm
@@ -30,6 +30,23 @@
 
 SECTION_RODATA 32
 
+; dav1d_obmc_masks[] with 64-x interleaved
+obmc_masks: db  0,  0,  0,  0
+            ; 2
+            db 45, 19, 64,  0
+            ; 4
+            db 39, 25, 50, 14, 59,  5, 64,  0
+            ; 8
+            db 36, 28, 42, 22, 48, 16, 53, 11, 57,  7, 61,  3, 64,  0, 64,  0
+            ; 16
+            db 34, 30, 37, 27, 40, 24, 43, 21, 46, 18, 49, 15, 52, 12, 54, 10
+            db 56,  8, 58,  6, 60,  4, 61,  3, 64,  0, 64,  0, 64,  0, 64,  0
+            ; 32
+            db 33, 31, 35, 29, 36, 28, 38, 26, 40, 24, 41, 23, 43, 21, 44, 20
+            db 45, 19, 47, 17, 48, 16, 50, 14, 51, 13, 52, 12, 53, 11, 55,  9
+            db 56,  8, 57,  7, 58,  6, 59,  5, 60,  4, 60,  4, 61,  3, 62,  2
+            db 64,  0, 64,  0, 64,  0, 64,  0, 64,  0, 64,  0, 64,  0, 64,  0
+
 warp_8x8_shufA: db 0,  2,  4,  6,  1,  3,  5,  7,  1,  3,  5,  7,  2,  4,  6,  8
                 db 4,  6,  8, 10,  5,  7,  9, 11,  5,  7,  9, 11,  6,  8, 10, 12
 warp_8x8_shufB: db 2,  4,  6,  8,  3,  5,  7,  9,  3,  5,  7,  9,  4,  6,  8, 10
@@ -42,10 +59,9 @@
 bilin_h_shuf4:  db 1,  0,  2,  1,  3,  2,  4,  3,  9,  8, 10,  9, 11, 10, 12, 11
 bilin_h_shuf8:  db 1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  8,  7
 deint_shuf4:    db 0,  4,  1,  5,  2,  6,  3,  7,  4,  8,  5,  9,  6, 10,  7, 11
+blend_shuf:     db 0,  1,  0,  1,  0,  1,  0,  1,  2,  3,  2,  3,  2,  3,  2,  3
 
-blend_shuf: ; bits 0-3: 0, 0, 0, 0, 1, 1, 1, 1
 pb_64:   times 4 db 64
-         times 4 db 1
 pw_8:    times 2 dw 8
 pw_26:   times 2 dw 26
 pw_34:   times 2 dw 34
@@ -61,7 +77,7 @@
 cextern mc_subpel_filters
 %define subpel_filters (mangle(private_prefix %+ _mc_subpel_filters)-8)
 
-%macro BIDIR_JMP_TABLE 1-* 4, 8, 16, 32, 64, 128
+%macro BIDIR_JMP_TABLE 1-*
     %xdefine %1_table (%%table - 2*%2)
     %xdefine %%base %1_table
     %xdefine %%prefix mangle(private_prefix %+ _%1)
@@ -72,11 +88,13 @@
     %endrep
 %endmacro
 
-BIDIR_JMP_TABLE avg_avx2
-BIDIR_JMP_TABLE w_avg_avx2
-BIDIR_JMP_TABLE mask_avx2
-BIDIR_JMP_TABLE w_mask_420_avx2
-BIDIR_JMP_TABLE blend_avx2, 2, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE avg_avx2,        4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_avg_avx2,      4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE mask_avx2,       4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_mask_420_avx2, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE blend_avx2,      4, 8, 16, 32
+BIDIR_JMP_TABLE blend_v_avx2, 2, 4, 8, 16, 32
+BIDIR_JMP_TABLE blend_h_avx2, 2, 4, 8, 16, 32, 32, 32
 
 %macro BASE_JMP_TABLE 3-*
     %xdefine %1_%2_table (%%table - %3)
@@ -3286,7 +3304,7 @@
     jg .w128_loop
     RET
 
-cglobal blend, 3, 7, 6, dst, ds, tmp, w, h, mask, ms
+cglobal blend, 3, 7, 7, dst, ds, tmp, w, h, mask
 %define base r6-blend_avx2_table
     lea                  r6, [blend_avx2_table]
     tzcnt                wd, wm
@@ -3296,219 +3314,68 @@
     vpbroadcastd         m4, [base+pb_64]
     vpbroadcastd         m5, [base+pw_512]
     add                  wq, r6
-    mov                 msq, msmp
+    lea                  r6, [dsq*3]
     jmp                  wq
-.w2:
-    cmp                 msq, 1
-    jb .w2_s0
-    je .w2_s1
-.w2_s2:
-    movd                xm1, [maskq]
-    movd                xm0, [dstq+dsq*0]
-    pinsrw              xm0, [dstq+dsq*1], 1
-    psubb               xm2, xm4, xm1
-    punpcklbw           xm2, xm1
-    movd                xm1, [tmpq]
-    add               maskq, 2*2
-    add                tmpq, 2*2
-    punpcklbw           xm0, xm1
-    pmaddubsw           xm0, xm2
-    pmulhrsw            xm0, xm5
-    packuswb            xm0, xm0
-    pextrw     [dstq+dsq*0], xm0, 0
-    pextrw     [dstq+dsq*1], xm0, 1
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w2_s2
-    RET
-.w2_s1:
-    movd                xm1, [maskq]
-    movd                xm0, [dstq+dsq*0]
-    psubb               xm2, xm4, xm1
-    punpcklbw           xm2, xm1
-    pinsrw              xm0, [dstq+dsq*1], 1
-    movd                xm1, [tmpq]
-    punpcklwd           xm2, xm2
-    add               maskq, 2
-    add                tmpq, 2*2
-    punpcklbw           xm0, xm1
-    pmaddubsw           xm0, xm2
-    pmulhrsw            xm0, xm5
-    packuswb            xm0, xm0
-    pextrw     [dstq+dsq*0], xm0, 0
-    pextrw     [dstq+dsq*1], xm0, 1
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w2_s1
-    RET
-.w2_s0:
-    vpbroadcastw        xm0, [maskq]
-    psubb               xm4, xm0
-    punpcklbw           xm4, xm0
-.w2_s0_loop:
-    movd                xm0, [dstq+dsq*0]
-    pinsrw              xm0, [dstq+dsq*1], 1
-    movd                xm1, [tmpq]
-    add                tmpq, 2*2
-    punpcklbw           xm0, xm1
-    pmaddubsw           xm0, xm4
-    pmulhrsw            xm0, xm5
-    packuswb            xm0, xm0
-    pextrw     [dstq+dsq*0], xm0, 0
-    pextrw     [dstq+dsq*1], xm0, 1
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w2_s0_loop
-    RET
-ALIGN function_align
 .w4:
-    cmp                 msq, 1
-    jb .w4_s0
-    je .w4_s1
-.w4_s4:
-    movq                xm1, [maskq]
     movd                xm0, [dstq+dsq*0]
     pinsrd              xm0, [dstq+dsq*1], 1
-    psubb               xm2, xm4, xm1
-    punpcklbw           xm2, xm1
-    movq                xm1, [tmpq]
-    add               maskq, 4*2
-    add                tmpq, 4*2
-    punpcklbw           xm0, xm1
+    vpbroadcastd        xm1, [dstq+dsq*2]
+    pinsrd              xm1, [dstq+r6   ], 3
+    mova                xm6, [maskq]
+    psubb               xm3, xm4, xm6
+    punpcklbw           xm2, xm3, xm6
+    punpckhbw           xm3, xm6
+    mova                xm6, [tmpq]
+    add               maskq, 4*4
+    add                tmpq, 4*4
+    punpcklbw           xm0, xm6
+    punpckhbw           xm1, xm6
     pmaddubsw           xm0, xm2
+    pmaddubsw           xm1, xm3
     pmulhrsw            xm0, xm5
-    packuswb            xm0, xm0
+    pmulhrsw            xm1, xm5
+    packuswb            xm0, xm1
     movd       [dstq+dsq*0], xm0
     pextrd     [dstq+dsq*1], xm0, 1
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w4_s4
-    RET
-.w4_s1:
-    movq                xm3, [blend_shuf]
-.w4_s1_loop:
-    movd                xm1, [maskq]
-    movd                xm0, [dstq+dsq*0]
-    pshufb              xm1, xm3
-    psubb               xm2, xm4, xm1
-    pinsrd              xm0, [dstq+dsq*1], 1
-    punpcklbw           xm2, xm1
-    movq                xm1, [tmpq]
-    add               maskq, 2
-    add                tmpq, 4*2
-    punpcklbw           xm0, xm1
-    pmaddubsw           xm0, xm2
-    pmulhrsw            xm0, xm5
-    packuswb            xm0, xm0
-    movd       [dstq+dsq*0], xm0
-    pextrd     [dstq+dsq*1], xm0, 1
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w4_s1_loop
-    RET
-.w4_s0:
-    vpbroadcastd        xm0, [maskq]
-    psubb               xm4, xm0
-    punpcklbw           xm4, xm0
-.w4_s0_loop:
-    movd                xm0, [dstq+dsq*0]
-    pinsrd              xm0, [dstq+dsq*1], 1
-    movq                xm1, [tmpq]
-    add                tmpq, 4*2
-    punpcklbw           xm0, xm1
-    pmaddubsw           xm0, xm4
-    pmulhrsw            xm0, xm5
-    packuswb            xm0, xm0
-    movd       [dstq+dsq*0], xm0
-    pextrd     [dstq+dsq*1], xm0, 1
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w4_s0_loop
+    pextrd     [dstq+dsq*2], xm0, 2
+    pextrd     [dstq+r6   ], xm0, 3
+    lea                dstq, [dstq+dsq*4]
+    sub                  hd, 4
+    jg .w4
     RET
 ALIGN function_align
 .w8:
-    cmp                 msq, 1
-    jb .w8_s0
-    je .w8_s1
-.w8_s8:
-    movq                xm1, [maskq+8*1]
-    vinserti128          m1, [maskq+8*0], 1
-    vpbroadcastq         m2, [dstq+dsq*0]
-    movq                xm0, [dstq+dsq*1]
-    vpblendd             m0, m2, 0x30
-    psubb                m2, m4, m1
-    punpcklbw            m2, m1
-    movq                xm1, [tmpq+8*1]
-    vinserti128          m1, [tmpq+8*0], 1
-    add               maskq, 8*2
-    add                tmpq, 8*2
-    punpcklbw            m0, m1
+    movq                xm1, [dstq+dsq*0]
+    movhps              xm1, [dstq+dsq*1]
+    vpbroadcastq         m2, [dstq+dsq*2]
+    vpbroadcastq         m3, [dstq+r6   ]
+    mova                 m0, [maskq]
+    mova                 m6, [tmpq]
+    add               maskq, 8*4
+    add                tmpq, 8*4
+    vpblendd             m1, m2, 0x30
+    vpblendd             m1, m3, 0xc0
+    psubb                m3, m4, m0
+    punpcklbw            m2, m3, m0
+    punpckhbw            m3, m0
+    punpcklbw            m0, m1, m6
+    punpckhbw            m1, m6
     pmaddubsw            m0, m2
+    pmaddubsw            m1, m3
     pmulhrsw             m0, m5
+    pmulhrsw             m1, m5
+    packuswb             m0, m1
     vextracti128        xm1, m0, 1
-    packuswb            xm0, xm1
-    movhps     [dstq+dsq*0], xm0
-    movq       [dstq+dsq*1], xm0
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w8_s8
-    RET
-.w8_s1:
-    vpbroadcastd         m0, [blend_shuf+0]
-    vpbroadcastd        xm3, [blend_shuf+4]
-    vpblendd             m3, m0, 0xf0
-.w8_s1_loop:
-    vpbroadcastd         m0, [maskq]
-    vpbroadcastq         m1, [dstq+dsq*0]
-    pshufb               m0, m3
-    psubb                m2, m4, m0
-    punpcklbw            m2, m0
-    movq                xm0, [dstq+dsq*1]
-    vpblendd             m0, m1, 0x30
-    movq                xm1, [tmpq+8*1]
-    vinserti128          m1, [tmpq+8*0], 1
-    add               maskq, 2
-    add                tmpq, 8*2
-    punpcklbw            m0, m1
-    pmaddubsw            m0, m2
-    pmulhrsw             m0, m5
-    vextracti128        xm1, m0, 1
-    packuswb            xm0, xm1
-    movhps     [dstq+dsq*0], xm0
-    movq       [dstq+dsq*1], xm0
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w8_s1_loop
-    RET
-.w8_s0:
-    vpbroadcastq         m0, [maskq]
-    psubb                m4, m0
-    punpcklbw            m4, m0
-.w8_s0_loop:
-    vpbroadcastq         m2, [dstq+dsq*0]
-    movq                xm0, [dstq+dsq*1]
-    vpblendd             m0, m2, 0x30
-    movq                xm1, [tmpq+8*1]
-    vinserti128          m1, [tmpq+8*0], 1
-    add                tmpq, 8*2
-    punpcklbw            m0, m1
-    pmaddubsw            m0, m4
-    pmulhrsw             m0, m5
-    vextracti128        xm1, m0, 1
-    packuswb            xm0, xm1
-    movhps     [dstq+dsq*0], xm0
-    movq       [dstq+dsq*1], xm0
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w8_s0_loop
+    movq       [dstq+dsq*0], xm0
+    movhps     [dstq+dsq*1], xm0
+    movq       [dstq+dsq*2], xm1
+    movhps     [dstq+r6   ], xm1
+    lea                dstq, [dstq+dsq*4]
+    sub                  hd, 4
+    jg .w8
     RET
 ALIGN function_align
 .w16:
-    cmp                 msq, 1
-    jb .w16_s0
-    WIN64_SPILL_XMM       7
-    je .w16_s1
-.w16_s16:
     mova                 m0, [maskq]
     mova                xm1, [dstq+dsq*0]
     vinserti128          m1, [dstq+dsq*1], 1
@@ -3529,43 +3396,102 @@
     vextracti128 [dstq+dsq*1], m0, 1
     lea                dstq, [dstq+dsq*2]
     sub                  hd, 2
-    jg .w16_s16
+    jg .w16
     RET
-.w16_s1:
-    vpbroadcastd        xm6, [blend_shuf]
-    vpbroadcastd         m0, [blend_shuf+4]
-    vpblendd             m6, m0, 0xf0
-.w16_s1_loop:
-    vpbroadcastd         m2, [maskq]
-    mova                xm1, [dstq+dsq*0]
-    pshufb               m2, m6
-    psubb                m3, m4, m2
-    vinserti128          m1, [dstq+dsq*1], 1
-    punpcklbw            m3, m2
-    mova                 m2, [tmpq]
-    add               maskq, 2
-    add                tmpq, 16*2
-    punpcklbw            m0, m1, m2
-    punpckhbw            m1, m2
-    pmaddubsw            m0, m3
+ALIGN function_align
+.w32:
+    mova                 m0, [maskq]
+    mova                 m1, [dstq]
+    mova                 m6, [tmpq]
+    add               maskq, 32
+    add                tmpq, 32
+    psubb                m3, m4, m0
+    punpcklbw            m2, m3, m0
+    punpckhbw            m3, m0
+    punpcklbw            m0, m1, m6
+    punpckhbw            m1, m6
+    pmaddubsw            m0, m2
     pmaddubsw            m1, m3
     pmulhrsw             m0, m5
     pmulhrsw             m1, m5
     packuswb             m0, m1
-    mova         [dstq+dsq*0], xm0
-    vextracti128 [dstq+dsq*1], m0, 1
+    mova             [dstq], m0
+    add                dstq, dsq
+    dec                  hd
+    jg .w32
+    RET
+
+cglobal blend_v, 3, 6, 6, dst, ds, tmp, w, h, mask
+%define base r5-blend_v_avx2_table
+    lea                  r5, [blend_v_avx2_table]
+    tzcnt                wd, wm
+    movifnidn            hd, hm
+    movsxd               wq, dword [r5+wq*4]
+    vpbroadcastd         m5, [base+pw_512]
+    add                  wq, r5
+    add               maskq, obmc_masks-blend_v_avx2_table
+    jmp                  wq
+.w2:
+    vpbroadcastd        xm2, [maskq+2*2]
+.w2_s0_loop:
+    movd                xm0, [dstq+dsq*0]
+    pinsrw              xm0, [dstq+dsq*1], 1
+    movd                xm1, [tmpq]
+    add                tmpq, 2*2
+    punpcklbw           xm0, xm1
+    pmaddubsw           xm0, xm2
+    pmulhrsw            xm0, xm5
+    packuswb            xm0, xm0
+    pextrw     [dstq+dsq*0], xm0, 0
+    pextrw     [dstq+dsq*1], xm0, 1
     lea                dstq, [dstq+dsq*2]
     sub                  hd, 2
-    jg .w16_s1_loop
+    jg .w2_s0_loop
     RET
-.w16_s0:
-    %assign stack_offset stack_offset - stack_size_padded
-    WIN64_SPILL_XMM       6
-    vbroadcasti128       m0, [maskq]
-    psubb                m4, m0
-    punpcklbw            m3, m4, m0
-    punpckhbw            m4, m0
-.w16_s0_loop:
+ALIGN function_align
+.w4:
+    vpbroadcastq        xm2, [maskq+4*2]
+.w4_loop:
+    movd                xm0, [dstq+dsq*0]
+    pinsrd              xm0, [dstq+dsq*1], 1
+    movq                xm1, [tmpq]
+    add                tmpq, 4*2
+    punpcklbw           xm0, xm1
+    pmaddubsw           xm0, xm2
+    pmulhrsw            xm0, xm5
+    packuswb            xm0, xm0
+    movd       [dstq+dsq*0], xm0
+    pextrd     [dstq+dsq*1], xm0, 1
+    lea                dstq, [dstq+dsq*2]
+    sub                  hd, 2
+    jg .w4_loop
+    RET
+ALIGN function_align
+.w8:
+    vbroadcasti128       m4, [maskq+8*2]
+.w8_loop:
+    vpbroadcastq         m2, [dstq+dsq*0]
+    movq                xm0, [dstq+dsq*1]
+    vpblendd             m0, m2, 0x30
+    movq                xm1, [tmpq+8*1]
+    vinserti128          m1, [tmpq+8*0], 1
+    add                tmpq, 8*2
+    punpcklbw            m0, m1
+    pmaddubsw            m0, m4
+    pmulhrsw             m0, m5
+    vextracti128        xm1, m0, 1
+    packuswb            xm0, xm1
+    movhps     [dstq+dsq*0], xm0
+    movq       [dstq+dsq*1], xm0
+    lea                dstq, [dstq+dsq*2]
+    sub                  hd, 2
+    jg .w8_loop
+    RET
+ALIGN function_align
+.w16:
+    vbroadcasti128       m3, [maskq+16*2]
+    vbroadcasti128       m4, [maskq+16*3]
+.w16_loop:
     mova                xm1, [dstq+dsq*0]
     vinserti128          m1, [dstq+dsq*1], 1
     mova                 m2, [tmpq]
@@ -3581,58 +3507,135 @@
     vextracti128 [dstq+dsq*1], m0, 1
     lea                dstq, [dstq+dsq*2]
     sub                  hd, 2
-    jg .w16_s0_loop
+    jg .w16_loop
     RET
 ALIGN function_align
 .w32:
-    mov                  wd, 32
-    jmp .w32_start
-.w64:
-    mov                  wd, 64
-    jmp .w32_start
-.w128:
-    mov                  wd, 128
-.w32_start:
-    WIN64_SPILL_XMM       7
-    cmp                 msq, 1
-    jb .w32_s0
-    je .w32_s1
-    sub                 dsq, wq
-.w32_s32:
-    mov                 r6d, wd
-.w32_s32_loop:
-    mova                 m0, [maskq]
+    mova                xm3, [maskq+16*4]
+    vinserti128          m3, [maskq+16*6], 1
+    mova                xm4, [maskq+16*5]
+    vinserti128          m4, [maskq+16*7], 1
+.w32_loop:
     mova                 m1, [dstq]
-    psubb                m3, m4, m0
-    punpcklbw            m2, m3, m0
-    punpckhbw            m3, m0
-    mova                 m6, [tmpq]
-    add               maskq, 32
+    mova                 m2, [tmpq]
     add                tmpq, 32
-    punpcklbw            m0, m1, m6
-    punpckhbw            m1, m6
-    pmaddubsw            m0, m2
-    pmaddubsw            m1, m3
+    punpcklbw            m0, m1, m2
+    punpckhbw            m1, m2
+    pmaddubsw            m0, m3
+    pmaddubsw            m1, m4
     pmulhrsw             m0, m5
     pmulhrsw             m1, m5
     packuswb             m0, m1
     mova             [dstq], m0
-    add                dstq, 32
-    sub                 r6d, 32
-    jg .w32_s32_loop
     add                dstq, dsq
     dec                  hd
-    jg .w32_s32
+    jg .w32_loop
     RET
-.w32_s1:
-    sub                 dsq, wq
-.w32_s1_loop0:
-    vpbroadcastb         m0, [maskq]
+
+cglobal blend_h, 4, 7, 6, dst, ds, tmp, w, h, mask
+%define base r5-blend_h_avx2_table
+    lea                  r5, [blend_h_avx2_table]
     mov                 r6d, wd
-    inc               maskq
-    psubb                m3, m4, m0
-    punpcklbw            m3, m0
-.w32_s1_loop:
+    tzcnt                wd, wd
+    mov                  hd, hm
+    movsxd               wq, dword [r5+wq*4]
+    vpbroadcastd         m5, [base+pw_512]
+    add                  wq, r5
+    lea               maskq, [base+obmc_masks+hq*4]
+    neg                  hq
+    jmp                  wq
+.w2:
+    movd                xm0, [dstq+dsq*0]
+    pinsrw              xm0, [dstq+dsq*1], 1
+    movd                xm2, [maskq+hq*2]
+    movd                xm1, [tmpq]
+    add                tmpq, 2*2
+    punpcklwd           xm2, xm2
+    punpcklbw           xm0, xm1
+    pmaddubsw           xm0, xm2
+    pmulhrsw            xm0, xm5
+    packuswb            xm0, xm0
+    pextrw     [dstq+dsq*0], xm0, 0
+    pextrw     [dstq+dsq*1], xm0, 1
+    lea                dstq, [dstq+dsq*2]
+    add                  hq, 2
+    jl .w2
+    RET
+ALIGN function_align
+.w4:
+    mova                xm3, [blend_shuf]
+.w4_loop:
+    movd                xm0, [dstq+dsq*0]
+    pinsrd              xm0, [dstq+dsq*1], 1
+    movq                xm2, [maskq+hq*2]
+    movq                xm1, [tmpq]
+    add                tmpq, 4*2
+    pshufb              xm2, xm3
+    punpcklbw           xm0, xm1
+    pmaddubsw           xm0, xm2
+    pmulhrsw            xm0, xm5
+    packuswb            xm0, xm0
+    movd       [dstq+dsq*0], xm0
+    pextrd     [dstq+dsq*1], xm0, 1
+    lea                dstq, [dstq+dsq*2]
+    add                  hq, 2
+    jl .w4_loop
+    RET
+ALIGN function_align
+.w8:
+    vbroadcasti128       m4, [blend_shuf]
+    shufpd               m4, m4, 0x03
+.w8_loop:
+    vpbroadcastq         m1, [dstq+dsq*0]
+    movq                xm0, [dstq+dsq*1]
+    vpblendd             m0, m1, 0x30
+    vpbroadcastd         m3, [maskq+hq*2]
+    movq                xm1, [tmpq+8*1]
+    vinserti128          m1, [tmpq+8*0], 1
+    add                tmpq, 8*2
+    pshufb               m3, m4
+    punpcklbw            m0, m1
+    pmaddubsw            m0, m3
+    pmulhrsw             m0, m5
+    vextracti128        xm1, m0, 1
+    packuswb            xm0, xm1
+    movhps     [dstq+dsq*0], xm0
+    movq       [dstq+dsq*1], xm0
+    lea                dstq, [dstq+dsq*2]
+    add                  hq, 2
+    jl .w8_loop
+    RET
+ALIGN function_align
+.w16:
+    vbroadcasti128       m4, [blend_shuf]
+    shufpd               m4, m4, 0x0c
+.w16_loop:
+    mova                xm1, [dstq+dsq*0]
+    vinserti128          m1, [dstq+dsq*1], 1
+    vpbroadcastd         m3, [maskq+hq*2]
+    mova                 m2, [tmpq]
+    add                tmpq, 16*2
+    pshufb               m3, m4
+    punpcklbw            m0, m1, m2
+    punpckhbw            m1, m2
+    pmaddubsw            m0, m3
+    pmaddubsw            m1, m3
+    pmulhrsw             m0, m5
+    pmulhrsw             m1, m5
+    packuswb             m0, m1
+    mova         [dstq+dsq*0], xm0
+    vextracti128 [dstq+dsq*1], m0, 1
+    lea                dstq, [dstq+dsq*2]
+    add                  hq, 2
+    jl .w16_loop
+    RET
+ALIGN function_align
+.w32: ; w32/w64/w128
+    sub                 dsq, r6
+.w32_loop0:
+    vpbroadcastw         m3, [maskq+hq*2]
+    mov                  wd, r6d
+.w32_loop:
     mova                 m1, [dstq]
     mova                 m2, [tmpq]
     add                tmpq, 32
@@ -3645,49 +3648,11 @@
     packuswb             m0, m1
     mova             [dstq], m0
     add                dstq, 32
-    sub                 r6d, 32
-    jg .w32_s1_loop
+    sub                  wd, 32
+    jg .w32_loop
     add                dstq, dsq
-    dec                  hd
-    jg .w32_s1_loop0
-    RET
-.w32_s0:
-%if WIN64
-    PUSH                 r7
-    PUSH                 r8
-    %define regs_used 9
-%endif
-    lea                 r6d, [hq+wq*8-256]
-    mov                  r7, dstq
-    mov                  r8, tmpq
-.w32_s0_loop0:
-    mova                 m0, [maskq]
-    add               maskq, 32
-    psubb                m3, m4, m0
-    punpcklbw            m2, m3, m0
-    punpckhbw            m3, m0
-.w32_s0_loop:
-    mova                 m1, [dstq]
-    mova                 m6, [tmpq]
-    add                tmpq, wq
-    punpcklbw            m0, m1, m6
-    punpckhbw            m1, m6
-    pmaddubsw            m0, m2
-    pmaddubsw            m1, m3
-    pmulhrsw             m0, m5
-    pmulhrsw             m1, m5
-    packuswb             m0, m1
-    mova             [dstq], m0
-    add                dstq, dsq
-    dec                  hd
-    jg .w32_s0_loop
-    add                  r7, 32
-    add                  r8, 32
-    mov                dstq, r7
-    mov                tmpq, r8
-    mov                  hb, r6b
-    sub                 r6d, 256
-    jg .w32_s0_loop0
+    inc                  hq
+    jl .w32_loop0
     RET
 
 cglobal emu_edge, 10, 13, 1, bw, bh, iw, ih, x, y, dst, dstride, src, sstride, \
diff --git a/src/x86/mc_init_tmpl.c b/src/x86/mc_init_tmpl.c
index ced7305..7579019 100644
--- a/src/x86/mc_init_tmpl.c
+++ b/src/x86/mc_init_tmpl.c
@@ -55,6 +55,8 @@
 decl_mask_fn(dav1d_mask_avx2);
 decl_w_mask_fn(dav1d_w_mask_420_avx2);
 decl_blend_fn(dav1d_blend_avx2);
+decl_blend_dir_fn(dav1d_blend_v_avx2);
+decl_blend_dir_fn(dav1d_blend_h_avx2);
 
 decl_warp8x8_fn(dav1d_warp_affine_8x8_avx2);
 decl_warp8x8t_fn(dav1d_warp_affine_8x8t_avx2);
@@ -98,6 +100,8 @@
     c->mask = dav1d_mask_avx2;
     c->w_mask[2] = dav1d_w_mask_420_avx2;
     c->blend = dav1d_blend_avx2;
+    c->blend_v = dav1d_blend_v_avx2;
+    c->blend_h = dav1d_blend_h_avx2;
 
     c->warp8x8  = dav1d_warp_affine_8x8_avx2;
     c->warp8x8t = dav1d_warp_affine_8x8t_avx2;
diff --git a/tests/checkasm/mc.c b/tests/checkasm/mc.c
index c44ca44..0a54eb2 100644
--- a/tests/checkasm/mc.c
+++ b/tests/checkasm/mc.c
@@ -237,38 +237,93 @@
 }
 
 static void check_blend(Dav1dMCDSPContext *const c) {
-    ALIGN_STK_32(pixel, tmp, 128 * 32,);
-    ALIGN_STK_32(pixel, c_dst, 128 * 32,);
-    ALIGN_STK_32(pixel, a_dst, 128 * 32,);
-    ALIGN_STK_32(uint8_t, mask, 128 * 32,);
+    ALIGN_STK_32(pixel, tmp, 32 * 32,);
+    ALIGN_STK_32(pixel, c_dst, 32 * 32,);
+    ALIGN_STK_32(pixel, a_dst, 32 * 32,);
+    ALIGN_STK_32(uint8_t, mask, 32 * 32,);
 
-    for (int i = 0; i < 128 * 32; i++) {
+    for (int i = 0; i < 32 * 32; i++) {
         tmp[i] = rand() & ((1 << BITDEPTH) - 1);
         mask[i] = rand() % 65;
     }
 
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel *tmp,
-                 int w, int h, const uint8_t *mask, ptrdiff_t mstride);
+                 int w, int h, const uint8_t *mask);
+
+    for (int w = 4; w <= 32; w <<= 1) {
+        const ptrdiff_t dst_stride = w * sizeof(pixel);
+        if (check_func(c->blend, "blend_w%d_%dbpc", w, BITDEPTH))
+            for (int h = imax(w / 2, 4); h <= imin(w * 2, 32); h <<= 1) {
+                for (int i = 0; i < w * h; i++)
+                    c_dst[i] = a_dst[i] = rand() & ((1 << BITDEPTH) - 1);
+
+                call_ref(c_dst, dst_stride, tmp, w, h, mask);
+                call_new(a_dst, dst_stride, tmp, w, h, mask);
+                if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
+                    fail();
+
+                bench_new(a_dst, dst_stride, tmp, w, h, mask);
+            }
+    }
+    report("blend");
+}
+
+static void check_blend_v(Dav1dMCDSPContext *const c) {
+    ALIGN_STK_32(pixel, tmp,   32 * 128,);
+    ALIGN_STK_32(pixel, c_dst, 32 * 128,);
+    ALIGN_STK_32(pixel, a_dst, 32 * 128,);
+
+    for (int i = 0; i < 32 * 128; i++)
+        tmp[i] = rand() & ((1 << BITDEPTH) - 1);
+
+    declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel *tmp,
+                 int w, int h);
+
+    for (int w = 2; w <= 32; w <<= 1) {
+        const ptrdiff_t dst_stride = w * sizeof(pixel);
+        if (check_func(c->blend_v, "blend_v_w%d_%dbpc", w, BITDEPTH))
+            for (int h = 2; h <= (w == 2 ? 64 : 128); h <<= 1) {
+                for (int i = 0; i < w * h; i++)
+                    c_dst[i] = a_dst[i] = rand() & ((1 << BITDEPTH) - 1);
+
+                call_ref(c_dst, dst_stride, tmp, w, h);
+                call_new(a_dst, dst_stride, tmp, w, h);
+                if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
+                    fail();
+
+                bench_new(a_dst, dst_stride, tmp, w, h);
+            }
+    }
+    report("blend_v");
+}
+
+static void check_blend_h(Dav1dMCDSPContext *const c) {
+    ALIGN_STK_32(pixel, tmp,   128 * 32,);
+    ALIGN_STK_32(pixel, c_dst, 128 * 32,);
+    ALIGN_STK_32(pixel, a_dst, 128 * 32,);
+
+    for (int i = 0; i < 128 * 32; i++)
+        tmp[i] = rand() & ((1 << BITDEPTH) - 1);
+
+    declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel *tmp,
+                 int w, int h);
 
     for (int w = 2; w <= 128; w <<= 1) {
         const ptrdiff_t dst_stride = w * sizeof(pixel);
-        const int h_min = (w == 128) ? 4 : 2;
-        const int h_max = (w > 32) ? 32 : (w == 2) ? 64 : 128;
-        for (int ms = 0; ms <= w; ms += ms ? w - 1 : 1)
-            if (check_func(c->blend, "blend_w%d_ms%d_%dbpc", w, ms, BITDEPTH))
-                for (int h = h_min; h <= h_max; h <<= 1) {
-                    for (int i = 0; i < w * h; i++)
-                        c_dst[i] = a_dst[i] = rand() & ((1 << BITDEPTH) - 1);
+        if (check_func(c->blend_h, "blend_h_w%d_%dbpc", w, BITDEPTH))
+            for (int h = (w == 128 ? 4 : 2); h <= 32; h <<= 1) {
+                for (int i = 0; i < w * h; i++)
+                    c_dst[i] = a_dst[i] = rand() & ((1 << BITDEPTH) - 1);
 
-                    call_ref(c_dst, dst_stride, tmp, w, h, mask, ms);
-                    call_new(a_dst, dst_stride, tmp, w, h, mask, ms);
-                    if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
-                        fail();
+                call_ref(c_dst, dst_stride, tmp, w, h);
+                call_new(a_dst, dst_stride, tmp, w, h);
+                if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
+                    fail();
 
-                    bench_new(a_dst, dst_stride, tmp, w, h, mask, ms);
-                }
+                bench_new(a_dst, dst_stride, tmp, w, h);
+            }
     }
-    report("blend");
+    report("blend_h");
 }
 
 static void check_warp8x8(Dav1dMCDSPContext *const c) {
@@ -430,6 +485,8 @@
     check_mask(&c);
     check_w_mask(&c);
     check_blend(&c);
+    check_blend_v(&c);
+    check_blend_h(&c);
     check_warp8x8(&c);
     check_warp8x8t(&c);
     check_emuedge(&c);