blob: 0c04a82818f486421b6b4220a3f572a18a1cdb67 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/broadcast_to_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
.Input(input)
.Input(shape)
.Finalize(g, &ret));
return ret;
}
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
.Input(in0)
.Input(in1)
.Attr("adj_x", adj_x)
.Attr("adj_y", adj_y)
.Finalize(g, &ret));
return ret;
}
template <typename T>
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
bool adjoint_b, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k}));
in0.flat<T>().setRandom();
Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n}));
in1.flat<T>().setRandom();
test::graph::BatchMatmul(g, test::graph::Constant(g, in0),
test::graph::Constant(g, in1), adjoint_a, adjoint_b);
return g;
}
template <typename T>
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
bool manual_broadcast, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, TensorShape({b0, m, k}));
in0.flat<T>().setRandom();
Tensor in1(type, TensorShape({b1, k, n}));
in1.flat<T>().setRandom();
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
Node* in0_node = nullptr;
Node* in1_node = nullptr;
if (manual_broadcast) {
for (int i = 0; i < 3; ++i) {
auto vec0 = broadcasted_in0_shape.vec<int64>();
auto vec1 = broadcasted_in1_shape.vec<int64>();
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
}
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
test::graph::Constant(g, broadcasted_in0_shape));
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
test::graph::Constant(g, broadcasted_in1_shape));
} else {
in0_node = test::graph::Constant(g, in0);
in1_node = test::graph::Constant(g, in1);
}
BatchMatmulV2(g, in0_node, in1_node, false, false);
return g;
}
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \
static void \
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2); \
test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE)) \
.Run(iters); \
} \
BENCHMARK( \
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
#define BM_BatchMatmul(B, M, K, N, TA, TB) \
BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu);
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
// cpu);
// BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu);
/* Uncomment to enable benchmarks for double & complex types: */
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
// gpu);
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu);
// \
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
// Macro arguments names: --------------------------------------------------- //
// B1: batch size of LHS
// B2: batch size of RHS
// M: outer dimension of LHS
// K: inner dimensions of LHS and RHS
// N: outer dimension of RHS
// MB: boolean indicating whether to use manual broadcasting
// T: C++ type of scalars (e.g. float, std::complex)
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
// D: Device (e.g. cpu, gpu)
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \
static void \
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
K * N * 2); \
test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \
.Run(iters); \
} \
BENCHMARK( \
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
// Typical fully connected layers
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
// Square matmul.
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
// Matrix-vector multiplies.
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
// Vector-matrix multiplies.
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
// Typical fully connected layers
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
BM_BatchMatmul(1, 16, 1024, 1024, false, false);
BM_BatchMatmul(1, 128, 1024, 1024, false, false);
BM_BatchMatmul(2, 1, 1024, 1024, false, false);
BM_BatchMatmul(2, 8, 1024, 1024, false, false);
BM_BatchMatmul(2, 16, 1024, 1024, false, false);
BM_BatchMatmul(2, 128, 1024, 1024, false, false);
BM_BatchMatmul(8, 1, 1024, 1024, false, false);
BM_BatchMatmul(8, 8, 1024, 1024, false, false);
BM_BatchMatmul(8, 16, 1024, 1024, false, false);
BM_BatchMatmul(8, 128, 1024, 1024, false, false);
BM_BatchMatmul(32, 1, 1024, 1024, false, false);
BM_BatchMatmul(32, 8, 1024, 1024, false, false);
BM_BatchMatmul(32, 16, 1024, 1024, false, false);
BM_BatchMatmul(32, 128, 1024, 1024, false, false);
// Square matmul.
BM_BatchMatmul(1, 32, 32, 32, false, false);
BM_BatchMatmul(1, 128, 128, 128, false, false);
BM_BatchMatmul(1, 256, 256, 256, false, false);
BM_BatchMatmul(1, 1024, 1024, 1024, false, false);
BM_BatchMatmul(1, 2048, 2048, 2048, false, false);
BM_BatchMatmul(2, 32, 32, 32, false, false);
BM_BatchMatmul(2, 128, 128, 128, false, false);
BM_BatchMatmul(2, 256, 256, 256, false, false);
BM_BatchMatmul(2, 1024, 1024, 1024, false, false);
BM_BatchMatmul(2, 2048, 2048, 2048, false, false);
BM_BatchMatmul(4, 32, 32, 32, false, false);
BM_BatchMatmul(4, 128, 128, 128, false, false);
BM_BatchMatmul(4, 256, 256, 256, false, false);
BM_BatchMatmul(4, 1024, 1024, 1024, false, false);
BM_BatchMatmul(4, 2048, 2048, 2048, false, false);
BM_BatchMatmul(8, 32, 32, 32, false, false);
BM_BatchMatmul(8, 128, 128, 128, false, false);
BM_BatchMatmul(8, 256, 256, 256, false, false);
BM_BatchMatmul(8, 1024, 1024, 1024, false, false);
BM_BatchMatmul(8, 2048, 2048, 2048, false, false);
BM_BatchMatmul(32, 32, 32, 32, false, false);
BM_BatchMatmul(32, 128, 128, 128, false, false);
BM_BatchMatmul(32, 256, 256, 256, false, false);
BM_BatchMatmul(32, 1024, 1024, 1024, false, false);
BM_BatchMatmul(32, 2048, 2048, 2048, false, false);
// Matrix-vector multiplies.
BM_BatchMatmul(1, 10000, 200, 1, false, false);
BM_BatchMatmul(8, 10000, 200, 1, false, false);
BM_BatchMatmul(32, 10000, 200, 1, false, false);
BM_BatchMatmul(1, 10000, 200, 1, true, false);
BM_BatchMatmul(8, 10000, 200, 1, true, false);
BM_BatchMatmul(32, 10000, 200, 1, true, false);
BM_BatchMatmul(1, 10000, 200, 1, false, true);
BM_BatchMatmul(8, 10000, 200, 1, false, true);
BM_BatchMatmul(32, 10000, 200, 1, false, true);
BM_BatchMatmul(1, 10000, 200, 1, true, true);
BM_BatchMatmul(8, 10000, 200, 1, true, true);
BM_BatchMatmul(32, 10000, 200, 1, true, true);
// Vector-matrix multiplies.
BM_BatchMatmul(1, 1, 200, 10000, false, false);
BM_BatchMatmul(8, 1, 200, 10000, false, false);
BM_BatchMatmul(32, 1, 200, 10000, false, false);
BM_BatchMatmul(1, 1, 200, 10000, true, false);
BM_BatchMatmul(8, 1, 200, 10000, true, false);
BM_BatchMatmul(32, 1, 200, 10000, true, false);
BM_BatchMatmul(1, 1, 200, 10000, false, true);
BM_BatchMatmul(8, 1, 200, 10000, false, true);
BM_BatchMatmul(32, 1, 200, 10000, false, true);
BM_BatchMatmul(1, 1, 200, 10000, true, true);
BM_BatchMatmul(8, 1, 200, 10000, true, true);
BM_BatchMatmul(32, 1, 200, 10000, true, true);
} // namespace
} // namespace tensorflow