blob: 94a7078898ba5d391939187b4746c316b18a241a [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <numeric>
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
namespace xla {
namespace {
class DynamicSliceTest : public ClientLibraryTestBase {
protected:
template <typename IndexT, typename DataT>
void TestR1() {
// Slice at dimension start.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {0}, {5}, {0, 1, 2, 3, 4});
// Slice in the middle.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {3}, {2, 3, 4});
// Slice at dimension boundaries.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {5}, {3}, {5, 6, 7});
// Zero element slice.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {0}, {});
}
template <typename IndexT, typename DataT>
void TestR1OOB() {
// Slice at dimension boundaries, but with out of bounds indices.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {4, 5, 6, 7});
}
template <typename IndexT, typename DataT>
void TestR2() {
// Slice at dimension start.
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 2},
{{1, 2}, {4, 5}});
// Slice in the middle.
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1},
{{5}, {8}});
// Slice at dimension boundaries.
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1},
{{5}, {8}});
// Zero element slice: 2x0.
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 0},
{{}, {}});
// Zero element slice: 0x2.
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {0, 2},
Array2D<int>(0, 2));
}
template <typename IndexT, typename DataT>
void TestR2OOB() {
// Slice at dimension boundaries, but with out of bounds indices.
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3},
{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
}
template <typename IndexT, typename DataT>
void TestR3() {
// R3 Shape: [2, 3, 2]
// clang-format off
// Slice at dimension start.
RunR3<IndexT, DataT>(
{{{1, 2}, {3, 4}, {5, 6}},
{{7, 8}, {9, 10}, {11, 12}}},
{0, 0, 0}, {2, 1, 2},
{{{1, 2}}, {{7, 8}}});
// Slice in the middle.
RunR3<IndexT, DataT>(
{{{1, 2}, {3, 4}, {5, 6}},
{{7, 8}, {9, 10}, {11, 12}}},
{0, 1, 1}, {2, 2, 1},
{{{4}, {6}}, {{10}, {12}}});
// clang-format on
}
template <typename IndexT, typename DataT>
void TestR3OOB() {
// Slice at dimension boundaries, but with out of bounds indices.
RunR3<IndexT, DataT>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1},
{2, 1, 2}, {{{5, 6}}, {{11, 12}}});
}
template <typename IndexT, typename DataT>
void RunR1(absl::Span<const int> input_values_int,
const std::vector<IndexT> slice_starts,
const std::vector<int64_t>& slice_sizes,
absl::Span<const int> expected_values_int) {
// bfloat16 has explicit constructors, so it does not implicitly convert the
// way built-in types do, which is why we can't take the parameter as an
// Span<DataT>. We also can't convert it to a vector, because
// vector<bool> is special so that it cannot be a Span<bool>, which
// is what the code below wants. So instead we do this.
Literal input_values =
LiteralUtil::CreateR1(input_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie();
Literal expected_values =
std::move(LiteralUtil::CreateR1(expected_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR0Parameter<IndexT>(
slice_starts[0], 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
auto input = ConstantLiteral(&builder, input_values);
DynamicSlice(input, absl::Span<const XlaOp>({starts}), slice_sizes);
// Run computation and compare against expected values.
ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
}
template <typename IndexT, typename DataT>
void RunR2(const Array2D<int>& input_values_int,
const std::vector<IndexT> slice_starts,
const std::vector<int64_t>& slice_sizes,
const Array2D<int>& expected_values_int) {
Literal input_values =
std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
std::vector<XlaOp> starts(2);
std::vector<std::unique_ptr<GlobalData>> start_data(2);
for (int i = 0; i < 2; ++i) {
start_data[i] = CreateR0Parameter<IndexT>(
slice_starts[i], i, "slice_starts", &builder, &starts[i]);
}
// Build dynamic slice computation.
auto input = ConstantLiteral(&builder, input_values);
DynamicSlice(input, starts, slice_sizes);
// Run computation and compare against expected values.
std::vector<GlobalData*> argument_ptrs;
absl::c_transform(start_data, std::back_inserter(argument_ptrs),
[](const std::unique_ptr<GlobalData>& argument) {
return argument.get();
});
ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs);
}
template <typename IndexT, typename DataT>
void RunR3(const Array3D<int>& input_values_int,
const std::vector<IndexT> slice_starts,
const std::vector<int64_t>& slice_sizes,
const Array3D<int>& expected_values_int) {
Literal input_values =
std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
std::vector<XlaOp> starts(3);
std::vector<std::unique_ptr<GlobalData>> start_data(3);
for (int i = 0; i < 3; ++i) {
start_data[i] = CreateR0Parameter<IndexT>(
slice_starts[i], i, "slice_starts", &builder, &starts[i]);
}
// Build dynamic slice computation.
auto input = ConstantLiteral(&builder, input_values);
DynamicSlice(input, starts, slice_sizes);
// Run computation and compare against expected values.
std::vector<GlobalData*> argument_ptrs;
absl::c_transform(start_data, std::back_inserter(argument_ptrs),
[](const std::unique_ptr<GlobalData>& argument) {
return argument.get();
});
ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs);
}
};
XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB<int32_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64_t, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64_t, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt32R1OOB) {
RunR1<uint32_t, int32_t>({0, 1, 2, 3, 4}, {2147483648u}, {2}, {3, 4});
}
XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB<int32_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64_t, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64_t, int32_t>(); }
XLA_TEST_F(DynamicSliceTest, UInt32R2OOB) {
RunR2<uint32_t, int32_t>({{0, 1}, {2, 3}}, {2147483648u, 0}, {1, 1}, {{2}});
}
XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32_t, float>(); }
XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB<int32_t, float>(); }
XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64_t, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64_t, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt32R3OOB) {
RunR3<uint32_t, int32_t>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}},
{2147483648u, 0, 2147483648u}, {1, 1, 1}, {{{5}}});
}
XLA_TEST_F(DynamicSliceTest, Int32R1Pred) {
// Slice at dimension start.
RunR1<int32_t, bool>({true, false, false, true, false, true, true, false},
{0}, {5}, {true, false, false, true, false});
// Slice in the middle.
RunR1<int32_t, bool>({true, false, false, true, false, true, true, false},
{2}, {3}, {false, true, false});
// Slice at dimension boundaries.
RunR1<int32_t, bool>({true, false, false, true, false, true, true, false},
{5}, {3}, {true, true, false});
// Zero element slice.
RunR1<int32_t, bool>({true, false, false, true, false, true, true, false},
{2}, {0}, {});
}
XLA_TEST_F(DynamicSliceTest, Int32R2Pred) {
// Slice at dimension start.
RunR2<int32_t, bool>(
{{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
{2, 2}, {{true, false}, {false, false}});
// Slice in the middle.
RunR2<int32_t, bool>(
{{true, false, true}, {false, false, true}, {true, true, false}}, {1, 1},
{2, 1}, {{false}, {true}});
// Slice at dimension boundaries.
RunR2<int32_t, bool>(
{{true, false, true}, {false, false, true}, {true, true, false}}, {1, 1},
{2, 1}, {{false}, {true}});
// Zero element slice: 2x0.
RunR2<int32_t, bool>(
{{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
{2, 0}, {{}, {}});
// Zero element slice: 0x2.
RunR2<int32_t, bool>(
{{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
{0, 2}, Array2D<int>(0, 2));
}
XLA_TEST_F(DynamicSliceTest, Int32R3Pred) {
// R3 Shape: [2, 3, 2]
// clang-format off
// Slice at dimension start.
RunR3<int32_t, bool>(
{{{true, false}, {false, true}, {true, true}},
{{false, true}, {true, false}, {false, false}}},
{0, 0, 0}, {2, 1, 2},
{{{true, false}}, {{false, true}}});
// Slice in the middle.
RunR3<int32_t, bool>(
{{{true, false}, {false, true}, {true, true}},
{{false, true}, {true, false}, {false, false}}},
{0, 1, 1}, {2, 2, 1},
{{{true}, {true}}, {{false}, {false}}});
// clang-format on
}
class DynamicUpdateSliceTest : public ClientLibraryTestBase {
protected:
template <typename IndexT, typename DataT>
void TestR0() {
// Disable algebraic simplifier, otherwise the op will be replaced by a
// constant.
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
"algsimp");
RunR0<IndexT, DataT>(0, 123, {}, 123);
}
template <typename IndexT, typename DataT>
void TestR1() {
// Slice at dimension start.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {0},
{8, 9, 10, 3, 4, 5, 6, 7});
// Slice in the middle.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {2},
{0, 1, 8, 9, 10, 5, 6, 7});
// Slice at dimension boundaries.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {5},
{0, 1, 2, 3, 4, 8, 9, 10});
// Zero-sized update.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {}, {2},
{0, 1, 2, 3, 4, 5, 6, 7});
}
template <typename IndexT, typename DataT>
void TestR2() {
// Slice at dimension start.
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {0, 0},
{{10, 11, 3}, {4, 5, 6}, {7, 8, 9}});
// Slice in the middle.
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {1, 1},
{{1, 2, 3}, {4, 10, 11}, {7, 8, 9}});
// Slice at dimension boundaries.
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 1},
{{1, 2, 3}, {4, 5, 6}, {7, 10, 11}});
// Zero-sized update.
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{}}, {2, 1},
{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
}
template <typename IndexT, typename DataT>
void TestR3() {
// R3 Shape: [2, 3, 2]
// Slice at dimension start.
RunR3<IndexT, DataT>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}},
{{{13, 14}, {15, 16}}, {{17, 18}, {19, 20}}}, {0, 0, 0},
{{{13, 14}, {15, 16}, {5, 6}}, {{17, 18}, {19, 20}, {11, 12}}});
// Slice in the middle.
RunR3<IndexT, DataT>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
{1, 1, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}});
}
template <typename IndexT, typename DataT>
void TestOOB() {
// // Slice at dimension boundaries, but with out of bounds indices.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6},
{0, 1, 2, 3, 4, 8, 9, 10});
// R2 Shape: [3, 3]
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2},
{{1, 2, 3}, {4, 5, 6}, {7, 10, 11}});
// R3 Shape: [2, 3, 2]
RunR3<IndexT, DataT>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
{1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}});
}
template <typename IndexT, typename DataT>
void RunR0(int input_value_int, int update_value_int,
const std::vector<IndexT> slice_starts, int expected_value_int) {
Literal input_value =
std::move(LiteralUtil::CreateR0(input_value_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_value =
std::move(LiteralUtil::CreateR0(update_value_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_value =
std::move(LiteralUtil::CreateR0(expected_value_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Build dynamic slice computation.
auto input = ConstantLiteral(&builder, input_value);
auto update = ConstantLiteral(&builder, update_value);
DynamicUpdateSlice(input, update, absl::Span<const XlaOp>({}));
// Run computation and compare against expected values.
ComputeAndCompareLiteral(&builder, expected_value, {});
}
template <typename IndexT, typename DataT>
void RunR1(absl::Span<const int> input_values_int,
absl::Span<const int> update_values_int,
const std::vector<IndexT> slice_starts,
absl::Span<const int> expected_values_int) {
Literal input_values =
std::move(LiteralUtil::CreateR1(input_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
std::move(LiteralUtil::CreateR1(update_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
std::move(LiteralUtil::CreateR1(expected_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR0Parameter<IndexT>(
slice_starts[0], 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
auto input = ConstantLiteral(&builder, input_values);
auto update = ConstantLiteral(&builder, update_values);
DynamicUpdateSlice(input, update, absl::Span<const XlaOp>({starts}));
// Run computation and compare against expected values.
ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
}
template <typename IndexT, typename DataT>
void RunR2(const Array2D<int>& input_values_int,
const Array2D<int>& update_values_int,
const std::vector<IndexT> slice_starts,
const Array2D<int>& expected_values_int) {
Literal input_values =
std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
std::move(LiteralUtil::CreateR2FromArray2D(update_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
std::vector<XlaOp> starts(2);
std::vector<std::unique_ptr<GlobalData>> start_data(2);
for (int i = 0; i < 2; ++i) {
start_data[i] = CreateR0Parameter<IndexT>(
slice_starts[i], i, "slice_starts", &builder, &starts[i]);
}
// Build dynamic slice computation.
auto input = ConstantLiteral(&builder, input_values);
auto update = ConstantLiteral(&builder, update_values);
DynamicUpdateSlice(input, update, starts);
// Run computation and compare against expected values.
std::vector<GlobalData*> argument_ptrs;
absl::c_transform(start_data, std::back_inserter(argument_ptrs),
[](const std::unique_ptr<GlobalData>& argument) {
return argument.get();
});
ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs);
}
template <typename IndexT, typename DataT>
void RunR3(const Array3D<int>& input_values_int,
const Array3D<int>& update_values_int,
const std::vector<IndexT> slice_starts,
const Array3D<int>& expected_values_int) {
Literal input_values =
std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
std::move(LiteralUtil::CreateR3FromArray3D(update_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
std::vector<XlaOp> starts(3);
std::vector<std::unique_ptr<GlobalData>> start_data(3);
for (int i = 0; i < 3; ++i) {
start_data[i] = CreateR0Parameter<IndexT>(
slice_starts[i], i, "slice_starts", &builder, &starts[i]);
}
// Build dynamic slice computation.
auto input = ConstantLiteral(&builder, input_values);
auto update = ConstantLiteral(&builder, update_values);
DynamicUpdateSlice(input, update, starts);
// Run computation and compare against expected values.
std::vector<GlobalData*> argument_ptrs;
absl::c_transform(start_data, std::back_inserter(argument_ptrs),
[](const std::unique_ptr<GlobalData>& argument) {
return argument.get();
});
ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs);
}
template <class T>
void RunR3Contiguous(std::vector<int32_t> operand_shape, int32_t index,
int32_t size) {
const int32_t kSeq = operand_shape[0];
const int32_t kBatch = operand_shape[1];
const int32_t kDim = operand_shape[2];
Array3D<T> input_values(kSeq, kBatch, kDim);
Array3D<T> update_values(size, kBatch, kDim);
Array3D<T> expected_values(kSeq, kBatch, kDim);
index = std::min(std::max(0, index), kSeq - size);
input_values.FillIota(static_cast<T>(0));
T value = static_cast<T>(10);
update_values.FillIota(static_cast<T>(value));
// TODO(b/34128753) Expected values may vary depending on backend when
// the indices are out of bounds.
expected_values.FillIota(static_cast<T>(0));
for (int i = 0; i < size; i++) {
for (int j = 0; j < kBatch; j++) {
for (int k = 0; k < kDim; k++) {
expected_values(index + i, j, k) = value++;
}
}
}
if (VLOG_IS_ON(1)) {
DumpArray<T>("input", input_values);
DumpArray<T>("update", update_values);
DumpArray<T>("expected", expected_values);
}
// Build dynamic slice computation.
XlaBuilder builder(TestName());
// Initialize and transfer input parameter.
XlaOp input;
std::unique_ptr<GlobalData> input_data =
CreateR3Parameter<T>(input_values, 0, "input_values", &builder, &input);
// Initialize and transfer update parameter.
XlaOp update;
std::unique_ptr<GlobalData> update_data = CreateR3Parameter<T>(
update_values, 1, "update_values", &builder, &update);
auto constant_index = ConstantR0<int32_t>(&builder, index);
auto zero = ConstantR0<int32_t>(&builder, 0);
DynamicUpdateSlice(input, update, {constant_index, zero, zero});
// Run computation and compare against expected values.
ComputeAndCompareR3<T>(&builder, expected_values,
{input_data.get(), update_data.get()},
ErrorSpec(0.000001));
}
template <typename NativeT>
void DumpArray(const std::string& name, const Array3D<NativeT> values) {
Literal literal = LiteralUtil::CreateR3FromArray3D<NativeT>(values);
LOG(INFO) << name << ":" << literal.ToString();
}
};
XLA_TEST_F(DynamicUpdateSliceTest, Int32R0BF16) { TestR0<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0<int64_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0<uint64_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R1BF16) { TestR1<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1<int64_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1<uint64_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt32R1OOB) {
RunR1<uint32_t, int32_t>({0, 1, 2, 3, 4}, {5, 6}, {2147483648u},
{0, 1, 2, 5, 6});
}
XLA_TEST_F(DynamicUpdateSliceTest, Int32R2BF16) { TestR2<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2<int64_t, int64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2<uint64_t, int32_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt32R2OOB) {
RunR2<uint32_t, int32_t>({{0, 1}, {2, 3}}, {{4}}, {2147483648u, 0},
{{0, 1}, {4, 3}});
}
XLA_TEST_F(DynamicUpdateSliceTest, Int32R3BF16) { TestR3<int32_t, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64_t, int64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64_t, uint64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt32R3OOB) {
RunR3<uint32_t, int32_t>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}, {{{8}}},
{2147483648u, 0, 2147483648u},
{{{0, 1}, {2, 3}}, {{4, 8}, {6, 7}}});
}
XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) {
TestOOB<int32_t, bfloat16>();
}
XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB<int32_t, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB<int64_t, int64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB<uint64_t, uint64_t>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) {
// Slice at dimension start.
RunR1<int32_t, bool>({false, false, true, true, false, true, true, false},
{true, true, false}, {0},
{true, true, false, true, false, true, true, false});
// Slice in the middle.
RunR1<int32_t, bool>({false, false, true, true, false, true, true, false},
{false, true, true}, {2},
{false, false, false, true, true, true, true, false});
// Slice at dimension boundaries.
RunR1<int32_t, bool>({false, false, true, true, false, true, true, false},
{false, true, true}, {5},
{false, false, true, true, false, false, true, true});
// Zero-sized update.
RunR1<int32_t, bool>({false, false, true, true, false, true, true, false}, {},
{2},
{false, false, true, true, false, true, true, false});
}
XLA_TEST_F(DynamicUpdateSliceTest, Int32R2Pred) {
// Slice at dimension start.
RunR2<int32_t, bool>(
{{false, true, false}, {true, false, true}, {false, true, true}},
{{true, false}}, {0, 0},
{{true, false, false}, {true, false, true}, {false, true, true}});
// Slice in the middle.
RunR2<int32_t, bool>(
{{false, true, false}, {true, false, true}, {false, true, true}},
{{true, false}}, {1, 1},
{{false, true, false}, {true, true, false}, {false, true, true}});
// Slice at dimension boundaries.
RunR2<int32_t, bool>(
{{false, true, false}, {true, false, true}, {false, true, true}},
{{true, false}}, {2, 1},
{{false, true, false}, {true, false, true}, {false, true, false}});
// Zero-sized update.
RunR2<int32_t, bool>(
{{false, true, false}, {true, false, true}, {false, true, true}}, {{}},
{2, 1}, {{false, true, false}, {true, false, true}, {false, true, true}});
}
XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) {
// R3 Shape: [2, 3, 2]
// Slice at dimension start.
RunR3<int32_t, bool>(
{{{true, false}, {false, true}, {true, true}},
{{false, false}, {false, true}, {true, false}}},
{{{false, true}, {true, false}}, {{true, true}, {false, true}}},
{0, 0, 0},
{{{false, true}, {true, false}, {true, true}},
{{true, true}, {false, true}, {true, false}}});
// Slice in the middle.
RunR3<int32_t, bool>({{{true, false}, {false, true}, {true, true}},
{{false, false}, {false, true}, {true, false}}},
{{{false}, {true}}}, {1, 1, 1},
{{{true, false}, {false, true}, {true, true}},
{{false, false}, {false, false}, {true, true}}});
}
// Tests for simple R3 case where the update is contiguous (i.e. the minor
// two dimensions are not sliced).
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) {
// Single element, index in-bounds
std::vector<int32_t> operand_shape({4, 5, 2});
RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) {
// Single element, index in-bounds
std::vector<int32_t> operand_shape({4, 5, 2});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) {
// Multiples element, index in-bounds.
std::vector<int32_t> operand_shape({4, 5, 2});
RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/2);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) {
// Multiples element, index in-bounds.
std::vector<int32_t> operand_shape({4, 5, 2});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/2);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) {
// Multiple element, index out of bounds.
std::vector<int32_t> operand_shape({4, 5, 2});
RunR3Contiguous<float>(operand_shape, /*index=*/3, /*size=*/2);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) {
// Multiple element, index out of bounds.
std::vector<int32_t> operand_shape({4, 5, 2});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/3, /*size=*/2);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLarge) {
// Multiple element, update size larger than operand.
std::vector<int32_t> operand_shape({4, 5, 2});
RunR3Contiguous<float>(operand_shape, /*index=*/5, /*size=*/2);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLargeBF16) {
// Multiple element, update size larger than operand.
std::vector<int32_t> operand_shape({4, 5, 2});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/5, /*size=*/2);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) {
std::vector<int32_t> operand_shape({3, 123, 247});
RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnalignedBF16) {
std::vector<int32_t> operand_shape({3, 123, 247});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
}
// TODO(b/34134076) Disabled on GPU 2016-01-06 due to out-of-memory error.
XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLarger)) {
std::vector<int32_t> operand_shape({32, 128, 1024});
RunR3Contiguous<float>(operand_shape, /*index=*/7, /*size=*/1);
}
XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLargerBF16)) {
std::vector<int32_t> operand_shape({32, 128, 1024});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/7, /*size=*/1);
}
// This test that buffer assignment does not alias constants with the output of
// dynamic update slice.
XLA_TEST_F(HloTestBase, AddOfDUS) {
const char* hlo_string = R"(
HloModule m
test {
o = s32[6] constant({2,3,4,5,6,7})
i = s32[] parameter(0)
u = s32[2] parameter(1)
dus = s32[6] dynamic-update-slice(o,u,i)
a = s32[6] add(dus, dus)
j = s32[] parameter(2)
ROOT ds = s32[2] dynamic-slice(a, j), dynamic_slice_sizes={2}
}
)";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
}
void BM_DynamicSlice(::testing::benchmark::State& state) {
se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
se::StreamExecutorMemoryAllocator allocator(platform, executors);
LocalClient* client =
ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
auto* transfer_manager =
TransferManager::GetForPlatform(platform).ValueOrDie();
int device_ordinal = client->default_device_ordinal();
XlaBuilder builder("DynamicSlice");
// Create input as a constant: shape [1, 2, 3, 4]
auto input_literal = LiteralUtil::CreateR4(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
auto input = ConstantLiteral(&builder, input_literal);
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
// Create dynamic slice start indices as a parameter: shape [4]
auto start_indices_shape = ShapeUtil::MakeShape(S32, {});
std::vector<XlaOp> start_indices(4);
std::vector<ScopedShapedBuffer> shaped_buffers;
std::vector<const Shape*> host_shapes(4);
for (int i = 0; i < 4; ++i) {
start_indices[i] =
Parameter(&builder, i, start_indices_shape, "start_indices");
auto start_index_literal = LiteralUtil::CreateR0<int32_t>(i + 1);
// Initialize and transfer parameter buffer.
shaped_buffers.emplace_back(
client->backend()
.transfer_manager()
->AllocateScopedShapedBuffer(start_indices_shape, &allocator,
/*device_ordinal=*/0)
.value());
host_shapes[i] = &shaped_buffers[i].on_host_shape();
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
stream.get(), start_index_literal, shaped_buffers[i]));
}
// Add DynamicSlice op to the computation.
DynamicSlice(input, start_indices, {1, 1, 1, 1});
auto computation = builder.Build().value();
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
client->Compile(computation, host_shapes, ExecutableBuildOptions()));
auto executable = std::move(executables[0]);
// Run some warm-up executions.
ExecutableRunOptions options;
options.set_allocator(&allocator);
const int kWarmups = 2;
std::vector<const ShapedBuffer*> shaped_buffer_ptrs;
absl::c_transform(shaped_buffers, std::back_inserter(shaped_buffer_ptrs),
[](const ScopedShapedBuffer& buffer) { return &buffer; });
for (int i = 0; i < kWarmups; ++i) {
auto result = executable->Run(shaped_buffer_ptrs, options);
ASSERT_TRUE(result.ok());
}
// Run benchmark.
for (auto s : state) {
auto result = executable->Run(shaped_buffer_ptrs, options);
ASSERT_TRUE(result.ok());
}
}
BENCHMARK(BM_DynamicSlice);
} // namespace
} // namespace xla