| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <memory> |
| #include <vector> |
| |
| #include "absl/strings/str_cat.h" |
| #include "tensorflow/compiler/xla/array2d.h" |
| #include "tensorflow/compiler/xla/array3d.h" |
| #include "tensorflow/compiler/xla/client/lib/matrix.h" |
| #include "tensorflow/compiler/xla/client/local_client.h" |
| #include "tensorflow/compiler/xla/client/xla_builder.h" |
| #include "tensorflow/compiler/xla/primitive_util.h" |
| #include "tensorflow/compiler/xla/reference_util.h" |
| #include "tensorflow/compiler/xla/service/hlo_parser.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/tests/client_library_test_base.h" |
| #include "tensorflow/compiler/xla/tests/hlo_test_base.h" |
| #include "tensorflow/compiler/xla/tests/literal_test_util.h" |
| #include "tensorflow/compiler/xla/tests/test_macros.h" |
| #include "tensorflow/compiler/xla/tests/test_utils.h" |
| #include "tensorflow/core/platform/test.h" |
| #include "tensorflow/core/platform/test_benchmark.h" |
| |
| namespace xla { |
| namespace { |
| |
| class DotOperationTest : public ClientLibraryTestBase { |
| public: |
| ErrorSpec error_spec_{0.0001, 1e-5}; |
| }; |
| |
| using TypesF16F32 = ::testing::Types< |
| #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) |
| Eigen::half, |
| #endif |
| float>; |
| |
| using TypesF16F32F64 = ::testing::Types< |
| #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) |
| Eigen::half, |
| #endif |
| #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) |
| double, |
| #endif |
| float>; |
| |
| using TypesF16F32F64CF64 = ::testing::Types< |
| #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) |
| Eigen::half, |
| #endif |
| #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) |
| double, complex64, |
| #endif |
| float>; |
| |
| // Check that we can safely pass an input tuple's elements to a dot operation. |
| XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { |
| XlaBuilder builder(TestName()); |
| |
| XlaOp param; |
| TF_ASSERT_OK_AND_ASSIGN( |
| auto param_data, |
| CreateParameterAndTransferLiteral( |
| 0, |
| LiteralUtil::MakeTupleFromSlices( |
| {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}), |
| LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}), |
| "arg0", &builder, ¶m)); |
| auto lhs = GetTupleElement(param, 0); |
| auto rhs = GetTupleElement(param, 1); |
| Dot(lhs, rhs); |
| |
| ComputeAndCompareLiteral(&builder, |
| LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}), |
| {param_data.get()}); |
| } |
| |
| template <typename T> |
| class DotOperationTest_F16F32F64CF64 : public DotOperationTest {}; |
| TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64); |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) { |
| using T = TypeParam; |
| XlaBuilder builder(this->TestName()); |
| |
| auto lhs = ConstantR1<T>(&builder, {}); |
| auto rhs = ConstantR1<T>(&builder, {}); |
| Dot(lhs, rhs); |
| |
| this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(0.0), {}, |
| this->error_spec_); |
| } |
| |
| template <typename T> |
| class DotOperationTest_F16F32F64 : public DotOperationTest {}; |
| TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64); |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { |
| using T = TypeParam; |
| XlaBuilder builder(this->TestName()); |
| auto lhs = ConstantR2FromArray2D<T>(&builder, {{3.0f, 4.0f}}); |
| auto rhs = ConstantFromArray<T>(&builder, {3.0f, 4.0f}); |
| Dot(lhs, rhs); |
| |
| this->template ComputeAndCompareR1<T>(&builder, {static_cast<T>(25.0f)}, {}, |
| this->error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, OneElementVectorDot) { |
| using T = TypeParam; |
| XlaBuilder builder(this->TestName()); |
| auto lhs = ConstantR1<T>(&builder, {static_cast<T>(2.0f)}); |
| auto rhs = ConstantR1<T>(&builder, {static_cast<T>(3.0f)}); |
| Dot(lhs, rhs); |
| |
| this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(6.0f), {}, |
| this->error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) { |
| using T = TypeParam; |
| XlaBuilder builder(this->TestName()); |
| auto lhs = ConstantFromArray<T>(&builder, {1.0f, 2.5f, 42.0f}); |
| auto rhs = ConstantFromArray<T>(&builder, {11.0f, -1.0f, 0.5f}); |
| Dot(lhs, rhs); |
| |
| this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(29.5f), {}, |
| this->error_spec_); |
| } |
| |
| std::vector<int64_t> MinorToMajorForIsRowMajor(bool row_major) { |
| return {row_major ? 1 : 0, row_major ? 0 : 1}; |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x0) { |
| using T = TypeParam; |
| XlaBuilder builder(this->TestName()); |
| auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2)); |
| auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0)); |
| Dot(lhs, rhs); |
| |
| this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 0), {}, |
| this->error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x3) { |
| using T = TypeParam; |
| XlaBuilder builder(this->TestName()); |
| auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2)); |
| auto rhs = ConstantR2FromArray2D<T>( |
| &builder, {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); |
| Dot(lhs, rhs); |
| |
| this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 3), {}, |
| this->error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_3x2_2x0) { |
| using T = TypeParam; |
| XlaBuilder builder(this->TestName()); |
| auto lhs = ConstantR2FromArray2D<T>( |
| &builder, {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); |
| auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0)); |
| Dot(lhs, rhs); |
| |
| this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(3, 0), {}, |
| this->error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_2x0_0x2) { |
| using T = TypeParam; |
| XlaBuilder builder(this->TestName()); |
| auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0)); |
| auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2)); |
| Dot(lhs, rhs); |
| |
| this->template ComputeAndCompareR2<T>( |
| &builder, Array2D<T>(2, 2, static_cast<T>(0.0f)), {}, this->error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) { |
| using T = TypeParam; |
| XlaBuilder builder(this->TestName()); |
| auto param0 = |
| Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 4}), "arg0"); |
| auto param1 = |
| Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({4, 1}), "arg1"); |
| auto exp0 = Exp(param0); |
| Dot(exp0, param1); |
| |
| auto lhs_handle = |
| this->client_ |
| ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>( |
| {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) |
| .value(); |
| auto rhs_handle = this->client_ |
| ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>( |
| {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) |
| .value(); |
| |
| if (std::is_same<Eigen::half, T>::value) { |
| this->error_spec_ = ErrorSpec{0.0001, 1e-3}; |
| } |
| |
| this->template ComputeAndCompareR2<T>( |
| &builder, Array2D<T>({{296.14560492846033f}, {0.8611737683031964f}}), |
| {lhs_handle.get(), rhs_handle.get()}, this->error_spec_); |
| } |
| |
| template <typename T> |
| class SquareMatrixDot : public DotOperationTest { |
| public: |
| void TestImpl(bool lhs_row_major, bool rhs_row_major) { |
| auto lhs_handle = |
| client_ |
| ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>( |
| {{1.0f, 2.0f}, {3.0f, -4.0f}}, |
| LayoutUtil::MakeLayout( |
| MinorToMajorForIsRowMajor(lhs_row_major)))) |
| .value(); |
| auto rhs_handle = |
| client_ |
| ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>( |
| {{1.0f, 6.0f}, {7.0f, -4.0f}}, |
| LayoutUtil::MakeLayout( |
| MinorToMajorForIsRowMajor(rhs_row_major)))) |
| .value(); |
| XlaBuilder builder(TestName()); |
| auto prim_type = primitive_util::NativeToPrimitiveType<T>(); |
| Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), |
| Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); |
| |
| Array2D<T> expected({{15.0f, -2.0f}, {-25.0f, 34.0f}}); |
| ComputeAndCompareR2<T>(&builder, expected, |
| {lhs_handle.get(), rhs_handle.get()}, error_spec_); |
| } |
| }; |
| |
| TYPED_TEST_CASE(SquareMatrixDot, TypesF16F32F64CF64); |
| XLA_TYPED_TEST(SquareMatrixDot, TypesFF) { this->TestImpl(false, false); } |
| XLA_TYPED_TEST(SquareMatrixDot, TypesFT) { this->TestImpl(false, true); } |
| XLA_TYPED_TEST(SquareMatrixDot, TypesTF) { this->TestImpl(true, false); } |
| XLA_TYPED_TEST(SquareMatrixDot, TypesTT) { this->TestImpl(true, true); } |
| |
| struct DotTestParam { |
| int m; |
| int k; |
| int n; |
| bool dot_lhs_row_major; |
| bool dot_rhs_row_major; |
| bool has_addend; |
| bool addend_row_major; |
| }; |
| |
| std::string PrintDotTestParam( |
| const ::testing::TestParamInfo<DotTestParam>& test_param) { |
| const DotTestParam& param = test_param.param; |
| if (param.has_addend) { |
| return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor", |
| param.dot_lhs_row_major ? "T" : "F", |
| param.dot_rhs_row_major ? "T" : "F", |
| param.addend_row_major ? "T" : "F"); |
| } else { |
| return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor", |
| param.dot_lhs_row_major ? "T" : "F", |
| param.dot_rhs_row_major ? "T" : "F"); |
| } |
| } |
| |
| class ParametricDotTest : public DotOperationTest, |
| public ::testing::WithParamInterface<DotTestParam> { |
| protected: |
| template <typename NativeT> |
| void TestImpl(); |
| |
| template <typename NativeT> |
| void ComputeAndCompareR2WithError(XlaBuilder* builder, |
| const Array2D<NativeT>& expected, |
| absl::Span<GlobalData* const> arguments); |
| }; |
| |
| template <typename NativeT> |
| void ParametricDotTest::ComputeAndCompareR2WithError( |
| XlaBuilder* builder, const Array2D<NativeT>& expected, |
| absl::Span<GlobalData* const> arguments) { |
| ErrorSpec error_spec(0.3, 3e-3); |
| ComputeAndCompareR2(builder, expected, arguments, error_spec); |
| } |
| |
| template <> |
| void ParametricDotTest::ComputeAndCompareR2WithError<Eigen::half>( |
| XlaBuilder* builder, const Array2D<Eigen::half>& expected, |
| absl::Span<GlobalData* const> arguments) { |
| ErrorSpec error_spec(0.3, 7e-3); |
| ComputeAndCompareR2(builder, expected, arguments, error_spec); |
| } |
| |
| template <> |
| void ParametricDotTest::ComputeAndCompareR2WithError<int32_t>( |
| XlaBuilder* builder, const Array2D<int32_t>& expected, |
| absl::Span<GlobalData* const> arguments) { |
| ComputeAndCompareR2(builder, expected, arguments); |
| } |
| |
| template <typename NativeT> |
| void ParametricDotTest::TestImpl() { |
| DotTestParam param = GetParam(); |
| |
| std::unique_ptr<Array2D<NativeT>> dot_lhs_data = |
| MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k); |
| Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( |
| *dot_lhs_data, LayoutUtil::MakeLayout( |
| MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); |
| std::unique_ptr<GlobalData> dot_lhs_handle = |
| client_->TransferToServer(dot_lhs_lit).value(); |
| |
| std::unique_ptr<Array2D<NativeT>> dot_rhs_data = |
| MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n); |
| Layout rhs_layout = LayoutUtil::MakeLayout( |
| MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); |
| Literal dot_rhs_lit = |
| LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); |
| std::unique_ptr<GlobalData> dot_rhs_handle = |
| client_->TransferToServer(dot_rhs_lit).value(); |
| |
| std::unique_ptr<Array2D<NativeT>> addend_data; |
| Literal addend_lit; |
| std::unique_ptr<GlobalData> addend_handle; |
| |
| if (param.has_addend) { |
| addend_data = MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.n); |
| addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout( |
| *addend_data, LayoutUtil::MakeLayout( |
| MinorToMajorForIsRowMajor(param.addend_row_major))); |
| addend_handle = client_->TransferToServer(addend_lit).value(); |
| } |
| |
| XlaBuilder builder(TestName()); |
| auto prim_type = primitive_util::NativeToPrimitiveType<NativeT>(); |
| auto result = |
| Dot(Parameter(&builder, 0, |
| ShapeUtil::MakeShapeWithLayout( |
| prim_type, {param.m, param.k}, |
| MinorToMajorForIsRowMajor(param.dot_lhs_row_major)), |
| "dot_lhs"), |
| Parameter(&builder, 1, |
| ShapeUtil::MakeShapeWithLayout( |
| prim_type, {param.k, param.n}, |
| MinorToMajorForIsRowMajor(param.dot_rhs_row_major)), |
| "dot_rhs")); |
| |
| if (param.has_addend) { |
| result = |
| Add(result, |
| Parameter(&builder, 2, |
| ShapeUtil::MakeShapeWithLayout( |
| prim_type, {param.m, param.n}, |
| MinorToMajorForIsRowMajor(param.addend_row_major)), |
| "addend")); |
| } |
| |
| std::unique_ptr<Array2D<NativeT>> expected; |
| if (param.has_addend) { |
| expected = ReferenceUtil::ApplyElementwise2D( |
| std::plus<NativeT>(), |
| *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data), |
| *addend_data); |
| } else { |
| expected = ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data); |
| } |
| |
| std::vector<GlobalData*> args = {dot_lhs_handle.get(), dot_rhs_handle.get()}; |
| if (param.has_addend) { |
| args.push_back(addend_handle.get()); |
| } |
| ComputeAndCompareR2WithError<NativeT>(&builder, *expected, args); |
| } |
| |
| std::vector<DotTestParam> CreateDotTestParameters() { |
| std::vector<DotTestParam> params; |
| |
| auto add_matrix_matrix_dot_test = [&](int m, int k, int n) { |
| for (bool lhs_row_major : {true, false}) { |
| for (bool rhs_row_major : {true, false}) { |
| params.push_back({/*m=*/m, /*k=*/k, /*n=*/n, |
| /*dot_lhs_row_major=*/lhs_row_major, |
| /*dot_rhs_row_major=*/rhs_row_major, |
| /*has_addend=*/false, /*addend_row_major=*/true}); |
| } |
| } |
| }; |
| |
| add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/42); |
| add_matrix_matrix_dot_test(/*m=*/23, /*k=*/1, /*n=*/42); |
| add_matrix_matrix_dot_test(/*m=*/23, /*k=*/42, /*n=*/1); |
| add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/1); |
| add_matrix_matrix_dot_test(/*m=*/1, /*k=*/1, /*n=*/1); |
| add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); |
| add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); |
| add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); |
| |
| return params; |
| } |
| |
| #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 |
| XLA_TEST_P(ParametricDotTest, TestF16) { TestImpl<Eigen::half>(); } |
| #endif |
| XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl<float>(); } |
| XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl<double>(); } |
| XLA_TEST_P(ParametricDotTest, TestC64) { TestImpl<std::complex<float>>(); } |
| #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX128 |
| XLA_TEST_P(ParametricDotTest, TestC128) { TestImpl<std::complex<double>>(); } |
| #endif |
| XLA_TEST_P(ParametricDotTest, TestS32) { TestImpl<int32_t>(); } |
| |
| INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, |
| ::testing::ValuesIn(CreateDotTestParameters()), |
| PrintDotTestParam); |
| |
| class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest { |
| public: |
| ParametricDotTestWithoutLayoutAssignment() { |
| execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( |
| "layout-assignment"); |
| execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( |
| "hlo-verifier"); |
| // Disable algebraic simplification because the pass may replace a dot |
| // instruction with a layout-changing multiplication instruction. |
| execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( |
| "algsimp"); |
| } |
| }; |
| |
| std::vector<DotTestParam> CreateNoLayoutAssignmentDotTestParameters() { |
| std::vector<DotTestParam> params; |
| |
| auto add_matrix_vector_dot_test = [&](int k, int n) { |
| for (bool lhs_row_major : {true, false}) { |
| for (bool rhs_row_major : {true, false}) { |
| for (bool has_addend : {true, false}) { |
| // The addend needs to be row major to match the result of the dot. |
| params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, |
| /*dot_lhs_row_major=*/lhs_row_major, |
| /*dot_rhs_row_major=*/rhs_row_major, |
| /*has_addend=*/has_addend, |
| /*addend_row_major=*/true}); |
| if (n != 1) { |
| params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, |
| /*dot_lhs_row_major=*/lhs_row_major, |
| /*dot_rhs_row_major=*/rhs_row_major, |
| /*has_addend=*/has_addend, |
| /*addend_row_major=*/true}); |
| } |
| } |
| } |
| } |
| }; |
| |
| add_matrix_vector_dot_test(/*k=*/8, /*n=*/8); |
| add_matrix_vector_dot_test(/*k=*/130, /*n=*/8); |
| add_matrix_vector_dot_test(/*k=*/8, /*n=*/130); |
| add_matrix_vector_dot_test(/*k=*/290, /*n=*/130); |
| add_matrix_vector_dot_test(/*k=*/1, /*n=*/1); |
| add_matrix_vector_dot_test(/*k=*/1, /*n=*/16); |
| add_matrix_vector_dot_test(/*k=*/1, /*n=*/4); |
| add_matrix_vector_dot_test(/*k=*/1, /*n=*/3); |
| add_matrix_vector_dot_test(/*k=*/3, /*n=*/16); |
| add_matrix_vector_dot_test(/*k=*/3, /*n=*/3); |
| add_matrix_vector_dot_test(/*k=*/29, /*n=*/29); |
| add_matrix_vector_dot_test(/*k=*/8, /*n=*/2); |
| add_matrix_vector_dot_test(/*k=*/2, /*n=*/8); |
| add_matrix_vector_dot_test(/*k=*/259, /*n=*/258); |
| |
| return params; |
| } |
| |
| #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 |
| XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF16) { |
| TestImpl<Eigen::half>(); |
| } |
| #endif |
| XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF32) { |
| TestImpl<float>(); |
| } |
| // TODO(b/147505663): Disabled for now. |
| XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, DISABLED_TestF64) { |
| TestImpl<double>(); |
| } |
| |
| INSTANTIATE_TEST_CASE_P( |
| DotTests, ParametricDotTestWithoutLayoutAssignment, |
| ::testing::ValuesIn(CreateNoLayoutAssignmentDotTestParameters()), |
| PrintDotTestParam); |
| |
| template <typename T> |
| class NonsquareMatrixDot : public DotOperationTest { |
| public: |
| void TestImpl(bool lhs_row_major, bool rhs_row_major) { |
| auto lhs_handle = |
| client_ |
| ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>( |
| {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, |
| LayoutUtil::MakeLayout( |
| MinorToMajorForIsRowMajor(lhs_row_major)))) |
| .value(); |
| auto rhs_handle = |
| client_ |
| ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>( |
| {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, |
| LayoutUtil::MakeLayout( |
| MinorToMajorForIsRowMajor(rhs_row_major)))) |
| .value(); |
| |
| XlaBuilder builder(TestName()); |
| auto prim_type = primitive_util::NativeToPrimitiveType<T>(); |
| Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), |
| Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); |
| |
| Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); |
| |
| ComputeAndCompareR2<T>(&builder, expected, |
| {lhs_handle.get(), rhs_handle.get()}, error_spec_); |
| } |
| }; |
| |
| TYPED_TEST_CASE(NonsquareMatrixDot, TypesF16F32F64CF64); |
| XLA_TYPED_TEST(NonsquareMatrixDot, TestFF) { this->TestImpl(false, false); } |
| XLA_TYPED_TEST(NonsquareMatrixDot, TestFT) { this->TestImpl(false, true); } |
| XLA_TYPED_TEST(NonsquareMatrixDot, TestTF) { this->TestImpl(true, false); } |
| XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } |
| |
| XLA_TEST_F(DotOperationTest, MatrixVectorC64) { |
| auto lhs_handle = |
| client_ |
| ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>( |
| {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) |
| .value(); |
| auto rhs_handle = |
| client_ |
| ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>( |
| {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, |
| LayoutUtil::MakeLayout({1, 0}))) |
| .value(); |
| |
| XlaBuilder builder(TestName()); |
| auto prim_type = primitive_util::NativeToPrimitiveType<complex64>(); |
| Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), |
| Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); |
| |
| Array2D<complex64> expected({{30.0, -2.0}}); |
| |
| ComputeAndCompareR2<complex64>( |
| &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ConcurrentMatMult) { |
| using T = TypeParam; |
| |
| XlaBuilder builder(this->TestName()); |
| auto matrix1 = |
| ConstantR2FromArray2D<T>(&builder, {{1.0f, 2.0f}, {3.0f, 4.0f}}); |
| auto matrix2 = |
| ConstantR2FromArray2D<T>(&builder, {{5.0f, 6.0f}, {7.0f, 8.0f}}); |
| auto matrix12 = Dot(matrix1, matrix2); |
| auto matrix21 = Dot(matrix2, matrix1); |
| Add(matrix12, matrix21); |
| |
| Array2D<T> expected({{42.0f, 56.0f}, {74.0f, 96.0f}}); |
| this->template ComputeAndCompareR2<T>(&builder, expected, {}, |
| this->error_spec_); |
| } |
| |
| template <typename T> |
| class DotOperationTestForBatchMatMul : public DotOperationTest {}; |
| TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64); |
| |
| // Regression test for b/32055648. The root of the graph is a kFusion of 4 |
| // bitcasts. Although bitcasts don't map to thunks, the root should still be |
| // sync-dependent on bitcasts' operands. |
| XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { |
| using T = TypeParam; |
| XlaBuilder builder(this->TestName()); |
| auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}), |
| "x"); |
| auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}), |
| "y"); |
| |
| auto x_flat = Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); |
| auto y_flat = Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); |
| |
| // Slice batches into individual matrices and multiply them. |
| std::vector<XlaOp> out_slices; |
| const auto n = 4; |
| out_slices.reserve(n); |
| for (int i = 0; i < n; ++i) { |
| // Slice off individual matrices and reshape to 2D tensors. |
| auto x_slice = Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); |
| x_slice = Reshape(x_slice, {0, 1, 2}, {2, 2}); |
| auto y_slice = Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); |
| y_slice = Reshape(y_slice, {0, 1, 2}, {2, 2}); |
| |
| auto out = Dot(x_slice, y_slice); |
| out = Reshape(out, {0, 1}, {1, 2, 2}); |
| out_slices.push_back(out); |
| } |
| auto out_flat = ConcatInDim(&builder, out_slices, 0); |
| Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); |
| |
| auto x_data = this->client_ |
| ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>( |
| {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, |
| {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, |
| {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, |
| {{4000.0f, 400.0f}, {40.0f, 4.0f}}}})) |
| .value(); |
| auto y_data = |
| this->client_ |
| ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>( |
| {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, |
| {{{11.0f, 22.0f}, {33.0f, 44.0f}}, |
| {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) |
| .value(); |
| |
| if (std::is_same<Eigen::half, T>::value) { |
| this->error_spec_ = ErrorSpec{0.0001, 1e-3}; |
| } |
| this->template ComputeAndCompareR4<T>( |
| &builder, |
| /*expected=*/ |
| {{{{1300.0f, 2400.0f}, {13.0f, 24.0f}}, |
| {{11400.0f, 13600.0f}, {114.0f, 136.0f}}}, |
| {{{42900.0f, 79200.0f}, {429.0f, 792.0f}}, |
| {{250800.0f, 299200.0f}, {2508.0f, 2992.0f}}}}, |
| {x_data.get(), y_data.get()}, this->error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { |
| using T = TypeParam; |
| |
| XlaBuilder builder(this->TestName()); |
| auto x = |
| Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x"); |
| auto y = |
| Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y"); |
| |
| DotDimensionNumbers dnums; |
| dnums.add_lhs_contracting_dimensions(2); |
| dnums.add_rhs_contracting_dimensions(1); |
| dnums.add_lhs_batch_dimensions(0); |
| dnums.add_rhs_batch_dimensions(0); |
| |
| DotGeneral(x, y, dnums); |
| |
| auto x_data = |
| this->client_ |
| ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>( |
| {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) |
| .value(); |
| |
| auto y_data = |
| this->client_ |
| ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>( |
| {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) |
| .value(); |
| |
| this->template ComputeAndCompareR3<T>( |
| &builder, |
| /*expected=*/ |
| {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, |
| {x_data.get(), y_data.get()}, this->error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR3LhsR2Rhs) { |
| using T = TypeParam; |
| |
| XlaBuilder builder(this->TestName()); |
| auto x = |
| Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x"); |
| auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2}), "y"); |
| |
| DotDimensionNumbers dnums; |
| dnums.add_lhs_contracting_dimensions(1); |
| dnums.add_rhs_contracting_dimensions(1); |
| dnums.add_lhs_batch_dimensions(0); |
| dnums.add_rhs_batch_dimensions(0); |
| |
| DotGeneral(x, y, dnums); |
| |
| auto x_data = |
| this->client_ |
| ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>( |
| {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) |
| .value(); |
| |
| auto y_data = this->client_ |
| ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>( |
| {{1.0f, 0.0f}, {0.0f, 1.0f}})) |
| .value(); |
| |
| this->template ComputeAndCompareR2<T>( |
| &builder, |
| /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()}, |
| this->error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR2LhsR3Rhs) { |
| using T = TypeParam; |
| |
| XlaBuilder builder(this->TestName()); |
| auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2}), "x"); |
| auto y = |
| Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y"); |
| |
| DotDimensionNumbers dnums; |
| dnums.add_lhs_contracting_dimensions(1); |
| dnums.add_rhs_contracting_dimensions(1); |
| dnums.add_lhs_batch_dimensions(0); |
| dnums.add_rhs_batch_dimensions(0); |
| |
| DotGeneral(x, y, dnums); |
| |
| auto x_data = this->client_ |
| ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>( |
| {{1.0f, 0.0f}, {0.0f, 1.0f}})) |
| .value(); |
| |
| auto y_data = |
| this->client_ |
| ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>( |
| {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) |
| .value(); |
| |
| this->template ComputeAndCompareR2<T>( |
| &builder, |
| /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()}, |
| this->error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { |
| using T = TypeParam; |
| |
| XlaBuilder builder(this->TestName()); |
| auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}), |
| "x"); |
| auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}), |
| "y"); |
| |
| DotDimensionNumbers dnums; |
| dnums.add_lhs_contracting_dimensions(3); |
| dnums.add_rhs_contracting_dimensions(2); |
| dnums.add_lhs_batch_dimensions(0); |
| dnums.add_lhs_batch_dimensions(1); |
| dnums.add_rhs_batch_dimensions(0); |
| dnums.add_rhs_batch_dimensions(1); |
| |
| DotGeneral(x, y, dnums); |
| |
| auto x_data = |
| this->client_ |
| ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>( |
| {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, |
| {{{9.0f, 10.0f}, {11.0f, 12.0f}}, |
| {{13.0f, 14.0f}, {15.0f, 16.0f}}}})) |
| .value(); |
| |
| auto y_data = |
| this->client_ |
| ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>( |
| {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}, |
| {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}})) |
| .value(); |
| |
| this->template ComputeAndCompareR4<T>( |
| &builder, |
| /*expected=*/ |
| {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, |
| {{{10.0f, 9.0f}, {12.0f, 11.0f}}, {{14.0f, 13.0f}, {16.0f, 15.0f}}}}, |
| {x_data.get(), y_data.get()}, this->error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) { |
| using T = TypeParam; |
| for (bool transpose_lhs : {false, true}) { |
| for (bool transpose_rhs : {false, true}) { |
| for (bool row_major : {false, true}) { |
| std::unique_ptr<Array2D<T>> lhs( |
| new Array2D<T>({{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}})); |
| std::unique_ptr<Array2D<T>> rhs( |
| new Array2D<T>({{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}})); |
| |
| if (transpose_lhs) { |
| lhs = ReferenceUtil::TransposeArray2D(*lhs); |
| } |
| if (transpose_rhs) { |
| rhs = ReferenceUtil::TransposeArray2D(*rhs); |
| } |
| auto lhs_handle = |
| this->client_ |
| ->TransferToServer( |
| LiteralUtil::CreateR2FromArray2DWithLayout<T>( |
| *lhs, LayoutUtil::MakeLayout( |
| MinorToMajorForIsRowMajor(row_major)))) |
| .value(); |
| auto rhs_handle = |
| this->client_ |
| ->TransferToServer( |
| LiteralUtil::CreateR2FromArray2DWithLayout<T>( |
| *rhs, LayoutUtil::MakeLayout( |
| MinorToMajorForIsRowMajor(row_major)))) |
| .value(); |
| |
| XlaBuilder builder(this->TestName()); |
| auto prim_type = primitive_util::NativeToPrimitiveType<T>(); |
| auto lhs_arg = Parameter( |
| &builder, 0, |
| ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), |
| "lhs"); |
| auto rhs_arg = Parameter( |
| &builder, 1, |
| ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}), |
| "rhs"); |
| if (transpose_lhs) { |
| lhs_arg = Transpose(lhs_arg, {1, 0}); |
| } |
| if (transpose_rhs) { |
| rhs_arg = Transpose(rhs_arg, {1, 0}); |
| } |
| Dot(lhs_arg, rhs_arg); |
| |
| Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); |
| VLOG(1) << "TestTransposeFolding " << transpose_lhs << " " |
| << transpose_rhs << " " << row_major; |
| this->template ComputeAndCompareR2<T>( |
| &builder, expected, {lhs_handle.get(), rhs_handle.get()}, |
| this->error_spec_); |
| } |
| } |
| } |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, |
| DotOfConcatOptimizationWithConstLHS) { |
| using T = TypeParam; |
| auto prim_type = primitive_util::NativeToPrimitiveType<T>(); |
| |
| std::unique_ptr<Array2D<T>> constant_lhs_array( |
| new Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, |
| {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}})); |
| |
| XlaBuilder builder(this->TestName()); |
| auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); |
| auto rhs_arg_0 = Parameter( |
| &builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0"); |
| auto rhs_arg_1 = Parameter( |
| &builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs_arg_1"); |
| auto rhs_arg_2 = Parameter( |
| &builder, 2, ShapeUtil::MakeShape(prim_type, {1, 2}), "rhs_arg_2"); |
| Dot(lhs_constant, |
| ConcatInDim(&builder, {rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); |
| |
| std::unique_ptr<Array2D<T>> arg_0_value_array( |
| new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}})); |
| std::unique_ptr<Array2D<T>> arg_1_value_array( |
| new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}})); |
| std::unique_ptr<Array2D<T>> arg_2_value_array(new Array2D<T>({{1.0f, 2.0f}})); |
| |
| TF_ASSERT_OK_AND_ASSIGN( |
| auto arg_0_value, |
| this->client_->TransferToServer( |
| LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array))); |
| TF_ASSERT_OK_AND_ASSIGN( |
| auto arg_1_value, |
| this->client_->TransferToServer( |
| LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array))); |
| TF_ASSERT_OK_AND_ASSIGN( |
| auto arg_2_value, |
| this->client_->TransferToServer( |
| LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array))); |
| |
| Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); |
| this->template ComputeAndCompareR2<T>( |
| &builder, expected, |
| {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, |
| this->error_spec_); |
| } |
| |
| XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, |
| DotOfConcatOptimizationWithConstRHS) { |
| using T = TypeParam; |
| std::unique_ptr<Array2D<T>> constant_rhs_array( |
| new Array2D<T>({{1.0f, 2.0f}, |
| {3.0f, 4.0f}, |
| {5.0f, 6.0f}, |
| {6.0f, 5.0f}, |
| {4.0f, 3.0f}, |
| {2.0f, 1.0f}})); |
| |
| XlaBuilder builder(this->TestName()); |
| auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); |
| auto lhs_arg_0 = Parameter( |
| &builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2}), "lhs_arg_0"); |
| auto lhs_arg_1 = Parameter( |
| &builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 3}), "lhs_arg_1"); |
| auto lhs_arg_2 = Parameter( |
| &builder, 2, ShapeUtil::MakeShapeWithType<T>({2, 1}), "lhs_arg_2"); |
| Dot(ConcatInDim(&builder, {lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), |
| rhs_constant); |
| |
| std::unique_ptr<Array2D<T>> arg_0_value_array( |
| new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}})); |
| std::unique_ptr<Array2D<T>> arg_1_value_array( |
| new Array2D<T>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}})); |
| std::unique_ptr<Array2D<T>> arg_2_value_array( |
| new Array2D<T>({{1.0f}, {2.0f}})); |
| |
| TF_ASSERT_OK_AND_ASSIGN( |
| auto arg_0_value, |
| this->client_->TransferToServer( |
| LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array))); |
| TF_ASSERT_OK_AND_ASSIGN( |
| auto arg_1_value, |
| this->client_->TransferToServer( |
| LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array))); |
| TF_ASSERT_OK_AND_ASSIGN( |
| auto arg_2_value, |
| this->client_->TransferToServer( |
| LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array))); |
| |
| Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); |
| this->template ComputeAndCompareR2<T>( |
| &builder, expected, |
| {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, |
| this->error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { |
| std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( |
| {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); |
| std::unique_ptr<Array2D<float>> constant_rhs_array( |
| new Array2D<float>({{1.0, 2.0, 3.0}, |
| {4.0, 5.0, 6.0}, |
| {7.0, 8.0, 9.0}, |
| {9.0, 8.0, 7.0}, |
| {6.0, 5.0, 4.0}, |
| {3.0, 2.0, 1.0}})); |
| // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} |
| |
| XlaBuilder builder(TestName()); |
| auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); |
| auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); |
| auto one = ConstantR0<int32_t>(&builder, 1); |
| auto zero = ConstantR0<int32_t>(&builder, 0); |
| auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6}); |
| |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(1); |
| dot_dnums.add_rhs_contracting_dimensions(0); |
| DotGeneral(dynamic_slice, rhs_constant, dot_dnums); |
| |
| Array2D<float> expected({{96.0, 105.0, 114.0}}); |
| ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { |
| std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( |
| {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); |
| std::unique_ptr<Array2D<float>> constant_rhs_array( |
| new Array2D<float>({{1.0, 2.0, 3.0}, |
| {4.0, 5.0, 6.0}, |
| {7.0, 8.0, 9.0}, |
| {9.0, 8.0, 7.0}, |
| {6.0, 5.0, 4.0}, |
| {3.0, 2.0, 1.0}})); |
| // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} |
| |
| XlaBuilder builder(TestName()); |
| auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); |
| auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); |
| auto zero = ConstantR0<int32_t>(&builder, 0); |
| auto one = ConstantR0<int32_t>(&builder, 1); |
| auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1}); |
| |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(1); |
| dot_dnums.add_rhs_contracting_dimensions(0); |
| DotGeneral(lhs_constant, dynamic_slice, dot_dnums); |
| |
| Array2D<float> expected({{105.0}, {105.0}}); |
| ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, |
| |
| DotOfGatherOptimizationWithConstRHSReverseMM) { |
| std::unique_ptr<Array2D<float>> constant_lhs_array( |
| new Array2D<float>({{1.0, 2.0, 3.0}, |
| {4.0, 5.0, 6.0}, |
| {7.0, 8.0, 9.0}, |
| {9.0, 8.0, 7.0}, |
| {6.0, 5.0, 4.0}, |
| {3.0, 2.0, 1.0}})); |
| std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>( |
| {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); |
| // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} |
| |
| XlaBuilder builder(TestName()); |
| auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); |
| auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); |
| auto zero = ConstantR0<int32_t>(&builder, 0); |
| auto one = ConstantR0<int32_t>(&builder, 1); |
| auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1}); |
| |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(0); |
| dot_dnums.add_rhs_contracting_dimensions(1); |
| DotGeneral(dynamic_slice, rhs_constant, dot_dnums); |
| |
| Array2D<float> expected({{105.0, 105.0}}); |
| ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) { |
| std::unique_ptr<Array2D<float>> constant_lhs_array( |
| new Array2D<float>({{1.0, 2.0, 3.0}, |
| {4.0, 5.0, 6.0}, |
| {7.0, 8.0, 9.0}, |
| {9.0, 8.0, 7.0}, |
| {6.0, 5.0, 4.0}, |
| {3.0, 2.0, 1.0}})); |
| std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>( |
| {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); |
| // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} |
| |
| XlaBuilder builder(TestName()); |
| auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); |
| auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); |
| auto zero = ConstantR0<int32_t>(&builder, 0); |
| auto one = ConstantR0<int32_t>(&builder, 1); |
| auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6}); |
| |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(0); |
| dot_dnums.add_rhs_contracting_dimensions(1); |
| DotGeneral(lhs_constant, dynamic_slice, dot_dnums); |
| |
| Array2D<float> expected({{96.0}, {105.0}, {114.0}}); |
| ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) { |
| std::unique_ptr<Array2D<float>> constant_lhs_array( |
| new Array2D<float>({{1.0, 2.0}, |
| {3.0, 4.0}, |
| {5.0, 6.0}, |
| {6.0, 5.0}, |
| {4.0, 3.0}, |
| {2.0, 1.0}})); |
| std::unique_ptr<Array2D<float>> constant_rhs_array( |
| new Array2D<float>({{1.0, 2.0, 3.0}, |
| {4.0, 5.0, 6.0}, |
| {7.0, 8.0, 9.0}, |
| {9.0, 8.0, 7.0}, |
| {6.0, 5.0, 4.0}, |
| {3.0, 2.0, 1.0}})); |
| // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} |
| |
| XlaBuilder builder(TestName()); |
| auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); |
| auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); |
| auto zero = ConstantR0<int32_t>(&builder, 0); |
| auto one = ConstantR0<int32_t>(&builder, 1); |
| auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1}); |
| |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(0); |
| dot_dnums.add_rhs_contracting_dimensions(0); |
| DotGeneral(dynamic_slice, rhs_constant, dot_dnums); |
| |
| Array2D<float> expected({{126.0, 129.0, 132.0}}); |
| ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) { |
| std::unique_ptr<Array2D<float>> constant_lhs_array( |
| new Array2D<float>({{1.0, 2.0}, |
| {3.0, 4.0}, |
| {5.0, 6.0}, |
| {6.0, 5.0}, |
| {4.0, 3.0}, |
| {2.0, 1.0}})); |
| std::unique_ptr<Array2D<float>> constant_rhs_array( |
| new Array2D<float>({{1.0, 2.0, 3.0}, |
| {4.0, 5.0, 6.0}, |
| {7.0, 8.0, 9.0}, |
| {9.0, 8.0, 7.0}, |
| {6.0, 5.0, 4.0}, |
| {3.0, 2.0, 1.0}})); |
| // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} |
| |
| XlaBuilder builder(TestName()); |
| auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); |
| auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); |
| auto zero = ConstantR0<int32_t>(&builder, 0); |
| auto one = ConstantR0<int32_t>(&builder, 1); |
| auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1}); |
| |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(0); |
| dot_dnums.add_rhs_contracting_dimensions(0); |
| DotGeneral(lhs_constant, dynamic_slice, dot_dnums); |
| |
| Array2D<float> expected({{129.0}, {129.0}}); |
| ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) { |
| std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( |
| {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); |
| std::unique_ptr<Array2D<float>> constant_rhs_array( |
| new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, |
| {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, |
| {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); |
| // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} |
| |
| XlaBuilder builder(TestName()); |
| auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); |
| auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); |
| auto zero = ConstantR0<int32_t>(&builder, 0); |
| auto one = ConstantR0<int32_t>(&builder, 1); |
| auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6}); |
| |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(1); |
| dot_dnums.add_rhs_contracting_dimensions(1); |
| DotGeneral(dynamic_slice, rhs_constant, dot_dnums); |
| |
| Array2D<float> expected({{56.0, 168.0, 91.0}}); |
| ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) { |
| std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( |
| {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); |
| std::unique_ptr<Array2D<float>> constant_rhs_array( |
| new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, |
| {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, |
| {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); |
| // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} |
| |
| XlaBuilder builder(TestName()); |
| auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); |
| auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); |
| auto zero = ConstantR0<int32_t>(&builder, 0); |
| auto one = ConstantR0<int32_t>(&builder, 1); |
| auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6}); |
| |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(1); |
| dot_dnums.add_rhs_contracting_dimensions(1); |
| DotGeneral(lhs_constant, dynamic_slice, dot_dnums); |
| |
| Array2D<float> expected({{168.0}, {168.0}}); |
| ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { |
| XlaBuilder builder(TestName()); |
| |
| Array2D<float> lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}}); |
| auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array); |
| |
| Array2D<float> rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}}); |
| auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array); |
| |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(0); |
| dot_dnums.add_rhs_contracting_dimensions(0); |
| DotGeneral(lhs_constant, rhs_constant, dot_dnums); |
| |
| Array2D<float> expected({ |
| {26.f, 30.f}, |
| {38.f, 44.f}, |
| }); |
| |
| ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); |
| } |
| |
| using EinsumParamType = |
| std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::string>; |
| class EinsumTest : public DotOperationTest, |
| public ::testing::WithParamInterface<EinsumParamType> {}; |
| XLA_TEST_P(EinsumTest, SimpleEinsumTest) { |
| XlaBuilder builder(TestName()); |
| auto x = AddParam( |
| MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam()))) |
| .ValueOrDie(), |
| &builder); |
| auto y = AddParam( |
| MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam()))) |
| .ValueOrDie(), |
| &builder); |
| auto config = std::get<2>(GetParam()); |
| if (config.find(',') == config.npos) { |
| Einsum(x, config); |
| } else { |
| Einsum(x, y, config); |
| } |
| ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3}); |
| } |
| |
| std::vector<EinsumParamType> GetEinsumTestCases() { |
| using v = std::vector<int64_t>; |
| using p = EinsumParamType; |
| std::vector<p> test_cases = { |
| p{v{5, 6}, v{6, 7}, "mk,kn->mn"}, |
| p{v{5, 6}, v{6, 7}, "mk,kn->nm"}, |
| p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn->nmB"}, |
| p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->nmB"}, |
| p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->Bnm"}, |
| p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"}, |
| p{v{5, 6}, v{6, 7}, "ab,cd->dcba"}, |
| p{v{6}, v{6, 7}, "b,bc->c"}, |
| p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc->ab"}, |
| p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba->ca"}, |
| p{v{77}, v{77}, "a,a->a"}, |
| p{v{77}, v{77, 55}, "a,ab->ba"}, |
| p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"}, |
| p{v{55}, v{}, "a,->a"}, |
| p{v{11, 111}, v{11}, "ab,a->ab"}, |
| p{v{16, 34}, v{16, 34}, "ab,ab->ab"}, |
| p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac->abc"}, |
| p{v{5, 19}, v{}, "ab,->ab"}, |
| p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf->bhqk"}, |
| p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn->...mn"}, |
| p{v{5, 6}, v{6, 7}, "...mk,...kn->...mn"}, |
| p{v{5, 6}, v{6, 7}, "...mk,kn->...mn"}, |
| p{v{6, 6}, v{7, 7}, "mm,nn->mn"}, |
| p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"}, |
| p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"}, |
| p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn->...mn"}, |
| p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->n"}, |
| p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb...->ba...ij"}, |
| p{v{5, 6}, v{6, 7}, "mk,kn"}, |
| p{v{5, 6}, v{6, 7}, "mk,kn"}, |
| p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn"}, |
| p{v{5, 6}, v{6, 7}, "ab,cd"}, |
| p{v{6}, v{6, 7}, "b,bc"}, |
| p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc"}, |
| p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba"}, |
| p{v{77}, v{77}, "a,a"}, |
| p{v{77}, v{77, 55}, "a,ab"}, |
| p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb"}, |
| p{v{55}, v{}, "a"}, |
| p{v{11, 111}, v{11}, "ab,a"}, |
| p{v{16, 34}, v{16, 34}, "ab,ab"}, |
| p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac"}, |
| p{v{5, 19}, v{}, "ab"}, |
| p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf"}, |
| p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn"}, |
| p{v{5, 6}, v{}, "...mk"}, |
| p{v{5, 6, 12, 13}, v{}, "...mk"}, |
| p{v{5, 6, 12, 13}, v{}, "m...k"}, |
| p{v{5, 6, 12, 13}, v{}, "mk..."}, |
| p{v{5, 6}, v{6, 7}, "...mk->km..."}, |
| p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"}, |
| p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"}, |
| p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn"}, |
| p{v{16, 16, 16}, v{}, "iii"}, |
| p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb..."}, |
| }; |
| return test_cases; |
| } |
| |
| INSTANTIATE_TEST_SUITE_P(Einsum, EinsumTest, |
| ::testing::ValuesIn(GetEinsumTestCases())); |
| |
| using BatchDotParamType = std::tuple<std::vector<int64_t>, std::vector<int64_t>, |
| std::vector<int64_t>>; |
| class BatchDotTest : public DotOperationTest, |
| public ::testing::WithParamInterface<BatchDotParamType> {}; |
| XLA_TEST_P(BatchDotTest, BroadcastingBatchDotTest) { |
| XlaBuilder builder(TestName()); |
| auto x = AddParam( |
| MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam()))) |
| .ValueOrDie(), |
| &builder); |
| auto y = AddParam( |
| MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam()))) |
| .ValueOrDie(), |
| &builder); |
| auto batch_dot = BatchDot(x, y); |
| auto output_shape = builder.GetShape(batch_dot).ValueOrDie(); |
| EXPECT_EQ(output_shape.dimensions(), std::get<2>(GetParam())); |
| ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3}); |
| } |
| |
| std::vector<BatchDotParamType> GetBatchDotTestCases() { |
| using v = std::vector<int64_t>; |
| using p = BatchDotParamType; |
| std::vector<p> test_cases = { |
| p{v{5, 6}, v{6, 7}, v{5, 7}}, |
| p{v{5, 6, 11}, v{5, 11, 7}, v{5, 6, 7}}, |
| p{v{5, 6, 11}, v{11, 7}, v{5, 6, 7}}, |
| p{v{5, 6, 11}, v{1, 11, 7}, v{5, 6, 7}}, |
| p{v{6, 11}, v{5, 11, 7}, v{5, 6, 7}}, |
| p{v{1, 6, 11}, v{5, 11, 7}, v{5, 6, 7}}, |
| p{v{8, 1, 2, 3}, v{8, 3, 4}, v{8, 8, 2, 4}}, |
| p{v{8, 8, 2, 3}, v{8, 1, 3, 2}, v{8, 8, 2, 2}}, |
| }; |
| return test_cases; |
| } |
| |
| INSTANTIATE_TEST_SUITE_P(BatchDot, BatchDotTest, |
| ::testing::ValuesIn(GetBatchDotTestCases())); |
| |
| class DotOperationTextTest : public HloTestBase {}; |
| |
| XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule ComplexDotMultipleNonContracting |
| |
| ENTRY %test { |
| %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0) |
| %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1) |
| ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, DotReorderedDotDimsAndMultipleContracting) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule ComplexDotMultipleNonContracting |
| |
| ENTRY %test { |
| %lhs = f32[7,5,17,10,13]{4,3,2,1,0} parameter(0) |
| %rhs = f32[7,9,10,13,6,5]{5,4,3,2,1,0} parameter(1) |
| ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={3,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={1,4}, rhs_contracting_dims={5,3} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, DotWithNoDnums) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule DotWithNoDnums |
| |
| ENTRY %test { |
| %lhs = f32[2,3]{1,0} parameter(0) |
| %rhs = f32[4,5]{1,0} parameter(1) |
| ROOT %dot = f32[2,3,4,5]{3,2,1,0} dot(%lhs, %rhs) |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, Einsum) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule Einsum |
| |
| ENTRY %test { |
| %lhs = f32[8,64,96]{2,1,0} parameter(0) |
| %rhs = f32[96,32,4]{2,1,0} parameter(1) |
| ROOT %dot = f32[8,64,32,4]{3,2,1,0} dot(%lhs, %rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_1) { |
| // Tests for a caching bug in the XLA CPU backend. |
| absl::string_view hlo_string = |
| R"( |
| HloModule CpuTiledDotEmitterCachingBug |
| |
| ENTRY main { |
| lhs = f32[20,40] parameter(0) |
| rhs_0 = f32[40,1] parameter(2) |
| rhs_1 = f32[1,40] parameter(1) |
| |
| dot_0 = f32[20,1] dot(lhs, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| dot_1 = f32[20,1] dot(lhs, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} |
| |
| ROOT result = f32[20,1] divide(dot_0, dot_1) |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_2) { |
| // Tests for a caching bug in the XLA CPU backend. |
| absl::string_view hlo_string = |
| R"( |
| HloModule CpuTiledDotEmitterCachingBug |
| |
| ENTRY main { |
| lhs_0 = f32[20,40] parameter(0) |
| rhs_0 = f32[40,1] parameter(1) |
| lhs_1 = f32[1,40] parameter(2) |
| rhs_1 = f32[20,40] parameter(3) |
| |
| dot_0 = f32[20,1] dot(lhs_0, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| dot_1 = f32[1,20] dot(lhs_1, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} |
| |
| dot_0_reshaped = f32[20] reshape(dot_0) |
| dot_1_reshaped = f32[20] reshape(dot_1) |
| |
| ROOT result = f32[20] divide(dot_0_reshaped, dot_1_reshaped) |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, S32IotaDot) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule SmallIntegerDot |
| |
| ENTRY SmallIntegerDot { |
| arg0 = s32[5,55,8] iota(), iota_dimension=1 |
| arg1 = s32[5,8,200] iota(), iota_dimension=2 |
| ROOT dot = s32[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, S32IotaSquaredDot) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule SmallIntegerDot |
| |
| ENTRY SmallIntegerDot { |
| arg0 = s32[16,2] iota(), iota_dimension=0 |
| a = s32[16,2] multiply(arg0, arg0) |
| r = s32[16,2] multiply(a, a) |
| arg1 = s32[2,98] iota(), iota_dimension=1 |
| b = s32[2,98] multiply(arg1, arg1) |
| s = s32[2,98] multiply(b, b) |
| ROOT dot = s32[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, U16IotaDot) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule SmallIntegerDot |
| |
| ENTRY SmallIntegerDot { |
| arg0 = u16[5,55,8] parameter(0) |
| arg1 = u16[5,8,200] parameter(1) |
| dot = u16[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} |
| ROOT c = s32[5,55,200] convert(dot) |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, U16IotaSquaredDot) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule SmallIntegerDot |
| |
| ENTRY SmallIntegerDot { |
| arg0 = u16[16,2] iota(), iota_dimension=0 |
| a = u16[16,2] multiply(arg0, arg0) |
| r = u16[16,2] multiply(a, a) |
| arg1 = u16[2,98] iota(), iota_dimension=1 |
| b = u16[2,98] multiply(arg1, arg1) |
| s = u16[2,98] multiply(b, b) |
| ROOT dot = u16[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, S16IotaDot) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule SmallIntegerDot |
| |
| ENTRY SmallIntegerDot { |
| arg0 = s16[5,55,8] iota(), iota_dimension=1 |
| arg1 = s16[5,8,200] iota(), iota_dimension=2 |
| ROOT dot = s16[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, S16IotaSquaredDot) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule SmallIntegerDot |
| |
| ENTRY SmallIntegerDot { |
| arg0 = s16[16,2] iota(), iota_dimension=0 |
| a = s16[16,2] multiply(arg0, arg0) |
| r = s16[16,2] multiply(a, a) |
| arg1 = s16[2,98] iota(), iota_dimension=1 |
| b = s16[2,98] multiply(arg1, arg1) |
| s = s16[2,98] multiply(b, b) |
| ROOT dot = s16[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, PREDDot) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule SmallIntegerDot |
| |
| ENTRY SmallIntegerDot { |
| arg0 = pred[20,2] parameter(0) |
| arg1 = pred[2,20] parameter(1) |
| ROOT dot = pred[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, S8Dot) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule SmallIntegerDot |
| |
| ENTRY SmallIntegerDot { |
| arg0 = s8[20,2] parameter(0) |
| arg1 = s8[2,20] parameter(1) |
| ROOT dot = s8[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, S32Dot) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule SmallIntegerDot |
| |
| ENTRY SmallIntegerDot { |
| arg0 = s32[20,55] parameter(0) |
| arg1 = s32[55,20] parameter(1) |
| ROOT dot = s32[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, GpuTransposeOutput) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule TransposeOutput |
| |
| ENTRY TransposeOutput { |
| p0 = f32[32,32] parameter(0) |
| p1 = f32[32,64] parameter(1) |
| dot = f32[32,64] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0} |
| ROOT tr = f32[64,32] transpose(dot), dimensions={1,0} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, MatrixVectorComplex) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule MatrixVectorComplex |
| |
| ENTRY MatrixVectorComplex { |
| p0 = c64[5,5] parameter(0) |
| p1 = c64[5,1] parameter(1) |
| p2 = c64[5,1] parameter(2) |
| dot = c64[5,1] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| ROOT add = c64[5,1] add(dot, p2) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, |
| ParseAndReturnUnverifiedModule(hlo_string)); |
| EXPECT_TRUE(RunAndCompare(std::move(hlo_module), ErrorSpec{4e-3, 4e-3})); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, MatrixVectorBF16) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule MatrixVectorBF16 |
| |
| ENTRY MatrixVectorBF16 { |
| p0 = bf16[128] parameter(0) |
| p1 = bf16[128,256] parameter(1) |
| p2 = bf16[256] parameter(2) |
| dot = bf16[256] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0} |
| ROOT add = bf16[256] add(dot, p2) |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); |
| } |
| |
| // Regression test for b/138155357, where we were incorrectly creating a dot-add |
| // fusion where the dot had a batch dimension. This isn't supported on the CPU |
| // backend. |
| XLA_TEST_F(DotOperationTextTest, FusedBatchDotRegressionTest) { |
| absl::string_view module_string = R"( |
| HloModule jaxpr_computation__5.33 |
| |
| jaxpr_computation__6.8 { |
| tuple.9 = () tuple() |
| parameter.14 = () parameter(4) |
| parameter.13 = (f32[2]{0}) parameter(3) |
| get-tuple-element.15 = f32[2]{0} get-tuple-element(parameter.13), index=0 |
| reshape.16 = f32[1,2]{1,0} reshape(get-tuple-element.15) |
| parameter.10 = f32[2,2]{1,0} parameter(0) |
| reshape.17 = f32[2,1]{1,0} reshape(get-tuple-element.15) |
| dot.18 = f32[2,1]{1,0} dot(parameter.10, reshape.17), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| reshape.19 = f32[2]{0} reshape(dot.18) |
| reshape.20 = f32[2,1]{1,0} reshape(reshape.19) |
| dot.21 = f32[1,1]{1,0} dot(reshape.16, reshape.20), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| reshape.22 = f32[] reshape(dot.21) |
| parameter.11 = f32[2,1,2]{2,1,0} parameter(1) |
| broadcast.23 = f32[2,2,1]{2,1,0} broadcast(reshape.20), dimensions={1,2} |
| dot.24 = f32[2,1,1]{2,1,0} dot(parameter.11, broadcast.23), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} |
| broadcast.25 = f32[2,1,2]{2,1,0} broadcast(reshape.16), dimensions={1,2} |
| parameter.12 = f32[2,2,1]{2,1,0} parameter(2) |
| dot.26 = f32[2,1,1]{2,1,0} dot(broadcast.25, parameter.12), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} |
| add.27 = f32[2,1,1]{2,1,0} add(dot.24, dot.26) |
| reshape.28 = f32[2]{0} reshape(add.27) |
| ROOT tuple.29 = (f32[], f32[2]{0}) tuple(reshape.22, reshape.28) |
| } |
| |
| ENTRY jaxpr_computation__5.33 { |
| constant.2 = f32[] constant(1) |
| broadcast.3 = f32[2,2]{1,0} broadcast(constant.2), dimensions={} |
| constant.5 = f32[2,1,2]{2,1,0} constant({ { { 1, 0 } }, { { 0, 1 } } }) |
| constant.4 = f32[2,2,1]{2,1,0} constant({ { {1}, {1} }, { {1}, {1} } }) |
| parameter.6 = f32[2]{0} parameter(0) |
| tuple.7 = (f32[2]{0}) tuple(parameter.6) |
| tuple.1 = () tuple() |
| call.30 = (f32[], f32[2]{0}) call(broadcast.3, constant.5, constant.4, tuple.7, tuple.1), to_apply=jaxpr_computation__6.8 |
| get-tuple-element.31 = f32[] get-tuple-element(call.30), index=0 |
| ROOT get-tuple-element.32 = f32[2]{0} get-tuple-element(call.30), index=1 |
| })"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(module_string)); |
| EXPECT_TRUE(RunAndCompare(std::move(module), /*error=*/std::nullopt)); |
| } |
| |
| XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstLHS_RL) { |
| Array3D<float> input_arr(2, 3, 2); |
| Array2D<float> const_arr(2, 6); |
| input_arr.FillIota(0); |
| const_arr.FillIota(0); |
| |
| XlaBuilder builder(TestName()); |
| auto t0 = |
| AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder); |
| auto t1 = Transpose(t0, {1, 0, 2}); |
| auto rhs = Reshape(t1, {6, 2}); |
| auto lhs = ConstantR2FromArray2D(&builder, const_arr); |
| Dot(lhs, rhs); |
| |
| ComputeAndCompare(&builder, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_LR) { |
| Array3D<float> input_arr(2, 3, 2); |
| Array2D<float> const_arr(2, 6); |
| input_arr.FillIota(0); |
| const_arr.FillIota(0); |
| |
| XlaBuilder builder(TestName()); |
| auto t0 = |
| AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder); |
| auto t1 = Transpose(t0, {1, 0, 2}); |
| auto lhs = Reshape(t1, {6, 2}); |
| auto rhs = ConstantR2FromArray2D(&builder, const_arr); |
| |
| DotDimensionNumbers dims; |
| dims.add_lhs_contracting_dimensions(0); |
| dims.add_rhs_contracting_dimensions(1); |
| DotGeneral(lhs, rhs, dims); |
| |
| ComputeAndCompare(&builder, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_RL) { |
| Array4D<float> input_arr(2, 2, 3, 4); |
| Array2D<float> const_arr(24, 2); |
| input_arr.FillIota(0); |
| const_arr.FillIota(0); |
| |
| XlaBuilder builder(TestName()); |
| auto t0 = |
| AddParam(LiteralUtil::CreateR4FromArray4D<float>(input_arr), &builder); |
| auto t1 = Transpose(t0, {0, 2, 3, 1}); |
| auto lhs = Reshape(t1, {2, 24}); |
| auto rhs = ConstantR2FromArray2D(&builder, const_arr); |
| Dot(lhs, rhs); |
| |
| ComputeAndCompare(&builder, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_MM) { |
| Array3D<float> input_arr(2, 6, 2); |
| Array3D<float> const_arr(2, 6, 3); |
| input_arr.FillIota(0); |
| const_arr.FillIota(0); |
| |
| XlaBuilder builder(TestName()); |
| auto t0 = |
| AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder); |
| auto t1 = Reshape(t0, {2, 2, 3, 2}); |
| auto t2 = Transpose(t1, {0, 2, 1, 3}); |
| auto lhs = Reshape(t2, {2, 6, 2}); |
| auto rhs = ConstantR3FromArray3D(&builder, const_arr); |
| |
| DotDimensionNumbers dims; |
| dims.add_lhs_contracting_dimensions(1); |
| dims.add_rhs_contracting_dimensions(1); |
| dims.add_lhs_batch_dimensions(0); |
| dims.add_rhs_batch_dimensions(0); |
| DotGeneral(lhs, rhs, dims); |
| |
| ComputeAndCompare(&builder, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTest, ReorderContractingDims_Multipass) { |
| Array4D<float> input_arr(2, 2, 3, 5); |
| Array2D<float> const_arr(2, 30); |
| input_arr.FillIota(0); |
| const_arr.FillIota(0); |
| |
| XlaBuilder builder(TestName()); |
| auto t0 = |
| AddParam(LiteralUtil::CreateR4FromArray4D<float>(input_arr), &builder); |
| auto t1 = Transpose(t0, {0, 2, 1, 3}); |
| auto t2 = Reshape(t1, {2, 6, 5}); |
| auto t3 = Transpose(t2, {0, 2, 1}); |
| auto lhs = Reshape(t3, {2, 30}); |
| auto rhs = ConstantR2FromArray2D(&builder, const_arr); |
| |
| DotDimensionNumbers dims; |
| dims.add_lhs_contracting_dimensions(1); |
| dims.add_rhs_contracting_dimensions(1); |
| DotGeneral(lhs, rhs, dims); |
| |
| // Constant folding are disabled by default in unit tests. algsimp |
| // optimization can be applied multiple times if we fold the transpose |
| // and reshape that are moved to the constant side of the dot. |
| mutable_debug_options()->clear_xla_disable_hlo_passes(); |
| ComputeAndCompare(&builder, {}, error_spec_); |
| } |
| |
| XLA_TEST_F(DotOperationTextTest, WiderIntegralResultAccumulation) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule WiderIntegralAccumulation |
| |
| ENTRY MatrixVectorComplex { |
| p0 = s8[5,5]{1,0} parameter(0) |
| p1 = s16[5,1]{0,1} parameter(1) |
| ROOT dot = s32[5,1]{1,0} dot(p0, p1), lhs_contracting_dims={1}, |
| rhs_contracting_dims={0} |
| } |
| )"; |
| |
| EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); |
| } |
| |
| // This benchmark is to show the performance impact of the following |
| // transformation: |
| // dot(reshape(transpose(A)), Const) ==> |
| // dot(reshape(A), reshape(transpose(reshape(Const)))), |
| // and then fold the reshape and transpose on the Const side. |
| // We can compare performance with and without algsimp pass to see the impact. |
| void DOT_ReorderContracting(::testing::benchmark::State& state) { |
| se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); |
| auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); |
| se::StreamExecutorMemoryAllocator allocator(platform, executors); |
| |
| xla::LocalClientOptions client_options; |
| client_options.set_platform(platform); |
| auto client = |
| ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); |
| |
| int device_ordinal = client->default_device_ordinal(); |
| |
| const int64_t d0 = 128; |
| const int64_t d1 = 128; |
| const int64_t d2 = 128; |
| const int64_t d3 = 128; |
| |
| Array3D<float> input_arr(d0, d1, d2); |
| Array2D<float> const_arr(d1 * d2, d3); |
| input_arr.FillIota(0); |
| const_arr.FillIota(0); |
| XlaBuilder builder("ReorderContracting"); |
| auto t0 = |
| Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {d0, d1, d2}), "param0"); |
| auto t1 = Transpose(t0, {0, 2, 1}); |
| auto lhs = Reshape(t1, {d0, d2 * d1}); |
| auto rhs = ConstantR2FromArray2D(&builder, const_arr); |
| Dot(lhs, rhs); |
| auto computation = builder.Build().value(); |
| |
| auto input_literal = LiteralUtil::CreateR3FromArray3D<float>(input_arr); |
| ScopedShapedBuffer buffer0 = |
| client->LiteralToShapedBuffer(input_literal, device_ordinal).value(); |
| |
| TF_ASSERT_OK_AND_ASSIGN( |
| auto executables, client->Compile(computation, {&buffer0.on_host_shape()}, |
| ExecutableBuildOptions())); |
| auto executable = std::move(executables[0]); |
| |
| se::Stream stream(executors[device_ordinal]); |
| stream.Init(); |
| |
| ExecutableRunOptions options; |
| options.set_allocator(&allocator); |
| |
| const int kWarmups = 2; |
| for (int i = 0; i < kWarmups; ++i) { |
| ASSERT_IS_OK(executable->Run({&buffer0}, options)); |
| } |
| |
| const int64_t total_bytes = d0 * d1 * d2 + d1 * d2 * d3 + d0 * d3; |
| for (auto s : state) { |
| ASSERT_IS_OK(executable->Run({&buffer0}, options)); |
| } |
| state.SetBytesProcessed(state.iterations() * total_bytes * sizeof(float)); |
| } |
| |
| BENCHMARK(DOT_ReorderContracting)->UseRealTime(); |
| |
| } // namespace |
| } // namespace xla |