When calculating gradient, we should gather nodes to propagate from y_node_outputs in a backward search.

Before this change, we gather nodes to propagate from x_node_outputs in a forward search. When there are nodes included in this forward search but not reverse reachable from y_node_outputs, it will generate unnecessary gradient outputs.

In core/common_runtime/function_test.cc, we generate more nodes in SymbolicGradient instantiation result. But later they are all pruned in OptimizeGraph(), so it won't affect run time performance.

PiperOrigin-RevId: 265803750
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index acce8fb..2f5cef3 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -1515,13 +1515,21 @@
     auto two = ops::Const(s.WithOpName("two"), 2LL);
     auto scale = ops::Cast(s.WithOpName("scale"), two, DT_FLOAT);
     auto y = ops::Mul(s.WithOpName("y"), x, scale);
-    NameAttrList fn;
-    fn.set_name("Mul");
-    (*fn.mutable_attr())["T"].set_type(DT_FLOAT);
+    NameAttrList fn0;
+    fn0.set_name("Mul");
+    (*fn0.mutable_attr())["T"].set_type(DT_FLOAT);
     auto func1 = ops::SymbolicGradient(
         s.WithOpName("Func/_1"), std::initializer_list<Input>{x, scale, func0},
-        {DT_FLOAT, DT_FLOAT}, fn);
-    auto func2 = ops::_Retval(s.WithOpName("Func/_2"), func1[0], 0);
+        {DT_FLOAT, DT_FLOAT}, fn0);
+    NameAttrList fn1;
+    fn1.set_name("Cast");
+    (*fn1.mutable_attr())["SrcT"].set_type(DT_INT64);
+    (*fn1.mutable_attr())["DstT"].set_type(DT_FLOAT);
+    (*fn1.mutable_attr())["Truncate"].set_b(false);
+    auto func2 = ops::SymbolicGradient(
+        s.WithOpName("Func/_2"),
+        std::initializer_list<Input>{two, func1.output[1]}, {DT_INT64}, fn1);
+    auto func3 = ops::_Retval(s.WithOpName("Func/_3"), func1[0], 0);
     GraphDef expected;
     TF_ASSERT_OK(s.ToGraphDef(&expected));
 
@@ -1552,7 +1560,7 @@
         ops::Sum(s.WithOpName("Func/_1/sum_gx"), func1_gx, func1_rx.r0);
     auto func1_dx =
         ops::Reshape(s.WithOpName("Func/_1/dx"), func1_sum_gx, func1_sx);
-    auto func2 = ops::_Retval(s.WithOpName("Func/_2"), func1_dx, 0);
+    auto func2 = ops::_Retval(s.WithOpName("Func/_3"), func1_dx, 0);
     GraphDef expected;
     TF_ASSERT_OK(s.ToGraphDef(&expected));
 
@@ -1755,20 +1763,23 @@
         std::initializer_list<Input>{grad0_z, grad0_indices, func2},
         {DT_FLOAT, DT_INT32}, sum);
 
-    auto grad0_func2 = ops::ZerosLike(s.WithOpName("grad0/Func/_2"), grad0_r);
+    auto grad0_func2 =
+        ops::ZerosLike(s.WithOpName("grad0/Func/_2"), grad0_zero);
+    auto grad0_func3 = ops::ZerosLike(s.WithOpName("grad0/Func/_3"), grad0_r);
+    auto grad0_func4 = ops::ZerosLike(s.WithOpName("grad0/Func/_4"), grad0_one);
 
     NameAttrList add;
     add.set_name("Add");
     (*add.mutable_attr())["T"].set_type(DT_FLOAT);
-    auto grad0_func3 = ops::SymbolicGradient(
-        s.WithOpName("grad0/Func/_3"),
+    auto grad0_func5 = ops::SymbolicGradient(
+        s.WithOpName("grad0/Func/_5"),
         std::initializer_list<Input>{func0, func1, grad0_func1[0]},
         {DT_FLOAT, DT_FLOAT}, add);
 
     auto func3 =
-        ops::Identity(s.WithOpName("Func/grad0/output/_3"), grad0_func3[0]);
+        ops::Identity(s.WithOpName("Func/grad0/output/_3"), grad0_func5[0]);
     auto func4 =
-        ops::Identity(s.WithOpName("Func/grad0/output/_4"), grad0_func3[1]);
+        ops::Identity(s.WithOpName("Func/grad0/output/_4"), grad0_func5[1]);
     auto dx = ops::Identity(s.WithOpName("dx"), func3);
     auto dy = ops::Identity(s.WithOpName("dy"), func4);
     auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc
index 75352fc..3cf8e8b 100644
--- a/tensorflow/core/graph/gradients.cc
+++ b/tensorflow/core/graph/gradients.cc
@@ -254,30 +254,35 @@
     backprops_.clear();
     std::unordered_set<Node*> visited;
     std::deque<Node*> queue;
-    for (const NodeOut& nout : x_node_outputs_) {
+    for (const NodeOut& nout : y_node_outputs_) {
       queue.push_back(nout.node);
       visited.insert(nout.node);
     }
 
     // Going forward to figure out which endpoints need backprop-ed.
     // A node's endpoints need to be backprop-ed only if one of the
-    // arg node can reach the node via data edges.
+    // return nodes can reach backwards to the node via data edges.
     while (!queue.empty()) {
       Node* n = queue.front();
       queue.pop_front();
       for (int i = 0; i < n->num_outputs(); ++i) {
         backprops_[{n, i}].clear();
       }
-      int num_expected_backprops = 0;
-      for (const Edge* e : n->out_edges()) {
+      for (const Edge* e : n->in_edges()) {
         if (e->IsControlEdge()) continue;
-        ++num_expected_backprops;
-        if (visited.find(e->dst()) == visited.end()) {
-          queue.push_back(e->dst());
-          visited.insert(e->dst());
+        pending_[e->src()->id()]++;
+        if (visited.find(e->src()) == visited.end()) {
+          queue.push_back(e->src());
+          visited.insert(e->src());
         }
       }
-      pending_[n->id()] = num_expected_backprops;
+    }
+
+    // Create entries in backprops_ for all x_node_outputs_, because they will
+    // not be added in above loop if they are not reverse reachable from
+    // y_node_outputs_.
+    for (const NodeOut& nout : x_node_outputs_) {
+      backprops_[{nout.node, nout.index}].clear();
     }
   }