| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #include "tensorflow/core/framework/fake_input.h" |
| #include "tensorflow/core/framework/node_def_builder.h" |
| #include "tensorflow/core/framework/shape_inference.h" |
| #include "tensorflow/core/framework/shape_inference_testutil.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/framework/tensor_testutil.h" |
| #include "tensorflow/core/kernels/ops_testutil.h" |
| #include "tensorflow/core/platform/test.h" |
| |
| namespace tensorflow { |
| namespace { |
| |
| class RaggedGatherOpTest : public ::tensorflow::OpsTestBase { |
| protected: |
| // Builds the tensorflow test graph for RaggedGather. |
| template <typename VALUE_TYPE, typename INDEX_TYPE> |
| void BuildRaggedGatherGraph( |
| const TensorShape& indices_shape, const std::vector<INDEX_TYPE>& indices, |
| const std::vector<std::vector<int64>>& params_nested_splits, |
| const TensorShape& params_dense_values_shape, |
| const gtl::ArraySlice<VALUE_TYPE> params_dense_values) { |
| const auto& value_dtype = DataTypeToEnum<VALUE_TYPE>::v(); |
| const auto& index_dtype = DataTypeToEnum<INDEX_TYPE>::v(); |
| int64 PARAMS_RAGGED_RANK = params_nested_splits.size(); |
| int64 num_splits = PARAMS_RAGGED_RANK + indices_shape.dims() - 1; |
| TF_ASSERT_OK( |
| NodeDefBuilder("tested_op", "RaggedGather") |
| .Input(FakeInput(PARAMS_RAGGED_RANK)) // params_nested_splits |
| .Input(FakeInput(value_dtype)) // params_dense_values |
| .Input(FakeInput(index_dtype)) // indices |
| .Attr("PARAMS_RAGGED_RANK", PARAMS_RAGGED_RANK) |
| .Attr("OUTPUT_RAGGED_RANK", num_splits) |
| .Attr("Tvalues", value_dtype) |
| .Attr("Tindices", index_dtype) |
| .Finalize(node_def())); |
| TF_ASSERT_OK(InitOp()); |
| for (const auto& splits : params_nested_splits) { |
| int64 splits_size = splits.size(); |
| AddInputFromArray<int64>(TensorShape({splits_size}), splits); |
| } |
| AddInputFromArray<VALUE_TYPE>(params_dense_values_shape, |
| params_dense_values); |
| AddInputFromArray<INDEX_TYPE>(indices_shape, indices); |
| } |
| }; |
| |
| TEST_F(RaggedGatherOpTest, RaggedGather) { |
| // indices = [2, 1, 0, 3] |
| // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] |
| // params.shape = [4, None] |
| BuildRaggedGatherGraph<float, int32>( |
| TensorShape({4}), // indices.shape |
| {2, 1, 0, 3}, // indices |
| {{0, 3, 3, 7, 9}}, // params_nested_splits |
| TensorShape({9}), // params_dense_values.shape |
| {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values |
| ); |
| |
| TF_ASSERT_OK(RunOpKernel()); |
| |
| // Expected: [[.4, .5, .6, .7], [.1, .2, .3], [], [.8, .9]] |
| test::ExpectTensorEqual<int64>(*GetOutput(0), |
| test::AsTensor<int64>({0, 4, 4, 7, 9})); |
| test::ExpectTensorNear<float>( |
| *GetOutput(1), |
| test::AsTensor<float>({.4, .5, .6, .7, .1, .2, .3, .8, .9}), 0.1); |
| } |
| |
| TEST_F(RaggedGatherOpTest, RaggedGather_3DParams) { |
| // indices = [2, 1, 0, 2, 3] |
| // params = [[[]], [[.1, 2], [.3]], [], [[.4, .5], [.6, .7, .8]], [[.9]]] |
| // params.shape = [5, None, None] |
| BuildRaggedGatherGraph<float, int32>( |
| TensorShape({5}), // indices.shape |
| {2, 1, 0, 2, 3}, // indices |
| {{0, 1, 3, 3, 5, 6}, {0, 0, 2, 3, 5, 8, 9}}, // params_nested_splits |
| TensorShape({9}), // params_dense_values.shape |
| {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values |
| ); |
| |
| TF_ASSERT_OK(RunOpKernel()); |
| |
| // Expected: [[], [[.1, 2], [.3]], [[]], [], [[.4, .5], [.6, .7, .8]]] |
| test::ExpectTensorEqual<int64>(*GetOutput(0), |
| test::AsTensor<int64>({0, 0, 2, 3, 3, 5})); |
| test::ExpectTensorEqual<int64>(*GetOutput(1), |
| test::AsTensor<int64>({0, 2, 3, 3, 5, 8})); |
| test::ExpectTensorNear<float>( |
| *GetOutput(2), test::AsTensor<float>({.1, .2, .3, .4, .5, .6, .7, .8}), |
| 0.1); |
| } |
| |
| TEST_F(RaggedGatherOpTest, RaggedGather_4DParams) { |
| // indices = [2, 1, 0, 2] |
| // params = [[[]], [[[1, 2], [3, 4], [5, 6]], [[7, 8]]], []] |
| // params.shape = [4, None, None, 2] |
| BuildRaggedGatherGraph<int32, int32>( |
| TensorShape({4}), // indices.shape |
| {2, 1, 0, 2}, // indices |
| {{0, 1, 3, 3}, {0, 0, 3, 4}}, // params_nested_splits |
| TensorShape({4, 2}), // params_dense_values.shape |
| {1, 2, 3, 4, 5, 6, 7, 8} // params_dense_values |
| ); |
| |
| TF_ASSERT_OK(RunOpKernel()); |
| |
| // Expected: [[], |
| // [[[1, 2], [3, 4], [5, 6]], [[7, 8]]], |
| // [[]], |
| // []] |
| test::ExpectTensorEqual<int64>(*GetOutput(0), |
| test::AsTensor<int64>({0, 0, 2, 3, 3})); |
| test::ExpectTensorEqual<int64>(*GetOutput(1), |
| test::AsTensor<int64>({0, 3, 4, 4})); |
| test::ExpectTensorEqual<int32>( |
| *GetOutput(2), |
| test::AsTensor<int32>({1, 2, 3, 4, 5, 6, 7, 8}, TensorShape({4, 2}))); |
| } |
| |
| TEST_F(RaggedGatherOpTest, RaggedGather_2DIndices) { |
| // indices = [[2, 1], [0, 3]] |
| // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] |
| BuildRaggedGatherGraph<float, int32>( |
| TensorShape({2, 2}), // indices.shape |
| {2, 1, 0, 3}, // indices |
| {{0, 3, 3, 7, 9}}, // params_nested_splits |
| TensorShape({9}), // params_dense_values.shape |
| {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values |
| ); |
| |
| TF_ASSERT_OK(RunOpKernel()); |
| |
| // Expected: [ [ [.4, .5, .6, .7], [.1, .2, .3] ], |
| // [ [], [.8, .9] ] ] |
| test::ExpectTensorEqual<int64>(*GetOutput(0), |
| test::AsTensor<int64>({0, 2, 4})); |
| test::ExpectTensorEqual<int64>(*GetOutput(1), |
| test::AsTensor<int64>({0, 4, 4, 7, 9})); |
| test::ExpectTensorNear<float>( |
| *GetOutput(2), |
| test::AsTensor<float>({.4, .5, .6, .7, .1, .2, .3, .8, .9}), 0.1); |
| } |
| |
| TEST_F(RaggedGatherOpTest, RaggedGather_ScalarIndices) { |
| // indices = 2 |
| // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] |
| BuildRaggedGatherGraph<float, int32>( |
| TensorShape({}), // indices.shape |
| {2}, // indices |
| {{0, 3, 3, 7, 9}}, // params_nested_splits |
| TensorShape({9}), // params_dense_values.shape |
| {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values |
| ); |
| TF_ASSERT_OK(RunOpKernel()); |
| |
| // Expected: [.4, .5, .6, .7] |
| test::ExpectTensorNear<float>(*GetOutput(0), |
| test::AsTensor<float>({.4, .5, .6, .7}), 0.1); |
| } |
| |
| TEST_F(RaggedGatherOpTest, RaggedGather_OutOfBounds) { |
| // indices = [2, 10] |
| // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] |
| BuildRaggedGatherGraph<float, int32>( |
| TensorShape({2}), // indices.shape |
| {2, 10}, // indices |
| {{0, 3, 3, 7, 9}}, // params_nested_splits |
| TensorShape({9}), // params_dense_values.shape |
| {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values |
| ); |
| EXPECT_EQ("indices[1] = 10 is not in [0, 4)", RunOpKernel().error_message()); |
| } |
| |
| TEST_F(RaggedGatherOpTest, InvalidSplitsNotSorted) { |
| BuildRaggedGatherGraph<float, int32>( |
| TensorShape({2}), // indices.shape |
| {0, 2}, // indices |
| {{0, 3, 5, 2, 9}}, // params_nested_splits |
| TensorShape({9}), // params_dense_values.shape |
| {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values |
| ); |
| EXPECT_EQ("Ragged splits must be sorted", RunOpKernel().error_message()); |
| } |
| |
| TEST_F(RaggedGatherOpTest, InvalidSplitsNegative) { |
| BuildRaggedGatherGraph<float, int32>( |
| TensorShape({2}), // indices.shape |
| {0, 2}, // indices |
| {{-1, 3, 2, 7, 9}}, // params_nested_splits |
| TensorShape({9}), // params_dense_values.shape |
| {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values |
| ); |
| EXPECT_EQ("Ragged splits must be non-negative", |
| RunOpKernel().error_message()); |
| } |
| |
| TEST_F(RaggedGatherOpTest, InvalidSplitsEmpty) { |
| BuildRaggedGatherGraph<float, int32>( |
| TensorShape({0}), // indices.shape |
| {}, // indices |
| {{}}, // params_nested_splits |
| TensorShape({0}), // params_dense_values.shape |
| {} // params_dense_values |
| ); |
| EXPECT_EQ("Ragged splits may not be empty", RunOpKernel().error_message()); |
| } |
| |
| TEST_F(RaggedGatherOpTest, InvalidSplitsTooBig) { |
| BuildRaggedGatherGraph<float, int32>( |
| TensorShape({2}), // indices.shape |
| {0, 2}, // indices |
| {{0, 20, 40, 80, 100}}, // params_nested_splits |
| TensorShape({9}), // params_dense_values.shape |
| {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values |
| ); |
| EXPECT_EQ("Ragged splits must not point past values", |
| RunOpKernel().error_message()); |
| } |
| |
| TEST_F(RaggedGatherOpTest, BadValuesShape) { |
| BuildRaggedGatherGraph<float, int32>( |
| TensorShape({0}), // indices.shape |
| {}, // indices |
| {{0}}, // params_nested_splits |
| TensorShape({}), // params_dense_values.shape |
| {.1} // params_dense_values |
| ); |
| EXPECT_EQ("params.rank must be nonzero", RunOpKernel().error_message()); |
| } |
| |
| TEST_F(RaggedGatherOpTest, ShapeFn) { |
| // RaggedGather(param_splits+, param_values, indices) -> [splits+, values] |
| ShapeInferenceTestOp op("RaggedGather"); |
| |
| (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1); |
| (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(1); |
| INFER_OK(op, "?;?;?", "[?];?"); |
| INFER_OK(op, "[?];[?];[?]", "[?];[?]"); |
| INFER_OK(op, "[?];[?,?,?];[?]", "[?];[?,d1_1,d1_2]"); |
| INFER_OK(op, "[5];[10];[15]", "[?];[?]"); |
| INFER_OK(op, "[5];[10,2];[15]", "[?];[?,d1_1]"); |
| INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[5];[];[]"); |
| INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,2];[];[5]"); |
| |
| (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(2); |
| (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(2); |
| INFER_OK(op, "?;?;?;?", "[?];[?];?"); |
| INFER_OK(op, "[?];[?];[?];[?]", "[?];[?];[?]"); |
| INFER_OK(op, "[?];[?];[?,?,?];[?]", "[?];[?];[?,d2_1,d2_2]"); |
| INFER_OK(op, "[5];[10];[15];[20]", "[?];[?];[?]"); |
| |
| (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1); |
| (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(2); |
| INFER_OK(op, "?;?;?", "[?];[?];?"); |
| INFER_OK(op, "[?];[?];[?,?]", "[?];[?];[?]"); |
| INFER_OK(op, "[?];[?,?,?];[?,?]", "[?];[?];[?,d1_1,d1_2]"); |
| INFER_OK(op, "[15];[20];[5,10]", "[?];[?];[?]"); |
| INFER_OK(op, "[15];[20,2];[5,10]", "[?];[?];[?,d1_1]"); |
| |
| (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1); |
| (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(0); |
| INFER_OK(op, "[?];[?];[]", "[?]"); |
| } |
| |
| } // namespace |
| } // namespace tensorflow |