Optimized packing code path for row-major 8bit inputs for the x86 paths.
PiperOrigin-RevId: 322925521
diff --git a/ruy/pack_avx2_fma.cc b/ruy/pack_avx2_fma.cc
index 601d464..55879a1 100644
--- a/ruy/pack_avx2_fma.cc
+++ b/ruy/pack_avx2_fma.cc
@@ -44,6 +44,11 @@
RUY_DCHECK(false);
}
+void Pack8bitRowMajorForAvx2(const std::uint8_t*, int, int, std::int8_t*, int,
+ int, int, int, int, int, int, std::int32_t*) {
+ RUY_DCHECK(false);
+}
+
#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
// The first int8_t template parameter is arbitrary: this routine is common to
@@ -811,6 +816,87 @@
}
}
+void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int packed_stride, int start_col, int end_col,
+ int src_cols, int block_row, int src_rows,
+ int input_xor, std::int32_t* sums) {
+ int col = start_col;
+ int src_end_col = std::min(end_col, src_cols);
+
+ for (; col <= src_end_col - 8; col += 8) {
+ std::int8_t* dst_ptr = packed_ptr;
+ __m128i val0, val1, val2, val3;
+ __m128i input_xor_dup = _mm_set1_epi8(input_xor);
+ // Load a 4x8 block.
+ if (block_row + 4 <= src_rows) {
+ val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
+ val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
+ val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
+ val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
+ } else {
+ val0 = _mm_set1_epi8(src_zero_point);
+ val1 = val0;
+ val2 = val0;
+ val3 = val0;
+ if (block_row + 0 < src_rows)
+ val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
+ if (block_row + 1 < src_rows)
+ val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
+ if (block_row + 2 < src_rows)
+ val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
+ if (block_row + 3 < src_rows)
+ val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
+ }
+ // Maybe xor the sign bit to convert from uint8 to int8.
+ val0 = _mm_xor_si128(val0, input_xor_dup);
+ val1 = _mm_xor_si128(val1, input_xor_dup);
+ val2 = _mm_xor_si128(val2, input_xor_dup);
+ val3 = _mm_xor_si128(val3, input_xor_dup);
+ // Update the sums.
+ __m128i val16_0 = _mm_cvtepi8_epi16(val0);
+ __m128i val16_1 = _mm_cvtepi8_epi16(val1);
+ __m128i val16_2 = _mm_cvtepi8_epi16(val2);
+ __m128i val16_3 = _mm_cvtepi8_epi16(val3);
+ __m128i new_sum16 = _mm_add_epi16(_mm_add_epi16(val16_0, val16_1),
+ _mm_add_epi16(val16_2, val16_3));
+ __m256i sum =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums + col));
+ sum = _mm256_add_epi32(sum, _mm256_cvtepi16_epi32(new_sum16));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums + col), sum);
+ // Perform the transposition of 4x4 blocks
+ __m128i t2_val0 = _mm_unpacklo_epi8(val0, val1);
+ __m128i t2_val1 = _mm_unpacklo_epi8(val2, val3);
+ __m128i t4_val0 = _mm_unpacklo_epi16(t2_val0, t2_val1);
+ __m128i t4_val1 = _mm_unpackhi_epi16(t2_val0, t2_val1);
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), t4_val0);
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), t4_val1);
+ src_ptr += 8;
+ packed_ptr += packed_stride * 8;
+ }
+ for (; col < src_end_col; col++) {
+ std::int32_t accum = 0;
+ for (int r = 0; r < 4; r++) {
+ std::int8_t packed_val;
+ if (block_row + r < src_rows) {
+ packed_val = input_xor ^ src_ptr[r * src_stride];
+ } else {
+ packed_val = input_xor ^ src_zero_point;
+ }
+ accum += packed_val;
+ *packed_ptr++ = packed_val;
+ }
+ if (sums) {
+ sums[col] += accum;
+ }
+ src_ptr++;
+ }
+ for (; col < end_col; col++) {
+ std::memset(packed_ptr, 0, 4);
+ packed_ptr += 4;
+ }
+}
+
#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS)
} // namespace ruy
diff --git a/ruy/pack_avx512.cc b/ruy/pack_avx512.cc
index bd04aca..b38bc01 100644
--- a/ruy/pack_avx512.cc
+++ b/ruy/pack_avx512.cc
@@ -44,6 +44,11 @@
RUY_DCHECK(false);
}
+void Pack8bitRowMajorForAvx512(const std::uint8_t*, int, int, std::int8_t*, int,
+ int, int, int, int, int, int, std::int32_t*) {
+ RUY_DCHECK(false);
+}
+
#else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
// The first int8_t template parameter is arbitrary: this routine is common to
@@ -717,6 +722,107 @@
}
}
+void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int packed_stride, int start_col, int end_col,
+ int src_cols, int block_row, int src_rows,
+ int input_xor, std::int32_t* sums) {
+ int col = start_col;
+ int src_end_col = std::min(end_col, src_cols);
+
+ for (; col <= src_end_col - 16; col += 16) {
+ std::int8_t* dst_ptr = packed_ptr;
+ __m128i val0, val1, val2, val3;
+ __m128i input_xor_dup = _mm_set1_epi8(input_xor);
+ // Load a 4x16 block.
+ if (block_row + 4 <= src_rows) {
+ val0 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
+ val1 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
+ val2 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
+ val3 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
+ } else {
+ val0 = _mm_set1_epi8(src_zero_point);
+ val1 = val0;
+ val2 = val0;
+ val3 = val0;
+ if (block_row + 0 < src_rows)
+ val0 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
+ if (block_row + 1 < src_rows)
+ val1 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
+ if (block_row + 2 < src_rows)
+ val2 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
+ if (block_row + 3 < src_rows)
+ val3 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
+ }
+ // Maybe xor the sign bit to convert from uint8 to int8.
+ val0 = _mm_xor_si128(val0, input_xor_dup);
+ val1 = _mm_xor_si128(val1, input_xor_dup);
+ val2 = _mm_xor_si128(val2, input_xor_dup);
+ val3 = _mm_xor_si128(val3, input_xor_dup);
+ // Update the sums.
+ __m256i val16_0 = _mm256_cvtepi8_epi16(val0);
+ __m256i val16_1 = _mm256_cvtepi8_epi16(val1);
+ __m256i val16_2 = _mm256_cvtepi8_epi16(val2);
+ __m256i val16_3 = _mm256_cvtepi8_epi16(val3);
+ __m256i new_sum16 = _mm256_add_epi16(_mm256_add_epi16(val16_0, val16_1),
+ _mm256_add_epi16(val16_2, val16_3));
+ __m512i sum =
+ _mm512_loadu_si512(reinterpret_cast<const __m512i*>(sums + col));
+ sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(new_sum16));
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(sums + col), sum);
+ auto zip = [](__m128i x, __m128i y) {
+ auto perm_64_0_64_0 = [](__m128i x) {
+ return _mm256_permutexvar_epi64(_mm256_setr_epi64x(0, 2, 1, 3),
+ _mm256_castsi128_si256(x));
+ };
+ return _mm256_unpacklo_epi8(perm_64_0_64_0(x), perm_64_0_64_0(y));
+ };
+ __m256i t2_val0 = zip(val0, val1);
+ __m256i t2_val1 = zip(val2, val3);
+ __m256i t4_val0 = _mm256_unpacklo_epi16(t2_val0, t2_val1);
+ __m256i t4_val1 = _mm256_unpackhi_epi16(t2_val0, t2_val1);
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr),
+ _mm256_extractf128_si256(t4_val0, 0));
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16),
+ _mm256_extractf128_si256(t4_val1, 0));
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 32),
+ _mm256_extractf128_si256(t4_val0, 1));
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 48),
+ _mm256_extractf128_si256(t4_val1, 1));
+ src_ptr += 16;
+ packed_ptr += packed_stride * 16;
+ }
+ for (; col < src_end_col; col++) {
+ std::int32_t accum = 0;
+ for (int r = 0; r < 4; r++) {
+ std::int8_t packed_val;
+ if (block_row + r < src_rows) {
+ packed_val = input_xor ^ src_ptr[r * src_stride];
+ } else {
+ packed_val = input_xor ^ src_zero_point;
+ }
+ accum += packed_val;
+ *packed_ptr++ = packed_val;
+ }
+ if (sums) {
+ sums[col] += accum;
+ }
+ src_ptr++;
+ }
+ for (; col < end_col; col++) {
+ std::memset(packed_ptr, 0, 4);
+ packed_ptr += 4;
+ }
+}
+
#endif // RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
} // namespace ruy
diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h
index d78892a..bc84359 100644
--- a/ruy/pack_x86.h
+++ b/ruy/pack_x86.h
@@ -217,6 +217,77 @@
}
}
};
+
+void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int packed_stride, int start_col, int end_col,
+ int src_cols, int block_row, int src_rows,
+ int input_xor, std::int32_t* sums);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
+ Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (kAvx2Fma 8bit row-major)");
+ RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+ std::int32_t* sums = packed_matrix->sums;
+ std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
+ int block_row = 0;
+ for (; block_row < packed_matrix->layout.rows; block_row += 4) {
+ int src_stride = src_matrix.layout.stride;
+ int packed_stride = packed_matrix->layout.stride;
+ const Scalar* src_ptr =
+ src_matrix.data.get() + block_row * src_stride + start_col;
+ std::int8_t* packed_ptr =
+ packed_matrix->data + start_col * packed_stride + block_row * 8;
+ Pack8bitRowMajorForAvx2(reinterpret_cast<const std::uint8_t*>(src_ptr),
+ src_stride, src_matrix.zero_point, packed_ptr,
+ packed_stride, start_col, end_col,
+ src_matrix.layout.cols, block_row,
+ src_matrix.layout.rows, kInputXor, sums);
+ }
+ }
+};
+
+void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int packed_stride, int start_col, int end_col,
+ int src_cols, int block_row, int src_rows,
+ int input_xor, std::int32_t* sums);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+ Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (kAvx512 8bit row-major)");
+ RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+ std::int32_t* sums = packed_matrix->sums;
+ std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
+ int block_row = 0;
+ for (; block_row < packed_matrix->layout.rows; block_row += 4) {
+ int src_stride = src_matrix.layout.stride;
+ int packed_stride = packed_matrix->layout.stride;
+ const Scalar* src_ptr =
+ src_matrix.data.get() + block_row * src_stride + start_col;
+ std::int8_t* packed_ptr =
+ packed_matrix->data + start_col * packed_stride + block_row * 16;
+ Pack8bitRowMajorForAvx512(reinterpret_cast<const std::uint8_t*>(src_ptr),
+ src_stride, src_matrix.zero_point, packed_ptr,
+ packed_stride, start_col, end_col,
+ src_matrix.layout.cols, block_row,
+ src_matrix.layout.rows, kInputXor, sums);
+ }
+ }
+};
+
#endif // RUY_PLATFORM_X86
} // namespace ruy