itx: Add a bpc parameter to the itx dsp init function
diff --git a/src/arm/itx_init_tmpl.c b/src/arm/itx_init_tmpl.c
index f9c68e9..8fa7a61 100644
--- a/src/arm/itx_init_tmpl.c
+++ b/src/arm/itx_init_tmpl.c
@@ -77,7 +77,7 @@
 decl_itx_fn(dav1d_inv_txfm_add_dct_dct_64x32_neon);
 decl_itx_fn(dav1d_inv_txfm_add_dct_dct_64x64_neon);
 
-COLD void bitfn(dav1d_itx_dsp_init_arm)(Dav1dInvTxfmDSPContext *const c) {
+COLD void bitfn(dav1d_itx_dsp_init_arm)(Dav1dInvTxfmDSPContext *const c, int bpc) {
 #define assign_itx_fn(pfx, w, h, type, type_enum, ext) \
     c->itxfm_add[pfx##TX_##w##X##h][type_enum] = \
         dav1d_inv_txfm_add_##type##_##w##x##h##_##ext
diff --git a/src/decode.c b/src/decode.c
index a5646c6..f678215 100644
--- a/src/decode.c
+++ b/src/decode.c
@@ -3302,7 +3302,7 @@
 #define assign_bitdepth_case(bd) \
             dav1d_cdef_dsp_init_##bd##bpc(&dsp->cdef); \
             dav1d_intra_pred_dsp_init_##bd##bpc(&dsp->ipred); \
-            dav1d_itx_dsp_init_##bd##bpc(&dsp->itx); \
+            dav1d_itx_dsp_init_##bd##bpc(&dsp->itx, bpc); \
             dav1d_loop_filter_dsp_init_##bd##bpc(&dsp->lf); \
             dav1d_loop_restoration_dsp_init_##bd##bpc(&dsp->lr, bpc); \
             dav1d_mc_dsp_init_##bd##bpc(&dsp->mc); \
diff --git a/src/itx.h b/src/itx.h
index 3befc42..a299629 100644
--- a/src/itx.h
+++ b/src/itx.h
@@ -43,8 +43,8 @@
     itxfm_fn itxfm_add[N_RECT_TX_SIZES][N_TX_TYPES_PLUS_LL];
 } Dav1dInvTxfmDSPContext;
 
-bitfn_decls(void dav1d_itx_dsp_init, Dav1dInvTxfmDSPContext *c);
-bitfn_decls(void dav1d_itx_dsp_init_arm, Dav1dInvTxfmDSPContext *c);
+bitfn_decls(void dav1d_itx_dsp_init, Dav1dInvTxfmDSPContext *c, int bpc);
+bitfn_decls(void dav1d_itx_dsp_init_arm, Dav1dInvTxfmDSPContext *c, int bpc);
 bitfn_decls(void dav1d_itx_dsp_init_x86, Dav1dInvTxfmDSPContext *c);
 
 #endif /* DAV1D_SRC_ITX_H */
diff --git a/src/itx_tmpl.c b/src/itx_tmpl.c
index 02f34e8..dcd9f68 100644
--- a/src/itx_tmpl.c
+++ b/src/itx_tmpl.c
@@ -180,7 +180,7 @@
             dst[x] = iclip_pixel(dst[x] + *c++);
 }
 
-COLD void bitfn(dav1d_itx_dsp_init)(Dav1dInvTxfmDSPContext *const c) {
+COLD void bitfn(dav1d_itx_dsp_init)(Dav1dInvTxfmDSPContext *const c, int bpc) {
 #define assign_itx_all_fn64(w, h, pfx) \
     c->itxfm_add[pfx##TX_##w##X##h][DCT_DCT  ] = \
         inv_txfm_add_dct_dct_##w##x##h##_c
@@ -249,7 +249,7 @@
 
 #if HAVE_ASM
 #if ARCH_AARCH64 || ARCH_ARM
-    bitfn(dav1d_itx_dsp_init_arm)(c);
+    bitfn(dav1d_itx_dsp_init_arm)(c, bpc);
 #endif
 #if ARCH_X86
     bitfn(dav1d_itx_dsp_init_x86)(c);
diff --git a/tests/checkasm/itx.c b/tests/checkasm/itx.c
index 9d715c8..1ef0cbc 100644
--- a/tests/checkasm/itx.c
+++ b/tests/checkasm/itx.c
@@ -223,8 +223,11 @@
 }
 
 void bitfn(checkasm_check_itx)(void) {
-    Dav1dInvTxfmDSPContext c;
-    bitfn(dav1d_itx_dsp_init)(&c);
+#if BITDEPTH == 16
+    const int bpc_min = 10, bpc_max = 12;
+#else
+    const int bpc_min = 8, bpc_max = 8;
+#endif
 
     ALIGN_STK_64(coef, coeff, 2, [32 * 32]);
     ALIGN_STK_64(pixel, c_dst, 64 * 64,);
@@ -250,39 +253,39 @@
         const int subsh_max = subsh_iters[imax(dav1d_txfm_dimensions[tx].lw,
                                                dav1d_txfm_dimensions[tx].lh)];
 
-        for (enum TxfmType txtp = 0; txtp < N_TX_TYPES_PLUS_LL; txtp++)
-            for (int subsh = 0; subsh < subsh_max; subsh++)
-                if (check_func(c.itxfm_add[tx][txtp],
-                               "inv_txfm_add_%dx%d_%s_%s_%d_%dbpc",
-                               w, h, itx_1d_names[itx_1d_types[txtp][0]],
-                               itx_1d_names[itx_1d_types[txtp][1]], subsh,
-                               BITDEPTH))
-                {
-#if BITDEPTH == 16
-                    const int bitdepth_max = rnd() & 1 ? 0x3ff : 0xfff;
-#else
-                    const int bitdepth_max = 0xff;
-#endif
-                    const int eob = ftx(coeff[0], tx, txtp, w, h, subsh, bitdepth_max);
-                    memcpy(coeff[1], coeff[0], sizeof(*coeff));
+        for (int bpc = bpc_min; bpc <= bpc_max; bpc += 2) {
+            Dav1dInvTxfmDSPContext c;
+            bitfn(dav1d_itx_dsp_init)(&c, bpc);
+            for (enum TxfmType txtp = 0; txtp < N_TX_TYPES_PLUS_LL; txtp++)
+                for (int subsh = 0; subsh < subsh_max; subsh++)
+                    if (check_func(c.itxfm_add[tx][txtp],
+                                   "inv_txfm_add_%dx%d_%s_%s_%d_%dbpc",
+                                   w, h, itx_1d_names[itx_1d_types[txtp][0]],
+                                   itx_1d_names[itx_1d_types[txtp][1]], subsh,
+                                   bpc))
+                    {
+                        const int bitdepth_max = (1 << bpc) - 1;
+                        const int eob = ftx(coeff[0], tx, txtp, w, h, subsh, bitdepth_max);
+                        memcpy(coeff[1], coeff[0], sizeof(*coeff));
 
-                    for (int j = 0; j < w * h; j++)
-                        c_dst[j] = a_dst[j] = rnd() & bitdepth_max;
+                        for (int j = 0; j < w * h; j++)
+                            c_dst[j] = a_dst[j] = rnd() & bitdepth_max;
 
-                    call_ref(c_dst, w * sizeof(*c_dst), coeff[0], eob
-                             HIGHBD_TAIL_SUFFIX);
-                    call_new(a_dst, w * sizeof(*c_dst), coeff[1], eob
-                             HIGHBD_TAIL_SUFFIX);
+                        call_ref(c_dst, w * sizeof(*c_dst), coeff[0], eob
+                                 HIGHBD_TAIL_SUFFIX);
+                        call_new(a_dst, w * sizeof(*c_dst), coeff[1], eob
+                                 HIGHBD_TAIL_SUFFIX);
 
-                    checkasm_check_pixel(c_dst, w * sizeof(*c_dst),
-                                         a_dst, w * sizeof(*a_dst),
-                                         w, h, "dst");
-                    if (memcmp(coeff[0], coeff[1], sizeof(*coeff)))
-                        fail();
+                        checkasm_check_pixel(c_dst, w * sizeof(*c_dst),
+                                             a_dst, w * sizeof(*a_dst),
+                                             w, h, "dst");
+                        if (memcmp(coeff[0], coeff[1], sizeof(*coeff)))
+                            fail();
 
-                    bench_new(a_dst, w * sizeof(*c_dst), coeff[0], eob
-                              HIGHBD_TAIL_SUFFIX);
-                }
+                        bench_new(a_dst, w * sizeof(*c_dst), coeff[0], eob
+                                  HIGHBD_TAIL_SUFFIX);
+                    }
+        }
         report("add_%dx%d", w, h);
     }
 }