Add a pass which converts while_loop body function control dependencies into chain loop variables (dummy float scalars), allowing inter-iteration loop parallelism in the presence of stateful ops. Previously, all control dependencies were bundled at the end of the body, creating an artificial barrier (possibly to ensure side effects are consistent in the cond function). For now, this only works for loops whose conditional is completely stateless. An experimental flag allows overriding the behavior.
This behavior is currently disabled.

PiperOrigin-RevId: 370737684
Change-Id: I12d062fc25a9f6ee3851c603f28f23f7f6164c21
diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD
index 59e146c..dc7330c 100644
--- a/tensorflow/core/common_runtime/BUILD
+++ b/tensorflow/core/common_runtime/BUILD
@@ -243,6 +243,7 @@
         "lower_if_op.h",
         "lower_case_op.h",
         "lower_functional_ops.h",
+        "control_flow_deps_to_chains.h",
         "lower_while_op.h",
         "memory_types.h",
         "mkl_cpu_allocator.h",
@@ -984,6 +985,26 @@
 )
 
 cc_library(
+    name = "control_flow_deps_to_chains",
+    srcs = ["control_flow_deps_to_chains.cc"],
+    hdrs = ["control_flow_deps_to_chains.h"],
+    copts = tf_copts(),
+    visibility = default_package_visibility + [
+        "//platforms/performance/autograppler:__subpackages__",
+        "//platforms/performance/tf_sim:__subpackages__",
+    ],
+    deps = [
+        ":optimization_registry",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/framework:attr_value_proto_cc",
+        "//tensorflow/core/framework:node_def_proto_cc",
+        "//tensorflow/core/framework:tensor_proto_cc",
+        "//tensorflow/core/platform:errors",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
     name = "lower_if_op",
     srcs = ["lower_if_op.cc"],
     hdrs = ["lower_if_op.h"],
@@ -1583,6 +1604,7 @@
         ":collective_rma_local",
         ":collective_util",
         ":composite_device",
+        ":control_flow_deps_to_chains",
         ":copy_tensor",
         ":costmodel_manager",
         ":debugger_state_interface",
diff --git a/tensorflow/core/common_runtime/control_flow_deps_to_chains.cc b/tensorflow/core/common_runtime/control_flow_deps_to_chains.cc
new file mode 100644
index 0000000..934c3b9
--- /dev/null
+++ b/tensorflow/core/common_runtime/control_flow_deps_to_chains.cc
@@ -0,0 +1,309 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/control_flow_deps_to_chains.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/strcat.h"
+#include "tensorflow/core/util/dump_graph.h"
+
+namespace tensorflow {
+
+// TODO(mdan): Move this into Grappler - cleaner interface.
+Status ControlFlowDepsToChainsPass::Run(
+    const GraphOptimizationPassOptions& options) {
+  VLOG(1) << "ControlFlowDepsToChainsPass::Run";
+
+  if (options.graph == nullptr) {
+    VLOG(1) << "ControlFlowDepsToChainsPass::Run Aborted";
+    return Status::OK();
+  }
+
+  Graph* g = options.graph->get();
+  DCHECK(g != nullptr);
+  FunctionLibraryDefinition* flib_def = options.flib_def;
+  DCHECK(flib_def != nullptr);
+
+  if (VLOG_IS_ON(1)) {
+    DumpGraphToFile("control_flow_deps_to_chains_before", *g, flib_def);
+  }
+
+  for (Node* n : g->nodes()) {
+    if (n == nullptr) continue;
+    if (!n->IsWhileNode()) continue;
+
+    // TODO(mdan): This breaks encapsulation of Node/Graph. Is there any needed?
+    // TODO(mdan): Consolidate this with AddWhileInputHack.
+    NodeDef* while_node = n->mutable_def();
+    const auto& attrs = while_node->attr();
+    auto* mattrs = while_node->mutable_attr();
+
+    string body_name = attrs.at("body").func().name();
+    auto* body_graph = flib_def->Find(body_name);
+    DCHECK(body_graph != nullptr);
+
+    // Look for required annotations.
+
+    if (attrs.find("_stateful_parallelism") == attrs.end()) continue;
+    if (!attrs.at("_stateful_parallelism").b()) continue;
+    // TODO(mdan): We don't really need this attribute.
+    if (attrs.find("_num_original_outputs") == attrs.end()) continue;
+    int body_barrier_loc = -1;
+    std::map<string, int> node_index;
+    for (int i = 0, s = body_graph->node_def_size(); i < s; i++) {
+      node_index.emplace(body_graph->node_def(i).name(), i);
+      if (body_barrier_loc < 0) {
+        const auto& node_attr = body_graph->node_def(i).attr();
+        if (node_attr.find("_acd_function_control_output") != node_attr.end()) {
+          body_barrier_loc = i;
+        }
+      }
+    }
+    if (body_barrier_loc < 0) continue;
+    bool ok_for_lowering = true;
+    for (int i = 0; i < body_graph->control_ret_size(); i++) {
+      const auto& control_node = body_graph->node_def(
+          node_index[body_graph->signature().control_output(i)]);
+      const auto& control_attr = control_node.attr();
+      if (control_attr.find("_res_first_used_by") == control_attr.end()) {
+        ok_for_lowering = false;
+        break;
+      }
+    }
+    if (!ok_for_lowering) continue;
+
+    int num_loop_vars = body_graph->signature().input_arg_size();
+    int num_new_chains = body_graph->control_ret_size();
+    int num_node_inputs = while_node->input_size();
+
+    if (!num_new_chains) continue;  // Nothing to do for stateless loops.
+
+    // Add extra loop vars to the while node.
+
+    // TODO(mdan): If the loop vars contains the resource, we should reuse it.
+    // Note that stateful ops of resource inputs cause their resources to be
+    // captured into the loop vars (through the body/cond captures). We could
+    // effectively use those as chains.
+
+    // TODO(mdan): Is there a more efficient way to do this?
+    // Insert the new While node inputs: at the end of the loop vars, but before
+    // any non-loop var inputs (like control dependencies). Once the initial
+    // chain values are created below, they will be added to these inputs.
+    for (int i = 0; i < num_new_chains; i++) {
+      while_node->add_input();
+    }
+    for (int i = num_node_inputs - 1; i >= num_loop_vars; i--) {
+      while_node->set_input(i + num_new_chains, while_node->input(i));
+    }
+
+    std::vector<Node*> new_inputs;
+    std::vector<int> new_input_locations;
+    // Set their name to a gensym, type to float and shape to scalar.
+    for (int i = 0; i < num_new_chains; i++) {
+      string c_name = g->NewName("acd__chain");
+
+      // The initial value for the i'th chain loop var.
+      NodeDef new_in;
+      new_in.set_name(c_name);
+      new_in.set_op("Const");
+      AttrValue att_dtype;
+      att_dtype.set_type(DT_FLOAT);
+      new_in.mutable_attr()->insert({"dtype", att_dtype});
+      AttrValue att_value;
+      att_value.mutable_tensor()->set_dtype(DT_FLOAT);
+      att_value.mutable_tensor()->mutable_tensor_shape();
+      att_value.mutable_tensor()->add_int_val(0);
+      new_in.mutable_attr()->insert({"value", att_value});
+      Status status;
+      new_inputs.push_back(g->AddNode(new_in, &status));
+      TF_RETURN_WITH_CONTEXT_IF_ERROR(status, "while creating chain", c_name);
+
+      int loc = num_loop_vars + i;
+      new_input_locations.push_back(loc);
+      while_node->set_input(loc, c_name);
+      mattrs->at("T").mutable_list()->add_type(DT_FLOAT);
+      mattrs->at("output_shapes").mutable_list()->add_shape();
+    }
+
+    // TODO(mdan): This should not be necessary to update. Delete?
+    mattrs->at("_num_original_outputs").set_i(num_loop_vars + num_new_chains);
+    n->UpdateProperties();
+    for (int i = 0; i < num_new_chains; i++) {
+      g->AddEdge(new_inputs[i], 0, n, new_input_locations[i]);
+    }
+
+    // TODO(mdan): This is wasteful. Can we just mutate the original proto?
+    FunctionDef modified_body = *body_graph;
+
+    // Disable the global end-of-body barrier from the body function.
+    // Because removing a node is too inefficient (would have to walk all the
+    // inputs of all graph nodes), we instead clear its control dependencies.
+    modified_body.mutable_node_def(body_barrier_loc)->clear_input();
+
+    // Add extra loop vars to the body function.
+
+    for (int i = 0; i < num_new_chains; i++) {
+      // Input loop vars.
+      // TODO(mdan): Double check that this doesn't clash with names in body.
+      string c_name = g->NewName("acd__chain");
+      std::replace(c_name.begin(), c_name.end(), '/', '_');
+      auto* new_arg = modified_body.mutable_signature()->add_input_arg();
+      new_arg->set_name(c_name);
+      new_arg->set_type(DT_FLOAT);
+
+      // Output ops. These are copies of the inputs conditioned on the actual
+      // control outputs.
+      string c_out_name = g->NewName("acd__outchain");
+      auto* new_out = modified_body.add_node_def();
+      new_out->set_name(c_out_name);
+      new_out->set_op("Identity");
+      new_out->add_input(c_name);
+      new_out->add_input(
+          strings::StrCat("^", body_graph->signature().control_output(i)));
+      AttrValue attr;
+      attr.set_type(DT_FLOAT);
+      new_out->mutable_attr()->insert({"T", attr});
+
+      // Output loop var declarations.
+      string c_ret_name = c_out_name;
+      std::replace(c_ret_name.begin(), c_ret_name.end(), '/', '_');
+      auto* new_out_arg = modified_body.mutable_signature()->add_output_arg();
+      new_out_arg->set_name(c_ret_name);
+      new_out_arg->set_type(DT_FLOAT);
+
+      // Actual output loop vars.
+      modified_body.mutable_ret()->insert(
+          {c_ret_name, strings::StrCat(c_out_name, ":output:0")});
+      AttrValue attr_val;
+      attr_val.mutable_list()->mutable_shape();
+      FunctionDef_ArgAttrs arg_attrs;
+      arg_attrs.mutable_attr()->insert({"_output_shapes", attr_val});
+      modified_body.mutable_arg_attr()->insert({i + num_loop_vars, arg_attrs});
+    }
+
+    // Wire chain loop vars to the ops they need to condition.
+
+    node_index.clear();
+    for (int i = 0; i < modified_body.node_def_size(); i++) {
+      node_index.emplace(modified_body.node_def(i).name(), i);
+    }
+    auto& modified_sig = modified_body.signature();
+    for (int i = 0; i < num_new_chains; i++) {
+      const auto& control_node =
+          modified_body.node_def(node_index[modified_sig.control_output(i)]);
+      for (const auto& r :
+           control_node.attr().at("_res_first_used_by").list().s()) {
+        NodeDef* first_node = modified_body.mutable_node_def(node_index[r]);
+        // This control dependency ensures proper sequencing of stateful ops
+        // upon entry into the loop body, so that they run after the ops
+        // which affected the same resource in the previous iteration.
+        first_node->add_input(strings::StrCat(
+            "^", modified_sig.input_arg(i + num_loop_vars).name()));
+      }
+    }
+
+    // Clear body function's control returns.
+    modified_body.mutable_control_ret()->clear();
+
+    // Add extra loop vars to the cond function.
+
+    // TODO(mdan): This is wasteful. Can't we just mutate the original proto?
+    string cond_name = attrs.at("cond").func().name();
+    auto* cond_graph = flib_def->Find(cond_name);
+    DCHECK(cond_graph != nullptr);
+    FunctionDef modified_cond = *cond_graph;
+
+    int cond_barrier_loc = -1;
+    for (int i = 0, s = cond_graph->node_def_size(); i < s; i++) {
+      if (cond_barrier_loc < 0) {
+        const auto& node_attr = cond_graph->node_def(i).attr();
+        if (node_attr.find("_acd_function_control_output") != node_attr.end()) {
+          cond_barrier_loc = i;
+        }
+      }
+    }
+    if (cond_barrier_loc > 0) {
+      // Disable the global end-of-body barrier from the cond function.
+      // Because removing a node is too inefficient (would have to walk all the
+      // inputs of all graph nodes), we instead clear its control dependencies.
+      modified_cond.mutable_node_def(cond_barrier_loc)->clear_input();
+    }
+
+    for (int i = 0; i < num_new_chains; i++) {
+      // Input loop vars.
+      // TODO(mdan): These should gate the stateful ops in the cond.
+      // Until ACD supplies the necessary information, these are dummies in this
+      // function.
+      string c_name = g->NewName("acd__chain");
+      auto* new_arg = modified_cond.mutable_signature()->add_input_arg();
+      new_arg->set_name(c_name);
+      new_arg->set_type(DT_FLOAT);
+
+      // TODO(mdan): Return values on the cond function? Most likely a bug.
+      AttrValue attr_val;
+      attr_val.mutable_list()->mutable_shape();
+      FunctionDef_ArgAttrs arg_attrs;
+      arg_attrs.mutable_attr()->insert({"_output_shapes", attr_val});
+      modified_cond.mutable_arg_attr()->insert({i + num_loop_vars, arg_attrs});
+    }
+
+    // Wire the new cond/body functions to the While node.
+
+    string new_cond_name = g->NewName("acd__while_cond");
+    modified_cond.mutable_signature()->set_name(new_cond_name);
+    mattrs->at("cond").mutable_func()->set_name(new_cond_name);
+
+    string new_body_name = g->NewName("acd__while_body");
+    modified_body.mutable_signature()->set_name(new_body_name);
+    mattrs->at("body").mutable_func()->set_name(new_body_name);
+
+    // Commit the new functions.
+
+    // TODO(b/183666205): One of these two should not be necessary.
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(
+        flib_def->AddFunctionDef(modified_body,
+                                 flib_def->GetStackTraces(body_name)),
+        "while attaching", body_name, "to flib_def");
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(
+        flib_def->AddFunctionDef(modified_cond,
+                                 flib_def->GetStackTraces(cond_name)),
+        "while attaching", cond_name, "to flib_def");
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(
+        g->mutable_flib_def()->AddFunctionDef(
+            modified_body, flib_def->GetStackTraces(body_name)),
+        "while attaching", body_name, "to graph");
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(
+        g->mutable_flib_def()->AddFunctionDef(
+            modified_cond, flib_def->GetStackTraces(cond_name)),
+        "while attaching", cond_name, "to grap");
+  }
+
+  if (VLOG_IS_ON(1)) {
+    DumpGraphToFile("control_flow_deps_to_chains_after", *g, flib_def);
+  }
+
+  return Status::OK();
+}
+
+// Note: This needs to run before functional control flow lowering, which is 10.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 9,
+                      ControlFlowDepsToChainsPass);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/control_flow_deps_to_chains.h b/tensorflow/core/common_runtime/control_flow_deps_to_chains.h
new file mode 100644
index 0000000..ab7358e
--- /dev/null
+++ b/tensorflow/core/common_runtime/control_flow_deps_to_chains.h
@@ -0,0 +1,76 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_CONTROL_FLOW_DEPS_TO_CHAINS_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_CONTROL_FLOW_DEPS_TO_CHAINS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+
+namespace tensorflow {
+
+// Move control flow dependencies in functional control flow to chains.
+// Chains are extra loop variables that serve as tokens for wiring control
+// dependencies across loop iterations at a finer granularity, compared to just
+// a single barrier at the end of each iteration. This enables the
+// parallel_iterations feature for tf.while_loop.
+//
+// One separate chain is added for each of the body function's `control_ret`.
+//
+// For example:
+//
+//   while i > 0:
+//     r = v.read_value()
+//     s += expensive_operation(r)
+//     assign = v.assign_add(1)  # control: r
+//     i += 1
+//
+// The loop above can safely compute `r` and `assign` ahead of `s`, by the
+// as-if rule. The separate switch/merge nodes that the loop lowers into support
+// that.
+// This transformation enables that to happen by rewriting the loop as follows:
+//
+//   chain = 0.0
+//   while i > 0:
+//     r = v.read_value()  # control: chain
+//     s += expensive_operation(r)
+//     assign = v.assign_add(1)  # control: r
+//     i += 1
+//     chain = identity(chain)  # control: assign
+//
+// This only rewires dependencies which need to cross scope boundaries, as the
+// switch/merge lowering process has no other way of dealing correctly with
+// those.
+//
+// This pass is best-effort and conservative, requiring attributes set by
+// tf.while_loop and automatic_control_dependencies. When the required
+// attributes are missing for a particular While node, no change is made to
+// that node. Other While nodes are still processed if they do have the needed
+// annotations.
+// The pass can also be toggled by omitting the `_stateful_parallelism=True`
+// attribute on the While node.
+// When the pass returns with error, the graph is left in an invalid state.
+// If successful, this pass also clears the body function's control_ret,
+// which in effect removes the hard barrier that gates each loop iteration.
+//
+//
+// TODO(mdan): Can we define that more formally?
+class ControlFlowDepsToChainsPass : public GraphOptimizationPass {
+ public:
+  Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_CONTROL_FLOW_DEPS_TO_CHAINS_H_
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 8ac76ac..975620d 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -178,6 +178,8 @@
 const NodeDef& Node::def() const { return props_->node_def; }
 const OpDef& Node::op_def() const { return *props_->op_def; }
 
+NodeDef* Node::mutable_def() { return &props_->node_def; }
+
 int32 Node::num_inputs() const { return props_->input_types.size(); }
 DataType Node::input_type(int32 i) const { return props_->input_types[i]; }
 const DataTypeVector& Node::input_types() const { return props_->input_types; }
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index aa55814..a0d85e0 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -101,6 +101,9 @@
   const NodeDef& def() const;
   const OpDef& op_def() const;
 
+  // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove?
+  NodeDef* mutable_def();
+
   // input and output types
   int32 num_inputs() const;
   DataType input_type(int32 i) const;
@@ -253,6 +256,10 @@
     return stack_trace_;
   }
 
+  // Called after an attr has changed. Decides whether we need to update some
+  // property of the node (stored in props_).
+  void UpdateProperties();
+
  private:
   friend class Graph;
   Node();
@@ -270,10 +277,6 @@
   // e.g. in AddAttr.
   void MaybeCopyOnWrite();
 
-  // Called after an attr has changed. Decides whether we need to update some
-  // property of the node (stored in props_).
-  void UpdateProperties();
-
   AttrValue* AddAttrHelper(const std::string& name);
 
   // A set of mutually exclusive classes for different kinds of nodes,
@@ -660,6 +663,9 @@
   const OpRegistryInterface* op_registry() const { return &ops_; }
   const FunctionLibraryDefinition& flib_def() const { return ops_; }
 
+  // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove?
+  FunctionLibraryDefinition* mutable_flib_def() { return &ops_; }
+
   void CheckDeviceNameIndex(int index) {
     DCHECK_GE(index, 0);
     DCHECK_LT(index, static_cast<int>(device_names_.size()));
diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py
index d9b8e47..6f92068 100644
--- a/tensorflow/python/framework/auto_control_deps.py
+++ b/tensorflow/python/framework/auto_control_deps.py
@@ -21,6 +21,7 @@
 import collections
 import enum
 
+from tensorflow.core.framework import attr_value_pb2
 from tensorflow.python.eager import context
 from tensorflow.python.framework import auto_control_deps_utils as utils
 from tensorflow.python.framework import dtypes as dtypes_module
@@ -180,14 +181,13 @@
   NOT THREAD SAFE
   """
 
-  __slots__ = [
-      "_returned_tensors", "ops_which_must_run", "_graph", "_n_operations",
-      "collective_manager_ids_used"
-  ]
-
-  def __init__(self):
+  def __init__(self,
+               record_initial_resource_uses=False,
+               record_uses_of_resource_ids=None):
     self._returned_tensors = object_identity.ObjectIdentitySet()
     self.ops_which_must_run = set()
+    self.record_initial_resource_uses = record_initial_resource_uses
+    self.record_uses_of_resource_ids = record_uses_of_resource_ids
 
   def mark_as_return(self, tensor):
     """Acts like identity but marks the `Tensor` as a return value.
@@ -341,6 +341,8 @@
     merge_for_resource = {}
 
     new_operations = self._graph.get_operations()[self._n_operations:]
+    first_use_for_res = {}
+    resources_by_op = {}
 
     # Ensures that uses of resource tensors get serialized properly and all
     # execute. This is done by keeping a map from resource tensor to the last op
@@ -432,13 +434,39 @@
         # Ensure merges happen after the closing of a cond block
         if input_id in merge_for_resource:
           merge_for_resource[input_id]._add_control_input(op)
+
+        do_record = (
+            self.record_initial_resource_uses and
+            input_id not in first_use_for_res)
+
         if is_read:
-          reads_since_last_write_to_resource[input_id].append(op)
+          reads_list = reads_since_last_write_to_resource[input_id]
+          reads_list.append(op)
+
+          if do_record:
+            # Note: this will track the entire list that
+            # reads_since_last_write_to_resource maintains. Updates to it will
+            # and should be tracked, until the first write is encountered. At
+            # that point, reads_since_last_write_to_resource will contain a new
+            # empty list. This logic relies on that behavior.
+            first_use_for_res[input_id] = reads_list
+
         else:
           control_inputs.update(reads_since_last_write_to_resource[input_id])
           reads_since_last_write_to_resource[input_id] = []
           last_write_to_resource[input_id] = op
 
+          if do_record:
+            first_use_for_res[input_id] = [op]
+
+      if self.record_initial_resource_uses and op_is_stateful(op):
+        if resource_inputs:
+          resources_by_op[op] = tuple(resource_inputs)
+        else:
+          if None not in first_use_for_res:
+            first_use_for_res[None] = [op]
+          resources_by_op[op] = (None,)
+
       if (op_is_stateful(op) and not resource_inputs
           and op._control_flow_context is None):
         if None in last_write_to_resource:
@@ -467,9 +495,26 @@
 
       op._add_control_inputs(control_inputs)
 
+    # Record the ops which first use resources touched by "ops which must run".
+    if self.record_initial_resource_uses:
+      first_uses_by_output_ops = {}
+      for op in ops_which_must_run:
+        for r in resources_by_op[op]:
+          if op not in first_uses_by_output_ops:
+            first_uses_by_output_ops[op] = set()
+          first_uses_by_output_ops[op].update(first_use_for_res[r])
+      # For each "op which must run", set a private attr indicating the ops that
+      # used the same resources it did.
+      for op in first_uses_by_output_ops:
+        others = [
+            other.name.encode() for other in first_uses_by_output_ops[op]
+        ]
+        l = attr_value_pb2.AttrValue.ListValue(s=others)
+        # TODO(mdan): Is there a way which doesn't use anonymous attrs?
+        op._set_attr("_res_first_used_by", attr_value_pb2.AttrValue(list=l))
+
     # Ensure all ops which must run do run
     self.ops_which_must_run.update(ops_which_must_run)
-
     control_output_op = None
     for idx, r in enumerate(
         nest.flatten(list(self._returned_tensors), expand_composites=True)):
@@ -478,11 +523,13 @@
         if r.graph.building_function:
           # There may be many stateful ops in the graph. Adding them as
           # control inputs to each function output could create excessive
-          # control edges in the graph. Thus we create an intermediate No-op to
-          # chain the control dependencies between stateful ops and function
-          # outputs.
+          # control edges in the graph. Thus we create an intermediate No-op
+          # to chain the control dependencies between stateful ops and
+          # function outputs.
           if idx == 0:
             control_output_op = control_flow_ops.no_op()
+            control_output_op._set_attr("_acd_function_control_output",
+                                        attr_value_pb2.AttrValue(b=True))
             control_output_op._add_control_inputs(self.ops_which_must_run)
           updated_ops_which_must_run = [control_output_op]
         else:
diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py
index 15e69ca..9d7acc7 100644
--- a/tensorflow/python/framework/func_graph.py
+++ b/tensorflow/python/framework/func_graph.py
@@ -811,6 +811,7 @@
     self._scope_exit_callbacks.append(fn)
 
 
+# TODO(mdan): Too many threaded arguments. Accept an ACD ctx manager instead.
 def func_graph_from_py_func(name,
                             python_func,
                             args,
@@ -824,7 +825,8 @@
                             op_return_value=None,
                             collections=None,
                             capture_by_value=None,
-                            override_flat_arg_shapes=None):
+                            override_flat_arg_shapes=None,
+                            acd_record_initial_resource_uses=False):
   """Returns a `FuncGraph` generated from `python_func`.
 
   Args:
@@ -869,6 +871,11 @@
       containing value `None` must match entries in flattened arguments
       containing non-tensors, while entries containing a `TensorShape` must
       match entries in the flattened arguments containing tensors.
+    acd_record_initial_resource_uses: If `True` and `add_control_dependencies`
+      is enabled, the results (those marked with
+      AutomaticControlDependencies.mark_result) will be annotated with a private
+      attribute, "_res_first_used_by", which points to the first nodes which
+      used the any of the resources that the result op is using.
 
   Returns:
     A FuncGraph.
@@ -886,7 +893,8 @@
                            capture_by_value=capture_by_value)
   assert isinstance(func_graph, FuncGraph)
   if add_control_dependencies:
-    deps_control_manager = auto_control_deps.AutomaticControlDependencies()
+    deps_control_manager = auto_control_deps.AutomaticControlDependencies(
+        record_initial_resource_uses=acd_record_initial_resource_uses)
   else:
     deps_control_manager = ops.NullContextmanager()
 
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 9f4306b..8002849 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -59,6 +59,7 @@
 from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
+from tensorflow.python.ops import while_v2
 import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
 from tensorflow.python.platform import googletest
 from tensorflow.python.training import momentum
@@ -1567,6 +1568,14 @@
 class WhileLoopParallelismTest(test_util.TensorFlowTestCase,
                                parameterized.TestCase):
 
+  def setUp(self):
+    super().setUp()
+    self._while_paralelism = while_v2.glob_stateful_parallelism
+
+  def tearDown(self):
+    while_v2.glob_stateful_parallelism = self._while_paralelism
+    super().tearDown()
+
   @parameterized.parameters(*itertools.product(
       (False, True),
       (False, True),
@@ -1580,6 +1589,8 @@
     if not tf2.enabled():
       self.skipTest("V2-only test.")
 
+    while_v2.glob_stateful_parallelism = True
+
     ticker = variables.Variable(0)
 
     @def_function.function
@@ -1598,6 +1609,7 @@
 
       while i < n:
         directives.set_loop_options(parallel_iterations=10)
+
         if modify_in_loop:
           ticker.assign_add(1)
         t_acc = t_acc.write(i, ticker.read_value())
@@ -1615,7 +1627,7 @@
     # Warm-up.
     self.evaluate(run_loop(1))
 
-    self.evaluate(ticker.assign(0))
+    self.evaluate(ticker.assign(123))
     acc, rb, ra = run_loop(3)
     self.assertEqual(
         self.evaluate(math_ops.reduce_max(acc)),
@@ -1635,11 +1647,168 @@
         self.evaluate(ticker.read_value()),
         int(modify_before) + 3 * int(modify_in_loop) + int(modify_after))
 
+  def testMultiReadsBeforeWrite(self):
+
+    if not tf2.enabled():
+      self.skipTest("V2-only test.")
+
+    while_v2.glob_stateful_parallelism = True
+
+    ticker = variables.Variable(0)
+
+    @def_function.function
+    def run_loop(n):
+      ticker.assign(0)
+      i = constant_op.constant(0)
+      t_acc = tensor_array_ops.TensorArray(
+          dtypes.int32, size=0, dynamic_size=True)
+
+      while i < n:
+        directives.set_loop_options(parallel_iterations=10)
+
+        a = ticker.read_value()
+        b = ticker.read_value()
+        t_acc = t_acc.write(2 * i, a)
+        t_acc = t_acc.write(2 * i + 1, b)
+
+        # Slow write forces reads to sprint ahead if they can.
+        # This test verifies that they don't.
+        ticker.assign_add(
+            math_ops.cast(
+                math_ops.reduce_max(
+                    random_ops.random_uniform(
+                        shape=(1000,), minval=1.0, maxval=1.001)),
+                dtypes.int32))
+
+        i += 1
+
+      a = ticker.read_value()
+      b = ticker.read_value()
+      t_acc = t_acc.write(2 * i, a)
+      t_acc = t_acc.write(2 * i + 1, b)
+
+      return t_acc.stack()
+
+    # Warm-up.
+    self.evaluate(run_loop(1))
+
+    acc = run_loop(3)
+    self.assertAllEqual(acc, [0, 0, 1, 1, 2, 2, 3, 3])
+
+  def testCondDependenceOnMutatedResource(self):
+
+    if not tf2.enabled():
+      self.skipTest("V2-only test.")
+
+    # TODO(b/152548567): Enable this.
+    while_v2.glob_stateful_parallelism = False
+
+    ticker = variables.Variable(0)
+    counter = variables.Variable(1)
+
+    @def_function.function
+    def run_loop(n):
+      ticker.assign(0)
+      counter.assign(0)
+
+      while ticker.read_value() < n:
+        directives.set_loop_options(parallel_iterations=10)
+
+        # Run a slow assign, to make sure counter sprints ahead.
+        ticker.assign_add(
+            math_ops.cast(
+                math_ops.reduce_max(
+                    random_ops.random_uniform(
+                        shape=(1000,), minval=1.0, maxval=1.001)),
+                dtypes.int32))
+        counter.assign_add(1)
+
+      return ticker.read_value(), counter.read_value()
+
+    # Warm-up.
+    self.evaluate(run_loop(1))
+
+    t, c = run_loop(3)
+    self.assertEqual(self.evaluate(t), 3)
+    self.assertEqual(self.evaluate(c), 3)
+
+  def testIndependentSideEffectsInCond(self):
+
+    if not tf2.enabled():
+      self.skipTest("V2-only test.")
+
+    # TODO(b/152548567): Enable experimental_stateful_parallelism.
+    # Without proper wiring of control deps in the cond branch, the test is
+    # non-deterministic, running cond's record_side_effect ahead of its
+    # counterpart in the body.
+    while_v2.glob_stateful_parallelism = False
+
+    state = []
+
+    def record_side_effect(c):
+
+      def side_effect_py_fn():
+        state.append(c)
+        return 0
+
+      script_ops.eager_py_func(side_effect_py_fn, [], [dtypes.int32])
+
+    @def_function.function
+    def run_loop(n):
+
+      def complex_cond(i):
+        record_side_effect("A")
+        return i < n
+
+      i = constant_op.constant(0)
+
+      while complex_cond(i):
+        directives.set_loop_options(parallel_iterations=10)
+
+        record_side_effect("B")
+        i += 1
+
+      return i
+
+    # Warm-up.
+    self.evaluate(run_loop(1))
+
+    state.clear()
+    i = run_loop(3)
+    self.assertEqual(self.evaluate(i), 3)
+    self.assertListEqual(state, ["A", "B", "A", "B", "A", "B", "A"])
+
+  def testStatelessLoop(self):
+
+    while_v2.glob_stateful_parallelism = True
+
+    @def_function.function
+    def run_loop(n):
+
+      a = 0
+      b = 1
+
+      i = constant_op.constant(0)
+      while i < n:
+        directives.set_loop_options(parallel_iterations=10)
+        i += 1
+        a += 2
+        b *= 3
+
+      return i, a, b
+
+    i, a, b = run_loop(3)
+    self.assertEqual(self.evaluate(i), 3)
+    self.assertEqual(self.evaluate(a), 6)
+    self.assertEqual(self.evaluate(b), 27)
+
   def testStatefulParallelism(self):
 
     if not tf2.enabled():
       self.skipTest("V2-only test.")
 
+    while_v2.glob_stateful_parallelism = True
+
     ticker = variables.Variable(0)
     # Secondary state for the pyfunc that lets us verify that things ran in
     # the correct relative order.
@@ -1667,6 +1836,7 @@
 
       while i < n:
         directives.set_loop_options(parallel_iterations=10)
+
         wait_then_tick(i + 1)
         # The read is expected to run in much less than `wait_then_tick`,
         # which sleeps for 1s. Hence all reads should complete before the first
@@ -1682,12 +1852,14 @@
     # This test is deterministic so long as the runtime is fast enough to
     # execute `t_acc = t_acc.write(i, ticker.read_value())` in much less than
     # one second.
-    self.evaluate(ticker.assign(0))
+    self.evaluate(ticker.assign(123))
     ticker_state.clear()
     acc = run_loop(3)
-    # Because the loop runs entirely sequentially, the reads in each iteration
-    # see the effects of the pyfunc from the previous iteration.
-    self.assertEqual(self.evaluate(math_ops.reduce_max(acc)), 2)
+    # Because the loop iterations are allowed to run in parallel, reads from
+    # different iterations may proceed ahead of pyfuncs from other iterations.
+    # Because reads are much faster, they should all complete before a single
+    # pyfunc does.
+    self.assertEqual(self.evaluate(math_ops.reduce_max(acc)), 0)
 
     # Double-check that the loop ran completely.
     self.assertEqual(self.evaluate(ticker.read_value()), 3)
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 38668d9..2e0d0a2 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -56,6 +56,20 @@
 
 # pylint: disable=protected-access
 
+# Controls parallelism in the presence of side-effecting ops like variable
+# operations, print, py_function, etc. Can be set to True, False, or
+# "stateless_cond" (default).
+# Note that loops without side-effecting operations always execute with maximum
+# parallelism, ignoring this setting. When False, loops with side-effecting ops
+# execute sequentially, one iteration at a time.
+# When True, loops with side-effecting ops may execute parts of different
+# iterations in parallel; caution: if the loop condition contains
+# side-effecting ops, this mode produces unspecified results.
+# Setting it to "stateless_cond" automatically sets this mode to True when
+# the loop condition is free of side-effecting ops.
+# TODO(b/152548567): Change this to "stateless_cond".
+glob_stateful_parallelism = False
+
 
 def while_loop(cond,
                body,
@@ -152,6 +166,12 @@
             cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
         add_control_dependencies=add_control_dependencies)
 
+    if glob_stateful_parallelism == "stateless_cond":
+      stateful_parallelism = (not any(
+          op._is_stateful for op in cond_graph.get_operations()))
+    else:
+      stateful_parallelism = glob_stateful_parallelism
+
     def wrapped_body(loop_counter, maximum_iterations_arg, *args):
       """Loop body augmented with counter update.
 
@@ -199,7 +219,8 @@
         signature=signature,
         func_graph=util.WhileBodyFuncGraph(
             body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
-        add_control_dependencies=add_control_dependencies)
+        add_control_dependencies=add_control_dependencies,
+        acd_record_initial_resource_uses=stateful_parallelism)
     # Add external captures of body to the list of loop vars.
     # Note that external tensors will be treated as loop invariants, i.e.,
     # the value of that tensor in each iteration is the same as it was at the
@@ -279,7 +300,8 @@
           output_shapes=output_shapes,
           parallel_iterations=parallel_iterations,
           name=scope,
-          num_original_outputs=num_original_outputs)
+          num_original_outputs=num_original_outputs,
+          stateful_parallelism=stateful_parallelism)
     if not ops.get_default_graph().building_function:
       # In V1 graph mode, return identities for each output of the While op,
       # rather than the output of the While op directly. This makes pruning work
