Optimized packing code path for row-major 8bit inputs for the kNeon path. Written in intrinsics to handle 3 cases at once:
ARM64, KernelCols==4
ARM32, KernelCols==4 (LHS)
ARM32, KernelCols==2 (RHS)
PiperOrigin-RevId: 322921106
diff --git a/ruy/pack_arm.cc b/ruy/pack_arm.cc
index 1dcac5c..f29b214 100644
--- a/ruy/pack_arm.cc
+++ b/ruy/pack_arm.cc
@@ -23,6 +23,10 @@
#include "ruy/platform.h"
#include "ruy/profiler/instrumentation.h"
+#if RUY_PLATFORM_NEON
+#include <arm_neon.h>
+#endif
+
namespace ruy {
#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
@@ -2216,4 +2220,263 @@
}
#endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+#if RUY_PLATFORM_NEON
+
+namespace {
+// transpose_*bit_vals are wrappers around ARM TRN1 instructions, allowing
+// to use these instructions like we would in assembly --- this is one instance
+// where assembly is more idiomatic than intrinsics.
+//
+// The way that TRN1 is exposed by vtrn_* intrinsics makes its usage very
+// cumbersome. The issue is that transposing grouped of values has been exposed
+// only as transposing values of a wider type, so this requires many
+// vreinterpret's, and to make it worse, vtrn_* return NEON array types like
+// int8x8x2_t for which vreinterpret's are not defined!
+void transpose_8bit_vals(int8x8_t& a, int8x8_t& b) {
+ int8x8x2_t t = vtrn_s8(a, b);
+ a = t.val[0];
+ b = t.val[1];
+}
+
+void transpose_16bit_vals(int8x8_t& a, int8x8_t& b) {
+ int16x4x2_t t = vtrn_s16(vreinterpret_s16_s8(a), vreinterpret_s16_s8(b));
+ a = vreinterpret_s8_s16(t.val[0]);
+ b = vreinterpret_s8_s16(t.val[1]);
+}
+
+void transpose_32bit_vals(int8x8_t& a, int8x8_t& b) {
+ int32x2x2_t t = vtrn_s32(vreinterpret_s32_s8(a), vreinterpret_s32_s8(b));
+ a = vreinterpret_s8_s32(t.val[0]);
+ b = vreinterpret_s8_s32(t.val[1]);
+}
+} // namespace
+
+void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride,
+ int src_rows, int src_cols, int block_row,
+ int start_col, int end_col,
+ std::int8_t* packed_ptr, int packed_stride,
+ int packed_zero_point, std::int32_t* sums,
+ int input_xor, int kernel_cols) {
+ profiler::ScopeLabel label("Pack (kNeon, from row-major)");
+
+ int src_end_col = std::min(end_col, src_cols);
+ int col = start_col;
+ for (; col <= src_end_col - 8; col += 8) {
+ // Each iteration of this loop handles 8 columns, and the kernel format
+ // has 16 rows, so each iteration handles a 16x8 block.
+ //
+ // Since the source is row-major, handling 8 columns at a time means
+ // loading only 8 bytes i.e. 64bit from each row. This may seem surprising
+ // on 128bit SIMD like NEON. While we could handle 16 columns at a time,
+ // we prefer to stick with 8 for the following reasons:
+ // 1. The arithmetic (computing sums and transposing data) done on these
+ // values is such that even though we initially start from 64bit vectors,
+ // most of our NEON instructions are full 128bit instructions. For the
+ // sums computation, that is because summing 8bit values requires
+ // expansion to 16bit anyway. For the matrix transposition code, that is
+ // because the ARM ZIP instructions take 64bit of data from two input
+ // registers and zip it into a 128bit output. If we had 128bit of data
+ // in each input registers, we would need 2x more ARM NEON instructions
+ // to zip it.
+ // 2. The main optimization target for this (ARM, 8bit, non-dotprod)
+ // code path is in-order ARM cores such as the Cortex-A53, which prefer
+ // 64bit loads anyway.
+ // 3. Handling only 8 columns at a time limits the size of the final
+ // leftover columns handled with slow scalar code.
+ //
+ // This code is not very optimized anyway, as evidenced from the facts that
+ // (1) it's written in intrinsics, (2) it's not using separate versions
+ // tuned for different types of CPU cores. At the level of optimization that
+ // it's working at, this seems like a fair compromise. If one wanted to
+ // maximize performance at the cost of more code complexity/size, one could
+ // have code handling 16 columns at a time (maybe limited to
+ // Tuning::kOutOfOrder), then 8, then 4 to minimize the amount of slow
+ // leftovers.
+ //
+ // Load 8 sums in sums0, sums1.
+ int32x4_t sums0 = vld1q_s32(sums + col);
+ int32x4_t sums1 = vld1q_s32(sums + col + 4);
+ // Load the 8x16 block from the source matrix.
+ // Each val* here is the data from one row.
+ int8x8_t val0, val1, val2, val3, val4, val5, val6, val7, val8, val9, val10,
+ val11, val12, val13, val14, val15;
+ // Even though this function takes a uint8_t* src_ptr, that's only a
+ // type-erased pointer (using uint8_t* so that pointer arithmetic is
+ // allowed). The actual type may be either uint8_t or int8_t. The only
+ // difference it makes is that if it's uint8_t then we need to flip the
+ // sign bit. This is specified by the input_xor value (which is 0x80 if the
+ // input data is uint8_t, and 0x0 otherwise).
+ auto load_and_convert = [=](const std::uint8_t* from) {
+ return vreinterpret_s8_u8(veor_u8(vdup_n_u8(input_xor), vld1_u8(from)));
+ };
+ if (block_row <= src_rows - 16) {
+ // Load data in the regular case: there are still 16 rows to be read from
+ // the source matrix.
+ val0 = load_and_convert(src_ptr + 0 * src_stride);
+ val1 = load_and_convert(src_ptr + 1 * src_stride);
+ val2 = load_and_convert(src_ptr + 2 * src_stride);
+ val3 = load_and_convert(src_ptr + 3 * src_stride);
+ val4 = load_and_convert(src_ptr + 4 * src_stride);
+ val5 = load_and_convert(src_ptr + 5 * src_stride);
+ val6 = load_and_convert(src_ptr + 6 * src_stride);
+ val7 = load_and_convert(src_ptr + 7 * src_stride);
+ val8 = load_and_convert(src_ptr + 8 * src_stride);
+ val9 = load_and_convert(src_ptr + 9 * src_stride);
+ val10 = load_and_convert(src_ptr + 10 * src_stride);
+ val11 = load_and_convert(src_ptr + 11 * src_stride);
+ val12 = load_and_convert(src_ptr + 12 * src_stride);
+ val13 = load_and_convert(src_ptr + 13 * src_stride);
+ val14 = load_and_convert(src_ptr + 14 * src_stride);
+ val15 = load_and_convert(src_ptr + 15 * src_stride);
+ } else {
+ // Boundary case: there are fewer than 16 rows to be read from the source
+ // matrix. We pad by the zero_point.
+ val0 = vdup_n_s8(packed_zero_point);
+ val1 = val0;
+ val2 = val0;
+ val3 = val0;
+ val4 = val0;
+ val5 = val0;
+ val6 = val0;
+ val7 = val0;
+ val8 = val0;
+ val9 = val0;
+ val10 = val0;
+ val11 = val0;
+ val12 = val0;
+ val13 = val0;
+ val14 = val0;
+ val15 = val0;
+ if (block_row + 0 < src_rows)
+ val0 = load_and_convert(src_ptr + 0 * src_stride);
+ if (block_row + 1 < src_rows)
+ val1 = load_and_convert(src_ptr + 1 * src_stride);
+ if (block_row + 2 < src_rows)
+ val2 = load_and_convert(src_ptr + 2 * src_stride);
+ if (block_row + 3 < src_rows)
+ val3 = load_and_convert(src_ptr + 3 * src_stride);
+ if (block_row + 4 < src_rows)
+ val4 = load_and_convert(src_ptr + 4 * src_stride);
+ if (block_row + 5 < src_rows)
+ val5 = load_and_convert(src_ptr + 5 * src_stride);
+ if (block_row + 6 < src_rows)
+ val6 = load_and_convert(src_ptr + 6 * src_stride);
+ if (block_row + 7 < src_rows)
+ val7 = load_and_convert(src_ptr + 7 * src_stride);
+ if (block_row + 8 < src_rows)
+ val8 = load_and_convert(src_ptr + 8 * src_stride);
+ if (block_row + 9 < src_rows)
+ val9 = load_and_convert(src_ptr + 9 * src_stride);
+ if (block_row + 10 < src_rows)
+ val10 = load_and_convert(src_ptr + 10 * src_stride);
+ if (block_row + 11 < src_rows)
+ val11 = load_and_convert(src_ptr + 11 * src_stride);
+ if (block_row + 12 < src_rows)
+ val12 = load_and_convert(src_ptr + 12 * src_stride);
+ if (block_row + 13 < src_rows)
+ val13 = load_and_convert(src_ptr + 13 * src_stride);
+ if (block_row + 14 < src_rows)
+ val14 = load_and_convert(src_ptr + 14 * src_stride);
+ if (block_row + 15 < src_rows)
+ val15 = load_and_convert(src_ptr + 15 * src_stride);
+ }
+ src_ptr += 8;
+ // Compute sums.
+ int16x8_t sums16_0 = vaddl_s8(val0, val1);
+ int16x8_t sums16_1 = vaddl_s8(val2, val3);
+ sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val4, val5));
+ sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val6, val7));
+ sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val8, val9));
+ sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val10, val11));
+ sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val12, val13));
+ sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val14, val15));
+ int16x8_t sums16 = vaddq_s16(sums16_0, sums16_1);
+ sums0 = vaddw_s16(sums0, vget_low_s16(sums16));
+ sums1 = vaddw_s16(sums1, vget_high_s16(sums16));
+ // Store sums.
+ vst1q_s32(sums + col, sums0);
+ vst1q_s32(sums + col + 4, sums1);
+
+ // Transpose the data, i.e. change the storage order of the
+ // 16x8 block, to convert from the row-major source to the
+ // column-major packed format.
+ //
+ // Before, for i in [0, 15], val<i> is the i-th row.
+ // After, for i in [0, 7], { val<i> val<i+8> } is the i-th column.
+ transpose_8bit_vals(val0, val1);
+ transpose_8bit_vals(val2, val3);
+ transpose_8bit_vals(val4, val5);
+ transpose_8bit_vals(val6, val7);
+ transpose_8bit_vals(val8, val9);
+ transpose_8bit_vals(val10, val11);
+ transpose_8bit_vals(val12, val13);
+ transpose_8bit_vals(val14, val15);
+ transpose_16bit_vals(val0, val2);
+ transpose_16bit_vals(val1, val3);
+ transpose_16bit_vals(val4, val6);
+ transpose_16bit_vals(val5, val7);
+ transpose_16bit_vals(val8, val10);
+ transpose_16bit_vals(val9, val11);
+ transpose_16bit_vals(val12, val14);
+ transpose_16bit_vals(val13, val15);
+ transpose_32bit_vals(val0, val4);
+ transpose_32bit_vals(val1, val5);
+ transpose_32bit_vals(val2, val6);
+ transpose_32bit_vals(val3, val7);
+ transpose_32bit_vals(val8, val12);
+ transpose_32bit_vals(val9, val13);
+ transpose_32bit_vals(val10, val14);
+ transpose_32bit_vals(val11, val15);
+ // Store to the packed_matrix.
+ std::int8_t* dst_ptr = packed_ptr;
+ vst1q_s8(dst_ptr, vcombine_s8(val0, val8));
+ vst1q_s8(dst_ptr + 16, vcombine_s8(val1, val9));
+ dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32;
+ vst1q_s8(dst_ptr, vcombine_s8(val2, val10));
+ vst1q_s8(dst_ptr + 16, vcombine_s8(val3, val11));
+ packed_ptr += 4 * packed_stride;
+ dst_ptr = packed_ptr;
+ vst1q_s8(dst_ptr, vcombine_s8(val4, val12));
+ vst1q_s8(dst_ptr + 16, vcombine_s8(val5, val13));
+ dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32;
+ vst1q_s8(dst_ptr, vcombine_s8(val6, val14));
+ vst1q_s8(dst_ptr + 16, vcombine_s8(val7, val15));
+ packed_ptr += 4 * packed_stride;
+ }
+ // Handle remaining columns, not fitting in a full block of 8 columns, but
+ // still true columns frome the source matrix (as opposed to the final columns
+ // below).
+ for (; col < src_end_col; col++) {
+ std::int32_t accum = 0;
+ std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16;
+ for (int r = 0; r < 16; r++) {
+ std::int8_t packed_val = (block_row + r < src_rows)
+ ? (src_ptr[r * src_stride] ^ input_xor)
+ : packed_zero_point;
+ accum += packed_val;
+ dst_ptr[r] = packed_val;
+ }
+ if (sums) {
+ sums[col] += accum;
+ }
+ src_ptr++;
+ if (((col + 1) & (kernel_cols - 1)) == 0) {
+ packed_ptr += kernel_cols * packed_stride;
+ }
+ }
+ // Handle the final columns of the packed matrix, beyond the last column of
+ // the source matrix. The values here don't matter, we just want to avoid
+ // leaving uninitialized data. Since the sums are already initialized above,
+ // we don't need to do anything about them here.
+ for (; col < end_col; col++) {
+ std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16;
+ std::memset(dst_ptr, 0, 16);
+ if (((col + 1) & (kernel_cols - 1)) == 0) {
+ packed_ptr += kernel_cols * packed_stride;
+ }
+ }
+}
+
+#endif
+
} // namespace ruy
diff --git a/ruy/pack_arm.h b/ruy/pack_arm.h
index 27498b9..12d5cab 100644
--- a/ruy/pack_arm.h
+++ b/ruy/pack_arm.h
@@ -51,6 +51,15 @@
};
#endif
+#if RUY_PLATFORM_NEON
+void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride,
+ int src_rows, int src_cols, int block_row,
+ int start_col, int end_col,
+ std::int8_t* packed_ptr, int packed_stride,
+ int packed_zero_point, std::int32_t* sums_ptr,
+ int input_xor, int kernel_cols);
+#endif
+
#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
void Pack8bitColMajorForNeonOutOfOrder(
@@ -81,7 +90,6 @@
int src_zero_point, std::int8_t* packed_ptr,
int packed_stride, std::int32_t* sums_ptr,
int input_xor);
-
#elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
struct PackParams8bit {
@@ -564,6 +572,42 @@
#endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
+#if RUY_PLATFORM_NEON
+
+template <typename Scalar, int KernelCols>
+struct PackImpl<Path::kNeon,
+ FixedKernelLayout<Order::kColMajor, 16, KernelCols>, Scalar,
+ std::int8_t, std::int32_t, Order::kRowMajor> {
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (KNeon, from row-major source)");
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+ RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
+ RUY_DCHECK_EQ((end_col - start_col) % KernelCols, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
+ int block_row = 0;
+ for (; block_row < packed_matrix->layout.rows; block_row += 16) {
+ int src_stride = src_matrix.layout.stride;
+ int packed_stride = packed_matrix->layout.stride;
+ const Scalar* src_ptr =
+ src_matrix.data.get() + block_row * src_stride + start_col;
+ std::int8_t* packed_ptr = packed_matrix->data +
+ start_col * packed_stride +
+ block_row * KernelCols;
+
+ Pack8bitRowMajorForNeon(
+ reinterpret_cast<const std::uint8_t*>(src_ptr), src_stride,
+ src_matrix.layout.rows, src_matrix.layout.cols, block_row, start_col,
+ end_col, packed_ptr, packed_stride, packed_matrix->zero_point, sums,
+ kInputXor, KernelCols);
+ }
+ }
+};
+#endif
+
} // namespace ruy
#endif // RUY_RUY_PACK_ARM_H_