| /* Copyright 2020 Google LLC. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <algorithm> |
| #include <cstdint> |
| #include <cstring> |
| |
| #include "ruy/check_macros.h" |
| #include "ruy/kernel_common.h" |
| #include "ruy/kernel_x86.h" |
| #include "ruy/opt_set.h" |
| #include "ruy/platform.h" |
| #include "ruy/profiler/instrumentation.h" |
| |
| #if RUY_PLATFORM_AVX && RUY_OPT(ASM) |
| #include <immintrin.h> // IWYU pragma: keep |
| #endif |
| |
| namespace ruy { |
| |
| #if !(RUY_PLATFORM_AVX && RUY_OPT(ASM)) |
| |
| void Kernel8bitAvx(const KernelParams8bit<8, 8>&) { |
| // CPU-ID-based checks should disable the path that would reach this point. |
| RUY_DCHECK(false); |
| } |
| |
| void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>&) { |
| // CPU-ID-based checks should disable the path that would reach this point. |
| RUY_DCHECK(false); |
| } |
| |
| void KernelFloatAvx(const KernelParamsFloat<8, 8>&) { |
| // CPU-ID-based checks should disable the path that would reach this point. |
| RUY_DCHECK(false); |
| } |
| |
| void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>&) { |
| // CPU-ID-based checks should disable the path that would reach this point. |
| RUY_DCHECK(false); |
| } |
| |
| #else // RUY_PLATFORM_AVX && RUY_OPT(ASM) |
| |
| static constexpr int kAvx8bitBlockSize = 8; |
| static constexpr int kAvx8bitInnerSize = 4; |
| |
| namespace { |
| namespace intrin_utils { |
| |
| template <> |
| inline __m256i mm256_shuffle_epi8<Path::kAvx>(const __m256i& a, |
| const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i dst_lo = _mm_shuffle_epi8(a_lo, b_lo); |
| __m128i dst_hi = _mm_shuffle_epi8(a_hi, b_hi); |
| return _mm256_set_m128i(dst_hi, dst_lo); |
| } |
| |
| template <> |
| inline __m128i mm256_extracti128_si256<Path::kAvx>(const __m256i& a, |
| const int imm) { |
| switch (imm) { |
| case 0: |
| return _mm256_extractf128_si256(a, 0); |
| case 1: |
| return _mm256_extractf128_si256(a, 1); |
| default: |
| RUY_DCHECK_LT(imm, 2); |
| return _mm_setzero_si128(); |
| } |
| } |
| |
| template <Path path> |
| inline __m256i mm256_cvtepi8_epi16(const __m128i& a) { |
| // Take the upper 64 bits of a and put in the first 64 bits of 'hi' |
| __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128()); |
| return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a)); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_cvtepi32_epi64(const __m128i& a) { |
| // sign extend the 32-bit values in the lower 64 bits of a. |
| __m128i lo = _mm_cvtepi32_epi64(a); |
| __m128i hi = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(a, _mm_setzero_si128())); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b, |
| const int imm) { |
| __m128i tmp = _mm_setzero_si128(); |
| if (!(imm & 8)) { |
| switch (imm & 3) { |
| case 0: |
| return _mm256_extractf128_si256(a, 0); |
| case 1: |
| return _mm256_extractf128_si256(a, 1); |
| case 2: |
| return _mm256_extractf128_si256(b, 0); |
| case 3: |
| return _mm256_extractf128_si256(b, 1); |
| } |
| } |
| return tmp; |
| } |
| |
| template <Path path> |
| inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b, |
| const int imm) { |
| const int lo_imm = imm & 15; |
| __m128i lo = mm_permute_helper(a, b, lo_imm); |
| const int hi_imm = (imm >> 4) & 15; |
| __m128i hi = mm_permute_helper(a, b, hi_imm); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_max_epi32(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i lo = _mm_max_epi32(a_lo, b_lo); |
| __m128i hi = _mm_max_epi32(a_hi, b_hi); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_min_epi32(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i lo = _mm_min_epi32(a_lo, b_lo); |
| __m128i hi = _mm_min_epi32(a_hi, b_hi); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i lo = _mm_add_epi32(a_lo, b_lo); |
| __m128i hi = _mm_add_epi32(a_hi, b_hi); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_add_epi64(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i lo = _mm_add_epi64(a_lo, b_lo); |
| __m128i hi = _mm_add_epi64(a_hi, b_hi); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_slli_epi64(const __m256i& a, int imm) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i lo = _mm_slli_epi64(a_lo, imm); |
| __m128i hi = _mm_slli_epi64(a_hi, imm); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_mullo_epi32(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i lo = _mm_mullo_epi32(a_lo, b_lo); |
| __m128i hi = _mm_mullo_epi32(a_hi, b_hi); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| // Defined as a macro since `imm` must be an immediate. |
| #define BlendM128_epi32(a, b, imm) \ |
| _mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), imm)) |
| |
| // Defined as a macro since `imm` must be an immediate. |
| #define BlendM128_epi64(a, b, imm) \ |
| _mm_castpd_si128(_mm_blend_pd(_mm_castsi128_pd(a), _mm_castsi128_pd(b), imm)) |
| |
| // Defined as a macro since `imm` must be an immediate. |
| #define mm256_blend_epi32(ans, a, b, imm) \ |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); \ |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); \ |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); \ |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); \ |
| __m128i lo = BlendM128_epi32(a_lo, b_lo, imm & 0xe); \ |
| __m128i hi = BlendM128_epi32(a_hi, b_hi, imm >> 4); \ |
| ans = _mm256_set_m128i(hi, lo); |
| |
| #define mm256_shuffle_epi32(ans, a, a_lo, a_hi, imm) \ |
| a_lo = _mm256_extractf128_si256(a, 0); \ |
| a_hi = _mm256_extractf128_si256(a, 1); \ |
| ans = _mm256_set_m128i(_mm_shuffle_epi32(a_hi, imm), \ |
| _mm_shuffle_epi32(a_lo, imm)); |
| |
| template <Path path> |
| inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i lo = _mm_madd_epi16(a_lo, b_lo); |
| __m128i hi = _mm_madd_epi16(a_hi, b_hi); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| inline __m128i mm_srlv_epi64(const __m128i& a, const __m128i& b) { |
| // shift both elements of a by lower 64bits of b. |
| __m128i res_lo = _mm_srl_epi64(a, b); |
| // shift both elements of a by upper 64bits of b. |
| __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128()); |
| __m128i res_hi = _mm_srl_epi64(a, hi_count); |
| // Take the lower 64 bits of res_lo and upper 64 bits of res hi |
| // 1. Swap the upper and lower 64 bits of res_hi |
| __m128i tmp_hi = |
| _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1)); |
| // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. |
| return _mm_unpacklo_epi64(res_lo, tmp_hi); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_srlv_epi64(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i lo = mm_srlv_epi64(a_lo, b_lo); |
| __m128i hi = mm_srlv_epi64(a_hi, b_hi); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| template <Path path> |
| inline __m128i mm_sllv_epi64(const __m128i& a, const __m128i& b) { |
| // shift both elements of a by lower 64bits of b. |
| __m128i res_lo = _mm_sll_epi64(a, b); |
| // shift both elements of a by upper 64bits of b. |
| __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128()); |
| __m128i res_hi = _mm_sll_epi64(a, hi_count); |
| // Take the lower 64 bits of res_lo and upper 64 bits of res hi |
| // 1. Swap the upper and lower 64 bits of res_hi |
| __m128i tmp_hi = |
| _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1)); |
| // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. |
| return _mm_unpacklo_epi64(res_lo, tmp_hi); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_sllv_epi64(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i lo = mm_sllv_epi64<path>(a_lo, b_lo); |
| __m128i hi = mm_sllv_epi64<path>(a_hi, b_hi); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| #define PermuteM128_epi32(a, imm) \ |
| _mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(a), imm)); |
| |
| inline __m128i mm_sllv_epi32(const __m128i& a, const __m128i& b) { |
| // shift all elements of a by first 32bits of b. |
| __m128i res0 = _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1)); |
| |
| // put bits 32-63 of b in the first slot. |
| __m128i tmp1 = PermuteM128_epi32(b, 1); |
| // put bits 32-63 of a in the first slot. |
| __m128i a1 = PermuteM128_epi32(a, 1); |
| // shift all elements of a by second 32bits of b. |
| __m128i res1 = |
| _mm_sll_epi32(a1, BlendM128_epi32(_mm_setzero_si128(), tmp1, 1)); |
| |
| // put bits 64-95 of b in the first slot. |
| __m128i tmp2 = PermuteM128_epi32(b, 2); |
| // shift all elements of a by third 32bits of b. |
| __m128i res2 = |
| _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), tmp2, 1)); |
| |
| // put bits 96-127 of b in the first slot. |
| __m128i tmp3 = PermuteM128_epi32(b, 3); |
| // put bits 96-127 of a in the third slot. |
| __m128i a3 = PermuteM128_epi32(a, 48); |
| // shift all elements of a3 by fourth 32bits of b. |
| __m128i res3 = |
| _mm_sll_epi32(a3, BlendM128_epi32(_mm_setzero_si128(), tmp3, 1)); |
| |
| // Take bits 0-31 of res0, bits 0-31 of res1, |
| // bits 64-95 of res2, and bits 64-95 of res3. |
| // res0 _ _ _ 0 |
| // res1 _ _ _ 1 |
| // res2 _ 2 _ _ |
| // res3 _ 3 _ _ |
| // f_01 _ _ 1 0 |
| // f_23 _ _ 3 2 |
| __m128i f_01 = _mm_unpacklo_epi32(res0, res1); |
| __m128i f_23 = _mm_unpackhi_epi32(res2, res3); |
| // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. |
| return _mm_unpacklo_epi64(f_01, f_23); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_sllv_epi32(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i lo = mm_sllv_epi32(a_lo, b_lo); |
| __m128i hi = mm_sllv_epi32(a_hi, b_hi); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_sub_epi32(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i lo = _mm_sub_epi32(a_lo, b_lo); |
| __m128i hi = _mm_sub_epi32(a_hi, b_hi); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_mul_epi32(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i lo = _mm_mul_epi32(a_lo, b_lo); |
| __m128i hi = _mm_mul_epi32(a_hi, b_hi); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| // Perform the equivalent of mm256_permutevar8x32 with |
| // a second argument of {7, 5, 3, 1, 6, 4, 2, 0} |
| template <Path path> |
| inline __m256i PermuteEpi32EvenOdds(const __m256i& a) { |
| // a_lo = 3 2 1 0 |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| // a_hi = 7 6 5 4 |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| // shuffle a_lo to get 3 1 2 0 |
| __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8); |
| // shuffle a_hi to get 7 5 6 4 |
| __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8); |
| // unpack lo 64 of res_lo and res hi to get 6 4 2 0 |
| __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi); |
| // unpack hi 64 of res_lo and res hi to get 7 5 1 3 |
| __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi); |
| return _mm256_set_m128i(res_hi, res_lo); |
| } |
| |
| template <Path path> |
| inline __m256i AddBiasEpi32(const __m256i& a, const int32_t* bias, int offset) { |
| const __m256i bias0 = _mm256_set1_epi32(*(bias + offset)); |
| return mm256_add_epi32<path>(a, bias0); |
| } |
| |
| __m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b, |
| const __m256i& mask) { |
| __m256 result = |
| _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), |
| _mm256_castsi256_ps(mask)); |
| return _mm256_castps_si256(result); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_cmpgt_epi32(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| __m128i b_lo = _mm256_extractf128_si256(b, 0); |
| __m128i b_hi = _mm256_extractf128_si256(b, 1); |
| __m128i lo = _mm_cmpgt_epi32(a_lo, b_lo); |
| __m128i hi = _mm_cmpgt_epi32(a_hi, b_hi); |
| return _mm256_set_m128i(hi, lo); |
| } |
| |
| template <Path path> |
| inline __m256i mm256_srav_epi32(const __m256i& a, const __m256i& b) { |
| __m128i a_lo = _mm256_extractf128_si256(a, 0); |
| __m128i a_hi = _mm256_extractf128_si256(a, 1); |
| |
| __m128i r0 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 0)); |
| __m128i r1 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 1)); |
| __m128i r2 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 2)); |
| __m128i r3 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 3)); |
| __m128i r4 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 4)); |
| __m128i r5 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 5)); |
| __m128i r6 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 6)); |
| __m128i r7 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 7)); |
| |
| // get element 0 from r0, element 1 from r1 |
| __m128i r01 = BlendM128_epi32(r0, r1, 2); |
| // get element 2 from r2, element 3 from r3 |
| __m128i r23 = BlendM128_epi32(r2, r3, 8); |
| // get element 0 from r4, element 1 from r5 |
| __m128i r45 = BlendM128_epi32(r4, r5, 2); |
| // get element 2 from r6, element 3 from r7 |
| __m128i r67 = BlendM128_epi32(r6, r7, 8); |
| // get lower 64 bits of r01, upper 64 bits of r23 |
| __m128i r0123 = BlendM128_epi64(r01, r23, 2); |
| // get lower 64 bits of r45, upper 64 bits of r67 |
| __m128i r4567 = BlendM128_epi64(r45, r67, 2); |
| return _mm256_set_m128i(r4567, r0123); |
| } |
| |
| // AVX doesn't have fused multiply-add so we define an inline function to be |
| // used in the common code following. |
| template <> |
| inline __m256 MulAdd<Path::kAvx>(const __m256& a, const __m256& b, |
| const __m256& c) { |
| const __m256 prod = _mm256_mul_ps(a, b); |
| return _mm256_add_ps(prod, c); |
| } |
| |
| } // namespace intrin_utils |
| } // namespace |
| |
| template <Path path> |
| void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) { |
| profiler::ScopeLabel label("Kernel kAvx 8-bit"); |
| const std::int8_t splitter_idx_data[32] = { |
| 0, 1, 4, 5, 8, 9, 12, 13, // |
| 2, 3, 6, 7, 10, 11, 14, 15, // |
| 0, 1, 4, 5, 8, 9, 12, 13, // |
| 2, 3, 6, 7, 10, 11, 14, 15 // |
| }; |
| |
| std::int32_t dst_stride = 0; |
| if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) || |
| (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) { |
| dst_stride = params.dst_stride; |
| } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
| dst_stride = params.dst_stride / sizeof(std::int16_t); |
| } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
| dst_stride = params.dst_stride / sizeof(std::int32_t); |
| } else { |
| RUY_DCHECK(false); |
| } |
| |
| const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; |
| void* dst_col_ptr = params.dst_base_ptr; |
| |
| for (int col = params.start_col; col <= params.last_col; |
| col += kAvx8bitBlockSize) { |
| const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
| void* dst_ptr = dst_col_ptr; |
| |
| const std::int32_t lhs_zero_point = params.lhs_zero_point; |
| const bool has_rhs_sums_offsets = |
| (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; |
| std::int32_t rhs_sums_offsets[8]; |
| if (has_rhs_sums_offsets) { |
| const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>( |
| _mm256_set1_epi32(lhs_zero_point), |
| _mm256_loadu_si256( |
| reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col]))); |
| _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), |
| rhs_sums_offset_v); |
| } |
| |
| for (int row = params.start_row; row <= params.last_row; |
| row += kAvx8bitBlockSize) { |
| int channel = |
| (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row; |
| int multiplier_channel = |
| (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0; |
| const int residual_rows = |
| std::min(params.dst_rows - row, kAvx8bitBlockSize); |
| const int residual_cols = |
| std::min(params.dst_cols - col, kAvx8bitBlockSize); |
| |
| const __m256i splitter_idx = _mm256_loadu_si256( |
| reinterpret_cast<__m256i const*>(splitter_idx_data)); |
| |
| __m256i accum_data_v0; |
| __m256i accum_data_v1; |
| __m256i accum_data_v2; |
| __m256i accum_data_v3; |
| __m256i accum_data_v4; |
| __m256i accum_data_v5; |
| __m256i accum_data_v6; |
| __m256i accum_data_v7; |
| |
| // initial_accum_data will be the initialize of each of the |
| // accum_data_* accumulator registers. We compute into it terms that are |
| // identical across columns. |
| __m128i initial_accum_data_lo = _mm_set1_epi32(params.prod_zp_depth); |
| __m128i initial_accum_data_hi = _mm_set1_epi32(params.prod_zp_depth); |
| |
| // In the channels-are-rows case, we can load bias here. |
| if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && |
| !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { |
| initial_accum_data_lo = _mm_add_epi32( |
| initial_accum_data_lo, |
| _mm_loadu_si128( |
| reinterpret_cast<const __m128i*>(params.bias + row))); |
| initial_accum_data_hi = _mm_add_epi32( |
| initial_accum_data_hi, |
| _mm_loadu_si128( |
| reinterpret_cast<const __m128i*>(params.bias + row + 4))); |
| } |
| |
| // Adjustments common across columns. |
| const std::int32_t rhs_zero_point = params.rhs_zero_point; |
| if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { |
| const __m128i rhs_zp = _mm_set1_epi32(rhs_zero_point); |
| const __m128i lhs_sums_offset_lo = _mm_mullo_epi32( |
| rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>( |
| ¶ms.lhs_sums[row]))); |
| const __m128i lhs_sums_offset_hi = _mm_mullo_epi32( |
| rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>( |
| ¶ms.lhs_sums[row + 4]))); |
| |
| initial_accum_data_lo = |
| _mm_sub_epi32(initial_accum_data_lo, lhs_sums_offset_lo); |
| initial_accum_data_hi = |
| _mm_sub_epi32(initial_accum_data_hi, lhs_sums_offset_hi); |
| } |
| |
| // Adjustments differing across columns. |
| if (has_rhs_sums_offsets) { |
| __m256i initial_accum_data = |
| _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo); |
| |
| accum_data_v0 = intrin_utils::mm256_sub_epi32<path>( |
| initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); |
| accum_data_v1 = intrin_utils::mm256_sub_epi32<path>( |
| initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1])); |
| accum_data_v2 = intrin_utils::mm256_sub_epi32<path>( |
| initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2])); |
| accum_data_v3 = intrin_utils::mm256_sub_epi32<path>( |
| initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3])); |
| accum_data_v4 = intrin_utils::mm256_sub_epi32<path>( |
| initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4])); |
| accum_data_v5 = intrin_utils::mm256_sub_epi32<path>( |
| initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5])); |
| accum_data_v6 = intrin_utils::mm256_sub_epi32<path>( |
| initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6])); |
| accum_data_v7 = intrin_utils::mm256_sub_epi32<path>( |
| initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7])); |
| } else { |
| __m256i initial_accum_data = |
| _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo); |
| accum_data_v0 = initial_accum_data; |
| accum_data_v1 = initial_accum_data; |
| accum_data_v2 = initial_accum_data; |
| accum_data_v3 = initial_accum_data; |
| accum_data_v4 = initial_accum_data; |
| accum_data_v5 = initial_accum_data; |
| accum_data_v6 = initial_accum_data; |
| accum_data_v7 = initial_accum_data; |
| } |
| |
| // Finally, in the channels-are-columns case, load bias data here. |
| if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && |
| (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { |
| accum_data_v0 = intrin_utils::AddBiasEpi32<path>(accum_data_v0, |
| params.bias + col, 0); |
| accum_data_v1 = intrin_utils::AddBiasEpi32<path>(accum_data_v1, |
| params.bias + col, 1); |
| accum_data_v2 = intrin_utils::AddBiasEpi32<path>(accum_data_v2, |
| params.bias + col, 2); |
| accum_data_v3 = intrin_utils::AddBiasEpi32<path>(accum_data_v3, |
| params.bias + col, 3); |
| accum_data_v4 = intrin_utils::AddBiasEpi32<path>(accum_data_v4, |
| params.bias + col, 4); |
| accum_data_v5 = intrin_utils::AddBiasEpi32<path>(accum_data_v5, |
| params.bias + col, 5); |
| accum_data_v6 = intrin_utils::AddBiasEpi32<path>(accum_data_v6, |
| params.bias + col, 6); |
| accum_data_v7 = intrin_utils::AddBiasEpi32<path>(accum_data_v7, |
| params.bias + col, 7); |
| } |
| |
| const std::int8_t* lhs_ptr = lhs_col_ptr; |
| const std::int8_t* rhs_ptr = rhs_col_ptr; |
| for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { |
| const __m256i lhs_data = |
| _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); |
| const __m256i rhs_data_8bit = |
| _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr)); |
| |
| // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. |
| std::int32_t rhs_data[16]; |
| const __m128i rhs_data_bottom_lane = |
| _mm256_castsi256_si128(rhs_data_8bit); |
| const __m128i rhs_data_top_lane = |
| _mm256_extractf128_si256(rhs_data_8bit, 1); |
| const __m256i rhs_16_bit_dup_low = |
| intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_bottom_lane); |
| const __m256i rhs_16_bit_dup_high = |
| intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_top_lane); |
| // Now that we have cast the RHS data, we store it so that each value |
| // can be separately loaded in the accumulation loop. |
| _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data), |
| rhs_16_bit_dup_low); |
| _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8), |
| rhs_16_bit_dup_high); |
| |
| // NOTE: There may be opportunities for permuting the data in the |
| // packing code instead of here. |
| const __m256i lhs_data_split = |
| intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx); |
| const __m256i lhs_data_split_expand_bottom = |
| intrin_utils::mm256_cvtepi8_epi16<path>( |
| _mm256_extractf128_si256(lhs_data_split, 0)); |
| const __m256i lhs_data_split_expand_top = |
| intrin_utils::mm256_cvtepi8_epi16<path>( |
| _mm256_extractf128_si256(lhs_data_split, 1)); |
| |
| // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. |
| const __m256i lhs_16_bit_low = |
| intrin_utils::mm256_permute2x128_si256<path>( |
| lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); |
| // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. |
| const __m256i lhs_16_bit_high = |
| intrin_utils::mm256_permute2x128_si256<path>( |
| lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); |
| |
| __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>( |
| rhs_data)); // Load [0 1 2 3 4 5 6 7] |
| __m256i rhs1 = _mm256_lddqu_si256( |
| reinterpret_cast<const __m256i*>(rhs_data + 8)); // Load [8 - 15] |
| __m256i rhs0_3 = |
| _mm256_permute2f128_si256(rhs0, rhs0, 0); // [0 1 2 3 0 1 2 3] |
| __m256i rhs4_7 = |
| _mm256_permute2f128_si256(rhs0, rhs0, 0x11); // [4 5 6 7 4 5 6 7] |
| __m256i rhs8_11 = |
| _mm256_permute2f128_si256(rhs1, rhs1, 0); // [8 9 10 11 8 9 10 11] |
| __m256i rhs12_15 = |
| _mm256_permute2f128_si256(rhs1, rhs1, 17); // [12 - 15, 12 - 15] |
| |
| auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi, |
| __m256i& accum) { |
| // Perform mul-adds on low and high components of accum separately. |
| __m128i accum_lo = _mm256_extractf128_si256(accum, 0); |
| __m128i accum_hi = _mm256_extractf128_si256(accum, 1); |
| |
| __m128i lhs_lo_0 = _mm256_extractf128_si256(lhs_16_bit_low, 0); |
| __m128i lhs_lo_1 = _mm256_extractf128_si256(lhs_16_bit_low, 1); |
| __m128i rhs_dup_lo_0 = _mm256_extractf128_si256(rhs_dup_lo, 0); |
| __m128i rhs_dup_lo_1 = _mm256_extractf128_si256(rhs_dup_lo, 1); |
| __m128i lo_0 = _mm_madd_epi16(lhs_lo_0, rhs_dup_lo_0); |
| __m128i lo_1 = _mm_madd_epi16(lhs_lo_1, rhs_dup_lo_1); |
| |
| accum_lo = _mm_add_epi32(accum_lo, lo_0); |
| accum_hi = _mm_add_epi32(accum_hi, lo_1); |
| |
| __m128i lhs_hi_0 = _mm256_extractf128_si256(lhs_16_bit_high, 0); |
| __m128i lhs_hi_1 = _mm256_extractf128_si256(lhs_16_bit_high, 1); |
| __m128i rhs_dup_hi_0 = _mm256_extractf128_si256(rhs_dup_hi, 0); |
| __m128i rhs_dup_hi_1 = _mm256_extractf128_si256(rhs_dup_hi, 1); |
| __m128i hi_0 = _mm_madd_epi16(lhs_hi_0, rhs_dup_hi_0); |
| __m128i hi_1 = _mm_madd_epi16(lhs_hi_1, rhs_dup_hi_1); |
| |
| accum_lo = _mm_add_epi32(accum_lo, hi_0); |
| accum_hi = _mm_add_epi32(accum_hi, hi_1); |
| accum = _mm256_set_m128i(accum_hi, accum_lo); |
| }; |
| __m256i tmp0, tmp1, tmp2, tmp3; |
| __m128i lo0, lo1, hi0, hi1; |
| mm256_shuffle_epi32(tmp0, rhs0_3, lo0, hi0, 0); |
| mm256_shuffle_epi32(tmp1, rhs0_3, lo1, hi1, 0x55); |
| process_column(tmp0, tmp1, accum_data_v0); |
| mm256_shuffle_epi32(tmp2, rhs0_3, lo0, hi0, 0xaa); |
| mm256_shuffle_epi32(tmp3, rhs0_3, lo1, hi1, 0xff); |
| process_column(tmp2, tmp3, accum_data_v1); |
| |
| mm256_shuffle_epi32(tmp0, rhs4_7, lo0, hi0, 0); |
| mm256_shuffle_epi32(tmp1, rhs4_7, lo1, hi1, 0x55); |
| process_column(tmp0, tmp1, accum_data_v2); |
| mm256_shuffle_epi32(tmp2, rhs4_7, lo0, hi0, 0xaa); |
| mm256_shuffle_epi32(tmp3, rhs4_7, lo1, hi1, 0xff); |
| process_column(tmp2, tmp3, accum_data_v3); |
| |
| mm256_shuffle_epi32(tmp0, rhs8_11, lo0, hi0, 0); |
| mm256_shuffle_epi32(tmp1, rhs8_11, lo1, hi1, 0x55); |
| process_column(tmp0, tmp1, accum_data_v4); |
| mm256_shuffle_epi32(tmp2, rhs8_11, lo0, hi0, 0xaa); |
| mm256_shuffle_epi32(tmp3, rhs8_11, lo1, hi1, 0xff); |
| process_column(tmp2, tmp3, accum_data_v5); |
| |
| mm256_shuffle_epi32(tmp0, rhs12_15, lo0, hi0, 0); |
| mm256_shuffle_epi32(tmp1, rhs12_15, lo1, hi1, 0x55); |
| process_column(tmp0, tmp1, accum_data_v6); |
| mm256_shuffle_epi32(tmp2, rhs12_15, lo0, hi0, 0xaa); |
| mm256_shuffle_epi32(tmp3, rhs12_15, lo1, hi1, 0xff); |
| process_column(tmp2, tmp3, accum_data_v7); |
| |
| lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; |
| rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; |
| } |
| |
| if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { |
| __m256i m_vector; |
| __m256i e_vector; |
| // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. |
| m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
| params.multiplier_fixedpoint + multiplier_channel)); |
| e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
| params.multiplier_exponent + multiplier_channel)); |
| |
| const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>( |
| _mm256_extractf128_si256(m_vector, 0)); |
| const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>( |
| _mm256_extractf128_si256(m_vector, 1)); |
| |
| const __m256i zero_vector = _mm256_setzero_si256(); |
| const __m256i left_shift = |
| intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector); |
| const __m256i neg_e_vector = |
| intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector); |
| const __m256i right_shift = |
| intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector); |
| const __m256i final_right_shift = _mm256_set1_epi32(31); |
| const __m256i final_right_shift_low = |
| intrin_utils::mm256_cvtepi32_epi64<path>( |
| _mm256_extractf128_si256(final_right_shift, 0)); |
| const __m256i final_right_shift_high = |
| intrin_utils::mm256_cvtepi32_epi64<path>( |
| _mm256_extractf128_si256(final_right_shift, 1)); |
| const __m256i convert_to_unsigned_64 = |
| _mm256_set1_epi64x(0x8000000000000000); |
| |
| __m256i post_scaling_offset = _mm256_setzero_si256(); |
| |
| // A "half" added for rounding prior to truncation of 64-bit value. |
| const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>( |
| intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30), |
| convert_to_unsigned_64); |
| |
| if (params.dst_zero_point) { |
| post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); |
| } |
| |
| // We cannot do |
| // |
| // scaled_v_low = |
| // _mm256_srav_epi64(scaled_v_low, final_right_shift_low); |
| // scaled_v_high = |
| // _mm256_srav_epi64(scaled_v_high, final_right_shift_high); |
| // |
| // since this instruction is not in AVX2. Instead we use |
| // _mm256_srlv_epi64, but this is an unsigned shift, so we applied |
| // offsets before (convert_to_unsigned_64) and after |
| // (convert_to_signed_halved). |
| // |
| // The overall process is, for 64-bit scaled accumulator: |
| // unsigned_accum = signed_accum + 1 << 63; |
| // unsigned_accum = (unsigned_accum >> right_shift) >> 31; |
| // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2; |
| |
| // There are various ways to repack the results, in the absence of |
| // _mm256_cvtepi64_epi32() or anything like it. |
| // A. |
| // accum_data_v[j] = |
| // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6), |
| // _mm256_extract_epi32(scaled_v_high, 4), |
| // _mm256_extract_epi32(scaled_v_high, 2), |
| // _mm256_extract_epi32(scaled_v_high, 0), |
| // _mm256_extract_epi32(scaled_v_low, 6), |
| // _mm256_extract_epi32(scaled_v_low, 4), |
| // _mm256_extract_epi32(scaled_v_low, 2), |
| // _mm256_extract_epi32(scaled_v_low, 0)); |
| // B. |
| // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8); |
| // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8); |
| // accum_data_v[j] = |
| // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2), |
| // _mm256_extract_epi64(scaled_v_high, 0), |
| // _mm256_extract_epi64(scaled_v_low, 2), |
| // _mm256_extract_epi64(scaled_v_low, 0)); |
| // C. |
| // scaled_v_low = |
| // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm); |
| // scaled_v_high = |
| // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm); |
| // accum_data_v[j] = |
| // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20); |
| // |
| // However, we choose the following because it uses two lighter |
| // instructions. The permutation does have a longer latency, but this |
| // loop can be unrolled. |
| // D. |
| // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); |
| // __m256i results = |
| // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); |
| // results = _mm256_permutevar8x32_epi32(results, repack_perm); |
| // accum_data_v[j] = intrin_utils::mm256_add_epi32<path>(results, |
| // post_scaling_offset); |
| |
| // This multiplier code is complex and expensive enough on x86, that |
| // we prefer to implement the channels-are-columns case by transposing |
| // around it, rather than duplicate it (which would also require |
| // duplicating the above code computing the multiplier constants). |
| // This is one instance where channels-are-columns has lower performance |
| // than channels-are-rows. |
| const bool transpose_around_multiplier = |
| (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) && |
| (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL); |
| if (transpose_around_multiplier) { |
| // Transpose the 8x8 accumulators block. Will be un-transposed below |
| // after the multplier implementation. |
| intrin_utils::mm256_transpose8x8_epi32<path>( |
| &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, |
| &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); |
| } |
| |
| auto rounding_right_shift = [=](__m256i& results, |
| const __m256i& exponent) { |
| // Construct the "nudge" value for each lane if the exponent is |
| // greater than 0. Otherwise, the nudge is 0. |
| const __m256i zeros = _mm256_setzero_si256(); |
| const __m256i mask_rightshift_gtz = |
| intrin_utils::mm256_cmpgt_epi32<path>(exponent, zeros); |
| const __m256i one_shift_exp_minus1 = |
| intrin_utils::mm256_sllv_epi32<path>( |
| _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( |
| exponent, _mm256_set1_epi32(1))); |
| __m256i nudge = intrin_utils::mm256_blendv_epi32( |
| zeros, one_shift_exp_minus1, mask_rightshift_gtz); |
| // Calculate the shifted sum (results + nudge) >> exp. |
| const __m256i r_plus_nudge = |
| intrin_utils::mm256_add_epi32<path>(results, nudge); |
| const __m256i shifted_sum = |
| intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, exponent); |
| |
| // Identify overflow in each lane and create mask. |
| const __m256i one_shift_31minus_exp = |
| intrin_utils::mm256_sllv_epi32<path>( |
| _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( |
| _mm256_set1_epi32(31), exponent)); |
| const __m256i mask_num_plus_nudge_overflow = |
| intrin_utils::mm256_cmpgt_epi32<path>( |
| results, intrin_utils::mm256_sub_epi32<path>( |
| _mm256_set1_epi32(0x7fffffff), nudge)); |
| // Fill results with either (results + nudge) >> exponent or |
| // 1 << (31 - exp) in the case of overflow. |
| results = intrin_utils::mm256_blendv_epi32( |
| shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); |
| }; |
| |
| auto apply_multiplier = [=](__m256i& accum) { |
| __m256i shifted_accum = |
| intrin_utils::mm256_sllv_epi32<path>(accum, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>( |
| intrin_utils::mm256_cvtepi32_epi64<path>( |
| _mm256_extractf128_si256(shifted_accum, 0)), |
| m_64bit_low); |
| __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>( |
| intrin_utils::mm256_cvtepi32_epi64<path>( |
| _mm256_extractf128_si256(shifted_accum, 1)), |
| m_64bit_high); |
| scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low, |
| offset_vector); |
| scaled_v_high = intrin_utils::mm256_add_epi64<path>( |
| scaled_v_high, offset_vector); |
| |
| scaled_v_low = intrin_utils::mm256_srlv_epi64<path>( |
| scaled_v_low, final_right_shift_low); |
| scaled_v_high = intrin_utils::mm256_srlv_epi64<path>( |
| scaled_v_high, final_right_shift_high); |
| scaled_v_high = |
| intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32); |
| __m256i results; |
| mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa); |
| // Permute results to this ordering of int32 elements |
| // lo->hi (0, 2, 4, 6, 1, 3, 5, 7); |
| results = intrin_utils::PermuteEpi32EvenOdds<path>(results); |
| |
| rounding_right_shift(results, right_shift); |
| accum = |
| intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset); |
| }; |
| apply_multiplier(accum_data_v0); |
| apply_multiplier(accum_data_v1); |
| apply_multiplier(accum_data_v2); |
| apply_multiplier(accum_data_v3); |
| apply_multiplier(accum_data_v4); |
| apply_multiplier(accum_data_v5); |
| apply_multiplier(accum_data_v6); |
| apply_multiplier(accum_data_v7); |
| // See above comment: here we transpose again to undo the transposition |
| // of the 8x8 block of accumulators used to implement the |
| // channels-are-columns case. |
| if (transpose_around_multiplier) { |
| intrin_utils::mm256_transpose8x8_epi32<path>( |
| &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, |
| &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); |
| } |
| } |
| const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); |
| const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); |
| const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && |
| (residual_cols == kAvx8bitBlockSize); |
| |
| __m256i accum_data_v[kAvx8bitBlockSize]; |
| if (!store_full_block) { |
| accum_data_v[0] = accum_data_v0; |
| accum_data_v[1] = accum_data_v1; |
| accum_data_v[2] = accum_data_v2; |
| accum_data_v[3] = accum_data_v3; |
| accum_data_v[4] = accum_data_v4; |
| accum_data_v[5] = accum_data_v5; |
| accum_data_v[6] = accum_data_v6; |
| accum_data_v[7] = accum_data_v7; |
| } |
| |
| if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { |
| std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); |
| if (store_full_block) { |
| accum_data_v0 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v); |
| accum_data_v0 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v); |
| accum_data_v1 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v); |
| accum_data_v1 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v); |
| accum_data_v2 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v); |
| accum_data_v2 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v); |
| accum_data_v3 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v); |
| accum_data_v3 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v); |
| accum_data_v4 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v); |
| accum_data_v4 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v); |
| accum_data_v5 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v); |
| accum_data_v5 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v); |
| accum_data_v6 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v); |
| accum_data_v6 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v); |
| accum_data_v7 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v); |
| accum_data_v7 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[0 * dst_stride], accum_data_v0); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[1 * dst_stride], accum_data_v1); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[2 * dst_stride], accum_data_v2); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[3 * dst_stride], accum_data_v3); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[4 * dst_stride], accum_data_v4); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[5 * dst_stride], accum_data_v5); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[6 * dst_stride], accum_data_v6); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[7 * dst_stride], accum_data_v7); |
| } else { |
| for (int j = 0; j < residual_cols; ++j) { |
| __m256i result = accum_data_v[j]; |
| result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); |
| result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); |
| intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( |
| tmp_ptr, residual_rows, result); |
| tmp_ptr += dst_stride; |
| } |
| } |
| dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + |
| kAvx8bitBlockSize); |
| } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { |
| std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); |
| if (store_full_block) { |
| accum_data_v0 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v); |
| accum_data_v0 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v); |
| accum_data_v1 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v); |
| accum_data_v1 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v); |
| accum_data_v2 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v); |
| accum_data_v2 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v); |
| accum_data_v3 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v); |
| accum_data_v3 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v); |
| accum_data_v4 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v); |
| accum_data_v4 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v); |
| accum_data_v5 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v); |
| accum_data_v5 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v); |
| accum_data_v6 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v); |
| accum_data_v6 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v); |
| accum_data_v7 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v); |
| accum_data_v7 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0], |
| accum_data_v0); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride], |
| accum_data_v1); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[2 * dst_stride], accum_data_v2); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[3 * dst_stride], accum_data_v3); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[4 * dst_stride], accum_data_v4); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[5 * dst_stride], accum_data_v5); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[6 * dst_stride], accum_data_v6); |
| intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
| &tmp_ptr[7 * dst_stride], accum_data_v7); |
| } else { |
| for (int j = 0; j < residual_cols; ++j) { |
| __m256i result = accum_data_v[j]; |
| result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); |
| result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); |
| intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( |
| tmp_ptr, residual_rows, result); |
| tmp_ptr += dst_stride; |
| } |
| } |
| dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + |
| kAvx8bitBlockSize); |
| } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
| std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); |
| if (store_full_block) { |
| accum_data_v0 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v); |
| accum_data_v0 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v); |
| accum_data_v1 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v); |
| accum_data_v1 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v); |
| accum_data_v2 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v); |
| accum_data_v2 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v); |
| accum_data_v3 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v); |
| accum_data_v3 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v); |
| accum_data_v4 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v); |
| accum_data_v4 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v); |
| accum_data_v5 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v); |
| accum_data_v5 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v); |
| accum_data_v6 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v); |
| accum_data_v6 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v); |
| accum_data_v7 = |
| intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v); |
| accum_data_v7 = |
| intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v); |
| intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0], |
| accum_data_v0); |
| intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride], |
| accum_data_v1); |
| intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
| &tmp_ptr[2 * dst_stride], accum_data_v2); |
| intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
| &tmp_ptr[3 * dst_stride], accum_data_v3); |
| intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
| &tmp_ptr[4 * dst_stride], accum_data_v4); |
| intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
| &tmp_ptr[5 * dst_stride], accum_data_v5); |
| intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
| &tmp_ptr[6 * dst_stride], accum_data_v6); |
| intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
| &tmp_ptr[7 * dst_stride], accum_data_v7); |
| } else { |
| for (int j = 0; j < residual_cols; ++j) { |
| __m256i result = accum_data_v[j]; |
| result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); |
| result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); |
| intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>( |
| tmp_ptr, residual_rows, result); |
| tmp_ptr += dst_stride; |
| } |
| } |
| dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + |
| kAvx8bitBlockSize); |
| } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
| if (store_full_block) { |
| std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); |
| intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0); |
| intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride], |
| accum_data_v1); |
| intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride], |
| accum_data_v2); |
| intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride], |
| accum_data_v3); |
| intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride], |
| accum_data_v4); |
| intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride], |
| accum_data_v5); |
| intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride], |
| accum_data_v6); |
| intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride], |
| accum_data_v7); |
| } else { |
| std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); |
| for (int j = 0; j < residual_cols; ++j) { |
| intrin_utils::mm256_n_storeu_epi32<path>( |
| dst_block_ptr, residual_rows, accum_data_v[j]); |
| dst_block_ptr += dst_stride; |
| } |
| } |
| dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + |
| kAvx8bitBlockSize); |
| } else { |
| RUY_DCHECK(false); |
| } |
| |
| lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; |
| } // End row-block loop. |
| |
| dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + |
| kAvx8bitBlockSize * params.dst_stride); |
| rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; |
| } // End col-block loop. |
| } // NOLINT(readability/fn_size) |
| |
| void Kernel8bitAvx(const KernelParams8bit<8, 8>& params) { |
| Kernel8bitAvxImpl<Path::kAvx>(params); |
| } |
| |
| template <Path path> |
| void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) { |
| profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV"); |
| |
| RUY_DCHECK_EQ(params.dst_cols, 1); |
| RUY_DCHECK_EQ(params.last_col, 0); |
| RUY_DCHECK_EQ(params.start_col, 0); |
| |
| const std::int8_t splitter_idx_data[32] = { |
| 0, 1, 4, 5, 8, 9, 12, 13, // |
| 2, 3, 6, 7, 10, 11, 14, 15, // |
| 0, 1, 4, 5, 8, 9, 12, 13, // |
| 2, 3, 6, 7, 10, 11, 14, 15 // |
| }; |
| |
| int bias_ptr_block_increment = |
| params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; |
| |
| const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; |
| void* dst_col_ptr = params.dst_base_ptr; |
| const std::int32_t* bias_col_ptr = params.bias; |
| if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { |
| bias_col_ptr += params.start_row; |
| } |
| |
| const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
| void* dst_ptr = dst_col_ptr; |
| const std::int32_t* bias_ptr = bias_col_ptr; |
| |
| const std::int32_t lhs_zero_point = params.lhs_zero_point; |
| const bool has_rhs_sums_offsets = |
| (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; |
| std::int32_t rhs_sums_offsets[8]; |
| if (has_rhs_sums_offsets) { |
| const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>( |
| _mm256_set1_epi32(lhs_zero_point), |
| _mm256_loadu_si256( |
| reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0]))); |
| _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), |
| rhs_sums_offset_v); |
| } |
| |
| for (int row = params.start_row; row <= params.last_row; |
| row += kAvx8bitBlockSize) { |
| const int residual_rows = |
| std::min(params.dst_rows - row, kAvx8bitBlockSize); |
| |
| const __m256i splitter_idx = |
| _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data)); |
| |
| __m256i accum_data_v0; |
| |
| // Initialize with bias. |
| __m256i initial_accum_data = |
| _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr)); |
| bias_ptr += bias_ptr_block_increment; |
| |
| // Adjustments common across columns. |
| const std::int32_t rhs_zero_point = params.rhs_zero_point; |
| if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { |
| const __m256i lhs_sums_offset = intrin_utils::mm256_mullo_epi32<path>( |
| _mm256_set1_epi32(rhs_zero_point), |
| _mm256_loadu_si256( |
| reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); |
| initial_accum_data = intrin_utils::mm256_sub_epi32<path>( |
| initial_accum_data, lhs_sums_offset); |
| } |
| const std::int32_t prod_zp_depth = params.prod_zp_depth; |
| if (prod_zp_depth) { |
| initial_accum_data = intrin_utils::mm256_add_epi32<path>( |
| initial_accum_data, _mm256_set1_epi32(prod_zp_depth)); |
| } |
| |
| // Adjustments differing across columns. |
| if (has_rhs_sums_offsets) { |
| accum_data_v0 = intrin_utils::mm256_sub_epi32<path>( |
| initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); |
| } else { |
| accum_data_v0 = initial_accum_data; |
| } |
| |
| const std::int8_t* lhs_ptr = lhs_col_ptr; |
| const std::int8_t* rhs_ptr = rhs_col_ptr; |
| for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { |
| const __m256i lhs_data = |
| _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); |
| const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr); |
| |
| // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. |
| // For simplicity we load 4x the data that we need and process twice the |
| // data that we need and store only the data we need. |
| std::int32_t rhs_data[2]; |
| const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); |
| // Now that we have cast the RHS data, we store it so that each value |
| // can be separately loaded in the accumulation loop. |
| _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); |
| |
| // NOTE: There may be opportunities for permuting the data in the packing |
| // code instead of here. |
| const __m256i lhs_data_split = |
| intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx); |
| const __m256i lhs_data_split_expand_bottom = |
| intrin_utils::mm256_cvtepi8_epi16<path>( |
| _mm256_extractf128_si256(lhs_data_split, 0)); |
| const __m256i lhs_data_split_expand_top = |
| intrin_utils::mm256_cvtepi8_epi16<path>( |
| _mm256_extractf128_si256(lhs_data_split, 1)); |
| |
| // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. |
| const __m256i lhs_16_bit_low = |
| intrin_utils::mm256_permute2x128_si256<path>( |
| lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); |
| // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. |
| const __m256i lhs_16_bit_high = |
| intrin_utils::mm256_permute2x128_si256<path>( |
| lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); |
| // Accumulate for column 0. |
| const std::int32_t low_rhs_value = rhs_data[0]; |
| const std::int32_t high_rhs_value = rhs_data[1]; |
| |
| const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); |
| const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); |
| |
| accum_data_v0 = intrin_utils::mm256_add_epi32<path>( |
| accum_data_v0, intrin_utils::mm256_madd_epi16<path>( |
| lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_data_v0 = intrin_utils::mm256_add_epi32<path>( |
| accum_data_v0, intrin_utils::mm256_madd_epi16<path>( |
| lhs_16_bit_high, rhs_16_bit_dup_high)); |
| |
| lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; |
| rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; |
| } |
| |
| if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { |
| __m256i m_vector; |
| __m256i e_vector; |
| // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. |
| int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0; |
| m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
| params.multiplier_fixedpoint + channel)); |
| e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
| params.multiplier_exponent + channel)); |
| |
| const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>( |
| _mm256_extractf128_si256(m_vector, 0)); |
| const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>( |
| _mm256_extractf128_si256(m_vector, 1)); |
| |
| const __m256i zero_vector = _mm256_setzero_si256(); |
| const __m256i left_shift = |
| intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector); |
| const __m256i neg_e_vector = |
| intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector); |
| const __m256i right_shift = |
| intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector); |
| const __m256i final_right_shift = _mm256_set1_epi32(31); |
| const __m256i final_right_shift_low = |
| intrin_utils::mm256_cvtepi32_epi64<path>( |
| _mm256_extractf128_si256(final_right_shift, 0)); |
| const __m256i final_right_shift_high = |
| intrin_utils::mm256_cvtepi32_epi64<path>( |
| _mm256_extractf128_si256(final_right_shift, 1)); |
| const __m256i convert_to_unsigned_64 = |
| _mm256_set1_epi64x(0x8000000000000000); |
| |
| __m256i post_scaling_offset = _mm256_setzero_si256(); |
| |
| // A "half" added for rounding prior to truncation of 64-bit value. |
| const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>( |
| intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30), |
| convert_to_unsigned_64); |
| |
| if (params.dst_zero_point) { |
| post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); |
| } |
| |
| // See GEMM version for details of this process. |
| { |
| __m256i shifted_accum = |
| intrin_utils::mm256_sllv_epi32<path>(accum_data_v0, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>( |
| intrin_utils::mm256_cvtepi32_epi64<path>( |
| _mm256_extractf128_si256(shifted_accum, 0)), |
| m_64bit_low); |
| __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>( |
| intrin_utils::mm256_cvtepi32_epi64<path>( |
| _mm256_extractf128_si256(shifted_accum, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low, |
| offset_vector); |
| scaled_v_high = intrin_utils::mm256_add_epi64<path>(scaled_v_high, |
| offset_vector); |
| |
| scaled_v_low = intrin_utils::mm256_srlv_epi64<path>( |
| scaled_v_low, final_right_shift_low); |
| scaled_v_high = intrin_utils::mm256_srlv_epi64<path>( |
| scaled_v_high, final_right_shift_high); |
| |
| scaled_v_high = intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32); |
| __m256i results; |
| mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa); |
| // Permute results to this ordering of int32 elements |
| // lo->hi (0, 2, 4, 6, 1, 3, 5, 7); |
| results = intrin_utils::PermuteEpi32EvenOdds<path>(results); |
| |
| // Now perform the Rounding Right Shift. |
| // First, construct the "nudge" value for each lane if the exponent is |
| // greater than 0. Otherwise, the nudge is 0. |
| const __m256i zeros = _mm256_setzero_si256(); |
| const __m256i mask_rightshift_gtz = |
| intrin_utils::mm256_cmpgt_epi32<path>(right_shift, zeros); |
| const __m256i one_shift_exp_minus1 = |
| intrin_utils::mm256_sllv_epi32<path>( |
| _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( |
| right_shift, _mm256_set1_epi32(1))); |
| __m256i nudge = intrin_utils::mm256_blendv_epi32( |
| zeros, one_shift_exp_minus1, mask_rightshift_gtz); |
| // Calculate the shifted sum (results + nudge) >> exp. |
| const __m256i r_plus_nudge = |
| intrin_utils::mm256_add_epi32<path>(results, nudge); |
| const __m256i shifted_sum = |
| intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, right_shift); |
| |
| // Identify overflow in each lane and create mask. |
| const __m256i one_shift_31minus_exp = |
| intrin_utils::mm256_sllv_epi32<path>( |
| _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( |
| _mm256_set1_epi32(31), right_shift)); |
| const __m256i mask_num_plus_nudge_overflow = |
| intrin_utils::mm256_cmpgt_epi32<path>( |
| results, intrin_utils::mm256_sub_epi32<path>( |
| _mm256_set1_epi32(0x7fffffff), nudge)); |
| // Fill results with either (results + nudge) >> exponent or |
| // 1 << (31 - exp) in the case of overflow. |
| results = intrin_utils::mm256_blendv_epi32( |
| shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); |
| accum_data_v0 = |
| intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset); |
| } |
| } |
| const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); |
| const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); |
| |
| if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { |
| std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); |
| __m256i result = accum_data_v0; |
| result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); |
| result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); |
| intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, |
| result); |
| dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + |
| kAvx8bitBlockSize); |
| } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { |
| std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); |
| __m256i result = accum_data_v0; |
| result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); |
| result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); |
| intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, |
| result); |
| dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + |
| kAvx8bitBlockSize); |
| } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
| std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); |
| __m256i result = accum_data_v0; |
| result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); |
| result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); |
| intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows, |
| result); |
| dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + |
| kAvx8bitBlockSize); |
| } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
| std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); |
| intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows, |
| accum_data_v0); |
| dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + |
| kAvx8bitBlockSize); |
| } else { |
| RUY_DCHECK(false); |
| } |
| |
| lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; |
| } // End row-block loop. |
| |
| dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + |
| kAvx8bitBlockSize * params.dst_stride); |
| rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; |
| } // NOLINT(readability/fn_size) |
| |
| void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params) { |
| Kernel8bitAvxSingleColImpl<Path::kAvx>(params); |
| } |
| |
| void KernelFloatAvx(const KernelParamsFloat<8, 8>& params) { |
| profiler::ScopeLabel label("Kernel kAvx float"); |
| KernelFloatAvxCommon<Path::kAvx>(params); |
| } |
| |
| void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params) { |
| profiler::ScopeLabel label("Kernel kAvx float GEMV"); |
| KernelFloatAvxCommonSingleCol<Path::kAvx>(params); |
| } |
| |
| #endif // RUY_PLATFORM_AVX && RUY_OPT(ASM) |
| |
| } // namespace ruy |