QS8/QU8 Multiply ND operators
- New API functions: xnn_create_multiply_nd_qs8, xnn_setup_multiply_nd_qs8,
xnn_create_multiply_nd_qu8, xnn_setup_multiply_nd_qu8
- Unit tests
PiperOrigin-RevId: 388398721
diff --git a/BUILD.bazel b/BUILD.bazel
index 4b0dcb8..40a9c92 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -253,6 +253,8 @@
"src/qs8-igemm/gen/4x4-minmax-fp32-scalar-magic.c",
"src/qs8-vadd/gen/minmax-scalar-x4.c",
"src/qs8-vaddc/gen/minmax-scalar-x4.c",
+ "src/qs8-vmul/gen/minmax-fp32-scalar-x4.c",
+ "src/qs8-vmulc/gen/minmax-fp32-scalar-x4.c",
"src/qu8-avgpool/9p8x-minmax-scalar-c1.c",
"src/qu8-avgpool/9x-minmax-scalar-c1.c",
"src/qu8-dwconv/gen/up1x9-minmax-fp32-scalar-lrint.c",
@@ -277,6 +279,8 @@
"src/qu8-vadd/gen/minmax-scalar-x4.c",
"src/qu8-vaddc/gen/minmax-scalar-x1.c",
"src/qu8-vaddc/gen/minmax-scalar-x4.c",
+ "src/qu8-vmul/gen/minmax-fp32-scalar-x4.c",
+ "src/qu8-vmulc/gen/minmax-fp32-scalar-x4.c",
"src/u8-lut32norm/scalar.c",
"src/u8-maxpool/9p8x-minmax-scalar-c1.c",
"src/u8-rmax/scalar.c",
@@ -1937,6 +1941,8 @@
"src/qs8-vadd/gen/minmax-neon-ld64-x32.c",
"src/qs8-vaddc/gen/minmax-neon-ld64-x16.c",
"src/qs8-vaddc/gen/minmax-neon-ld64-x32.c",
+ "src/qs8-vmul/gen/minmax-fp32-neon-ld64-x16.c",
+ "src/qs8-vmulc/gen/minmax-fp32-neon-ld64-x16.c",
"src/qu8-avgpool/9p8x-minmax-neon-c8.c",
"src/qu8-avgpool/9x-minmax-neon-c8.c",
"src/qu8-dwconv/gen/up8x9-minmax-rndnu-neon-mul16.c",
@@ -1953,6 +1959,8 @@
"src/qu8-igemm/gen/4x16-minmax-rndnu-neon-mlal-lane.c",
"src/qu8-vadd/gen/minmax-neon-ld64-x8.c",
"src/qu8-vaddc/gen/minmax-neon-ld64-x8.c",
+ "src/qu8-vmul/gen/minmax-fp32-neon-ld64-x16.c",
+ "src/qu8-vmulc/gen/minmax-fp32-neon-ld64-x16.c",
"src/u8-maxpool/9p8x-minmax-neon-c16.c",
"src/u8-rmax/neon.c",
"src/u8-vclamp/neon-x64.c",
@@ -2951,6 +2959,10 @@
"src/qc8-igemm/gen/2x8c2-minmax-fp32-neonv8-mlal-padal-dup.c",
"src/qc8-igemm/gen/2x8c8-minmax-fp32-neonv8-mlal-padal.c",
"src/qc8-igemm/gen/4x16-minmax-fp32-neonv8-mlal-lane.c",
+ "src/qs8-vmul/gen/minmax-fp32-neonv8-ld64-x16.c",
+ "src/qs8-vmulc/gen/minmax-fp32-neonv8-ld64-x16.c",
+ "src/qu8-vmul/gen/minmax-fp32-neonv8-ld64-x16.c",
+ "src/qu8-vmulc/gen/minmax-fp32-neonv8-ld64-x16.c",
]
ALL_NEONV8_MICROKERNEL_SRCS = [
@@ -3488,6 +3500,8 @@
"src/qs8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c",
"src/qs8-vadd/gen/minmax-sse2-mul16-ld64-x8.c",
"src/qs8-vaddc/gen/minmax-sse2-mul16-ld64-x8.c",
+ "src/qs8-vmul/gen/minmax-fp32-sse2-mul16-ld64-x8.c",
+ "src/qs8-vmulc/gen/minmax-fp32-sse2-mul16-ld64-x8.c",
"src/qu8-avgpool/9p8x-minmax-sse2-c8.c",
"src/qu8-avgpool/9x-minmax-sse2-c8.c",
"src/qu8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16.c",
@@ -3500,6 +3514,8 @@
"src/qu8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c",
"src/qu8-vadd/gen/minmax-sse2-mul16-ld64-x8.c",
"src/qu8-vaddc/gen/minmax-sse2-mul16-ld64-x8.c",
+ "src/qu8-vmul/gen/minmax-fp32-sse2-mul16-ld64-x8.c",
+ "src/qu8-vmulc/gen/minmax-fp32-sse2-mul16-ld64-x8.c",
"src/u8-maxpool/9p8x-minmax-sse2-c16.c",
"src/u8-rmax/sse2.c",
"src/u8-vclamp/sse2-x64.c",
@@ -3858,6 +3874,8 @@
"src/qs8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c",
"src/qs8-vadd/gen/minmax-sse41-mul16-ld64-x8.c",
"src/qs8-vaddc/gen/minmax-sse41-mul16-ld64-x8.c",
+ "src/qs8-vmul/gen/minmax-fp32-sse41-mul16-ld64-x16.c",
+ "src/qs8-vmulc/gen/minmax-fp32-sse41-mul16-ld64-x16.c",
"src/qu8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16.c",
"src/qu8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16.c",
"src/qu8-gemm/gen/1x4c8-minmax-fp32-sse41-ld64.c",
@@ -3866,6 +3884,8 @@
"src/qu8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c",
"src/qu8-vadd/gen/minmax-sse41-mul16-ld64-x8.c",
"src/qu8-vaddc/gen/minmax-sse41-mul16-ld64-x8.c",
+ "src/qu8-vmul/gen/minmax-fp32-sse41-mul16-ld64-x16.c",
+ "src/qu8-vmulc/gen/minmax-fp32-sse41-mul16-ld64-x16.c",
]
ALL_SSE41_MICROKERNEL_SRCS = [
@@ -4159,6 +4179,8 @@
"src/qs8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c",
"src/qs8-vadd/gen/minmax-avx-mul32-ld32-x8.c",
"src/qs8-vaddc/gen/minmax-avx-mul32-ld32-x8.c",
+ "src/qs8-vmul/gen/minmax-fp32-avx-mul16-ld64-x16.c",
+ "src/qs8-vmulc/gen/minmax-fp32-avx-mul16-ld64-x16.c",
"src/qu8-dwconv/gen/up16x9-minmax-fp32-avx-mul16.c",
"src/qu8-dwconv/gen/up16x25-minmax-fp32-avx-mul16.c",
"src/qu8-gemm/gen/1x4c8-minmax-fp32-avx-ld128.c",
@@ -4167,6 +4189,8 @@
"src/qu8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c",
"src/qu8-vadd/gen/minmax-avx-mul32-ld32-x8.c",
"src/qu8-vaddc/gen/minmax-avx-mul32-ld32-x8.c",
+ "src/qu8-vmul/gen/minmax-fp32-avx-mul16-ld64-x16.c",
+ "src/qu8-vmulc/gen/minmax-fp32-avx-mul16-ld64-x16.c",
]
ALL_AVX_MICROKERNEL_SRCS = [
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 04d3c58..1e6c5f1 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -380,6 +380,8 @@
src/qs8-igemm/gen/4x4-minmax-fp32-scalar-magic.c
src/qs8-vadd/gen/minmax-scalar-x4.c
src/qs8-vaddc/gen/minmax-scalar-x4.c
+ src/qs8-vmul/gen/minmax-fp32-scalar-x4.c
+ src/qs8-vmulc/gen/minmax-fp32-scalar-x4.c
src/qu8-avgpool/9p8x-minmax-scalar-c1.c
src/qu8-avgpool/9x-minmax-scalar-c1.c
src/qu8-dwconv/gen/up1x9-minmax-fp32-scalar-lrint.c
@@ -404,6 +406,8 @@
src/qu8-vadd/gen/minmax-scalar-x4.c
src/qu8-vaddc/gen/minmax-scalar-x1.c
src/qu8-vaddc/gen/minmax-scalar-x4.c
+ src/qu8-vmul/gen/minmax-fp32-scalar-x4.c
+ src/qu8-vmulc/gen/minmax-fp32-scalar-x4.c
src/u8-lut32norm/scalar.c
src/u8-maxpool/9p8x-minmax-scalar-c1.c
src/u8-rmax/scalar.c
@@ -1117,6 +1121,8 @@
src/qs8-vadd/gen/minmax-neon-ld64-x32.c
src/qs8-vaddc/gen/minmax-neon-ld64-x16.c
src/qs8-vaddc/gen/minmax-neon-ld64-x32.c
+ src/qs8-vmul/gen/minmax-fp32-neon-ld64-x16.c
+ src/qs8-vmulc/gen/minmax-fp32-neon-ld64-x16.c
src/qu8-avgpool/9p8x-minmax-neon-c8.c
src/qu8-avgpool/9x-minmax-neon-c8.c
src/qu8-dwconv/gen/up8x9-minmax-rndnu-neon-mul16.c
@@ -1133,6 +1139,8 @@
src/qu8-igemm/gen/4x16-minmax-rndnu-neon-mlal-lane.c
src/qu8-vadd/gen/minmax-neon-ld64-x8.c
src/qu8-vaddc/gen/minmax-neon-ld64-x8.c
+ src/qu8-vmul/gen/minmax-fp32-neon-ld64-x16.c
+ src/qu8-vmulc/gen/minmax-fp32-neon-ld64-x16.c
src/u8-maxpool/9p8x-minmax-neon-c16.c
src/u8-rmax/neon.c
src/u8-vclamp/neon-x64.c
@@ -2124,7 +2132,11 @@
src/qc8-igemm/gen/1x16-minmax-fp32-neonv8-mlal-lane.c
src/qc8-igemm/gen/2x8c2-minmax-fp32-neonv8-mlal-padal-dup.c
src/qc8-igemm/gen/2x8c8-minmax-fp32-neonv8-mlal-padal.c
- src/qc8-igemm/gen/4x16-minmax-fp32-neonv8-mlal-lane.c)
+ src/qc8-igemm/gen/4x16-minmax-fp32-neonv8-mlal-lane.c
+ src/qs8-vmul/gen/minmax-fp32-neonv8-ld64-x16.c
+ src/qs8-vmulc/gen/minmax-fp32-neonv8-ld64-x16.c
+ src/qu8-vmul/gen/minmax-fp32-neonv8-ld64-x16.c
+ src/qu8-vmulc/gen/minmax-fp32-neonv8-ld64-x16.c)
SET(ALL_NEONV8_MICROKERNEL_SRCS
src/f32-vrnd/gen/vrndd-neonv8-x4.c
@@ -2654,6 +2666,8 @@
src/qs8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c
src/qs8-vadd/gen/minmax-sse2-mul16-ld64-x8.c
src/qs8-vaddc/gen/minmax-sse2-mul16-ld64-x8.c
+ src/qs8-vmul/gen/minmax-fp32-sse2-mul16-ld64-x8.c
+ src/qs8-vmulc/gen/minmax-fp32-sse2-mul16-ld64-x8.c
src/qu8-avgpool/9p8x-minmax-sse2-c8.c
src/qu8-avgpool/9x-minmax-sse2-c8.c
src/qu8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16.c
@@ -2666,6 +2680,8 @@
src/qu8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c
src/qu8-vadd/gen/minmax-sse2-mul16-ld64-x8.c
src/qu8-vaddc/gen/minmax-sse2-mul16-ld64-x8.c
+ src/qu8-vmul/gen/minmax-fp32-sse2-mul16-ld64-x8.c
+ src/qu8-vmulc/gen/minmax-fp32-sse2-mul16-ld64-x8.c
src/u8-maxpool/9p8x-minmax-sse2-c16.c
src/u8-rmax/sse2.c
src/u8-vclamp/sse2-x64.c
@@ -3020,6 +3036,8 @@
src/qs8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c
src/qs8-vadd/gen/minmax-sse41-mul16-ld64-x8.c
src/qs8-vaddc/gen/minmax-sse41-mul16-ld64-x8.c
+ src/qs8-vmul/gen/minmax-fp32-sse41-mul16-ld64-x16.c
+ src/qs8-vmulc/gen/minmax-fp32-sse41-mul16-ld64-x16.c
src/qu8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16.c
src/qu8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16.c
src/qu8-gemm/gen/1x4c8-minmax-fp32-sse41-ld64.c
@@ -3027,7 +3045,9 @@
src/qu8-igemm/gen/1x4c8-minmax-fp32-sse41-ld64.c
src/qu8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c
src/qu8-vadd/gen/minmax-sse41-mul16-ld64-x8.c
- src/qu8-vaddc/gen/minmax-sse41-mul16-ld64-x8.c)
+ src/qu8-vaddc/gen/minmax-sse41-mul16-ld64-x8.c
+ src/qu8-vmul/gen/minmax-fp32-sse41-mul16-ld64-x16.c
+ src/qu8-vmulc/gen/minmax-fp32-sse41-mul16-ld64-x16.c)
SET(ALL_SSE41_MICROKERNEL_SRCS
src/f32-prelu/gen/sse41-2x4.c
@@ -3319,6 +3339,8 @@
src/qs8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c
src/qs8-vadd/gen/minmax-avx-mul32-ld32-x8.c
src/qs8-vaddc/gen/minmax-avx-mul32-ld32-x8.c
+ src/qs8-vmul/gen/minmax-fp32-avx-mul16-ld64-x16.c
+ src/qs8-vmulc/gen/minmax-fp32-avx-mul16-ld64-x16.c
src/qu8-dwconv/gen/up16x9-minmax-fp32-avx-mul16.c
src/qu8-dwconv/gen/up16x25-minmax-fp32-avx-mul16.c
src/qu8-gemm/gen/1x4c8-minmax-fp32-avx-ld128.c
@@ -3326,7 +3348,9 @@
src/qu8-igemm/gen/1x4c8-minmax-fp32-avx-ld128.c
src/qu8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c
src/qu8-vadd/gen/minmax-avx-mul32-ld32-x8.c
- src/qu8-vaddc/gen/minmax-avx-mul32-ld32-x8.c)
+ src/qu8-vaddc/gen/minmax-avx-mul32-ld32-x8.c
+ src/qu8-vmul/gen/minmax-fp32-avx-mul16-ld64-x16.c
+ src/qu8-vmulc/gen/minmax-fp32-avx-mul16-ld64-x16.c)
SET(ALL_AVX_MICROKERNEL_SRCS
src/f32-dwconv/gen/up8x4-minmax-avx-acc2.c
diff --git a/include/xnnpack.h b/include/xnnpack.h
index 41831df..ca2c284 100644
--- a/include/xnnpack.h
+++ b/include/xnnpack.h
@@ -2162,6 +2162,29 @@
int8_t* output,
pthreadpool_t threadpool);
+enum xnn_status xnn_create_multiply_nd_qs8(
+ int8_t input1_zero_point,
+ float input1_scale,
+ int8_t input2_zero_point,
+ float input2_scale,
+ int8_t output_zero_point,
+ float output_scale,
+ int8_t output_min,
+ int8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* multiply_op_out);
+
+enum xnn_status xnn_setup_multiply_nd_qs8(
+ xnn_operator_t multiply_op,
+ size_t num_input1_dims,
+ const size_t* input1_shape,
+ size_t num_input2_dims,
+ const size_t* input2_shape,
+ const int8_t* input1,
+ const int8_t* input2,
+ int8_t* output,
+ pthreadpool_t threadpool);
+
#endif // XNN_NO_QS8_OPERATORS
#ifndef XNN_NO_QU8_OPERATORS
@@ -2364,6 +2387,29 @@
uint8_t* output,
pthreadpool_t threadpool);
+enum xnn_status xnn_create_multiply_nd_qu8(
+ uint8_t input1_zero_point,
+ float input1_scale,
+ uint8_t input2_zero_point,
+ float input2_scale,
+ uint8_t output_zero_point,
+ float output_scale,
+ uint8_t output_min,
+ uint8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* multiply_op_out);
+
+enum xnn_status xnn_setup_multiply_nd_qu8(
+ xnn_operator_t multiply_op,
+ size_t num_input1_dims,
+ const size_t* input1_shape,
+ size_t num_input2_dims,
+ const size_t* input2_shape,
+ const uint8_t* input1,
+ const uint8_t* input2,
+ uint8_t* output,
+ pthreadpool_t threadpool);
+
enum xnn_status xnn_create_sigmoid_nc_qu8(
size_t channels,
size_t input_stride,
diff --git a/src/init.c b/src/init.c
index d402b3c..5cebc03 100644
--- a/src/init.c
+++ b/src/init.c
@@ -53,6 +53,7 @@
#include <xnnpack/unpool.h>
#include <xnnpack/vadd.h>
#include <xnnpack/vbinary.h>
+#include <xnnpack/vmul.h>
#include <xnnpack/vmulcaddc.h>
#include <xnnpack/vunary.h>
#include <xnnpack/zip.h>
@@ -185,6 +186,23 @@
.init.qs8_add = xnn_init_qs8_add_minmax_neon_params,
.element_tile = 16,
};
+ if (cpuinfo_has_arm_neon_v8()) {
+ xnn_params.qs8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmul_minmax_fp32_ukernel__neonv8_ld64_x16,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__neonv8_ld64_x16,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__neonv8_ld64_x16,
+ .init.qs8_mul = xnn_init_qs8_mul_minmax_fp32_neonv8_params,
+ .element_tile = 16,
+ };
+ } else {
+ xnn_params.qs8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmul_minmax_fp32_ukernel__neon_ld64_x16,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__neon_ld64_x16,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__neon_ld64_x16,
+ .init.qs8_mul = xnn_init_qs8_mul_minmax_fp32_neon_params,
+ .element_tile = 16,
+ };
+ }
#endif // XNN_NO_QS8_OPERATORS
/*************************** QU8 micro-kernels ***************************/
@@ -226,6 +244,23 @@
.init.qu8_add = xnn_init_qu8_add_minmax_neon_params,
.element_tile = 8,
};
+ if (cpuinfo_has_arm_neon_v8()) {
+ xnn_params.qu8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmul_minmax_fp32_ukernel__neonv8_ld64_x16,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__neonv8_ld64_x16,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__neonv8_ld64_x16,
+ .init.qu8_mul = xnn_init_qu8_mul_minmax_fp32_neonv8_params,
+ .element_tile = 16,
+ };
+ } else {
+ xnn_params.qu8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmul_minmax_fp32_ukernel__neon_ld64_x16,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__neon_ld64_x16,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__neon_ld64_x16,
+ .init.qu8_mul = xnn_init_qu8_mul_minmax_fp32_neon_params,
+ .element_tile = 16,
+ };
+ }
#endif // XNN_NO_QU8_OPERATORS
/**************************** U8 micro-kernels ****************************/
@@ -1264,6 +1299,13 @@
.init.qs8_add = xnn_init_qs8_add_minmax_neon_params,
.element_tile = 32,
};
+ xnn_params.qs8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmul_minmax_fp32_ukernel__neonv8_ld64_x16,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__neonv8_ld64_x16,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__neonv8_ld64_x16,
+ .init.qs8_mul = xnn_init_qs8_mul_minmax_fp32_neonv8_params,
+ .element_tile = 16,
+ };
#endif // XNN_NO_QS8_OPERATORS
/**************************** QU8 micro-kernels ****************************/
@@ -1309,6 +1351,13 @@
.init.qu8_add = xnn_init_qu8_add_minmax_neon_params,
.element_tile = 8,
};
+ xnn_params.qu8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmul_minmax_fp32_ukernel__neonv8_ld64_x16,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__neonv8_ld64_x16,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__neonv8_ld64_x16,
+ .init.qu8_mul = xnn_init_qu8_mul_minmax_fp32_neonv8_params,
+ .element_tile = 16,
+ };
#endif // XNN_NO_QU8_OPERATORS
/**************************** U8 micro-kernels ****************************/
@@ -2164,6 +2213,31 @@
.element_tile = 8,
};
}
+ if (cpuinfo_has_x86_avx()) {
+ xnn_params.qs8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmul_minmax_fp32_ukernel__avx_mul16_ld64_x16,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__avx_mul16_ld64_x16,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__avx_mul16_ld64_x16,
+ .init.qs8_mul = xnn_init_qs8_mul_minmax_fp32_sse4_params,
+ .element_tile = 16,
+ };
+ } else if (cpuinfo_has_x86_sse4_1()) {
+ xnn_params.qs8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmul_minmax_fp32_ukernel__sse41_mul16_ld64_x16,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__sse41_mul16_ld64_x16,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__sse41_mul16_ld64_x16,
+ .init.qs8_mul = xnn_init_qs8_mul_minmax_fp32_sse4_params,
+ .element_tile = 16,
+ };
+ } else {
+ xnn_params.qs8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmul_minmax_fp32_ukernel__sse2_mul16_ld64_x8,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__sse2_mul16_ld64_x8,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__sse2_mul16_ld64_x8,
+ .init.qs8_mul = xnn_init_qs8_mul_minmax_fp32_sse2_params,
+ .element_tile = 8,
+ };
+ }
#endif // XNN_NO_QS8_OPERATORS
/**************************** QU8 micro-kernels ****************************/
@@ -2335,6 +2409,31 @@
.element_tile = 8,
};
}
+ if (cpuinfo_has_x86_avx()) {
+ xnn_params.qu8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmul_minmax_fp32_ukernel__avx_mul16_ld64_x16,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__avx_mul16_ld64_x16,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__avx_mul16_ld64_x16,
+ .init.qu8_mul = xnn_init_qu8_mul_minmax_fp32_sse2_params,
+ .element_tile = 16,
+ };
+ } else if (cpuinfo_has_x86_sse4_1()) {
+ xnn_params.qu8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmul_minmax_fp32_ukernel__sse41_mul16_ld64_x16,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__sse41_mul16_ld64_x16,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__sse41_mul16_ld64_x16,
+ .init.qu8_mul = xnn_init_qu8_mul_minmax_fp32_sse2_params,
+ .element_tile = 16,
+ };
+ } else {
+ xnn_params.qu8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmul_minmax_fp32_ukernel__sse2_mul16_ld64_x8,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__sse2_mul16_ld64_x8,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__sse2_mul16_ld64_x8,
+ .init.qu8_mul = xnn_init_qu8_mul_minmax_fp32_sse2_params,
+ .element_tile = 8,
+ };
+ }
#endif // XNN_NO_QU8_OPERATORS
/**************************** U8 micro-kernels ****************************/
@@ -2933,6 +3032,13 @@
.init.qs8_add = xnn_init_qs8_add_minmax_wasmsimd_params,
.element_tile = 8,
};
+ xnn_params.qs8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmul_minmax_fp32_ukernel__wasmsimd_mul32_ld64_x8,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__wasmsimd_mul32_ld64_x8,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__wasmsimd_mul32_ld64_x8,
+ .init.qs8_mul = xnn_init_qs8_mul_minmax_fp32_wasmsimd_params,
+ .element_tile = 8,
+ };
#endif // XNN_NO_QS8_OPERATORS
/**************************** QU8 micro-kernels ****************************/
@@ -2976,6 +3082,13 @@
.init.qu8_add = xnn_init_qu8_add_minmax_wasmsimd_params,
.element_tile = 8,
};
+ xnn_params.qu8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmul_minmax_fp32_ukernel__wasmsimd_mul32_ld64_x8,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__wasmsimd_mul32_ld64_x8,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__wasmsimd_mul32_ld64_x8,
+ .init.qu8_mul = xnn_init_qu8_mul_minmax_fp32_wasmsimd_params,
+ .element_tile = 8,
+ };
#endif // XNN_NO_QU8_OPERATORS
/**************************** U8 micro-kernels ****************************/
@@ -3525,6 +3638,13 @@
.init.qs8_add = xnn_init_qs8_add_minmax_scalar_params,
.element_tile = 4,
};
+ xnn_params.qs8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmul_minmax_fp32_ukernel__scalar_x4,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__scalar_x4,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qs8_vmulc_minmax_fp32_ukernel__scalar_x4,
+ .init.qs8_mul = xnn_init_qs8_mul_minmax_fp32_scalar_params,
+ .element_tile = 4,
+ };
#endif // XNN_NO_QS8_OPERATORS
/**************************** QU8 micro-kernels ****************************/
@@ -3577,6 +3697,13 @@
.init.qu8_add = xnn_init_qu8_add_minmax_scalar_params,
.element_tile = 4,
};
+ xnn_params.qu8.vmul = (struct vbinary_parameters) {
+ .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmul_minmax_fp32_ukernel__scalar_x4,
+ .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__scalar_x4,
+ .minmax.ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_qu8_vmulc_minmax_fp32_ukernel__scalar_x4,
+ .init.qu8_mul = xnn_init_qu8_mul_minmax_fp32_scalar_params,
+ .element_tile = 4,
+ };
#endif // XNN_NO_QU8_OPERATORS
/**************************** U8 micro-kernels ****************************/
diff --git a/src/operator-strings.c b/src/operator-strings.c
index 27f3463..4f706d4 100644
--- a/src/operator-strings.c
+++ b/src/operator-strings.c
@@ -114,6 +114,10 @@
return "Multiply (ND, F16)";
case xnn_operator_type_multiply_nd_f32:
return "Multiply (ND, F32)";
+ case xnn_operator_type_multiply_nd_qs8:
+ return "Multiply (ND, QS8)";
+ case xnn_operator_type_multiply_nd_qu8:
+ return "Multiply (ND, QU8)";
case xnn_operator_type_negate_nc_f32:
return "Negate (NC, F32)";
case xnn_operator_type_prelu_nc_f32:
diff --git a/src/operators/binary-elementwise-nd.c b/src/operators/binary-elementwise-nd.c
index 4c5364d..888ed9d 100644
--- a/src/operators/binary-elementwise-nd.c
+++ b/src/operators/binary-elementwise-nd.c
@@ -387,6 +387,144 @@
minimum_op_out);
}
+enum xnn_status xnn_create_multiply_nd_qs8(
+ int8_t input1_zero_point,
+ float input1_scale,
+ int8_t input2_zero_point,
+ float input2_scale,
+ int8_t output_zero_point,
+ float output_scale,
+ int8_t output_min,
+ int8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* multiply_op_out)
+{
+ if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
+ xnn_log_error(
+ "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
+ xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), input1_scale);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
+ xnn_log_error(
+ "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
+ xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), input2_scale);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (output_scale <= 0.0f || !isnormal(output_scale)) {
+ xnn_log_error(
+ "failed to create %s operator with %.7g output scale: scale must be finite and positive",
+ xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), output_scale);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: lower bound must be below upper bound",
+ xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), output_min, output_max);
+ return xnn_status_invalid_parameter;
+ }
+
+ const float product_scale = input1_scale * input2_scale;
+ const float product_output_scale = product_scale / output_scale;
+ if (product_output_scale < 0x1.0p-16f || product_output_scale >= 0x1.0p+8f) {
+ xnn_log_error(
+ "failed to create %s operator with %.7g product-to-output scale ratio: scale ratio must be in [2**-16, 2**8) range",
+ xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), product_output_scale);
+ return xnn_status_unsupported_parameter;
+ }
+
+ struct {
+ union xnn_qs8_mul_minmax_params qs8_mul;
+ union xnn_qs8_mul_minmax_params qs8_rmul;
+ } params;
+ xnn_params.qs8.vmul.init.qs8_mul(
+ ¶ms.qs8_mul, input1_zero_point, input2_zero_point, output_zero_point,
+ product_output_scale, output_min, output_max);
+ xnn_params.qs8.vmul.init.qs8_mul(
+ ¶ms.qs8_rmul, input2_zero_point, input1_zero_point, output_zero_point,
+ product_output_scale, output_min, output_max);
+ return create_binary_elementwise_nd(
+ flags,
+ ¶ms,
+ sizeof(params),
+ XNN_INIT_FLAG_QS8,
+ xnn_operator_type_multiply_nd_qs8,
+ &xnn_params.qs8.vmul.minmax,
+ multiply_op_out);
+}
+
+enum xnn_status xnn_create_multiply_nd_qu8(
+ uint8_t input1_zero_point,
+ float input1_scale,
+ uint8_t input2_zero_point,
+ float input2_scale,
+ uint8_t output_zero_point,
+ float output_scale,
+ uint8_t output_min,
+ uint8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* multiply_op_out)
+{
+ if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
+ xnn_log_error(
+ "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
+ xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), input1_scale);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
+ xnn_log_error(
+ "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
+ xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), input2_scale);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (output_scale <= 0.0f || !isnormal(output_scale)) {
+ xnn_log_error(
+ "failed to create %s operator with %.7g output scale: scale must be finite and positive",
+ xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), output_scale);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: lower bound must be below upper bound",
+ xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), output_min, output_max);
+ return xnn_status_invalid_parameter;
+ }
+
+ const float product_scale = input1_scale * input2_scale;
+ const float product_output_scale = product_scale / output_scale;
+ if (product_output_scale < 0x1.0p-16f || product_output_scale >= 0x1.0p+8f) {
+ xnn_log_error(
+ "failed to create %s operator with %.7g product-to-output scale ratio: scale ratio must be in [2**-16, 2**8) range",
+ xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), product_output_scale);
+ return xnn_status_unsupported_parameter;
+ }
+
+ struct {
+ union xnn_qu8_mul_minmax_params qu8_mul;
+ union xnn_qu8_mul_minmax_params qu8_rmul;
+ } params;
+ xnn_params.qu8.vmul.init.qu8_mul(
+ ¶ms.qu8_mul, input1_zero_point, input2_zero_point, output_zero_point,
+ product_output_scale, output_min, output_max);
+ xnn_params.qu8.vmul.init.qu8_mul(
+ ¶ms.qu8_rmul, input2_zero_point, input1_zero_point, output_zero_point,
+ product_output_scale, output_min, output_max);
+ return create_binary_elementwise_nd(
+ flags,
+ ¶ms,
+ sizeof(params),
+ XNN_INIT_FLAG_QU8,
+ xnn_operator_type_multiply_nd_qu8,
+ &xnn_params.qu8.vmul.minmax,
+ multiply_op_out);
+}
+
enum xnn_status xnn_create_multiply_nd_f16(
float output_min,
float output_max,
@@ -841,6 +979,54 @@
pthreadpool_get_threads_count(threadpool));
}
+enum xnn_status xnn_setup_multiply_nd_qs8(
+ xnn_operator_t multiply_op,
+ size_t num_input1_dims,
+ const size_t* input1_shape,
+ size_t num_input2_dims,
+ const size_t* input2_shape,
+ const int8_t* input1,
+ const int8_t* input2,
+ int8_t* output,
+ pthreadpool_t threadpool)
+{
+ return setup_binary_elementwise_nd(
+ multiply_op, xnn_operator_type_multiply_nd_qs8,
+ num_input1_dims, input1_shape,
+ num_input2_dims, input2_shape,
+ input1, input2, output,
+ XNN_INIT_FLAG_QS8,
+ 0 /* log2(sizeof(int8_t))) */,
+ &multiply_op->params.qs8_mul, sizeof(multiply_op->params.qs8_mul),
+ &multiply_op->params.qs8_rmul, sizeof(multiply_op->params.qs8_rmul),
+ &xnn_params.qs8.vmul,
+ pthreadpool_get_threads_count(threadpool));
+}
+
+enum xnn_status xnn_setup_multiply_nd_qu8(
+ xnn_operator_t multiply_op,
+ size_t num_input1_dims,
+ const size_t* input1_shape,
+ size_t num_input2_dims,
+ const size_t* input2_shape,
+ const uint8_t* input1,
+ const uint8_t* input2,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ return setup_binary_elementwise_nd(
+ multiply_op, xnn_operator_type_multiply_nd_qu8,
+ num_input1_dims, input1_shape,
+ num_input2_dims, input2_shape,
+ input1, input2, output,
+ XNN_INIT_FLAG_QU8,
+ 0 /* log2(sizeof(uint8_t))) */,
+ &multiply_op->params.qu8_mul, sizeof(multiply_op->params.qu8_mul),
+ &multiply_op->params.qu8_rmul, sizeof(multiply_op->params.qu8_rmul),
+ &xnn_params.qu8.vmul,
+ pthreadpool_get_threads_count(threadpool));
+}
+
enum xnn_status xnn_setup_multiply_nd_f16(
xnn_operator_t multiply_op,
size_t num_input1_dims,
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index afe0589..8a08e28 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -80,6 +80,8 @@
xnn_operator_type_minimum_nd_f32,
xnn_operator_type_multiply_nd_f16,
xnn_operator_type_multiply_nd_f32,
+ xnn_operator_type_multiply_nd_qs8,
+ xnn_operator_type_multiply_nd_qu8,
xnn_operator_type_negate_nc_f32,
xnn_operator_type_prelu_nc_f32,
xnn_operator_type_resize_bilinear_nchw_f32,
@@ -292,9 +294,17 @@
union xnn_qs8_add_minmax_params qs8_radd;
};
struct {
+ union xnn_qs8_mul_minmax_params qs8_mul;
+ union xnn_qs8_mul_minmax_params qs8_rmul;
+ };
+ struct {
union xnn_qu8_add_minmax_params qu8_add;
union xnn_qu8_add_minmax_params qu8_radd;
};
+ struct {
+ union xnn_qu8_mul_minmax_params qu8_mul;
+ union xnn_qu8_mul_minmax_params qu8_rmul;
+ };
union xnn_qu8_conv_minmax_params qu8_conv_minmax;
// Average Pooling normally use qu8_avgpool_params, but also initialize qu8_gavgpool_params in case it needs to switch
// to Global Average Pooling operation.
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
index 40b4995..0314b7f 100644
--- a/src/xnnpack/params.h
+++ b/src/xnnpack/params.h
@@ -2331,7 +2331,9 @@
struct vbinary_fused_ukernels linear;
union {
xnn_init_qs8_add_minmax_params_fn qs8_add;
+ xnn_init_qs8_mul_minmax_params_fn qs8_mul;
xnn_init_qu8_add_minmax_params_fn qu8_add;
+ xnn_init_qu8_mul_minmax_params_fn qu8_mul;
} init;
// Number of elements in a tile.
// For best efficiency, micro-kernel must process a multiple of this number of elements in each call.
@@ -2547,6 +2549,7 @@
struct dwconv_parameters dwconv[XNN_MAX_QS8_DWCONV_UKERNELS];
struct gavgpool_parameters gavgpool;
struct vbinary_parameters vadd;
+ struct vbinary_parameters vmul;
} qs8;
struct {
struct gemm_parameters gemm;
@@ -2554,6 +2557,7 @@
struct avgpool_parameters avgpool;
struct gavgpool_parameters gavgpool;
struct vbinary_parameters vadd;
+ struct vbinary_parameters vmul;
} qu8;
struct {
struct maxpool_parameters maxpool;
diff --git a/test/binary-elementwise-operator-tester.h b/test/binary-elementwise-operator-tester.h
index 576aab7..8e02bcf 100644
--- a/test/binary-elementwise-operator-tester.h
+++ b/test/binary-elementwise-operator-tester.h
@@ -287,6 +287,14 @@
int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
0, &binary_elementwise_op);
break;
+ case OperationType::Multiply:
+ status = xnn_create_multiply_nd_qs8(
+ input1_zero_point(), input1_scale(),
+ input2_zero_point(), input2_scale(),
+ output_zero_point(), output_scale(),
+ int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
+ 0, &binary_elementwise_op);
+ break;
default:
FAIL() << "Unsupported operation type";
}
@@ -311,6 +319,17 @@
input1.data(), input2.data(), output.data(),
nullptr /* thread pool */));
break;
+ case OperationType::Multiply:
+ ASSERT_EQ(xnn_status_success,
+ xnn_setup_multiply_nd_qs8(
+ binary_elementwise_op,
+ num_input1_dims(),
+ input1_shape().data(),
+ num_input2_dims(),
+ input2_shape().data(),
+ input1.data(), input2.data(), output.data(),
+ nullptr /* thread pool */));
+ break;
default:
FAIL() << "Unsupported operation type";
}
@@ -431,6 +450,14 @@
qmin(), qmax(),
0, &binary_elementwise_op);
break;
+ case OperationType::Multiply:
+ status = xnn_create_multiply_nd_qu8(
+ input1_zero_point(), input1_scale(),
+ input2_zero_point(), input2_scale(),
+ output_zero_point(), output_scale(),
+ qmin(), qmax(),
+ 0, &binary_elementwise_op);
+ break;
default:
FAIL() << "Unsupported operation type";
}
@@ -455,6 +482,17 @@
input1.data(), input2.data(), output.data(),
nullptr /* thread pool */));
break;
+ case OperationType::Multiply:
+ ASSERT_EQ(xnn_status_success,
+ xnn_setup_multiply_nd_qu8(
+ binary_elementwise_op,
+ num_input1_dims(),
+ input1_shape().data(),
+ num_input2_dims(),
+ input2_shape().data(),
+ input1.data(), input2.data(), output.data(),
+ nullptr /* thread pool */));
+ break;
default:
FAIL() << "Unsupported operation type";
}
diff --git a/test/multiply-nd.cc b/test/multiply-nd.cc
index ac92b13..31fb7f3 100644
--- a/test/multiply-nd.cc
+++ b/test/multiply-nd.cc
@@ -15,6 +15,2654 @@
constexpr size_t kDim6 = 7;
+TEST(MULTIPLY_ND_QS8, 0d_x_0d) {
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .TestQS8();
+}
+
+TEST(MULTIPLY_ND_QS8, 1d_x_0d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .TestQS8();
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 0d_x_1d) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_shape({input2_dim1})
+ .TestQS8();
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 1d_x_1d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .input2_shape({input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 0d_x_2d) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 1d_x_2d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 2d_x_0d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .TestQS8();
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 2d_x_1d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .input2_shape({input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 2d_x_2d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 0d_x_3d) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 1d_x_3d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 2d_x_3d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 3d_x_0d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .TestQS8();
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 3d_x_1d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 3d_x_2d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 3d_x_3d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 0d_x_4d) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 1d_x_4d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 2d_x_4d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 3d_x_4d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 4d_x_0d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .TestQS8();
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 4d_x_1d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 4d_x_2d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 4d_x_3d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 4d_x_4d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 0d_x_5d) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 1d_x_5d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 2d_x_5d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 3d_x_5d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 4d_x_5d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 5d_x_0d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .TestQS8();
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 5d_x_1d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 5d_x_2d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 5d_x_3d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 5d_x_4d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 5d_x_5d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .iterations(1)
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 0d_x_6d) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 1d_x_6d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 2d_x_6d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 3d_x_6d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 4d_x_6d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 5d_x_6d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .iterations(1)
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 6d_x_0d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .TestQS8();
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 6d_x_1d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 6d_x_2d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 6d_x_3d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 6d_x_4d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 6d_x_5d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .iterations(1)
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, 6d_x_6d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .iterations(1)
+ .TestQS8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, input1_scale) {
+ for (float input1_scale = 0.1f; input1_scale <= 10.0f; input1_scale *= 3.14f) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_scale(input1_scale)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, input1_zero_point) {
+ for (int32_t input1_zero_point = -128; input1_zero_point <= 127; input1_zero_point += 51) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_zero_point(input1_zero_point)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, input2_scale) {
+ for (float input2_scale = 0.1f; input2_scale <= 10.0f; input2_scale *= 3.14f) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_scale(input2_scale)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, input2_zero_point) {
+ for (int32_t input2_zero_point = -128; input2_zero_point <= 127; input2_zero_point += 51) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_zero_point(input2_zero_point)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, output_scale) {
+ for (float output_scale = 0.1f; output_scale <= 10.0f; output_scale *= 3.14f) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_scale(output_scale)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QS8, output_zero_point) {
+ for (int32_t output_zero_point = -128; output_zero_point <= 127; output_zero_point += 51) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .output_zero_point(output_zero_point)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQS8();
+ }
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 0d_x_0d) {
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .TestQU8();
+}
+
+TEST(MULTIPLY_ND_QU8, 1d_x_0d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .TestQU8();
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 0d_x_1d) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_shape({input2_dim1})
+ .TestQU8();
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 1d_x_1d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .input2_shape({input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 0d_x_2d) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 1d_x_2d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 2d_x_0d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .TestQU8();
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 2d_x_1d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .input2_shape({input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 2d_x_2d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 0d_x_3d) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 1d_x_3d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 2d_x_3d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 3d_x_0d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .TestQU8();
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 3d_x_1d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 3d_x_2d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 3d_x_3d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 0d_x_4d) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 1d_x_4d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 2d_x_4d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 3d_x_4d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 4d_x_0d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .TestQU8();
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 4d_x_1d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 4d_x_2d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 4d_x_3d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 4d_x_4d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 0d_x_5d) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 1d_x_5d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 2d_x_5d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 3d_x_5d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 4d_x_5d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 5d_x_0d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .TestQU8();
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 5d_x_1d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 5d_x_2d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 5d_x_3d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 5d_x_4d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 5d_x_5d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .iterations(1)
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 0d_x_6d) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 1d_x_6d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 1); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim1})
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 2d_x_6d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 2); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim2, input1_dim1})
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 3d_x_6d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 3); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 4d_x_6d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 5d_x_6d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 5); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .iterations(1)
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 6d_x_0d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .TestQU8();
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 6d_x_1d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 1); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 6d_x_2d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 2); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 6d_x_3d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 3); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 6d_x_4d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 6d_x_5d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 5); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .iterations(1)
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, 6d_x_6d) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 6); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 6); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input1_broadcast_dim5 = bm1 & (uint32_t(1) << 4);
+ const bool input1_broadcast_dim6 = bm1 & (uint32_t(1) << 5);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim5 = bm2 & (uint32_t(1) << 4);
+ const bool input2_broadcast_dim6 = bm2 & (uint32_t(1) << 5);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input1_dim5 = input1_broadcast_dim5 ? 1 : kDim5;
+ const size_t input1_dim6 = input1_broadcast_dim6 ? 1 : kDim6;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim5 = input2_broadcast_dim5 ? 1 : kDim5;
+ const size_t input2_dim6 = input2_broadcast_dim6 ? 1 : kDim6;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_shape({input1_dim6, input1_dim5, input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim6, input2_dim5, input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .iterations(1)
+ .TestQU8();
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, input1_scale) {
+ for (float input1_scale = 0.1f; input1_scale <= 10.0f; input1_scale *= 3.14f) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_scale(input1_scale)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, input1_zero_point) {
+ for (int32_t input1_zero_point = 0; input1_zero_point <= 255; input1_zero_point += 51) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_zero_point(input1_zero_point)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, input2_scale) {
+ for (float input2_scale = 0.1f; input2_scale <= 10.0f; input2_scale *= 3.14f) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_scale(input2_scale)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, input2_zero_point) {
+ for (int32_t input2_zero_point = 0; input2_zero_point <= 255; input2_zero_point += 51) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input2_zero_point(input2_zero_point)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, output_scale) {
+ for (float output_scale = 0.1f; output_scale <= 10.0f; output_scale *= 3.14f) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .input1_scale(output_scale)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+ }
+}
+
+TEST(MULTIPLY_ND_QU8, output_zero_point) {
+ for (int32_t output_zero_point = 0; output_zero_point <= 255; output_zero_point += 51) {
+ for (uint32_t bm1 = 0; bm1 < (uint32_t(1) << 4); bm1++) {
+ for (uint32_t bm2 = 0; bm2 < (uint32_t(1) << 4); bm2++) {
+ const bool input1_broadcast_dim1 = bm1 & (uint32_t(1) << 0);
+ const bool input1_broadcast_dim2 = bm1 & (uint32_t(1) << 1);
+ const bool input1_broadcast_dim3 = bm1 & (uint32_t(1) << 2);
+ const bool input1_broadcast_dim4 = bm1 & (uint32_t(1) << 3);
+ const bool input2_broadcast_dim1 = bm2 & (uint32_t(1) << 0);
+ const bool input2_broadcast_dim2 = bm2 & (uint32_t(1) << 1);
+ const bool input2_broadcast_dim3 = bm2 & (uint32_t(1) << 2);
+ const bool input2_broadcast_dim4 = bm2 & (uint32_t(1) << 3);
+ const size_t input1_dim1 = input1_broadcast_dim1 ? 1 : kDim1;
+ const size_t input1_dim2 = input1_broadcast_dim2 ? 1 : kDim2;
+ const size_t input1_dim3 = input1_broadcast_dim3 ? 1 : kDim3;
+ const size_t input1_dim4 = input1_broadcast_dim4 ? 1 : kDim4;
+ const size_t input2_dim1 = input2_broadcast_dim1 ? 1 : kDim1;
+ const size_t input2_dim2 = input2_broadcast_dim2 ? 1 : kDim2;
+ const size_t input2_dim3 = input2_broadcast_dim3 ? 1 : kDim3;
+ const size_t input2_dim4 = input2_broadcast_dim4 ? 1 : kDim4;
+ BinaryElementwiseOperatorTester()
+ .operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)
+ .output_zero_point(output_zero_point)
+ .input1_shape({input1_dim4, input1_dim3, input1_dim2, input1_dim1})
+ .input2_shape({input2_dim4, input2_dim3, input2_dim2, input2_dim1})
+ .TestQU8();
+ }
+ }
+ }
+}
+
TEST(MULTIPLY_ND_F16, 0d_x_0d) {
BinaryElementwiseOperatorTester()
.operation_type(BinaryElementwiseOperatorTester::OperationType::Multiply)