Do not create a cycle when adding control edges in ScopedAllocatorOptimizer.

ScopedAllocatorOptimizer adds a few control edges from the parent of the input
to collective ops.  More context: cl/265532843.

Before this change, the optimizer would not check if the node from which it was
adding a control edge to the ScopedAllocator node was already in the set of
inputs to the collective op group (input set).  If that was the case, this
would result in a cycle, because the optimizer also adds control edges from
ScopedAllocator node to all nodes in the input set.

After this change, the optimizer avoids adding a control edge if the node
belongs to the transitive fanout of the input set.

PiperOrigin-RevId: 280249308
Change-Id: Iacfb2543b2ac4ec4f0f0966ea73c903ab507cdaa
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
index bc9ddf3..12018a5 100644
--- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
@@ -457,6 +457,33 @@
     return Status::OK();
   }
 
+  // Returns the set of all nodes that are transitively reachable via data or
+  // control edges starting at `source_nodes`.  Stop at the boundary of a frame.
+  Status TransitiveFanoutWithinFrame(
+      GraphDef* graph, NodeMap* node_map,
+      const std::vector<const NodeDef*>& source_nodes,
+      absl::flat_hash_set<const NodeDef*>* fanout) {
+    std::deque<const NodeDef*> queue(source_nodes.begin(), source_nodes.end());
+    absl::flat_hash_set<const NodeDef*> visited;
+    while (!queue.empty()) {
+      const NodeDef* node = queue.front();
+      queue.pop_front();
+      if (!visited.insert(node).second) {
+        continue;
+      }
+      fanout->insert(node);
+      for (const NodeDef* output : node_map->GetOutputs(node->name())) {
+        if (!ModifiesFrameInfo(*output)) {
+          queue.push_back(output);
+        }
+        VLOG(2) << "TransitiveFanout parent: " << node->name()
+                << " child: " << output->name() << " of type " << output->op();
+      }
+    }
+
+    return Status::OK();
+  }
+
   // Build the ScopedAllocator node that will be assigned to allocate
   // the output tensors of the input node set.
   Status ConstructScopedAllocatorNode(
@@ -478,6 +505,15 @@
     LOG_WARNING_AND_RETURN_IF_ERROR(sa_builder.Finalize(sa_node));
     node_map->AddNode(sa_name, sa_node);
 
+    std::vector<const NodeDef*> fanout_sources;
+    fanout_sources.reserve(inputs.size());
+    for (const auto& input : inputs) {
+      fanout_sources.push_back(input.from_node_def);
+    }
+    absl::flat_hash_set<const NodeDef*> fanout;
+    TF_RETURN_IF_ERROR(
+        TransitiveFanoutWithinFrame(graph, node_map, fanout_sources, &fanout));
+
     // Add control edges from the ScopedAllocatorOp to all of the
     // input nodes and mark them for allocation from backing tensor.
     for (int i = 0; i < inputs.size(); ++i) {
@@ -496,18 +532,36 @@
 
     // We add control edges in order to delay execution of the ScopedAllocatorOp
     // until just before first use in order to conserve memory.
-    {
-      auto& nd = inputs[0];
+    bool added_delay_edge = false;
+    for (auto& nd : inputs) {
       std::vector<InputDesc> inputs_to_first;
       LOG_WARNING_AND_RETURN_IF_ERROR(GetDataInputs(
           graph, sa_opti->node_map(), nd.from_node_def, &inputs_to_first));
       for (int i = 0; i < inputs_to_first.size(); ++i) {
+        if (fanout.find(inputs_to_first[i].from_node_def) != fanout.end()) {
+          VLOG(2) << "Found node " << inputs_to_first[i].from_node_def->name()
+                  << " in the fanout of " << sa_name;
+          continue;
+        }
         sa_node->add_input(
             strings::StrCat("^", inputs_to_first[i].from_node_def->name()));
+        node_map->AddOutput(inputs_to_first[i].from_node_def->name(), sa_name);
+        added_delay_edge = true;
         VLOG(2) << "Adding control dependency from "
                 << inputs_to_first[i].from_node_def->name() << " to "
                 << sa_node->name();
+        break;
       }
+      if (added_delay_edge) {
+        break;
+      }
+    }
+
+    if (!added_delay_edge) {
+      LOG(WARNING) << "Found no node from which a control edge can be added to "
+                      "scoped allocator node.  If you run into issues with "
+                      "graphs that contain control flow, turn off the "
+                      "ScopedAllocatorOptimizer and file a bug.";
     }
 
     return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
index 5fd8a12..7cab416 100644
--- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
@@ -64,7 +64,7 @@
   // (Flow is top to bottom, like nature intends.)
   //
   // The intended optimization is to have s1 and s2 allocate from
-  // an new ScopedAllocator, then replace a1 and a2 with a3 that
+  // a new ScopedAllocator, then replace a1 and a2 with a3 that
   // reads from the backing buffer.
   /*
         a    b    c
@@ -105,6 +105,49 @@
   }
 
   // Constructs the following graph.
+  // (Flow is top to bottom, like nature intends.)
+  //
+  // a, b, and c are constants.  s is an Add op.  a1, a2, and a3 are Abs ops.
+  // r1, r2, and r3 are Reshape ops.
+  //
+  // After this graph undergoes SA optimization, we expect a, b, and s to be
+  // allocated from a new ScopedAllocator.  There will be control edges from the
+  // ScopedAllocator node to a, b, and s, to ensure that we allocate the
+  // backing tensor before we need it.  There will also be a control edge from c
+  // to ScopedAllocator node, so that we delay allocation as much as possible.
+  // There should be no edge from b to ScopedAllocator node, because that would
+  // imply a cycle in the graph.
+  /*
+      a      b     c
+      |     / \   /
+      |    /   \ /
+      |    |    s1
+      |    |    |
+      a1   a2   a3
+      |    |    |
+      r1   r2   r3
+  */
+  void BuildAbsGraphWithInputDependencies(GraphDef* graph_def) {
+    Scope s = Scope::NewRootScope();
+    s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
+
+    Output a =
+        ops::Const<float>(s.WithOpName("a"), {1.0, 0.0, 0.0, -1.0}, {2, 2});
+    Output b =
+        ops::Const<float>(s.WithOpName("b"), {1.0, -2.0, 3.0, 4.0}, {2, 2});
+    Output c =
+        ops::Const<float>(s.WithOpName("c"), {-5.0, -2.0, 0.0, -2.0}, {2, 2});
+    Output s1 = ops::Add(s.WithOpName("s1"), b, c);
+    Output a1 = ops::Abs(s.WithOpName("a1"), a);
+    Output a2 = ops::Abs(s.WithOpName("a2"), b);
+    Output a3 = ops::Abs(s.WithOpName("a3"), s1);
+    Output r1 = ops::Reshape(s.WithOpName("r1"), a1, {1, 4});
+    Output r2 = ops::Reshape(s.WithOpName("r2"), a2, {4, 1});
+    Output r3 = ops::Reshape(s.WithOpName("r3"), a3, {4, 1});
+    TF_CHECK_OK(s.ToGraphDef(graph_def));
+  }
+
+  // Constructs the following graph.
   //
   // We have 2 different name scopes in this graph.  s3, a3, a4, r3, and r4 are
   // all under "sub" scope.  All other nodes are in the root scope.
@@ -203,6 +246,27 @@
       }
     }
   }
