[tf.data] Add more unit test to check the correctness of `disable_intra_op_parallelism` optimization.
PiperOrigin-RevId: 325475258
Change-Id: I3a8480a15c0828add3fa37e7a7afe095a20a73e2
diff --git a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism_test.cc b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism_test.cc
index 76d6b46..b1c8865 100644
--- a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism_test.cc
@@ -70,10 +70,11 @@
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
EXPECT_EQ(output.node_size(), 6);
EXPECT_TRUE(graph_utils::ContainsNodeWithOp(op, output));
- NodeDef test_node = output.node(graph_utils::FindGraphNodeWithOp(op, output));
- NodeDef test_val = output.node(
- graph_utils::FindGraphNodeWithName(test_node.input(1), output));
- EXPECT_EQ(test_val.attr().at("value").tensor().int64_val(0), value);
+ NodeDef parallelism_node =
+ output.node(graph_utils::FindGraphNodeWithOp(op, output));
+ NodeDef parallelism_val = output.node(
+ graph_utils::FindGraphNodeWithName(parallelism_node.input(1), output));
+ EXPECT_EQ(parallelism_val.attr().at("value").tensor().int64_val(0), value);
}
INSTANTIATE_TEST_SUITE_P(
@@ -105,11 +106,19 @@
EXPECT_EQ(output.node_size(), 7);
EXPECT_TRUE(
graph_utils::ContainsNodeWithOp("MaxIntraOpParallelismDataset", output));
- NodeDef test_node = output.node(
- graph_utils::FindGraphNodeWithOp("MaxIntraOpParallelismDataset", output));
- NodeDef test_val = output.node(
- graph_utils::FindGraphNodeWithName(test_node.input(1), output));
- EXPECT_EQ(test_val.attr().at("value").tensor().int64_val(0), 1);
+ NodeDef sink_node =
+ output.node(graph_utils::FindGraphNodeWithName("Sink", output));
+ EXPECT_EQ(sink_node.input_size(), 1);
+ NodeDef parallelism_node = output.node(
+ graph_utils::FindGraphNodeWithName(sink_node.input(0), output));
+ EXPECT_EQ(parallelism_node.op(), "MaxIntraOpParallelismDataset");
+ EXPECT_EQ(parallelism_node.input_size(), 2);
+ NodeDef range_node = output.node(
+ graph_utils::FindGraphNodeWithName(parallelism_node.input(0), output));
+ EXPECT_EQ(range_node.name(), "range");
+ NodeDef parallelism_val = output.node(
+ graph_utils::FindGraphNodeWithName(parallelism_node.input(1), output));
+ EXPECT_EQ(parallelism_val.attr().at("value").tensor().int64_val(0), 1);
}
} // namespace