Snap for 6439596 from 18d2e95977453f4c06f999efe21301f5844fafa9 to qt-aml-tzdata-release
Change-Id: I7e68f726e2d4512003b9097bc37aaae81a3922bf
diff --git a/doc/kernel.md b/doc/kernel.md
index f3f2138..261cb92 100644
--- a/doc/kernel.md
+++ b/doc/kernel.md
@@ -40,15 +40,11 @@
The meaning of these terms is explained in the lengthy comment at the top of
internal/kernel.h. Here, they mean that this kernel handles at each iteration
-(along the depth dimension):
-
-- 3 'cells' of size 4x2 each of the lhs, so a total lhs block of size 12x2
-
-- 1 'cell' of size 2x4 of the rhs.
-
-In other words, this kernel handles 12 rows of the lhs and 4 columns of the
-rhs, and handles two levels of depth at once. The 'cells' and `CellFormat`
-detail the layout of these 12x2 and 2x4 blocks.
+(along the depth dimension): - 3 'cells' of size 4x2 each of the lhs, so a total
+lhs block of size 12x2 - 1 'cell' of size 2x4 of the rhs. In other words, this
+kernel handles 12 rows of the lhs and 4 columns of the rhs, and handles two
+levels of depth at once. The 'cells' and `CellFormat` detail the layout of these
+12x2 and 2x4 blocks.
This kernel then loads these 12x2 and 2x4 blocks and computes the corresponding
12x4 GEMM; for ease of reference let us paste the critical comment and code
diff --git a/doc/public.md b/doc/public.md
index 7739b85..935f6db 100644
--- a/doc/public.md
+++ b/doc/public.md
@@ -14,7 +14,7 @@
multiplication is explained in [low-precision.md](low-precision.md). The
rationale for a specific quantization paradigm is given in
[quantization.md](quantization.md). That specific quantization paradigm is
-implemented at two different stages of the computation: as pre-processing on
+implemented at two different stages of the computation: as pre-processing ont
the operands and as post-processing on the result:
* Pre-processing on the LHS, RHS operands, in the form of adding constant
@@ -56,7 +56,7 @@
* `InputScalar`: The scalar type of the LHS and RHS operands. At the moment,
this must be `std::uint8_t`.
-* `OutputScalar`: The scalar type of the result. At the moment,
+* `OutputScalar`: The scalar type of the LHS and RHS operands. At the moment,
this must be `std::uint8_t`.
* `BitDepthParams`: Defines the bit format of the input and output matrices
and the required accuracy of the computation. At the moment, the only
diff --git a/doc/quantization.md b/doc/quantization.md
index e5055e7..3a8f72b 100644
--- a/doc/quantization.md
+++ b/doc/quantization.md
@@ -13,7 +13,7 @@
perform, specifically, it affects how one goes from internal 32bit accumulator
to final 8bit outputs.
-The part of gemmlowp transforming internal 32bit accumulator to final
+The part of gemmlowp transforming internal internal 32bit accumulator to final
8bit outputs is the "output pipeline" described in [output.md](output.md).
gemmlowp's `GemmWithOutputPipeline` entry point allows specifying an arbitrary
diff --git a/eight_bit_int_gemm/eight_bit_int_gemm.cc b/eight_bit_int_gemm/eight_bit_int_gemm.cc
index a8d9b43..512c483 100644
--- a/eight_bit_int_gemm/eight_bit_int_gemm.cc
+++ b/eight_bit_int_gemm/eight_bit_int_gemm.cc
@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
+#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
+#endif
#include "eight_bit_int_gemm.h"
#include <memory>
diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h
index 58e8050..d39341b 100644
--- a/fixedpoint/fixedpoint.h
+++ b/fixedpoint/fixedpoint.h
@@ -18,13 +18,10 @@
#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
#define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
-#include <algorithm>
#include <cassert>
-#include <cmath>
-#include <cstdint>
#include <limits>
-#include "../internal/detect_platform.h"
+#include "../internal/common.h"
namespace gemmlowp {
@@ -50,13 +47,13 @@
template <>
struct FixedPointRawTypeTraits<std::int32_t> {
typedef std::int32_t ScalarRawType;
- static constexpr int kLanes = 1;
+ static const int kLanes = 1;
};
template <>
struct FixedPointRawTypeTraits<std::int16_t> {
typedef std::int16_t ScalarRawType;
- static constexpr int kLanes = 1;
+ static const int kLanes = 1;
};
// Returns a SIMD value duplicating a scalar value across all lanes.
@@ -112,25 +109,11 @@
return -a;
}
-// Integer arithmetic left-shift, equivalent to multiplying with a power of two.
-// Negative values are OK. In case of overflow, no Undefined
-// Behavior, but the results are implementation-defined (in practice,
-// they currently are saturated, but we make no commitment to that). The idea
-// is that the caller will want to implement the overflowing cases with
-// saturation with compare-and-mask, so we don't care about the results
-// in the overflow case, we just want to avoid undefined behavior.
-//
-// tIntegerType may be int32 or any narrower signed type.
-template <typename tIntegerType, typename OffsetType>
-tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) {
- const std::int64_t wide_a = static_cast<std::int64_t>(a);
- const std::int64_t wide_shifted = wide_a * (1 << offset);
- const auto min = std::numeric_limits<tIntegerType>::min();
- const auto max = std::numeric_limits<tIntegerType>::max();
- return wide_shifted < min
- ? min
- : wide_shifted > max ? max
- : static_cast<tIntegerType>(wide_shifted);
+// Integer arithmetic left-shift, equivalent to multiplying with a
+// power of two. Not saturating. Overflow is undefined behavior.
+template <typename tIntegerType>
+tIntegerType ShiftLeft(tIntegerType a, int offset) {
+ return a << offset;
}
// Integer arithmetic right-shift. Not rounding.
@@ -154,7 +137,7 @@
// input scalar is non-zero.
template <typename tIntegerType>
tIntegerType MaskIfNonZero(tIntegerType a) {
- static constexpr tIntegerType zero = 0;
+ static const tIntegerType zero = 0;
return a ? BitNot(zero) : zero;
}
@@ -228,7 +211,6 @@
template <typename IntegerType>
IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
- (void)b;
return a;
}
@@ -253,7 +235,6 @@
template <typename IntegerType>
IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
- (void)b;
return a;
}
@@ -263,9 +244,7 @@
std::int32_t a32 = a;
std::int32_t b32 = b;
std::int32_t sum = a32 + b32;
- return static_cast<std::int16_t>(
- std::min(static_cast<std::int32_t>(32767),
- std::max(static_cast<std::int32_t>(-32768), sum)));
+ return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum)));
}
// Returns a+b, saturating if the integers are 16bit or narrower,
@@ -319,7 +298,6 @@
template <typename IntegerType>
IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
- (void)b;
return a;
}
@@ -353,8 +331,8 @@
// Correctly-rounded-to-nearest division by a power-of-two.
// Also known as a rounding arithmetic right shift.
-template <typename IntegerType, typename ExponentType>
-inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
+template <typename IntegerType>
+inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
assert(exponent >= 0);
assert(exponent <= 31);
const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
@@ -454,9 +432,9 @@
typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
- static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
- static constexpr int kIntegerBits = tIntegerBits;
- static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
+ static const int kTotalBits = 8 * sizeof(ScalarRawType);
+ static const int kIntegerBits = tIntegerBits;
+ static const int kFractionalBits = kTotalBits - 1 - kIntegerBits;
static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
"bad IntegerBits");
@@ -496,7 +474,7 @@
template <int Exponent>
static FixedPoint ConstantPOT() {
- static constexpr int kOffset = kFractionalBits + Exponent;
+ static const int kOffset = kFractionalBits + Exponent;
static_assert(
kOffset < 31,
"Constant not exactly representable in this fixed-point format");
@@ -667,7 +645,7 @@
template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
FixedPoint<tRawType, tIntegerBitsDst> Rescale(
FixedPoint<tRawType, tIntegerBitsSrc> x) {
- static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
+ static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
FixedPoint<tRawType, tIntegerBitsDst> result;
result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
return result;
@@ -747,9 +725,9 @@
FixedPoint<tRawType, tIntegerBits> a) {
typedef FixedPoint<tRawType, tIntegerBits> InputF;
typedef FixedPoint<tRawType, 0> ResultF;
- static constexpr int kFractionalBits = InputF::kFractionalBits;
- static constexpr int kIntegerBits = InputF::kIntegerBits;
- const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
+ static const int kFractionalBits = InputF::kFractionalBits;
+ static const int kIntegerBits = InputF::kIntegerBits;
+ static const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
@@ -777,10 +755,10 @@
#undef GEMMLOWP_EXP_BARREL_SHIFTER
- static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
if (kIntegerBits > 5) {
+ static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
const InputF clamp =
- GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0);
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
}
@@ -889,8 +867,6 @@
#ifdef GEMMLOWP_NEON
#include "./fixedpoint_neon.h"
-#elif defined(GEMMLOWP_AVX2)
-#include "./fixedpoint_avx.h"
#elif defined(GEMMLOWP_SSE4)
#include "./fixedpoint_sse.h"
#elif defined(GEMMLOWP_MSA)
diff --git a/fixedpoint/fixedpoint_avx.h b/fixedpoint/fixedpoint_avx.h
deleted file mode 100644
index 1816386..0000000
--- a/fixedpoint/fixedpoint_avx.h
+++ /dev/null
@@ -1,218 +0,0 @@
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// fixedpoint_avx.h: optimized avx specializations of the templates
-// in fixedpoint.h.
-
-#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
-#define GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
-
-#include <smmintrin.h>
-#include "fixedpoint.h"
-#include "fixedpoint_sse.h"
-
-namespace gemmlowp {
-
-template <>
-struct FixedPointRawTypeTraits<__m256i> {
- typedef std::int32_t ScalarRawType;
- static const int kLanes = 4;
-};
-
-template <>
-inline __m256i BitAnd(__m256i a, __m256i b) {
- return _mm256_and_si256(a, b);
-}
-
-template <>
-inline __m256i BitOr(__m256i a, __m256i b) {
- return _mm256_or_si256(a, b);
-}
-
-template <>
-inline __m256i BitXor(__m256i a, __m256i b) {
- return _mm256_xor_si256(a, b);
-}
-
-template <>
-inline __m256i BitNot(__m256i a) {
- return _mm256_andnot_si256(a, _mm256_set1_epi32(-1));
-}
-
-template <>
-inline __m256i Add(__m256i a, __m256i b) {
- return _mm256_add_epi32(a, b);
-}
-
-template <>
-inline __m256i Mul(__m256i a, __m256i b) {
- return _mm256_mullo_epi32(a, b);
-}
-
-template <>
-inline __m256i Sub(__m256i a, __m256i b) {
- return _mm256_sub_epi32(a, b);
-}
-
-template <>
-inline __m256i Neg(__m256i a) {
- return _mm256_sign_epi32(a, _mm256_set1_epi32(-1));
-}
-
-template <>
-inline __m256i ShiftLeft(__m256i a, int offset) {
- return _mm256_slli_epi32(a, offset);
-}
-
-template <>
-inline __m256i ShiftRight(__m256i a, int offset) {
- return _mm256_srai_epi32(a, offset);
-}
-
-template <>
-inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val,
- __m256i else_val) {
- return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(else_val),
- _mm256_castsi256_ps(then_val),
- _mm256_castsi256_ps(if_mask)));
-}
-
-template <>
-inline __m256i MaskIfEqual(__m256i a, __m256i b) {
- return _mm256_cmpeq_epi32(a, b);
-}
-
-template <>
-inline __m256i MaskIfNotEqual(__m256i a, __m256i b) {
- return BitNot(MaskIfEqual(a, b));
-}
-
-template <>
-inline __m256i MaskIfZero(__m256i a) {
- return MaskIfEqual(a, _mm256_set1_epi32(0));
-}
-
-template <>
-inline __m256i MaskIfNonZero(__m256i a) {
- return MaskIfNotEqual(a, _mm256_set1_epi32(0));
-}
-
-template <>
-inline __m256i MaskIfGreaterThan(__m256i a, __m256i b) {
- return _mm256_cmpgt_epi32(a, b);
-}
-
-template <>
-inline __m256i MaskIfLessThan(__m256i a, __m256i b) {
- return _mm256_cmpgt_epi32(b, a);
-}
-
-template <>
-inline __m256i MaskIfGreaterThanOrEqual(__m256i a, __m256i b) {
- return BitNot(MaskIfLessThan(a, b));
-}
-
-template <>
-inline __m256i MaskIfLessThanOrEqual(__m256i a, __m256i b) {
- return BitNot(MaskIfGreaterThan(a, b));
-}
-
-/* Assumptions:
- - All and Any are used on masks.
- - masks are all_ones for true lanes, all_zeroes otherwise.
-Hence, All means all 128bits set, and Any means any bit set.
-*/
-
-template <>
-inline bool All(__m256i a) {
- return _mm256_testc_si256(a, a);
-}
-
-template <>
-inline bool Any(__m256i a) {
- return BitNot(_mm256_testz_si256(a, a));
-}
-
-template <>
-inline __m256i RoundingHalfSum(__m256i a, __m256i b) {
- /* __m256i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */
- /* We divide the inputs before the add to avoid the overflow and costly test
- */
- /* of checking if an overflow occured on signed add */
- /* round_bit_mask = _mm_set1_epi32(1); */
- /* a_over_2 = _mm_srai_epi32(a, 1); */
- /* b_over_2 = _mm_srai_epi32(b, 1); */
- /* sum = Add(a_over_2, b_over_2); */
- /* round_bit = _mm_sign_epi32(BitAnd(BitOr(a,b), round_bit_mask), sum); */
- /* return Add(sum, round_bit); */
-
- /* Other possibility detecting overflow and xor the sign if an overflow
- * happened*/
- __m256i one, sign_bit_mask, sum, rounded_half_sum, overflow, result;
- one = _mm256_set1_epi32(1);
- sign_bit_mask = _mm256_set1_epi32(0x80000000);
- sum = Add(a, b);
- rounded_half_sum = _mm256_srai_epi32(Add(sum, one), 1);
- overflow =
- BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)),
- sign_bit_mask);
- result = BitXor(rounded_half_sum, overflow);
- return result;
-}
-
-template <>
-inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) {
- __m256i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3;
- __m256i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded;
- __m256i a0b0_a2b2_rounded_2x, a1b1_a3b3_rounded_2x, result;
- __m256i nudge;
-
- // saturation only happen if a == b == INT_MIN
- min = _mm256_set1_epi32(std::numeric_limits<std::int32_t>::min());
- saturation_mask = BitAnd(MaskIfEqual(a, b), MaskIfEqual(a, min));
-
- // a = a0 | a1 | a2 | a3
- // b = b0 | b1 | b2 | b3
- a0_a2 = a;
- a1_a3 = _mm256_srli_si256(a, 4);
- b0_b2 = b;
- b1_b3 = _mm256_srli_si256(b, 4);
-
- a0b0_a2b2 = _mm256_mul_epi32(a0_a2, b0_b2);
- a1b1_a3b3 = _mm256_mul_epi32(a1_a3, b1_b3);
-
- // do the rounding and take into account that it will be doubled
- nudge = _mm256_set1_epi64x(1 << 30);
- a0b0_a2b2_rounded = _mm256_add_epi64(a0b0_a2b2, nudge);
- a1b1_a3b3_rounded = _mm256_add_epi64(a1b1_a3b3, nudge);
-
- // do the doubling
- a0b0_a2b2_rounded_2x = _mm256_slli_epi64(a0b0_a2b2_rounded, 1);
- a1b1_a3b3_rounded_2x = _mm256_slli_epi64(a1b1_a3b3_rounded, 1);
-
- // get the high part of the products
- result = _mm256_blend_epi16(_mm256_srli_si256(a0b0_a2b2_rounded_2x, 4),
- a1b1_a3b3_rounded_2x, 0xcc);
-
- // saturate those which overflowed
- return SelectUsingMask(saturation_mask, min, result);
-}
-
-template <>
-inline __m256i Dup<__m256i>(std::int32_t x) {
- return _mm256_set1_epi32(x);
-}
-
-} // end namespace gemmlowp
-
-#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
diff --git a/fixedpoint/fixedpoint_msa.h b/fixedpoint/fixedpoint_msa.h
index b17f32a..c7a110c 100644
--- a/fixedpoint/fixedpoint_msa.h
+++ b/fixedpoint/fixedpoint_msa.h
@@ -25,13 +25,13 @@
template <>
struct FixedPointRawTypeTraits<v4i32> {
typedef std::int32_t ScalarRawType;
- static constexpr int kLanes = 4;
+ static const int kLanes = 4;
};
template <>
struct FixedPointRawTypeTraits<v8i16> {
typedef std::int16_t ScalarRawType;
- static constexpr int kLanes = 8;
+ static const int kLanes = 8;
};
template <>
@@ -326,71 +326,11 @@
}
};
-template <int Exponent>
-struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> {
- static v4i32 eval(v4i32 x) {
- static_assert(-31 <= Exponent && Exponent <= -1, "");
- // Isolate the sign bits.
- v4i32 sign = __builtin_msa_srli_w(x, 31);
- // Decrement the negative elements by 1 (with saturation).
- x = __builtin_msa_subs_s_w(x, sign);
- // Arithmetic shift right with rounding.
- // The srari instruction rounds all midpoint values towards +infinity.
- // It will correctly round negative midpoint values as we just
- // decremented the negative values by 1.
- return __builtin_msa_srari_w(x, -Exponent);
- }
-};
-
-template <int Exponent>
-struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> {
- static v8i16 eval(v8i16 x) {
- static_assert(-15 <= Exponent && Exponent <= -1, "");
- // Isolate the sign bits.
- v8i16 sign = __builtin_msa_srli_h(x, 15);
- // Decrement the negative elements by 1 (with saturation).
- x = __builtin_msa_subs_s_h(x, sign);
- // Arithmetic shift right with rounding.
- // The srari instruction rounds all midpoint values towards +infinity.
- // It will correctly round negative midpoint values as we just
- // decremented the negative values by 1.
- return __builtin_msa_srari_h(x, -Exponent);
- }
-};
-
-template <>
-inline v4i32 RoundingDivideByPOT(v4i32 x, int exponent) {
- v4i32 e = __builtin_msa_fill_w(exponent);
- // Isolate the sign bits.
- v4i32 sign = __builtin_msa_srli_w(x, 31);
- // Reset them to 0 if exponent is 0.
- sign = __builtin_msa_min_s_w(sign, e);
- // Decrement the negative elements by 1 (with saturation)
- // if exponent is non-zero.
- x = __builtin_msa_subs_s_w(x, sign);
- // Arithmetic shift right with rounding.
- // The srar instruction rounds all midpoint values towards +infinity.
- // It will correctly round negative midpoint values as we just
- // decremented the negative values by 1.
- return __builtin_msa_srar_w(x, e);
-}
-
-template <>
-inline v8i16 RoundingDivideByPOT(v8i16 x, int exponent) {
- v8i16 e = __builtin_msa_fill_h(exponent);
- // Isolate the sign bits.
- v8i16 sign = __builtin_msa_srli_h(x, 15);
- // Reset them to 0 if exponent is 0.
- sign = __builtin_msa_min_s_h(sign, e);
- // Decrement the negative elements by 1 (with saturation)
- // if exponent is non-zero.
- x = __builtin_msa_subs_s_h(x, sign);
- // Arithmetic shift right with rounding.
- // The srar instruction rounds all midpoint values towards +infinity.
- // It will correctly round negative midpoint values as we just
- // decremented the negative values by 1.
- return __builtin_msa_srar_h(x, e);
-}
+// TODO: possibly implement:
+// template <> v4i32 RoundingDivideByPOT(v4i32, int)
+// template <> v8i16 RoundingDivideByPOT(v8i16, int)
+// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1>
+// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1>
template <>
inline v4i32 Dup<v4i32>(std::int32_t x) {
@@ -406,6 +346,7 @@
template <>
inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) {
return __builtin_msa_adds_s_h(a, b);
+ return a;
}
} // end namespace gemmlowp
diff --git a/fixedpoint/fixedpoint_neon.h b/fixedpoint/fixedpoint_neon.h
index 4dab6c9..92b349b 100644
--- a/fixedpoint/fixedpoint_neon.h
+++ b/fixedpoint/fixedpoint_neon.h
@@ -25,13 +25,13 @@
template <>
struct FixedPointRawTypeTraits<int32x4_t> {
typedef std::int32_t ScalarRawType;
- static constexpr int kLanes = 4;
+ static const int kLanes = 4;
};
template <>
struct FixedPointRawTypeTraits<int16x8_t> {
typedef std::int16_t ScalarRawType;
- static constexpr int kLanes = 8;
+ static const int kLanes = 8;
};
template <>
@@ -115,16 +115,6 @@
}
template <>
-inline int32x4_t ShiftLeft(int32x4_t a, int32x4_t offset) {
- return vshlq_s32(a, offset);
-}
-
-template <>
-inline int16x8_t ShiftLeft(int16x8_t a, int16x8_t offset) {
- return vshlq_s16(a, offset);
-}
-
-template <>
inline int32x4_t ShiftRight(int32x4_t a, int offset) {
return vshlq_s32(a, vdupq_n_s32(-offset));
}
@@ -292,22 +282,6 @@
return vrshlq_s16(fixed_up_x, shift_vec);
}
-template <>
-inline int32x4_t RoundingDivideByPOT(int32x4_t x, int32x4_t exponent) {
- const int32x4_t shift_vec = vnegq_s32(exponent);
- const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
- const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
- return vrshlq_s32(fixed_up_x, shift_vec);
-}
-
-template <>
-inline int16x8_t RoundingDivideByPOT(int16x8_t x, int16x8_t exponent) {
- const int16x8_t shift_vec = vnegq_s16(exponent);
- const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15);
- const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
- return vrshlq_s16(fixed_up_x, shift_vec);
-}
-
template <int Exponent>
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> {
static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); }
diff --git a/fixedpoint/fixedpoint_sse.h b/fixedpoint/fixedpoint_sse.h
index a1fae32..ba990f0 100644
--- a/fixedpoint/fixedpoint_sse.h
+++ b/fixedpoint/fixedpoint_sse.h
@@ -42,13 +42,13 @@
template <>
struct FixedPointRawTypeTraits<__m128i> {
typedef std::int32_t ScalarRawType;
- static constexpr int kLanes = 4;
+ static const int kLanes = 4;
};
template <>
struct FixedPointRawTypeTraits<int16x8_m128i> {
typedef std::int16_t ScalarRawType;
- static constexpr int kLanes = 8;
+ static const int kLanes = 8;
};
template <>
diff --git a/internal/common.h b/internal/common.h
index 332ad07..26b6713 100644
--- a/internal/common.h
+++ b/internal/common.h
@@ -26,9 +26,144 @@
#include <cmath>
#include <cstdlib>
-#include "../internal/detect_platform.h"
#include "../profiling/instrumentation.h"
+// Our inline assembly path assume GCC/Clang syntax.
+// Native Client doesn't seem to support inline assembly(?).
+#if defined(__GNUC__) && !defined(__native_client__)
+#define GEMMLOWP_ALLOW_INLINE_ASM
+#endif
+
+// Define macro statement that avoids inlining for GCC.
+// For non-GCC, define as empty macro.
+#if defined(__GNUC__)
+#define GEMMLOWP_NOINLINE __attribute__((noinline))
+#else
+#define GEMMLOWP_NOINLINE
+#endif
+
+// Detect ARM, 32-bit or 64-bit
+#ifdef __arm__
+#define GEMMLOWP_ARM_32
+#endif
+
+#ifdef __aarch64__
+#define GEMMLOWP_ARM_64
+#endif
+
+#if defined(GEMMLOWP_ARM_32) || defined(GEMMLOWP_ARM_64)
+#define GEMMLOWP_ARM
+#endif
+
+// Detect MIPS, 32-bit or 64-bit
+#if defined(__mips) && !defined(__LP64__)
+#define GEMMLOWP_MIPS_32
+#endif
+
+#if defined(__mips) && defined(__LP64__)
+#define GEMMLOWP_MIPS_64
+#endif
+
+#if defined(GEMMLOWP_MIPS_32) || defined(GEMMLOWP_MIPS_64)
+#define GEMMLOWP_MIPS
+#endif
+
+// Detect x86, 32-bit or 64-bit
+#if defined(__i386__) || defined(_M_IX86) || defined(_X86_) || defined(__i386)
+#define GEMMLOWP_X86_32
+#endif
+
+#if defined(__x86_64__) || defined(_M_X64) || defined(__amd64)
+#define GEMMLOWP_X86_64
+#endif
+
+#if defined(GEMMLOWP_X86_32) || defined(GEMMLOWP_X86_64)
+#define GEMMLOWP_X86
+#endif
+
+// Some of our optimized paths use inline assembly and for
+// now we don't bother enabling some other optimized paths using intrinddics
+// where we can't use inline assembly paths.
+#ifdef GEMMLOWP_ALLOW_INLINE_ASM
+
+// Detect NEON. It's important to check for both tokens.
+#if (defined __ARM_NEON) || (defined __ARM_NEON__)
+#define GEMMLOWP_NEON
+#endif
+
+// Convenience NEON tokens for 32-bit or 64-bit
+#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_32)
+#define GEMMLOWP_NEON_32
+#endif
+
+#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_64)
+#define GEMMLOWP_NEON_64
+#endif
+
+// Detect MIPS MSA.
+// Limit MSA optimizations to little-endian CPUs for now.
+// TODO: Perhaps, eventually support MSA optimizations on big-endian CPUs?
+#if defined(GEMMLOWP_MIPS) && (__mips_isa_rev >= 5) && defined(__mips_msa) && \
+ defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+#define GEMMLOWP_MSA
+#endif
+
+// Convenience MIPS MSA tokens for 32-bit or 64-bit.
+#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_32)
+#define GEMMLOWP_MSA_32
+#endif
+
+#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_64)
+#define GEMMLOWP_MSA_64
+#endif
+
+// Detect SSE.
+#ifdef __SSE4_1__
+#define GEMMLOWP_SSE4
+#endif
+
+#ifdef __SSE3__
+#define GEMMLOWP_SSE3
+#endif
+
+// Convenience SSE4 tokens for 32-bit or 64-bit
+#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32) && \
+ !defined(GEMMLOWP_DISABLE_SSE4)
+#define GEMMLOWP_SSE4_32
+#endif
+
+#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_32)
+#define GEMMLOWP_SSE3_32
+#endif
+
+#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64) && \
+ !defined(GEMMLOWP_DISABLE_SSE4)
+#define GEMMLOWP_SSE4_64
+#endif
+
+#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_64)
+#define GEMMLOWP_SSE3_64
+#endif
+
+#if defined(__has_feature)
+#if __has_feature(memory_sanitizer)
+#include <sanitizer/msan_interface.h>
+#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __msan_unpoison
+#elif __has_feature(address_sanitizer)
+#include <sanitizer/asan_interface.h>
+#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __asan_unpoison_memory_region
+#endif
+#endif
+
+#endif // GEMMLOWP_ALLOW_INLINE_ASM
+
+// Detect Android. Don't conflate with ARM - we care about tuning
+// for non-ARM Android devices too. This can be used in conjunction
+// with x86 to tune differently for mobile x86 CPUs (Atom) vs. desktop x86 CPUs.
+#if defined(__ANDROID__) || defined(ANDROID)
+#define GEMMLOWP_ANDROID
+#endif
+
namespace gemmlowp {
// Standard cache line size. Useful to optimize alignment and
@@ -107,12 +242,7 @@
// size, so any size would work there. Different platforms may set this
// to different values but must ensure that their own optimized packing paths
// are consistent with this value.
-
-#ifdef GEMMLOWP_AVX2
-const int kRegisterSize = 32;
-#else
const int kRegisterSize = 16;
-#endif
// Hints the CPU to prefetch the cache line containing ptr.
inline void Prefetch(const void* ptr) {
diff --git a/internal/detect_platform.h b/internal/detect_platform.h
deleted file mode 100644
index 6f06d19..0000000
--- a/internal/detect_platform.h
+++ /dev/null
@@ -1,166 +0,0 @@
-// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// detect_platform.h: Sets up macros that control architecture-specific
-// features of gemmlowp's implementation.
-
-#ifndef GEMMLOWP_INTERNAL_DETECT_PLATFORM_H_
-#define GEMMLOWP_INTERNAL_DETECT_PLATFORM_H_
-
-// Our inline assembly path assume GCC/Clang syntax.
-// Native Client doesn't seem to support inline assembly(?).
-#if defined(__GNUC__) && !defined(__native_client__)
-#define GEMMLOWP_ALLOW_INLINE_ASM
-#endif
-
-// Define macro statement that avoids inlining for GCC.
-// For non-GCC, define as empty macro.
-#if defined(__GNUC__)
-#define GEMMLOWP_NOINLINE __attribute__((noinline))
-#else
-#define GEMMLOWP_NOINLINE
-#endif
-
-// Detect ARM, 32-bit or 64-bit
-#ifdef __arm__
-#define GEMMLOWP_ARM_32
-#endif
-
-#ifdef __aarch64__
-#define GEMMLOWP_ARM_64
-#endif
-
-#if defined(GEMMLOWP_ARM_32) || defined(GEMMLOWP_ARM_64)
-#define GEMMLOWP_ARM
-#endif
-
-// Detect MIPS, 32-bit or 64-bit
-#if defined(__mips) && !defined(__LP64__)
-#define GEMMLOWP_MIPS_32
-#endif
-
-#if defined(__mips) && defined(__LP64__)
-#define GEMMLOWP_MIPS_64
-#endif
-
-#if defined(GEMMLOWP_MIPS_32) || defined(GEMMLOWP_MIPS_64)
-#define GEMMLOWP_MIPS
-#endif
-
-// Detect x86, 32-bit or 64-bit
-#if defined(__i386__) || defined(_M_IX86) || defined(_X86_) || defined(__i386)
-#define GEMMLOWP_X86_32
-#endif
-
-#if defined(__x86_64__) || defined(_M_X64) || defined(__amd64)
-#define GEMMLOWP_X86_64
-#endif
-
-#if defined(GEMMLOWP_X86_32) || defined(GEMMLOWP_X86_64)
-#define GEMMLOWP_X86
-#endif
-
-// Some of our optimized paths use inline assembly and for
-// now we don't bother enabling some other optimized paths using intrinddics
-// where we can't use inline assembly paths.
-#ifdef GEMMLOWP_ALLOW_INLINE_ASM
-
-// Detect NEON. It's important to check for both tokens.
-#if (defined __ARM_NEON) || (defined __ARM_NEON__)
-#define GEMMLOWP_NEON
-#endif
-
-// Convenience NEON tokens for 32-bit or 64-bit
-#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_32)
-#define GEMMLOWP_NEON_32
-#endif
-
-#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_64)
-#define GEMMLOWP_NEON_64
-#endif
-
-// Detect MIPS MSA.
-// Limit MSA optimizations to little-endian CPUs for now.
-// TODO: Perhaps, eventually support MSA optimizations on big-endian CPUs?
-#if defined(GEMMLOWP_MIPS) && (__mips_isa_rev >= 5) && defined(__mips_msa) && \
- defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
-#define GEMMLOWP_MSA
-#endif
-
-// Convenience MIPS MSA tokens for 32-bit or 64-bit.
-#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_32)
-#define GEMMLOWP_MSA_32
-#endif
-
-#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_64)
-#define GEMMLOWP_MSA_64
-#endif
-
-// compiler define for AVX2 -D GEMMLOWP_ENABLE_AVX2
-// Detect AVX2
-#if defined(__AVX2__) && defined(GEMMLOWP_ENABLE_AVX2)
-#define GEMMLOWP_AVX2
-// Detect SSE4.
-// MSVC does not have __SSE4_1__ macro, but will enable SSE4
-// when AVX is turned on.
-#elif defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__))
-#define GEMMLOWP_SSE4
-// Detect SSE3.
-#elif defined(__SSE3__)
-#define GEMMLOWP_SSE3
-#endif
-
-// Convenience SSE4 tokens for 32-bit or 64-bit
-#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32) && \
- !defined(GEMMLOWP_DISABLE_SSE4)
-#define GEMMLOWP_SSE4_32
-#endif
-
-#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_32)
-#define GEMMLOWP_SSE3_32
-#endif
-
-#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64) && \
- !defined(GEMMLOWP_DISABLE_SSE4)
-#define GEMMLOWP_SSE4_64
-#endif
-
-#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_64)
-#define GEMMLOWP_SSE3_64
-#endif
-
-#if defined(GEMMLOWP_AVX2) && defined(GEMMLOWP_X86_64)
-#define GEMMLOWP_AVX2_64
-#endif
-
-#if defined(__has_feature)
-#if __has_feature(memory_sanitizer)
-#include <sanitizer/msan_interface.h>
-#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __msan_unpoison
-#elif __has_feature(address_sanitizer)
-#include <sanitizer/asan_interface.h>
-#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __asan_unpoison_memory_region
-#endif
-#endif
-
-#endif // GEMMLOWP_ALLOW_INLINE_ASM
-
-// Detect Android. Don't conflate with ARM - we care about tuning
-// for non-ARM Android devices too. This can be used in conjunction
-// with x86 to tune differently for mobile x86 CPUs (Atom) vs. desktop x86 CPUs.
-#if defined(__ANDROID__) || defined(ANDROID)
-#define GEMMLOWP_ANDROID
-#endif
-
-#endif // GEMMLOWP_INTERNAL_DETECT_PLATFORM_H_
diff --git a/internal/dispatch_gemm_shape.h b/internal/dispatch_gemm_shape.h
index ba4f341..0be0bf3 100644
--- a/internal/dispatch_gemm_shape.h
+++ b/internal/dispatch_gemm_shape.h
@@ -85,22 +85,6 @@
}
};
-template <VectorShape Shape>
-struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> {
- typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType;
- static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
- typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape>
- DstType;
- static DstType Run(const SrcType& src) {
- DstType dst;
- dst.result_fixedpoint_multiplier =
- Transpose(src.result_fixedpoint_multiplier);
- dst.result_exponent = Transpose(src.result_exponent);
- dst.result_offset_after_shift = src.result_offset_after_shift;
- return dst;
- }
-};
-
template <typename VectorMapType>
struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
typedef OutputStageBiasAddition<VectorMapType> SrcType;
diff --git a/internal/kernel.h b/internal/kernel.h
index 3120216..825a7f3 100644
--- a/internal/kernel.h
+++ b/internal/kernel.h
@@ -145,24 +145,12 @@
static const int kCells = tCells;
static const int kWidth = kCells * Cell::kWidth;
static const int kDepth = Cell::kDepth;
- typedef std::uint8_t Scalar; // The scalar type of the Format.
- typedef std::uint8_t InputScalar; // The scalar type of the original input.
+ typedef std::uint8_t Scalar;
};
-// KernelSideFormat for int8 fast kernel trick. The original input is uint8, but
-// packs converts it to int8.
template <typename tCellFormat, int tCells>
struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> {
typedef std::int8_t Scalar;
- typedef std::uint8_t InputScalar;
-};
-
-// KernelSideFormat for int8 inputs, enabling int8 fast kernel trick without
-// pack conversion.
-template <typename tCellFormat, int tCells>
-struct KernelSideFormatInt8Inputs : KernelSideFormat<tCellFormat, tCells> {
- typedef std::int8_t Scalar;
- typedef std::int8_t InputScalar;
};
// KernelFormat describes fully the input data layout that a kernel expects.
@@ -228,24 +216,19 @@
virtual ~KernelBase() {}
};
-template <typename InputKernelScalarType, typename KernelScalarType>
+template <typename KernelScalarType>
struct ZeroPointInputValue {};
template <>
-struct ZeroPointInputValue<std::uint8_t, std::uint8_t> {
+struct ZeroPointInputValue<std::uint8_t> {
static constexpr std::uint8_t kValue = 0;
};
template <>
-struct ZeroPointInputValue<std::uint8_t, std::int8_t> {
+struct ZeroPointInputValue<std::int8_t> {
static constexpr std::uint8_t kValue = 128;
};
-template <>
-struct ZeroPointInputValue<std::int8_t, std::int8_t> {
- static constexpr std::uint8_t kValue = 0;
-};
-
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_KERNEL_H_
diff --git a/internal/kernel_avx.h b/internal/kernel_avx.h
deleted file mode 100644
index 2fe1249..0000000
--- a/internal/kernel_avx.h
+++ /dev/null
@@ -1,361 +0,0 @@
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// kernel_SSE.h: a collection of Intel SSE optimized kernels.
-// Check in kernel_default.h which one(s) are actually used by default.
-// Others are mere experiments; they are still covered by tests
-// in case they might be useful some day.
-//
-
-#ifndef GEMMLOWP_INTERNAL_KERNEL_AVX_H_
-#define GEMMLOWP_INTERNAL_KERNEL_AVX_H_
-
-#include "kernel.h"
-
-#include <string.h>
-#include <cassert>
-
-namespace gemmlowp {
-
-#ifdef GEMMLOWP_AVX2_64
-struct AVX2_64_Kernel24x8Depth2 : KernelBase {
- typedef KernelFormat<KernelSideFormat<CellFormat<8, 2, CellOrder::WidthMajor>, 3>,
- KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1>>
- Format;
-
- const char *Name() const override { return "AVX, 24x8, depth 2"; }
-
- void Run(std::int32_t *dst_ptr, std::size_t dst_row_stride, std::size_t dst_col_stride,
- const std::uint8_t *lhs_ptr, const std::uint8_t *rhs_ptr, std::size_t start_depth,
- std::size_t run_depth) const override {
- ScopedProfilingLabel label("optimized kernel");
- assert(dst_row_stride == 1);
- const std::int64_t run_depth_cells = run_depth / Format::kDepth;
- const std::int64_t dst_col_stride_q = dst_col_stride;
-
- /* Main loop */
-
- // A 2x8 cell of Rhs is stored in 16bit in ymm1 .
- // A 24x2 block of 3 8x2 cells Lhs is stored in 16bit in ymm0, replaced
- // every Iteration.
- // A 8x8 block of accumulators is stored in 32bit in xmm4--xmm15.
- //
- // +-------+-------+-------+-------+
- // |ymm1[0] |ymm2[2] |
- // Rhs +-------+---------------+-------+
- // |ymm1[1] |ymm1[4] |
- // +-------+-------+-------+-------+
- //
- // | | | | |
- //
- // Lhs | | | | |
- //
- // +--+--+ - - - - +-------+-------+-------+-------+
- // |ymm0 | | ymm4 | ymm5 | ymm6 | ymm7 |
- // |ymm0 | (Iter1) | ymm4 | ymm5 | ymm6 | ymm7 |
- // |ymm0 | | ymm4 | ymm5 | ymm6 | ymm7 |
- // |ymm0 | | ymm4 | ymm5 | ymm6 | ymm7 |
- // +--+--+ - - - - +-------+-------+-------+-------+
- // |ymm0 | | ymm8 | ymm9 | ymm10 | ymm11 |
- // |ymm0 | (Iter2) | ymm8 | ymm9 | ymm10 | ymm11 |
- // |ymm0 | | ymm8 | ymm9 | ymm10 | ymm11 |
- // |ymm0 | | ymm8 | ymm9 | ymm10 | ymm11 |
- // +--+--+ - - - - +-------+-------+-------+-------+
- // |ymm0 | | ymm12 | ymm13 | ymm14 | ymm15 |
- // |ymm0 | (Iter3) | ymm12 | ymm13 | ymm14 | ymm15 |
- // |ymm0 | | ymm12 | ymm13 | ymm14 | ymm15 |
- // |ymm0 | | ymm12 | ymm13 | ymm14 | ymm15 |
- // +--+--+ - - - - +-------+-------+-------+-------+
- //
- // Accumulator
-
- asm volatile(
- // Set registers for destination
- "movq %[dst_col_stride_q], %%r12\n\t" // stride is r12
- "shlq $2, %%r12\n\t" // set stride dword
- "leaq (%%r12,%%r12,0x2), %%r13\n\t" // load stride aligned r13
-
- // Set accumulators to zero.
- "vpxor %%ymm4, %%ymm4, %%ymm4 \n\t" // zero accumulators
- "vpxor %%ymm5, %%ymm5, %%ymm5 \n\t" // zero accumulators
- "vpxor %%ymm6, %%ymm6, %%ymm6 \n\t" // zero accumulators
- "vpxor %%ymm7, %%ymm7, %%ymm7 \n\t" // zero accumulators
- "vpxor %%ymm8, %%ymm8, %%ymm8 \n\t" // zero accumulators
- "vpxor %%ymm9, %%ymm9, %%ymm9 \n\t" // zero accumulators
- "vpxor %%ymm10, %%ymm10, %%ymm10\n\t" // zero accumulators
- "vpxor %%ymm11, %%ymm11, %%ymm11\n\t" // zero accumulators
- "vpxor %%ymm12, %%ymm12, %%ymm12\n\t" // zero accumulators
- "vpxor %%ymm13, %%ymm13, %%ymm13\n\t" // zero accumulators
- "vpxor %%ymm14, %%ymm14, %%ymm14\n\t" // zero accumulators
- "vpxor %%ymm15, %%ymm15, %%ymm15\n\t" // zero accumulators
-
- "movq %[run_depth_cells], %%r14 \n\t" // load cell depth r14
- "subq $2, %%r14 \n\t" // cell depth is 2
- "js outerLoop1%= \n\t" // outerloop for matrix
-
- // Loop for K unrolled by 4
- "outerLoop2%=: \n\t" // outer loop unroll
-
- // K = 0,1,2,3
- // RHS cell to ymm1
-
- // lower half
- "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1
- "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
- // LHS cell elements 0 and 1
- "vpmovzxbw 0x00(%[lhs_ptr]), %%ymm0\n\t" // mov lhs to ymm0
- "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // move rhs 0 element to all ymm2
- "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // move rhs 1 element to all ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs0 into ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs1 into ymm3
- "vpaddd %%ymm2, %%ymm4, %%ymm4 \n\t" // add muladd lhs + rhs0 into ymm4
- "vpaddd %%ymm3, %%ymm5, %%ymm5 \n\t" // add muladd lhs + rhs1 into ymm5
- // LHS cell elements 2 and 3
- "vpshufd $0xaa, %%ymm1, %%ymm2 \n\t" // move rhs 2 element to all ymm2
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rh3 into ymm2
- "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element into all ymm3
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rh4 into ymm3
- "vpaddd %%ymm2, %%ymm6, %%ymm6 \n\t" // add muladd lhs + rhs2 into ymm6
- "vpaddd %%ymm3, %%ymm7, %%ymm7 \n\t" // add muladd lhs + rhs3 into ymm7
-
- // cache prefect lhs //see if it works better?
- //"prefetcht0 0x80(%[lhs_ptr]) \n\t" //prefetch cache lines
- "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1
- "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
-
- // K = 5,6,7,8
- // next LHS cell elements 0 and 1
- "vpmovzxbw 0x10(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0
- "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // mov rhs 0 element to all ymm2
- "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // mov rhs 1 element to all ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs0 into ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs1 into ymm3
- "vpaddd %%ymm2, %%ymm8, %%ymm8 \n\t" // add muladd lhs + rhs0 into ymm8
- "vpaddd %%ymm3, %%ymm9, %%ymm9 \n\t" // add muladd lhs + rhs1 into ymm9
- // next LHS cell elements 2 and 3
- "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to all ymm2
- "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to all ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs3 into ymm3
- "vpaddd %%ymm2, %%ymm10, %%ymm10 \n\t" // add muladd lhs + rhs2 into ymm10
- "vpaddd %%ymm3, %%ymm11, %%ymm11 \n\t" // add muladd lhs + rhs3 into ymm11
-
- // rhs lower half
- "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1
- "vpermq $0x44,%%ymm1, %%ymm1 \n\t" // duplcate lower 16
-
- // next LHS cell elements 0 and 1
- "vpmovzxbw 0x20(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0
- "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // mov rhs 0 element to all ymm2
- "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // mov rhs 1 element to all ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs0 into ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs1 into ymm3
- "vpaddd %%ymm2, %%ymm12, %%ymm12 \n\t" // add muladd lhs + rhs0 into ymm8
- "vpaddd %%ymm3, %%ymm13, %%ymm13 \n\t" // add muladd lhs + rhs1 into ymm9
-
- // cache prefetch rhs //see if it works better?
- //"prefetcht0 0x80(%[rhs_ptr]) \n\t"
-
- // next LHS cell elements 2 and 3
- "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to all ymm2
- "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to all ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs3 into ymm3
- "vpaddd %%ymm2, %%ymm14, %%ymm14 \n\t" // add muladd lhs + rhs2 into ymm10
- "vpaddd %%ymm3, %%ymm15, %%ymm15 \n\t" // add muladd lhs + rhs3 into ymm11
-
- // current result in ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10 ymm11 ymm12 ymm13 ymm14 ymm15
-
- // rhs+10 lower half
- "vpmovzxbw 0x08(%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1
- "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
- // next LHS cell elements 0 and 1
- "vpmovzxbw 0x30(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0
- "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // move rhs 0 element to ymm2
- "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // move rhs 1 element to ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs0 into ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs1 into ymm3
- "vpaddd %%ymm2, %%ymm4, %%ymm4 \n\t" // accumulate to ymm4
- "vpaddd %%ymm3, %%ymm5, %%ymm5 \n\t" // accumulate to ymm5
- // next LHS cell elements 2 and 3
- "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to ymm2
- "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to ymm2
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mull add lhs rhs3 into ymm3
- "vpaddd %%ymm2, %%ymm6, %%ymm6 \n\t" // add lhs rhs2 to ymm6
- "vpaddd %%ymm3, %%ymm7, %%ymm7 \n\t" // add lhs rhs3 to ymm7
-
- // rhs+10 lower half
- "vpmovzxbw 0x08(%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1
- "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
-
- // next LHS cell elements 4 and 5
- "vpmovzxbw 0x40(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0
- "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // move rhs 0 element to ymm2
- "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // move rhs 1 element to ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs0 into ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs1 into ymm3
- "vpaddd %%ymm2, %%ymm8, %%ymm8 \n\t" // accumulate to ymm8
- "vpaddd %%ymm3, %%ymm9, %%ymm9 \n\t" // accumulate to ymm9
- // next LHS cell elements 6 and 7
- "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to ymm2
- "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to ymm2
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mull add lhs rhs3 into ymm3
- "vpaddd %%ymm2, %%ymm10, %%ymm10 \n\t" // add lhs rhs2 to ymm10
- "vpaddd %%ymm3, %%ymm11, %%ymm11 \n\t" // add lhs rhs3 to ymm11
-
- "vpmovzxbw 0x08(%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1
- "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
- // next LHS cell elements 9 and 10
- "vpmovzxbw 0x50(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0
- "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // move rhs 0 element to ymm2
- "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // move rhs 1 element to ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs0 into ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs1 into ymm3
- "vpaddd %%ymm2, %%ymm12, %%ymm12 \n\t" // accumulate to ymm12
- "vpaddd %%ymm3, %%ymm13, %%ymm13 \n\t" // accumulate to ymm13
-
- // next LHS cell elements 11 and 12
- "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to ymm2
- "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to ymm2
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mull add lhs rhs3 into ymm3
- "vpaddd %%ymm2, %%ymm14, %%ymm14 \n\t" // add lhs rhs2 to ymm14
- "vpaddd %%ymm3, %%ymm15, %%ymm15 \n\t" // add lhs rhs3 to ymm15
-
- // completed rhs+10
- "addq $0x60, %[lhs_ptr] \n\t" // increment stride lhs
- "addq $0x10, %[rhs_ptr] \n\t" // increment stride rhs
-
- "subq $2, %[run_depth_cells] \n\t"
- "ja outerLoop2%= \n\t"
-
- "movq %[run_depth_cells], %%r14 \n\t"
- "decq %%r14 \n\t"
- "js finish%= \n\t"
-
- // Loop for K unrolled by 2
- "outerLoop1%=: \n\t"
-
- // rhs lower
- "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t" // get rhs into ymm1
- "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
-
- // LHS cell
- "vpmovzxbw (%[lhs_ptr]), %%ymm0 \n\t" // lhs in into ymm0
- "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // rhs element 0 into ymm2
- "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // rhs element 1 into ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 0 ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 1 ymm3
- "vpaddd %%ymm2, %%ymm4, %%ymm4 \n\t" // acc element 0 ymm4
- "vpaddd %%ymm3, %%ymm5, %%ymm5 \n\t" // acc element 1 ymm5
- "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // rhs element 2 into ymm2
- "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // rhs element 3 into ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 2 ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 3 ymm3
- "vpaddd %%ymm2, %%ymm6, %%ymm6 \n\t" // acc element 2 into ymm6
- "vpaddd %%ymm3, %%ymm7, %%ymm7 \n\t" // acc element 3 into ymm7
-
- // lhs+10
- "vpmovzxbw 0x10(%[lhs_ptr]), %%ymm0 \n\t" // lhs in into ymm0
- "vpshufd $0x00, %%ymm1, %%ymm2 \n\t" // rhs element 0 into ymm2
- "vpshufd $0x55, %%ymm1, %%ymm3 \n\t" // rhs element 1 into ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 0 ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 1 ymm3
- "vpaddd %%ymm2, %%ymm8, %%ymm8 \n\t" // acc element 0 ymm8
- "vpaddd %%ymm3, %%ymm9, %%ymm9 \n\t" // acc element 1 ymm9
- "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // rhs element 2 into ymm2
- "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // rhs element 3 into ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 2 ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 3 ymm3
- "vpaddd %%ymm2, %%ymm10, %%ymm10 \n\t" // acc element 2 into ymm10
- "vpaddd %%ymm3, %%ymm11, %%ymm11 \n\t" // acc element 3 into ymm11
-
- "vpmovzxbw 0x20(%[lhs_ptr]), %%ymm0 \n\t"
- "vpshufd $0x00, %%ymm1, %%ymm2 \n\t" // rhs element 0 into ymm2
- "vpshufd $0x55, %%ymm1, %%ymm3 \n\t" // rhs element 1 into ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 0 ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 1 ymm3
- "vpaddd %%ymm2, %%ymm12, %%ymm12 \n\t" // acc element 0 ymm12
- "vpaddd %%ymm3, %%ymm13, %%ymm13 \n\t" // acc element 1 ymm13
- "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // rhs element 2 into ymm2
- "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // rhs element 3 into ymm3
- "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 2 ymm2
- "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 3 ymm3
- "vpaddd %%ymm2, %%ymm14, %%ymm14 \n\t" // acc element 2 into ymm14
- "vpaddd %%ymm3, %%ymm15, %%ymm15 \n\t" // acc element 3 into ymm15
-
- // update matrix pointers
- "addq $0x30, %[lhs_ptr] \n\t"
- "addq $0x08, %[rhs_ptr] \n\t"
-
- "decq %[run_depth_cells] \n\t"
- "jnz outerLoop1%= \n\t"
-
- "finish%=:\n\t"
-
- "test %[start_depth], %[start_depth] \n\t"
- "jz storeDst%= \n\t"
-
- "vpaddd 0x00(%[dst_ptr]), %%ymm4, %%ymm4 \n\t" // rhs0
- "vpaddd 0x20(%[dst_ptr]), %%ymm8, %%ymm8 \n\t" // rhs0
- "vpaddd 0x40(%[dst_ptr]), %%ymm12, %%ymm12 \n\t" // rhs0
-
- "vpaddd 0x00(%[dst_ptr], %%r12, 1) , %%ymm5, %%ymm5 \n\t" // rhs1
- "vpaddd 0x20(%[dst_ptr], %%r12, 1) , %%ymm9, %%ymm9 \n\t" // rhs1
- "vpaddd 0x40(%[dst_ptr], %%r12, 1) , %%ymm13, %%ymm13 \n\t" // rhs1
-
- "vpaddd 0x00(%[dst_ptr], %%r12, 2) , %%ymm6, %%ymm6 \n\t" // rhs2
- "vpaddd 0x20(%[dst_ptr], %%r12, 2) , %%ymm10, %%ymm10 \n\t" // rhs2
- "vpaddd 0x40(%[dst_ptr], %%r12, 2) , %%ymm14, %%ymm14 \n\t" // rhs2
-
- "vpaddd 0x00(%[dst_ptr], %%r13, 1) , %%ymm7, %%ymm7 \n\t" // rhs3
- "vpaddd 0x20(%[dst_ptr], %%r13, 1) , %%ymm11, %%ymm11 \n\t" // rhs3
- "vpaddd 0x40(%[dst_ptr], %%r13, 1) , %%ymm15, %%ymm15 \n\t" // rhs3
-
- "storeDst%=:\n\t"
-
- "vmovdqu %%ymm4, 0x00(%[dst_ptr]) \n\t" // rhs0
- "vmovdqu %%ymm8, 0x20(%[dst_ptr]) \n\t" // rhs0
- "vmovdqu %%ymm12, 0x40(%[dst_ptr]) \n\t" // rhs0
-
- "vmovdqu %%ymm5, 0x00(%[dst_ptr], %%r12, 1) \n\t" // rhs1
- "vmovdqu %%ymm9, 0x20(%[dst_ptr], %%r12, 1) \n\t" // rhs1
- "vmovdqu %%ymm13, 0x40(%[dst_ptr], %%r12, 1) \n\t" // rhs1
-
- "vmovdqu %%ymm6, 0x00(%[dst_ptr], %%r12, 2) \n\t" // rhs2
- "vmovdqu %%ymm10, 0x20(%[dst_ptr], %%r12, 2) \n\t" // rhs2
- "vmovdqu %%ymm14, 0x40(%[dst_ptr], %%r12, 2) \n\t" // rhs2
-
- "vmovdqu %%ymm7, 0x00(%[dst_ptr], %%r13, 1) \n\t" // rhs3
- "vmovdqu %%ymm11, 0x20(%[dst_ptr], %%r13, 1) \n\t" // rhs3
- "vmovdqu %%ymm15, 0x40(%[dst_ptr], %%r13, 1) \n\t" // rhs3
-
- : // outputs
- [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
- [dst_ptr] "+r"(dst_ptr)
- : // inputs
- [start_depth] "r"(start_depth), [dst_col_stride_q] "r"(dst_col_stride_q),
- [run_depth_cells] "r"(run_depth_cells)
- : // clobbers
- "cc", "memory", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7",
- "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15", "%r12",
- "%r13", "%r14");
- }
-};
-#endif
-
-} // namespace gemmlowp
-
-#endif // GEMMLOWP_INTERNAL_KERNEL_AVX_H_
diff --git a/internal/kernel_default.h b/internal/kernel_default.h
index 29b0991..a919ffe 100644
--- a/internal/kernel_default.h
+++ b/internal/kernel_default.h
@@ -20,84 +20,66 @@
#include "../public/bit_depth.h"
#include "common.h"
-#include "kernel.h"
#include "kernel_reference.h"
namespace gemmlowp {
-template <bool MaxProductIsLessThan4096, bool IsUnsigned, bool LhsNonZero>
+template <bool MaxProductIsLessThan4096, bool LhsAlwaysNonzero>
struct DefaultKernelImpl {};
// Partial specialization implementing the logic that if we want to use
+// a kernel for LhsAlwaysNonzero but do not have such a kernel, then we fall
+// back to a generic kernel not taking advantage of LhsAlwaysNonzero.
+template <bool LhsAlwaysNonzero>
+struct DefaultKernelImpl<true, LhsAlwaysNonzero>
+ : DefaultKernelImpl<false, LhsAlwaysNonzero> {};
+
+// Partial specialization implementing the logic that if we want to use
// a kernel for MaxProductIsLessThan4096 but do not have such a kernel, then we
// fall back to a generic kernel not taking advantage of
// MaxProductIsLessThan4096.
-template <bool LhsNonZero>
-struct DefaultKernelImpl<true, true, LhsNonZero>
- : DefaultKernelImpl<false, true, LhsNonZero> {};
-
-// Partial specialization implementing the logic that if we want to use
-// a kernel for LhsNonZero but do not have such a kernel, then we fall
-// back to a generic kernel not taking advantage of LhsNonZero.
template <bool MaxProductIsLessThan4096>
-struct DefaultKernelImpl<MaxProductIsLessThan4096, true, true>
- : DefaultKernelImpl<MaxProductIsLessThan4096, true, false> {};
+struct DefaultKernelImpl<MaxProductIsLessThan4096, true>
+ : DefaultKernelImpl<MaxProductIsLessThan4096, false> {};
template <typename BitDepthParams>
struct DefaultKernel
: DefaultKernelImpl<(BitDepthParams::LhsRange::kMaxValue *
BitDepthParams::RhsRange::kMaxValue <
4096),
- (BitDepthParams::LhsRange::kMinValue >= 0),
- (BitDepthParams::LhsRange::kMinValue > 0 ||
- (BitDepthParams::LhsRange::kMaxValue <= 127 &&
- BitDepthParams::LhsRange::kMinValue > -128))> {};
+ (BitDepthParams::LhsRange::kMinValue > 0)> {};
} // end namespace gemmlowp
-#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, IsUnsigned, \
- LhsAlwaysNonZero, Kernel) \
- namespace gemmlowp { \
- template <> \
- struct DefaultKernelImpl<MaxProductIsLessThan4096, IsUnsigned, \
- LhsAlwaysNonZero> : Kernel {}; \
+#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, \
+ LhsAlwaysNonzero, Kernel) \
+ namespace gemmlowp { \
+ template <> \
+ struct DefaultKernelImpl<MaxProductIsLessThan4096, LhsAlwaysNonzero> \
+ : Kernel {}; \
}
-// User-provided int8 inputs is only supported in the NEON path currently.
#if defined GEMMLOWP_NEON_32
#include "kernel_neon.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, NEON_32_Kernel12x4Depth2)
-GEMMLOWP_SET_DEFAULT_KERNEL(true, true, false,
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_32_Kernel12x4Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(true, false,
NEON_32_Kernel12x4Depth2Assuming12BitProducts)
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true,
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
NEON_32bit_GEMM_Int8Operands_LhsNonzero)
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, true,
- NEON_32bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs)
#elif defined GEMMLOWP_NEON_64
#include "kernel_neon.h"
-#if defined GEMMLOWP_DOTPROD_KERNEL
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false,
- NEON_64_Kernel12x8Depth4_dotprod)
-#else
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, NEON_64_Kernel12x8Depth2)
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true,
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_64_Kernel12x8Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
NEON_64bit_GEMM_Int8Operands_LhsNonzero)
-#endif
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, true,
- NEON_64bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs)
#elif defined(GEMMLOWP_MSA)
#include "kernel_msa.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, MSA_Kernel12x8Depth2)
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true, MSA_GEMM_Int8Operands_LhsNonzero)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, MSA_Kernel12x8Depth2)
#elif defined GEMMLOWP_SSE4_32
#include "kernel_sse.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, SSE4_32_Kernel4x4Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_32_Kernel4x4Depth2)
#elif defined GEMMLOWP_SSE4_64
#include "kernel_sse.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, SSE4_64_Kernel12x4Depth2)
-#elif defined GEMMLOWP_AVX2_64
-#include "kernel_avx.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, AVX2_64_Kernel24x8Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_64_Kernel12x4Depth2)
#else
#include "kernel_reference.h"
namespace gemmlowp {
@@ -106,7 +88,7 @@
KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > >
DefaultReferenceKernel;
}
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, DefaultReferenceKernel)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, DefaultReferenceKernel)
#endif
#endif // GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_
diff --git a/internal/kernel_msa.h b/internal/kernel_msa.h
index a9205f6..4985b73 100644
--- a/internal/kernel_msa.h
+++ b/internal/kernel_msa.h
@@ -42,8 +42,8 @@
// Our main GEMM kernel.
struct MSA_Kernel12x8Depth2 : KernelBase {
- typedef KernelFormat<KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 3>,
- KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2> >
+ typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,
+ KernelSideFormat<CellFormat<4, 2>, 2> >
Format;
const char* Name() const override { return "MSA, 12x8, depth 2"; }
@@ -62,6 +62,9 @@
assert(dst_row_stride == 1);
asm volatile(
+ // Set a temp to all zeroes.
+ "ldi.b $w31, 0\n"
+
// Multiply dst_col_stride by 4 == sizeof(int32) to use
// it as a byte offset below.
GEMMLOWP_MIPS_XSLL
@@ -72,25 +75,32 @@
"beqz %[start_depth], " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n"
// Load accumulators (start_depth != 0).
- GEMMLOWP_MIPS_XADDU " $a0, %[dst_ptr], %[dst_col_stride]\n"
+ GEMMLOWP_MIPS_XADDU
+ " $a0, %[dst_ptr], %[dst_col_stride]\n"
"ld.w $w0, (0*16)(%[dst_ptr])\n"
"ld.w $w4, (1*16)(%[dst_ptr])\n"
- "ld.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
+ "ld.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
"ld.w $w1, (0*16)($a0)\n"
"ld.w $w5, (1*16)($a0)\n"
- "ld.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
+ "ld.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
"ld.w $w2, (0*16)($a1)\n"
"ld.w $w6, (1*16)($a1)\n"
- "ld.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
+ "ld.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
"ld.w $w3, (0*16)($a0)\n"
"ld.w $w7, (1*16)($a0)\n"
- "ld.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
+ "ld.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
"ld.w $w12, (0*16)($a1)\n"
"ld.w $w16, (1*16)($a1)\n"
- "ld.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
+ "ld.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
"ld.w $w13, (0*16)($a0)\n"
"ld.w $w17, (1*16)($a0)\n"
- "ld.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
+ "ld.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
"ld.w $w14, (0*16)($a1)\n"
"ld.w $w18, (1*16)($a1)\n"
"ld.w $w22, (2*16)($a1)\n"
@@ -99,7 +109,8 @@
"ld.w $w23, (2*16)($a0)\n"
"b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n"
- GEMMLOWP_LABEL_CLEAR_ACCUMULATORS ":\n"
+ GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
+ ":\n"
// Clear accumulators (start_depth == 0).
"ldi.w $w0, 0\n"
"ldi.w $w4, 0\n"
@@ -128,16 +139,17 @@
GEMMLOWP_LABEL_BEFORE_LOOP ":\n"
- GEMMLOWP_LABEL_LOOP ":\n"
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
// Overview of register layout:
//
- // A half of the 2 2x4 cells of Rhs is stored in 16bit in w28-w31
+ // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
// (each register contains 4 replicas of a pair of elements).
// A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
// A 12x8 block of accumulators is stored in 32bit in w0-w23.
//
// +------+------+------+------+
- // Rhs |w28 |w29 |w30 |w31 |
+ // Rhs |w27 |w28 |w29 |w30 |
// +------+------+------+------+
//
// | | | | |
@@ -167,86 +179,128 @@
"ld.b $w24, 0(%[lhs_ptr])\n"
"ld.b $w25, 8(%[lhs_ptr])\n"
- // Load 2 x 8 bytes of rhs[].
- "ld.b $w27, 0(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the first half of depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the first half of depth 1.
+ "lbu $v0, 4(%[rhs_ptr])\n"
+ "lbu $v1, 5(%[rhs_ptr])\n"
+ "lbu $t8, 6(%[rhs_ptr])\n"
+ "lbu $t9, 7(%[rhs_ptr])\n"
// Zero-extend 8-bit elements of lhs[] to 16 bits.
- "ldi.b $w31, 0\n"
"ilvr.b $w24, $w31, $w24\n"
"ilvl.b $w26, $w31, $w25\n"
"ilvr.b $w25, $w31, $w25\n"
+ // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
+ "ilvl.d $w27, $w31, $w24\n"
+ "ilvl.d $w28, $w31, $w25\n"
+ "ilvl.d $w29, $w31, $w26\n"
+ "ilvr.h $w24, $w27, $w24\n"
+ "ilvr.h $w25, $w28, $w25\n"
+ "ilvr.h $w26, $w29, $w26\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for
+ // dpadd_u.w (for the first half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Load 4 bytes of rhs[] for the second half of depth 0.
+ "lbu $a0, 8(%[rhs_ptr])\n"
+ "lbu $a1, 9(%[rhs_ptr])\n"
+ "lbu $a2, 10(%[rhs_ptr])\n"
+ "lbu $a3, 11(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the second half of depth 1.
+ "lbu $v0, 12(%[rhs_ptr])\n"
+ "lbu $v1, 13(%[rhs_ptr])\n"
+ "lbu $t8, 14(%[rhs_ptr])\n"
+ "lbu $t9, 15(%[rhs_ptr])\n"
// First half of depths 0 and 1.
- // Zero-extend 8-bit elements of rhs[] to 16 bits.
- "ilvr.b $w31, $w31, $w27\n"
- // Make 4 replicas of every pair of rhs[] elements.
- "splati.w $w28, $w31[0]\n"
- "splati.w $w29, $w31[1]\n"
- "splati.w $w30, $w31[2]\n"
- "splati.w $w31, $w31[3]\n"
// Dot-product-(and)-add doubles multiplicand width.
- "dpadd_u.w $w0, $w24, $w28\n"
- "dpadd_u.w $w4, $w25, $w28\n"
- "dpadd_u.w $w8, $w26, $w28\n"
- "dpadd_u.w $w1, $w24, $w29\n"
- "dpadd_u.w $w5, $w25, $w29\n"
- "dpadd_u.w $w9, $w26, $w29\n"
- "dpadd_u.w $w2, $w24, $w30\n"
- "dpadd_u.w $w6, $w25, $w30\n"
- "dpadd_u.w $w10, $w26, $w30\n"
- "dpadd_u.w $w3, $w24, $w31\n"
- "dpadd_u.w $w7, $w25, $w31\n"
- "dpadd_u.w $w11, $w26, $w31\n"
+ "dpadd_u.w $w0, $w24, $w27\n"
+ "dpadd_u.w $w4, $w25, $w27\n"
+ "dpadd_u.w $w8, $w26, $w27\n"
+ "dpadd_u.w $w1, $w24, $w28\n"
+ "dpadd_u.w $w5, $w25, $w28\n"
+ "dpadd_u.w $w9, $w26, $w28\n"
+ "dpadd_u.w $w2, $w24, $w29\n"
+ "dpadd_u.w $w6, $w25, $w29\n"
+ "dpadd_u.w $w10, $w26, $w29\n"
+ "dpadd_u.w $w3, $w24, $w30\n"
+ "dpadd_u.w $w7, $w25, $w30\n"
+ "dpadd_u.w $w11, $w26, $w30\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for
+ // dpadd_u.w (for the second half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
// Second half of depths 0 and 1.
- // Zero-extend 8-bit elements of rhs[] to 16 bits.
- "ldi.b $w31, 0\n"
- "ilvl.b $w31, $w31, $w27\n"
- // Make 4 replicas of every pair of rhs[] elements.
- "splati.w $w28, $w31[0]\n"
- "splati.w $w29, $w31[1]\n"
- "splati.w $w30, $w31[2]\n"
- "splati.w $w31, $w31[3]\n"
// Dot-product-(and)-add doubles multiplicand width.
- "dpadd_u.w $w12, $w24, $w28\n"
- "dpadd_u.w $w16, $w25, $w28\n"
- "dpadd_u.w $w20, $w26, $w28\n"
- "dpadd_u.w $w13, $w24, $w29\n"
- "dpadd_u.w $w17, $w25, $w29\n"
- "dpadd_u.w $w21, $w26, $w29\n"
- "dpadd_u.w $w14, $w24, $w30\n"
- "dpadd_u.w $w18, $w25, $w30\n"
- "dpadd_u.w $w22, $w26, $w30\n"
- "dpadd_u.w $w15, $w24, $w31\n"
- "dpadd_u.w $w19, $w25, $w31\n"
- "dpadd_u.w $w23, $w26, $w31\n"
+ "dpadd_u.w $w12, $w24, $w27\n"
+ "dpadd_u.w $w16, $w25, $w27\n"
+ "dpadd_u.w $w20, $w26, $w27\n"
+ "dpadd_u.w $w13, $w24, $w28\n"
+ "dpadd_u.w $w17, $w25, $w28\n"
+ "dpadd_u.w $w21, $w26, $w28\n"
+ "dpadd_u.w $w14, $w24, $w29\n"
+ "dpadd_u.w $w18, $w25, $w29\n"
+ "dpadd_u.w $w22, $w26, $w29\n"
+ "dpadd_u.w $w15, $w24, $w30\n"
+ "dpadd_u.w $w19, $w25, $w30\n"
+ "dpadd_u.w $w23, $w26, $w30\n"
GEMMLOWP_MIPS_XADDIU " %[run_depth], -2\n" GEMMLOWP_MIPS_XADDIU
- " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
+ " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU
+ " %[rhs_ptr], 16\n"
"bnez %[run_depth]," GEMMLOWP_LABEL_LOOP "b\n"
GEMMLOWP_LABEL_AFTER_LOOP ":\n"
// Store accumulators.
- GEMMLOWP_MIPS_XADDU " $a0, %[dst_ptr], %[dst_col_stride]\n"
+ GEMMLOWP_MIPS_XADDU
+ " $a0, %[dst_ptr], %[dst_col_stride]\n"
"st.w $w0, (0*16)(%[dst_ptr])\n"
"st.w $w4, (1*16)(%[dst_ptr])\n"
- "st.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
+ "st.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
"st.w $w1, (0*16)($a0)\n"
"st.w $w5, (1*16)($a0)\n"
- "st.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
+ "st.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
"st.w $w2, (0*16)($a1)\n"
"st.w $w6, (1*16)($a1)\n"
- "st.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
+ "st.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
"st.w $w3, (0*16)($a0)\n"
"st.w $w7, (1*16)($a0)\n"
- "st.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
+ "st.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
"st.w $w12, (0*16)($a1)\n"
"st.w $w16, (1*16)($a1)\n"
- "st.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
+ "st.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
"st.w $w13, (0*16)($a0)\n"
"st.w $w17, (1*16)($a0)\n"
- "st.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
+ "st.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
"st.w $w14, (0*16)($a1)\n"
"st.w $w18, (1*16)($a1)\n"
"st.w $w22, (2*16)($a1)\n"
@@ -254,15 +308,18 @@
"st.w $w19, (1*16)($a0)\n"
"st.w $w23, (2*16)($a0)\n"
: // outputs
- [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [run_depth] "+r"(run_depth),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [run_depth] "+r"(run_depth),
[dst_col_stride] "+r"(dst_col_stride)
: // inputs
[dst_ptr] "r"(dst_ptr),
[start_depth] "r"(start_depth)
: // clobbers
- "memory", "a0", "a1", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9",
- "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20",
- "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
+ "memory", "v0", "v1", "a0", "a1", "a2", "a3", "t8", "t9", "$f0", "$f1",
+ "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9", "$f10", "$f11",
+ "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20",
+ "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29",
+ "$f30", "$f31");
#undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
#undef GEMMLOWP_LABEL_BEFORE_LOOP
@@ -271,303 +328,6 @@
}
};
-// Fast kernel operating on int8 operands.
-// It is assumed that one of the two int8 operands only takes values
-// in [-127, 127], while the other may freely range in [-128, 127].
-// The issue with both operands taking the value -128 is that:
-// -128*-128 + -128*-128 == -32768 overflows int16.
-// Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
-// range. That is the basic idea of this kernel.
-struct MSA_GEMM_Int8Operands_LhsNonzero : KernelBase {
- typedef KernelFormat<
- KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
- KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
- Format;
-
- const char* Name() const override {
- return "MSA, 4x4, depth 16, accumulating two within signed int16";
- }
-
- // TODO(benoitjacob): reorder function arguments so dst comes last
- void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
- std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
- const std::uint8_t* rhs_ptr, std::size_t start_depth,
- std::size_t run_depth) const override {
- (void)dst_row_stride;
-#define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1"
-#define GEMMLOWP_LABEL_LOOP "2"
-#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
-#define GEMMLOWP_LABEL_STORE "4"
- asm volatile(
- GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n"
- // Load lhs[] and rhs[], zero out internal accumulators.
- "ld.b $w16, 0(%[lhs_ptr])\n"
- "ldi.b $w0, 0\n"
- "ld.b $w20, 0(%[rhs_ptr])\n"
- "ldi.b $w1, 0\n"
- "ld.b $w17, 16(%[lhs_ptr])\n"
- "ldi.b $w2, 0\n"
- "ld.b $w21, 16(%[rhs_ptr])\n"
- "ldi.b $w3, 0\n"
- "ld.b $w18, 32(%[lhs_ptr])\n"
- "ldi.b $w4, 0\n"
- "ld.b $w19, 48(%[lhs_ptr])\n"
- "ldi.b $w5, 0\n"
- "ld.b $w22, 32(%[rhs_ptr])\n"
- "ldi.b $w6, 0\n"
- "ld.b $w23, 48(%[rhs_ptr])\n"
- "ldi.b $w7, 0\n"
- "ldi.b $w8, 0\n"
- "ldi.b $w9, 0\n"
- "ldi.b $w10, 0\n"
- "ldi.b $w11, 0\n"
- "ldi.b $w12, 0\n"
- "ldi.b $w13, 0\n"
- "ldi.b $w14, 0\n"
- "ldi.b $w15, 0\n"
- "ldi.h $w31, 1\n"
- // If the loop depth is only 16, then we can skip the general loop
- // and go straight to the final part of the code.
- "beqz %[run_depth], " GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "f\n"
-
- GEMMLOWP_LABEL_LOOP ":\n"
- // Overview of register layout:
- //
- // A 4x16 block of Rhs is stored in 8 bit in w16-w19.
- // A 4x16 block of Lhs is stored in 8 bit in w20-w23.
- //
- // A 4x4 block of accumulators is stored in w0-w15 (as 4x32 bit
- // components which need to be horizontally added at the end).
- //
- // Dot products of Lhs and Rhs are 16-bit values, which can't
- // immediately be accumulated in 32-bit accumulators by that
- // same instruction that calculates them.
- // For example, "dotp_s.h $w25, $w16, $w20" produces 8 16-bit
- // sums in w25 (note, the 16 sums have already been reduced to 8
- // by the horizontal addition of the dotp instruction).
- // They are then sign-extended to 32 bits, horizontally added
- // (again) to form 4 32-bit sums and then they are finally added
- // to the 32-bit accumulators, all by "dpadd_s.w $w0, $w25, $w31".
- //
- // +-----+-----+-----+-----+
- // Rhs | w20 | w21 | w22 | w23 |
- // +-----+-----+-----+-----+
- //
- // | | | | |
- //
- // Lhs | | | | |
- //
- // +---+ - - - - +-----+-----+-----+-----+
- // |w16| | w0 | w4 | w8 | w12 |
- // |w17| | w1 | w5 | w9 | w13 |
- // |w18| | w2 | w6 | w10 | w14 |
- // |w19| | w3 | w7 | w11 | w15 |
- // +---+ - - - - +-----+-----+-----+-----+
- //
- // Accumulators
-
- // Calculate the results for 16 depths and load
- // lhs[] and rhs[] for the next iteration.
- GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 64\n"
- GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 64\n"
- GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n"
-
- // Dot product: multiply-add pairs of adjacent int8 elements.
- // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
- "dotp_s.h $w25, $w16, $w20\n"
- "dotp_s.h $w26, $w17, $w20\n"
- "dotp_s.h $w27, $w16, $w21\n"
- "dotp_s.h $w28, $w17, $w21\n"
- "dotp_s.h $w29, $w18, $w20\n"
- // Horizontal add of pairs of adjacent int16 sums into internal int32
- // accumulators.
- "dpadd_s.w $w0, $w25, $w31\n"
- "dpadd_s.w $w1, $w26, $w31\n"
- "dpadd_s.w $w4, $w27, $w31\n"
- "dpadd_s.w $w5, $w28, $w31\n"
- "dpadd_s.w $w2, $w29, $w31\n"
-
- // Dot product: multiply-add pairs of adjacent int8 elements.
- // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
- "dotp_s.h $w24, $w16, $w22\n"
- "dotp_s.h $w25, $w19, $w20\n"
- "dotp_s.h $w26, $w16, $w23\n"
- "dotp_s.h $w27, $w17, $w22\n"
- "ld.b $w20, 0(%[rhs_ptr])\n"
- "dotp_s.h $w28, $w17, $w23\n"
- "ld.b $w16, 0(%[lhs_ptr])\n"
- "dotp_s.h $w29, $w18, $w21\n"
- "ld.b $w17, 16(%[lhs_ptr])\n"
- // Horizontal add of pairs of adjacent int16 sums into internal int32
- // accumulators.
- "dpadd_s.w $w8, $w24, $w31\n"
- "dpadd_s.w $w3, $w25, $w31\n"
- "dpadd_s.w $w12, $w26, $w31\n"
- "dpadd_s.w $w9, $w27, $w31\n"
- "dpadd_s.w $w13, $w28, $w31\n"
- "dpadd_s.w $w6, $w29, $w31\n"
-
- // Dot product: multiply-add pairs of adjacent int8 elements.
- // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
- "dotp_s.h $w25, $w19, $w21\n"
- "dotp_s.h $w26, $w18, $w22\n"
- "dotp_s.h $w27, $w18, $w23\n"
- "ld.b $w21, 16(%[rhs_ptr])\n"
- "dotp_s.h $w28, $w19, $w22\n"
- "ld.b $w18, 32(%[lhs_ptr])\n"
- "dotp_s.h $w29, $w19, $w23\n"
- "ld.b $w22, 32(%[rhs_ptr])\n"
- // Horizontal add of pairs of adjacent int16 sums into internal int32
- // accumulators.
- "dpadd_s.w $w7, $w25, $w31\n"
- "ld.b $w19, 48(%[lhs_ptr])\n"
- "dpadd_s.w $w10, $w26, $w31\n"
- "ld.b $w23, 48(%[rhs_ptr])\n"
- "dpadd_s.w $w14, $w27, $w31\n"
- "dpadd_s.w $w11, $w28, $w31\n"
- "dpadd_s.w $w15, $w29, $w31\n"
-
- "bnez %[run_depth], " GEMMLOWP_LABEL_LOOP "b\n"
-
- GEMMLOWP_LABEL_AFTER_LOOP_LAST16 ":\n"
- // Calculate the results for the last 16 depths.
-
- // Dot product: multiply-add pairs of adjacent int8 elements.
- // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
- "dotp_s.h $w25, $w16, $w20\n"
- "dotp_s.h $w26, $w17, $w20\n"
- "dotp_s.h $w27, $w16, $w21\n"
- "dotp_s.h $w28, $w17, $w21\n"
- "dotp_s.h $w29, $w18, $w20\n"
- // Horizontal add of pairs of adjacent int16 sums into internal int32
- // accumulators.
- "dpadd_s.w $w0, $w25, $w31\n"
- "dpadd_s.w $w1, $w26, $w31\n"
- "dpadd_s.w $w4, $w27, $w31\n"
- "dpadd_s.w $w5, $w28, $w31\n"
- "dpadd_s.w $w2, $w29, $w31\n"
-
- // Dot product: multiply-add pairs of adjacent int8 elements.
- // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
- "dotp_s.h $w24, $w16, $w22\n"
- "dotp_s.h $w25, $w19, $w20\n"
- "dotp_s.h $w26, $w16, $w23\n"
- "dotp_s.h $w27, $w17, $w22\n"
- "dotp_s.h $w28, $w17, $w23\n"
- "dotp_s.h $w29, $w18, $w21\n"
- // Horizontal add of pairs of adjacent int16 sums into internal int32
- // accumulators.
- "dpadd_s.w $w8, $w24, $w31\n"
- "dpadd_s.w $w3, $w25, $w31\n"
- "dpadd_s.w $w12, $w26, $w31\n"
- "dpadd_s.w $w9, $w27, $w31\n"
- "dpadd_s.w $w13, $w28, $w31\n"
- "dpadd_s.w $w6, $w29, $w31\n"
-
- // Dot product: multiply-add pairs of adjacent int8 elements.
- // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
- "dotp_s.h $w25, $w19, $w21\n"
- "dotp_s.h $w26, $w18, $w22\n"
- "dotp_s.h $w27, $w18, $w23\n"
- "dotp_s.h $w28, $w19, $w22\n"
- "dotp_s.h $w29, $w19, $w23\n"
- // Horizontal add of pairs of adjacent int16 sums into internal int32
- // accumulators.
- "dpadd_s.w $w7, $w25, $w31\n"
- "dpadd_s.w $w10, $w26, $w31\n"
- "dpadd_s.w $w14, $w27, $w31\n"
- "dpadd_s.w $w11, $w28, $w31\n"
- "dpadd_s.w $w15, $w29, $w31\n"
-
- // Horizontal-add internal accumulators.
- "hadd_s.d $w0, $w0, $w0\n"
- "hadd_s.d $w1, $w1, $w1\n"
- "hadd_s.d $w2, $w2, $w2\n"
- "hadd_s.d $w3, $w3, $w3\n"
- "hadd_s.d $w4, $w4, $w4\n"
- "hadd_s.d $w5, $w5, $w5\n"
- "hadd_s.d $w6, $w6, $w6\n"
- "hadd_s.d $w7, $w7, $w7\n"
- "hadd_s.d $w8, $w8, $w8\n"
- "hadd_s.d $w9, $w9, $w9\n"
- "hadd_s.d $w10, $w10, $w10\n"
- "hadd_s.d $w11, $w11, $w11\n"
- "hadd_s.d $w12, $w12, $w12\n"
- "hadd_s.d $w13, $w13, $w13\n"
- "hadd_s.d $w14, $w14, $w14\n"
- "hadd_s.d $w15, $w15, $w15\n"
- "pckev.w $w0, $w1, $w0\n"
- "pckev.w $w2, $w3, $w2\n"
- "pckev.w $w4, $w5, $w4\n"
- "pckev.w $w6, $w7, $w6\n"
- "pckev.w $w8, $w9, $w8\n"
- "pckev.w $w10, $w11, $w10\n"
- "pckev.w $w12, $w13, $w12\n"
- "pckev.w $w14, $w15, $w14\n"
- "hadd_s.d $w0, $w0, $w0\n"
- "hadd_s.d $w2, $w2, $w2\n"
- "hadd_s.d $w4, $w4, $w4\n"
- "hadd_s.d $w6, $w6, $w6\n"
- "hadd_s.d $w8, $w8, $w8\n"
- "hadd_s.d $w10, $w10, $w10\n"
- "hadd_s.d $w12, $w12, $w12\n"
- "hadd_s.d $w14, $w14, $w14\n"
- // 4 more pckev instructions follow in both paths below.
-
- // Check if start_depth==0 to decide whether we will load
- // existing accumulators from memory.
- "bnez %[start_depth], " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "f\n"
-
- "pckev.w $w0, $w2, $w0\n"
- "pckev.w $w1, $w6, $w4\n"
- "pckev.w $w2, $w10, $w8\n"
- "pckev.w $w3, $w14, $w12\n"
-
- "b " GEMMLOWP_LABEL_STORE "f\n"
-
- GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES ":\n"
- // Load accumulators from memory.
- "ld.w $w16, 0(%[dst_ptr0])\n"
- "pckev.w $w0, $w2, $w0\n"
- "ld.w $w17, 0(%[dst_ptr1])\n"
- "pckev.w $w1, $w6, $w4\n"
- "ld.w $w18, 0(%[dst_ptr2])\n"
- "pckev.w $w2, $w10, $w8\n"
- "ld.w $w19, 0(%[dst_ptr3])\n"
- "pckev.w $w3, $w14, $w12\n"
-
- // Add them to internal accumulators.
- "addv.w $w0, $w0, $w16\n"
- "addv.w $w1, $w1, $w17\n"
- "addv.w $w2, $w2, $w18\n"
- "addv.w $w3, $w3, $w19\n"
-
- GEMMLOWP_LABEL_STORE ":\n"
- // Store accumulators.
- "st.w $w0, 0(%[dst_ptr0])\n"
- "st.w $w1, 0(%[dst_ptr1])\n"
- "st.w $w2, 0(%[dst_ptr2])\n"
- "st.w $w3, 0(%[dst_ptr3])\n"
- : // outputs
- [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
- [run_depth] "+r"(run_depth)
- : // inputs
- [dst_ptr0] "r"(dst_ptr), [dst_ptr1] "r"(dst_ptr + dst_col_stride),
- [dst_ptr2] "r"(dst_ptr + dst_col_stride * 2),
- [dst_ptr3] "r"(dst_ptr + dst_col_stride * 3),
- [start_depth] "r"(start_depth)
- : // clobbers
- "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8",
- "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17",
- "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26",
- "$f27", "$f28", "$f29", "$f30", "$f31");
-#undef GEMMLOWP_LABEL_LOOP
-#undef GEMMLOWP_LABEL_AFTER_LOOP_LAST16
-#undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
-#undef GEMMLOWP_LABEL_STORE
- }
-};
-
#undef GEMMLOWP_MIPS_XADDU
#undef GEMMLOWP_MIPS_XADDIU
#undef GEMMLOWP_MIPS_XSLL
diff --git a/internal/kernel_neon.h b/internal/kernel_neon.h
index 9859637..3cd48f4 100644
--- a/internal/kernel_neon.h
+++ b/internal/kernel_neon.h
@@ -55,7 +55,6 @@
#define GEMMLOWP_LABEL_AFTER_LOOP "4"
assert(dst_row_stride == 1);
- (void)dst_row_stride;
asm volatile(
// Overview of register layout:
//
@@ -309,7 +308,6 @@
ScopedProfilingLabel label(
"optimized kernel (NEON 12x4, assuming 12-bit products)");
assert(dst_row_stride == 1);
- (void)dst_row_stride;
// See comments above for why we need local numerical labels in our asm.
#define GEMMLOWP_LOOP_NEON_32_KERNEL_12X4_DEPTH2_ASSUMING_12BIT_PRODUCTS "1"
@@ -680,7 +678,6 @@
std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
const std::uint8_t* rhs_ptr, std::size_t start_depth,
std::size_t run_depth) const override {
- (void)dst_row_stride;
#define GEMMLOWP_LABEL_AFTER_LOOP "1"
#define GEMMLOWP_LABEL_LOOP "2"
#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
@@ -924,17 +921,6 @@
}
};
-// Same as NEON_32bit_GEMM_Int8Operands_LhsNonzero, but uses a side format that
-// requires that user inputs were originally int8. This avoids the uint8->int8
-// conversion in the pack step.
-struct NEON_32bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs
- : NEON_32bit_GEMM_Int8Operands_LhsNonzero {
- typedef KernelFormat<
- KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
- KernelSideFormatInt8Inputs<CellFormat<2, 16, CellOrder::WidthMajor>, 1> >
- Format;
-};
-
#endif // GEMMLOWP_NEON_32
// The kernels here are specifically arm 64bit assembly, not arm 32bit.
@@ -954,7 +940,6 @@
std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
const std::uint8_t* rhs_ptr, std::size_t start_depth,
std::size_t run_depth) const override {
- (void)dst_row_stride;
#define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1"
#define GEMMLOWP_LABEL_LOOP "2"
#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
@@ -1276,17 +1261,6 @@
}
};
-// Same as NEON_32bit_GEMM_Int8Operands_LhsNonzero, but uses a side format that
-// requires that user inputs were originally int8. This avoids the uint8->int8
-// conversion in the pack step.
-struct NEON_64bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs
- : NEON_64bit_GEMM_Int8Operands_LhsNonzero {
- typedef KernelFormat<
- KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
- KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
- Format;
-};
-
// Our main GEMM kernel.
struct NEON_64_Kernel12x8Depth2 : KernelBase {
typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,
@@ -1300,7 +1274,6 @@
std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
const std::uint8_t* rhs_ptr, std::size_t start_depth,
std::size_t run_depth) const override {
- (void)dst_row_stride;
ScopedProfilingLabel label("optimized kernel (NEON 12x8)");
// See comments above for why we need local numerical labels in our asm.
#define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1"
@@ -1638,274 +1611,6 @@
}
};
-#ifdef GEMMLOWP_DOTPROD_KERNEL
-#ifndef __ARM_FEATURE_DOTPROD
-#error This kernel requires ARM dot-product instructions. Enable them by \
- adding '+dotprod' to a compiler flag, e.g. -march=armv8.2-a+dotprod . \
- Note that Clang up to version 7 fails to define the corresponding \
- preprocessor token __ARM_FEATURE_DOTPROD, so you will still have to define \
- it manually.
-#endif
-// Kernels utilizing the Armv8.2 Dot Product extension.
-//
-// The dot product instructions work by taking 4 consecutive 8-bit depth
-// values from each operand, multiplying the 4 pairs together and
-// accumulating all the results into the corresponding 32-bit accumulator
-// lane. As such, the operation is identical to a 32-bit instruction (like
-// FMLA used in SGEMM), except that 4 depth values are processed at a time
-// instead of 1.
-
-// Thus, this first kernel is a carbon copy of
-// "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good
-// performance for most processors) below with the opcode (fmla -> udot) and
-// types (float32 -> uint8/uint32) changed.
-//
-// A signed version of this kernel could be produced by replacing "udot"
-// with "sdot" - performance should be identical to this udot kernel.
-struct NEON_64_Kernel12x8Depth4_dotprod : KernelBase {
- typedef KernelFormat<KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
- KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
- Format;
-
- const char* Name() const override { return "NEON, 12x8, depth 4, dotprod"; }
-
- void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, std::size_t dst_col_stride,
- const std::uint8_t* lhs_ptr, const std::uint8_t* rhs_ptr, std::size_t start_depth,
- std::size_t depth) const override {
- (void)dst_row_stride;
- ScopedProfilingLabel label("optimized kernel (NEON 12x8, depth 4, dotprod)");
-// See comments above for why we need local numerical labels in our asm.
-#define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1"
-#define GEMMLOWP_LABEL_BEFORE_LOOP "2"
-#define GEMMLOWP_LABEL_LOOP "3"
-#define GEMMLOWP_LABEL_AFTER_LOOP "4"
-
- assert(dst_row_stride == 1);
- asm volatile(
- // Multiply dst_col_stride by 4 == sizeof(int32) to use
- // it as a byte offset below.
- "lsl %[dst_col_stride], %[dst_col_stride], #2\n"
-
- "cmp %[start_depth], #0\n"
- "beq " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n"
-
- // Load accumulators
- "mov x1, %[dst_ptr]\n"
- "mov x0, x1\n"
- "ld1 {v8.16b}, [x0], #16\n"
- "ld1 {v16.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "ld1 {v24.16b}, [x0]\n"
- "mov x0, x1\n"
- "ld1 {v9.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "ld1 {v17.16b}, [x0], #16\n"
- "ld1 {v25.16b}, [x0]\n"
- "mov x0, x1\n"
- "ld1 {v10.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "ld1 {v18.16b}, [x0], #16\n"
- "ld1 {v26.16b}, [x0]\n"
- "mov x0, x1\n"
- "ld1 {v11.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "ld1 {v19.16b}, [x0], #16\n"
- "ld1 {v27.16b}, [x0]\n"
- "mov x0, x1\n"
- "ld1 {v12.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "ld1 {v20.16b}, [x0], #16\n"
- "ld1 {v28.16b}, [x0]\n"
- "mov x0, x1\n"
- "ld1 {v13.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "ld1 {v21.16b}, [x0], #16\n"
- "ld1 {v29.16b}, [x0]\n"
- "mov x0, x1\n"
- "ld1 {v14.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "ld1 {v22.16b}, [x0], #16\n"
- "ld1 {v30.16b}, [x0]\n"
- "mov x0, x1\n"
- "ld1 {v15.16b}, [x0], #16\n"
- "ld1 {v23.16b}, [x0], #16\n"
- "ld1 {v31.16b}, [x0]\n"
-
- "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n"
-
- GEMMLOWP_LABEL_CLEAR_ACCUMULATORS ":\n"
-
- // Clear accumulator registers (see layout below)
- "dup v8.4s, wzr\n"
- "dup v9.4s, wzr\n"
- "dup v10.4s, wzr\n"
- "dup v11.4s, wzr\n"
- "dup v12.4s, wzr\n"
- "dup v13.4s, wzr\n"
- "dup v14.4s, wzr\n"
- "dup v15.4s, wzr\n"
- "dup v16.4s, wzr\n"
- "dup v17.4s, wzr\n"
- "dup v18.4s, wzr\n"
- "dup v19.4s, wzr\n"
- "dup v20.4s, wzr\n"
- "dup v21.4s, wzr\n"
- "dup v22.4s, wzr\n"
- "dup v23.4s, wzr\n"
- "dup v24.4s, wzr\n"
- "dup v25.4s, wzr\n"
- "dup v26.4s, wzr\n"
- "dup v27.4s, wzr\n"
- "dup v28.4s, wzr\n"
- "dup v29.4s, wzr\n"
- "dup v30.4s, wzr\n"
- "dup v31.4s, wzr\n"
-
- GEMMLOWP_LABEL_BEFORE_LOOP ":\n"
-
- "subs %w[depth], %w[depth], #4\n"
-
- // The start of the loop assumes first Rhs cell is already loaded, so
- // do it here for first iteration.
- "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
-
- // And the same for the first Lhs cell.
- "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
-
- "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n"
-
- GEMMLOWP_LABEL_LOOP ":\n"
-
- // Start the MACs at the head of the loop - 1st cell from each side
- // already loaded.
- ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n"
- ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n"
- "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell.
- ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n"
- ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n"
- "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell.
- ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n"
- ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n"
- "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell.
- ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n"
- ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n"
- "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load
- // for the next iteration early.
- ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n"
- ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n"
- ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n"
- ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n"
- ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n"
- ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n"
- ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n"
- ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n"
- ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n"
- ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n"
- ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n"
- ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n"
- "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell -
- // load for the next iteration early.
- ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n"
- ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n"
-
- // Loop. Decrement loop index (depth) by 4 as udot processes 4
- // depth values.
- "subs %w[depth], %w[depth], #4\n"
- ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n"
- ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n"
-
- "bne " GEMMLOWP_LABEL_LOOP "b\n"
-
- GEMMLOWP_LABEL_AFTER_LOOP ":\n"
-
- // Final iteration. v0 and v2 were already loaded, don't load
- // them again, don't read past the end of buffers.
- ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n"
- ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n"
- "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell.
- ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n"
- ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n"
- "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell.
- ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n"
- ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n"
- "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell.
- ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n"
- ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n"
- ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n"
- ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n"
- ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n"
- ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n"
- ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n"
- ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n"
- ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n"
- ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n"
- ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n"
- ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n"
- ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n"
- ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n"
- ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n"
- ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n"
-
- // Loop. Decrement loop index (depth) by 4 as udot processes 4
- // depth values.
- "subs %w[depth], %w[depth], #4\n"
- ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n"
- ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n"
-
- // Store accumulators
- "mov x1, %[dst_ptr]\n"
- "mov x0, x1\n"
- "st1 {v8.16b}, [x0], #16\n"
- "st1 {v16.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "st1 {v24.16b}, [x0]\n"
- "mov x0, x1\n"
- "st1 {v9.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "st1 {v17.16b}, [x0], #16\n"
- "st1 {v25.16b}, [x0]\n"
- "mov x0, x1\n"
- "st1 {v10.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "st1 {v18.16b}, [x0], #16\n"
- "st1 {v26.16b}, [x0]\n"
- "mov x0, x1\n"
- "st1 {v11.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "st1 {v19.16b}, [x0], #16\n"
- "st1 {v27.16b}, [x0]\n"
- "mov x0, x1\n"
- "st1 {v12.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "st1 {v20.16b}, [x0], #16\n"
- "st1 {v28.16b}, [x0]\n"
- "mov x0, x1\n"
- "st1 {v13.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "st1 {v21.16b}, [x0], #16\n"
- "st1 {v29.16b}, [x0]\n"
- "mov x0, x1\n"
- "st1 {v14.16b}, [x0], #16\n"
- "add x1, x1, %[dst_col_stride]\n"
- "st1 {v22.16b}, [x0], #16\n"
- "st1 {v30.16b}, [x0]\n"
- "mov x0, x1\n"
- "st1 {v15.16b}, [x0], #16\n"
- "st1 {v23.16b}, [x0], #16\n"
- "st1 {v31.16b}, [x0]\n"
- : // outputs
- [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
- [depth] "+r"(depth)
- : // inputs
- [dst_ptr] "r"(dst_ptr), [dst_col_stride] "r"(dst_col_stride), [start_depth] "r"(start_depth)
- : // clobbers
- "cc", "memory", "x0", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
- "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
- "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
- }
-};
-#endif // GEMMLOWP_DOTPROD_KERNEL
-
#endif // GEMMLOWP_NEON_64
} // namespace gemmlowp
diff --git a/internal/kernel_sse.h b/internal/kernel_sse.h
index ba7959b..b879fd7 100644
--- a/internal/kernel_sse.h
+++ b/internal/kernel_sse.h
@@ -43,7 +43,6 @@
std::size_t run_depth) const override {
ScopedProfilingLabel label("optimized kernel");
assert(dst_row_stride == 1);
- (void)dst_row_stride;
std::int32_t run_depth_cells = run_depth / Format::kDepth;
/* Main loop */
@@ -218,7 +217,6 @@
std::size_t run_depth) const override {
ScopedProfilingLabel label("optimized kernel");
assert(dst_row_stride == 1);
- (void)dst_row_stride;
const std::int64_t run_depth_cells = run_depth / Format::kDepth;
const std::int64_t dst_col_stride_q = dst_col_stride;
diff --git a/internal/multi_thread_gemm.h b/internal/multi_thread_gemm.h
index 97183e7..791402f 100644
--- a/internal/multi_thread_gemm.h
+++ b/internal/multi_thread_gemm.h
@@ -19,43 +19,23 @@
#ifndef GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
#define GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
-#include <atomic> // NOLINT
-#include <chrono> // NOLINT
-#include <thread> // NOLINT
#include <vector>
#include "single_thread_gemm.h"
namespace gemmlowp {
-// This value was empirically derived on an end-to-end application benchmark.
-// That this number of cycles means that we may be sleeping substantially longer
-// than a scheduler timeslice's duration is not necessarily surprising. The
-// idea is to pick up quickly new work after having finished the previous
-// workload. When it's new work within the same GEMM as the previous work, the
-// time interval that we might be busy-waiting is very small, so for that
-// purpose it would be more than enough to sleep for 1 million cycles.
-// That is all what we would observe on a GEMM benchmark. However, in a real
-// application, after having finished a GEMM, we might do unrelated work for
-// a little while, then start on a new GEMM. Think of a neural network
-// application performing inference, where many but not all layers are
-// implemented by a GEMM. In such cases, our worker threads might be idle for
-// longer periods of time before having work again. If we let them passively
-// wait, on a mobile device, the CPU scheduler might aggressively clock down
-// or even turn off the CPU cores that they were running on. That would result
-// in a long delay the next time these need to be turned back on for the next
-// GEMM. So we need to strike a balance that reflects typical time intervals
-// between consecutive GEMM invokations, not just intra-GEMM considerations.
-// Of course, we need to balance keeping CPUs spinning longer to resume work
-// faster, versus passively waiting to conserve power.
-const int kMaxBusyWaitNOPs = 4 * 1000 * 1000;
-
-// On X86 and ARM platforms we may use NOP instructions to know how long we
-// are busy-waiting.
+// On X86 and ARM platforms we enable a busy-wait spinlock before waiting on a
+// pthread conditional variable. In order to implement that correctly we need
+// to put some explicit memory load/store barriers.
#if defined(GEMMLOWP_ALLOW_INLINE_ASM) && !defined(GEMMLOWP_NO_BUSYWAIT) && \
(defined(GEMMLOWP_ARM) || defined(GEMMLOWP_X86))
+#define GEMMLOWP_USE_BUSYWAIT
+
+const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
+
#define GEMMLOWP_NOP "nop\n"
#define GEMMLOWP_STRING_CONCAT_4(X) X X X X
@@ -63,26 +43,46 @@
#define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4)
#define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16)
-inline int DoSomeNOPs() {
+inline int Do256NOPs() {
asm volatile(GEMMLOWP_NOP64);
return 64;
}
#undef GEMMLOWP_STRING_CONCAT_4
+#undef GEMMLOWP_NOP256
#undef GEMMLOWP_NOP64
#undef GEMMLOWP_NOP16
#undef GEMMLOWP_NOP4
#undef GEMMLOWP_NOP
-#else // May not use asm NOP.
-
-// If we can't use NOPs, let's use a non-inline function call as a basic
-// thing that has some vaguely known, nonzero cost.
-GEMMLOWP_NOINLINE
-inline int DoSomeNOPs() {
- // Pretend that calling an empty function takes as long as 16 NOPs...
- return 16;
+inline void WriteBarrier() {
+#if defined(_MSC_VER)
+ MemoryBarrier();
+#elif defined(GEMMLOWP_ARM_32)
+ asm volatile("" ::: "memory");
+#elif defined(GEMMLOWP_ARM_64)
+ asm volatile("dmb ishst" ::: "memory");
+#elif defined(GEMMLOWP_X86)
+ asm volatile("sfence" ::: "memory");
+#else
+#error "Unsupported architecture for WriteBarrier."
+#endif
}
+
+inline void ReadBarrier() {
+#if defined(_MSC_VER)
+ MemoryBarrier();
+#elif defined(GEMMLOWP_ARM_32)
+ asm volatile("" ::: "memory");
+#elif defined(GEMMLOWP_ARM_64)
+ asm volatile("dmb ishld" ::: "memory");
+#elif defined(GEMMLOWP_X86)
+ asm volatile("lfence" ::: "memory");
+#else
+#error "Unsupported architecture for ReadBarrier."
+#endif
+}
+
#endif
// Waits until *var != initial_value.
@@ -108,29 +108,37 @@
// so as to avoid permanently spinning.
//
template <typename T>
-T WaitForVariableChange(std::atomic<T>* var, T initial_value,
- pthread_cond_t* cond, pthread_mutex_t* mutex) {
- // First, trivial case where the variable already changed value.
- T new_value = var->load(std::memory_order_acquire);
- if (new_value != initial_value) {
- return new_value;
- }
- // Then try busy-waiting.
- int nops = 0;
- while (nops < kMaxBusyWaitNOPs) {
- nops += DoSomeNOPs();
- new_value = var->load(std::memory_order_acquire);
+T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond,
+ pthread_mutex_t* mutex) {
+#ifdef GEMMLOWP_USE_BUSYWAIT
+ // If we are on a platform that supports it, spin for some time.
+ {
+ int nops = 0;
+ // First, trivial case where the variable already changed value.
+ T new_value = *var;
if (new_value != initial_value) {
+ ReadBarrier();
return new_value;
}
+ // Then try busy-waiting.
+ while (nops < kMaxBusyWaitNOPs) {
+ nops += Do256NOPs();
+ new_value = *var;
+ if (new_value != initial_value) {
+ ReadBarrier();
+ return new_value;
+ }
+ }
}
+#endif
// Finally, do real passive waiting.
pthread_mutex_lock(mutex);
- new_value = var->load(std::memory_order_acquire);
- while (new_value == initial_value) {
+ T new_value = *var;
+ if (new_value == initial_value) {
pthread_cond_wait(cond, mutex);
- new_value = var->load(std::memory_order_acquire);
+ new_value = *var;
+ assert(new_value != initial_value);
}
pthread_mutex_unlock(mutex);
return new_value;
@@ -139,74 +147,73 @@
// A BlockingCounter lets one thread to wait for N events to occur.
// This is how the master thread waits for all the worker threads
// to have finished working.
-// The waiting is done using a naive spinlock waiting for the atomic
-// count_ to hit the value 0. This is acceptable because in our usage
-// pattern, BlockingCounter is used only to synchronize threads after
-// short-lived tasks (performing parts of the same GEMM). It is not used
-// for synchronizing longer waits (resuming work on the next GEMM).
class BlockingCounter {
public:
- BlockingCounter() : count_(0) {}
+ BlockingCounter() : count_(0), initial_count_(0) {
+ pthread_cond_init(&cond_, nullptr);
+ pthread_mutex_init(&mutex_, nullptr);
+ }
+
+ ~BlockingCounter() {
+ pthread_cond_destroy(&cond_);
+ pthread_mutex_destroy(&mutex_);
+ }
// Sets/resets the counter; initial_count is the number of
// decrementing events that the Wait() call will be waiting for.
void Reset(std::size_t initial_count) {
- std::size_t old_count_value = count_.load(std::memory_order_relaxed);
- assert(old_count_value == 0);
- (void)old_count_value;
- count_.store(initial_count, std::memory_order_release);
+ pthread_mutex_lock(&mutex_);
+ assert(count_ == 0);
+ initial_count_ = initial_count;
+ count_ = initial_count_;
+ pthread_mutex_unlock(&mutex_);
}
// Decrements the counter; if the counter hits zero, signals
- // the threads that were waiting for that, and returns true.
+ // the thread that was waiting for that, and returns true.
// Otherwise (if the decremented count is still nonzero),
// returns false.
bool DecrementCount() {
- std::size_t old_count_value =
- count_.fetch_sub(1, std::memory_order_acq_rel);
- assert(old_count_value > 0);
- std::size_t count_value = old_count_value - 1;
- return count_value == 0;
+ pthread_mutex_lock(&mutex_);
+ assert(count_ > 0);
+ count_--;
+#ifdef GEMMLOWP_USE_BUSYWAIT
+ WriteBarrier();
+#endif
+ if (count_ == 0) {
+ pthread_cond_signal(&cond_);
+ }
+ bool retval = count_ == 0;
+ pthread_mutex_unlock(&mutex_);
+ return retval;
}
// Waits for the N other threads (N having been set by Reset())
// to hit the BlockingCounter.
void Wait() {
ScopedProfilingLabel label("BlockingCounter::Wait");
- // Busy-wait until the count value is 0.
- int nops = 0;
- while (count_.load(std::memory_order_acquire)) {
- nops += DoSomeNOPs();
- if (nops > kMaxBusyWaitNOPs) {
- nops = 0;
- // If we are unlucky, the blocking thread (that calls DecrementCount)
- // and the blocked thread (here, calling Wait) may be scheduled on
- // the same CPU, so the busy-waiting of the present thread may prevent
- // the blocking thread from resuming and unblocking.
- // If we are even unluckier, the priorities of the present thread
- // might be higher than that of the blocking thread, so just yielding
- // wouldn't allow the blocking thread to resume. So we sleep for
- // a substantial amount of time in that case. Notice that we only
- // do so after having busy-waited for kMaxBusyWaitNOPs, which is
- // typically several milliseconds, so sleeping 1 more millisecond
- // isn't terrible at that point.
- //
- // How this is mitigated in practice:
- // In practice, it is well known that the application should be
- // conservative in choosing how many threads to tell gemmlowp to use,
- // as it's hard to know how many CPU cores it will get to run on,
- // on typical mobile devices.
- // It seems impossible for gemmlowp to make this choice automatically,
- // which is why gemmlowp's default is to use only 1 thread, and
- // applications may override that if they know that they can count on
- // using more than that.
- std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ while (count_) {
+#ifdef GEMMLOWP_USE_BUSYWAIT
+ ReadBarrier();
+#else
+ // This is likely unnecessary, but is kept to ensure regressions are not
+ // introduced.
+#ifndef _WIN32
+ asm volatile("" ::: "memory");
+#endif
+#endif
+ const std::size_t count_value = count_;
+ if (count_value) {
+ WaitForVariableChange(&count_, count_value, &cond_, &mutex_);
}
}
}
private:
- std::atomic<std::size_t> count_;
+ pthread_cond_t cond_;
+ pthread_mutex_t mutex_;
+ std::size_t count_;
+ std::size_t initial_count_;
};
// A workload for a worker.
@@ -246,15 +253,11 @@
// Changes State; may be called from either the worker thread
// or the master thread; however, not all state transitions are legal,
// which is guarded by assertions.
- //
- // The Task argument is to be used only with new_state==HasWork.
- // It specifies the Task being handed to this Worker.
- void ChangeState(State new_state, Task* task = nullptr) {
+ void ChangeState(State new_state) {
ScopedProfilingLabel label("Worker::ChangeState");
pthread_mutex_lock(&state_mutex_);
- State old_state = state_.load(std::memory_order_relaxed);
- assert(old_state != new_state);
- switch (old_state) {
+ assert(new_state != state_);
+ switch (state_) {
case State::ThreadStartup:
assert(new_state == State::Ready);
break;
@@ -269,33 +272,18 @@
default:
abort();
}
- switch (new_state) {
- case State::Ready:
- if (task_) {
- // Doing work is part of reverting to 'ready' state.
- task_->Run();
- task_ = nullptr;
- }
- break;
- case State::HasWork:
- assert(!task_);
- task->local_allocator = &local_allocator_;
- task_ = task;
- break;
- default:
- break;
- }
- state_.store(new_state, std::memory_order_relaxed);
- pthread_cond_broadcast(&state_cond_);
- pthread_mutex_unlock(&state_mutex_);
- if (new_state == State::Ready) {
+ state_ = new_state;
+ pthread_cond_signal(&state_cond_);
+ if (state_ == State::Ready) {
counter_to_decrement_when_ready_->DecrementCount();
}
+ pthread_mutex_unlock(&state_mutex_);
}
// Thread entry point.
void ThreadFunc() {
ScopedProfilingLabel label("Worker::ThreadFunc");
+ RegisterCurrentThreadForProfiling();
ChangeState(State::Ready);
@@ -311,6 +299,9 @@
switch (state_to_act_upon) {
case State::HasWork:
// Got work to do! So do it, and then revert to 'Ready' state.
+ assert(task_);
+ task_->Run();
+ task_ = nullptr;
ChangeState(State::Ready);
break;
case State::ExitAsSoonAsPossible:
@@ -327,7 +318,17 @@
}
// Called by the master thead to give this worker work to do.
- void StartWork(Task* task) { ChangeState(State::HasWork, task); }
+ // It is only legal to call this if the worker
+ void StartWork(Task* task) {
+ assert(!task_);
+ task->local_allocator = &local_allocator_;
+ task_ = task;
+#ifdef GEMMLOWP_USE_BUSYWAIT
+ WriteBarrier();
+#endif
+ assert(state_ == State::Ready);
+ ChangeState(State::HasWork);
+ }
private:
// The underlying thread.
@@ -341,10 +342,7 @@
pthread_mutex_t state_mutex_;
// The state enum tells if we're currently working, waiting for work, etc.
- // Its concurrent accesses by the worker and main threads are guarded by
- // state_mutex_, and can thus use memory_order_relaxed. This still needs
- // to be a std::atomic because we use WaitForVariableChange.
- std::atomic<State> state_;
+ State state_;
// Each thread had a local allocator so they can allocate temporary
// buffers without blocking each other.
@@ -361,7 +359,9 @@
// waits for all of them to finish.
//
// See MultiThreadGemmContextBase for how other WorkersPool implementations can
-// be used.
+// be used. Note that in those implementations, StartWorker can be free to
+// ignore the <index> value; that is, the caller of WorkersPool does not rely on
+// <index> to order tasks with equal <index>.
class WorkersPool {
public:
WorkersPool() {}
@@ -372,41 +372,18 @@
}
}
- // Just executes the tasks. Does not destroy them. Similar to
- // ruy::ThreadPool::Execute.
- template <typename TaskType>
- void Execute(int tasks_count, TaskType* tasks) {
- assert(tasks_count >= 1);
+ void Execute(const std::vector<Task*>& tasks) {
+ assert(tasks.size() >= 1);
// One of the tasks will be run on the current thread.
- std::size_t workers_count = tasks_count - 1;
+ std::size_t workers_count = tasks.size() - 1;
CreateWorkers(workers_count);
assert(workers_count <= workers_.size());
counter_to_decrement_when_ready_.Reset(workers_count);
- for (std::size_t i = 0; i < tasks_count - 1; i++) {
- workers_[i]->StartWork(&tasks[i]);
- }
+ int n = 0;
+ std::for_each(tasks.begin(), --tasks.end(),
+ [this, &n](Task* task) { workers_[n++]->StartWork(task); });
// Execute the remaining workload immediately on the current thread.
- Task* task = &tasks[tasks_count - 1];
- task->local_allocator = &main_thread_task_allocator_;
- task->Run();
- // Wait for the workers submitted above to finish.
- counter_to_decrement_when_ready_.Wait();
- }
-
- // Legacy: executes the tasks and destroys them
- void LegacyExecuteAndDestroyTasks(const std::vector<Task*>& tasks) {
- std::size_t tasks_count = tasks.size();
- assert(tasks_count >= 1);
- // One of the tasks will be run on the current thread.
- std::size_t workers_count = tasks_count - 1;
- CreateWorkers(workers_count);
- assert(workers_count <= workers_.size());
- counter_to_decrement_when_ready_.Reset(workers_count);
- for (int i = 0; i < tasks_count - 1; i++) {
- workers_[i]->StartWork(tasks[i]);
- }
- // Execute the remaining workload immediately on the current thread.
- Task* task = tasks[tasks_count - 1];
+ Task* task = tasks.back();
task->local_allocator = &main_thread_task_allocator_;
task->Run();
// Wait for the workers submitted above to finish.
@@ -416,11 +393,6 @@
std::for_each(tasks.begin(), tasks.end(), [](Task* task) { delete task; });
}
- // Legacy old name of LegacyExecuteAndDestroyTasks
- void Execute(const std::vector<Task*>& tasks) {
- LegacyExecuteAndDestroyTasks(tasks);
- }
-
private:
// Ensures that the pool has at least the given count of workers.
// If any new worker has to be created, this function waits for it to
diff --git a/internal/output.h b/internal/output.h
index 92bf7b9..dcfe2b5 100644
--- a/internal/output.h
+++ b/internal/output.h
@@ -22,7 +22,6 @@
#include <cmath>
#include <tuple>
#include <type_traits>
-#include <typeinfo>
#include "../fixedpoint/fixedpoint.h"
#include "../public/output_stages.h"
@@ -180,47 +179,7 @@
int right_shift;
};
-template <int Rows, int Cols, VectorShape Shape>
-struct OutputStageEvalImpl<
- OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>,
- RegisterBlock<std::int32_t, Rows, Cols>> {
- typedef RegisterBlock<std::int32_t, Rows, Cols> InputType;
- typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType;
-
- typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> OutputStage;
-
- OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
-
- OutputType Eval(InputType input, int row, int col) const {
- OutputType output;
- const int pos = Shape == VectorShape::Row ? col : row;
- using RegisterType = typename InputType::RegisterType;
- const RegisterType result_offset_after_shift =
- Dup<RegisterType>(output_stage.result_offset_after_shift);
- auto left_shift =
- LoadForBroadcasting<InputType>(output_stage.result_exponent, pos);
- auto right_shift =
- LoadForBroadcasting<InputType>(output_stage.result_exponent, pos);
- const auto result_fixedpoint_multiplier = LoadForBroadcasting<InputType>(
- output_stage.result_fixedpoint_multiplier, pos);
- for (int i = 0; i < decltype(left_shift)::kRegisterCount; i++) {
- left_shift.buf.reg[i] = Max(left_shift.buf.reg[i], 0);
- right_shift.buf.reg[i] = Max(-right_shift.buf.reg[i], 0);
- }
- const auto mulhigh_val = BroadcastSaturatingRoundingDoublingHighMul(
- BroadcastShiftLeft(input, left_shift), result_fixedpoint_multiplier);
- const auto rdpot_val =
- BroadcastRoundingDivideByPOT(mulhigh_val, right_shift);
- for (int i = 0; i < InputType::kRegisterCount; i++) {
- output.buf.reg[i] = Add(rdpot_val.buf.reg[i], result_offset_after_shift);
- }
- return output;
- }
-
- const OutputStage& output_stage;
-};
-
-// Implementation of OutputStageSaturatingCastToUint8 for scalar data.
+// Implementation of OutputStageSaturatingCastToUint8 for scalar data
template <int Size>
struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
RegisterBuffer<std::int32_t, Size>> {
@@ -243,30 +202,7 @@
}
};
-// Implementation of OutputStageSaturatingCastToInt8 for scalar data.
-template <int Size>
-struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
- RegisterBuffer<std::int32_t, Size>> {
- typedef RegisterBuffer<std::int32_t, Size> InputType;
- typedef RegisterBuffer<std::int8_t, Size> OutputType;
- static_assert(InputType::kRegisterLanes == 1,
- "This path is only for scalar values");
-
- typedef OutputStageSaturatingCastToInt8 OutputStage;
-
- OutputStageEvalBufferImpl(const OutputStage&) {}
-
- OutputType Eval(InputType input) const {
- OutputType output;
- for (int i = 0; i < InputType::kRegisterCount; i++) {
- std::int32_t data = input.reg[i];
- output.reg[i] = data > 127 ? 127 : data < -128 ? -128 : data;
- }
- return output;
- }
-};
-
-// Implementation of OutputStageSaturatingCastToInt16 for scalar data.
+// Implementation of OutputStageSaturatingCastToInt16 for scalar data
template <int Size>
struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
RegisterBuffer<std::int32_t, Size>> {
@@ -289,28 +225,6 @@
}
};
-// Implementation of OutputStageTruncatingCastToUint8 for scalar data
-template <int Size>
-struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
- RegisterBuffer<std::int32_t, Size>> {
- typedef RegisterBuffer<std::int32_t, Size> InputType;
- typedef RegisterBuffer<std::uint8_t, Size> OutputType;
- static_assert(InputType::kRegisterLanes == 1,
- "This path is only for scalar values");
-
- typedef OutputStageTruncatingCastToUint8 OutputStage;
-
- OutputStageEvalBufferImpl(const OutputStage&) {}
-
- OutputType Eval(InputType input) const {
- OutputType output;
- for (int i = 0; i < InputType::kRegisterCount; i++) {
- output.reg[i] = input.reg[i];
- }
- return output;
- }
-};
-
template <int Rows, int Cols, typename VectorType>
struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>,
RegisterBlock<std::int32_t, Rows, Cols>> {
@@ -538,7 +452,7 @@
OutputPipelineExecutor(const OutputPipelineType& output_pipeline)
: output_pipeline_eval_impl_(output_pipeline) {}
- // Execute is the entry point into the output pipeline evaluation
+ // RunOutputPipeline is the entry point into the output pipeline evaluation
// code. It should be the only thing that unpack code calls. It takes the
// result
// of the unpack stage and stores it into the destination matrix.
diff --git a/internal/output_avx.h b/internal/output_avx.h
deleted file mode 100644
index b8f94fb..0000000
--- a/internal/output_avx.h
+++ /dev/null
@@ -1,19 +0,0 @@
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// output_avx.h: optimized AVX 2 specializations of the templates in output.h.
-
-#ifndef GEMMLOWP_INTERNAL_OUTPUT_AVX_H_
-#define GEMMLOWP_INTERNAL_OUTPUT_AVX_H_
-
-#endif // GEMMLOWP_INTERNAL_OUTPUT_AVX_H_
diff --git a/internal/output_msa.h b/internal/output_msa.h
index 0540bb3..4c8eb5d 100644
--- a/internal/output_msa.h
+++ b/internal/output_msa.h
@@ -38,14 +38,18 @@
// Signed saturate each 32-bit element to 9 bits
// (this takes full care of non-negative elements).
v4i32 tmp = __builtin_msa_sat_s_w(input.reg[0], 8);
- // Zero out negative elements.
- tmp = __builtin_msa_maxi_s_w(tmp, 0);
// Pack every 32-bit element into 16 bits.
tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
reinterpret_cast<v8i16>(tmp), reinterpret_cast<v8i16>(tmp)));
+ // Detect negative elements with arithmetic shift right (we
+ // get a 16-bit mask of all zeroes or all ones for every element).
+ v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp), 15);
+ // Zero out negative elements.
+ signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
+ reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp), 0));
// Pack every element into 8 bits.
tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
- reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp)));
+ reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
// Return 4 uint8_t elements as uint32_t.
output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
return output;
@@ -72,12 +76,15 @@
// combining all 8 elements into one vector.
tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
reinterpret_cast<v8i16>(tmp_hi), reinterpret_cast<v8i16>(tmp_lo)));
+ // Detect negative elements with arithmetic shift right (we
+ // get a 16-bit mask of all zeroes or all ones for every element).
+ v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp_lo), 15);
// Zero out negative elements.
- tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h(
- reinterpret_cast<v8i16>(tmp_lo), 0));
+ signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
+ reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp_lo), 0));
// Pack every element into 8 bits.
tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
- reinterpret_cast<v16i8>(tmp_lo), reinterpret_cast<v16i8>(tmp_lo)));
+ reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
// Return 8 uint8_t elements as 2 uint32_t's.
output.reg[0] = __builtin_msa_copy_s_w(tmp_lo, 0);
output.reg[1] = __builtin_msa_copy_s_w(tmp_lo, 1);
@@ -95,13 +102,15 @@
reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0))); \
tmp2 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( \
reinterpret_cast<v8i16>(tmp3), reinterpret_cast<v8i16>(tmp2))); \
- tmp0 = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h( \
- reinterpret_cast<v8i16>(tmp0), 0)); \
- tmp2 = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h( \
- reinterpret_cast<v8i16>(tmp2), 0)); \
- tmp0 = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( \
- reinterpret_cast<v16i8>(tmp2), reinterpret_cast<v16i8>(tmp0))); \
- out = reinterpret_cast<v16i8>(tmp0); \
+ v8i16 signs0 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp0), 15); \
+ v8i16 signs1 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp2), 15); \
+ signs0 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( \
+ reinterpret_cast<v16u8>(signs0), reinterpret_cast<v16u8>(tmp0), 0)); \
+ signs1 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( \
+ reinterpret_cast<v16u8>(signs1), reinterpret_cast<v16u8>(tmp2), 0)); \
+ signs0 = reinterpret_cast<v8i16>(__builtin_msa_pckev_b( \
+ reinterpret_cast<v16i8>(signs1), reinterpret_cast<v16i8>(signs0))); \
+ out = reinterpret_cast<v16i8>(signs0); \
}
template <>
@@ -157,8 +166,8 @@
OutputType Eval(InputType input) const {
OutputType output;
// Signed saturate each 32-bit element to 16 bits.
- v8i16 tmp =
- reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(input.reg[0], 15));
+ v8i16 tmp = reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(
+ input.reg[0], 15));
output.reg[0] = __builtin_msa_copy_s_h(tmp, 0);
output.reg[1] = __builtin_msa_copy_s_h(tmp, 2);
output.reg[2] = __builtin_msa_copy_s_h(tmp, 4);
@@ -167,12 +176,12 @@
}
};
-#define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1) \
- { \
- v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15); \
- v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15); \
- out = __builtin_msa_pckev_h(reinterpret_cast<v8i16>(tmp1), \
- reinterpret_cast<v8i16>(tmp0)); \
+#define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1) \
+ { \
+ v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15); \
+ v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15); \
+ out = __builtin_msa_pckev_h( \
+ reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)); \
}
template <>
@@ -232,117 +241,6 @@
#undef GEMMLOWP_MIPS_SAT_I16_8
-template <>
-struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
- RegBufferInt32<4>> {
- typedef RegBufferInt32<4> InputType;
- typedef RegBufferUint8<4> OutputType;
-
- typedef OutputStageTruncatingCastToUint8 OutputStage;
-
- OutputStageEvalBufferImpl(const OutputStage&) {}
-
- OutputType Eval(InputType input) const {
- OutputType output;
- // Pack every 32-bit element into 16 bits.
- v4i32 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
- reinterpret_cast<v8i16>(input.reg[0]),
- reinterpret_cast<v8i16>(input.reg[0])));
- // Pack every element into 8 bits.
- tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
- reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp)));
- // Return 4 uint8_t elements as uint32_t.
- output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
- return output;
- }
-};
-
-template <>
-struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
- RegBufferInt32<8>> {
- typedef RegBufferInt32<8> InputType;
- typedef RegBufferUint8<8> OutputType;
-
- typedef OutputStageTruncatingCastToUint8 OutputStage;
-
- OutputStageEvalBufferImpl(const OutputStage&) {}
-
- OutputType Eval(InputType input) const {
- OutputType output;
- // Pack every 32-bit element into 16 bits.
- v4i32 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
- reinterpret_cast<v8i16>(input.reg[1]),
- reinterpret_cast<v8i16>(input.reg[0])));
- // Pack every element into 8 bits.
- tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
- reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp)));
- // Return 8 uint8_t elements as 2 uint32_t's.
- output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
- output.reg[1] = __builtin_msa_copy_s_w(tmp, 1);
- return output;
- }
-};
-
-template <>
-struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
- RegBufferInt32<16>> {
- typedef RegBufferInt32<16> InputType;
- typedef RegBufferUint8<16> OutputType;
-
- typedef OutputStageTruncatingCastToUint8 OutputStage;
-
- OutputStageEvalBufferImpl(const OutputStage&) {}
-
- OutputType Eval(InputType input) const {
- OutputType output;
- // Pack every 32-bit element into 16 bits.
- v8i16 tmp0 = __builtin_msa_pckev_h(
- reinterpret_cast<v8i16>(input.reg[1]),
- reinterpret_cast<v8i16>(input.reg[0]));
- v8i16 tmp1 = __builtin_msa_pckev_h(
- reinterpret_cast<v8i16>(input.reg[3]),
- reinterpret_cast<v8i16>(input.reg[2]));
- // Pack every element into 8 bits.
- output.reg[0] = __builtin_msa_pckev_b(
- reinterpret_cast<v16i8>(tmp1), reinterpret_cast<v16i8>(tmp0));
- return output;
- }
-};
-
-template <>
-struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
- RegBufferInt32<32>> {
- typedef RegBufferInt32<32> InputType;
- typedef RegBufferUint8<32> OutputType;
-
- typedef OutputStageTruncatingCastToUint8 OutputStage;
-
- OutputStageEvalBufferImpl(const OutputStage&) {}
-
- OutputType Eval(InputType input) const {
- OutputType output;
- // Pack every 32-bit element into 16 bits.
- v8i16 tmp0 = __builtin_msa_pckev_h(
- reinterpret_cast<v8i16>(input.reg[1]),
- reinterpret_cast<v8i16>(input.reg[0]));
- v8i16 tmp1 = __builtin_msa_pckev_h(
- reinterpret_cast<v8i16>(input.reg[3]),
- reinterpret_cast<v8i16>(input.reg[2]));
- v8i16 tmp2 = __builtin_msa_pckev_h(
- reinterpret_cast<v8i16>(input.reg[5]),
- reinterpret_cast<v8i16>(input.reg[4]));
- v8i16 tmp3 = __builtin_msa_pckev_h(
- reinterpret_cast<v8i16>(input.reg[7]),
- reinterpret_cast<v8i16>(input.reg[6]));
- // Pack every element into 8 bits.
- output.reg[0] = __builtin_msa_pckev_b(
- reinterpret_cast<v16i8>(tmp1), reinterpret_cast<v16i8>(tmp0));
- output.reg[1] = __builtin_msa_pckev_b(
- reinterpret_cast<v16i8>(tmp3), reinterpret_cast<v16i8>(tmp2));
- return output;
- }
-};
-
template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
@@ -576,50 +474,50 @@
}
} else {
// top-left 4x4
- v4i32 t0 = reinterpret_cast<v4i32>(
- __builtin_msa_ilvr_h(src.buf.reg[1], src.buf.reg[0]));
- v4i32 t1 = reinterpret_cast<v4i32>(
- __builtin_msa_ilvr_h(src.buf.reg[3], src.buf.reg[2]));
+ v4i32 t0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[1],
+ src.buf.reg[0]));
+ v4i32 t1 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[3],
+ src.buf.reg[2]));
v2i64 u0 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t1, t0));
v2i64 u1 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t1, t0));
// top-right 4x4
- v4i32 t2 = reinterpret_cast<v4i32>(
- __builtin_msa_ilvr_h(src.buf.reg[5], src.buf.reg[4]));
- v4i32 t3 = reinterpret_cast<v4i32>(
- __builtin_msa_ilvr_h(src.buf.reg[7], src.buf.reg[6]));
+ v4i32 t2 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[5],
+ src.buf.reg[4]));
+ v4i32 t3 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[7],
+ src.buf.reg[6]));
v2i64 u2 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t3, t2));
v2i64 u3 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t3, t2));
// bottom-left 4x4
- v4i32 t4 = reinterpret_cast<v4i32>(
- __builtin_msa_ilvl_h(src.buf.reg[1], src.buf.reg[0]));
- v4i32 t5 = reinterpret_cast<v4i32>(
- __builtin_msa_ilvl_h(src.buf.reg[3], src.buf.reg[2]));
+ v4i32 t4 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[1],
+ src.buf.reg[0]));
+ v4i32 t5 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[3],
+ src.buf.reg[2]));
v2i64 u4 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t5, t4));
v2i64 u5 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t5, t4));
// bottom-right 4x4
- v4i32 t6 = reinterpret_cast<v4i32>(
- __builtin_msa_ilvl_h(src.buf.reg[5], src.buf.reg[4]));
- v4i32 t7 = reinterpret_cast<v4i32>(
- __builtin_msa_ilvl_h(src.buf.reg[7], src.buf.reg[6]));
+ v4i32 t6 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[5],
+ src.buf.reg[4]));
+ v4i32 t7 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[7],
+ src.buf.reg[6]));
v2i64 u6 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t7, t6));
v2i64 u7 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t7, t6));
- StoreInt16x8(dst->data(row + 0, col),
- reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u2, u0)));
- StoreInt16x8(dst->data(row + 1, col),
- reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u2, u0)));
- StoreInt16x8(dst->data(row + 2, col),
- reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u3, u1)));
- StoreInt16x8(dst->data(row + 3, col),
- reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u3, u1)));
- StoreInt16x8(dst->data(row + 4, col),
- reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u6, u4)));
- StoreInt16x8(dst->data(row + 5, col),
- reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u6, u4)));
- StoreInt16x8(dst->data(row + 6, col),
- reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u7, u5)));
- StoreInt16x8(dst->data(row + 7, col),
- reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u7, u5)));
+ StoreInt16x8(dst->data(row + 0, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvr_d(u2, u0)));
+ StoreInt16x8(dst->data(row + 1, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvl_d(u2, u0)));
+ StoreInt16x8(dst->data(row + 2, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvr_d(u3, u1)));
+ StoreInt16x8(dst->data(row + 3, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvl_d(u3, u1)));
+ StoreInt16x8(dst->data(row + 4, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvr_d(u6, u4)));
+ StoreInt16x8(dst->data(row + 5, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvl_d(u6, u4)));
+ StoreInt16x8(dst->data(row + 6, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvr_d(u7, u5)));
+ StoreInt16x8(dst->data(row + 7, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvl_d(u7, u5)));
}
}
};
@@ -687,391 +585,6 @@
}
};
-// There's no way to express in C++ the desired machine code for
-// StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> and
-// StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType>.
-// Hence, if we can, we use inline assembly, which takes advantage
-// of little-endian byte order and specifics of different CPU revisions.
-// Note, clang currently can't derive MSA register names from floating-
-// point register names and vice versa in inline assembly.
-#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) && \
- !defined(__clang__)
-
-// Instructions for pointer-sized operands.
-#ifdef GEMMLOWP_MIPS_64
-#define GEMMLOWP_MIPS_XADDU "daddu"
-#define GEMMLOWP_MIPS_XLSA "dlsa"
-#else
-#define GEMMLOWP_MIPS_XADDU "addu"
-#define GEMMLOWP_MIPS_XLSA "lsa"
-#endif
-
-// Stores 4 8-byte half-vectors with a stride.
-inline void MipsMsaStore4x8(const RegBlockUint8<8, 4>& src,
- std::uint8_t* dst_ptr, int stride) {
-#if (__mips_isa_rev >= 6)
- // Assembly temporaries that will be handily referred to by their names.
- std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3;
- v16i8 vtmp0, vtmp1;
- asm volatile(
- GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
- "ilvl.d %w[vtmp0], %w[src0], %w[src0]\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
- "ilvl.d %w[vtmp1], %w[src1], %w[src1]\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
- "sdc1 %[src0], 0(%[dst_ptr0])\n"
- "sdc1 %[vtmp0], 0(%[dst_ptr1])\n"
- "sdc1 %[src1], 0(%[dst_ptr2])\n"
- "sdc1 %[vtmp1], 0(%[dst_ptr3])\n"
- :
- // Outputs.
- [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
- [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
- [vtmp0] "=&f"(vtmp0), [vtmp1] "=&f"(vtmp1)
- :
- // Inputs.
- [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
- [stride] "r"(stride)
- :
- // Clobbers.
- "memory");
-#else
- // Assembly temporaries that will be handily referred to by their names.
- std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3;
- int tmp0, tmp1, tmp2, tmp3;
- asm volatile(
- GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
- "copy_s.w %[tmp0], %w[src0][0]\n"
- "copy_s.w %[tmp1], %w[src0][1]\n"
- "copy_s.w %[tmp2], %w[src0][2]\n"
- "copy_s.w %[tmp3], %w[src0][3]\n"
- "swr %[tmp0], 0(%[dst_ptr0])\n"
- "swl %[tmp0], 3(%[dst_ptr0])\n"
- "swr %[tmp1], 4(%[dst_ptr0])\n"
- "swl %[tmp1], 7(%[dst_ptr0])\n"
- "swr %[tmp2], 0(%[dst_ptr1])\n"
- "swl %[tmp2], 3(%[dst_ptr1])\n"
- "swr %[tmp3], 4(%[dst_ptr1])\n"
- "swl %[tmp3], 7(%[dst_ptr1])\n"
- "copy_s.w %[tmp0], %w[src1][0]\n"
- "copy_s.w %[tmp1], %w[src1][1]\n"
- "copy_s.w %[tmp2], %w[src1][2]\n"
- "copy_s.w %[tmp3], %w[src1][3]\n"
- "swr %[tmp0], 0(%[dst_ptr2])\n"
- "swl %[tmp0], 3(%[dst_ptr2])\n"
- "swr %[tmp1], 4(%[dst_ptr2])\n"
- "swl %[tmp1], 7(%[dst_ptr2])\n"
- "swr %[tmp2], 0(%[dst_ptr3])\n"
- "swl %[tmp2], 3(%[dst_ptr3])\n"
- "swr %[tmp3], 4(%[dst_ptr3])\n"
- "swl %[tmp3], 7(%[dst_ptr3])\n"
- :
- // Outputs.
- [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
- [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), [tmp0] "=&r"(tmp0),
- [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), [tmp3] "=&r"(tmp3)
- :
- // Inputs.
- [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
- [stride] "r"(stride)
- :
- // Clobbers.
- "memory");
-#endif
-}
-
-// Stores 8 4-byte quarter-vectors with a stride.
-inline void MipsMsaStore8x4(const RegBlockUint8<4, 8>& src,
- std::uint8_t* dst_ptr, int stride) {
-#if (__mips_isa_rev >= 6)
- // Assembly temporaries that will be handily referred to by their names.
- std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5,
- *dst_ptr6, *dst_ptr7;
- int tmp1, tmp2, tmp3;
- asm volatile(
- GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n"
- "copy_s.w %[tmp1], %w[src0][1]\n"
- "copy_s.w %[tmp2], %w[src0][2]\n"
- "copy_s.w %[tmp3], %w[src0][3]\n"
- "swc1 %[src0], 0(%[dst_ptr0])\n"
- "sw %[tmp1], 0(%[dst_ptr1])\n"
- "sw %[tmp2], 0(%[dst_ptr2])\n"
- "sw %[tmp3], 0(%[dst_ptr3])\n"
- "copy_s.w %[tmp1], %w[src1][1]\n"
- "copy_s.w %[tmp2], %w[src1][2]\n"
- "copy_s.w %[tmp3], %w[src1][3]\n"
- "swc1 %[src1], 0(%[dst_ptr4])\n"
- "sw %[tmp1], 0(%[dst_ptr5])\n"
- "sw %[tmp2], 0(%[dst_ptr6])\n"
- "sw %[tmp3], 0(%[dst_ptr7])\n"
- :
- // Outputs.
- [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
- [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
- [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5),
- [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7),
- [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), [tmp3] "=&r"(tmp3)
- :
- // Inputs.
- [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
- [stride] "r"(stride)
- :
- // Clobbers.
- "memory");
-#else
- // Assembly temporaries that will be handily referred to by their names.
- std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5,
- *dst_ptr6, *dst_ptr7;
- int tmp0, tmp1, tmp2, tmp3;
- asm volatile(
- GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n"
- "copy_s.w %[tmp0], %w[src0][0]\n"
- "copy_s.w %[tmp1], %w[src0][1]\n"
- "copy_s.w %[tmp2], %w[src0][2]\n"
- "copy_s.w %[tmp3], %w[src0][3]\n"
- "swr %[tmp0], 0(%[dst_ptr0])\n"
- "swl %[tmp0], 3(%[dst_ptr0])\n"
- "swr %[tmp1], 0(%[dst_ptr1])\n"
- "swl %[tmp1], 3(%[dst_ptr1])\n"
- "swr %[tmp2], 0(%[dst_ptr2])\n"
- "swl %[tmp2], 3(%[dst_ptr2])\n"
- "swr %[tmp3], 0(%[dst_ptr3])\n"
- "swl %[tmp3], 3(%[dst_ptr3])\n"
- "copy_s.w %[tmp0], %w[src1][0]\n"
- "copy_s.w %[tmp1], %w[src1][1]\n"
- "copy_s.w %[tmp2], %w[src1][2]\n"
- "copy_s.w %[tmp3], %w[src1][3]\n"
- "swr %[tmp0], 0(%[dst_ptr4])\n"
- "swl %[tmp0], 3(%[dst_ptr4])\n"
- "swr %[tmp1], 0(%[dst_ptr5])\n"
- "swl %[tmp1], 3(%[dst_ptr5])\n"
- "swr %[tmp2], 0(%[dst_ptr6])\n"
- "swl %[tmp2], 3(%[dst_ptr6])\n"
- "swr %[tmp3], 0(%[dst_ptr7])\n"
- "swl %[tmp3], 3(%[dst_ptr7])\n"
- :
- // Outputs.
- [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
- [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
- [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5),
- [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7),
- [tmp0] "=&r"(tmp0), [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2),
- [tmp3] "=&r"(tmp3)
- :
- // Inputs.
- [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
- [stride] "r"(stride)
- :
- // Clobbers.
- "memory");
-#endif
-}
-
-// Stores 8 8-byte half-vectors with a stride.
-inline void MipsMsaStore8x8(const RegBlockUint8<8, 8>& src,
- std::uint8_t* dst_ptr, int stride) {
-#if (__mips_isa_rev >= 6)
- // Assembly temporaries that will be handily referred to by their names.
- std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5,
- *dst_ptr6, *dst_ptr7;
- v16i8 vtmp0, vtmp1, vtmp2, vtmp3;
- asm volatile(
- "ilvl.d %w[vtmp0], %w[src0], %w[src0]\n"
- GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
- "ilvl.d %w[vtmp1], %w[src1], %w[src1]\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
- "ilvl.d %w[vtmp2], %w[src2], %w[src2]\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n"
- "ilvl.d %w[vtmp3], %w[src3], %w[src3]\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n"
- "sdc1 %[src0], 0(%[dst_ptr0])\n"
- "sdc1 %[vtmp0], 0(%[dst_ptr1])\n"
- "sdc1 %[src1], 0(%[dst_ptr2])\n"
- "sdc1 %[vtmp1], 0(%[dst_ptr3])\n"
- "sdc1 %[src2], 0(%[dst_ptr4])\n"
- "sdc1 %[vtmp2], 0(%[dst_ptr5])\n"
- "sdc1 %[src3], 0(%[dst_ptr6])\n"
- "sdc1 %[vtmp3], 0(%[dst_ptr7])\n"
- :
- // Outputs.
- [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
- [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
- [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5),
- [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7),
- [vtmp0] "=&f"(vtmp0), [vtmp1] "=&f"(vtmp1), [vtmp2] "=&f"(vtmp2),
- [vtmp3] "=&f"(vtmp3)
- :
- // Inputs.
- [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
- [src2] "f"(src.buf.reg[2]), [src3] "f"(src.buf.reg[3]),
- [stride] "r"(stride)
- :
- // Clobbers.
- "memory");
-#else
- // Assembly temporaries that will be handily referred to by their names.
- std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5,
- *dst_ptr6, *dst_ptr7;
- int tmp0, tmp1, tmp2, tmp3;
- asm volatile(
- GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n"
- GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n"
- "copy_s.w %[tmp0], %w[src0][0]\n"
- "copy_s.w %[tmp1], %w[src0][1]\n"
- "copy_s.w %[tmp2], %w[src0][2]\n"
- "copy_s.w %[tmp3], %w[src0][3]\n"
- "swr %[tmp0], 0(%[dst_ptr0])\n"
- "swl %[tmp0], 3(%[dst_ptr0])\n"
- "swr %[tmp1], 4(%[dst_ptr0])\n"
- "swl %[tmp1], 7(%[dst_ptr0])\n"
- "swr %[tmp2], 0(%[dst_ptr1])\n"
- "swl %[tmp2], 3(%[dst_ptr1])\n"
- "swr %[tmp3], 4(%[dst_ptr1])\n"
- "swl %[tmp3], 7(%[dst_ptr1])\n"
- "copy_s.w %[tmp0], %w[src1][0]\n"
- "copy_s.w %[tmp1], %w[src1][1]\n"
- "copy_s.w %[tmp2], %w[src1][2]\n"
- "copy_s.w %[tmp3], %w[src1][3]\n"
- "swr %[tmp0], 0(%[dst_ptr2])\n"
- "swl %[tmp0], 3(%[dst_ptr2])\n"
- "swr %[tmp1], 4(%[dst_ptr2])\n"
- "swl %[tmp1], 7(%[dst_ptr2])\n"
- "swr %[tmp2], 0(%[dst_ptr3])\n"
- "swl %[tmp2], 3(%[dst_ptr3])\n"
- "swr %[tmp3], 4(%[dst_ptr3])\n"
- "swl %[tmp3], 7(%[dst_ptr3])\n"
- "copy_s.w %[tmp0], %w[src2][0]\n"
- "copy_s.w %[tmp1], %w[src2][1]\n"
- "copy_s.w %[tmp2], %w[src2][2]\n"
- "copy_s.w %[tmp3], %w[src2][3]\n"
- "swr %[tmp0], 0(%[dst_ptr4])\n"
- "swl %[tmp0], 3(%[dst_ptr4])\n"
- "swr %[tmp1], 4(%[dst_ptr4])\n"
- "swl %[tmp1], 7(%[dst_ptr4])\n"
- "swr %[tmp2], 0(%[dst_ptr5])\n"
- "swl %[tmp2], 3(%[dst_ptr5])\n"
- "swr %[tmp3], 4(%[dst_ptr5])\n"
- "swl %[tmp3], 7(%[dst_ptr5])\n"
- "copy_s.w %[tmp0], %w[src3][0]\n"
- "copy_s.w %[tmp1], %w[src3][1]\n"
- "copy_s.w %[tmp2], %w[src3][2]\n"
- "copy_s.w %[tmp3], %w[src3][3]\n"
- "swr %[tmp0], 0(%[dst_ptr6])\n"
- "swl %[tmp0], 3(%[dst_ptr6])\n"
- "swr %[tmp1], 4(%[dst_ptr6])\n"
- "swl %[tmp1], 7(%[dst_ptr6])\n"
- "swr %[tmp2], 0(%[dst_ptr7])\n"
- "swl %[tmp2], 3(%[dst_ptr7])\n"
- "swr %[tmp3], 4(%[dst_ptr7])\n"
- "swl %[tmp3], 7(%[dst_ptr7])\n"
- :
- // Outputs.
- [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
- [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
- [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5),
- [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7),
- [tmp0] "=&r"(tmp0), [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2),
- [tmp3] "=&r"(tmp3)
- :
- // Inputs.
- [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
- [src2] "f"(src.buf.reg[2]), [src3] "f"(src.buf.reg[3]),
- [stride] "r"(stride)
- :
- // Clobbers.
- "memory");
-#endif
-}
-
-#undef GEMMLOWP_MIPS_XADDU
-#undef GEMMLOWP_MIPS_XLSA
-
-// Transposes a column-major 8x4 block for storage into a row-major matrix.
-inline RegBlockUint8<4, 8> Transpose(const RegBlockUint8<8, 4>& src) {
- v16i8 tmp0 = __builtin_msa_ilvr_b(src.buf.reg[1], src.buf.reg[0]);
- v16i8 tmp1 = __builtin_msa_ilvl_b(src.buf.reg[1], src.buf.reg[0]);
- RegBlockUint8<4, 8> result;
- result.buf.reg[0] = __builtin_msa_ilvr_b(tmp1, tmp0);
- result.buf.reg[1] = __builtin_msa_ilvl_b(tmp1, tmp0);
- return result;
-}
-
-inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) {
- v16i8 tmp0[4];
- tmp0[0] = __builtin_msa_ilvr_b(src.buf.reg[1], src.buf.reg[0]);
- tmp0[1] = __builtin_msa_ilvl_b(src.buf.reg[1], src.buf.reg[0]);
- tmp0[2] = __builtin_msa_ilvr_b(src.buf.reg[3], src.buf.reg[2]);
- tmp0[3] = __builtin_msa_ilvl_b(src.buf.reg[3], src.buf.reg[2]);
- v16i8 tmp1[4];
- tmp1[0] = __builtin_msa_ilvr_b(tmp0[1], tmp0[0]);
- tmp1[1] = __builtin_msa_ilvl_b(tmp0[1], tmp0[0]);
- tmp1[2] = __builtin_msa_ilvr_b(tmp0[3], tmp0[2]);
- tmp1[3] = __builtin_msa_ilvl_b(tmp0[3], tmp0[2]);
- RegBlockUint8<8, 8> result;
- result.buf.reg[0] = reinterpret_cast<v16i8>(__builtin_msa_ilvr_w(
- reinterpret_cast<v4i32>(tmp1[2]), reinterpret_cast<v4i32>(tmp1[0])));
- result.buf.reg[1] = reinterpret_cast<v16i8>(__builtin_msa_ilvl_w(
- reinterpret_cast<v4i32>(tmp1[2]), reinterpret_cast<v4i32>(tmp1[0])));
- result.buf.reg[2] = reinterpret_cast<v16i8>(__builtin_msa_ilvr_w(
- reinterpret_cast<v4i32>(tmp1[3]), reinterpret_cast<v4i32>(tmp1[1])));
- result.buf.reg[3] = reinterpret_cast<v16i8>(__builtin_msa_ilvl_w(
- reinterpret_cast<v4i32>(tmp1[3]), reinterpret_cast<v4i32>(tmp1[1])));
- return result;
-}
-
-template <typename DstType>
-struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
- static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
- int col) {
- if (DstType::kOrder == MapOrder::ColMajor) {
- std::uint8_t* dst_ptr = dst->data(row, col);
- int col_stride = dst->cols_stride();
- MipsMsaStore4x8(src, dst_ptr, col_stride);
- } else {
- const auto& block = Transpose(src);
- std::uint8_t* dst_ptr = dst->data(row, col);
- int row_stride = dst->rows_stride();
- MipsMsaStore8x4(block, dst_ptr, row_stride);
- }
- }
-};
-
-template <typename DstType>
-struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
- static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
- int col) {
- const auto& block =
- (DstType::kOrder == MapOrder::ColMajor) ? src : Transpose(src);
- std::uint8_t* dst_ptr = dst->data(row, col);
- int stride = dst->stride();
- MipsMsaStore8x8(block, dst_ptr, stride);
- }
-};
-
-#else
-
template <typename DstType>
struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
@@ -1104,8 +617,6 @@
}
};
-#endif // Endianness, compiler.
-
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
diff --git a/internal/output_neon.h b/internal/output_neon.h
index 52ea1bc..911fed0 100644
--- a/internal/output_neon.h
+++ b/internal/output_neon.h
@@ -108,90 +108,6 @@
};
template <>
-struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
- RegBufferInt32<4>> {
- typedef RegBufferInt32<4> InputType;
- typedef RegBufferInt8<4> OutputType;
-
- typedef OutputStageSaturatingCastToInt8 OutputStage;
-
- OutputStageEvalBufferImpl(const OutputStage&) {}
-
- OutputType Eval(InputType input) const {
- OutputType output;
- int16x4_t res_16 = vqmovn_s32(input.reg[0]);
- int8x8_t res_8 = vqmovn_s16(vcombine_s16(res_16, res_16));
- output.reg[0] = vget_lane_s32(vreinterpret_s32_s8(res_8), 0);
- return output;
- }
-};
-
-template <>
-struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
- RegBufferInt32<8>> {
- typedef RegBufferInt32<8> InputType;
- typedef RegBufferInt8<8> OutputType;
-
- typedef OutputStageSaturatingCastToInt8 OutputStage;
-
- OutputStageEvalBufferImpl(const OutputStage&) {}
-
- OutputType Eval(InputType input) const {
- OutputType output;
- int16x8_t res_16 =
- vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
- output.reg[0] = vqmovn_s16(res_16);
- return output;
- }
-};
-
-template <>
-struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
- RegBufferInt32<16>> {
- typedef RegBufferInt32<16> InputType;
- typedef RegBufferInt8<16> OutputType;
-
- typedef OutputStageSaturatingCastToInt8 OutputStage;
-
- OutputStageEvalBufferImpl(const OutputStage&) {}
-
- OutputType Eval(InputType input) const {
- OutputType output;
- int16x8_t res_16_0 =
- vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
- int16x8_t res_16_1 =
- vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
- output.reg[0] = vqmovn_s16(res_16_0);
- output.reg[1] = vqmovn_s16(res_16_1);
- return output;
- }
-};
-
-template <>
-struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
- RegBufferInt32<32>> {
- typedef RegBufferInt32<32> InputType;
- typedef RegBufferInt8<32> OutputType;
-
- typedef OutputStageSaturatingCastToInt8 OutputStage;
-
- OutputStageEvalBufferImpl(const OutputStage&) {}
-
- OutputType Eval(InputType input) const {
- OutputType output;
- int16x8_t res_16[4];
- for (int i = 0; i < 4; i++) {
- res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]),
- vqmovn_s32(input.reg[2 * i + 1]));
- }
- for (int i = 0; i < 4; i++) {
- output.reg[i] = vqmovn_s16(res_16[i]);
- }
- return output;
- }
-};
-
-template <>
struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
RegBufferInt32<4>> {
typedef RegBufferInt32<4> InputType;
@@ -640,8 +556,8 @@
vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]);
}
} else {
- int row_stride = dst->rows_stride();
for (int i = 0; i < 4; i++) {
+ int row_stride = dst->rows_stride();
std::uint8_t* col_ptr = dst_ptr + i;
vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0);
vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1);
@@ -707,153 +623,6 @@
};
template <typename DstType>
-struct StoreFinalOutputImpl<RegBlockInt8<4, 1>, DstType> {
- static void Run(const RegBlockInt8<4, 1>& src, DstType* dst, int row,
- int col) {
- const std::int32_t src_reg = src.buf.reg[0];
- for (int i = 0; i < 4; i++) {
- *dst->data(row + i, col) = (src_reg >> (8 * i));
- }
- }
-};
-
-template <typename DstType>
-struct StoreFinalOutputImpl<RegBlockInt8<1, 4>, DstType> {
- static void Run(const RegBlockInt8<1, 4>& src, DstType* dst, int row,
- int col) {
- for (int i = 0; i < 4; i++) {
- *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
- }
- }
-};
-
-template <typename DstType>
-struct StoreFinalOutputImpl<RegBlockInt8<8, 1>, DstType> {
- static void Run(const RegBlockInt8<8, 1>& src, DstType* dst, int row,
- int col) {
- std::int8_t* dst_ptr = dst->data(row, col);
- if (DstType::kOrder == MapOrder::ColMajor) {
- vst1_s8(dst_ptr, src.buf.reg[0]);
- } else {
- const int row_stride = dst->rows_stride();
- vst1_lane_s8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0);
- vst1_lane_s8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1);
- vst1_lane_s8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2);
- vst1_lane_s8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3);
- vst1_lane_s8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4);
- vst1_lane_s8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5);
- vst1_lane_s8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6);
- vst1_lane_s8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7);
- }
- }
-};
-
-template <typename DstType>
-struct StoreFinalOutputImpl<RegBlockInt8<4, 4>, DstType> {
- static void Run(const RegBlockInt8<4, 4>& src, DstType* dst, int row,
- int col) {
- std::int8_t* dst_ptr = dst->data(row, col);
- const int row_stride = dst->rows_stride();
- const int col_stride = dst->cols_stride();
- for (int i = 0; i < 2; i++) {
- vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride,
- src.buf.reg[i], 0);
- vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride,
- src.buf.reg[i], 1);
- vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride,
- src.buf.reg[i], 2);
- vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride,
- src.buf.reg[i], 3);
- vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride,
- src.buf.reg[i], 4);
- vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride,
- src.buf.reg[i], 5);
- vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride,
- src.buf.reg[i], 6);
- vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride,
- src.buf.reg[i], 7);
- }
- }
-};
-
-template <typename DstType>
-struct StoreFinalOutputImpl<RegBlockInt8<8, 4>, DstType> {
- static void Run(const RegBlockInt8<8, 4>& src, DstType* dst, int row,
- int col) {
- std::int8_t* dst_ptr = dst->data(row, col);
- if (DstType::kOrder == MapOrder::ColMajor) {
- int col_stride = dst->cols_stride();
- for (int i = 0; i < 4; i++) {
- vst1_s8(dst_ptr + i * col_stride, src.buf.reg[i]);
- }
- } else {
- int row_stride = dst->rows_stride();
- for (int i = 0; i < 4; i++) {
- std::int8_t* col_ptr = dst_ptr + i;
- vst1_lane_s8(col_ptr + 0 * row_stride, src.buf.reg[i], 0);
- vst1_lane_s8(col_ptr + 1 * row_stride, src.buf.reg[i], 1);
- vst1_lane_s8(col_ptr + 2 * row_stride, src.buf.reg[i], 2);
- vst1_lane_s8(col_ptr + 3 * row_stride, src.buf.reg[i], 3);
- vst1_lane_s8(col_ptr + 4 * row_stride, src.buf.reg[i], 4);
- vst1_lane_s8(col_ptr + 5 * row_stride, src.buf.reg[i], 5);
- vst1_lane_s8(col_ptr + 6 * row_stride, src.buf.reg[i], 6);
- vst1_lane_s8(col_ptr + 7 * row_stride, src.buf.reg[i], 7);
- }
- }
- }
-};
-
-inline RegBlockInt8<8, 8> Transpose(const RegBlockInt8<8, 8>& src) {
- int8x8x2_t a[4];
- a[0] = vtrn_s8(src.buf.reg[0], src.buf.reg[1]);
- a[1] = vtrn_s8(src.buf.reg[2], src.buf.reg[3]);
- a[2] = vtrn_s8(src.buf.reg[4], src.buf.reg[5]);
- a[3] = vtrn_s8(src.buf.reg[6], src.buf.reg[7]);
- int16x4x2_t b[4];
- b[0] = vtrn_s16(vreinterpret_s16_s8(a[0].val[0]),
- vreinterpret_s16_s8(a[1].val[0]));
- b[1] = vtrn_s16(vreinterpret_s16_s8(a[0].val[1]),
- vreinterpret_s16_s8(a[1].val[1]));
- b[2] = vtrn_s16(vreinterpret_s16_s8(a[2].val[0]),
- vreinterpret_s16_s8(a[3].val[0]));
- b[3] = vtrn_s16(vreinterpret_s16_s8(a[2].val[1]),
- vreinterpret_s16_s8(a[3].val[1]));
- int32x2x2_t c[4];
- c[0] = vtrn_s32(vreinterpret_s32_s16(b[0].val[0]),
- vreinterpret_s32_s16(b[2].val[0]));
- c[1] = vtrn_s32(vreinterpret_s32_s16(b[1].val[0]),
- vreinterpret_s32_s16(b[3].val[0]));
- c[2] = vtrn_s32(vreinterpret_s32_s16(b[0].val[1]),
- vreinterpret_s32_s16(b[2].val[1]));
- c[3] = vtrn_s32(vreinterpret_s32_s16(b[1].val[1]),
- vreinterpret_s32_s16(b[3].val[1]));
- RegBlockInt8<8, 8> result;
- result.buf.reg[0] = vreinterpret_s8_s32(c[0].val[0]);
- result.buf.reg[1] = vreinterpret_s8_s32(c[1].val[0]);
- result.buf.reg[2] = vreinterpret_s8_s32(c[2].val[0]);
- result.buf.reg[3] = vreinterpret_s8_s32(c[3].val[0]);
- result.buf.reg[4] = vreinterpret_s8_s32(c[0].val[1]);
- result.buf.reg[5] = vreinterpret_s8_s32(c[1].val[1]);
- result.buf.reg[6] = vreinterpret_s8_s32(c[2].val[1]);
- result.buf.reg[7] = vreinterpret_s8_s32(c[3].val[1]);
- return result;
-}
-
-template <typename DstType>
-struct StoreFinalOutputImpl<RegBlockInt8<8, 8>, DstType> {
- static void Run(const RegBlockInt8<8, 8>& src, DstType* dst, int row,
- int col) {
- const auto& block =
- DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src);
- std::int8_t* dst_ptr = dst->data(row, col);
- int stride = dst->stride();
- for (int i = 0; i < 8; i++) {
- vst1_s8(dst_ptr + i * stride, block.buf.reg[i]);
- }
- }
-};
-
-template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
int col) {
diff --git a/internal/pack.h b/internal/pack.h
index 7c43d6e..cb4b93a 100644
--- a/internal/pack.h
+++ b/internal/pack.h
@@ -72,10 +72,6 @@
pos_ += n * KernelSideFormat::Cell::kSize;
}
- // TODO(suharshs): The datatype can now be int8 as well. We could introduce a
- // new int8 current_data impl as well. This change would propagate to all pack
- // impls and the Kernel::Run API, which all assume uint8. For now we leave
- // this as-is pending future refactor.
const std::uint8_t* current_data() const {
return allocator_->GetPointer<std::uint8_t>(data_handle_) + pos_;
}
@@ -212,7 +208,6 @@
public:
typedef typename PackedSideBlock::KernelSideFormat KernelSideFormat;
typedef typename KernelSideFormat::Cell CellFormat;
- typedef typename KernelSideFormat::InputScalar KernelInputScalar;
typedef typename KernelSideFormat::Scalar KernelScalar;
static const int kCells = KernelSideFormat::kCells;
static const int kCellWidth = CellFormat::kWidth;
@@ -221,7 +216,7 @@
static const int kCellSize = CellFormat::kSize;
static const SideMapOrder kSrcOrder = SrcMapType::kOrder;
static const int kZeroPointInputValue =
- ZeroPointInputValue<KernelInputScalar, KernelScalar>::kValue;
+ ZeroPointInputValue<KernelScalar>::kValue;
PackingRegisterBlockBase() : complete_src_(nullptr, 0, 0, 0) {}
@@ -238,7 +233,7 @@
std::uint8_t buf_[kKernelWidth * kRegisterSize];
public:
- // Selects a block if in-place source data that's already a complete block.
+ // Selects a block if in-place source data that's already a complete block
void UseCompleteSrcInPlace(const SrcMapType& src) { complete_src_ = src; }
// Copies an incomplete block of source data into a local temporary
// complete block by zero-extending it.
@@ -254,10 +249,7 @@
memcpy(buf_ + d * kKernelWidth, src.data(0, d), src.width());
}
}
-
- // Since the KernelInputScalar type may not be uint8, we need to cast buf_.
- complete_src_ = SrcMapType(reinterpret_cast<KernelInputScalar*>(buf_),
- kKernelWidth, kRegisterSize);
+ complete_src_ = SrcMapType(buf_, kKernelWidth, kRegisterSize);
}
// Packs a complete block into the destination. This is the most
// critical part and the part that we most typically want to
@@ -348,7 +340,7 @@
}
}
- // Prefetches the data that will be read by PackL1.
+ // Prefetches the data that will be read by PackL1
void PrefetchL1(int start_width, int width, int start_depth, int depth) {
if (SrcMapType::kOrder == SideMapOrder::WidthMajor) {
for (int d = 0; d < depth; d += kDefaultCacheLineSize) {
@@ -402,7 +394,7 @@
const SrcMapType& src_map_;
};
-// Packs a block of the input LHS matrix, into a PackedSideBlock.
+// Packs a block of the input LHS matrix, into a PackedSideBlock
template <typename PackedSideBlock, typename MatrixMapType>
void PackLhs(PackedSideBlock* dst, const MatrixMapType& src) {
ScopedProfilingLabel label("pack LHS");
@@ -417,7 +409,7 @@
impl.PackL2();
}
-// Packs a block of the input RHS matrix, into a PackedSideBlock.
+// Packs a block of the input RHS matrix, into a PackedSideBlock
template <typename PackedSideBlock, typename MatrixMapType>
void PackRhs(PackedSideBlock* dst, const MatrixMapType& src) {
ScopedProfilingLabel label("pack RHS");
@@ -438,8 +430,6 @@
#include "pack_neon.h"
#elif defined(GEMMLOWP_SSE4)
#include "pack_sse.h"
-#elif defined(GEMMLOWP_AVX2)
-#include "pack_avx.h"
#elif defined(GEMMLOWP_MSA)
#include "pack_msa.h"
#endif
diff --git a/internal/pack_avx.h b/internal/pack_avx.h
deleted file mode 100644
index 1ef5ce1..0000000
--- a/internal/pack_avx.h
+++ /dev/null
@@ -1,282 +0,0 @@
-// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// pack_avx.h: optimized AVX specializations of the templates in pack.h.
-
-#ifndef GEMMLOWP_INTERNAL_PACK_AVX_H_
-#define GEMMLOWP_INTERNAL_PACK_AVX_H_
-
-#include <immintrin.h>
-#include "pack.h"
-
-namespace gemmlowp {
-
-// TODO: Add DepthMajorUint8SideMap
-
-typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
- WidthMajorUint8SideMap;
-
-template <int Cells>
-using WidthMajorSideFormatNCells4x2 =
- KernelSideFormat<CellFormat<8, 2, CellOrder::WidthMajor>, Cells>;
-
-template <int Cells>
-class PackingRegisterBlock<
- WidthMajorUint8SideMap,
- PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>>
- : public PackingRegisterBlockBase<
- WidthMajorUint8SideMap,
- PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> {
- public:
- typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
- typedef typename KernelSideFormat::Cell CellFormat;
- static const int kCells = KernelSideFormat::kCells;
- static const int kCellWidth = CellFormat::kWidth;
- static const int kKernelWidth = CellFormat::kWidth * kCells;
- static const int kCellDepth = CellFormat::kDepth;
- static const int kCellSize = CellFormat::kSize;
-
- void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) {
- std::uint8_t *dst_ptr = dst->current_data();
- const int width_stride = this->complete_src_.width_stride();
- int depth_step = 16;
-
- __m256i one = _mm256_set1_epi16(1);
- for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
- cell_start_depth += depth_step) {
- for (int cell_start_width = 0; cell_start_width < kKernelWidth;
- cell_start_width += kCellWidth) {
- std::int32_t *cell_sums_of_each_slice_ptr =
- dst->sums_of_each_slice() + start_width + cell_start_width;
- const std::uint8_t *src_data =
- this->complete_src_.data(cell_start_width, cell_start_depth);
-
- __m128i xmm1 =
- _mm_loadu_si128(reinterpret_cast<const __m128i *>(&src_data[0]));
- __m128i xmm2 = _mm_loadu_si128(
- reinterpret_cast<const __m128i *>(&src_data[1 * width_stride]));
- __m128i xmm3 = _mm_loadu_si128(
- reinterpret_cast<const __m128i *>(&src_data[2 * width_stride]));
- __m128i xmm4 = _mm_loadu_si128(
- reinterpret_cast<const __m128i *>(&src_data[3 * width_stride]));
- __m128i xmm5 = _mm_loadu_si128(
- reinterpret_cast<const __m128i *>(&src_data[4 * width_stride]));
- __m128i xmm6 = _mm_loadu_si128(
- reinterpret_cast<const __m128i *>(&src_data[5 * width_stride]));
- __m128i xmm7 = _mm_loadu_si128(
- reinterpret_cast<const __m128i *>(&src_data[6 * width_stride]));
- __m128i xmm8 = _mm_loadu_si128(
- reinterpret_cast<const __m128i *>(&src_data[7 * width_stride]));
-
- __m256i ymm1 = _mm256_set_m128i(xmm5, xmm1);
- __m256i ymm2 = _mm256_set_m128i(xmm6, xmm2);
- __m256i ymm3 = _mm256_set_m128i(xmm7, xmm3);
- __m256i ymm4 = _mm256_set_m128i(xmm8, xmm4);
-
- __m256i ymm5 = _mm256_unpacklo_epi16(ymm1, ymm2);
- __m256i ymm6 = _mm256_unpacklo_epi16(ymm3, ymm4);
-
- __m256i ymm9 = _mm256_unpackhi_epi16(ymm1, ymm2);
- __m256i ymm10 = _mm256_unpackhi_epi16(ymm3, ymm4);
-
- __m256i ymm7 = _mm256_unpacklo_epi32(ymm5, ymm6);
- __m256i ymm8 = _mm256_unpackhi_epi32(ymm5, ymm6);
-
- __m256i ymm13 = _mm256_unpacklo_epi32(ymm9, ymm10);
- __m256i ymm14 = _mm256_unpackhi_epi32(ymm9, ymm10);
-
- __m256i ymm11 = _mm256_permute4x64_epi64(ymm7, 0xd8);
- __m256i ymm12 = _mm256_permute4x64_epi64(ymm8, 0xd8);
-
- __m256i ymm15 = _mm256_permute4x64_epi64(ymm13, 0xd8);
- __m256i ymm16 = _mm256_permute4x64_epi64(ymm14, 0xd8);
-
- __m128i xmm9 = _mm256_castsi256_si128(ymm11);
- __m128i xmm10 = _mm256_castsi256_si128(ymm12);
- __m128i xmm11 = _mm256_extracti128_si256(ymm11, 1);
- __m128i xmm12 = _mm256_extracti128_si256(ymm12, 1);
-
- xmm1 = _mm256_castsi256_si128(ymm15);
- xmm2 = _mm256_castsi256_si128(ymm16);
- xmm3 = _mm256_extracti128_si256(ymm15, 1);
- xmm4 = _mm256_extracti128_si256(ymm16, 1);
-
- _mm_storeu_si128(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9);
- _mm_storeu_si128(
- reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm11);
- _mm_storeu_si128(
- reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]),
- xmm10);
- _mm_storeu_si128(
- reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]),
- xmm12);
- _mm_storeu_si128(
- reinterpret_cast<__m128i *>(&dst_ptr[4 * kCellSize * kCells]),
- xmm1);
- _mm_storeu_si128(
- reinterpret_cast<__m128i *>(&dst_ptr[5 * kCellSize * kCells]),
- xmm3);
-
- _mm_storeu_si128(
- reinterpret_cast<__m128i *>(&dst_ptr[6 * kCellSize * kCells]),
- xmm2);
- _mm_storeu_si128(
- reinterpret_cast<__m128i *>(&dst_ptr[7 * kCellSize * kCells]),
- xmm4);
-
- ymm6 = _mm256_cvtepu8_epi16(xmm9);
- ymm7 = _mm256_madd_epi16(ymm6, one);
- __m256i sums_of_each_slice_xmm = _mm256_loadu_si256(
- reinterpret_cast<const __m256i *>(&cell_sums_of_each_slice_ptr[0]));
- sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
-
- ymm6 = _mm256_cvtepu8_epi16(xmm11);
- ymm7 = _mm256_madd_epi16(ymm6, one);
- sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
-
- ymm6 = _mm256_cvtepu8_epi16(xmm10);
- ymm7 = _mm256_madd_epi16(ymm6, one);
- sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
-
- ymm6 = _mm256_cvtepu8_epi16(xmm12);
- ymm7 = _mm256_madd_epi16(ymm6, one);
- sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
-
- ymm6 = _mm256_cvtepu8_epi16(xmm1);
- ymm7 = _mm256_madd_epi16(ymm6, one);
- sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
-
- ymm6 = _mm256_cvtepu8_epi16(xmm3);
- ymm7 = _mm256_madd_epi16(ymm6, one);
- sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
-
- ymm6 = _mm256_cvtepu8_epi16(xmm2);
- ymm7 = _mm256_madd_epi16(ymm6, one);
- sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
-
- ymm6 = _mm256_cvtepu8_epi16(xmm4);
- ymm7 = _mm256_madd_epi16(ymm6, one);
- sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
-
- _mm256_storeu_si256(
- reinterpret_cast<__m256i *>(&cell_sums_of_each_slice_ptr[0]),
- sums_of_each_slice_xmm);
- dst_ptr += kCellSize;
- }
- dst_ptr += 7 * kCellSize * kCells;
- }
- dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
- }
-};
-
-// Pack format for 4x2 rhs format
-template <int Cells>
-using RhsWidthMajorSideFormatNCells4x2 =
- KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
-
-template <int Cells>
-class PackingRegisterBlock<
- WidthMajorUint8SideMap,
- PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>>
- : public PackingRegisterBlockBase<
- WidthMajorUint8SideMap,
- PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>> {
- public:
- typedef RhsWidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
- typedef typename KernelSideFormat::Cell CellFormat;
- static const int kCells = KernelSideFormat::kCells;
- static const int kCellWidth = CellFormat::kWidth;
- static const int kKernelWidth = CellFormat::kWidth * kCells;
- static const int kCellDepth = CellFormat::kDepth;
- static const int kCellSize = CellFormat::kSize;
-
- void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) {
- std::uint8_t *dst_ptr = dst->current_data();
- const int width_stride = this->complete_src_.width_stride();
- int depth_step = 8;
-
- __m128i one = _mm_set1_epi16(1);
- for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
- cell_start_depth += depth_step) {
- for (int cell_start_width = 0; cell_start_width < kKernelWidth;
- cell_start_width += kCellWidth) {
- std::int32_t *cell_sums_of_each_slice_ptr =
- dst->sums_of_each_slice() + start_width + cell_start_width;
- const std::uint8_t *src_data =
- this->complete_src_.data(cell_start_width, cell_start_depth);
-
- __m128i xmm1 =
- _mm_loadl_epi64(reinterpret_cast<const __m128i *>(&src_data[0]));
- __m128i xmm2 = _mm_loadl_epi64(
- reinterpret_cast<const __m128i *>(&src_data[1 * width_stride]));
- __m128i xmm3 = _mm_loadl_epi64(
- reinterpret_cast<const __m128i *>(&src_data[2 * width_stride]));
- __m128i xmm4 = _mm_loadl_epi64(
- reinterpret_cast<const __m128i *>(&src_data[3 * width_stride]));
-
- __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2);
- __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31);
-
- __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4);
- __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80);
-
- __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc);
- __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc);
-
- _mm_storel_epi64(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9);
- _mm_storel_epi64(
- reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm10);
-
- __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee);
- __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee);
-
- _mm_storel_epi64(
- reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]),
- xmm11);
- _mm_storel_epi64(
- reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]),
- xmm12);
-
- xmm1 = _mm_cvtepu8_epi16(xmm9);
- xmm2 = _mm_madd_epi16(xmm1, one);
- __m128i sums_of_each_slice_xmm = _mm_loadu_si128(
- reinterpret_cast<const __m128i *>(&cell_sums_of_each_slice_ptr[0]));
- sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
-
- xmm1 = _mm_cvtepu8_epi16(xmm10);
- xmm2 = _mm_madd_epi16(xmm1, one);
- sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
-
- xmm1 = _mm_cvtepu8_epi16(xmm11);
- xmm2 = _mm_madd_epi16(xmm1, one);
- sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
-
- xmm1 = _mm_cvtepu8_epi16(xmm12);
- xmm2 = _mm_madd_epi16(xmm1, one);
- sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
-
- _mm_storeu_si128(
- reinterpret_cast<__m128i *>(&cell_sums_of_each_slice_ptr[0]),
- sums_of_each_slice_xmm);
- dst_ptr += kCellSize;
- }
- dst_ptr += 3 * kCellSize * kCells;
- }
- dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
- }
-};
-
-} // namespace gemmlowp
-
-#endif // GEMMLOWP_INTERNAL_PACK_AVX_H_
diff --git a/internal/pack_msa.h b/internal/pack_msa.h
index 4072229..fba8a0f 100644
--- a/internal/pack_msa.h
+++ b/internal/pack_msa.h
@@ -348,84 +348,6 @@
}
};
-template <int Width>
-using Int8FastKernelFormat =
- KernelSideFormatInt8<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>;
-
-template <int Width>
-class PackingRegisterBlock<WidthMajorUint8SideMap,
- PackedSideBlock<Int8FastKernelFormat<Width>>>
- : public PackingRegisterBlockBase<
- WidthMajorUint8SideMap,
- PackedSideBlock<Int8FastKernelFormat<Width>>> {
- public:
- static_assert(Width == 2 || Width == 4, "");
- typedef Int8FastKernelFormat<Width> KernelSideFormat;
- typedef typename KernelSideFormat::Cell CellFormat;
- static const int kCells = KernelSideFormat::kCells;
- static const int kCellWidth = CellFormat::kWidth;
- static const int kKernelWidth = CellFormat::kWidth * kCells;
- static const int kCellDepth = CellFormat::kDepth;
- static const int kCellSize = CellFormat::kSize;
-
- void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
- std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width;
- std::uint8_t* dst_ptr = dst->current_data();
- const std::uint8_t* const src_ptr = this->complete_src_.data();
- const int stride = this->complete_src_.stride();
- // Load source WidthMajor data.
- v16i8 src_lines[Width];
- for (int i = 0; i < Width; i++) {
- src_lines[i] = __builtin_msa_ld_b(
- const_cast<std::uint8_t*>(src_ptr + i * stride), 0);
- }
- for (int i = 0; i < Width; i++) {
- // Subtract 128 by inverting bit 7.
- src_lines[i] = reinterpret_cast<v16i8>(
- __builtin_msa_bnegi_b(reinterpret_cast<v16u8>(src_lines[i]), 7));
- }
- for (int i = 0; i < Width; i++) {
- __builtin_msa_st_b(src_lines[i], dst_ptr + 16 * i, 0);
- }
- v8i16 sums2[Width];
- for (int i = 0; i < Width; i++) {
- sums2[i] = __builtin_msa_hadd_s_h(src_lines[i], src_lines[i]);
- }
- v4i32 sums4_wide[Width];
- for (int i = 0; i < Width; i++) {
- sums4_wide[i] = __builtin_msa_hadd_s_w(sums2[i], sums2[i]);
- }
- v8i16 sums4[Width / 2];
- for (int i = 0; i < Width / 2; i++) {
- sums4[i] = __builtin_msa_pckev_h(
- reinterpret_cast<v8i16>(sums4_wide[2 * i + 1]),
- reinterpret_cast<v8i16>(sums4_wide[2 * i]));
- }
- v4i32 sums8_wide[Width / 2];
- for (int i = 0; i < Width / 2; i++) {
- sums8_wide[i] = __builtin_msa_hadd_s_w(sums4[i], sums4[i]);
- }
- if (Width == 4) {
- v4i32 sum = __builtin_msa_ld_w(const_cast<std::int32_t*>(sums_ptr), 0);
- v8i16 sums8 = __builtin_msa_pckev_h(
- reinterpret_cast<v8i16>(sums8_wide[1]),
- reinterpret_cast<v8i16>(sums8_wide[0]));
- v4i32 sums16 = __builtin_msa_hadd_s_w(sums8, sums8);
- sum = __builtin_msa_addv_w(sum, sums16);
- __builtin_msa_st_w(sum, sums_ptr, 0);
- } else {
- assert(Width == 2);
- std::int32_t sum[2] = { sums_ptr[0], sums_ptr[1] };
- v2i64 sums16 = __builtin_msa_hadd_s_d(sums8_wide[0], sums8_wide[0]);
- sum[0] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 0);
- sum[1] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 2);
- sums_ptr[0] = sum[0];
- sums_ptr[1] = sum[1];
- }
- dst->seek_forward_n_cells(1);
- }
-};
-
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_PACK_MSA_H_
diff --git a/internal/pack_neon.h b/internal/pack_neon.h
index f113d9e..2b08464 100644
--- a/internal/pack_neon.h
+++ b/internal/pack_neon.h
@@ -26,9 +26,6 @@
typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
WidthMajorUint8SideMap;
-typedef SideMap<const std::int8_t, SideMapOrder::WidthMajor>
- WidthMajorInt8SideMap;
-
template <int Cells>
using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>;
@@ -318,67 +315,6 @@
}
};
-template <int Width>
-using Int8InputsFastKernelFormat =
- KernelSideFormatInt8Inputs<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>;
-
-// Same as above, but for int8 inputs, avoiding the uint8 -> int8 conversion.
-template <int Width>
-class PackingRegisterBlock<WidthMajorInt8SideMap,
- PackedSideBlock<Int8InputsFastKernelFormat<Width>>>
- : public PackingRegisterBlockBase<
- WidthMajorInt8SideMap,
- PackedSideBlock<Int8InputsFastKernelFormat<Width>>> {
- public:
- static_assert(Width == 2 || Width == 4, "");
- typedef Int8InputsFastKernelFormat<Width> KernelSideFormat;
- typedef typename KernelSideFormat::Cell CellFormat;
- static const int kCells = KernelSideFormat::kCells;
- static const int kCellWidth = CellFormat::kWidth;
- static const int kKernelWidth = CellFormat::kWidth * kCells;
- static const int kCellDepth = CellFormat::kDepth;
- static const int kCellSize = CellFormat::kSize;
-
- void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
- std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width;
- std::int8_t* dst_ptr = reinterpret_cast<std::int8_t*>(dst->current_data());
- const std::int8_t* const src_ptr = this->complete_src_.data();
- const int stride = this->complete_src_.stride();
- // Load source WidthMajor data
- int8x16_t src_lines[Width];
- for (int i = 0; i < Width; i++) {
- src_lines[i] = vld1q_s8(src_ptr + i * stride);
- }
- for (int i = 0; i < Width; i++) {
- vst1q_s8(dst_ptr + 16 * i, src_lines[i]);
- }
- int16x8_t sums2[Width];
- for (int i = 0; i < Width; i++) {
- const int8x8_t lo = vget_low_s8(src_lines[i]);
- const int8x8_t hi = vget_high_s8(src_lines[i]);
- sums2[i] = vaddl_s8(lo, hi);
- }
- int16x8_t sums4[Width / 2];
- for (int i = 0; i < Width / 2; i++) {
- sums4[i] = vpaddq_s16(sums2[2 * i], sums2[2 * i + 1]);
- }
- if (Width == 4) {
- int32x4_t sum = vld1q_s32(sums_ptr);
- int16x8_t sums8 = vpaddq_s16(sums4[0], sums4[1]);
- sum = vpadalq_s16(sum, sums8);
- vst1q_s32(sums_ptr, sum);
- } else {
- assert(Width == 2);
- int32x2_t sum = vld1_s32(sums_ptr);
- int16x4_t sums8 =
- vpadd_s16(vget_low_s16(sums4[0]), vget_high_s16(sums4[0]));
- sum = vpadal_s16(sum, sums8);
- vst1_s32(sums_ptr, sum);
- }
- dst->seek_forward_n_cells(1);
- }
-};
-
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_PACK_NEON_H_
diff --git a/internal/platform.h b/internal/platform.h
index ab71414..1114767 100644
--- a/internal/platform.h
+++ b/internal/platform.h
@@ -18,7 +18,6 @@
#define GEMMLOWP_INTERNAL_PLATFORM_H_
#ifdef _WIN32
-#include <malloc.h>
#include <windows.h>
#else
#include <stdlib.h>
@@ -72,8 +71,8 @@
inline double real_time_in_seconds() {
__int64 wintime;
GetSystemTimeAsFileTime((FILETIME *)&wintime);
- wintime -= 116444736000000000LL; // 1jan1601 to 1jan1970
- return wintime / 10000000LL + wintime % 10000000LL * 100 * 1e-9;
+ wintime -= 116444736000000000i64; // 1jan1601 to 1jan1970
+ return wintime / 10000000i64 + wintime % 10000000i64 * 100 * 1e-9;
}
#else
diff --git a/internal/simd_wrappers.h b/internal/simd_wrappers.h
index 4e4cce8..d9721c9 100644
--- a/internal/simd_wrappers.h
+++ b/internal/simd_wrappers.h
@@ -105,12 +105,10 @@
using FlippedRhsType = RhsType;
static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
const RhsType& rhs) {
- (void)rhs;
return lhs;
}
static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
const RhsType& rhs) {
- (void)lhs;
return rhs;
}
};
@@ -121,12 +119,10 @@
using FlippedRhsType = LhsType;
static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
const RhsType& rhs) {
- (void)lhs;
return rhs;
}
static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
const RhsType& rhs) {
- (void)rhs;
return lhs;
}
};
@@ -196,153 +192,6 @@
}
template <typename Lhs, typename Rhs>
-struct BroadcastShiftLeftImpl {
- using ResultBlockType =
- typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
- static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
- ResultBlockType result;
- static constexpr int Rows = ResultBlockType::kRows;
- static constexpr int Cols = ResultBlockType::kCols;
- static constexpr int LhsRows = Lhs::kRows;
- static constexpr int LhsCols = Lhs::kCols;
- static constexpr int RhsRows = Rhs::kRows;
- static constexpr int RhsCols = Rhs::kCols;
-
- static_assert(LhsRows == Rows || LhsRows == 1, "");
- static_assert(RhsRows == Rows || RhsRows == 1, "");
- static_assert(LhsCols == Cols || LhsCols == 1, "");
- static_assert(RhsCols == Cols || RhsCols == 1, "");
- static_assert(ResultBlockType::kRegisterLanes == 1,
- "This path is only for scalar values");
- static_assert(Lhs::kRegisterLanes == 1,
- "This path is only for scalar values");
- static_assert(Rhs::kRegisterLanes == 1,
- "This path is only for scalar values");
-
- for (int c = 0; c < Cols; c++) {
- const int lhs_c = LhsCols == Cols ? c : 0;
- const int rhs_c = RhsCols == Cols ? c : 0;
- for (int r = 0; r < Rows; r++) {
- const int lhs_r = LhsRows == Rows ? r : 0;
- const int rhs_r = RhsRows == Rows ? r : 0;
- result.buf.reg[r + c * Rows] =
- ShiftLeft(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
- rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
- }
- }
- return result;
- }
-};
-
-template <typename Lhs, typename Rhs>
-typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastShiftLeft(
- const Lhs& lhs, const Rhs& rhs) {
- using Flip = FlipLhsRhs<Lhs, Rhs>;
- return BroadcastShiftLeftImpl<
- typename Flip::FlippedLhsType,
- typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
- Flip::FlippedRhs(lhs, rhs));
-}
-
-template <typename Lhs, typename Rhs>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl {
- using ResultBlockType =
- typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
- static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
- ResultBlockType result;
- static constexpr int Rows = ResultBlockType::kRows;
- static constexpr int Cols = ResultBlockType::kCols;
- static constexpr int LhsRows = Lhs::kRows;
- static constexpr int LhsCols = Lhs::kCols;
- static constexpr int RhsRows = Rhs::kRows;
- static constexpr int RhsCols = Rhs::kCols;
-
- static_assert(LhsRows == Rows || LhsRows == 1, "");
- static_assert(RhsRows == Rows || RhsRows == 1, "");
- static_assert(LhsCols == Cols || LhsCols == 1, "");
- static_assert(RhsCols == Cols || RhsCols == 1, "");
- static_assert(ResultBlockType::kRegisterLanes == 1,
- "This path is only for scalar values");
- static_assert(Lhs::kRegisterLanes == 1,
- "This path is only for scalar values");
- static_assert(Rhs::kRegisterLanes == 1,
- "This path is only for scalar values");
-
- for (int c = 0; c < Cols; c++) {
- const int lhs_c = LhsCols == Cols ? c : 0;
- const int rhs_c = RhsCols == Cols ? c : 0;
- for (int r = 0; r < Rows; r++) {
- const int lhs_r = LhsRows == Rows ? r : 0;
- const int rhs_r = RhsRows == Rows ? r : 0;
- result.buf.reg[r + c * Rows] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[lhs_r + lhs_c * LhsRows],
- rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
- }
- }
- return result;
- }
-};
-
-template <typename Lhs, typename Rhs>
-typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
-BroadcastSaturatingRoundingDoublingHighMul(const Lhs& lhs, const Rhs& rhs) {
- using Flip = FlipLhsRhs<Lhs, Rhs>;
- return BroadcastSaturatingRoundingDoublingHighMulImpl<
- typename Flip::FlippedLhsType,
- typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
- Flip::FlippedRhs(lhs, rhs));
-}
-
-template <typename Lhs, typename Rhs>
-struct BroadcastRoundingDivideByPOTImpl {
- using ResultBlockType =
- typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
- static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
- ResultBlockType result;
- static constexpr int Rows = ResultBlockType::kRows;
- static constexpr int Cols = ResultBlockType::kCols;
- static constexpr int LhsRows = Lhs::kRows;
- static constexpr int LhsCols = Lhs::kCols;
- static constexpr int RhsRows = Rhs::kRows;
- static constexpr int RhsCols = Rhs::kCols;
-
- static_assert(LhsRows == Rows || LhsRows == 1, "");
- static_assert(RhsRows == Rows || RhsRows == 1, "");
- static_assert(LhsCols == Cols || LhsCols == 1, "");
- static_assert(RhsCols == Cols || RhsCols == 1, "");
- static_assert(ResultBlockType::kRegisterLanes == 1,
- "This path is only for scalar values");
- static_assert(Lhs::kRegisterLanes == 1,
- "This path is only for scalar values");
- static_assert(Rhs::kRegisterLanes == 1,
- "This path is only for scalar values");
-
- for (int c = 0; c < Cols; c++) {
- const int lhs_c = LhsCols == Cols ? c : 0;
- const int rhs_c = RhsCols == Cols ? c : 0;
- for (int r = 0; r < Rows; r++) {
- const int lhs_r = LhsRows == Rows ? r : 0;
- const int rhs_r = RhsRows == Rows ? r : 0;
- result.buf.reg[r + c * Rows] =
- RoundingDivideByPOT(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
- rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
- }
- }
- return result;
- }
-};
-
-template <typename Lhs, typename Rhs>
-typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
-BroadcastRoundingDivideByPOT(const Lhs& lhs, const Rhs& rhs) {
- using Flip = FlipLhsRhs<Lhs, Rhs>;
- return BroadcastRoundingDivideByPOTImpl<
- typename Flip::FlippedLhsType,
- typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
- Flip::FlippedRhs(lhs, rhs));
-}
-
-template <typename Lhs, typename Rhs>
struct BroadcastMulImpl {
using ResultBlockType =
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
@@ -645,16 +494,12 @@
using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
template <int N>
using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
-template <int N>
-using RegBufferInt8 = RegisterBuffer<std::int8_t, N>;
template <int R, int C>
using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
template <int R, int C>
using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
template <int R, int C>
using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
-template <int R, int C>
-using RegBlockInt8 = RegisterBlock<std::int8_t, R, C>;
} // end namespace gemmlowp
diff --git a/internal/simd_wrappers_common_neon_sse.h b/internal/simd_wrappers_common_neon_sse.h
index 694bf99..3830eb1 100644
--- a/internal/simd_wrappers_common_neon_sse.h
+++ b/internal/simd_wrappers_common_neon_sse.h
@@ -350,210 +350,6 @@
}
};
-// 4x1 := 4x1 + 1x1
-template <>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>,
- RegBlockInt32<1, 1>> {
- static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
- const RegBlockInt32<1, 1>& rhs) {
- RegBlockInt32<4, 1> result;
- result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 1x4 := 1x4 + 1x1
-template <>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>,
- RegBlockInt32<1, 1>> {
- static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
- const RegBlockInt32<1, 1>& rhs) {
- RegBlockInt32<1, 4> result;
- result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 4x1 := 4x1 + 4x1
-template <>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>,
- RegBlockInt32<4, 1>> {
- static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
- const RegBlockInt32<4, 1>& rhs) {
- RegBlockInt32<4, 1> result;
- result.buf.reg[0] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
- return result;
- }
-};
-
-// 1x4 := 1x4 + 1x4
-template <>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>,
- RegBlockInt32<1, 4>> {
- static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
- const RegBlockInt32<1, 4>& rhs) {
- RegBlockInt32<1, 4> result;
- result.buf.reg[0] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
- return result;
- }
-};
-
-// 4x4 := 4x4 + 1x4
-template <>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>,
- RegBlockInt32<1, 4>> {
- static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
- const RegBlockInt32<1, 4>& rhs) {
- RegBlockInt32<4, 4> result;
- result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
- result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
- result.buf.reg[2] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
- result.buf.reg[3] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 4x4 := 4x4 + 4x1
-template <>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>,
- RegBlockInt32<4, 1>> {
- static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
- const RegBlockInt32<4, 1>& rhs) {
- RegBlockInt32<4, 4> result;
- result.buf.reg[0] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
- result.buf.reg[1] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[0]);
- result.buf.reg[2] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]);
- result.buf.reg[3] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[0]);
- return result;
- }
-};
-
-// 8x1 := 8x1 + 1x1
-template <>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>,
- RegBlockInt32<1, 1>> {
- static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
- const RegBlockInt32<1, 1>& rhs) {
- RegBlockInt32<8, 1> result;
- const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
- for (int i = 0; i < 2; i++) {
- result.buf.reg[i] = SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], p);
- }
- return result;
- }
-};
-
-// 8x1 := 8x1 + 8x1
-template <>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>,
- RegBlockInt32<8, 1>> {
- static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
- const RegBlockInt32<8, 1>& rhs) {
- RegBlockInt32<8, 1> result;
- for (int i = 0; i < 2; i++) {
- result.buf.reg[i] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], rhs.buf.reg[i]);
- }
- return result;
- }
-};
-
-// 8x4 := 8x4 + 1x4
-template <>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>,
- RegBlockInt32<1, 4>> {
- static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
- const RegBlockInt32<1, 4>& rhs) {
- RegBlockInt32<8, 4> result;
- result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
- result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
- result.buf.reg[2] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
- result.buf.reg[3] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
- result.buf.reg[4] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
- result.buf.reg[5] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
- result.buf.reg[6] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
- result.buf.reg[7] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 8x4 := 8x4 + 8x1
-template <>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>,
- RegBlockInt32<8, 1>> {
- static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
- const RegBlockInt32<8, 1>& rhs) {
- RegBlockInt32<8, 4> result;
- result.buf.reg[0] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
- result.buf.reg[1] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]);
- result.buf.reg[2] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]);
- result.buf.reg[3] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[1]);
- result.buf.reg[4] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[4], rhs.buf.reg[0]);
- result.buf.reg[5] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[5], rhs.buf.reg[1]);
- result.buf.reg[6] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[6], rhs.buf.reg[0]);
- result.buf.reg[7] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[7], rhs.buf.reg[1]);
- return result;
- }
-};
-
-// 1x8 := 1x8 + 1x8
-template <>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>,
- RegBlockInt32<1, 8>> {
- static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
- const RegBlockInt32<1, 8>& rhs) {
- RegBlockInt32<1, 8> result;
- result.buf.reg[0] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
- result.buf.reg[1] =
- SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]);
- return result;
- }
-};
-
-// 1x8 := 1x8 + 1x1
-template <>
-struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>,
- RegBlockInt32<1, 1>> {
- static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
- const RegBlockInt32<1, 1>& rhs) {
- RegBlockInt32<1, 8> result;
- result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
- result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
- lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
- return result;
- }
-};
-
// 4x1 := 4x1 * 1x1
template <>
struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
diff --git a/internal/simd_wrappers_msa.h b/internal/simd_wrappers_msa.h
index 7de01ff..cf5e8e9 100644
--- a/internal/simd_wrappers_msa.h
+++ b/internal/simd_wrappers_msa.h
@@ -33,7 +33,8 @@
template <int ScalarCount>
struct RegisterType<std::int16_t, ScalarCount> {
- using Type = typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type;
+ using Type =
+ typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type;
};
template <int ScalarCount>
@@ -68,9 +69,13 @@
return __builtin_msa_ld_h(const_cast<Int16x8*>(src), 0);
}
-inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) { __builtin_msa_st_h(value, dst, 0); }
+inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
+ __builtin_msa_st_h(value, dst, 0);
+}
-inline void StoreInt16x8(Int16x8* dst, Int16x8 value) { __builtin_msa_st_h(value, dst, 0); }
+inline void StoreInt16x8(Int16x8* dst, Int16x8 value) {
+ __builtin_msa_st_h(value, dst, 0);
+}
inline Uint8x16 LoadUint8x16(const std::uint8_t* src) {
return __builtin_msa_ld_b(const_cast<std::uint8_t*>(src), 0);
diff --git a/internal/simd_wrappers_neon.h b/internal/simd_wrappers_neon.h
index 6871055..2949173 100644
--- a/internal/simd_wrappers_neon.h
+++ b/internal/simd_wrappers_neon.h
@@ -25,7 +25,6 @@
using Int16x4 = int16x4_t;
using Int16x8 = int16x8_t;
using Uint8x8 = uint8x8_t;
-using Int8x8 = int8x8_t;
template <int ScalarCount>
struct RegisterType<std::int32_t, ScalarCount> {
@@ -49,14 +48,6 @@
std::uint8_t>::type>::type;
};
-template <int ScalarCount>
-struct RegisterType<std::int8_t, ScalarCount> {
- using Type = typename std::conditional<
- ScalarCount >= 8, Int8x8,
- typename std::conditional<ScalarCount >= 4, std::int32_t,
- std::int8_t>::type>::type;
-};
-
inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); }
inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); }
inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); }
@@ -101,10 +92,6 @@
inline Int32x4 Max(Int32x4 a, Int32x4 b) { return vmaxq_s32(a, b); }
-inline Int32x4 Max(Int32x4 a, std::int32_t b) {
- return vmaxq_s32(a, vdupq_n_s32(b));
-}
-
inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) {
return vqrdmulhq_n_s32(a, b);
}
@@ -177,17 +164,6 @@
};
template <>
-struct LoadContiguousImpl<RegBlockInt8<8, 8>> {
- static RegBlockInt8<8, 8> Run(const std::int8_t* src) {
- RegBlockInt8<8, 8> result;
- for (int i = 0; i < 8; i++) {
- result.buf.reg[i] = vld1_s8(src + 8 * i);
- }
- return result;
- }
-};
-
-template <>
struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
static RegBlockInt32<8, 8> Run(const std::int32_t* src) {
RegBlockInt32<8, 8> result;
@@ -198,352 +174,6 @@
}
};
-// 4x1 := 4x1 + 1x1
-template <>
-struct BroadcastShiftLeftImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
- static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
- const RegBlockInt32<1, 1>& rhs) {
- RegBlockInt32<4, 1> result;
- result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 1x4 := 1x4 + 1x1
-template <>
-struct BroadcastShiftLeftImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
- static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
- const RegBlockInt32<1, 1>& rhs) {
- RegBlockInt32<1, 4> result;
- result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 4x1 := 4x1 + 4x1
-template <>
-struct BroadcastShiftLeftImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
- static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
- const RegBlockInt32<4, 1>& rhs) {
- RegBlockInt32<4, 1> result;
- result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
- return result;
- }
-};
-
-// 1x4 := 1x4 + 1x4
-template <>
-struct BroadcastShiftLeftImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
- static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
- const RegBlockInt32<1, 4>& rhs) {
- RegBlockInt32<1, 4> result;
- result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
- return result;
- }
-};
-
-// 4x4 := 4x4 + 1x4
-template <>
-struct BroadcastShiftLeftImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
- static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
- const RegBlockInt32<1, 4>& rhs) {
- RegBlockInt32<4, 4> result;
- result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
- result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
- result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
- result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 4x4 := 4x4 + 4x1
-template <>
-struct BroadcastShiftLeftImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
- static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
- const RegBlockInt32<4, 1>& rhs) {
- RegBlockInt32<4, 4> result;
- result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
- result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[0]);
- result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]);
- result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[0]);
- return result;
- }
-};
-
-// 8x1 := 8x1 + 1x1
-template <>
-struct BroadcastShiftLeftImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
- static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
- const RegBlockInt32<1, 1>& rhs) {
- RegBlockInt32<8, 1> result;
- const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
- for (int i = 0; i < 2; i++) {
- result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], p);
- }
- return result;
- }
-};
-
-// 8x1 := 8x1 + 8x1
-template <>
-struct BroadcastShiftLeftImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
- static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
- const RegBlockInt32<8, 1>& rhs) {
- RegBlockInt32<8, 1> result;
- for (int i = 0; i < 2; i++) {
- result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], rhs.buf.reg[i]);
- }
- return result;
- }
-};
-
-// 8x4 := 8x4 + 1x4
-template <>
-struct BroadcastShiftLeftImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
- static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
- const RegBlockInt32<1, 4>& rhs) {
- RegBlockInt32<8, 4> result;
- result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
- result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
- result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
- result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
- result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
- result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
- result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
- result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 8x4 := 8x4 + 8x1
-template <>
-struct BroadcastShiftLeftImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
- static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
- const RegBlockInt32<8, 1>& rhs) {
- RegBlockInt32<8, 4> result;
- result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
- result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]);
- result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]);
- result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[1]);
- result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], rhs.buf.reg[0]);
- result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], rhs.buf.reg[1]);
- result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], rhs.buf.reg[0]);
- result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], rhs.buf.reg[1]);
- return result;
- }
-};
-
-// 1x8 := 1x8 + 1x8
-template <>
-struct BroadcastShiftLeftImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> {
- static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
- const RegBlockInt32<1, 8>& rhs) {
- RegBlockInt32<1, 8> result;
- result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
- result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]);
- return result;
- }
-};
-
-// 1x8 := 1x8 + 1x1
-template <>
-struct BroadcastShiftLeftImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> {
- static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
- const RegBlockInt32<1, 1>& rhs) {
- RegBlockInt32<1, 8> result;
- result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
- result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 4x1 := 4x1 + 1x1
-template <>
-struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 1>,
- RegBlockInt32<1, 1>> {
- static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
- const RegBlockInt32<1, 1>& rhs) {
- RegBlockInt32<4, 1> result;
- result.buf.reg[0] =
- RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 1x4 := 1x4 + 1x1
-template <>
-struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 4>,
- RegBlockInt32<1, 1>> {
- static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
- const RegBlockInt32<1, 1>& rhs) {
- RegBlockInt32<1, 4> result;
- result.buf.reg[0] =
- RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 4x1 := 4x1 + 4x1
-template <>
-struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 1>,
- RegBlockInt32<4, 1>> {
- static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
- const RegBlockInt32<4, 1>& rhs) {
- RegBlockInt32<4, 1> result;
- result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
- return result;
- }
-};
-
-// 1x4 := 1x4 + 1x4
-template <>
-struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 4>,
- RegBlockInt32<1, 4>> {
- static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
- const RegBlockInt32<1, 4>& rhs) {
- RegBlockInt32<1, 4> result;
- result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
- return result;
- }
-};
-
-// 4x4 := 4x4 + 1x4
-template <>
-struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 4>,
- RegBlockInt32<1, 4>> {
- static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
- const RegBlockInt32<1, 4>& rhs) {
- RegBlockInt32<4, 4> result;
- result.buf.reg[0] =
- RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
- result.buf.reg[1] =
- RoundingDivideByPOT(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
- result.buf.reg[2] =
- RoundingDivideByPOT(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
- result.buf.reg[3] =
- RoundingDivideByPOT(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 4x4 := 4x4 + 4x1
-template <>
-struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 4>,
- RegBlockInt32<4, 1>> {
- static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
- const RegBlockInt32<4, 1>& rhs) {
- RegBlockInt32<4, 4> result;
- result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
- result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[0]);
- result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]);
- result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[0]);
- return result;
- }
-};
-
-// 8x1 := 8x1 + 1x1
-template <>
-struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 1>,
- RegBlockInt32<1, 1>> {
- static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
- const RegBlockInt32<1, 1>& rhs) {
- RegBlockInt32<8, 1> result;
- const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
- for (int i = 0; i < 2; i++) {
- result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], p);
- }
- return result;
- }
-};
-
-// 8x1 := 8x1 + 8x1
-template <>
-struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 1>,
- RegBlockInt32<8, 1>> {
- static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
- const RegBlockInt32<8, 1>& rhs) {
- RegBlockInt32<8, 1> result;
- for (int i = 0; i < 2; i++) {
- result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], rhs.buf.reg[i]);
- }
- return result;
- }
-};
-
-// 8x4 := 8x4 + 1x4
-template <>
-struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 4>,
- RegBlockInt32<1, 4>> {
- static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
- const RegBlockInt32<1, 4>& rhs) {
- RegBlockInt32<8, 4> result;
- result.buf.reg[0] =
- RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
- result.buf.reg[1] =
- RoundingDivideByPOT(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
- result.buf.reg[2] =
- RoundingDivideByPOT(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
- result.buf.reg[3] =
- RoundingDivideByPOT(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
- result.buf.reg[4] =
- RoundingDivideByPOT(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
- result.buf.reg[5] =
- RoundingDivideByPOT(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
- result.buf.reg[6] =
- RoundingDivideByPOT(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
- result.buf.reg[7] =
- RoundingDivideByPOT(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
- return result;
- }
-};
-
-// 8x4 := 8x4 + 8x1
-template <>
-struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 4>,
- RegBlockInt32<8, 1>> {
- static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
- const RegBlockInt32<8, 1>& rhs) {
- RegBlockInt32<8, 4> result;
- result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
- result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]);
- result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]);
- result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[1]);
- result.buf.reg[4] = RoundingDivideByPOT(lhs.buf.reg[4], rhs.buf.reg[0]);
- result.buf.reg[5] = RoundingDivideByPOT(lhs.buf.reg[5], rhs.buf.reg[1]);
- result.buf.reg[6] = RoundingDivideByPOT(lhs.buf.reg[6], rhs.buf.reg[0]);
- result.buf.reg[7] = RoundingDivideByPOT(lhs.buf.reg[7], rhs.buf.reg[1]);
- return result;
- }
-};
-
-// 1x8 := 1x8 + 1x8
-template <>
-struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 8>,
- RegBlockInt32<1, 8>> {
- static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
- const RegBlockInt32<1, 8>& rhs) {
- RegBlockInt32<1, 8> result;
- result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
- result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]);
- return result;
- }
-};
-
-// 1x8 := 1x8 + 1x1
-template <>
-struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 8>,
- RegBlockInt32<1, 1>> {
- static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
- const RegBlockInt32<1, 1>& rhs) {
- RegBlockInt32<1, 8> result;
- result.buf.reg[0] =
- RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
- result.buf.reg[1] =
- RoundingDivideByPOT(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
- return result;
- }
-};
-
} // end namespace gemmlowp
#include "simd_wrappers_common_neon_sse.h"
diff --git a/internal/unpack.h b/internal/unpack.h
index 021f4aa..33aee13 100644
--- a/internal/unpack.h
+++ b/internal/unpack.h
@@ -98,14 +98,12 @@
const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
int depth, int src_row, int src_col, int src_global_row,
int src_global_col, int dst_row, int dst_col) {
- using KernelLhsInputScalar = typename KernelFormat::Lhs::InputScalar;
using KernelLhsScalar = typename KernelFormat::Lhs::Scalar;
- using KernelRhsInputScalar = typename KernelFormat::Rhs::InputScalar;
using KernelRhsScalar = typename KernelFormat::Rhs::Scalar;
static constexpr int KernelLhsZeroPointInput =
- ZeroPointInputValue<KernelLhsInputScalar, KernelLhsScalar>::kValue;
+ ZeroPointInputValue<KernelLhsScalar>::kValue;
static constexpr int KernelRhsZeroPointInput =
- ZeroPointInputValue<KernelRhsInputScalar, KernelRhsScalar>::kValue;
+ ZeroPointInputValue<KernelRhsScalar>::kValue;
auto acc = Load<RegisterBlockType>(src, src_row, src_col);
const auto& lhs_sums_of_each_slice_block =
LoadForBroadcasting<RegisterBlockType>(lhs_sums_of_each_slice, src_row);
diff --git a/meta/multi_thread_common.h b/meta/multi_thread_common.h
index b39c3f2..0b35759 100644
--- a/meta/multi_thread_common.h
+++ b/meta/multi_thread_common.h
@@ -22,15 +22,9 @@
inline int ResolveMaxThreads(int max_threads) {
if (max_threads == 0) {
-#ifdef _WIN32
- SYSTEM_INFO sysinfo;
- GetSystemInfo(&sysinfo);
- return sysinfo.dwNumberOfProcessors;
-#else
static const int hardware_threads_count =
static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
return hardware_threads_count;
-#endif
}
return max_threads;
}
diff --git a/profiling/instrumentation.h b/profiling/instrumentation.h
index c1f852e..437fe54 100644
--- a/profiling/instrumentation.h
+++ b/profiling/instrumentation.h
@@ -108,14 +108,13 @@
// contains pointers to literal strings that were manually entered
// in the instrumented code (see ScopedProfilingLabel).
struct ProfilingStack {
- static const std::size_t kMaxSize = 30;
+ static const std::size_t kMaxSize = 14;
typedef const char* LabelsArrayType[kMaxSize];
LabelsArrayType labels;
std::size_t size;
Mutex* lock;
ProfilingStack() { memset(this, 0, sizeof(ProfilingStack)); }
- ~ProfilingStack() { delete lock; }
void Push(const char* label) {
ScopedLock sl(lock);
@@ -172,6 +171,8 @@
ScopedLock sl(GlobalMutexes::Profiler());
ThreadInfo* self = static_cast<ThreadInfo*>(ptr);
ThreadsUnderProfiling().erase(self);
+ pthread_key_delete(self->key);
+ delete self->stack.lock;
}
};
@@ -184,11 +185,7 @@
}
};
- // key_result is unused. The purpose of this 'static' local object is
- // to have its initializer (the pthread_key_create call) performed exactly
- // once, in a way that is guaranteed (since C++11) to be reentrant.
- static const int key_result = pthread_key_create(&key, DeleteThreadInfo);
- (void)key_result;
+ static int key_result = pthread_key_create(&key, DeleteThreadInfo);
ThreadInfo* threadInfo = static_cast<ThreadInfo*>(pthread_getspecific(key));
if (!threadInfo) {
diff --git a/profiling/pthread_everywhere.h b/profiling/pthread_everywhere.h
index 2569bbc..df17c6f 100644
--- a/profiling/pthread_everywhere.h
+++ b/profiling/pthread_everywhere.h
@@ -60,9 +60,6 @@
*cond = new std::condition_variable;
}
inline void pthread_cond_signal(pthread_cond_t *cond) { (*cond)->notify_one(); }
-inline void pthread_cond_broadcast(pthread_cond_t *cond) {
- (*cond)->notify_all();
-}
inline void pthread_cond_wait(pthread_cond_t *cond, pthread_mutex_t *mutex) {
std::unique_lock<std::mutex> lock(**mutex, std::adopt_lock);
(*cond)->wait(lock);
diff --git a/public/bit_depth.h b/public/bit_depth.h
index 412944e..6cb4ecf 100644
--- a/public/bit_depth.h
+++ b/public/bit_depth.h
@@ -24,15 +24,14 @@
struct OperandRange {
static const int kMinValue = tMinValue;
static const int kMaxValue = tMaxValue;
+ static_assert(0 <= kMinValue, "");
static_assert(kMinValue < kMaxValue, "");
+ static_assert(kMaxValue <= 255, "");
};
using Uint8Range = OperandRange<0, 255>;
using Uint8RangeExcludingZero = OperandRange<1, 255>;
-using Int8Range = OperandRange<-128, 127>;
-using Int8RangeExcludingLow = OperandRange<-127, 127>;
-
template <typename tLhsRange, typename tRhsRange>
struct BitDepthParams {
using LhsRange = tLhsRange;
@@ -48,11 +47,6 @@
using L8R8WithLhsNonzeroBitDepthParams =
BitDepthParams<Uint8RangeExcludingZero, Uint8Range>;
-// Signed Variant: This allows using faster kernels using signed arithmetic, see
-// NEON_64bit_GEMM_Int8Operands_Int32Accumulators_AccumTwoWithin16Bits
-using SignedL8R8WithLhsNonzeroBitDepthParams =
- BitDepthParams<Int8RangeExcludingLow, Int8Range>;
-
// Deprecated: when gemmlowp used to allow requantizing 8bit
// inputs to less-than-8-bit depths, the public setting allowing
// that was DefaultL7R5BitDepthParams. That requantization
diff --git a/public/map.h b/public/map.h
index fe6bc5c..3073e05 100644
--- a/public/map.h
+++ b/public/map.h
@@ -131,7 +131,6 @@
assert(start >= 0);
assert(start + len <= size_);
- (void)start;
return VectorDup(data_, len);
}
};
diff --git a/public/output_stages.h b/public/output_stages.h
index 797b662..1d5fca4 100644
--- a/public/output_stages.h
+++ b/public/output_stages.h
@@ -138,44 +138,12 @@
std::int32_t result_offset_after_shift;
};
-// Variant of OutputStageQuantizeDownInt32ByFixedPoint where the 'shift'
-// is not necessarily just a right shift, so we can represent multipliers
-// greater than 1. This takes an result_exponent parameter; when it's
-// <= 0, this is equivalent to OutputStageQuantizeDownInt32ByFixedPoint
-// with result_shift = -result_exponent.
-// In the general case, this consists in first left-shifting by
-// std::max(result_exponent, 0), before doing the same as
-// OutputStageQuantizeDownInt32ByFixedPoint with
-// result_shift = std::max(-result_exponent, 0).
-//
-// Difference from OutputStageScaleInt32ByFixedPointAndExponent here is that
-// each row or column of the output (depending on tShape) has its own
-// result_fixedpoint_multiplier and result_exponent numbers.
-template <VectorShape tShape>
-struct OutputStageScaleInt32ByFixedPointAndExponentPC {
- VectorMap<const std::int32_t, tShape> result_fixedpoint_multiplier;
- VectorMap<const std::int32_t, tShape> result_exponent;
- std::int32_t result_offset_after_shift;
-};
-
// This output stage takes int32 values that are expected to be already
// on the final uint8 scale, but not necessarily in the [0..255] range.
// It clamps them to the [0..255] range and returns them casted to uint8.
struct OutputStageSaturatingCastToUint8 {};
// This output stage takes int32 values that are expected to be already
-// on the final int8 scale, but not necessarily in the [-128..127] range.
-// It clamps them to the [-128..127] range and returns them casted to int8.
-struct OutputStageSaturatingCastToInt8 {};
-
-// This output stage takes int32 values that are expected to be already
-// in the [0..255] range and returns them casted to uint8.
-// This stage can save time if used instead of the
-// OutputStageSaturatingCastToUint8 stage immediately after the
-// OutputStageClamp stage.
-struct OutputStageTruncatingCastToUint8 {};
-
-// This output stage takes int32 values that are expected to be already
// on the final int16 scale, but not necessarily in the [-32768..32767] range.
// It clamps them to the [-32768..32767] range and returns them casted to int16.
struct OutputStageSaturatingCastToInt16 {};