[Grappler] Prevent cycles creation in HoistCWiseUnaryChainsStage optimizer

PiperOrigin-RevId: 263205996
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 3bbd988..badfe2a 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1464,6 +1464,11 @@
         // in place for merging slices into splits.
         return false;
       }
+      if (NumControlOutputs(*node, *ctx().node_map) > 0) {
+        // TODO(ezhulenev): Unary ops after Split might have a control path to
+        // the Split node, and we currently do not propertly handle cycles.
+        return false;
+      }
       return num_split > 1 && !IsAlreadyOptimized(*node);
     }
     return false;
@@ -1519,6 +1524,11 @@
   // control inputs gathered from them to the concat or split node.
   Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails,
                            std::set<string>* ctrl_inputs, NodeDef* root_node) {
+    VLOG(3) << "Hoist unary op chain:"
+            << " root=" << root_node->name()
+            << " prefix_length=" << prefix_length << " ctrl_inputs=["
+            << absl::StrJoin(*ctrl_inputs, ", ") << "]";
+
     if (tails.empty()) {
       return Status::OK();
     }
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index acbb81a..7054eb2 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -275,6 +275,21 @@
   return num_inputs;
 }
 
+int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
+  int num_outputs = 0;
+  for (const NodeDef* output : node_map.GetOutputs(node.name())) {
+    for (const string& node_as_input : output->input()) {
+      if (!IsControlInput(node_as_input)) continue;
+
+      TensorId tensor = ParseTensorName(node_as_input);
+      if (tensor.node() == node.name()) {
+        ++num_outputs;
+      }
+    }
+  }
+  return num_outputs;
+}
+
 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
   int num_outputs = 0;
   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 700e431..8a69843 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -259,6 +259,9 @@
 // Number of connected non-control inputs.
 int NumNonControlInputs(const NodeDef& node);
 
+// Number of connected control outputs.
+int NumControlOutputs(const NodeDef& node, const NodeMap& node_map);
+
 // Number of connected non-control outputs.
 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);