[tf.data] Add more unit test to check the correctness of `disable_intra_op_parallelism` optimization.

PiperOrigin-RevId: 325475258
Change-Id: I3a8480a15c0828add3fa37e7a7afe095a20a73e2
diff --git a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism_test.cc b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism_test.cc
index 76d6b46..b1c8865 100644
--- a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism_test.cc
@@ -70,10 +70,11 @@
   TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
   EXPECT_EQ(output.node_size(), 6);
   EXPECT_TRUE(graph_utils::ContainsNodeWithOp(op, output));
-  NodeDef test_node = output.node(graph_utils::FindGraphNodeWithOp(op, output));
-  NodeDef test_val = output.node(
-      graph_utils::FindGraphNodeWithName(test_node.input(1), output));
-  EXPECT_EQ(test_val.attr().at("value").tensor().int64_val(0), value);
+  NodeDef parallelism_node =
+      output.node(graph_utils::FindGraphNodeWithOp(op, output));
+  NodeDef parallelism_val = output.node(
+      graph_utils::FindGraphNodeWithName(parallelism_node.input(1), output));
+  EXPECT_EQ(parallelism_val.attr().at("value").tensor().int64_val(0), value);
 }
 
 INSTANTIATE_TEST_SUITE_P(
@@ -105,11 +106,19 @@
   EXPECT_EQ(output.node_size(), 7);
   EXPECT_TRUE(
       graph_utils::ContainsNodeWithOp("MaxIntraOpParallelismDataset", output));
-  NodeDef test_node = output.node(
-      graph_utils::FindGraphNodeWithOp("MaxIntraOpParallelismDataset", output));
-  NodeDef test_val = output.node(
-      graph_utils::FindGraphNodeWithName(test_node.input(1), output));
-  EXPECT_EQ(test_val.attr().at("value").tensor().int64_val(0), 1);
+  NodeDef sink_node =
+      output.node(graph_utils::FindGraphNodeWithName("Sink", output));
+  EXPECT_EQ(sink_node.input_size(), 1);
+  NodeDef parallelism_node = output.node(
+      graph_utils::FindGraphNodeWithName(sink_node.input(0), output));
+  EXPECT_EQ(parallelism_node.op(), "MaxIntraOpParallelismDataset");
+  EXPECT_EQ(parallelism_node.input_size(), 2);
+  NodeDef range_node = output.node(
+      graph_utils::FindGraphNodeWithName(parallelism_node.input(0), output));
+  EXPECT_EQ(range_node.name(), "range");
+  NodeDef parallelism_val = output.node(
+      graph_utils::FindGraphNodeWithName(parallelism_node.input(1), output));
+  EXPECT_EQ(parallelism_val.attr().at("value").tensor().int64_val(0), 1);
 }
 
 }  // namespace