+
+  // Validate that a node has a single control input from scoped allocator node.
+  // Return the scoped allocator node.
+  NodeDef* ValidateSAControlInput(GraphDef* graph, NodeMap* node_map,
+                                  const string& node_name) {
+    NodeDef* node = node_map->GetNode(node_name);
+    EXPECT_TRUE(node);
+    int num_control_inputs = 0;
+    string control_input_name;
+    for (const auto& input : node->input()) {
+      if (input[0] == '^') {
+        ++num_control_inputs;
+        control_input_name = input;
+      }
+    }
+    EXPECT_EQ(num_control_inputs, 1);
+    NodeDef* control_input_node = node_map->GetNode(control_input_name);
+    EXPECT_TRUE(control_input_node);
+    EXPECT_EQ(control_input_node->op(), "_ScopedAllocator");
+    return control_input_node;
+  }
 };
 
 TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) {
@@ -324,6 +388,38 @@
   ValidateValues(outputs, /*expected=*/{{2, 2, 3, 3}, {4, 4, 3, 2}});
 }
 
+// Test that graphs with a dependency upstream from the inputs, such as the one
+// produced by `BuildAbsGraphWithInputDependencies`, are handled well by this
+// optimizer.  In particular, the optimizer should not create cycles.
+TEST_F(ScopedAllocatorOptimizerTest, InputDependencies) {
+  GrapplerItem item;
+  BuildAbsGraphWithInputDependencies(&item.graph);
+  SetShapes(&item.graph);
+
+  ScopedAllocatorOptions opts;
+  opts.add_enable_op("Abs");
+  ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts);
+  ScopedAllocatorOptimizer::OpNameSet ons;
+  ons.insert("Add");
+
+  GraphDef optimized_graph;
+  TF_ASSERT_OK(sao.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
+  NodeMap node_map(&optimized_graph);
+
+  // Check that all inputs to Abs ops have ScopedAllocator as a control
+  // dependency.
+  NodeDef* scoped_allocator_node =
+      ValidateSAControlInput(&optimized_graph, &node_map, "a");
+  VLOG(1) << scoped_allocator_node->DebugString();
+  EXPECT_TRUE(ValidateSAControlInput(&optimized_graph, &node_map, "b"));
+  EXPECT_TRUE(ValidateSAControlInput(&optimized_graph, &node_map, "s1"));
+
+  // Check that ScopedAllocator node has a single input, which is a control edge
+  // from c.
+  EXPECT_EQ(scoped_allocator_node->input_size(), 1);
+  EXPECT_EQ(scoped_allocator_node->input(0), "^c");
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow