When calculating gradient, we should gather nodes to propagate from y_node_outputs in a backward search.
Before this change, we gather nodes to propagate from x_node_outputs in a forward search. When there are nodes included in this forward search but not reverse reachable from y_node_outputs, it will generate unnecessary gradient outputs.
In core/common_runtime/function_test.cc, we generate more nodes in SymbolicGradient instantiation result. But later they are all pruned in OptimizeGraph(), so it won't affect run time performance.
PiperOrigin-RevId: 265803750
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index acce8fb..2f5cef3 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -1515,13 +1515,21 @@
auto two = ops::Const(s.WithOpName("two"), 2LL);
auto scale = ops::Cast(s.WithOpName("scale"), two, DT_FLOAT);
auto y = ops::Mul(s.WithOpName("y"), x, scale);
- NameAttrList fn;
- fn.set_name("Mul");
- (*fn.mutable_attr())["T"].set_type(DT_FLOAT);
+ NameAttrList fn0;
+ fn0.set_name("Mul");
+ (*fn0.mutable_attr())["T"].set_type(DT_FLOAT);
auto func1 = ops::SymbolicGradient(
s.WithOpName("Func/_1"), std::initializer_list<Input>{x, scale, func0},
- {DT_FLOAT, DT_FLOAT}, fn);
- auto func2 = ops::_Retval(s.WithOpName("Func/_2"), func1[0], 0);
+ {DT_FLOAT, DT_FLOAT}, fn0);
+ NameAttrList fn1;
+ fn1.set_name("Cast");
+ (*fn1.mutable_attr())["SrcT"].set_type(DT_INT64);
+ (*fn1.mutable_attr())["DstT"].set_type(DT_FLOAT);
+ (*fn1.mutable_attr())["Truncate"].set_b(false);
+ auto func2 = ops::SymbolicGradient(
+ s.WithOpName("Func/_2"),
+ std::initializer_list<Input>{two, func1.output[1]}, {DT_INT64}, fn1);
+ auto func3 = ops::_Retval(s.WithOpName("Func/_3"), func1[0], 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
@@ -1552,7 +1560,7 @@
ops::Sum(s.WithOpName("Func/_1/sum_gx"), func1_gx, func1_rx.r0);
auto func1_dx =
ops::Reshape(s.WithOpName("Func/_1/dx"), func1_sum_gx, func1_sx);
- auto func2 = ops::_Retval(s.WithOpName("Func/_2"), func1_dx, 0);
+ auto func2 = ops::_Retval(s.WithOpName("Func/_3"), func1_dx, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
@@ -1755,20 +1763,23 @@
std::initializer_list<Input>{grad0_z, grad0_indices, func2},
{DT_FLOAT, DT_INT32}, sum);
- auto grad0_func2 = ops::ZerosLike(s.WithOpName("grad0/Func/_2"), grad0_r);
+ auto grad0_func2 =
+ ops::ZerosLike(s.WithOpName("grad0/Func/_2"), grad0_zero);
+ auto grad0_func3 = ops::ZerosLike(s.WithOpName("grad0/Func/_3"), grad0_r);
+ auto grad0_func4 = ops::ZerosLike(s.WithOpName("grad0/Func/_4"), grad0_one);
NameAttrList add;
add.set_name("Add");
(*add.mutable_attr())["T"].set_type(DT_FLOAT);
- auto grad0_func3 = ops::SymbolicGradient(
- s.WithOpName("grad0/Func/_3"),
+ auto grad0_func5 = ops::SymbolicGradient(
+ s.WithOpName("grad0/Func/_5"),
std::initializer_list<Input>{func0, func1, grad0_func1[0]},
{DT_FLOAT, DT_FLOAT}, add);
auto func3 =
- ops::Identity(s.WithOpName("Func/grad0/output/_3"), grad0_func3[0]);
+ ops::Identity(s.WithOpName("Func/grad0/output/_3"), grad0_func5[0]);
auto func4 =
- ops::Identity(s.WithOpName("Func/grad0/output/_4"), grad0_func3[1]);
+ ops::Identity(s.WithOpName("Func/grad0/output/_4"), grad0_func5[1]);
auto dx = ops::Identity(s.WithOpName("dx"), func3);
auto dy = ops::Identity(s.WithOpName("dy"), func4);
auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc
index 75352fc..3cf8e8b 100644
--- a/tensorflow/core/graph/gradients.cc
+++ b/tensorflow/core/graph/gradients.cc
@@ -254,30 +254,35 @@
backprops_.clear();
std::unordered_set<Node*> visited;
std::deque<Node*> queue;
- for (const NodeOut& nout : x_node_outputs_) {
+ for (const NodeOut& nout : y_node_outputs_) {
queue.push_back(nout.node);
visited.insert(nout.node);
}
// Going forward to figure out which endpoints need backprop-ed.
// A node's endpoints need to be backprop-ed only if one of the
- // arg node can reach the node via data edges.
+ // return nodes can reach backwards to the node via data edges.
while (!queue.empty()) {
Node* n = queue.front();
queue.pop_front();
for (int i = 0; i < n->num_outputs(); ++i) {
backprops_[{n, i}].clear();
}
- int num_expected_backprops = 0;
- for (const Edge* e : n->out_edges()) {
+ for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) continue;
- ++num_expected_backprops;
- if (visited.find(e->dst()) == visited.end()) {
- queue.push_back(e->dst());
- visited.insert(e->dst());
+ pending_[e->src()->id()]++;
+ if (visited.find(e->src()) == visited.end()) {
+ queue.push_back(e->src());
+ visited.insert(e->src());
}
}
- pending_[n->id()] = num_expected_backprops;
+ }
+
+ // Create entries in backprops_ for all x_node_outputs_, because they will
+ // not be added in above loop if they are not reverse reachable from
+ // y_node_outputs_.
+ for (const NodeOut& nout : x_node_outputs_) {
+ backprops_[{nout.node, nout.index}].clear();
}
}