blob: 2f90e2dbc341cd84d01a3f5e8d8fd420dd73eb1d [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/dataset_test_base.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kNodeName[] = "reduce_dataset";
constexpr char kOpName[] = "ReduceDataset";
class ReduceDatasetOpTest : public DatasetOpsTestBase {
protected:
// Create a new `ReduceDataset` op kernel.
Status CreateReduceDatasetOpKernel(
const FunctionDefHelper::AttrValueWrapper &func,
const DataTypeVector &t_state, const DataTypeVector &output_types,
const std::vector<PartialTensorShape> &output_shapes,
bool use_inter_op_parallelism,
std::unique_ptr<OpKernel> *reduce_dataset_op_kernel) {
std::vector<string> components;
components.reserve(1 + t_state.size());
components.emplace_back("input_dataset");
for (int i = 0; i < t_state.size(); ++i) {
components.emplace_back(strings::StrCat("initial_state_", i));
}
NodeDef node_def = test::function::NDef(
kNodeName, kOpName, components,
{{"f", func},
{"Tstate", t_state},
{"Targuments", {}},
{"output_types", output_types},
{"output_shapes", output_shapes},
{"use_inter_op_parallelism", use_inter_op_parallelism}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, reduce_dataset_op_kernel));
return Status::OK();
}
// Create a new `ReduceDataset` op kernel context
Status CreateReduceDatasetContext(
OpKernel *const op_kernel,
gtl::InlinedVector<TensorValue, 4> *const inputs,
std::unique_ptr<OpKernelContext> *context) {
TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK();
}
};
struct RangeDatasetParam {
int64 start;
int64 end;
int64 step;
};
struct TestCase {
RangeDatasetParam range_data_param;
std::vector<Tensor> initial_state;
FunctionDefHelper::AttrValueWrapper func;
std::vector<FunctionDef> func_lib;
DataTypeVector t_state;
bool use_inter_op_parallelism;
std::vector<Tensor> expected_outputs;
DataTypeVector output_dtypes;
std::vector<PartialTensorShape> output_shapes;
};
// Test case 1: the input function has one output.
TestCase TestCase1() {
return {/*range_data_param*/ {0, 10, 1},
/*initial_state*/
{CreateTensor<int64>(TensorShape({}), {0})},
/*func*/
FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}}),
/*func_lib*/ {test::function::XAddY()},
/*t_state*/ {DT_INT64},
/*use_inter_op_parallelism*/ true,
/*expected_outputs*/
{CreateTensor<int64>(TensorShape({}), {45})},
/*output_dtypes*/ {DT_INT64},
/*output_shapes*/ {PartialTensorShape({})}};
}
// Test case 2: the reduce function has two inputs and two outputs. As the
// number of components of initial_state need to match with the reduce function
// outputs, the initial_state will have two components. It results in that
// the components of initial_state will be all the inputs for the reduce
// function, and the input dataset will not be involved in the
// reduce/aggregation process.
TestCase TestCase2() {
return {/*range_data_param*/ {1, 10, 1},
/*initial_state*/
{CreateTensor<int64>(TensorShape({}), {1}),
CreateTensor<int64>(TensorShape({}), {1})},
/*func*/
FunctionDefHelper::FunctionRef("XPlusOneXTimesY", {{"T", DT_INT64}}),
/*func_lib*/ {test::function::XPlusOneXTimesY()},
/*t_state*/ {DT_INT64, DT_INT64},
/*use_inter_op_parallelism*/ true,
/*expected_outputs*/
{CreateTensor<int64>(TensorShape({}), {10}),
CreateTensor<int64>(TensorShape({}),
{1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9})},
/*output_dtypes*/ {DT_INT64, DT_INT64},
/*output_shapes*/ {PartialTensorShape({}), PartialTensorShape({})}};
}
// Test case 3: the input dataset has no outputs, so the reduce dataset just
// returns the initial state.
TestCase TestCase3() {
return {/*range_data_param*/ {0, 0, 1},
/*initial_state*/
{CreateTensor<int64>(TensorShape({}), {1}),
CreateTensor<int64>(TensorShape({}), {3})},
/*func*/
FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}}),
/*func_lib*/ {test::function::XAddY()},
/*t_state*/ {DT_INT64, DT_INT64},
/*use_inter_op_parallelism*/ true,
/*expected_outputs*/
{CreateTensor<int64>(TensorShape({}), {1}),
CreateTensor<int64>(TensorShape({}), {3})},
/*output_dtypes*/ {DT_INT64, DT_INT64},
/*output_shapes*/ {PartialTensorShape({}), PartialTensorShape({})}};
}
class ParameterizedReduceDatasetOpTest
: public ReduceDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParameterizedReduceDatasetOpTest, Compute) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
std::unique_ptr<OpKernel> reduce_dataset_kernel;
TF_ASSERT_OK(CreateReduceDatasetOpKernel(
test_case.func, test_case.t_state, test_case.output_dtypes,
test_case.output_shapes, test_case.use_inter_op_parallelism,
&reduce_dataset_kernel));
DatasetBase *range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
std::vector<Tensor> initial_state = test_case.initial_state;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_dataset_tensor)});
for (auto &t : initial_state) {
inputs.emplace_back(&t);
}
std::unique_ptr<OpKernelContext> reduce_dataset_context;
TF_ASSERT_OK(CreateReduceDatasetContext(reduce_dataset_kernel.get(), &inputs,
&reduce_dataset_context));
TF_ASSERT_OK(
RunOpKernel(reduce_dataset_kernel.get(), reduce_dataset_context.get()));
int num_outputs = reduce_dataset_context->num_outputs();
EXPECT_EQ(num_outputs, test_case.expected_outputs.size());
for (int i = 0; i < num_outputs; i++) {
// output will be released by the op kernel context.
Tensor *output = reduce_dataset_context->mutable_output(i);
TF_EXPECT_OK(ExpectEqual(test_case.expected_outputs[i], *output));
}
}
INSTANTIATE_TEST_SUITE_P(ReduceDatasetOpTest, ParameterizedReduceDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>(
{TestCase1(), TestCase2(), TestCase3()})));
} // namespace
} // namespace data
} // namespace tensorflow