x86: Add splat_mv AVX2 asm
diff --git a/src/x86/refmvs.asm b/src/x86/refmvs.asm
index b94fea3..b3c47d2 100644
--- a/src/x86/refmvs.asm
+++ b/src/x86/refmvs.asm
@@ -26,7 +26,7 @@
%include "config.asm"
%include "ext/x86/x86inc.asm"
-SECTION_RODATA
+SECTION_RODATA 32
%macro JMP_TABLE 2-*
%xdefine %%prefix mangle(private_prefix %+ _%1)
@@ -38,6 +38,12 @@
%endrep
%endmacro
+%if ARCH_X86_64
+splat_mv_shuf: db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3
+ db 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7
+
+JMP_TABLE splat_mv_avx2, 1, 2, 4, 8, 16, 32
+%endif
JMP_TABLE splat_mv_sse2, 1, 2, 4, 8, 16, 32
SECTION .text
@@ -104,3 +110,60 @@
dec bh4d
jg .loop
RET
+
+%if ARCH_X86_64
+INIT_YMM avx2
+cglobal splat_mv, 4, 5, 3, rr, a, bx4, bw4, bh4
+ add bx4d, bw4d
+ tzcnt bw4d, bw4d
+ vbroadcasti128 m0, [aq]
+ lea aq, [splat_mv_avx2_table]
+ lea bx4q, [bx4q*3-32]
+ movsxd bw4q, [aq+bw4q*4]
+ pshufb m0, [splat_mv_shuf]
+ movifnidn bh4d, bh4m
+ pshufd m1, m0, q2102
+ pshufd m2, m0, q1021
+ add bw4q, aq
+.loop:
+ mov aq, [rrq]
+ add rrq, gprsize
+ lea aq, [aq+bx4q*4]
+ jmp bw4q
+.w32:
+ mova [aq-32*8], m0
+ mova [aq-32*7], m1
+ mova [aq-32*6], m2
+ mova [aq-32*5], m0
+ mova [aq-32*4], m1
+ mova [aq-32*3], m2
+.w16:
+ mova [aq-32*2], m0
+ mova [aq-32*1], m1
+ mova [aq+32*0], m2
+.w8:
+ mova [aq+32*1], m0
+ mova [aq+32*2], m1
+ mova [aq+32*3], m2
+ dec bh4d
+ jg .loop
+ RET
+.w4:
+ movu [aq+ 80], m0
+ mova [aq+112], xm1
+ dec bh4d
+ jg .loop
+ RET
+.w2:
+ movu [aq+104], xm0
+ movq [aq+120], xm2
+ dec bh4d
+ jg .loop
+ RET
+.w1:
+ movq [aq+116], xm0
+ movd [aq+124], xm1
+ dec bh4d
+ jg .loop
+ RET
+%endif
diff --git a/src/x86/refmvs_init.c b/src/x86/refmvs_init.c
index a9deba8..2d3da67 100644
--- a/src/x86/refmvs_init.c
+++ b/src/x86/refmvs_init.c
@@ -29,6 +29,7 @@
#include "src/refmvs.h"
decl_splat_mv_fn(dav1d_splat_mv_sse2);
+decl_splat_mv_fn(dav1d_splat_mv_avx2);
COLD void dav1d_refmvs_dsp_init_x86(Dav1dRefmvsDSPContext *const c) {
const unsigned flags = dav1d_get_cpu_flags();
@@ -36,4 +37,10 @@
if (!(flags & DAV1D_X86_CPU_FLAG_SSE2)) return;
c->splat_mv = dav1d_splat_mv_sse2;
+
+#if ARCH_X86_64
+ if (!(flags & DAV1D_X86_CPU_FLAG_AVX2)) return;
+
+ c->splat_mv = dav1d_splat_mv_avx2;
+#endif
}