| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // See docs in ../ops/math_ops.cc. |
| |
| #define EIGEN_USE_THREADS |
| |
| #include "tensorflow/core/kernels/sparse_matmul_op.h" |
| |
| #include <map> |
| #include <memory> |
| #include <vector> |
| |
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
| #include "tensorflow/core/common_runtime/device.h" |
| #include "tensorflow/core/framework/bfloat16.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/kernels/fill_functor.h" |
| #include "tensorflow/core/lib/core/blocking_counter.h" |
| #include "tensorflow/core/lib/core/threadpool.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/platform/macros.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/thread_annotations.h" |
| #include "tensorflow/core/platform/types.h" |
| #ifdef TENSORFLOW_USE_LIBXSMM |
| #include "include/libxsmm_intrinsics_x86.h" |
| #include "include/libxsmm_malloc.h" |
| #include "include/libxsmm_spmdm.h" |
| #endif |
| |
| #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) |
| #include "tensorflow/core/kernels/eigen_contraction_kernel.h" |
| #endif |
| |
| #define ALWAYS_INLINE EIGEN_ALWAYS_INLINE |
| |
| namespace tensorflow { |
| namespace { |
| |
| template <typename T> |
| using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>; |
| |
| template <typename T> |
| using BasicMatrixMap = |
| Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>; |
| |
| using Matrix = BasicMatrix<float>; |
| using MatrixMap = BasicMatrixMap<float>; |
| using CPUDevice = Eigen::ThreadPoolDevice; |
| using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>; |
| |
| // Two commonly used static dsizes. We use Eigen::type2index to allow as much |
| // compile time optimization as possible. |
| #ifdef EIGEN_HAS_INDEX_LIST |
| inline Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>> |
| dsizes_00() { |
| return Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>(); |
| } |
| inline Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>> |
| dsizes_10() { |
| return Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>(); |
| } |
| #else |
| inline DSizes dsizes_00() { return DSizes(0, 0); } |
| inline DSizes dsizes_10() { return DSizes(1, 0); } |
| #endif |
| |
| // Blocksizes |
| // TODO(agarwal): compute these sizes based on cache sizes. |
| const int K = 64; |
| const int M = 64; |
| const int N = 128; |
| |
| // This stores a sparse representation of a slice of a matrix with size |
| // (num_rows, num_cols). The slice is represented as a series of blocks of size |
| // (num_rows, b), where b = block_size for all but the last block, which may |
| // have fewer columns. |
| // |
| // num_rows and block_size are assumed to be <= 256. This allows storing |
| // different indices as uint8. |
| // |
| // For each block, we store all the non zero entries in data/data3 vector and |
| // the corresponding coordinates of the element in index/index3 vectors. index3 |
| // vector stores index of 3 elements in the same row so that these elements can |
| // share the same row coordinate. Each entry in Index3 corresponds to 3 entries |
| // in data3. |
| // |
| // Note that all the data/indices of all the blocks are stored in the same |
| // vectors respectively. To identify block boundaries, we store the block |
| // offsets using index3_offset/index_offset. If there are n blocks in the slice, |
| // index3_offset and index_offset have n entries. The indices for the ith block |
| // are the values in the following range: |
| // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for |
| // index_offset. |
| template <typename T> |
| struct SparseSlice { |
| using ConstMatrixMap = BasicMatrixMap<const T>; |
| |
| public: |
| // Indices of three elements on the same row. |
| struct Index3 { |
| uint8 m; // row |
| // columns |
| uint8 k1; |
| uint8 k2; |
| uint8 k3; |
| }; |
| |
| // Index of one element. |
| struct Index { |
| uint8 m; |
| uint8 k; |
| }; |
| |
| SparseSlice(int nrows, int ncols, int bsize) |
| : num_rows(nrows), num_cols(ncols), block_size(bsize) { |
| DCHECK_LE(nrows, 256); |
| DCHECK_LE(block_size, 256); |
| } |
| |
| // Initializes the slice with data starting at mat(0, col_offset) and with |
| // size (num_rows, num_cols). |
| // If Transpose is true, implicitly transposes mat. |
| template <bool Transpose = false> |
| void Initialize(const ConstMatrixMap& mat, int col_offset); |
| |
| void Clear(); |
| |
| // See comments above. |
| std::vector<int> index3_offset; |
| std::vector<Index3> index3; |
| std::vector<T> data3; |
| |
| // See comments above. Similar to "index3" except that each element in "index" |
| // corresponds to one element in data. |
| std::vector<int> index_offset; |
| std::vector<Index> index; |
| std::vector<T> data; |
| |
| // Number of rows and columns for the slice. |
| const int num_rows; |
| const int num_cols; |
| |
| // Block size used to initialize from a matrix. |
| const int block_size; |
| }; |
| |
| template <typename T> |
| bool IsZero(T v); |
| |
| template <> |
| ALWAYS_INLINE bool IsZero(bfloat16 v) { |
| return !static_cast<bool>(v); |
| } |
| |
| template <> |
| ALWAYS_INLINE bool IsZero(float v) { |
| return v == 0.0f; |
| } |
| |
| template <typename T> |
| template <bool Transpose> |
| void SparseSlice<T>::Initialize( |
| const typename SparseSlice<T>::ConstMatrixMap& mat, int col_offset) { |
| const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0); |
| const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1); |
| DCHECK_LE(num_rows, mat_rows); |
| DCHECK_LE(num_cols + col_offset, mat_cols); |
| |
| int num_blocks = (num_cols + block_size - 1) / block_size; |
| int mat_size = num_rows * num_cols; |
| |
| index3_offset.reserve(num_blocks); |
| data3.reserve(mat_size); |
| index3.reserve(mat_size / 3); |
| |
| index_offset.reserve(num_blocks); |
| data.reserve(num_blocks * num_rows * 2); |
| index.reserve(num_blocks * num_rows * 2); |
| |
| Index3 idx3; |
| const int stride = Transpose ? mat.dimension(1) : 1; |
| |
| for (int i = 0; i < num_blocks; ++i) { |
| int num_block_cols = std::min(block_size, num_cols - block_size * i); |
| for (int row = 0; row < num_rows; ++row) { |
| idx3.m = static_cast<uint8>(row); |
| // Safety note: The following code has a race, since it checks whether |
| // *curr is nonzero and then reads it again on use. However, the result |
| // of the race is only that some of the "nonzeros" in the resulting sparse |
| // representation may actually be zero, which is harmless. |
| const auto* start = |
| Transpose ? &mat(col_offset, row) : &mat(row, col_offset); |
| const auto* curr = start; |
| const auto* end = start + stride * num_block_cols; |
| uint8 k = 0; |
| #define NEXT_ELEM \ |
| curr += stride; \ |
| ++k; |
| #define EAT_ZEROS \ |
| while (curr < end && IsZero<T>(*curr)) { \ |
| NEXT_ELEM; \ |
| } |
| while (true) { |
| EAT_ZEROS |
| if (curr >= end) break; |
| idx3.k1 = k; |
| const T value1 = *curr; |
| NEXT_ELEM; |
| |
| EAT_ZEROS |
| if (curr >= end) { |
| data.push_back(value1); |
| index.push_back({idx3.m, idx3.k1}); |
| break; |
| } |
| idx3.k2 = k; |
| const T value2 = *curr; |
| NEXT_ELEM; |
| |
| EAT_ZEROS |
| if (curr >= end) { |
| data.push_back(value2); |
| index.push_back({idx3.m, idx3.k2}); |
| data.push_back(value1); |
| index.push_back({idx3.m, idx3.k1}); |
| break; |
| } |
| idx3.k3 = k; |
| data3.push_back(value1); |
| data3.push_back(value2); |
| data3.push_back(*curr); |
| NEXT_ELEM; |
| index3.push_back(idx3); |
| #undef NEXT_ELEM |
| #undef EAT_ZEROS |
| } |
| } |
| col_offset += block_size; |
| index3_offset.push_back(index3.size()); |
| index_offset.push_back(index.size()); |
| } |
| DCHECK_EQ(index3_offset.size(), num_blocks); |
| DCHECK_EQ(index_offset.size(), num_blocks); |
| DCHECK_EQ(3 * index3.size(), data3.size()); |
| DCHECK_EQ(index.size(), data.size()); |
| } |
| |
| template <typename T> |
| void SparseSlice<T>::Clear() { |
| index3_offset.clear(); |
| index3.clear(); |
| data3.clear(); |
| index_offset.clear(); |
| index.clear(); |
| data.clear(); |
| } |
| |
| using Packet = Eigen::internal::packet_traits<float>::type; |
| const int kNumOperands = (sizeof(Packet) / sizeof(float)); |
| #define LOAD(x) Eigen::internal::pload<Packet>(x); |
| #define EXPAND_BFLOAT_L(x, y) \ |
| const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x); |
| #define EXPAND_BFLOAT_U(x, y) \ |
| const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x); |
| #define STORE(x, y) Eigen::internal::pstore<float>(x, y); |
| #define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c); |
| |
| ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) { |
| float out = 0; |
| auto tmp = reinterpret_cast<bfloat16*>(&out); |
| #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ |
| tmp[0] = *src; |
| #else |
| tmp[1] = *src; |
| #endif |
| return out; |
| } |
| |
| ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) { |
| return Eigen::internal::pload4bf16<Packet>( |
| reinterpret_cast<const float*>(src)); |
| } |
| |
| ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) { |
| return Eigen::internal::pload2bf16<Packet>( |
| reinterpret_cast<const float*>(src)); |
| } |
| |
| ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) { |
| **out += a * **inp; |
| ++*inp; |
| ++*out; |
| } |
| |
| ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp, |
| float** out) { |
| float inp_f = ConvertBfloat16ToFloat(*inp); |
| **out += a * inp_f; |
| ++*inp; |
| ++*out; |
| } |
| ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2, |
| const float a3, const bfloat16** inp1, |
| const bfloat16** inp2, |
| const bfloat16** inp3, float** out) { |
| float inp1_f = ConvertBfloat16ToFloat(*inp1); |
| float inp2_f = ConvertBfloat16ToFloat(*inp2); |
| float inp3_f = ConvertBfloat16ToFloat(*inp3); |
| **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f; |
| ++*out; |
| ++*inp1; |
| ++*inp2; |
| ++*inp3; |
| } |
| |
| ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2, |
| const float a3, const float** inp1, |
| const float** inp2, const float** inp3, |
| float** out) { |
| **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3; |
| ++*out; |
| ++*inp1; |
| ++*inp2; |
| ++*inp3; |
| } |
| |
| ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) { |
| auto tmp = ConvertBfloat16ToFloat(*data); |
| *l = Eigen::internal::pset1<Packet>(tmp); |
| ++*data; |
| } |
| |
| ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1, |
| Packet* l2) { |
| if (kNumOperands >= 2) { |
| auto tmp = ConvertTwoBfloat16ToFloat(*data); |
| *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp); |
| *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp); |
| *data += 2; |
| } else { |
| LoadSingleScalar(data, l1); |
| LoadSingleScalar(data, l2); |
| } |
| } |
| |
| ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1, |
| Packet* l2, Packet* l3, Packet* l4) { |
| if (kNumOperands >= 4) { |
| auto tmp = ConvertFourBfloat16ToFloat(*data); |
| *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp); |
| *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp); |
| *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp); |
| *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp); |
| *data += 4; |
| } else { |
| LoadTwoScalars(data, l1, l2); |
| LoadTwoScalars(data, l3, l4); |
| } |
| } |
| |
| ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) { |
| *l = Eigen::internal::pload1<Packet>(*data); |
| ++(*data); |
| } |
| |
| ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) { |
| LoadSingleScalar(data, l1); |
| LoadSingleScalar(data, l2); |
| } |
| |
| ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2, |
| Packet* l3, Packet* l4) { |
| LoadTwoScalars(data, l1, l2); |
| LoadTwoScalars(data, l3, l4); |
| } |
| |
| template <typename T> |
| ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2, |
| Packet* l3) { |
| LoadTwoScalars(data, l1, l2); |
| LoadSingleScalar(data, l3); |
| } |
| |
| template <typename T> |
| ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2, |
| Packet* l3, Packet* l4, Packet* l5, |
| Packet* l6) { |
| LoadFourScalars(data, l1, l2, l3, l4); |
| LoadTwoScalars(data, l5, l6); |
| } |
| |
| // Vectorized version of ScalarMulAdd. |
| ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) { |
| auto inp = reinterpret_cast<const float*>(*binp); |
| const auto b = LOAD(inp); |
| EXPAND_BFLOAT_L(b, b_0); |
| EXPAND_BFLOAT_U(b, b_1); |
| *binp += 2 * kNumOperands; |
| auto c1 = LOAD(*out); |
| auto c2 = LOAD(*out + kNumOperands); |
| FMA(a, b_0, c1, c1); |
| FMA(a, b_1, c2, c2); |
| STORE(*out, c1); |
| STORE(*out + kNumOperands, c2); |
| *out += 2 * kNumOperands; |
| } |
| |
| // Vectorized version of ScalarMulAdd3Way. |
| ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3, |
| const bfloat16** binp1, const bfloat16** binp2, |
| const bfloat16** binp3, float** out) { |
| auto inp1 = reinterpret_cast<const float*>(*binp1); |
| auto inp2 = reinterpret_cast<const float*>(*binp2); |
| auto inp3 = reinterpret_cast<const float*>(*binp3); |
| auto c1 = LOAD(*out); |
| auto c2 = LOAD(*out + kNumOperands); |
| const auto b1 = LOAD(inp1); |
| EXPAND_BFLOAT_L(b1, b1_0); |
| EXPAND_BFLOAT_U(b1, b1_1); |
| *binp1 += 2 * kNumOperands; |
| const auto b2 = LOAD(inp2); |
| EXPAND_BFLOAT_L(b2, b2_0); |
| EXPAND_BFLOAT_U(b2, b2_1); |
| *binp2 += 2 * kNumOperands; |
| const auto b3 = LOAD(inp3); |
| EXPAND_BFLOAT_L(b3, b3_0); |
| EXPAND_BFLOAT_U(b3, b3_1); |
| *binp3 += 2 * kNumOperands; |
| FMA(a1, b1_0, c1, c1); |
| FMA(a1, b1_1, c2, c2); |
| FMA(a2, b2_0, c1, c1); |
| FMA(a2, b2_1, c2, c2); |
| FMA(a3, b3_0, c1, c1); |
| FMA(a3, b3_1, c2, c2); |
| STORE(*out, c1); |
| STORE(*out + kNumOperands, c2); |
| *out += 2 * kNumOperands; |
| } |
| |
| // Unroll MulAdd3Way for two iterations |
| ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2, |
| const Packet a3, const bfloat16** binp1, |
| const bfloat16** binp2, const bfloat16** binp3, |
| float** out) { |
| auto inp1 = reinterpret_cast<const float*>(*binp1); |
| auto inp2 = reinterpret_cast<const float*>(*binp2); |
| auto inp3 = reinterpret_cast<const float*>(*binp3); |
| auto c1 = LOAD(*out); |
| auto c2 = LOAD(*out + kNumOperands); |
| const auto b1 = LOAD(inp1); |
| const auto b2 = LOAD(inp2); |
| const auto b3 = LOAD(inp3); |
| |
| EXPAND_BFLOAT_L(b1, b1_0); |
| EXPAND_BFLOAT_U(b1, b1_1); |
| EXPAND_BFLOAT_L(b2, b2_0); |
| EXPAND_BFLOAT_U(b2, b2_1); |
| EXPAND_BFLOAT_L(b3, b3_0); |
| EXPAND_BFLOAT_U(b3, b3_1); |
| auto c3 = LOAD(*out + 2 * kNumOperands); |
| auto c4 = LOAD(*out + 3 * kNumOperands); |
| const auto b4 = LOAD(inp1 + kNumOperands); |
| const auto b5 = LOAD(inp2 + kNumOperands); |
| const auto b6 = LOAD(inp3 + kNumOperands); |
| |
| EXPAND_BFLOAT_L(b4, b4_0); |
| EXPAND_BFLOAT_U(b4, b4_1); |
| EXPAND_BFLOAT_L(b5, b5_0); |
| EXPAND_BFLOAT_U(b5, b5_1); |
| EXPAND_BFLOAT_L(b6, b6_0); |
| EXPAND_BFLOAT_U(b6, b6_1); |
| |
| FMA(a1, b1_0, c1, c1); |
| FMA(a1, b1_1, c2, c2); |
| FMA(a1, b4_0, c3, c3); |
| FMA(a1, b4_1, c4, c4); |
| FMA(a2, b2_0, c1, c1); |
| FMA(a2, b2_1, c2, c2); |
| FMA(a2, b5_0, c3, c3); |
| FMA(a2, b5_1, c4, c4); |
| FMA(a3, b3_0, c1, c1); |
| FMA(a3, b3_1, c2, c2); |
| FMA(a3, b6_0, c3, c3); |
| FMA(a3, b6_1, c4, c4); |
| STORE(*out, c1); |
| STORE(*out + kNumOperands, c2); |
| STORE(*out + 2 * kNumOperands, c3); |
| STORE(*out + 3 * kNumOperands, c4); |
| *out += 4 * kNumOperands; |
| *binp1 += 4 * kNumOperands; |
| *binp2 += 4 * kNumOperands; |
| *binp3 += 4 * kNumOperands; |
| } |
| |
| // Apply MulAdd3Way on 128 operands. |
| ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2, |
| const Packet a3, const bfloat16** inp1, |
| const bfloat16** inp2, const bfloat16** inp3, |
| float** out) { |
| for (int k = 0; k < 128 / (8 * kNumOperands); ++k) { |
| TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
| TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
| } |
| } |
| |
| // Vectorized version of ScalarMulAdd |
| ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) { |
| const auto b = LOAD(*inp); |
| *inp += kNumOperands; |
| auto c = LOAD(*out); |
| FMA(a, b, c, c); |
| STORE(*out, c); |
| *out += kNumOperands; |
| } |
| |
| // Vectorized version of ScalarMulAdd3Way |
| ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3, |
| const float** inp1, const float** inp2, |
| const float** inp3, float** out) { |
| auto c = LOAD(*out); |
| const auto b1 = LOAD(*inp1); |
| *inp1 += kNumOperands; |
| const auto b2 = LOAD(*inp2); |
| *inp2 += kNumOperands; |
| const auto b3 = LOAD(*inp3); |
| *inp3 += kNumOperands; |
| FMA(a1, b1, c, c); |
| FMA(a2, b2, c, c); |
| FMA(a3, b3, c, c); |
| STORE(*out, c); |
| *out += kNumOperands; |
| } |
| |
| // Unroll MulAdd3Way for two iterations |
| ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2, |
| const Packet a3, const float** inp1, |
| const float** inp2, const float** inp3, |
| float** out) { |
| auto c1 = LOAD(*out); |
| const auto b1 = LOAD(*inp1); |
| const auto b2 = LOAD(*inp2); |
| const auto b3 = LOAD(*inp3); |
| |
| auto c2 = LOAD(*out + kNumOperands); |
| const auto b4 = LOAD(*inp1 + kNumOperands); |
| const auto b5 = LOAD(*inp2 + kNumOperands); |
| const auto b6 = LOAD(*inp3 + kNumOperands); |
| |
| FMA(a1, b1, c1, c1); |
| FMA(a1, b4, c2, c2); |
| FMA(a2, b2, c1, c1); |
| FMA(a2, b5, c2, c2); |
| FMA(a3, b3, c1, c1); |
| FMA(a3, b6, c2, c2); |
| STORE(*out, c1); |
| STORE(*out + kNumOperands, c2); |
| *out += 2 * kNumOperands; |
| *inp1 += 2 * kNumOperands; |
| *inp2 += 2 * kNumOperands; |
| *inp3 += 2 * kNumOperands; |
| } |
| |
| // Unroll MulAdd3Way for four iterations |
| ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2, |
| const Packet a3, const float** inp1, |
| const float** inp2, const float** inp3, |
| float** out) { |
| TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
| TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
| } |
| |
| // Apply MulAdd3Way on 128 operands. |
| ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2, |
| const Packet a3, const float** inp1, |
| const float** inp2, const float** inp3, |
| float** out) { |
| if (kNumOperands == 8) { |
| FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
| FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
| FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
| FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
| } else { |
| DCHECK_LE(4 * kNumOperands, 128); |
| for (int i = 0; i < 128 / (4 * kNumOperands); ++i) { |
| MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
| MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
| MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
| MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); |
| } |
| } |
| } |
| // Computes product of "left_slices" with "num_cols" columns of "right", and |
| // stores the output in *"output". |
| // Note that left_slices is a list of SparseSlices, which are conceptually |
| // assumed to be concatenated along the column dimension. Also each SparseSlice |
| // is encoded as a list of blocks with upto N columns. See SparseSlice for more |
| // details. |
| template <typename TL, typename TR, int Cols> |
| inline void GEPP( |
| const std::vector<SparseSlice<TL>*>& left_slices, |
| const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>, |
| Eigen::Aligned>& right, |
| const int num_cols, Matrix* output) { |
| const int cols = (Cols == -1) ? num_cols : Cols; |
| DCHECK_EQ(num_cols, cols); |
| const int right_num_cols = right.dimension(1); |
| const int output_num_cols = output->dimension(1); |
| static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR); |
| const int cols_mod = cols % kNumOperandsR; |
| int k_offset = 0; |
| // Pre-compute pointers for output matrix. |
| float* out_ptrs[M]; |
| float* const out_start = &(*output)(0, 0); |
| for (int j = 0; j < M; ++j) { |
| out_ptrs[j] = out_start + output_num_cols * j; |
| } |
| for (const auto* left_slice : left_slices) { |
| const auto& left = *left_slice; |
| const auto* data3 = (!left.data3.empty()) ? &left.data3[0] : nullptr; |
| const auto* data = (!left.data.empty()) ? &left.data[0] : nullptr; |
| const int num_blocks = left.index3_offset.size(); |
| int begin3 = 0; |
| int begin = 0; |
| for (int i = 0; i < num_blocks; ++i) { |
| // Pre-compute pointers for right matrix |
| const TR* right_ptrs[K]; |
| const auto* const right_start = &right(k_offset, 0); |
| DCHECK_LT(k_offset, right.dimension(0)); |
| for (int j = 0; j < K; ++j) { |
| right_ptrs[j] = right_start + right_num_cols * j; |
| } |
| |
| const int end3 = left.index3_offset[i]; |
| int j = begin3; |
| // Loop unrolled for 2 iterations. |
| for (; j + 1 < end3; j += 2) { |
| Packet l1, l2, l3, nl1, nl2, nl3; |
| LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3); |
| const auto& index = left.index3[j]; |
| const auto& nindex = left.index3[j + 1]; |
| float* out = out_ptrs[index.m]; |
| float* nout = out_ptrs[nindex.m]; |
| const auto* r1 = right_ptrs[index.k1]; |
| const auto* r2 = right_ptrs[index.k2]; |
| const auto* r3 = right_ptrs[index.k3]; |
| |
| const auto* nr1 = right_ptrs[nindex.k1]; |
| const auto* nr2 = right_ptrs[nindex.k2]; |
| const auto* nr3 = right_ptrs[nindex.k3]; |
| if (cols == 128) { |
| MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out); |
| MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout); |
| } else { |
| for (int n = 0; n < cols / kNumOperandsR; ++n) { |
| MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out); |
| MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout); |
| } |
| |
| const float sl1 = Eigen::internal::pfirst<Packet>(l1); |
| const float sl2 = Eigen::internal::pfirst<Packet>(l2); |
| const float sl3 = Eigen::internal::pfirst<Packet>(l3); |
| const float nsl1 = Eigen::internal::pfirst<Packet>(nl1); |
| const float nsl2 = Eigen::internal::pfirst<Packet>(nl2); |
| const float nsl3 = Eigen::internal::pfirst<Packet>(nl3); |
| for (int k = 0; k < cols_mod; ++k) { |
| ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out); |
| ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout); |
| } |
| } |
| } |
| if (j < end3) { |
| Packet l1, l2, l3; |
| LoadThreeScalars(&data3, &l1, &l2, &l3); |
| |
| const auto& index = left.index3[j]; |
| float* out = out_ptrs[index.m]; |
| const auto* r1 = right_ptrs[index.k1]; |
| const auto* r2 = right_ptrs[index.k2]; |
| const auto* r3 = right_ptrs[index.k3]; |
| if (cols == 128) { |
| MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out); |
| } else { |
| for (int n = 0; n < cols / kNumOperandsR; ++n) { |
| MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out); |
| } |
| const float sl1 = Eigen::internal::pfirst<Packet>(l1); |
| const float sl2 = Eigen::internal::pfirst<Packet>(l2); |
| const float sl3 = Eigen::internal::pfirst<Packet>(l3); |
| for (int k = 0; k < cols_mod; ++k) { |
| ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out); |
| } |
| } |
| } |
| begin3 = end3; |
| int end = left.index_offset[i]; |
| // Loop unrolled for 4 iterations. |
| j = begin; |
| for (; j + 3 < end; j += 4) { |
| Packet l, nl, n2l, n3l; |
| LoadFourScalars(&data, &l, &nl, &n2l, &n3l); |
| |
| const auto& index = left.index[j]; |
| const auto& nindex = left.index[j + 1]; |
| const auto& n2index = left.index[j + 2]; |
| const auto& n3index = left.index[j + 3]; |
| const auto* r = right_ptrs[index.k]; |
| const auto* nr = right_ptrs[nindex.k]; |
| const auto* n2r = right_ptrs[n2index.k]; |
| const auto* n3r = right_ptrs[n3index.k]; |
| float* out = out_ptrs[index.m]; |
| float* nout = out_ptrs[nindex.m]; |
| float* n2out = out_ptrs[n2index.m]; |
| float* n3out = out_ptrs[n3index.m]; |
| |
| for (int n = 0; n < cols / kNumOperandsR; ++n) { |
| MulAdd(l, &r, &out); |
| MulAdd(nl, &nr, &nout); |
| MulAdd(n2l, &n2r, &n2out); |
| MulAdd(n3l, &n3r, &n3out); |
| } |
| |
| const float sl1 = Eigen::internal::pfirst<Packet>(l); |
| const float sl2 = Eigen::internal::pfirst<Packet>(nl); |
| const float sl3 = Eigen::internal::pfirst<Packet>(n2l); |
| const float sl4 = Eigen::internal::pfirst<Packet>(n3l); |
| for (int k = 0; k < cols_mod; ++k) { |
| ScalarMulAdd(sl1, &r, &out); |
| ScalarMulAdd(sl2, &nr, &nout); |
| ScalarMulAdd(sl3, &n2r, &n2out); |
| ScalarMulAdd(sl4, &n3r, &n3out); |
| } |
| } |
| while (j < end) { |
| Packet l; |
| LoadSingleScalar(&data, &l); |
| const auto& index = left.index[j]; |
| const auto* r = right_ptrs[index.k]; |
| float* out = out_ptrs[index.m]; |
| for (int n = 0; n < cols / kNumOperandsR; ++n) { |
| MulAdd(l, &r, &out); |
| } |
| const float sl = Eigen::internal::pfirst<Packet>(l); |
| for (int k = 0; k < cols_mod; ++k) { |
| ScalarMulAdd(sl, &r, &out); |
| } |
| j++; |
| } |
| k_offset += left.block_size; |
| begin = end; |
| } |
| } |
| } |
| |
| #undef LOAD |
| #undef EXPAND_BFLOAT_L |
| #undef EXPAND_BFLOAT_U |
| #undef STORE |
| #undef FMA |
| |
| } // namespace |
| |
| template <typename TL, typename TR> |
| class SparseMatMul { |
| using MatrixL = BasicMatrix<TL>; |
| using MatrixR = BasicMatrix<TR>; |
| using ConstMatrixMapL = BasicMatrixMap<const TL>; |
| using ConstMatrixMapR = BasicMatrixMap<const TR>; |
| using MatrixMapR = BasicMatrixMap<TR>; |
| |
| public: |
| // Not used; added to match interface of LibxsmmSparseMatMul |
| struct TensorInfoCache {}; |
| |
| // Perform matrix multiplication of "left" and "right", and store the result |
| // in *"output". |
| public: |
| static inline void Compute(TensorInfoCache* cache, |
| const ConstMatrixMapL& left, |
| const ConstMatrixMapR& right, bool transpose_left, |
| const DeviceBase::CpuWorkerThreads* thread_pool, |
| bool transpose_output, MatrixMap* output); |
| |
| private: |
| // Computes multiplication of left and num_cols columns of right, and stores |
| // the output block in *"output" at offsets "output_row_offset" and |
| // "output_col_offset". If assign is true, assigns the value to that block, |
| // else adds the values to the existing values. |
| static inline void ComputeOutputBlock( |
| const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right, |
| int num_cols, int output_row_offset, int output_col_offset, bool assign, |
| bool transpose_output, MatrixMap* output); |
| |
| // Encodes "mat" using a sparse representation and stores that in |
| // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and |
| // "slice_num_cols", each grid element is converted into a SparseSlice and |
| // stored in mat_slices. "slice_block_size" is used to perform further column |
| // blocking of each slice. |
| static inline std::unique_ptr<BlockingCounter> CreateSparseSlices( |
| const ConstMatrixMapL& mat, bool transpose, int slice_num_rows, |
| int slice_block_size, int slice_num_cols, |
| std::vector<std::vector<SparseSlice<TL>*>>* mat_slices, |
| const DeviceBase::CpuWorkerThreads* thread_pool); |
| |
| // This function chops "mat" along column dimension into pieces with at most N |
| // columns, and concatenates the pieces one after the other in "buffer". It |
| // returns the list of the pieces in "slices". It returns a BlockingCounter |
| // which should be used to wait for the shuffle operations to complete. |
| static inline std::unique_ptr<BlockingCounter> CreateDenseSlices( |
| const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start, |
| int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool, |
| MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices); |
| |
| // Helper function for CreateDenseSlices to move the data around. It returns a |
| // BlockingCounter which should be used to wait for the shuffle operations to |
| // complete. |
| static inline BlockingCounter* ShuffleMatrix( |
| const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows, |
| int slice_col_start, int slice_num_cols, const int N, |
| const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer); |
| |
| // Helper function for CreateDenseSlices to create slices. |
| static inline void SliceMatrix(const MatrixR& mat, const int num_rows, |
| const int num_slices, |
| std::vector<ConstMatrixMapR*>* slices); |
| |
| // Heuristics to compute various block sizes. |
| // KR, NR: block sizes for "right". We run blocking iterations that operate on |
| // matrices with at most this size. |
| // KL: grid size along the column dimension used while encoding left. |
| // IB, JB: number of left and right slices to multiply together. This is used |
| // for ordering different ComputeBlockOutput operations inside each blocking |
| // iteration so as to potentially reduce the working set size. |
| static inline void ComputeBlockSizes(const ConstMatrixMapL& left, |
| const ConstMatrixMapR& right, |
| bool transpose_left, int num_threads, |
| int* KR, int* NR, int* KL, int* JB, |
| int* IB); |
| |
| TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul); |
| }; |
| |
| #ifdef TENSORFLOW_USE_LIBXSMM |
| template <typename TL, typename TR> |
| class LibxsmmSparseMatMul { |
| using MatrixL = BasicMatrix<TL>; |
| using MatrixR = BasicMatrix<TR>; |
| using ConstMatrixMapL = BasicMatrixMap<const TL>; |
| using ConstMatrixMapR = BasicMatrixMap<const TR>; |
| using MatrixMapR = BasicMatrixMap<TR>; |
| |
| public: |
| // This structure contains a set of libxsmm kernels for sizes that have been |
| // encountered previously by this operator so that libxsmm does not need to |
| // reallocate its scratchpad memory each time (which hurts performance |
| // substantially). |
| struct TensorInfoCache { |
| struct TensorInfoCacheEntry { |
| // Parameters for kernel |
| int M; |
| int K; |
| int N; |
| int max_threads; |
| // libxsmm handle and matrix data |
| libxsmm_spmdm_handle handle; |
| libxsmm_CSR_sparseslice* output_csr; |
| // Chain to non-libxsmm implementation's cache in case that ever becomes |
| // useful (it is an empty struct right now) |
| typename SparseMatMul<TL, TR>::TensorInfoCache |
| non_libxsmm_cache; // Currently not used |
| }; |
| // protects entries; invariant: entries is a valid std::multimap |
| tensorflow::mutex lock; |
| // Because there could be multiple matrix multiplies with the same sizes |
| // going on at the same time, we need to allow multiple cache entries for a |
| // given set of parameters. Taking and returning entries is used to make |
| // sure the same cache entry is not used from two threads at a time. |
| std::multimap<std::tuple<int, int, int, int>, |
| std::unique_ptr<TensorInfoCacheEntry>> |
| entries TF_GUARDED_BY(lock); |
| |
| TensorInfoCache() : lock(), entries() {} |
| // Look up and remove first entry with these parameters, creating one if |
| // there isn't one |
| std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N, |
| int max_threads) |
| TF_LOCKS_EXCLUDED(lock) { |
| tensorflow::mutex_lock ml(lock); |
| auto key = std::make_tuple(M, K, N, max_threads); |
| auto it = entries.find(key); |
| if (it != entries.end()) { |
| auto val = std::move(it->second); |
| entries.erase(it); |
| return val; |
| } else { |
| std::unique_ptr<TensorInfoCacheEntry> e{ |
| new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}}; |
| // setup scoped allocator, which uses cpu_allocator() for this scope |
| const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator; |
| libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr); |
| return e; |
| } |
| } |
| // Add a cache entry with certain parameters |
| void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e) |
| TF_LOCKS_EXCLUDED(lock) { |
| tensorflow::mutex_lock ml(lock); |
| auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads); |
| entries.insert(std::make_pair(key, std::move(e))); |
| } |
| ~TensorInfoCache() { |
| tensorflow::mutex_lock ml(lock); |
| for (auto& p : entries) { |
| libxsmm_spmdm_destroy(&p.second->handle); |
| } |
| entries.clear(); |
| } |
| |
| private: |
| TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache); |
| }; |
| |
| // Perform matrix multiplication of "left" and "right", and store the result |
| // in *"output". |
| public: |
| static inline void Compute(TensorInfoCache* cache, |
| const ConstMatrixMapL& left, |
| const ConstMatrixMapR& right, bool transpose_left, |
| const DeviceBase::CpuWorkerThreads* thread_pool, |
| bool transpose_output, MatrixMap* output); |
| |
| private: |
| TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul); |
| }; |
| #endif |
| |
| template <typename TL, typename TR, |
| template <typename TL2, typename TR2> class DoMatMul> |
| class SparseMatMulOp : public OpKernel { |
| using MatrixR = BasicMatrix<TR>; |
| using ConstMatrixMapR = BasicMatrixMap<const TR>; |
| |
| public: |
| explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_)); |
| } |
| |
| void Compute(OpKernelContext* ctx) override { |
| const Tensor& a = ctx->input(0); |
| const Tensor& b = ctx->input(1); |
| OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), |
| errors::InvalidArgument("a is not a matrix")); |
| OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), |
| errors::InvalidArgument("b is not a matrix")); |
| |
| const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0); |
| const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1); |
| const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1); |
| const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0); |
| |
| OP_REQUIRES(ctx, k == k2, |
| errors::InvalidArgument( |
| "Matrix size incompatible: a: ", a.shape().DebugString(), |
| ", b: ", b.shape().DebugString())); |
| Tensor* output = nullptr; |
| OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output)); |
| |
| if (k == 0) { |
| // If the inner dimension k in the matrix multiplication is zero, we fill |
| // the output with zeros. |
| functor::SetZeroFunctor<CPUDevice, float> f; |
| f(ctx->eigen_device<CPUDevice>(), output->flat<float>()); |
| return; |
| } |
| |
| auto out = output->matrix<float>(); |
| |
| std::unique_ptr<Tensor> a_float; |
| std::unique_ptr<Tensor> b_float; |
| if (!a_is_sparse_ && !b_is_sparse_) { |
| auto left = &a; |
| auto right = &b; |
| // TODO(agarwal): multi-thread the conversions from bfloat16 to float. |
| if (std::is_same<TL, bfloat16>::value) { |
| a_float.reset(new Tensor(DT_FLOAT, a.shape())); |
| BFloat16ToFloat(a.flat<bfloat16>().data(), |
| a_float->flat<float>().data(), a.NumElements()); |
| left = a_float.get(); |
| } |
| if (std::is_same<TR, bfloat16>::value) { |
| b_float.reset(new Tensor(DT_FLOAT, b.shape())); |
| BFloat16ToFloat(b.flat<bfloat16>().data(), |
| b_float->flat<float>().data(), b.NumElements()); |
| right = b_float.get(); |
| } |
| Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; |
| dim_pair[0].first = transpose_a_ ? 0 : 1; |
| dim_pair[0].second = transpose_b_ ? 1 : 0; |
| |
| out.device(ctx->template eigen_device<CPUDevice>()) = |
| left->matrix<float>().contract(right->matrix<float>(), dim_pair); |
| return; |
| } |
| |
| auto left = &a; |
| auto right = &b; |
| bool transpose_output = false; |
| bool transpose_a = transpose_a_; |
| bool transpose_b = transpose_b_; |
| if (!a_is_sparse_) { |
| // Swap the order of multiplications using the identity: |
| // A * B = (B' * A')'. |
| std::swap(left, right); |
| std::swap(transpose_a, transpose_b); |
| transpose_a = !transpose_a; |
| transpose_b = !transpose_b; |
| transpose_output = !transpose_output; |
| } |
| |
| std::unique_ptr<Tensor> right_tr; |
| if (transpose_b) { |
| // TODO(agarwal): avoid transposing the matrix here and directly handle |
| // transpose in CreateDenseSlices. |
| OP_REQUIRES(ctx, right->dim_size(0) != 0, |
| errors::InvalidArgument("b has an entry 0 in it's shape.")); |
| OP_REQUIRES(ctx, right->dim_size(1) != 0, |
| errors::InvalidArgument("b has an entry 0 in it's shape.")); |
| right_tr.reset( |
| new Tensor(right->dtype(), |
| TensorShape({right->dim_size(1), right->dim_size(0)}))); |
| |
| const auto perm = dsizes_10(); |
| if (transpose_output) { |
| right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) = |
| right->matrix<TL>().shuffle(perm); |
| } else { |
| right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) = |
| right->matrix<TR>().shuffle(perm); |
| } |
| right = right_tr.get(); |
| } |
| |
| if (transpose_output) { |
| DoMatMul<TR, TL>::Compute(&this->cache_tr_, left->matrix<TR>(), |
| right->matrix<TL>(), transpose_a, |
| ctx->device()->tensorflow_cpu_worker_threads(), |
| transpose_output, &out); |
| } else { |
| DoMatMul<TL, TR>::Compute(&this->cache_nt_, left->matrix<TL>(), |
| right->matrix<TR>(), transpose_a, |
| ctx->device()->tensorflow_cpu_worker_threads(), |
| transpose_output, &out); |
| } |
| } |
| |
| private: |
| bool transpose_a_; |
| bool transpose_b_; |
| bool a_is_sparse_; |
| bool b_is_sparse_; |
| |
| // Cache for non-transposed-output multiply |
| typename DoMatMul<TL, TR>::TensorInfoCache cache_nt_; |
| // Cache for transposed-output multiply |
| typename DoMatMul<TR, TL>::TensorInfoCache cache_tr_; |
| |
| TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp); |
| }; |
| |
| template <typename TL, typename TR> |
| inline void SparseMatMul<TL, TR>::ComputeOutputBlock( |
| const std::vector<SparseSlice<TL>*>& left, |
| const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols, |
| int output_row_offset, int output_col_offset, bool assign, |
| bool transpose_output, MatrixMap* output) { |
| const auto perm = dsizes_10(); |
| int num_rows = left[0]->num_rows; |
| const int rhs_num_cols = right.dimension(1); |
| DCHECK_LE(num_cols, rhs_num_cols); |
| Matrix out(num_rows, rhs_num_cols); |
| out.setZero(); |
| if (num_cols == N) { |
| GEPP<TL, TR, N>(left, right, num_cols, &out); |
| } else { |
| GEPP<TL, TR, -1>(left, right, num_cols, &out); |
| } |
| if (!assign) { |
| const DSizes begin(output_row_offset, output_col_offset); |
| const DSizes sizes(num_rows, num_cols); |
| if (transpose_output) { |
| if (num_cols == rhs_num_cols) { |
| output->shuffle(perm).slice(begin, sizes) += out; |
| } else { |
| const auto zero = dsizes_00(); |
| output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes); |
| } |
| } else { |
| if (num_cols == rhs_num_cols) { |
| output->slice(begin, sizes) += out; |
| } else { |
| const auto zero = dsizes_00(); |
| output->slice(begin, sizes) += out.slice(zero, sizes); |
| } |
| } |
| } else { |
| std::unique_ptr<Matrix> out_tr; |
| if (transpose_output) { |
| out_tr.reset(new Matrix(rhs_num_cols, num_rows)); |
| *out_tr = out.shuffle(perm); |
| std::swap(output_row_offset, output_col_offset); |
| std::swap(num_rows, num_cols); |
| } |
| const Matrix& final_out = transpose_output ? *out_tr : out; |
| for (int i = 0; i < num_rows; ++i) { |
| memcpy(&(*output)(output_row_offset + i, output_col_offset), |
| &final_out(i, 0), num_cols * sizeof(float)); |
| } |
| } |
| } |
| |
| template <typename TL, typename TR> |
| inline std::unique_ptr<BlockingCounter> |
| SparseMatMul<TL, TR>::CreateSparseSlices( |
| const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose, |
| int slice_num_rows, int slice_block_size, int slice_num_cols, |
| std::vector<std::vector<SparseSlice<TL>*>>* mat_slices, |
| const DeviceBase::CpuWorkerThreads* thread_pool) { |
| const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0); |
| const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1); |
| const int num_slices_dim0 = |
| std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows); |
| const int num_slices_dim1 = |
| std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols); |
| mat_slices->resize(num_slices_dim0); |
| BlockingCounter* counter = |
| new BlockingCounter(num_slices_dim0 * num_slices_dim1); |
| auto work = [counter, transpose](SparseSlice<TL>* sparse_slice, |
| SparseMatMul<TL, TR>::ConstMatrixMapL* slice, |
| int col_offset) { |
| if (transpose) { |
| sparse_slice->template Initialize<true>(*slice, col_offset); |
| } else { |
| sparse_slice->template Initialize<false>(*slice, col_offset); |
| } |
| delete slice; |
| counter->DecrementCount(); |
| }; |
| for (int i = 0; i < num_slices_dim0; ++i) { |
| (*mat_slices)[i].resize(num_slices_dim1); |
| int num_rows = |
| std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows); |
| for (int j = 0; j < num_slices_dim1; ++j) { |
| int num_cols = |
| std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols); |
| SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr; |
| if (transpose) { |
| slice = new SparseMatMul<TL, TR>::ConstMatrixMapL( |
| &mat(0, i * slice_num_rows), mat.dimensions()); |
| } else { |
| DSizes d(num_rows, mat_num_cols); |
| slice = new SparseMatMul<TL, TR>::ConstMatrixMapL( |
| &mat(i * slice_num_rows, 0), d); |
| } |
| auto* sparse_slice = |
| new SparseSlice<TL>(num_rows, num_cols, slice_block_size); |
| (*mat_slices)[i][j] = sparse_slice; |
| thread_pool->workers->Schedule( |
| [=]() { work(sparse_slice, slice, slice_num_cols * j); }); |
| } |
| } |
| return std::unique_ptr<BlockingCounter>(counter); |
| } |
| #define LOAD(x) Eigen::internal::ploadu<Packet>((x)); |
| #define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x); |
| #define STORE(x, y) Eigen::internal::pstoreu<float>(x, y); |
| |
| template <int NUM_ELEM = -1> |
| ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc, |
| int num_elements) { |
| DCHECK_GE(kNumOperands, 8); |
| static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16); |
| const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM; |
| DCHECK_EQ(num, num_elements); |
| const float* src = reinterpret_cast<const float*>(bsrc); |
| float* dst = reinterpret_cast<float*>(bdst); |
| for (int index = 0; index + kStep <= num; index += kStep) { |
| auto in = LOAD(src); |
| auto tmp = INTERLEAVE(in); |
| STORE(dst, tmp); |
| src += kNumOperands; |
| dst += kNumOperands; |
| } |
| if (num % kStep != 0) { |
| memcpy(dst, src, (num % kStep) * sizeof(bfloat16)); |
| } |
| } |
| |
| template <typename T> |
| ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src, |
| int num_elements) { |
| if (std::is_same<T, float>::value || kNumOperands < 8) { |
| memcpy(dst, src, num_elements * sizeof(T)); |
| } else if (std::is_same<T, bfloat16>::value) { |
| if (num_elements == N) { |
| CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements); |
| } else { |
| CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements); |
| } |
| } else { |
| LOG(FATAL) << "Unsupported type"; |
| } |
| } |
| |
| #undef LOAD |
| #undef Interleave |
| #undef Store |
| |
| template <typename TL, typename TR> |
| inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix( |
| const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, |
| int slice_row_start, int slice_num_rows, int slice_col_start, |
| int slice_num_cols, const int N, |
| const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) { |
| DCHECK_EQ(N % 2, 0); |
| DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N); |
| // Note(nikhilsarda): This heuristic is optimal in benchmarks as of |
| // Jan 21, 2020. |
| int num_threads = std::min(thread_pool->num_threads, 8); |
| BlockingCounter* counter = new BlockingCounter(num_threads); |
| DCHECK_EQ(N, buffer->dimension(1)); |
| auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start, |
| slice_num_cols, N, buffer, counter](int s, int e) { |
| const int row_start = s % slice_num_rows + slice_row_start; |
| const int col_start = s / slice_num_rows * N + slice_col_start; |
| auto* out_start = &(*buffer)(s, 0); |
| const auto* input_start = &mat(row_start, col_start); |
| const auto* input_end = &mat(slice_row_start + slice_num_rows - 1, |
| slice_col_start + slice_num_cols - 1); |
| const int mat_num_cols = mat.dimension(1); |
| const int row_slice_size = slice_num_rows * mat_num_cols; |
| |
| const int aligned_end = slice_num_cols / N * slice_num_rows; |
| const int e1 = std::min(e, aligned_end); |
| while (s < e1) { |
| CopyAndMayBeInterleave<TR>(out_start, input_start, N); |
| out_start += N; |
| input_start += mat_num_cols; |
| if (input_start > input_end) { |
| input_start = input_start - row_slice_size + N; |
| } |
| ++s; |
| } |
| int s1 = std::max(s, aligned_end); |
| const int copy_num_cols = slice_num_cols % N; |
| while (s1 < e) { |
| CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols); |
| out_start += N; |
| input_start += mat_num_cols; |
| ++s1; |
| } |
| if (counter) counter->DecrementCount(); |
| }; |
| |
| int start = 0; |
| int end = 0; |
| int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows; |
| DCHECK_LE(num_out_rows, buffer->dimension(0)); |
| for (int i = std::max(1, num_threads); i > 0; --i) { |
| end = start + num_out_rows / i; |
| thread_pool->workers->Schedule([=]() { shuffle_work(start, end); }); |
| num_out_rows -= (end - start); |
| start = end; |
| } |
| return counter; |
| } |
| |
| template <typename TL, typename TR> |
| inline void SparseMatMul<TL, TR>::SliceMatrix( |
| const MatrixR& mat, const int num_rows, const int num_slices, |
| std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) { |
| slices->resize(num_slices); |
| DSizes d(num_rows, mat.dimension(1)); |
| DCHECK_LE(num_rows * num_slices, mat.dimension(0)); |
| for (int i = 0; i < num_slices; ++i) { |
| (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d); |
| } |
| } |
| |
| template <typename TL, typename TR> |
| inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices( |
| const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start, |
| int num_rows, int col_start, int num_cols, |
| const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer, |
| std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) { |
| std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix( |
| mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer)); |
| const int num_slices = (num_cols + N - 1) / N; |
| SliceMatrix(*buffer, num_rows, num_slices, slices); |
| return shuffle_counter; |
| } |
| |
| template <typename TL, typename TR> |
| inline void SparseMatMul<TL, TR>::ComputeBlockSizes( |
| const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left, |
| const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, |
| bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB, |
| int* IB) { |
| // Heuristics for calculating block sizes |
| // Assume two hyperthreads per core. |
| const int est_num_cores = std::max(1, (num_threads + 1) / 2); |
| // Use block of rhs with at most 128K floats per core. |
| const int mem = est_num_cores * 128 * 1024; |
| *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256); |
| *NR = right.dimension(1); |
| if (*KR * *NR > mem) { |
| // 4096 may be enough to amortize the cost of writes. |
| *KR = std::min<int>(*KR, 4096); |
| } |
| // Use sizes that are multiples of K and 256. |
| *KR = std::max(1, *KR / K) * K; |
| *NR = std::max(1, *NR / 256) * 256; |
| if (*KR * *NR > mem) { |
| *NR = mem / *KR; |
| } |
| *NR = std::max(1, *NR / 256) * 256; |
| |
| const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0); |
| const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1); |
| for (*KL = 1024; *KL > K; *KL /= 2) { |
| if (*KR % *KL == 0 && |
| std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) { |
| break; |
| } |
| } |
| DCHECK_EQ(*KL % K, 0); |
| DCHECK_GE(*KR, *KL); |
| if (*KR < right.dimension(0)) { |
| CHECK_EQ(*KR % *KL, 0); |
| } |
| |
| *JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0)); |
| *IB = 8 * *JB; |
| DCHECK_EQ(N * sizeof(float) % 64, size_t{0}); |
| } |
| |
| #ifdef TENSORFLOW_USE_LIBXSMM |
| |
| template <typename F> |
| void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool, |
| const F& f) { |
| int num_threads = thread_pool->num_threads; |
| if (num_threads == 0) { |
| LOG(FATAL) << "Have 0 threads in thread pool"; |
| } else if (num_threads == 1) { |
| f(0); |
| } else { |
| BlockingCounter counter(num_threads - 1); |
| for (int i = 1; i < num_threads; ++i) { |
| thread_pool->workers->Schedule([&, i]() { |
| f(i); |
| counter.DecrementCount(); |
| }); |
| } |
| f(0); |
| counter.Wait(); |
| } |
| } |
| |
| template <typename T> |
| struct empty_type_wrapper {}; |
| |
| // Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to |
| // allow overloading |
| void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( |
| empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA, |
| const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, |
| int tid, int nthreads) { |
| return libxsmm_spmdm_createSparseSlice_fp32_thread( |
| handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads); |
| } |
| void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( |
| empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle, |
| char transA, const bfloat16* A, |
| libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid, |
| int nthreads) { |
| return libxsmm_spmdm_createSparseSlice_bfloat16_thread( |
| handle, transA, reinterpret_cast<const libxsmm_bfloat16*>(A), |
| libxsmm_output_csr_a, block_id, tid, nthreads); |
| } |
| |
| void wrapper_libxsmm_spmdm_compute_generic_thread( |
| empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle, |
| char transA, char transB, const bfloat16* alpha, |
| libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC, |
| const bfloat16* beta, float* C, int block_id, int tid, int nthreads) { |
| return libxsmm_spmdm_compute_bfloat16_thread( |
| handle, transA, transB, reinterpret_cast<const libxsmm_bfloat16*>(alpha), |
| A_sparse, reinterpret_cast<const libxsmm_bfloat16*>(B), transC, |
| reinterpret_cast<const libxsmm_bfloat16*>(beta), C, block_id, tid, |
| nthreads); |
| } |
| void wrapper_libxsmm_spmdm_compute_generic_thread( |
| empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA, |
| char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse, |
| const float* B, char transC, const float* beta, float* C, int block_id, |
| int tid, int nthreads) { |
| return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha, |
| A_sparse, B, transC, beta, C, |
| block_id, tid, nthreads); |
| } |
| |
| template <typename TL, typename TR> |
| inline void LibxsmmSparseMatMul<TL, TR>::Compute( |
| typename LibxsmmSparseMatMul<TL, TR>::TensorInfoCache* cache, |
| const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapL& left, |
| const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right, |
| bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, |
| bool transpose_output, MatrixMap* output) { |
| const int num_threads = thread_pool->num_threads; |
| const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0); |
| const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1); |
| const int right_dim0 = right.dimension(0); |
| const int right_dim1 = right.dimension(1); |
| CHECK_EQ(left_dim1, right_dim0); |
| CHECK_EQ(left_dim0, |
| (transpose_output ? output->dimension(1) : output->dimension(0))); |
| CHECK_EQ(right_dim1, |
| (transpose_output ? output->dimension(0) : output->dimension(1))); |
| #if 0 // this issue seems to be resolved |
| if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) { |
| // Causes problems in libxsmm |
| SparseMatMul<TL, TR>::Compute( |
| nullptr /* Assumes no cached data for fallback */, left, right, |
| transpose_left, thread_pool, transpose_output, output); |
| return; |
| } |
| #endif |
| auto left_data = left.data(); |
| auto right_data = right.data(); |
| auto output_data = output->data(); |
| // Initialize libxsmm for this matrix; make sure another thread doesn't use |
| // this handle |
| auto entry = |
| cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads); |
| // Convert the left matrix to compressed sparse row (CSR) format |
| ptrdiff_t total_num_creation_blocks = |
| libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle); |
| std::atomic<int> cur_create_block_number; |
| cur_create_block_number.store(0); |
| do_on_all_threads(thread_pool, [&](int i) { |
| while (true) { |
| int work_item = cur_create_block_number.fetch_add(1); |
| if (work_item >= total_num_creation_blocks) break; |
| wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( |
| empty_type_wrapper<TL>{}, &entry->handle, |
| (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item, |
| i, num_threads); |
| } |
| }); |
| // Do matrix-matrix multiplication |
| ptrdiff_t total_num_mult_blocks = |
| libxsmm_spmdm_get_num_compute_blocks(&entry->handle); |
| std::atomic<int> cur_mult_block_number; |
| cur_mult_block_number.store(0); |
| do_on_all_threads(thread_pool, [&](int i) { |
| while (true) { |
| int work_item = cur_mult_block_number.fetch_add(1); |
| if (work_item >= total_num_mult_blocks) break; |
| const TL alpha(1.0); // Stored in a variable so we can get a pointer |
| const TL beta(0.0); // Stored in a variable so we can get a pointer |
| wrapper_libxsmm_spmdm_compute_generic_thread( |
| empty_type_wrapper<TL>{}, &entry->handle, |
| (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr, |
| right_data, (transpose_output ? 'T' : 'N'), &beta, output_data, |
| work_item, i, num_threads); |
| } |
| }); |
| // Put handle + CSR storage back into cache |
| cache->return_cache_entry(std::move(entry)); |
| } |
| |
| #endif // TENSORFLOW_USE_LIBXSMM |
| |
| // Here is an overview of the SparseMatMul code. Note that we assume that the |
| // left matrix is sparse. |
| // |
| // The matrix "left" is divided into a grid with blocksize of (M, KL). Each |
| // block is encoded as a SparseSlice. These grid elements are stored as |
| // std::vector<std::vector<SparseSlice>>. Each element of the outer vector |
| // represents M rows of the left matrix. Lets call these elements l_i and lets |
| // call each element of the inner vector L_mk. |
| // |
| // The matrix "right" is divided into a grid with block size KR * NR. Lets |
| // denote the blocks on the right as R_kn. Note that we ensure that KL divides |
| // KR so that for each element R_kn, we don't need to multiply it with any |
| // partial L_mk blocks. |
| // |
| // We then multiply each right side block R_kn with the full "left" matrix and |
| // update the output. These iterations are run sequentially since R_kn are |
| // packed into the same underlying temporary buffer. |
| // |
| // In each iteration we do the following: |
| // 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N |
| // (=128) columns and then concatenating these slices into a buffer. This is |
| // done so that each slice r_j of R_kn is stored contiguously in memory. Note |
| // that if R_kj has dimensions (KR, NR), we create NR / N slices, and the |
| // buffer has dimensions (KR * NR / N, N) (assuming N divides NR). |
| // 2. For each (l_i, r_j), we compute the inner product using the GEPP function |
| // and update the output block o_ij. These calls are further blocked to |
| // reduce the working set size. In each iteration we take IB elements from |
| // {l_i} and JB elements from {r_j} and compute the IB * JB inner products. |
| template <typename TL, typename TR> |
| inline void SparseMatMul<TL, TR>::Compute( |
| typename SparseMatMul<TL, TR>::TensorInfoCache* /*cache*/, |
| const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left, |
| const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, |
| bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, |
| bool transpose_output, MatrixMap* output) { |
| const int num_threads = thread_pool->num_threads; |
| int KR, NR, KL, JB, IB; |
| ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL, |
| &JB, &IB); |
| // Slice the left matrix |
| std::vector<std::vector<SparseSlice<TL>*>> left_slices; |
| std::unique_ptr<BlockingCounter> sparse_slice_counter = |
| CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()), |
| transpose_left, M, K, KL, &left_slices, thread_pool); |
| const int num_left_slices = left_slices.size(); |
| |
| const int right_dim0 = right.dimension(0); |
| const int right_dim1 = right.dimension(1); |
| // Allocate buffer for storing slices of right matrix. |
| // Note buffer needs enough space to hold at most a KR * NR matrix since that |
| // is the block size per iteration. |
| const int buffer_num_rows = |
| std::min(KR, right_dim0) * ((std::min(NR, right_dim1) + N - 1) / N); |
| MatrixR buffer(buffer_num_rows, N); |
| std::vector<ConstMatrixMapR*> right_slices; |
| |
| std::vector<SparseSlice<TL>*> block_left_slices; |
| std::vector<std::function<void(void)>> tasks; |
| // Number of blocks based on block sizes of KR * NR. |
| const int num_k_blocks = (right_dim0 + KR - 1) / KR; |
| const int num_n_blocks = (right_dim1 + NR - 1) / NR; |
| std::unique_ptr<BlockingCounter> dense_slice_counter; |
| |
| for (int nb = 0; nb < num_n_blocks; ++nb) { |
| const int right_num_cols = |
| std::min(NR, static_cast<int>(right_dim1 - NR * nb)); |
| for (int kb = 0; kb < num_k_blocks; ++kb) { |
| const int right_num_rows = |
| std::min(KR, static_cast<int>(right_dim0 - KR * kb)); |
| dense_slice_counter = CreateDenseSlices( |
| right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool, |
| &buffer, &right_slices); |
| const int num_right_slices = right_slices.size(); |
| tasks.reserve(num_left_slices * num_right_slices); |
| for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) { |
| for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) { |
| for (int j_inner = j_outer; |
| j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) { |
| const int num_cols = std::min(N, right_num_cols - N * j_inner); |
| for (int i_inner = i_outer; |
| i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) { |
| block_left_slices.clear(); |
| int begin = kb * KR / KL; |
| int end = std::min<int>((kb + 1) * KR / KL, |
| (right.dimension(0) + KL - 1) / KL); |
| DCHECK_LT(begin, end); |
| block_left_slices.insert(block_left_slices.begin(), |
| left_slices[i_inner].begin() + begin, |
| left_slices[i_inner].begin() + end); |
| tasks.push_back(std::bind( |
| &ComputeOutputBlock, block_left_slices, |
| std::ref(*right_slices[j_inner]), num_cols, M * i_inner, |
| N * j_inner + nb * NR, kb == 0, transpose_output, output)); |
| } |
| } |
| } |
| } |
| if (sparse_slice_counter) { |
| sparse_slice_counter->Wait(); |
| sparse_slice_counter.reset(nullptr); |
| } |
| if (dense_slice_counter) { |
| dense_slice_counter->Wait(); |
| dense_slice_counter.reset(nullptr); |
| } |
| BlockingCounter bc(tasks.size()); |
| for (const auto& t : tasks) { |
| thread_pool->workers->Schedule([&bc, &t]() { |
| t(); |
| bc.DecrementCount(); |
| }); |
| } |
| bc.Wait(); |
| tasks.clear(); |
| for (auto& temp : right_slices) { |
| delete temp; |
| } |
| right_slices.clear(); |
| } |
| } |
| for (auto& left_slice : left_slices) { |
| for (auto& temp : left_slice) { |
| delete temp; |
| } |
| left_slice.clear(); |
| } |
| } |
| |
| #define REGISTER_SPARSE_MATMUL(TA, TB) \ |
| REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<TA>("Ta") \ |
| .TypeConstraint<TB>("Tb"), \ |
| SparseMatMulOp<TA, TB, SparseMatMul>); |
| #ifdef TENSORFLOW_USE_LIBXSMM |
| #define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB) \ |
| REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<TA>("Ta") \ |
| .TypeConstraint<TB>("Tb"), \ |
| SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>); |
| #endif |
| |
| REGISTER_SPARSE_MATMUL(float, bfloat16); |
| REGISTER_SPARSE_MATMUL(bfloat16, float); |
| |
| #ifdef TENSORFLOW_USE_LIBXSMM |
| REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16); |
| REGISTER_SPARSE_MATMUL_LIBXSMM(float, float); |
| #else |
| REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); |
| REGISTER_SPARSE_MATMUL(float, float); |
| #endif |
| |
| #undef REGISTER_SPARSE_MATMUL |
| |
| } // end namespace tensorflow |