[tf.data] Only apply the optimization `enable_gradient_descent` on the main dataset. Also update the unit test.

PiperOrigin-RevId: 335486882
Change-Id: I864f44d90fc3d84f702b3e53ed6b226362dd27a7
diff --git a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc
index fa31be8..4ece165 100644
--- a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc
+++ b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc
@@ -31,6 +31,8 @@
 
 constexpr char kAlgorithm[] = "algorithm";
 constexpr char kModelDataset[] = "ModelDataset";
+constexpr char kRetValOp[] = "_Retval";
+
 constexpr int64 HILL_CLIMB = 0;
 constexpr int64 GRADIENT_DESCENT = 1;
 
@@ -47,9 +49,21 @@
   }
   MutableGraphView graph(output);
 
-  int index = graph_utils::FindGraphNodeWithOp(kModelDataset, *output);
+  for (const auto& fetch_name : item.fetch) {
+    // If the GrapplerItem is derived from a FunctionDef, we don't optimize it,
+    // because we only want to enable gradient descent on the main dataset
+    // pipeline.
+    auto fetch = graph.GetNode(fetch_name);
+    if (fetch == nullptr || fetch->op() == kRetValOp) {
+      // Heuristic: If the fetch nodes are Retval ops, this item is from a
+      // function.
+      return Status::OK();
+    }
+  }
 
+  int index = graph_utils::FindGraphNodeWithOp(kModelDataset, *output);
   NodeDef& model_node = *(output->mutable_node(index));
+
   if (model_node.attr().at(kAlgorithm).i() == HILL_CLIMB) {
     (*model_node.mutable_attr())[kAlgorithm].set_i(GRADIENT_DESCENT);
     stats->num_changes++;
diff --git a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent_test.cc b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent_test.cc
index 86d66f3..c623c53 100644
--- a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent_test.cc
@@ -41,12 +41,13 @@
   return optimizer.Optimize(nullptr, item, output);
 }
 
-class SimpleRewrite : public ::testing::TestWithParam<std::tuple<bool, int64>> {
-};
+class SimpleRewrite
+    : public ::testing::TestWithParam<std::tuple<bool, int64, string>> {};
 
 TEST_P(SimpleRewrite, EnableGradientDescentTest) {
   const bool autotune = std::get<0>(GetParam());
   const int64 algorithm_index = std::get<1>(GetParam());
+  const string op = std::get<2>(GetParam());
 
   using test::function::NDef;
   GrapplerItem item;
@@ -58,7 +59,9 @@
        NDef("batch_size", "Const", {}, {{"value", 5}, {"dtype", DT_INT32}}),
        NDef("batch", "BatchDataset", {"range", "batch_size"}, {}),
        NDef("model", "ModelDataset", {"batch"},
-            {{"algorithm", algorithm_index}})});
+            {{"algorithm", algorithm_index}}),
+       NDef("Sink", op, {"model"}, {})});
+  item.fetch.push_back("Sink");
 
   GraphDef output;
   TF_ASSERT_OK(OptimizeWithEnableGradientDescent(item, &output, autotune));
@@ -67,12 +70,13 @@
   NodeDef model_node =
       output.node(graph_utils::FindGraphNodeWithName("model", output));
   EXPECT_EQ(model_node.attr().at("algorithm").i(),
-            autotune ? 1 : algorithm_index);
+            (autotune && op != "_Retval") ? 1 : algorithm_index);
 }
 
-INSTANTIATE_TEST_SUITE_P(Test, SimpleRewrite,
-                         ::testing::Combine(::testing::Values(false, true),
-                                            ::testing::Values(0, 1)));
+INSTANTIATE_TEST_SUITE_P(
+    Test, SimpleRewrite,
+    ::testing::Combine(::testing::Values(false, true), ::testing::Values(0, 1),
+                       ::testing::Values("Identity", "_Retval")));
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow