| // Copyright 2015 Google Inc. All Rights Reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "test.h" |
| |
| #include <unistd.h> |
| #include <iostream> |
| #include <ctime> |
| #include <cstdint> |
| #include <vector> |
| #include <cstdlib> |
| #include <memory> |
| #include <string> |
| |
| #include "../public/gemmlowp.h" |
| #include "../internal/kernel_reference.h" |
| #include "../eight_bit_int_gemm/eight_bit_int_gemm.h" |
| #include "test_data.h" |
| |
| namespace gemmlowp { |
| |
| void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b, |
| bool transpose_c, int m, int n, int k, |
| const uint8_t* a, int32_t a_offset, int lda, |
| const uint8_t* b, int32_t b_offset, int ldb, |
| uint8_t* c, int32_t c_offset, int32_t c_mult_int, |
| int32_t c_shift, int ldc) { |
| assert((c_shift >= 0) && (c_shift <= 32)); |
| |
| assert(a != nullptr); |
| assert(b != nullptr); |
| assert(c != nullptr); |
| int a_i_stride; |
| int a_l_stride; |
| if (transpose_a == transpose_c) { |
| a_i_stride = 1; |
| a_l_stride = lda; |
| } else { |
| a_i_stride = lda; |
| a_l_stride = 1; |
| } |
| int b_j_stride; |
| int b_l_stride; |
| if (transpose_b == transpose_c) { |
| b_j_stride = ldb; |
| b_l_stride = 1; |
| } else { |
| b_j_stride = 1; |
| b_l_stride = ldb; |
| } |
| int c_i_stride; |
| int c_j_stride; |
| if (transpose_c) { |
| c_i_stride = ldc; |
| c_j_stride = 1; |
| } else { |
| c_i_stride = 1; |
| c_j_stride = ldc; |
| } |
| int i, j, l; |
| |
| for (j = 0; j < n; j++) { |
| for (i = 0; i < m; i++) { |
| int32_t total = 0; |
| for (l = 0; l < k; l++) { |
| const int a_index = i * a_i_stride + l * a_l_stride; |
| const uint8_t a_as_byte = a[a_index]; |
| const int32_t a_as_int = static_cast<int32_t>(a_as_byte) + a_offset; |
| const int b_index = j * b_j_stride + l * b_l_stride; |
| const uint8_t b_as_byte = b[b_index]; |
| const int32_t b_as_int = static_cast<int32_t>(b_as_byte) + b_offset; |
| const int32_t mult_as_int = a_as_int * b_as_int; |
| total += mult_as_int; |
| } |
| int32_t output = |
| (((total + c_offset) * c_mult_int) + (1 << (c_shift - 1))) >> c_shift; |
| if (output > 255) { |
| output = 255; |
| } |
| if (output < 0) { |
| output = 0; |
| } |
| const int c_index = i * c_i_stride + j * c_j_stride; |
| c[c_index] = static_cast<uint8_t>(output); |
| } |
| } |
| } |
| // *GemmWrapper's allow to wrap various Gemm functions in a uniform |
| // interface, so we can use the same testing code to test all of them |
| |
| template <typename Kernel, typename Scalar, BitDepthSetting BitDepth> |
| struct SingleThreadGemmWrapper { |
| static const BitDepthSetting kBitDepthSetting = BitDepth; |
| |
| static const char* Name() { |
| static char buf[256]; |
| snprintf(buf, sizeof(buf), "SingleThreadGemm, Kernel: %s", Kernel().Name()); |
| return buf; |
| } |
| |
| typedef SingleThreadGemmContext Context; |
| |
| template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder> |
| static void Gemm(Context* context, |
| const MatrixMap<const Scalar, LhsOrder>& lhs, |
| const MatrixMap<const Scalar, RhsOrder>& rhs, |
| MatrixMap<Scalar, ResultOrder>* result, int lhs_offset, |
| int rhs_offset, int result_offset, int result_mult_int, |
| int result_shift) { |
| SingleThreadGemm<typename Kernel::Format, Scalar, BitDepth, LhsOrder, |
| RhsOrder, ResultOrder>( |
| context, Kernel(), lhs, rhs, result, lhs_offset, rhs_offset, |
| result_offset, result_mult_int, result_shift); |
| } |
| }; |
| |
| template <typename Kernel, typename Scalar, BitDepthSetting BitDepth> |
| struct MultiThreadGemmWrapper { |
| static const BitDepthSetting kBitDepthSetting = BitDepth; |
| |
| static const char* Name() { |
| static char buf[256]; |
| snprintf(buf, sizeof(buf), "MultiThreadGemm, Kernel: %s", Kernel().Name()); |
| return buf; |
| } |
| |
| typedef MultiThreadGemmContext Context; |
| |
| template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder> |
| static void Gemm(Context* context, |
| const MatrixMap<const Scalar, LhsOrder>& lhs, |
| const MatrixMap<const Scalar, RhsOrder>& rhs, |
| MatrixMap<Scalar, ResultOrder>* result, int lhs_offset, |
| int rhs_offset, int result_offset, int result_mult_int, |
| int result_shift) { |
| MultiThreadGemm<typename Kernel::Format, Scalar, BitDepth, LhsOrder, |
| RhsOrder, ResultOrder>( |
| context, Kernel(), lhs, rhs, result, lhs_offset, rhs_offset, |
| result_offset, result_mult_int, result_shift); |
| } |
| }; |
| |
| template <typename Scalar, BitDepthSetting BitDepth> |
| struct PublicGemmWrapper { |
| static const BitDepthSetting kBitDepthSetting = BitDepth; |
| |
| static const char* Name() { return "public Gemm"; } |
| |
| typedef GemmContext Context; |
| |
| template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder> |
| static void Gemm(Context* context, |
| const MatrixMap<const Scalar, LhsOrder>& lhs, |
| const MatrixMap<const Scalar, RhsOrder>& rhs, |
| MatrixMap<Scalar, ResultOrder>* result, int lhs_offset, |
| int rhs_offset, int result_offset, int result_mult_int, |
| int result_shift) { |
| gemmlowp::Gemm<uint8_t, BitDepth, LhsOrder, RhsOrder, ResultOrder>( |
| context, lhs, rhs, result, lhs_offset, rhs_offset, result_offset, |
| result_mult_int, result_shift); |
| } |
| }; |
| |
| template <typename Scalar, eight_bit_int_gemm::BitDepthSetting BitDepth> |
| struct EightBitIntGemmWrapper { |
| static const eight_bit_int_gemm::BitDepthSetting kEBitDepthSetting = BitDepth; |
| static const BitDepthSetting kBitDepthSetting = |
| BitDepth == eight_bit_int_gemm::BitDepthSetting::A5B7 |
| ? BitDepthSetting::L7R5 |
| : BitDepthSetting::L8R8; |
| |
| static const char* Name() { return "EightBitIntGemm"; } |
| |
| typedef void Context; |
| |
| template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder> |
| static void Gemm(Context*, const MatrixMap<const Scalar, LhsOrder>& lhs, |
| const MatrixMap<const Scalar, RhsOrder>& rhs, |
| MatrixMap<Scalar, ResultOrder>* result, int lhs_offset, |
| int rhs_offset, int result_offset, int result_mult_int, |
| int result_shift) { |
| const bool transpose_c = ResultOrder == MapOrder::ColMajor; |
| const bool transpose_a = |
| RhsOrder == MapOrder::RowMajor ? transpose_c : !transpose_c; |
| const bool transpose_b = |
| LhsOrder == MapOrder::RowMajor ? transpose_c : !transpose_c; |
| |
| eight_bit_int_gemm::EightBitIntGemm( |
| transpose_a, transpose_b, transpose_c, rhs.cols(), lhs.rows(), |
| lhs.cols(), rhs.data(), rhs_offset, rhs.stride(), lhs.data(), |
| lhs_offset, lhs.stride(), result->data(), result_offset, |
| result_mult_int, result_shift, result->stride(), kEBitDepthSetting); |
| } |
| }; |
| |
| template <typename Scalar> |
| struct ReferenceEightBitIntGemmWrapper { |
| static const BitDepthSetting kBitDepthSetting = BitDepthSetting::L8R8; |
| |
| static const char* Name() { return "ReferenceEightBitIntGemm"; } |
| |
| template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder> |
| static void Gemm(bool transpose_a, bool transpose_b, bool transpose_c, |
| const MatrixMap<const Scalar, LhsOrder>& lhs, |
| const MatrixMap<const Scalar, RhsOrder>& rhs, |
| MatrixMap<Scalar, ResultOrder>* result, int lhs_offset, |
| int rhs_offset, int result_offset, int result_mult_int, |
| int result_shift) { |
| ReferenceEightBitIntGemm(transpose_a, transpose_b, transpose_c, rhs.cols(), |
| lhs.rows(), lhs.cols(), rhs.data(), rhs_offset, |
| rhs.stride(), lhs.data(), lhs_offset, lhs.stride(), |
| result->data(), result_offset, result_mult_int, |
| result_shift, result->stride()); |
| } |
| }; |
| |
| const char* OrderName(MapOrder order) { |
| return order == MapOrder::ColMajor ? "ColMajor" : "RowMajor"; |
| } |
| |
| struct ResultStats { |
| ResultStats() |
| : count(0), |
| med_val(0), |
| mean_signed_diff(0), |
| med_signed_diff(0), |
| med_unsigned_diff(0), |
| max_unsigned_diff(0) |
| {} |
| |
| int count; |
| int med_val; |
| float mean_signed_diff; |
| int med_signed_diff; |
| int med_unsigned_diff; |
| int max_unsigned_diff; |
| |
| std::vector<int> count_diff_by_pot_slice; |
| }; |
| |
| void GetResultStats(const uint8_t* actual, const uint8_t* expected, |
| size_t count, ResultStats* stats) { |
| std::vector<uint8_t> results; |
| std::vector<int16_t> signed_diffs; |
| std::vector<uint8_t> unsigned_diffs; |
| int64_t signed_diffs_sum = 0; |
| for (size_t i = 0; i < count; i++) { |
| results.push_back(actual[i]); |
| int16_t signed_diff = actual[i] - expected[i]; |
| signed_diffs.push_back(signed_diff); |
| unsigned_diffs.push_back(std::abs(signed_diff)); |
| signed_diffs_sum += signed_diff; |
| } |
| |
| std::sort(results.begin(), results.end()); |
| std::sort(signed_diffs.begin(), signed_diffs.end()); |
| std::sort(unsigned_diffs.begin(), unsigned_diffs.end()); |
| |
| const size_t middle = count / 2; |
| |
| stats->count = count; |
| stats->med_val = results[middle]; |
| stats->mean_signed_diff = float(signed_diffs_sum) / count; |
| stats->med_signed_diff = signed_diffs[middle]; |
| stats->med_unsigned_diff = unsigned_diffs[middle]; |
| stats->max_unsigned_diff = unsigned_diffs.back(); |
| |
| // Size 9 for 9 different POT values: 2^0, ..., 2^8 |
| stats->count_diff_by_pot_slice.resize(9); |
| auto cur = unsigned_diffs.begin(); |
| size_t checksum = 0; |
| for (int exponent = 0; exponent < 9; exponent++) { |
| int pot = 1 << exponent; |
| auto next = std::lower_bound(cur, unsigned_diffs.end(), pot); |
| checksum += stats->count_diff_by_pot_slice[exponent] = next - cur; |
| cur = next; |
| } |
| assert(checksum == count); |
| } |
| |
| struct ResultStatsBounds { |
| ResultStatsBounds() |
| : mean_signed_diff(0), |
| med_signed_diff(0), |
| med_unsigned_diff(0), |
| max_unsigned_diff(0) {} |
| |
| float mean_signed_diff; |
| int med_signed_diff; |
| int med_unsigned_diff; |
| int max_unsigned_diff; |
| }; |
| |
| bool CheckResultStatsBounds(const ResultStats& stats, |
| const ResultStatsBounds& bounds) { |
| return stats.max_unsigned_diff <= bounds.max_unsigned_diff && |
| stats.med_unsigned_diff <= bounds.med_unsigned_diff && |
| std::abs(stats.med_signed_diff) <= bounds.med_signed_diff && |
| std::abs(stats.mean_signed_diff) <= bounds.mean_signed_diff; |
| } |
| |
| void ReportResultStats(const ResultStats& stats, |
| const ResultStatsBounds& bounds) { |
| printf(" number of matrix entries: %d\n", stats.count); |
| printf(" median value: %d\n", stats.med_val); |
| printf(" median unsigned diff: %d (tolerating %d)\n", |
| stats.med_unsigned_diff, bounds.med_unsigned_diff); |
| printf(" max unsigned diff: %d (tolerating %d)\n", stats.max_unsigned_diff, |
| bounds.max_unsigned_diff); |
| printf(" median signed diff: %d (tolerating %d)\n", stats.med_signed_diff, |
| bounds.med_signed_diff); |
| printf(" mean signed diff: %.3g (tolerating %.3g)\n", |
| stats.mean_signed_diff, bounds.mean_signed_diff); |
| |
| printf("No error: %.2f %% of entries\n", |
| 100.f * stats.count_diff_by_pot_slice[0] / stats.count); |
| for (int exponent = 1; exponent < 9; exponent++) { |
| printf("Error in %d..%d range: %.2f %% of entries\n", |
| 1 << (exponent - 1), |
| (1 << exponent) - 1, |
| 100.f * stats.count_diff_by_pot_slice[exponent] / stats.count); |
| } |
| } |
| |
| // Our approach to choosing result_shift values for testing, is bisection. |
| // This function takes an interval, [result_shift_min .. result_shift_max]. |
| // If too much saturation occurred in either direction, it bisects accordingly, |
| // recursing until the interval contains only one value. |
| // The primary reason why we prefer this over computing optimal shift values, |
| // is that we actually want to exercise some saturation, as there is nontrivial |
| // code handling that in gemmlowp. |
| // Secondarily, this is faster than computing optimal shifts, since in 90% of |
| // cases the first-tried shift value 16 turns out to be good enough. |
| template <typename GemmWrapper, typename LhsType, typename RhsType, |
| typename ResultType> |
| void test_gemm_impl(typename GemmWrapper::Context* context, const LhsType& lhs, |
| const RhsType& rhs, ResultType* result, int lhs_offset, |
| int rhs_offset, int result_offset, int result_mult_int, |
| int result_shift_min, int result_shift_max) { |
| const int rows = lhs.rows(); |
| const int cols = rhs.cols(); |
| Check(lhs.cols() == rhs.rows()); |
| const int depth = lhs.cols(); |
| |
| const int result_shift = (result_shift_min + result_shift_max) / 2; |
| |
| GemmWrapper::Gemm(context, lhs.const_map(), rhs.const_map(), &result->map(), |
| lhs_offset, rhs_offset, result_offset, result_mult_int, |
| result_shift); |
| |
| typedef typename ResultType::Scalar Scalar; |
| static const MapOrder kLhsOrder = LhsType::kOrder; |
| static const MapOrder kRhsOrder = RhsType::kOrder; |
| static const MapOrder kResultOrder = ResultType::kOrder; |
| ResultType ref_result(rows, cols); |
| const bool transpose_c = kResultOrder == MapOrder::ColMajor; |
| const bool transpose_a = |
| kRhsOrder == MapOrder::RowMajor ? transpose_c : !transpose_c; |
| const bool transpose_b = |
| kLhsOrder == MapOrder::RowMajor ? transpose_c : !transpose_c; |
| ReferenceEightBitIntGemmWrapper<Scalar>::Gemm( |
| transpose_a, transpose_b, transpose_c, lhs.const_map(), rhs.const_map(), |
| &ref_result.map(), lhs_offset, rhs_offset, result_offset, result_mult_int, |
| result_shift); |
| |
| static const BitDepthSetting BitDepth = GemmWrapper::kBitDepthSetting; |
| |
| ResultStats stats; |
| GetResultStats(result->data(), ref_result.data(), rows * cols, &stats); |
| |
| // Adjust shifts until we get meaningful results |
| int new_result_shift_min = result_shift_min; |
| int new_result_shift_max = result_shift_max; |
| bool retry = false; |
| |
| if (stats.med_val < 32) { |
| new_result_shift_max = (result_shift_min + result_shift_max) / 2; |
| retry = true; |
| } |
| |
| if (stats.med_val > 224) { |
| new_result_shift_min = (result_shift_min + result_shift_max) / 2; |
| retry = true; |
| } |
| |
| if (retry) { |
| if (result_shift_min != result_shift_max) { |
| test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, |
| rhs_offset, result_offset, result_mult_int, |
| new_result_shift_min, new_result_shift_max); |
| } |
| return; |
| } |
| |
| ResultStatsBounds bounds; |
| |
| if (BitDepth == BitDepthSetting::L7R5) { |
| // We have very lax requirements on unsigned diff. |
| // We have tighter requirements on signed diff (bias), but only |
| // if the matrix is large enough for things to average out. |
| // For very small sizes, we... basically don't test anything. |
| // The problem is that this test uses unrealistic combinations of |
| // result_mult_int |
| // and result_shift, resulting in potentially wild requantization artifacts |
| // on small GEMMs. |
| int adjust_for_small_sizes = 1000 / (rows * cols); |
| bounds.max_unsigned_diff = |
| std::max(stats.med_val / 2, adjust_for_small_sizes); |
| bounds.med_unsigned_diff = |
| std::max(stats.med_val / 8, adjust_for_small_sizes); |
| bounds.med_signed_diff = std::max(2, adjust_for_small_sizes); |
| bounds.mean_signed_diff = std::max(2, adjust_for_small_sizes); |
| } |
| |
| // Check results |
| const bool good = CheckResultStatsBounds(stats, bounds); |
| |
| printf( |
| "%s: %dx%dx%d %s x %s -> %s, %s, offsets %d/%d/%d, mult %d, shift %d\n", |
| good ? "PASS" : "FAIL", rows, depth, cols, OrderName(kLhsOrder), |
| OrderName(kRhsOrder), OrderName(kResultOrder), GemmWrapper::Name(), |
| lhs_offset, rhs_offset, result_offset, result_mult_int, result_shift); |
| |
| if (!good) { |
| ReportResultStats(stats, bounds); |
| |
| int bad_coeffs_printed = 0; |
| for (int c = 0; c < result->cols() && bad_coeffs_printed < 20; c++) { |
| for (int r = 0; r < result->rows() && bad_coeffs_printed < 20; r++) { |
| if (ref_result(r, c) != (*result)(r, c)) { |
| printf("bad coeff: at (%d, %d), expected %d, got %d\n", r, c, |
| ref_result(r, c), (*result)(r, c)); |
| bad_coeffs_printed++; |
| } |
| } |
| } |
| } |
| |
| Check(good); |
| } |
| |
| template <typename GemmWrapper, typename LhsType, typename RhsType, |
| typename ResultType> |
| void test_gemm(typename GemmWrapper::Context* context, const LhsType& lhs, |
| const RhsType& rhs, ResultType* result, int lhs_offset, |
| int rhs_offset, int result_offset, int result_mult_int) { |
| test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, rhs_offset, |
| result_offset, result_mult_int, 0, 32); |
| } |
| |
| enum class WhatParamsToTest { |
| All, |
| OnlyGenericCase, |
| }; |
| |
| template <typename GemmWrapper, MapOrder LhsOrder, MapOrder RhsOrder, |
| MapOrder ResultOrder> |
| void test_gemm(typename GemmWrapper::Context* context, int rows, int depth, |
| int cols, WhatParamsToTest params_to_test) { |
| typedef std::uint8_t Scalar; |
| typedef Matrix<Scalar, LhsOrder> LhsType; |
| LhsType lhs(rows, depth); |
| MakeRandom(&lhs, 8); |
| typedef Matrix<Scalar, RhsOrder> RhsType; |
| RhsType rhs(depth, cols); |
| MakeRandom(&rhs, 8); |
| typedef Matrix<Scalar, ResultOrder> ResultType; |
| ResultType result(rows, cols); |
| MakeZero(&result); |
| |
| if (params_to_test == WhatParamsToTest::All) { |
| test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 1); |
| test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 0, 0, 1); |
| test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 10, 0, 1); |
| test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 10, 1); |
| test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 10); |
| test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 10, 10, 10); |
| test_gemm<GemmWrapper>(context, lhs, rhs, &result, 256, 1, 17, 4); |
| } |
| test_gemm<GemmWrapper>(context, lhs, rhs, &result, -75, -91, 74980, 123); |
| } |
| |
| enum class WhatOrdersToTest { All, OnlyRCC }; |
| |
| template <typename GemmWrapper> |
| void test_gemm(typename GemmWrapper::Context* context, int rows, int depth, |
| int cols, WhatParamsToTest params_to_test, |
| WhatOrdersToTest orders_to_test) { |
| #define GEMMLOWP_ONE_TEST(LhsOrder, RhsOrder, ResultOrder) \ |
| do { \ |
| test_gemm<GemmWrapper, MapOrder::LhsOrder, MapOrder::RhsOrder, \ |
| MapOrder::ResultOrder>(context, rows, depth, cols, \ |
| params_to_test); \ |
| } while (false) |
| |
| if (orders_to_test == WhatOrdersToTest::All) { |
| GEMMLOWP_ONE_TEST(ColMajor, ColMajor, ColMajor); |
| GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor); |
| GEMMLOWP_ONE_TEST(ColMajor, RowMajor, ColMajor); |
| GEMMLOWP_ONE_TEST(RowMajor, RowMajor, ColMajor); |
| |
| GEMMLOWP_ONE_TEST(ColMajor, ColMajor, RowMajor); |
| GEMMLOWP_ONE_TEST(RowMajor, ColMajor, RowMajor); |
| GEMMLOWP_ONE_TEST(ColMajor, RowMajor, RowMajor); |
| GEMMLOWP_ONE_TEST(RowMajor, RowMajor, RowMajor); |
| } else { |
| GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor); |
| } |
| |
| #undef GEMMLOWP_ONE_TEST |
| } |
| |
| template <typename Kernel> |
| void test_gemm_kernel(MultiThreadGemmContext* context) { |
| typedef MultiThreadGemmWrapper<Kernel, std::uint8_t, BitDepthSetting::L8R8> |
| GemmWrapper; |
| test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 9, 11, 13, WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 50, 50, 50, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 200, 200, 200, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::All); |
| test_gemm<GemmWrapper>(context, 50, 5000, 50, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| } |
| |
| template <typename GemmWrapper> |
| void test_gemm(typename GemmWrapper::Context* context) { |
| test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 2, 1, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 1, 2, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 1, 1, 2, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 6, 6, 6, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 3, 5, 7, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 7, 3, 5, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 5, 7, 3, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 8, 8, 8, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 16, 16, 16, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 32, 32, 32, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 64, 64, 64, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 128, 128, 128, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| |
| test_gemm<GemmWrapper>(context, 16, 17, 16, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 37, 55, 73, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 57, 87, 117, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 93, 83, 73, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 109, 89, 99, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 78, 101, 82, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| |
| test_gemm<GemmWrapper>(context, 512, 512, 512, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 1024, 1024, 1024, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 567, 2345, 123, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 100, 5000, 100, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 1, 1, 1000, WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 1000, 1, 1, WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 1, 1000, 1, WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 1, 1000, 1000, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 1000, 1, 1000, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 1000, 1000, 1, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 777, 3456, 1, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 4567, 555, 1, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::OnlyRCC); |
| |
| // Test all storage orders |
| test_gemm<GemmWrapper>(context, 70, 90, 110, WhatParamsToTest::All, |
| WhatOrdersToTest::All); |
| test_gemm<GemmWrapper>(context, 300, 400, 500, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::All); |
| } |
| |
| template <typename GemmWrapper> |
| void test_gemv(typename GemmWrapper::Context* context) { |
| test_gemm<GemmWrapper>(context, 2, 2, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 3, 3, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 4, 4, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 5, 5, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 6, 6, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 3, 5, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 7, 3, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 5, 7, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 8, 8, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 32, 32, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 128, 128, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| test_gemm<GemmWrapper>(context, 321, 123, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::OnlyRCC); |
| |
| // Test all storage orders |
| test_gemm<GemmWrapper>(context, 70, 90, 1, WhatParamsToTest::All, |
| WhatOrdersToTest::All); |
| test_gemm<GemmWrapper>(context, 300, 400, 1, |
| WhatParamsToTest::OnlyGenericCase, |
| WhatOrdersToTest::All); |
| } |
| |
| const char* GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b) { |
| switch (b) { |
| case eight_bit_int_gemm::BitDepthSetting::A8B8: |
| return "Lhs: 8 bit, Rhs: 8 bit"; |
| case eight_bit_int_gemm::BitDepthSetting::A5B7: |
| return "Lhs: 7 bit, Rhs: 5 bit"; |
| default: |
| abort(); |
| return nullptr; |
| } |
| } |
| |
| // This is the most realistic test of how we'll be using the low-precision GEMM |
| // function in applications. It takes in large input matrices that have been |
| // captured from an actual neural network run. |
| void TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth, |
| int tolerance_median, int tolerance_max) { |
| std::unique_ptr<uint8_t[]> output_data(new uint8_t[test_data::c_count]); |
| gemmlowp::eight_bit_int_gemm::EightBitIntGemm( |
| test_data::is_a_transposed, test_data::is_b_transposed, |
| test_data::is_c_transposed, test_data::m, test_data::n, test_data::k, |
| test_data::a_data, test_data::a_offset, test_data::k, test_data::b_data, |
| test_data::b_offset, test_data::k, output_data.get(), test_data::c_offset, |
| test_data::c_mult_int, test_data::c_shift, test_data::n, BitDepth); |
| |
| ResultStats stats; |
| GetResultStats(output_data.get(), test_data::expected_c_data, |
| test_data::c_count, &stats); |
| |
| ResultStatsBounds bounds; |
| if (BitDepth == eight_bit_int_gemm::BitDepthSetting::A5B7) { |
| bounds.med_unsigned_diff = tolerance_median; |
| bounds.max_unsigned_diff = tolerance_max; |
| bounds.med_signed_diff = 0; |
| bounds.mean_signed_diff = 0.2f; |
| } |
| |
| const bool good = CheckResultStatsBounds(stats, bounds); |
| printf("TestWithRealData: %s with %s\n", good ? "PASS" : "FAIL", |
| GetBitDepthName(BitDepth)); |
| ReportResultStats(stats, bounds); |
| Check(good); |
| } |
| |
| void test() { |
| #ifdef GEMMLOWP_TEST_PROFILE |
| RegisterCurrentThreadForProfiling(); |
| StartProfiling(); |
| #endif |
| |
| GemmContext context; |
| |
| // Test the internal GEMM interfaces |
| test_gemm<SingleThreadGemmWrapper<DefaultKernelForGemm<BitDepthSetting::L8R8>, |
| std::uint8_t, BitDepthSetting::L8R8>>( |
| &context); |
| |
| test_gemm<MultiThreadGemmWrapper<DefaultKernelForGemm<BitDepthSetting::L8R8>, |
| std::uint8_t, BitDepthSetting::L8R8>>( |
| &context); |
| |
| // Test the public GEMM interfaces |
| test_gemm<PublicGemmWrapper<uint8_t, BitDepthSetting::L8R8>>(&context); |
| |
| test_gemm<EightBitIntGemmWrapper<uint8_t, |
| eight_bit_int_gemm::BitDepthSetting::A8B8>>( |
| &context); |
| |
| // Test GEMV cases (internal interfaces) |
| test_gemv<SingleThreadGemmWrapper<DefaultKernelForGemv<BitDepthSetting::L8R8>, |
| std::uint8_t, BitDepthSetting::L8R8>>( |
| &context); |
| |
| test_gemv<MultiThreadGemmWrapper<DefaultKernelForGemv<BitDepthSetting::L8R8>, |
| std::uint8_t, BitDepthSetting::L8R8>>( |
| &context); |
| |
| // Test GEMV cases (public interfaces) |
| test_gemv<PublicGemmWrapper<uint8_t, BitDepthSetting::L8R8>>(&context); |
| |
| test_gemv<EightBitIntGemmWrapper<uint8_t, |
| eight_bit_int_gemm::BitDepthSetting::A8B8>>( |
| &context); |
| |
| // Test other bit depths |
| // L7R5 |
| for (int foo = 0; foo < 4; foo++) { |
| test_gemm< |
| SingleThreadGemmWrapper<DefaultKernelForGemm<BitDepthSetting::L7R5>, |
| std::uint8_t, BitDepthSetting::L7R5>>(&context); |
| |
| test_gemv< |
| SingleThreadGemmWrapper<DefaultKernelForGemv<BitDepthSetting::L7R5>, |
| std::uint8_t, BitDepthSetting::L7R5>>(&context); |
| |
| test_gemm<EightBitIntGemmWrapper< |
| std::uint8_t, eight_bit_int_gemm::BitDepthSetting::A5B7>>(&context); |
| } |
| |
| // Test specific kernels with various different formats, |
| // to exercises corner cases especially in the packing code. |
| test_gemm_kernel< |
| ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<1, 1>, 1>, |
| KernelSideFormat<CellFormat<1, 1>, 1>>>>( |
| &context); |
| |
| test_gemm_kernel< |
| ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 1>, |
| KernelSideFormat<CellFormat<4, 2>, 2>>>>( |
| &context); |
| |
| test_gemm_kernel< |
| ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 4>, |
| KernelSideFormat<CellFormat<4, 2>, 5>>>>( |
| &context); |
| |
| test_gemm_kernel<ReferenceKernel<KernelFormat< |
| KernelSideFormat<CellFormat<3, 4, CellOrder::DepthMajor>, 2>, |
| KernelSideFormat<CellFormat<5, 4, CellOrder::DepthMajor>, 3>>>>(&context); |
| |
| test_gemm_kernel<ReferenceKernel<KernelFormat< |
| KernelSideFormat<CellFormat<3, 4, CellOrder::WidthMajor>, 2>, |
| KernelSideFormat<CellFormat<5, 4, CellOrder::WidthMajor>, 3>>>>(&context); |
| |
| test_gemm_kernel<ReferenceKernel<KernelFormat< |
| KernelSideFormat<CellFormat<5, 2, CellOrder::WidthMajor>, 3>, |
| KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2>>>>(&context); |
| |
| test_gemm_kernel<ReferenceKernel<KernelFormat< |
| KernelSideFormat<CellFormat<5, 2, CellOrder::DepthMajor>, 3>, |
| KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2>>>>(&context); |
| |
| test_gemm_kernel<ReferenceKernel<KernelFormat< |
| KernelSideFormat<CellFormat<8, 8, CellOrder::Diagonal>, 2>, |
| KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>>>>(&context); |
| |
| test_gemm_kernel<ReferenceKernel<KernelFormat< |
| KernelSideFormat<CellFormat<1, 4, CellOrder::DepthMajor>, 1>, |
| KernelSideFormat<CellFormat<4, 4, CellOrder::Diagonal>, 1>>>>(&context); |
| |
| // Run against actual data from a network evaluation. |
| TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A8B8, 0, 0); |
| TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A5B7, 2, 10); |
| |
| #ifdef GEMMLOWP_TEST_PROFILE |
| FinishProfiling(); |
| #endif |
| |
| std::cerr << "All tests passed." << std::endl; |
| |
| // We have been testing the eight_bit_int_gemm, so we should free its |
| // persistent |
| // resources now to avoid having leak-checking tools report leaks. |
| eight_bit_int_gemm::FreePersistentResources(); |
| } |
| |
| } // end namespace gemmlowp |
| |
| int main() { gemmlowp::test(); } |