@@ -327,6 +349,11 @@
   except:  # pylint: disable=bare-except
     num_original_outputs = len(while_op.outputs)
 
+  try:
+    stateful_parallelism = while_op.get_attr("_stateful_parallelism")
+  except:  # pylint: disable=bare-except
+    stateful_parallelism = False
+
   num_intermediates = len(while_op.outputs) - num_original_outputs
   grads = [
       _preprocess_grad(grad, body_out, while_in, while_out)  # pylint: disable=g-complex-comprehension
@@ -353,7 +380,8 @@
 
   body_grad_graph, args = _create_grad_func(
       ys, xs, non_none_grads, cond_graph, body_graph,
-      util.unique_grad_fn_name(body_graph.name), op, maximum_iterations)
+      util.unique_grad_fn_name(body_graph.name), op, maximum_iterations,
+      stateful_parallelism)
 
   if body_grad_graph.while_op_needs_rewrite:
     # Modify 'op' to output the intermediate accumulators needed by the grad
@@ -414,7 +442,8 @@
       output_shapes=[t.shape for t in body_grad_graph.outputs],
       parallel_iterations=parallel_iterations,
       name="%s_grad" % while_op.name,
-      num_original_outputs=len(body_grad_graph.outputs))
+      num_original_outputs=len(body_grad_graph.outputs),
+      stateful_parallelism=stateful_parallelism)
 
   # See comment in while_loop.
   outputs = [array_ops.identity(t) for t in outputs]
