Enable cross-iteration parallelism all in stateful tf.while_loops with stateless condition.

PiperOrigin-RevId: 373217844
Change-Id: Ibca02e9d7d8060dc9f3373e008c3c66da93a1dc7
diff --git a/tensorflow/core/common_runtime/control_flow_deps_to_chains.cc b/tensorflow/core/common_runtime/control_flow_deps_to_chains.cc
index e7f6eda..4411c3b 100644
--- a/tensorflow/core/common_runtime/control_flow_deps_to_chains.cc
+++ b/tensorflow/core/common_runtime/control_flow_deps_to_chains.cc
@@ -279,23 +279,29 @@
 
     // Commit the new functions.
 
-    // TODO(b/183666205): One of these two should not be necessary.
     TF_RETURN_WITH_CONTEXT_IF_ERROR(
         flib_def->AddFunctionDef(modified_body,
                                  flib_def->GetStackTraces(body_name)),
-        "while attaching", body_name, "to flib_def");
+        "while attaching ", new_body_name, " to flib_def");
     TF_RETURN_WITH_CONTEXT_IF_ERROR(
         flib_def->AddFunctionDef(modified_cond,
                                  flib_def->GetStackTraces(cond_name)),
-        "while attaching", cond_name, "to flib_def");
-    TF_RETURN_WITH_CONTEXT_IF_ERROR(
-        g->mutable_flib_def()->AddFunctionDef(
-            modified_body, flib_def->GetStackTraces(body_name)),
-        "while attaching", body_name, "to graph");
-    TF_RETURN_WITH_CONTEXT_IF_ERROR(
-        g->mutable_flib_def()->AddFunctionDef(
-            modified_cond, flib_def->GetStackTraces(cond_name)),
-        "while attaching", cond_name, "to grap");
+        "while attaching ", new_cond_name, " to flib_def");
+
+    // TODO(b/183666205): This should not be necessary.
+    // It's unclear why adding the functions here is also required.
+    // Moreover, it's unclear when graph_lib's parent is flib_def itself.
+    auto* graph_lib = g->mutable_flib_def();
+    if (graph_lib->default_registry() != flib_def) {
+      TF_RETURN_WITH_CONTEXT_IF_ERROR(
+          graph_lib->AddFunctionDef(modified_body,
+                                    graph_lib->GetStackTraces(body_name)),
+          "while attaching ", new_body_name, " to graph");
+      TF_RETURN_WITH_CONTEXT_IF_ERROR(
+          graph_lib->AddFunctionDef(modified_cond,
+                                    graph_lib->GetStackTraces(cond_name)),
+          "while attaching ", new_cond_name, " to graph");
+    }
   }
 
   if (VLOG_IS_ON(1)) {
diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py
index 6f92068..383a357 100644
--- a/tensorflow/python/framework/auto_control_deps.py
+++ b/tensorflow/python/framework/auto_control_deps.py
@@ -397,6 +397,8 @@
       if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
         continue
       # Make merges trigger all other computation which must run
+      # TODO(mdan): Don't do this. Write a transform to chains instead.
+      # See core/common_runtime/control_flow_deps_to_chains.cc.
       if op.type == "Merge":
         for o in ops_which_must_run:
           op._add_control_input(o)
@@ -499,6 +501,10 @@
     if self.record_initial_resource_uses:
       first_uses_by_output_ops = {}
       for op in ops_which_must_run:
+        if op not in resources_by_op:
+          # This may happen with Merge/Switch nodes which are special cased
+          # above.
+          continue
         for r in resources_by_op[op]:
           if op not in first_uses_by_output_ops:
             first_uses_by_output_ops[op] = set()
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 8002849..c1b94b9 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -1700,8 +1700,8 @@
     if not tf2.enabled():
       self.skipTest("V2-only test.")
 
-    # TODO(b/152548567): Enable this.
-    while_v2.glob_stateful_parallelism = False
+    # TODO(b/187340669): Switch to True. Current status is: not yet supported.
+    while_v2.glob_stateful_parallelism = "stateless_cond"
 
     ticker = variables.Variable(0)
     counter = variables.Variable(1)
@@ -1737,11 +1737,8 @@
     if not tf2.enabled():
       self.skipTest("V2-only test.")
 
-    # TODO(b/152548567): Enable experimental_stateful_parallelism.
-    # Without proper wiring of control deps in the cond branch, the test is
-    # non-deterministic, running cond's record_side_effect ahead of its
-    # counterpart in the body.
-    while_v2.glob_stateful_parallelism = False
+    # TODO(b/187340669): Switch to True. Current status is: not yet supported.
+    while_v2.glob_stateful_parallelism = "stateless_cond"
 
     state = []
 
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 2e0d0a2..bda04aa 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -67,8 +67,7 @@
 # side-effecting ops, this mode produces unspecified results.
 # Setting it to "stateless_cond" automatically sets this mode to True when
 # the loop condition is free of side-effecting ops.
-# TODO(b/152548567): Change this to "stateless_cond".
-glob_stateful_parallelism = False
+glob_stateful_parallelism = "stateless_cond"
 
 
 def while_loop(cond,