| // Copyright 2015 Google Inc. All Rights Reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK |
| #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK |
| #endif |
| #include "eight_bit_int_gemm.h" |
| |
| #include <memory> |
| |
| // gemmlowp symbols should have hidden visibility. |
| // currently this is ensured in the build system by |
| // passing -finlines-visibility-hidden. TODO: it would be |
| // safer to hardcode it here with some #pragma's. |
| #include "../public/gemmlowp.h" |
| |
| // Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON |
| // code. This code path consists of a number of meta-programmed, automatically |
| // generated GEMM kernels that are suitable for some sizes of input matrices. |
| // Due to the fact that the generated code relies heavily on loop unrolling, |
| // inling and currying of runtime parameters the size of the generated binary |
| // is quite significant (approx. 200kb) which might be prohibitive in |
| // low-memory situations. |
| |
| #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) |
| #include "../meta/legacy_multi_thread_gemm.h" |
| #else |
| |
| #if defined(GEMMLOWP_USE_META_FASTPATH) |
| #warning "META fast path turned on without NEON!" |
| #endif |
| |
| #endif |
| |
| namespace gemmlowp { |
| namespace eight_bit_int_gemm { |
| namespace { |
| |
| // To be used as template parameter for GlobalLock. |
| // GlobalLock<EightBitIntGemmLockId> is the global lock |
| // on EightBitIntGemm entry points, protecting |
| // EightBitIntGemm's global state. |
| struct EightBitIntGemmLockId; |
| |
| // Global state: consists of one global GemmContext instance. |
| GemmContext* global_context; |
| |
| GemmContext* GetOrCreateGlobalContext() { |
| if (!global_context) { |
| global_context = new GemmContext; |
| } |
| return global_context; |
| } |
| |
| void DestroyGlobalContext() { |
| delete global_context; |
| global_context = nullptr; |
| } |
| |
| template <bool transpose_a, bool transpose_b, bool transpose_c> |
| void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k, |
| const std::uint8_t* a, std::int32_t a_offset, int lda, |
| const std::uint8_t* b, std::int32_t b_offset, int ldb, |
| std::uint8_t* c, std::int32_t c_offset, |
| std::int32_t c_mult_int, std::int32_t c_shift, int ldc, |
| BitDepthSetting bit_depth) { |
| const int lhs_offset = a_offset; |
| const int rhs_offset = b_offset; |
| const int result_offset = c_offset; |
| const int result_mult_int = c_mult_int; |
| const int result_shift = c_shift; |
| |
| static const MapOrder ResultOrder = |
| transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor; |
| static const MapOrder LhsOrder = |
| transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor; |
| static const MapOrder RhsOrder = |
| transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor; |
| |
| MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda); |
| MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb); |
| MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc); |
| |
| switch (bit_depth) { |
| #define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \ |
| case BitDepthSetting::BIT_DEPTH_SETTING: \ |
| Gemm<std::uint8_t, BIT_DEPTH_PARAMS>( \ |
| context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \ |
| result_mult_int, result_shift); \ |
| return; |
| GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams) |
| GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams) |
| default: |
| abort(); |
| #undef GEMMLOWP_HANDLE_BIT_DEPTH |
| } |
| } |
| |
| template <bool transpose_a, bool transpose_b, bool transpose_c> |
| void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k, |
| const std::uint8_t* a, std::int32_t a_offset, |
| int lda, const std::uint8_t* b, |
| std::int32_t b_offset, int ldb, std::int32_t* c, |
| int ldc, BitDepthSetting bit_depth) { |
| const int lhs_offset = a_offset; |
| const int rhs_offset = b_offset; |
| |
| static const MapOrder ResultOrder = |
| transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor; |
| static const MapOrder LhsOrder = |
| transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor; |
| static const MapOrder RhsOrder = |
| transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor; |
| |
| MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda); |
| MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb); |
| MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc); |
| |
| auto empty_pipeline = std::make_tuple(); |
| |
| switch (bit_depth) { |
| #define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \ |
| case BitDepthSetting::BIT_DEPTH_SETTING: \ |
| GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>( \ |
| context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \ |
| return; |
| GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams) |
| GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams) |
| default: |
| abort(); |
| #undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32 |
| } |
| } |
| |
| class Scratch { |
| public: |
| Scratch() : buffer_(), buffer_32_(nullptr), size_(0) {} |
| |
| void AssureSize(std::int32_t required_size) { |
| if (size_ >= required_size) { |
| return; |
| } |
| buffer_.reset(new std::uint8_t[required_size + 32]); |
| buffer_32_ = |
| buffer_.get() + |
| ((32 - (reinterpret_cast<uintptr_t>(buffer_.get()) % 32)) % 32); |
| assert((reinterpret_cast<uintptr_t>(buffer_32_) % 32) == 0); |
| size_ = required_size; |
| } |
| |
| void Clear() { |
| buffer_.reset(nullptr); |
| buffer_32_ = nullptr; |
| size_ = 0; |
| } |
| |
| std::uint8_t* buffer() { return buffer_32_; } |
| |
| private: |
| std::unique_ptr<std::uint8_t[]> buffer_; |
| std::uint8_t* buffer_32_; |
| std::int32_t size_; |
| }; |
| |
| Scratch* global_scratch = nullptr; |
| |
| Scratch* GetOrCreateGlobalScratch() { |
| if (global_scratch == nullptr) { |
| global_scratch = new Scratch(); |
| } |
| return global_scratch; |
| } |
| |
| void DestroyGlobalScratch() { |
| delete global_scratch; |
| global_scratch = nullptr; |
| } |
| |
| #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) |
| |
| bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) { |
| // Is it row major and nicely packed? |
| if (transpose && stride == cols) { |
| return true; |
| } |
| |
| // Is it a one row vector? (a vector is both row and column major) |
| if (rows == 1) { |
| return true; |
| } |
| |
| return false; |
| } |
| |
| bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) { |
| // Is it column major and nicely packed? |
| if (!transpose && stride == rows) { |
| return true; |
| } |
| |
| // Is it a one column vector? (a vector is both row and column major) |
| if (cols == 1) { |
| return true; |
| } |
| |
| return false; |
| } |
| |
| bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c, |
| int m, int n, int k, int lda, int ldb, int ldc, |
| BitDepthSetting depth_setting) { |
| // Meta fastpath only supports 8bit x 8bit and k between 8 and 2048. |
| if (depth_setting != BitDepthSetting::A8B8 || k < 8 || k > 2048) { |
| return false; |
| } |
| |
| // The first operand needs to be a row major matrix or a vector. |
| if (!IsRowMajorOrVector(transpose_a, lda, m, k)) { |
| return false; |
| } |
| |
| // The second operand needs to be a column major matrix or a vector. |
| if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) { |
| return false; |
| } |
| |
| // The result can either be a row major matrix, a column major matrix or |
| // a vector. |
| if (IsRowMajorOrVector(transpose_c, ldc, m, n)) { |
| return true; |
| } |
| |
| if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) { |
| return true; |
| } |
| |
| return false; |
| } |
| |
| // Assure enough scratch memory is allocated and run the fast path gemm. |
| void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs, |
| const std::uint8_t* rhs, int m, int n, int k, |
| std::int32_t lhs_offset, std::int32_t rhs_offset, |
| std::int32_t sum_offset, |
| std::int32_t multiplicative_offset, |
| std::int32_t shift, bool result_transpose, |
| std::int32_t result_stride, std::uint8_t* result) { |
| Scratch* scratch = GetOrCreateGlobalScratch(); |
| const std::int32_t max_num_threads = context->max_num_threads(); |
| if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { |
| scratch->AssureSize(meta::gemm_q8_scratch(m, n, k, max_num_threads)); |
| meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads, |
| scratch->buffer(), lhs, rhs, m, n, k, lhs_offset, |
| rhs_offset, sum_offset, multiplicative_offset, |
| shift, result); |
| } else { |
| scratch->AssureSize(meta::gemm_q8_scratch(n, m, k, max_num_threads)); |
| meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads, |
| scratch->buffer(), rhs, lhs, n, m, k, rhs_offset, |
| lhs_offset, sum_offset, multiplicative_offset, |
| shift, result); |
| } |
| } |
| |
| // Assure enough scratch memory is allocated and run the 8bit to float fast |
| // path gemm. |
| void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs, |
| const std::uint8_t* rhs, int m, int n, int k, |
| std::int32_t lhs_offset, std::int32_t rhs_offset, |
| float result_offset, bool result_transpose, |
| std::int32_t result_stride, float* result) { |
| Scratch* scratch = GetOrCreateGlobalScratch(); |
| const std::int32_t max_num_threads = context->max_num_threads(); |
| if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { |
| scratch->AssureSize(meta::gemm_f_scratch(m, n, k, max_num_threads)); |
| meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads, |
| scratch->buffer(), lhs, rhs, m, n, k, lhs_offset, |
| rhs_offset, result_offset, result); |
| } else { |
| scratch->AssureSize(meta::gemm_f_scratch(n, m, k, max_num_threads)); |
| meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads, |
| scratch->buffer(), rhs, lhs, n, m, k, rhs_offset, |
| lhs_offset, result_offset, result); |
| } |
| } |
| |
| #endif |
| |
| } // end anonymous namespace |
| |
| // Public interface entry points |
| |
| void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, |
| int m, int n, int k, const std::uint8_t* a, |
| std::int32_t a_offset, int lda, const std::uint8_t* b, |
| std::int32_t b_offset, int ldb, std::uint8_t* c, |
| std::int32_t c_offset, std::int32_t c_mult_int, |
| std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) { |
| ScopedLock sl(GlobalMutexes::EightBitIntGemm()); |
| GemmContext* context = GetOrCreateGlobalContext(); |
| |
| #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) |
| if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, |
| ldb, ldc, bit_depth)) { |
| MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset, |
| c_mult_int, c_shift, transpose_c, ldc, c); |
| return; |
| } |
| #endif |
| |
| #define GEMMLOWP_HANDLE_CASE(ta, tb, tc) \ |
| if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \ |
| EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b, \ |
| b_offset, ldb, c, c_offset, c_mult_int, \ |
| c_shift, ldc, bit_depth); \ |
| } |
| |
| GEMMLOWP_HANDLE_CASE(false, false, false) |
| GEMMLOWP_HANDLE_CASE(false, false, true) |
| GEMMLOWP_HANDLE_CASE(false, true, false) |
| GEMMLOWP_HANDLE_CASE(false, true, true) |
| GEMMLOWP_HANDLE_CASE(true, false, false) |
| GEMMLOWP_HANDLE_CASE(true, false, true) |
| GEMMLOWP_HANDLE_CASE(true, true, false) |
| GEMMLOWP_HANDLE_CASE(true, true, true) |
| |
| #undef GEMMLOWP_HANDLE_CASE |
| } |
| |
| void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, |
| int m, int n, int k, const std::uint8_t* a, |
| std::int32_t a_offset, std::int32_t lda, |
| const std::uint8_t* b, std::int32_t b_offset, |
| std::int32_t ldb, float* c, float c_offset, |
| std::int32_t ldc, BitDepthSetting bit_depth) { |
| ScopedLock sl(GlobalMutexes::EightBitIntGemm()); |
| GemmContext* context = GetOrCreateGlobalContext(); |
| |
| #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) |
| if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, |
| ldb, ldc, bit_depth)) { |
| MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset, |
| transpose_c, ldc, c); |
| return; |
| } |
| #endif |
| |
| // TODO(maciekc): implement a float output stage, get rid of scratch memory. |
| Scratch* scratch = GetOrCreateGlobalScratch(); |
| if (transpose_c) { |
| scratch->AssureSize(m * ldc * sizeof(std::int32_t)); |
| } else { |
| scratch->AssureSize(n * ldc * sizeof(std::int32_t)); |
| } |
| std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer()); |
| |
| #define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc) \ |
| if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \ |
| EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \ |
| b, b_offset, ldb, temp_c, ldc, \ |
| bit_depth); \ |
| } |
| |
| GEMMLOWP_HANDLE_INT32_CASE(false, false, false) |
| GEMMLOWP_HANDLE_INT32_CASE(false, false, true) |
| GEMMLOWP_HANDLE_INT32_CASE(false, true, false) |
| GEMMLOWP_HANDLE_INT32_CASE(false, true, true) |
| GEMMLOWP_HANDLE_INT32_CASE(true, false, false) |
| GEMMLOWP_HANDLE_INT32_CASE(true, false, true) |
| GEMMLOWP_HANDLE_INT32_CASE(true, true, false) |
| GEMMLOWP_HANDLE_INT32_CASE(true, true, true) |
| |
| #undef GEMMLOWP_HANDLE_INT32_CASE |
| |
| if (transpose_c) { |
| // Row major. |
| for (int i = 0; i < m; ++i) { |
| float* dest_row = c + i * ldc; |
| std::int32_t* src_row = temp_c + i * ldc; |
| for (int j = 0; j < n; ++j) { |
| dest_row[j] = static_cast<float>(src_row[j]) * c_offset; |
| } |
| } |
| } else { |
| // Column major. |
| for (int i = 0; i < n; ++i) { |
| float* dest_column = c + i * ldc; |
| std::int32_t* src_column = temp_c + i * ldc; |
| for (int j = 0; j < m; ++j) { |
| dest_column[j] = static_cast<float>(src_column[j]) * c_offset; |
| } |
| } |
| } |
| } |
| |
| void SetMaxNumThreads(int n) { |
| ScopedLock sl(GlobalMutexes::EightBitIntGemm()); |
| GemmContext* context = GetOrCreateGlobalContext(); |
| context->set_max_num_threads(n); |
| } |
| |
| void FreePersistentResources() { |
| ScopedLock sl(GlobalMutexes::EightBitIntGemm()); |
| DestroyGlobalContext(); |
| DestroyGlobalScratch(); |
| } |
| |
| } // namespace eight_bit_int_gemm |
| } // namespace gemmlowp |