Enable cross-iteration parallelism all in stateful tf.while_loops with stateless condition.
PiperOrigin-RevId: 373217844
Change-Id: Ibca02e9d7d8060dc9f3373e008c3c66da93a1dc7
diff --git a/tensorflow/core/common_runtime/control_flow_deps_to_chains.cc b/tensorflow/core/common_runtime/control_flow_deps_to_chains.cc
index e7f6eda..4411c3b 100644
--- a/tensorflow/core/common_runtime/control_flow_deps_to_chains.cc
+++ b/tensorflow/core/common_runtime/control_flow_deps_to_chains.cc
@@ -279,23 +279,29 @@
// Commit the new functions.
- // TODO(b/183666205): One of these two should not be necessary.
TF_RETURN_WITH_CONTEXT_IF_ERROR(
flib_def->AddFunctionDef(modified_body,
flib_def->GetStackTraces(body_name)),
- "while attaching", body_name, "to flib_def");
+ "while attaching ", new_body_name, " to flib_def");
TF_RETURN_WITH_CONTEXT_IF_ERROR(
flib_def->AddFunctionDef(modified_cond,
flib_def->GetStackTraces(cond_name)),
- "while attaching", cond_name, "to flib_def");
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- g->mutable_flib_def()->AddFunctionDef(
- modified_body, flib_def->GetStackTraces(body_name)),
- "while attaching", body_name, "to graph");
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- g->mutable_flib_def()->AddFunctionDef(
- modified_cond, flib_def->GetStackTraces(cond_name)),
- "while attaching", cond_name, "to grap");
+ "while attaching ", new_cond_name, " to flib_def");
+
+ // TODO(b/183666205): This should not be necessary.
+ // It's unclear why adding the functions here is also required.
+ // Moreover, it's unclear when graph_lib's parent is flib_def itself.
+ auto* graph_lib = g->mutable_flib_def();
+ if (graph_lib->default_registry() != flib_def) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ graph_lib->AddFunctionDef(modified_body,
+ graph_lib->GetStackTraces(body_name)),
+ "while attaching ", new_body_name, " to graph");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ graph_lib->AddFunctionDef(modified_cond,
+ graph_lib->GetStackTraces(cond_name)),
+ "while attaching ", new_cond_name, " to graph");
+ }
}
if (VLOG_IS_ON(1)) {
diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py
index 6f92068..383a357 100644
--- a/tensorflow/python/framework/auto_control_deps.py
+++ b/tensorflow/python/framework/auto_control_deps.py
@@ -397,6 +397,8 @@
if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
continue
# Make merges trigger all other computation which must run
+ # TODO(mdan): Don't do this. Write a transform to chains instead.
+ # See core/common_runtime/control_flow_deps_to_chains.cc.
if op.type == "Merge":
for o in ops_which_must_run:
op._add_control_input(o)
@@ -499,6 +501,10 @@
if self.record_initial_resource_uses:
first_uses_by_output_ops = {}
for op in ops_which_must_run:
+ if op not in resources_by_op:
+ # This may happen with Merge/Switch nodes which are special cased
+ # above.
+ continue
for r in resources_by_op[op]:
if op not in first_uses_by_output_ops:
first_uses_by_output_ops[op] = set()
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 8002849..c1b94b9 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -1700,8 +1700,8 @@
if not tf2.enabled():
self.skipTest("V2-only test.")
- # TODO(b/152548567): Enable this.
- while_v2.glob_stateful_parallelism = False
+ # TODO(b/187340669): Switch to True. Current status is: not yet supported.
+ while_v2.glob_stateful_parallelism = "stateless_cond"
ticker = variables.Variable(0)
counter = variables.Variable(1)
@@ -1737,11 +1737,8 @@
if not tf2.enabled():
self.skipTest("V2-only test.")
- # TODO(b/152548567): Enable experimental_stateful_parallelism.
- # Without proper wiring of control deps in the cond branch, the test is
- # non-deterministic, running cond's record_side_effect ahead of its
- # counterpart in the body.
- while_v2.glob_stateful_parallelism = False
+ # TODO(b/187340669): Switch to True. Current status is: not yet supported.
+ while_v2.glob_stateful_parallelism = "stateless_cond"
state = []
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 2e0d0a2..bda04aa 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -67,8 +67,7 @@
# side-effecting ops, this mode produces unspecified results.
# Setting it to "stateless_cond" automatically sets this mode to True when
# the loop condition is free of side-effecting ops.
-# TODO(b/152548567): Change this to "stateless_cond".
-glob_stateful_parallelism = False
+glob_stateful_parallelism = "stateless_cond"
def while_loop(cond,