blob: 2b2bf09814543325c52e7476ef74e4bf678a305c [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Tests of 1D convolution with trivial kernels and no special variations (like
// strides and padding).
#include <memory>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
class ConvolutionTest : public ClientLibraryTestBase {
protected:
#if XLA_TEST_BACKEND_GPU
// XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
// convolution. So relax the absolute error threshold.
ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-3);
#else
ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-3);
#endif
};
#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
using TestTypes = ::testing::Types<float>;
#else
using TestTypes = ::testing::Types<float, Eigen::half>;
#endif
struct Convolve1DTestParam {
int64 input_feature;
int64 output_feature;
int64 batch;
int64 window_size;
int64 num_windows;
};
class Convolve1D1WindowTestBase
: public ConvolutionTest,
public ::testing::WithParamInterface<Convolve1DTestParam> {
protected:
template <typename T>
void TestImpl() {
XlaBuilder builder(TestName());
int64 input_feature = GetParam().input_feature;
int64 output_feature = GetParam().output_feature;
int64 batch = GetParam().batch;
int64 num_windows = GetParam().num_windows;
int64 window_size = GetParam().window_size;
std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
input_feature};
std::vector<int64> filter_dims = {window_size, input_feature,
output_feature};
Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
{
auto input = Parameter(&builder, 0, input_shape, "input");
auto filter = Parameter(&builder, 1, filter_shape, "filter");
// Tensorflow dimension numbers for 1D convolution.
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
dnums.add_input_spatial_dimensions(1);
dnums.add_output_spatial_dimensions(1);
dnums.set_input_feature_dimension(2);
dnums.set_output_feature_dimension(2);
dnums.add_kernel_spatial_dimensions(0);
dnums.set_kernel_input_feature_dimension(1);
dnums.set_kernel_output_feature_dimension(2);
ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums);
}
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
static_cast<T>(1.0f));
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
static_cast<T>(1.0f));
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
std::vector<T> expect_elems(batch * output_feature * num_windows,
static_cast<T>(window_size * input_feature));
auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
.ConsumeValueOrDie();
auto input_literal =
client_->TransferToServer(input_r3).ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(filter_r3).ConsumeValueOrDie();
ComputeAndCompareLiteral(&builder, expected_r3,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
};
class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {};
XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl<float>(); }
INSTANTIATE_TEST_CASE_P(
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat,
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
Convolve1DTestParam{160, 1, 1, 5, 1},
Convolve1DTestParam{24, 1, 1, 20, 1},
Convolve1DTestParam{30, 1, 1, 20, 1},
Convolve1DTestParam{23, 1, 1, 20, 20},
Convolve1DTestParam{25, 1, 1, 20, 1},
Convolve1DTestParam{24, 1, 1, 10, 5},
Convolve1DTestParam{160, 1, 1, 10, 1},
Convolve1DTestParam{255, 1, 1, 3, 1},
Convolve1DTestParam{130, 1, 1, 1, 2},
Convolve1DTestParam{136, 1, 1, 1, 2},
Convolve1DTestParam{64, 1, 1, 1, 1},
Convolve1DTestParam{128, 1, 1, 1, 1},
Convolve1DTestParam{139, 1, 1, 128, 1},
Convolve1DTestParam{1, 10, 10, 1, 10},
Convolve1DTestParam{1, 10, 130, 1, 2},
Convolve1DTestParam{1, 10, 130, 1, 1},
Convolve1DTestParam{1, 64, 64, 1, 10},
Convolve1DTestParam{1, 65, 65, 1, 1},
Convolve1DTestParam{1, 128, 128, 1, 1},
Convolve1DTestParam{128, 128, 128, 128, 1},
Convolve1DTestParam{1, 128, 128, 1, 1},
Convolve1DTestParam{2, 2, 2, 2, 1},
Convolve1DTestParam{161, 1, 1, 10, 1},
Convolve1DTestParam{900, 1, 1, 10, 1},
Convolve1DTestParam{640, 3, 3, 128, 1})
);
#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU)
class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {};
XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) {
TestImpl<Eigen::half>();
}
INSTANTIATE_TEST_CASE_P(
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf,
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
Convolve1DTestParam{160, 1, 1, 5, 1},
Convolve1DTestParam{24, 1, 1, 20, 1},
Convolve1DTestParam{30, 1, 1, 20, 1},
Convolve1DTestParam{23, 1, 1, 20, 20},
Convolve1DTestParam{25, 1, 1, 20, 1},
Convolve1DTestParam{24, 1, 1, 10, 5},
Convolve1DTestParam{160, 1, 1, 10, 1},
Convolve1DTestParam{255, 1, 1, 3, 1},
Convolve1DTestParam{130, 1, 1, 1, 3},
Convolve1DTestParam{64, 1, 1, 1, 1},
Convolve1DTestParam{128, 1, 1, 1, 1},
Convolve1DTestParam{139, 1, 1, 128, 1},
Convolve1DTestParam{640, 3, 3, 128, 1},
Convolve1DTestParam{900, 1, 1, 10, 1},
Convolve1DTestParam{1, 10, 10, 1, 10},
Convolve1DTestParam{1, 10, 130, 1, 1},
Convolve1DTestParam{1, 10, 130, 1, 2},
Convolve1DTestParam{1, 64, 64, 1, 10},
Convolve1DTestParam{1, 65, 65, 1, 1},
Convolve1DTestParam{1, 128, 128, 1, 1},
Convolve1DTestParam{128, 128, 128, 128, 1},
Convolve1DTestParam{1, 128, 128, 1, 1},
Convolve1DTestParam{2, 2, 2, 2, 1},
Convolve1DTestParam{161, 1, 1, 10, 1})
);
#endif
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
XlaBuilder builder(TestName());
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
auto input = Parameter(&builder, 0, input_shape, "input");
auto filter = Parameter(&builder, 1, filter_shape, "filter");
Conv(input, filter, {1}, Padding::kValid);
}
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
Array3D<float> filter({{{10, 20}, {30, 40}}});
Array3D<float> expected({{{510, 610, 710, 810}}});
auto input_literal =
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
template <typename T>
class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
public:
void RunTest() {
XlaBuilder builder(TestName());
{
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
auto input = Parameter(&builder, 0, input_shape, "input");
auto filter = Parameter(&builder, 1, filter_shape, "filter");
// Convolution dimensions are bf0_oi0->bo0.
ConvGeneralDilated(
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
}
Array3D<T> input(
{{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
auto input_literal =
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
}; // namespace
TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes);
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); }
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
XlaBuilder builder(TestName());
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
auto input = Parameter(&builder, 0, input_shape, "input");
auto filter = Parameter(&builder, 1, filter_shape, "filter");
// Convolution dimensions are bf0_oi0->bo0.
ConvGeneralDilated(
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
/*lhs_dilation=*/{2}, /*rhs_dilation=*/{1},
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
}
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
Array3D<float> filter({{{10, 20}, {30, 40}}});
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
auto input_literal =
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
XlaBuilder builder(TestName());
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
auto input = Parameter(&builder, 0, input_shape, "input");
auto filter = Parameter(&builder, 1, filter_shape, "filter");
// Convolution dimensions are bf0_oi0->bo0.
ConvGeneralDilated(
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
/*lhs_dilation=*/{2}, /*rhs_dilation=*/{2},
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
}
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
Array3D<float> filter({{{10, 20}, {30, 40}}});
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
auto input_literal =
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
template <typename T>
class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
public:
void RunTest() {
XlaBuilder builder(TestName());
{
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
auto input = Parameter(&builder, 0, input_shape, "input");
auto filter = Parameter(&builder, 1, filter_shape, "filter");
// Convolution dimensions are bf0_oi0->bo0.
ConvGeneralDilated(
input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
}
Array3D<T> input(
{{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
Array3D<T> expected(
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
auto input_literal =
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
};
TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes);
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); }
} // namespace
} // namespace xla