@@ -422,7 +451,8 @@
 
 
 def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
-                    parallel_iterations, name, num_original_outputs):
+                    parallel_iterations, name, num_original_outputs,
+                    stateful_parallelism):
   """Builds the functional StatelessWhile/While op."""
   cond_stateful_ops = [
       op for op in cond_graph.get_operations() if op._is_stateful
@@ -450,6 +480,8 @@
     # This is needed so we do not compute derivative wrt these extra outputs.
     while_op._set_attr("_num_original_outputs",
                        attr_value_pb2.AttrValue(i=num_original_outputs))
+    while_op._set_attr("_stateful_parallelism",
+                       attr_value_pb2.AttrValue(b=stateful_parallelism))
     # The while op may be created inside a tf.function, in which case ops
     # needs to capture "through" it when taking gradients; outer_graph is used
     # as a sanity check that capturing only happens from parent to child.
@@ -606,7 +638,7 @@
 
 
 def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
-                      maximum_iterations):
+                      maximum_iterations, stateful_parallelism):
   """Builds and returns the gradient FuncGraph of `func_graph` and its args.
 
   The returned grad_func_graph must be called with the returned
@@ -621,6 +653,7 @@
     name: Name of the returned gradient function.
     while_op: The forward While op.
     maximum_iterations: Tensor. The maximum number of iterations.
+    stateful_parallelism: Bool, see tf.while_loop.
 
   Returns:
     2-tuple of (grad_func_graph, args).
@@ -650,7 +683,8 @@
       args, {},
       func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph,
                                          maximum_iterations, while_op,
-                                         body_graph_inputs, body_graph_outputs))
+                                         body_graph_inputs, body_graph_outputs),
+      acd_record_initial_resource_uses=stateful_parallelism)
 
   # Update the list of outputs with tensors corresponding to the captured
   # tensors. We capture 3 types of tensors when building the grad fn: