blob: ebc6d8fa4ec5422925e57c25856e0007702299b1 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/sparse_matmul_op.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
random::PhiloxRandom philox(1, 1);
random::SimplePhilox rnd(&philox);
using Eigen::operator==;
template <typename T>
void Sparsify(Tensor* t, float sparsity) {
const int64 N = t->NumElements();
CHECK_LE(sparsity, 1);
auto flat = t->flat<T>();
if (sparsity == 1) {
flat.setZero();
return;
}
static const uint32 K = 10000;
for (int64 i = 0; i < N; ++i) {
if (rnd.Uniform(K) < sparsity * K) {
flat(i) = T(0);
} else if (flat(i) == T(0)) {
flat(i) = T(1);
}
}
}
Node* SparseMatMulNode(Graph* g, Node* in0, Node* in1, bool transpose_a,
bool transpose_b, bool a_sparse, bool b_sparse) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "SparseMatMul")
.Input(in0)
.Input(in1)
.Attr("transpose_a", transpose_a)
.Attr("transpose_b", transpose_b)
.Attr("a_is_sparse", a_sparse)
.Attr("b_is_sparse", b_sparse)
.Finalize(g, &ret));
return ret;
}
template <typename TA, typename TB>
static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d,
float sparsity_a, float sparsity_b,
bool transpose_a, bool transpose_b) {
bool a_sparse = (sparsity_a > 0);
bool b_sparse = (sparsity_b > 0);
auto left_shape = transpose_a ? TensorShape({d, m}) : TensorShape({m, d});
Tensor left(DataTypeToEnum<TA>::value, left_shape);
left.flat<TA>().setRandom();
Sparsify<TA>(&left, sparsity_a);
auto right_shape = transpose_b ? TensorShape({n, d}) : TensorShape({d, n});
Tensor right(DataTypeToEnum<TB>::value, right_shape);
right.flat<TB>().setRandom();
Sparsify<TB>(&right, sparsity_b);
SparseMatMulNode(g, test::graph::Constant(g, left),
test::graph::Constant(g, right), transpose_a, transpose_b,
a_sparse, b_sparse);
return g;
}
template <typename TA, typename TB>
static Graph* SparseMatMul(int m, int n, int d, float sparsity_a,
float sparsity_b, bool transpose_a,
bool transpose_b) {
Graph* g = new Graph(OpRegistry::Global());
return SparseMatMulHelper<TA, TB>(g, m, n, d, sparsity_a, sparsity_b,
transpose_a, transpose_b);
}
static Graph* ReplicatedSparseMatMul(int m, int n, int d, float sparsity_1,
float sparsity_2, int copies) {
Graph* g = new Graph(OpRegistry::Global());
for (int i = 0; i < copies; ++i) {
SparseMatMulHelper<float, float>(g, m, n, d, sparsity_1, sparsity_2, false,
false);
}
return g;
}
#define BM_SPARSE(M, K, N, S1, S2, TRA, TRB, TA, TB) \
static void \
BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB( \
int iters) { \
testing::StopTiming(); \
testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
auto label = strings::Printf("tr_a: %d tr_b: %d sp_a: %0.2f sp_b: %0.2f", \
TRA, TRB, S1 / 100.0, S2 / 100.0); \
testing::SetLabel(label); \
testing::UseRealTime(); \
auto g = SparseMatMul<TA, TB>(M, N, K, S1 / 100.0, S2 / 100.0, TRA, TRB); \
testing::StartTiming(); \
test::Benchmark("cpu", g).Run(iters); \
} \
BENCHMARK( \
BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB);
#define BM_SPARSE_REPLICATED(M, K, N, S1, S2, Copies) \
static void BM_Sparse_replicated##_##M##_##K##_##N##_##S1##_##S2##_##Copies( \
int iters) { \
testing::StopTiming(); \
testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * Copies * \
2); \
auto label = strings::Printf("copies: %d sp_a: %0.2f sp_b: %0.2f", \
(Copies), S1 / 100.0, S2 / 100.0); \
testing::SetLabel(label); \
testing::UseRealTime(); \
auto g = \
ReplicatedSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, (Copies)); \
testing::StartTiming(); \
test::Benchmark("cpu", g).Run(iters); \
} \
BENCHMARK(BM_Sparse_replicated##_##M##_##K##_##N##_##S1##_##S2##_##Copies);
#define BM_SPARSE_FLOAT(M, K, N, S1, S2, TRA, TRB) \
BM_SPARSE(M, K, N, S1, S2, TRA, TRB, float, float)
#define BM_SPARSE_BFLOAT16(M, K, N, S1, S2, TRA, TRB) \
BM_SPARSE(M, K, N, S1, S2, TRA, TRB, bfloat16, bfloat16)
#define BM_SPARSE_FLOAT_BFLOAT16(M, K, N, S1, S2, TRA, TRB) \
BM_SPARSE(M, K, N, S1, S2, TRA, TRB, float, bfloat16)
#define BM_SPARSE_BFLOAT16_FLOAT(M, K, N, S1, S2, TRA, TRB) \
BM_SPARSE(M, K, N, S1, S2, TRA, TRB, bfloat16, float)
// Test sparse b
BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 0, false, false);
BM_SPARSE_FLOAT(2048, 2048, 2048, 1, 0, false, false);
BM_SPARSE_FLOAT(2048, 2048, 2048, 50, 0, false, false);
BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, false, false);
BM_SPARSE_FLOAT(2048, 2048, 2048, 99, 0, false, false);
// Test sparse a
BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 50, false, false);
BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, false, false);
// Test transposing
BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, true, false);
BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, false, true);
BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, true, true);
BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, true, false);
BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, false, true);
BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, true, true);
// Test smaller sizes
BM_SPARSE_FLOAT(1024, 1024, 1024, 0, 0, false, false);
BM_SPARSE_FLOAT(1024, 1024, 1024, 1, 0, false, false);
BM_SPARSE_FLOAT(1024, 1024, 1024, 85, 0, false, false);
BM_SPARSE_FLOAT(256, 256, 256, 1, 0, false, false);
BM_SPARSE_FLOAT(512, 512, 512, 1, 0, false, false);
BM_SPARSE_FLOAT(2560, 400, 1024, 85, 0, false, false);
BM_SPARSE_FLOAT(2560, 400, 1024, 85, 0, true, false);
BM_SPARSE_FLOAT(400, 800, 2560, 85, 0, false, false);
BM_SPARSE_FLOAT(400, 2560, 1024, 85, 0, false, false);
BM_SPARSE_FLOAT(400, 1024, 256, 85, 0, false, false);
BM_SPARSE_FLOAT(400, 256, 1, 85, 0, false, false);
BM_SPARSE_REPLICATED(400, 800, 2560, 85, 0, 6);
BM_SPARSE_REPLICATED(400, 2560, 1024, 85, 0, 6);
BM_SPARSE_REPLICATED(400, 1024, 256, 85, 0, 6);
BM_SPARSE_REPLICATED(400, 256, 1, 85, 0, 6);
BM_SPARSE_FLOAT(2048, 1792, 1024, 85, 0, false, false);
BM_SPARSE_FLOAT(2048, 1024, 768, 85, 0, false, false);
BM_SPARSE_FLOAT(2048, 768, 512, 85, 0, false, false);
BM_SPARSE_FLOAT(2048, 512, 256, 85, 0, false, false);
BM_SPARSE_FLOAT(2049, 1792, 1024, 85, 0, false, false);
BM_SPARSE_FLOAT(2049, 1024, 768, 85, 0, false, false);
BM_SPARSE_FLOAT(2049, 768, 512, 85, 0, false, false);
BM_SPARSE_FLOAT(2049, 512, 256, 85, 0, false, false);
BM_SPARSE_REPLICATED(2048, 1792, 1024, 85, 0, 6);
BM_SPARSE_REPLICATED(2048, 1024, 768, 85, 0, 6);
BM_SPARSE_REPLICATED(2048, 768, 512, 85, 0, 6);
BM_SPARSE_REPLICATED(2048, 512, 256, 85, 0, 6);
// Test bfloat16
BM_SPARSE_BFLOAT16(2048, 2048, 2048, 0, 0, false, false);
BM_SPARSE_BFLOAT16(2048, 2048, 2048, 1, 0, false, false);
BM_SPARSE_BFLOAT16(2048, 2048, 2048, 85, 0, false, false);
BM_SPARSE_BFLOAT16(2048, 2048, 2048, 99, 0, false, false);
BM_SPARSE_BFLOAT16_FLOAT(2048, 2048, 2048, 85, 0, false, false);
BM_SPARSE_BFLOAT16_FLOAT(2048, 2048, 2048, 99, 0, false, false);
BM_SPARSE_FLOAT_BFLOAT16(2048, 2048, 2048, 85, 0, false, false);
BM_SPARSE_FLOAT_BFLOAT16(2048, 2048, 2048, 99, 0, false, false);
static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_1,
float sparsity_2, int copies) {
Graph* g = new Graph(OpRegistry::Global());
for (int i = 0; i < copies; ++i) {
SparseMatMulHelper<float, float>(g, d, n, m, sparsity_1, sparsity_2, true,
false);
SparseMatMulHelper<float, float>(g, m, d, n, sparsity_2, 0, false, true);
}
return g;
}
#define BM_SPARSE_MULTI(M, K, N, S1, S2, Copies) \
static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2##_##Copies( \
int iters) { \
testing::StopTiming(); \
testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2 * 2 * \
Copies); \
auto label = strings::Printf("%d_%d_%d_%d_%0.2f_%0.2f", M, K, N, Copies, \
S1 / 100.0, S2 / 100.0); \
testing::SetLabel(label); \
testing::UseRealTime(); \
auto g = MultiSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, Copies); \
testing::StartTiming(); \
test::Benchmark("cpu", g).Run(iters); \
} \
BENCHMARK(BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2##_##Copies);
BM_SPARSE_MULTI(1024, 2140, 4096, 0, 82, 1);
BM_SPARSE_MULTI(1024, 4096, 2048, 83, 83, 1);
BM_SPARSE_MULTI(400, 800, 2560, 85, 85, 1);
BM_SPARSE_MULTI(400, 2560, 1024, 85, 85, 1);
BM_SPARSE_MULTI(400, 1024, 256, 85, 85, 1);
BM_SPARSE_MULTI(400, 256, 1, 85, 85, 1);
BM_SPARSE_MULTI(2048, 1792, 1024, 85, 85, 1);
BM_SPARSE_MULTI(2048, 1024, 768, 85, 85, 1);
BM_SPARSE_MULTI(2048, 768, 512, 85, 85, 1);
BM_SPARSE_MULTI(2048, 512, 256, 85, 85, 1);
BM_SPARSE_MULTI(2048, 1792, 1024, 85, 85, 3);
BM_SPARSE_MULTI(2048, 1024, 768, 85, 85, 3);
BM_SPARSE_MULTI(2048, 768, 512, 85, 85, 3);
BM_SPARSE_MULTI(2048, 512, 256, 85, 85, 3);
BM_SPARSE_MULTI(2048, 1792, 1024, 85, 85, 6);
BM_SPARSE_MULTI(2048, 1024, 768, 85, 85, 6);
BM_SPARSE_MULTI(2048, 768, 512, 85, 85, 6);
BM_SPARSE_MULTI(2048, 512, 256, 85, 85, 6);
} // end namespace tensorflow
namespace Eigen {
namespace internal {
class SparseMatmulOpTest : public ::testing::Test {
protected:
SparseMatmulOpTest()
: PacketSize(Eigen::internal::packet_traits<float>::size) {
typedef typename NumTraits<float>::Real RealFloat;
for (int i = 0; i < kMaxPacketSize; ++i) {
data1[i] = internal::random<float>() / RealFloat(PacketSize);
data2[i] = internal::random<float>() / RealFloat(PacketSize);
data3[i] = internal::random<float>() / RealFloat(PacketSize);
}
for (int i = kMaxPacketSize; i < kMaxPacketSize * 2; ++i) {
data3[i] = internal::random<float>() / RealFloat(PacketSize);
}
// zero out lower 16-bits of mantissa of data3 values
// copy bfloat representation to data3_bfloat16
for (int i = 0; i < kMaxPacketSize * 2; ++i) {
uint16_t* data3_p = reinterpret_cast<uint16_t*>(&data3[i]);
uint16_t* data3_bfloat16_p =
reinterpret_cast<uint16_t*>(data3_bfloat16) + i;
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
data3_p[1] = 0;
data3_bfloat16_p[0] = data3_p[0];
#else
data3_p[0] = 0;
data3_bfloat16_p[0] = data3_p[1];
#endif
}
}
bool areApprox(const float* a, const float* b, int size) {
for (int i = 0; i < size; ++i) {
if (a[i] != b[i] && !internal::isApprox(a[i], b[i])) {
auto ma = Map<const Matrix<float, 1, Dynamic> >(a, size);
auto mb = Map<const Matrix<float, 1, Dynamic> >(b, size);
std::cout << "[" << ma << "]"
<< " != [" << mb << "], differences: [" << (mb - ma) << "]\n";
return false;
}
}
return true;
}
#ifdef EIGEN_VECTORIZE_AVX512
static const int kMaxPacketSize = 16;
#elif defined EIGEN_VECTORIZE_AVX || defined EIGEN_VECTORIZE_AVX2
static const int kMaxPacketSize = 8;
#else
static const int kMaxPacketSize = 4;
#endif
typedef typename Eigen::internal::packet_traits<float>::type Packet;
const int PacketSize;
// float values
EIGEN_ALIGN_MAX float data1[kMaxPacketSize];
// output of intrinsics
EIGEN_ALIGN_MAX float data2[kMaxPacketSize];
// float values with only 7 mantissa bits (bfloat representable)
EIGEN_ALIGN_MAX float data3[kMaxPacketSize * 2];
// bfloat16 representation of data3
EIGEN_ALIGN_MAX float data3_bfloat16[kMaxPacketSize];
EIGEN_ALIGN_MAX float ref[kMaxPacketSize];
};
TEST_F(SparseMatmulOpTest, BroadcastPacketTest) {
for (int i = 0; i < PacketSize; ++i) ref[i] = data1[0];
internal::pstoreu(data2, internal::pbroadcast_first<Packet>(
internal::ploadu<Packet>(data1)));
ASSERT_TRUE(areApprox(ref, data2, PacketSize));
if (PacketSize > 1) {
for (int i = 0; i < PacketSize; ++i) ref[i] = data1[1];
internal::pstoreu(data2, internal::pbroadcast_second<Packet>(
internal::ploadu<Packet>(data1)));
ASSERT_TRUE(areApprox(ref, data2, PacketSize));
if (PacketSize > 2) {
for (int i = 0; i < PacketSize; ++i) ref[i] = data1[2];
internal::pstoreu(data2, internal::pbroadcast_third<Packet>(
internal::ploadu<Packet>(data1)));
ASSERT_TRUE(areApprox(ref, data2, PacketSize));
if (PacketSize > 3) {
for (int i = 0; i < PacketSize; ++i) ref[i] = data1[3];
internal::pstoreu(data2, internal::pbroadcast_fourth<Packet>(
internal::ploadu<Packet>(data1)));
ASSERT_TRUE(areApprox(ref, data2, PacketSize));
}
}
}
}
TEST_F(SparseMatmulOpTest, InterleavePacketTest) {
if (PacketSize == 8) { // AVX
for (int i = 0; i < PacketSize / 4; ++i) ref[i] = data1[i];
for (int i = PacketSize / 4; i < PacketSize / 2; ++i)
ref[i] = data1[i + PacketSize / 4];
for (int i = PacketSize / 2; i < 3 * PacketSize / 4; ++i)
ref[i] = data1[i - PacketSize / 4];
for (int i = 3 * PacketSize / 4; i < PacketSize; ++i) ref[i] = data1[i];
} else {
// No interleaving done for smaller packets
for (int i = 0; i < PacketSize; ++i) ref[i] = data1[i];
}
internal::pstoreu(data2, internal::pinterleave4x64<Packet>(
internal::ploadu<Packet>(data1)));
ASSERT_TRUE(areApprox(ref, data2, PacketSize));
}
TEST_F(SparseMatmulOpTest, Bfloat16ExpandTest) {
if (PacketSize == 8) { // AVX
for (int i = 0; i < PacketSize / 2; ++i) {
ref[i] = data3[i];
}
for (int i = 0; i < PacketSize / 2; ++i) {
ref[i + PacketSize / 2] = data3[i + PacketSize];
}
} else {
for (int i = 0; i < PacketSize; ++i) {
ref[i] = data3[i];
}
}
internal::pstoreu(data2, internal::pexpand_bf16_l<Packet>(
internal::ploadu<Packet>(data3_bfloat16)));
ASSERT_TRUE(areApprox(ref, data2, PacketSize));
if (PacketSize == 8) { // AVX
for (int i = 0; i < PacketSize / 2; ++i) {
ref[i] = data3[i + PacketSize / 2];
}
for (int i = 0; i < PacketSize / 2; ++i) {
ref[i + PacketSize / 2] = data3[i + 3 * PacketSize / 2];
}
} else {
for (int i = 0; i < PacketSize; ++i) {
ref[i] = data3[i + PacketSize];
}
}
internal::pstoreu(data2, internal::pexpand_bf16_u<Packet>(
internal::ploadu<Packet>(data3_bfloat16)));
ASSERT_TRUE(areApprox(ref, data2, PacketSize));
}
TEST_F(SparseMatmulOpTest, Bfloat16LoadTest) {
if (PacketSize >= 4) {
for (int i = 0; i < 4; ++i) ref[i] = data3[i];
internal::pstoreu(data2, internal::pload4bf16<Packet>(data3_bfloat16));
ASSERT_TRUE(areApprox(ref, data2, 4));
internal::pstoreu(data2, internal::pload2bf16<Packet>(data3_bfloat16));
ASSERT_TRUE(areApprox(ref, data2, 2));
}
}
} // namespace internal
} // namespace Eigen