[Grappler] Prevent cycles creation in HoistCWiseUnaryChainsStage optimizer
PiperOrigin-RevId: 263205996
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 3bbd988..badfe2a 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1464,6 +1464,11 @@
// in place for merging slices into splits.
return false;
}
+ if (NumControlOutputs(*node, *ctx().node_map) > 0) {
+ // TODO(ezhulenev): Unary ops after Split might have a control path to
+ // the Split node, and we currently do not propertly handle cycles.
+ return false;
+ }
return num_split > 1 && !IsAlreadyOptimized(*node);
}
return false;
@@ -1519,6 +1524,11 @@
// control inputs gathered from them to the concat or split node.
Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails,
std::set<string>* ctrl_inputs, NodeDef* root_node) {
+ VLOG(3) << "Hoist unary op chain:"
+ << " root=" << root_node->name()
+ << " prefix_length=" << prefix_length << " ctrl_inputs=["
+ << absl::StrJoin(*ctrl_inputs, ", ") << "]";
+
if (tails.empty()) {
return Status::OK();
}
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index acbb81a..7054eb2 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -275,6 +275,21 @@
return num_inputs;
}
+int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
+ int num_outputs = 0;
+ for (const NodeDef* output : node_map.GetOutputs(node.name())) {
+ for (const string& node_as_input : output->input()) {
+ if (!IsControlInput(node_as_input)) continue;
+
+ TensorId tensor = ParseTensorName(node_as_input);
+ if (tensor.node() == node.name()) {
+ ++num_outputs;
+ }
+ }
+ }
+ return num_outputs;
+}
+
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
int num_outputs = 0;
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 700e431..8a69843 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -259,6 +259,9 @@
// Number of connected non-control inputs.
int NumNonControlInputs(const NodeDef& node);
+// Number of connected control outputs.
+int NumControlOutputs(const NodeDef& node, const NodeMap& node_map);
+
// Number of connected non-control outputs.
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);