Split convolution tests into two.
PiperOrigin-RevId: 321634608
Change-Id: I1dfd1c5ab7010af10962ec021cc66f8fe9c6ce6e
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 83851fa..b3353cf 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1115,10 +1115,24 @@
name = "convolution_test",
timeout = "long",
srcs = ["convolution_test.cc"],
- shard_count = 40,
+ shard_count = 50,
tags = [
"no_rocm",
- "nozapfhahn",
+ "optonly",
+ ],
+ deps = CONVOLUTION_TEST_DEPS + [
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+xla_test(
+ name = "convolution_test_1d",
+ timeout = "long",
+ srcs = ["convolution_test_1d.cc"],
+ shard_count = 50,
+ tags = [
+ "no_rocm",
"optonly",
],
deps = CONVOLUTION_TEST_DEPS + [
@@ -1148,6 +1162,23 @@
)
xla_test(
+ name = "convolution_test_1d_autotune_disabled",
+ timeout = "long",
+ srcs = ["convolution_test_1d.cc"],
+ args = ["--xla_gpu_autotune_level=0"],
+ backends = ["gpu"],
+ shard_count = 40,
+ tags = [
+ "no_rocm",
+ "optonly",
+ ],
+ deps = CONVOLUTION_TEST_DEPS + [
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+xla_test(
name = "convolution_test_gpu_alternative_layout",
timeout = "long",
srcs = ["convolution_test.cc"],
@@ -1164,6 +1195,22 @@
)
xla_test(
+ name = "convolution_test_1d_gpu_alternative_layout",
+ timeout = "long",
+ srcs = ["convolution_test_1d.cc"],
+ backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
+ backends = ["gpu"],
+ shard_count = 25,
+ tags = [
+ "no_rocm",
+ ],
+ deps = CONVOLUTION_TEST_DEPS + [
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+xla_test(
name = "convolution_variants_test",
timeout = "long",
srcs = ["convolution_variants_test.cc"],
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index c63f1d0..8021d6f 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -13,7 +13,7 @@
limitations under the License.
==============================================================================*/
-// Tests of convolution with trivial kernels and no special variations (like
+// Tests of 2+D convolution with trivial kernels and no special variations (like
// strides and padding).
#include <memory>
@@ -240,174 +240,6 @@
TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); }
-XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
- XlaBuilder builder(TestName());
- {
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
- auto input = Parameter(&builder, 0, input_shape, "input");
- auto filter = Parameter(&builder, 1, filter_shape, "filter");
- Conv(input, filter, {1}, Padding::kValid);
- }
-
- Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
- Array3D<float> filter({{{10, 20}, {30, 40}}});
-
- Array3D<float> expected({{{510, 610, 710, 810}}});
-
- auto input_literal =
- client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
- .ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
- .ConsumeValueOrDie();
-
- ComputeAndCompareR3<float>(&builder, expected,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
-}
-
-template <typename T>
-class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
- public:
- void RunTest() {
- XlaBuilder builder(TestName());
- {
- Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
- Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
- auto input = Parameter(&builder, 0, input_shape, "input");
- auto filter = Parameter(&builder, 1, filter_shape, "filter");
- // Convolution dimensions are bf0_oi0->bo0.
- ConvGeneralDilated(
- input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
- /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
- /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
- }
-
- Array3D<T> input(
- {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
- Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
-
- Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
-
- auto input_literal =
- client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
- .ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
- .ConsumeValueOrDie();
-
- ComputeAndCompareR3<T>(&builder, expected,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
- }
-}; // namespace
-
-TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes);
-TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); }
-
-XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
- XlaBuilder builder(TestName());
- {
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
- auto input = Parameter(&builder, 0, input_shape, "input");
- auto filter = Parameter(&builder, 1, filter_shape, "filter");
- // Convolution dimensions are bf0_oi0->bo0.
- ConvGeneralDilated(
- input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
- /*lhs_dilation=*/{2}, /*rhs_dilation=*/{1},
- /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
- }
-
- Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
- Array3D<float> filter({{{10, 20}, {30, 40}}});
-
- Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
-
- auto input_literal =
- client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
- .ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
- .ConsumeValueOrDie();
-
- ComputeAndCompareR3<float>(&builder, expected,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
-}
-
-XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
- XlaBuilder builder(TestName());
- {
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
- auto input = Parameter(&builder, 0, input_shape, "input");
- auto filter = Parameter(&builder, 1, filter_shape, "filter");
- // Convolution dimensions are bf0_oi0->bo0.
- ConvGeneralDilated(
- input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
- /*lhs_dilation=*/{2}, /*rhs_dilation=*/{2},
- /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
- }
-
- Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
- Array3D<float> filter({{{10, 20}, {30, 40}}});
-
- Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
-
- auto input_literal =
- client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
- .ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
- .ConsumeValueOrDie();
-
- ComputeAndCompareR3<float>(&builder, expected,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
-}
-
-template <typename T>
-class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
- public:
- void RunTest() {
- XlaBuilder builder(TestName());
- {
- Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
- Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
- auto input = Parameter(&builder, 0, input_shape, "input");
- auto filter = Parameter(&builder, 1, filter_shape, "filter");
- // Convolution dimensions are bf0_oi0->bo0.
- ConvGeneralDilated(
- input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
- /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
- /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
- }
-
- Array3D<T> input(
- {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
- Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
-
- Array3D<T> expected(
- {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
-
- auto input_literal =
- client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
- .ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
- .ConsumeValueOrDie();
-
- ComputeAndCompareR3<T>(&builder, expected,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
- }
-};
-
-TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes);
-TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); }
-
XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
XlaBuilder builder(TestName());
std::vector<int64> input_dims = {1, 4, 2, 3, 3};
@@ -1714,150 +1546,7 @@
ConvolveWithAndWithoutCanonicalization,
::testing::Values(true, false));
-struct Convolve1DTestParam {
- int64 input_feature;
- int64 output_feature;
- int64 batch;
- int64 window_size;
- int64 num_windows;
-};
-class Convolve1D1WindowTestBase
- : public ConvolutionTest,
- public ::testing::WithParamInterface<Convolve1DTestParam> {
- protected:
- template <typename T>
- void TestImpl() {
- XlaBuilder builder(TestName());
- int64 input_feature = GetParam().input_feature;
- int64 output_feature = GetParam().output_feature;
- int64 batch = GetParam().batch;
- int64 num_windows = GetParam().num_windows;
- int64 window_size = GetParam().window_size;
- std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
- input_feature};
- std::vector<int64> filter_dims = {window_size, input_feature,
- output_feature};
- Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
- Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
- {
- auto input = Parameter(&builder, 0, input_shape, "input");
- auto filter = Parameter(&builder, 1, filter_shape, "filter");
-
- // Tensorflow dimension numbers for 1D convolution.
- ConvolutionDimensionNumbers dnums;
- dnums.set_input_batch_dimension(0);
- dnums.set_output_batch_dimension(0);
- dnums.add_input_spatial_dimensions(1);
- dnums.add_output_spatial_dimensions(1);
- dnums.set_input_feature_dimension(2);
- dnums.set_output_feature_dimension(2);
- dnums.add_kernel_spatial_dimensions(0);
- dnums.set_kernel_input_feature_dimension(1);
- dnums.set_kernel_output_feature_dimension(2);
-
- ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums);
- }
-
- std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
- static_cast<T>(1.0f));
- auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
-
- std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
- static_cast<T>(1.0f));
-
- auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
-
- std::vector<T> expect_elems(batch * output_feature * num_windows,
- static_cast<T>(window_size * input_feature));
- auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
- auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
- .ConsumeValueOrDie();
-
- auto input_literal =
- client_->TransferToServer(input_r3).ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(filter_r3).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, expected_r3,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
- }
-};
-
-class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {};
-
-XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl<float>(); }
-
-INSTANTIATE_TEST_CASE_P(
- Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat,
- ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
- Convolve1DTestParam{160, 1, 1, 5, 1},
- Convolve1DTestParam{24, 1, 1, 20, 1},
- Convolve1DTestParam{30, 1, 1, 20, 1},
- Convolve1DTestParam{23, 1, 1, 20, 20},
- Convolve1DTestParam{25, 1, 1, 20, 1},
- Convolve1DTestParam{24, 1, 1, 10, 5},
- Convolve1DTestParam{160, 1, 1, 10, 1},
- Convolve1DTestParam{255, 1, 1, 3, 1},
- Convolve1DTestParam{130, 1, 1, 1, 2},
- Convolve1DTestParam{136, 1, 1, 1, 2},
- Convolve1DTestParam{64, 1, 1, 1, 1},
- Convolve1DTestParam{128, 1, 1, 1, 1},
- Convolve1DTestParam{139, 1, 1, 128, 1},
- Convolve1DTestParam{1, 10, 10, 1, 10},
- Convolve1DTestParam{1, 10, 130, 1, 2},
- Convolve1DTestParam{1, 10, 130, 1, 1},
- Convolve1DTestParam{1, 64, 64, 1, 10},
- Convolve1DTestParam{1, 65, 65, 1, 1},
- Convolve1DTestParam{1, 128, 128, 1, 1},
- Convolve1DTestParam{128, 128, 128, 128, 1},
- Convolve1DTestParam{1, 128, 128, 1, 1},
- Convolve1DTestParam{2, 2, 2, 2, 1},
- Convolve1DTestParam{161, 1, 1, 10, 1},
- Convolve1DTestParam{900, 1, 1, 10, 1},
- Convolve1DTestParam{640, 3, 3, 128, 1})
-
-);
-
-#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU)
-class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {};
-
-XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) {
- TestImpl<Eigen::half>();
-}
-
-INSTANTIATE_TEST_CASE_P(
- Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf,
- ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
- Convolve1DTestParam{160, 1, 1, 5, 1},
- Convolve1DTestParam{24, 1, 1, 20, 1},
- Convolve1DTestParam{30, 1, 1, 20, 1},
- Convolve1DTestParam{23, 1, 1, 20, 20},
- Convolve1DTestParam{25, 1, 1, 20, 1},
- Convolve1DTestParam{24, 1, 1, 10, 5},
- Convolve1DTestParam{160, 1, 1, 10, 1},
- Convolve1DTestParam{255, 1, 1, 3, 1},
- Convolve1DTestParam{130, 1, 1, 1, 3},
- Convolve1DTestParam{64, 1, 1, 1, 1},
- Convolve1DTestParam{128, 1, 1, 1, 1},
- Convolve1DTestParam{139, 1, 1, 128, 1},
- Convolve1DTestParam{640, 3, 3, 128, 1},
- Convolve1DTestParam{900, 1, 1, 10, 1},
- Convolve1DTestParam{1, 10, 10, 1, 10},
- Convolve1DTestParam{1, 10, 130, 1, 1},
- Convolve1DTestParam{1, 10, 130, 1, 2},
- Convolve1DTestParam{1, 64, 64, 1, 10},
- Convolve1DTestParam{1, 65, 65, 1, 1},
- Convolve1DTestParam{1, 128, 128, 1, 1},
- Convolve1DTestParam{128, 128, 128, 128, 1},
- Convolve1DTestParam{1, 128, 128, 1, 1},
- Convolve1DTestParam{2, 2, 2, 2, 1},
- Convolve1DTestParam{161, 1, 1, 10, 1})
-
-);
-#endif
XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/convolution_test_1d.cc b/tensorflow/compiler/xla/tests/convolution_test_1d.cc
new file mode 100644
index 0000000..2b2bf09
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/convolution_test_1d.cc
@@ -0,0 +1,376 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests of 1D convolution with trivial kernels and no special variations (like
+// strides and padding).
+
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ConvolutionTest : public ClientLibraryTestBase {
+ protected:
+#if XLA_TEST_BACKEND_GPU
+ // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
+ // convolution. So relax the absolute error threshold.
+ ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-3);
+#else
+ ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-3);
+#endif
+};
+
+#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
+using TestTypes = ::testing::Types<float>;
+#else
+using TestTypes = ::testing::Types<float, Eigen::half>;
+#endif
+
+struct Convolve1DTestParam {
+ int64 input_feature;
+ int64 output_feature;
+ int64 batch;
+ int64 window_size;
+ int64 num_windows;
+};
+
+class Convolve1D1WindowTestBase
+ : public ConvolutionTest,
+ public ::testing::WithParamInterface<Convolve1DTestParam> {
+ protected:
+ template <typename T>
+ void TestImpl() {
+ XlaBuilder builder(TestName());
+ int64 input_feature = GetParam().input_feature;
+ int64 output_feature = GetParam().output_feature;
+ int64 batch = GetParam().batch;
+ int64 num_windows = GetParam().num_windows;
+ int64 window_size = GetParam().window_size;
+ std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
+ input_feature};
+ std::vector<int64> filter_dims = {window_size, input_feature,
+ output_feature};
+ Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
+ Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
+ {
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+
+ // Tensorflow dimension numbers for 1D convolution.
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(0);
+ dnums.set_output_batch_dimension(0);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.set_input_feature_dimension(2);
+ dnums.set_output_feature_dimension(2);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.set_kernel_input_feature_dimension(1);
+ dnums.set_kernel_output_feature_dimension(2);
+
+ ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums);
+ }
+
+ std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
+ static_cast<T>(1.0f));
+ auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
+ auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
+
+ std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
+ static_cast<T>(1.0f));
+
+ auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
+ auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
+
+ std::vector<T> expect_elems(batch * output_feature * num_windows,
+ static_cast<T>(window_size * input_feature));
+ auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
+ auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
+ .ConsumeValueOrDie();
+
+ auto input_literal =
+ client_->TransferToServer(input_r3).ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(filter_r3).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, expected_r3,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+ }
+};
+
+class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {};
+
+XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl<float>(); }
+
+INSTANTIATE_TEST_CASE_P(
+ Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat,
+ ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
+ Convolve1DTestParam{160, 1, 1, 5, 1},
+ Convolve1DTestParam{24, 1, 1, 20, 1},
+ Convolve1DTestParam{30, 1, 1, 20, 1},
+ Convolve1DTestParam{23, 1, 1, 20, 20},
+ Convolve1DTestParam{25, 1, 1, 20, 1},
+ Convolve1DTestParam{24, 1, 1, 10, 5},
+ Convolve1DTestParam{160, 1, 1, 10, 1},
+ Convolve1DTestParam{255, 1, 1, 3, 1},
+ Convolve1DTestParam{130, 1, 1, 1, 2},
+ Convolve1DTestParam{136, 1, 1, 1, 2},
+ Convolve1DTestParam{64, 1, 1, 1, 1},
+ Convolve1DTestParam{128, 1, 1, 1, 1},
+ Convolve1DTestParam{139, 1, 1, 128, 1},
+ Convolve1DTestParam{1, 10, 10, 1, 10},
+ Convolve1DTestParam{1, 10, 130, 1, 2},
+ Convolve1DTestParam{1, 10, 130, 1, 1},
+ Convolve1DTestParam{1, 64, 64, 1, 10},
+ Convolve1DTestParam{1, 65, 65, 1, 1},
+ Convolve1DTestParam{1, 128, 128, 1, 1},
+ Convolve1DTestParam{128, 128, 128, 128, 1},
+ Convolve1DTestParam{1, 128, 128, 1, 1},
+ Convolve1DTestParam{2, 2, 2, 2, 1},
+ Convolve1DTestParam{161, 1, 1, 10, 1},
+ Convolve1DTestParam{900, 1, 1, 10, 1},
+ Convolve1DTestParam{640, 3, 3, 128, 1})
+
+);
+
+#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU)
+class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {};
+
+XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) {
+ TestImpl<Eigen::half>();
+}
+
+INSTANTIATE_TEST_CASE_P(
+ Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf,
+ ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
+ Convolve1DTestParam{160, 1, 1, 5, 1},
+ Convolve1DTestParam{24, 1, 1, 20, 1},
+ Convolve1DTestParam{30, 1, 1, 20, 1},
+ Convolve1DTestParam{23, 1, 1, 20, 20},
+ Convolve1DTestParam{25, 1, 1, 20, 1},
+ Convolve1DTestParam{24, 1, 1, 10, 5},
+ Convolve1DTestParam{160, 1, 1, 10, 1},
+ Convolve1DTestParam{255, 1, 1, 3, 1},
+ Convolve1DTestParam{130, 1, 1, 1, 3},
+ Convolve1DTestParam{64, 1, 1, 1, 1},
+ Convolve1DTestParam{128, 1, 1, 1, 1},
+ Convolve1DTestParam{139, 1, 1, 128, 1},
+ Convolve1DTestParam{640, 3, 3, 128, 1},
+ Convolve1DTestParam{900, 1, 1, 10, 1},
+ Convolve1DTestParam{1, 10, 10, 1, 10},
+ Convolve1DTestParam{1, 10, 130, 1, 1},
+ Convolve1DTestParam{1, 10, 130, 1, 2},
+ Convolve1DTestParam{1, 64, 64, 1, 10},
+ Convolve1DTestParam{1, 65, 65, 1, 1},
+ Convolve1DTestParam{1, 128, 128, 1, 1},
+ Convolve1DTestParam{128, 128, 128, 128, 1},
+ Convolve1DTestParam{1, 128, 128, 1, 1},
+ Convolve1DTestParam{2, 2, 2, 2, 1},
+ Convolve1DTestParam{161, 1, 1, 10, 1})
+
+);
+#endif
+
+XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
+ XlaBuilder builder(TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+ Conv(input, filter, {1}, Padding::kValid);
+ }
+
+ Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
+ Array3D<float> filter({{{10, 20}, {30, 40}}});
+
+ Array3D<float> expected({{{510, 610, 710, 810}}});
+
+ auto input_literal =
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR3<float>(&builder, expected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+template <typename T>
+class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
+ public:
+ void RunTest() {
+ XlaBuilder builder(TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
+ Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+ // Convolution dimensions are bf0_oi0->bo0.
+ ConvGeneralDilated(
+ input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
+ /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
+ /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
+ }
+
+ Array3D<T> input(
+ {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
+ Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
+
+ Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
+
+ auto input_literal =
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR3<T>(&builder, expected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+ }
+}; // namespace
+
+TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes);
+TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); }
+
+XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
+ XlaBuilder builder(TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+ // Convolution dimensions are bf0_oi0->bo0.
+ ConvGeneralDilated(
+ input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
+ /*lhs_dilation=*/{2}, /*rhs_dilation=*/{1},
+ /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
+ }
+
+ Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
+ Array3D<float> filter({{{10, 20}, {30, 40}}});
+
+ Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
+
+ auto input_literal =
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR3<float>(&builder, expected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
+ XlaBuilder builder(TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+ // Convolution dimensions are bf0_oi0->bo0.
+ ConvGeneralDilated(
+ input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
+ /*lhs_dilation=*/{2}, /*rhs_dilation=*/{2},
+ /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
+ }
+
+ Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
+ Array3D<float> filter({{{10, 20}, {30, 40}}});
+
+ Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
+
+ auto input_literal =
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR3<float>(&builder, expected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+template <typename T>
+class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
+ public:
+ void RunTest() {
+ XlaBuilder builder(TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
+ Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+ // Convolution dimensions are bf0_oi0->bo0.
+ ConvGeneralDilated(
+ input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
+ /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
+ /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
+ }
+
+ Array3D<T> input(
+ {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
+ Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
+
+ Array3D<T> expected(
+ {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
+
+ auto input_literal =
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR3<T>(&builder, expected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+ }
+};
+
+TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes);
+TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); }
+
+} // namespace
+} // namespace xla