Avoid using GraphView after function graph is modified in graph_properties.cc
PiperOrigin-RevId: 335933815
Change-Id: If0503bdc9ec0a6711a8ee3f4091a7eadb0345c5e
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 3299e2b..41910fc 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -854,6 +854,15 @@
}
}
+ // ReplaceInputWithConst() may break GraphView's internal node mapping
+ // structure; hence, we separately build node name to NodeDef* map, for the
+ // output nodes (before GraphView becomes invalid). Note that we use string,
+ // not string_view.
+ absl::flat_hash_map<std::string, NodeDef*> output_nodes;
+ for (const auto& output_arg : grappler_function_item.outputs()) {
+ output_nodes[output_arg.node_name] = gv.GetNode(output_arg.node_name);
+ }
+
// Replace input nodes with Consts, if values are known. Note that
// we don't check exceptions here as it's done in the above loop.
auto* ctx = GetNodeContext(function_node);
@@ -884,11 +893,13 @@
&grappler_function_item));
}
}
+ // node_name to NodeDef* map in GraphView gv can be broken due to
+ // ReplaceInputWithConst(). gv should not be used after this.
// Replace output _Retval nodes with Identity nodes. _Retval is a system op
// without outputs and registered shape function.
for (const auto& output_arg : grappler_function_item.outputs()) {
- NodeDef* output_node = gv.GetNode(output_arg.node_name);
+ NodeDef* output_node = output_nodes[output_arg.node_name];
DCHECK_EQ(output_node->op(), "_Retval");
output_node->set_op("Identity");
output_node->mutable_attr()->erase("index");
@@ -911,12 +922,12 @@
// inputs, so port_id >= 0.
TensorId out_tensor = ParseTensorName(out_arg.node_name);
- const NodeDef* retnode = gv.GetNode(out_tensor.node());
- if (retnode == nullptr) {
+ if (output_nodes.count(out_tensor.node()) <= 0) {
return errors::FailedPrecondition(
"Unable to find return function_node ", out_tensor.node(), " for ",
function_node->name());
}
+ const NodeDef* retnode = output_nodes[out_tensor.node()];
auto output_properties = gp.GetOutputProperties(retnode->name());
int output_properties_size = output_properties.size();