Add control edges to and from functional control ops in BreakUpIslands.

We add an explicit control edge to ensure the side effects of a functional
control flow op (if, while, and case) are preserved in the LowerFunctionalOps
pass, because it appears only having an output dependency on the control flow
op is not enough to ensure the side effects of the control flow op take place.

PiperOrigin-RevId: 368678026
Change-Id: Ie0caebfa297902d6e48285947554998dc1865735
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
index ffa3394..6dfaa41 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
@@ -406,3 +406,54 @@
  }
  return
 }
+
+// CHECK: func @stateful_composite_op_control
+func @stateful_composite_op_control(%arg0: tensor<i1>, %arg1: tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32> {
+  %0 = tf_executor.graph {
+    %output, %control = tf_executor.island {
+      // CHECK: {{%.+}}, [[IF_CONTROL:%.+]] = tf_executor.island wraps "tf.If"
+      %1 = "tf.If"(%arg0, %arg1) {device = "", else_branch = @stateful_composite_op_control_else, is_stateless = false, then_branch = @stateful_composite_op_control_then} : (tensor<i1>, tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32>
+      // CHECK: [[IDENTITY_OUTPUT:%.+]], [[IDENTITY_CONTROL:%.+]] = tf_executor.island wraps "tf.Identity"
+      %2 = "tf.Identity"(%1) {device = ""} : (tensor<i32>) -> tensor<i32>
+
+      // The side effects of the If op might not be executed without an
+      // explicit control dependency on the tf.If op, due to the way the
+      // LowerFunctionalOpsPass in TF operates (b/185483669). Check that we
+      // output an explicit control dependency on the tf.If op in this case to
+      // be on the safe side.
+      // CHECK: [[SINK:%.+]] = tf_executor.island([[IF_CONTROL]], [[IDENTITY_CONTROL]]) wraps "tf.NoOp"
+      tf_executor.yield %2 : tensor<i32>
+    }
+    // CHECK: tf_executor.fetch [[IDENTITY_OUTPUT]], [[SINK]]
+    tf_executor.fetch %output : tensor<i32>
+  }
+  return %0 : tensor<i32>
+}
+
+// CHECK: func @stateful_composite_op_control_else
+// This is a helper function for the stateful_composite_op_control test.
+func @stateful_composite_op_control_else(%arg0: tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32> {
+  %0 = tf_executor.graph {
+    %outputs, %control = tf_executor.island {
+      %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+      "tf.AssignVariableOp"(%arg0, %1) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
+      tf_executor.yield %1 : tensor<i32>
+    }
+    tf_executor.fetch %outputs : tensor<i32>
+  }
+  return %0 : tensor<i32>
+}
+
+// CHECK: func @stateful_composite_op_control_then
+// This is a helper function for the stateful_composite_op_control test.
+func @stateful_composite_op_control_then(%arg0: tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32> {
+  %0 = tf_executor.graph {
+    %outputs, %control = tf_executor.island {
+      %1 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
+      "tf.AssignVariableOp"(%arg0, %1) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
+      tf_executor.yield %1 : tensor<i32>
+    }
+    tf_executor.fetch %outputs : tensor<i32>
+  }
+  return %0 : tensor<i32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
index 74e7b5c..b60f035 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
@@ -169,16 +169,36 @@
   return island;
 }
 
-// A struct contains the operations in an island that do not have incoming or
-// outgoing dependencies.
+// A struct that contains the operations in an island that need explicit control
+// dependencies added going into and out of the island to capture inter-island
+// dependencies properly.
 struct IslandSourcesAndSinks {
-  // Sub-ops that do not depend on other sub-ops in the island.
+  // Sub-ops that need a control dependency going into the island. This includes
+  // sub-ops that do not depend on other sub-ops in the island and functional
+  // control ops (e.g. if, while, case) with side effects that must not take
+  // effect before the previous island is finished executing.
   llvm::SmallPtrSet<Operation*, 4> sources;
-  // Sub-ops that do not have other sub-ops in the island depending on them
-  // (excluding yield).
+
+  // Sub-ops that need a control dependency going out of the island. This
+  // includes sub-ops that do not have other sub-ops in the island depending on
+  // them (excluding yield) and functional control ops (e.g. if, while, case)
+  // with side effects that must take effect before the next island starts
+  // executing.
   llvm::SmallPtrSet<Operation*, 4> sinks;
 };
 
+// Returns true if the operation is a stateful If, Case, or While op.
+bool IsStatefulFunctionalControlFlowOp(Operation* op) {
+  if (!isa<TF::IfOp, TF::CaseOp, TF::WhileOp>(op)) {
+    return false;
+  }
+
+  if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless")) {
+    return !is_stateless.getValue();
+  }
+  return false;
+}
+
 // Finds IslandSourcesAndSinks for an unmodified island.
 IslandSourcesAndSinks FindSourcesAndSinksInIsland(
     tf_executor::IslandOp island,
@@ -194,11 +214,19 @@
     for (auto operand : sub_op.getOperands()) {
       auto defining_op = operand.getDefiningOp();
       if (!defining_op || defining_op->getParentOp() != island) continue;
-      // Remove operands from sinks.
-      result.sinks.erase(defining_op);
       has_in_island_operands = true;
+
+      // Remove operands from sinks.
+      // We don't remove the operand if it is a stateful functional control flow
+      // op to work around an issue in LowerFunctionalOpsPass where the operand
+      // dependency isn't enough to ensure the side effects take place
+      // (b/185483669).
+      if (!IsStatefulFunctionalControlFlowOp(defining_op)) {
+        result.sinks.erase(defining_op);
+      }
     }
-    if (predecessors.empty() && !has_in_island_operands) {
+    if (predecessors.empty() && (!has_in_island_operands ||
+                                 IsStatefulFunctionalControlFlowOp(&sub_op))) {
       result.sources.insert(&sub_op);
     }
   }
@@ -251,7 +279,7 @@
     island_control_inputs.push_back(new_island.control());
   }
   // Find sources and sinks inside the original island.
-  auto sources_and_sinks =
+  IslandSourcesAndSinks sources_and_sinks =
       FindSourcesAndSinksInIsland(island_op, side_effect_analysis);
   // The corresponding control output of the new island created for each sub-op.
   llvm::SmallDenseMap<Operation*, Value, 8> new_control_for_sub_ops;