blob: 63133fc5d4ca95dcc4d6fd989f16fb8baf08160b [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/bcast.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
string BCast(const tensorflow::BCast::Vec& x, const tensorflow::BCast::Vec& y,
const bool fewer_dims_optimization = true) {
tensorflow::BCast b(x, y, fewer_dims_optimization);
if (!b.IsValid()) {
return "invalid";
}
string ret;
strings::StrAppend(&ret, "[", absl::StrJoin(b.x_reshape(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.x_bcast(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.y_reshape(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.y_bcast(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.result_shape(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.output_shape(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.grad_x_reduce_idx(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.grad_y_reduce_idx(), ","), "]");
return ret;
}
string BCastBatchIndices(const tensorflow::BCast::Vec& x,
const tensorflow::BCast::Vec& y,
const bool fewer_dims_optimization = true) {
tensorflow::BCast b(x, y, fewer_dims_optimization,
/*return_flattened_batch_indices=*/true);
string ret;
strings::StrAppend(&ret, "[", absl::StrJoin(b.x_batch_indices(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.y_batch_indices(), ","), "]");
return ret;
}
string BCastList3(const tensorflow::BCast::Vec& x,
const tensorflow::BCast::Vec& y,
const tensorflow::BCast::Vec& z,
const bool fewer_dims_optimization = true) {
tensorflow::BCastList<3> b({x, y, z}, fewer_dims_optimization);
if (!b.IsValid()) {
return "invalid";
}
string ret;
strings::StrAppend(&ret, "[", absl::StrJoin(b.reshape(0), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.bcast(0), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.reshape(1), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.bcast(1), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.reshape(2), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.bcast(2), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.result_shape(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.output_shape(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.grad_reduce_idx(0), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.grad_reduce_idx(1), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.grad_reduce_idx(2), ","), "]");
return ret;
}
TEST(BCastTest, Invalid) {
for (const bool use_optimization : {true, false}) {
EXPECT_EQ("invalid", BCast({5, 3, 2}, {3}, use_optimization));
EXPECT_EQ("invalid", BCast({5, 3, 2}, {2, 2}, use_optimization));
EXPECT_EQ("invalid", BCast({5, 3, 2}, {10, 1, 1}, use_optimization));
EXPECT_EQ("invalid",
BCast({1, 2, 1, 2, 1, 2}, {2, 4, 2, 1, 2, 1}, use_optimization));
}
}
TEST(BCastListTest, Invalid) {
for (const bool use_optimization : {true, false}) {
EXPECT_EQ("invalid", BCastList3({5, 3, 2}, {3}, {1}, use_optimization));
EXPECT_EQ("invalid", BCastList3({5, 3, 2}, {2, 2}, {1}, use_optimization));
EXPECT_EQ("invalid",
BCastList3({5, 3, 2}, {10, 1, 1}, {1}, use_optimization));
EXPECT_EQ("invalid", BCastList3({1, 2, 1, 2, 1, 2}, {2, 4, 2, 1, 2, 1}, {1},
use_optimization));
EXPECT_EQ("invalid", BCastList3({5, 3, 2}, {1}, {3}, use_optimization));
EXPECT_EQ("invalid", BCastList3({5, 3, 2}, {1}, {2, 2}, use_optimization));
EXPECT_EQ("invalid",
BCastList3({5, 3, 2}, {1}, {10, 1, 1}, use_optimization));
EXPECT_EQ("invalid", BCastList3({1}, {5, 3, 2}, {3}, use_optimization));
EXPECT_EQ("invalid", BCastList3({1}, {5, 3, 2}, {2, 2}, use_optimization));
EXPECT_EQ("invalid",
BCastList3({1}, {5, 3, 2}, {10, 1, 1}, use_optimization));
}
}
TEST(BCastTest, Basic_SameShape) {
// Effectively no broadcast needed.
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}),
"[2310][1][2310][1]"
"[2310]"
"[11,7,5,3,2]"
"[][]");
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}, false),
"[11,7,5,3,2][1,1,1,1,1][11,7,5,3,2][1,1,1,1,1]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[][]");
}
TEST(BCastListTest, Basic_SameShape) {
// Effectively no broadcast needed.
EXPECT_EQ(BCastList3({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}),
"[2310][1][2310][1][2310][1]"
"[2310]"
"[11,7,5,3,2]"
"[][][]");
EXPECT_EQ(
BCastList3({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}, false),
"[11,7,5,3,2][1,1,1,1,1][11,7,5,3,2][1,1,1,1,1][11,7,5,3,2][1,1,1,1,1]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[][][]");
}
TEST(BCastTest, Basic_SameShapeWithZeroDim) {
// Effectively no broadcast needed.
EXPECT_EQ(BCast({11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}),
"[0][1][0][1]"
"[0]"
"[11,7,0,3,2]"
"[][]");
EXPECT_EQ(BCast({11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}, false),
"[11,7,0,3,2][1,1,1,1,1][11,7,0,3,2][1,1,1,1,1]"
"[11,7,0,3,2]"
"[11,7,0,3,2]"
"[][]");
}
TEST(BCastListTest, Basic_SameShapeWithZeroDim) {
// Effectively no broadcast needed.
EXPECT_EQ(BCastList3({11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}),
"[0][1][0][1][0][1]"
"[0]"
"[11,7,0,3,2]"
"[][][]");
EXPECT_EQ(
BCastList3({11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}, false),
"[11,7,0,3,2][1,1,1,1,1][11,7,0,3,2][1,1,1,1,1][11,7,0,3,2][1,1,1,1,1]"
"[11,7,0,3,2]"
"[11,7,0,3,2]"
"[][][]");
}
TEST(BCastTest, Basic_Scalar_Scalar) {
// Effectively it's a scalar and a scalar.
// [1, 1] [1]
//
EXPECT_EQ(BCast({1, 1}, {}),
"[1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1]");
EXPECT_EQ(BCast({1, 1}, {1}),
"[1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1]");
EXPECT_EQ(BCast({1, 1}, {1}, false),
"[1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1]");
// [1] [1, 1]
EXPECT_EQ(BCast({1}, {1, 1}),
"[1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1]");
EXPECT_EQ(BCast({1}, {1, 1}, false),
"[1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1]");
}
TEST(BCastTest, Basic_TrueScalar_Scalar) {
// [] []
EXPECT_EQ(BCast({}, {}),
"[1][1][1][1]"
"[1]"
"[]"
"[][]");
// [] [1]
EXPECT_EQ(BCast({}, {1}),
"[1][1][1][1]"
"[1]"
"[1]"
"[0][0]");
EXPECT_EQ(BCast({}, {1}, false),
"[1][1][1][1]"
"[1]"
"[1]"
"[0][0]");
// [] [1, 1]
EXPECT_EQ(BCast({}, {1, 1}),
"[1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1]");
EXPECT_EQ(BCast({}, {1, 1}, false),
"[1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1]");
// [1] []
EXPECT_EQ(BCast({1}, {}),
"[1][1][1][1]"
"[1]"
"[1]"
"[0][0]");
EXPECT_EQ(BCast({1}, {}, false),
"[1][1][1][1]"
"[1]"
"[1]"
"[0][0]");
// [1, 1] []
EXPECT_EQ(BCast({1, 1}, {}),
"[1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1]");
EXPECT_EQ(BCast({1, 1}, {}, false),
"[1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1]");
}
TEST(BCastListTest, Basic_Scalar_Scalar_Scalar) {
// Effectively it's a scalar and a scalar.
// [1, 1] [1] [1]
EXPECT_EQ(BCastList3({1, 1}, {1}, {1}),
"[1][1][1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1][0,1]");
EXPECT_EQ(BCastList3({1, 1}, {1}, {1}, false),
"[1,1][1,1][1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1][0,1]");
// [1] [1, 1] [1]
EXPECT_EQ(BCastList3({1}, {1, 1}, {1}),
"[1][1][1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1][0,1]");
EXPECT_EQ(BCastList3({1}, {1, 1}, {1}, false),
"[1,1][1,1][1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1][0,1]");
// [1] [1] [1, 1]
EXPECT_EQ(BCastList3({1}, {1}, {1, 1}),
"[1][1][1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1][0,1]");
EXPECT_EQ(BCastList3({1}, {1}, {1, 1}, false),
"[1,1][1,1][1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1][0,1]");
}
TEST(BCastListTest, Basic_TrueScalar_Scalar_Scalar) {
// Effectively it's a scalar and a scalar.
// [1, 1] [1] []
EXPECT_EQ(BCastList3({1, 1}, {1}, {}),
"[1][1][1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1][0,1]");
EXPECT_EQ(BCastList3({1, 1}, {1}, {}, false),
"[1,1][1,1][1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1][0,1]");
// [] [1, 1] [1]
EXPECT_EQ(BCastList3({}, {1, 1}, {1}),
"[1][1][1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1][0,1]");
EXPECT_EQ(BCastList3({}, {1, 1}, {1}, false),
"[1,1][1,1][1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1][0,1]");
// [1] [] [1, 1]
EXPECT_EQ(BCastList3({1}, {}, {1, 1}),
"[1][1][1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1][0,1]");
EXPECT_EQ(BCastList3({1}, {}, {1, 1}, false),
"[1,1][1,1][1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1][0,1]");
}
TEST(BCastTest, Basic_Tensor_Scalar) {
// Effectively it's a tensor and a scalar.
// [11, 7, 5, 3, 2] [1]
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {1}),
"[2310][1][1][2310]"
"[2310]"
"[11,7,5,3,2]"
"[][0,1,2,3,4]");
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {1}, false),
"[11,7,5,3,2][1,1,1,1,1][1,1,1,1,1][11,7,5,3,2]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[][0,1,2,3,4]");
// [1] [11, 7, 5, 3, 2]
EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2}),
"[1][2310][2310][1]"
"[2310]"
"[11,7,5,3,2]"
"[0,1,2,3,4][]");
EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2}, false),
"[1,1,1,1,1][11,7,5,3,2][11,7,5,3,2][1,1,1,1,1]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[0,1,2,3,4][]");
}
TEST(BCastTest, Basic_Tensor_With_DimSize_1_Scalar) {
// Effectively it's a tensor and a scalar.
// [11, 7, 5, 3, 2, 1] [1]
EXPECT_EQ(BCast({11, 7, 5, 3, 2, 1}, {1}),
"[2310][1][1][2310]"
"[2310]"
"[11,7,5,3,2,1]"
"[5][0,1,2,3,4,5]");
EXPECT_EQ(BCast({11, 7, 5, 3, 2, 1}, {1}, false),
"[11,7,5,3,2,1][1,1,1,1,1,1][1,1,1,1,1,1][11,7,5,3,2,1]"
"[11,7,5,3,2,1]"
"[11,7,5,3,2,1]"
"[5][0,1,2,3,4,5]");
// [1] [11, 7, 5, 3, 2, 1]
EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2, 1}),
"[1][2310][2310][1]"
"[2310]"
"[11,7,5,3,2,1]"
"[0,1,2,3,4,5][5]");
EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2, 1}, false),
"[1,1,1,1,1,1][11,7,5,3,2,1][11,7,5,3,2,1][1,1,1,1,1,1]"
"[11,7,5,3,2,1]"
"[11,7,5,3,2,1]"
"[0,1,2,3,4,5][5]");
// Effectively it's a tensor and a scalar.
// [11, 7, 5, 1, 1, 3, 2, 1] [1]
EXPECT_EQ(BCast({11, 7, 5, 1, 1, 3, 2, 1, 1}, {1}),
"[2310][1][1][2310]"
"[2310]"
"[11,7,5,1,1,3,2,1,1]"
"[3,4,7,8][0,1,2,3,4,5,6,7,8]");
EXPECT_EQ(BCast({11, 7, 5, 1, 1, 3, 2, 1, 1}, {1}, false),
"[11,7,5,1,1,3,2,1,1][1,1,1,1,1,1,1,1,1]" // x_reshape(), x_bcast()
"[1,1,1,1,1,1,1,1,1][11,7,5,1,1,3,2,1,1]" // y_reshape(), y_bcast()
"[11,7,5,1,1,3,2,1,1]"
"[11,7,5,1,1,3,2,1,1]"
"[3,4,7,8][0,1,2,3,4,5,6,7,8]");
// [1] [11, 7, 5, 1, 1, 3, 2, 1]
EXPECT_EQ(BCast({1}, {11, 7, 5, 1, 1, 3, 2, 1, 1}),
"[1][2310][2310][1]"
"[2310]"
"[11,7,5,1,1,3,2,1,1]"
"[0,1,2,3,4,5,6,7,8][3,4,7,8]");
EXPECT_EQ(BCast({1}, {11, 7, 5, 1, 1, 3, 2, 1, 1}, false),
"[1,1,1,1,1,1,1,1,1][11,7,5,1,1,3,2,1,1]" // x_reshape(), x_bcast()
"[11,7,5,1,1,3,2,1,1][1,1,1,1,1,1,1,1,1]" // y_reshape(), y_bcast()
"[11,7,5,1,1,3,2,1,1]"
"[11,7,5,1,1,3,2,1,1]"
"[0,1,2,3,4,5,6,7,8][3,4,7,8]");
}
TEST(BCastTest, Basic_Tensor_Vector) {
// [11, 7, 5, 3, 2] [2]
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {2}),
"[1155,2][1,1][1,2][1155,1]"
"[1155,2]"
"[11,7,5,3,2]"
"[][0,1,2,3]");
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {2}, false),
"[11,7,5,3,2][1,1,1,1,1][1,1,1,1,2][11,7,5,3,1]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[][0,1,2,3]");
// [2] [11, 7, 5, 3, 2]
EXPECT_EQ(BCast({2}, {11, 7, 5, 3, 2}),
"[1,2][1155,1][1155,2][1,1]"
"[1155,2]"
"[11,7,5,3,2]"
"[0,1,2,3][]");
EXPECT_EQ(BCast({2}, {11, 7, 5, 3, 2}, false),
"[1,1,1,1,2][11,7,5,3,1][11,7,5,3,2][1,1,1,1,1]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[0,1,2,3][]");
}
TEST(BCastTest, Basic_Tensor_Matrix) {
// [11, 7, 5, 3, 2] [3, 2]
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 2}),
"[385,6][1,1][1,6][385,1]"
"[385,6]"
"[11,7,5,3,2]"
"[][0,1,2]");
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 2}, false),
"[11,7,5,3,2][1,1,1,1,1][1,1,1,3,2][11,7,5,1,1]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[][0,1,2]");
// [3, 2] [11, 7, 5, 3, 2]
EXPECT_EQ(BCast({3, 2}, {11, 7, 5, 3, 2}),
"[1,6][385,1][385,6][1,1]"
"[385,6]"
"[11,7,5,3,2]"
"[0,1,2][]");
EXPECT_EQ(BCast({3, 2}, {11, 7, 5, 3, 2}, false),
"[1,1,1,3,2][11,7,5,1,1][11,7,5,3,2][1,1,1,1,1]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[0,1,2][]");
}
TEST(BCastTest, Basic_Tensor_Matrix_Column) {
// [11, 7, 5, 3, 2] [3, 1]
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 1}),
"[385,3,2][1,1,1][1,3,1][385,1,2]"
"[385,3,2]"
"[11,7,5,3,2]"
"[][0,1,2,4]");
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 1}, false),
"[11,7,5,3,2][1,1,1,1,1][1,1,1,3,1][11,7,5,1,2]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[][0,1,2,4]");
// [3, 1] [11, 7, 5, 3, 2]
EXPECT_EQ(BCast({3, 1}, {11, 7, 5, 3, 2}),
"[1,3,1][385,1,2][385,3,2][1,1,1]"
"[385,3,2]"
"[11,7,5,3,2]"
"[0,1,2,4][]");
EXPECT_EQ(BCast({3, 1}, {11, 7, 5, 3, 2}, false),
"[1,1,1,3,1][11,7,5,1,2][11,7,5,3,2][1,1,1,1,1]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[0,1,2,4][]");
}
TEST(BCastTest, Basic_Tensor_Matrix_As_Tensor) {
// [11, 7, 5, 3, 2] [7, 5, 1, 1]
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {7, 5, 1, 1}),
"[11,35,6][1,1,1][1,35,1][11,1,6]"
"[11,35,6]"
"[11,7,5,3,2]"
"[][0,3,4]");
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {7, 5, 1, 1}, false),
"[11,7,5,3,2][1,1,1,1,1][1,7,5,1,1][11,1,1,3,2]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[][0,3,4]");
// [7, 5, 1, 1] [11, 7, 5, 3, 2]
EXPECT_EQ(BCast({7, 5, 1, 1}, {11, 7, 5, 3, 2}),
"[1,35,1][11,1,6][11,35,6][1,1,1]"
"[11,35,6]"
"[11,7,5,3,2]"
"[0,3,4][]");
EXPECT_EQ(BCast({7, 5, 1, 1}, {11, 7, 5, 3, 2}, false),
"[1,7,5,1,1][11,1,1,3,2][11,7,5,3,2][1,1,1,1,1]"
"[11,7,5,3,2][11,7,5,3,2]"
"[0,3,4][]");
}
TEST(BCastTest, Complex_BCast_To_Each_Other) {
// Rare cases. x and y broadcast to each other. x and y are of
// different ranks.
// Can be verified in numpy as:
// import numpy as np
// x = np.arange(0,110).reshape([11,1,5,1,2])
// y = np.arange(0,21).reshape([7,1,3,1])
// np.shape(x + y)
// Out[.]: (11, 7, 5, 3, 2)
string truth =
"[11,1,5,1,2][1,7,1,3,1][1,7,1,3,1][11,1,5,1,2]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[1,3][0,2,4]";
EXPECT_EQ(BCast({11, 1, 5, 1, 2}, {7, 1, 3, 1}), truth);
EXPECT_EQ(BCast({11, 1, 5, 1, 2}, {7, 1, 3, 1}, false), truth);
}
TEST(BCastListTest, Complex_BCast_To_Each_Other) {
// Rare cases. x, y and z broadcast to each other. x,y and z are of
// different ranks.
// Can be verified in numpy as:
// import numpy as np
// x = np.arange(0,22).reshape([11,1,1,1,2])
// y = np.arange(0,21).reshape([7,1,3,1])
// z = np.arange(0,5).reshape([5,1,1])
// np.shape(x + y + z)
// Out[.]: (11, 7, 5, 3, 2)
//
string truth =
"[11,1,1,1,2][1,7,5,3,1]"
"[1,7,1,3,1][11,1,5,1,2]"
"[1,1,5,1,1][11,7,1,3,2]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[1,2,3][0,2,4][0,1,3,4]";
EXPECT_EQ(BCastList3({11, 1, 1, 1, 2}, {7, 1, 3, 1}, {5, 1, 1}), truth);
EXPECT_EQ(BCastList3({11, 1, 1, 1, 2}, {7, 1, 3, 1}, {5, 1, 1}, false),
truth);
}
TEST(BCastTest, TestZeroDimensionShape) {
// (2,0,5) and (5) in both orders
EXPECT_EQ(BCast({2, 0, 5}, {5}),
"[0,5][1,1][1,5][0,1]"
"[0,5]"
"[2,0,5]"
"[][0,1]");
EXPECT_EQ(BCast({5}, {2, 0, 5}),
"[1,5][0,1][0,5][1,1]"
"[0,5]"
"[2,0,5]"
"[0,1][]");
EXPECT_EQ(BCast({2, 0, 5}, {5}, false),
"[2,0,5][1,1,1][1,1,5][2,0,1]"
"[2,0,5]"
"[2,0,5]"
"[][0,1]");
EXPECT_EQ(BCast({5}, {2, 0, 5}, false),
"[1,1,5][2,0,1][2,0,5][1,1,1]"
"[2,0,5]"
"[2,0,5]"
"[0,1][]");
// (2,0,3,0,5) and (5) in both orders
EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {5}),
"[0,5][1,1][1,5][0,1]"
"[0,5]"
"[2,0,3,0,5]"
"[][0,1,2,3]");
EXPECT_EQ(BCast({5}, {2, 0, 3, 0, 5}),
"[1,5][0,1][0,5][1,1]"
"[0,5]"
"[2,0,3,0,5]"
"[0,1,2,3][]");
EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {5}, false),
"[2,0,3,0,5][1,1,1,1,1][1,1,1,1,5][2,0,3,0,1]"
"[2,0,3,0,5]"
"[2,0,3,0,5]"
"[][0,1,2,3]");
EXPECT_EQ(BCast({5}, {2, 0, 3, 0, 5}, false),
"[1,1,1,1,5][2,0,3,0,1][2,0,3,0,5][1,1,1,1,1]"
"[2,0,3,0,5]"
"[2,0,3,0,5]"
"[0,1,2,3][]");
// (2,0,3,0,5) and (3,1,5) in both orders
EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {3, 1, 5}),
"[0,3,0,5][1,1,1,1][1,3,1,5][0,1,0,1]"
"[0,3,0,5]"
"[2,0,3,0,5]"
"[][0,1,3]");
EXPECT_EQ(BCast({3, 1, 5}, {2, 0, 3, 0, 5}),
"[1,3,1,5][0,1,0,1][0,3,0,5][1,1,1,1]"
"[0,3,0,5]"
"[2,0,3,0,5]"
"[0,1,3][]");
EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {3, 1, 5}, false),
"[2,0,3,0,5][1,1,1,1,1][1,1,3,1,5][2,0,1,0,1]"
"[2,0,3,0,5]"
"[2,0,3,0,5]"
"[][0,1,3]");
EXPECT_EQ(BCast({3, 1, 5}, {2, 0, 3, 0, 5}, false),
"[1,1,3,1,5][2,0,1,0,1][2,0,3,0,5][1,1,1,1,1]"
"[2,0,3,0,5]"
"[2,0,3,0,5]"
"[0,1,3][]");
}
TEST(BCastTest, BatchIndices) {
EXPECT_EQ("[0,0,0,0][0,1,2,3]", BCastBatchIndices({1}, {4}));
// Invalid broadcast.
EXPECT_EQ("[][]", BCastBatchIndices({5}, {7}));
// Same shape, no batch indices.
EXPECT_EQ("[][]", BCastBatchIndices({2, 4, 6}, {2, 4, 6}));
// More complicated broadcasts.
EXPECT_EQ("[0,0,0,0,1,1,1,1,2,2,2,2][0,1,2,3,0,1,2,3,0,1,2,3]",
BCastBatchIndices({3, 1}, {1, 4}));
EXPECT_EQ("[0,0,1,1,2,2,0,0,1,1,2,2][0,1,0,1,0,1,2,3,2,3,2,3]",
BCastBatchIndices({3, 1}, {2, 1, 2}));
}
void BM_BCastSetup(::testing::benchmark::State& state) {
const int same_shape = state.range(0);
if (same_shape) {
state.SetLabel("same_shapes");
for (auto s : state) {
class BCast b({1000, 100}, {1000, 100});
}
} else {
state.SetLabel("different_shapes");
for (auto s : state) {
class BCast b({3, 1, 5}, {2, 0, 3, 0, 5});
}
}
}
BENCHMARK(BM_BCastSetup)->Arg(0)->Arg(1);
} // namespace
} // namespace tensorflow