blob: 47be788151e485d4888cf9dbd3e0816a63bafb44 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class RaggedGatherOpTest : public ::tensorflow::OpsTestBase {
protected:
// Builds the tensorflow test graph for RaggedGather.
template <typename VALUE_TYPE, typename INDEX_TYPE>
void BuildRaggedGatherGraph(
const TensorShape& indices_shape, const std::vector<INDEX_TYPE>& indices,
const std::vector<std::vector<int64>>& params_nested_splits,
const TensorShape& params_dense_values_shape,
const gtl::ArraySlice<VALUE_TYPE> params_dense_values) {
const auto& value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
const auto& index_dtype = DataTypeToEnum<INDEX_TYPE>::v();
int64 PARAMS_RAGGED_RANK = params_nested_splits.size();
int64 num_splits = PARAMS_RAGGED_RANK + indices_shape.dims() - 1;
TF_ASSERT_OK(
NodeDefBuilder("tested_op", "RaggedGather")
.Input(FakeInput(PARAMS_RAGGED_RANK)) // params_nested_splits
.Input(FakeInput(value_dtype)) // params_dense_values
.Input(FakeInput(index_dtype)) // indices
.Attr("PARAMS_RAGGED_RANK", PARAMS_RAGGED_RANK)
.Attr("OUTPUT_RAGGED_RANK", num_splits)
.Attr("Tvalues", value_dtype)
.Attr("Tindices", index_dtype)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
for (const auto& splits : params_nested_splits) {
int64 splits_size = splits.size();
AddInputFromArray<int64>(TensorShape({splits_size}), splits);
}
AddInputFromArray<VALUE_TYPE>(params_dense_values_shape,
params_dense_values);
AddInputFromArray<INDEX_TYPE>(indices_shape, indices);
}
};
TEST_F(RaggedGatherOpTest, RaggedGather) {
// indices = [2, 1, 0, 3]
// params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
// params.shape = [4, None]
BuildRaggedGatherGraph<float, int32>(
TensorShape({4}), // indices.shape
{2, 1, 0, 3}, // indices
{{0, 3, 3, 7, 9}}, // params_nested_splits
TensorShape({9}), // params_dense_values.shape
{.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
);
TF_ASSERT_OK(RunOpKernel());
// Expected: [[.4, .5, .6, .7], [.1, .2, .3], [], [.8, .9]]
test::ExpectTensorEqual<int64>(*GetOutput(0),
test::AsTensor<int64>({0, 4, 4, 7, 9}));
test::ExpectTensorNear<float>(
*GetOutput(1),
test::AsTensor<float>({.4, .5, .6, .7, .1, .2, .3, .8, .9}), 0.1);
}
TEST_F(RaggedGatherOpTest, RaggedGather_3DParams) {
// indices = [2, 1, 0, 2, 3]
// params = [[[]], [[.1, 2], [.3]], [], [[.4, .5], [.6, .7, .8]], [[.9]]]
// params.shape = [5, None, None]
BuildRaggedGatherGraph<float, int32>(
TensorShape({5}), // indices.shape
{2, 1, 0, 2, 3}, // indices
{{0, 1, 3, 3, 5, 6}, {0, 0, 2, 3, 5, 8, 9}}, // params_nested_splits
TensorShape({9}), // params_dense_values.shape
{.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
);
TF_ASSERT_OK(RunOpKernel());
// Expected: [[], [[.1, 2], [.3]], [[]], [], [[.4, .5], [.6, .7, .8]]]
test::ExpectTensorEqual<int64>(*GetOutput(0),
test::AsTensor<int64>({0, 0, 2, 3, 3, 5}));
test::ExpectTensorEqual<int64>(*GetOutput(1),
test::AsTensor<int64>({0, 2, 3, 3, 5, 8}));
test::ExpectTensorNear<float>(
*GetOutput(2), test::AsTensor<float>({.1, .2, .3, .4, .5, .6, .7, .8}),
0.1);
}
TEST_F(RaggedGatherOpTest, RaggedGather_4DParams) {
// indices = [2, 1, 0, 2]
// params = [[[]], [[[1, 2], [3, 4], [5, 6]], [[7, 8]]], []]
// params.shape = [4, None, None, 2]
BuildRaggedGatherGraph<int32, int32>(
TensorShape({4}), // indices.shape
{2, 1, 0, 2}, // indices
{{0, 1, 3, 3}, {0, 0, 3, 4}}, // params_nested_splits
TensorShape({4, 2}), // params_dense_values.shape
{1, 2, 3, 4, 5, 6, 7, 8} // params_dense_values
);
TF_ASSERT_OK(RunOpKernel());
// Expected: [[],
// [[[1, 2], [3, 4], [5, 6]], [[7, 8]]],
// [[]],
// []]
test::ExpectTensorEqual<int64>(*GetOutput(0),
test::AsTensor<int64>({0, 0, 2, 3, 3}));
test::ExpectTensorEqual<int64>(*GetOutput(1),
test::AsTensor<int64>({0, 3, 4, 4}));
test::ExpectTensorEqual<int32>(
*GetOutput(2),
test::AsTensor<int32>({1, 2, 3, 4, 5, 6, 7, 8}, TensorShape({4, 2})));
}
TEST_F(RaggedGatherOpTest, RaggedGather_2DIndices) {
// indices = [[2, 1], [0, 3]]
// params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
BuildRaggedGatherGraph<float, int32>(
TensorShape({2, 2}), // indices.shape
{2, 1, 0, 3}, // indices
{{0, 3, 3, 7, 9}}, // params_nested_splits
TensorShape({9}), // params_dense_values.shape
{.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
);
TF_ASSERT_OK(RunOpKernel());
// Expected: [ [ [.4, .5, .6, .7], [.1, .2, .3] ],
// [ [], [.8, .9] ] ]
test::ExpectTensorEqual<int64>(*GetOutput(0),
test::AsTensor<int64>({0, 2, 4}));
test::ExpectTensorEqual<int64>(*GetOutput(1),
test::AsTensor<int64>({0, 4, 4, 7, 9}));
test::ExpectTensorNear<float>(
*GetOutput(2),
test::AsTensor<float>({.4, .5, .6, .7, .1, .2, .3, .8, .9}), 0.1);
}
TEST_F(RaggedGatherOpTest, RaggedGather_ScalarIndices) {
// indices = 2
// params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
BuildRaggedGatherGraph<float, int32>(
TensorShape({}), // indices.shape
{2}, // indices
{{0, 3, 3, 7, 9}}, // params_nested_splits
TensorShape({9}), // params_dense_values.shape
{.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
);
TF_ASSERT_OK(RunOpKernel());
// Expected: [.4, .5, .6, .7]
test::ExpectTensorNear<float>(*GetOutput(0),
test::AsTensor<float>({.4, .5, .6, .7}), 0.1);
}
TEST_F(RaggedGatherOpTest, RaggedGather_OutOfBounds) {
// indices = [2, 10]
// params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
BuildRaggedGatherGraph<float, int32>(
TensorShape({2}), // indices.shape
{2, 10}, // indices
{{0, 3, 3, 7, 9}}, // params_nested_splits
TensorShape({9}), // params_dense_values.shape
{.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
);
EXPECT_EQ("indices[1] = 10 is not in [0, 4)", RunOpKernel().error_message());
}
TEST_F(RaggedGatherOpTest, InvalidSplitsNotSorted) {
BuildRaggedGatherGraph<float, int32>(
TensorShape({2}), // indices.shape
{0, 2}, // indices
{{0, 3, 5, 2, 9}}, // params_nested_splits
TensorShape({9}), // params_dense_values.shape
{.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
);
EXPECT_EQ("Ragged splits must be sorted", RunOpKernel().error_message());
}
TEST_F(RaggedGatherOpTest, InvalidSplitsNegative) {
BuildRaggedGatherGraph<float, int32>(
TensorShape({2}), // indices.shape
{0, 2}, // indices
{{-1, 3, 2, 7, 9}}, // params_nested_splits
TensorShape({9}), // params_dense_values.shape
{.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
);
EXPECT_EQ("Ragged splits must be non-negative",
RunOpKernel().error_message());
}
TEST_F(RaggedGatherOpTest, InvalidSplitsEmpty) {
BuildRaggedGatherGraph<float, int32>(
TensorShape({0}), // indices.shape
{}, // indices
{{}}, // params_nested_splits
TensorShape({0}), // params_dense_values.shape
{} // params_dense_values
);
EXPECT_EQ("Ragged splits may not be empty", RunOpKernel().error_message());
}
TEST_F(RaggedGatherOpTest, InvalidSplitsTooBig) {
BuildRaggedGatherGraph<float, int32>(
TensorShape({2}), // indices.shape
{0, 2}, // indices
{{0, 20, 40, 80, 100}}, // params_nested_splits
TensorShape({9}), // params_dense_values.shape
{.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
);
EXPECT_EQ("Ragged splits must not point past values",
RunOpKernel().error_message());
}
TEST_F(RaggedGatherOpTest, BadValuesShape) {
BuildRaggedGatherGraph<float, int32>(
TensorShape({0}), // indices.shape
{}, // indices
{{0}}, // params_nested_splits
TensorShape({}), // params_dense_values.shape
{.1} // params_dense_values
);
EXPECT_EQ("params.rank must be nonzero", RunOpKernel().error_message());
}
TEST_F(RaggedGatherOpTest, ShapeFn) {
// RaggedGather(param_splits+, param_values, indices) -> [splits+, values]
ShapeInferenceTestOp op("RaggedGather");
(*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1);
(*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(1);
INFER_OK(op, "?;?;?", "[?];?");
INFER_OK(op, "[?];[?];[?]", "[?];[?]");
INFER_OK(op, "[?];[?,?,?];[?]", "[?];[?,d1_1,d1_2]");
INFER_OK(op, "[5];[10];[15]", "[?];[?]");
INFER_OK(op, "[5];[10,2];[15]", "[?];[?,d1_1]");
INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[5];[];[]");
INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,2];[];[5]");
(*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(2);
(*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(2);
INFER_OK(op, "?;?;?;?", "[?];[?];?");
INFER_OK(op, "[?];[?];[?];[?]", "[?];[?];[?]");
INFER_OK(op, "[?];[?];[?,?,?];[?]", "[?];[?];[?,d2_1,d2_2]");
INFER_OK(op, "[5];[10];[15];[20]", "[?];[?];[?]");
(*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1);
(*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(2);
INFER_OK(op, "?;?;?", "[?];[?];?");
INFER_OK(op, "[?];[?];[?,?]", "[?];[?];[?]");
INFER_OK(op, "[?];[?,?,?];[?,?]", "[?];[?];[?,d1_1,d1_2]");
INFER_OK(op, "[15];[20];[5,10]", "[?];[?];[?]");
INFER_OK(op, "[15];[20,2];[5,10]", "[?];[?];[?,d1_1]");
(*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1);
(*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(0);
INFER_OK(op, "[?];[?];[]", "[?]");
}
} // namespace
} // namespace tensorflow