Colocate output identities with outputs when inlining multi-device functions

PiperOrigin-RevId: 280441133
Change-Id: Ie539eaa0d8eeef790a994a1abf8845b362d0a68c
diff --git a/tensorflow/core/common_runtime/colocation_graph.cc b/tensorflow/core/common_runtime/colocation_graph.cc
index 8a8aaac..c337766 100644
--- a/tensorflow/core/common_runtime/colocation_graph.cc
+++ b/tensorflow/core/common_runtime/colocation_graph.cc
@@ -644,14 +644,20 @@
     bool found_spec = false;
     const AttrValue* attr_value =
         node->attrs().Find(kColocationAttrNameStringPiece);
-    if (attr_value != nullptr && attr_value->has_list()) {
-      for (const string& class_spec : attr_value->list().s()) {
-        StringPiece spec(class_spec);
-        if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
-          found_spec = true;
-          TF_RETURN_IF_ERROR(
-              ColocateNodeToGroup(&colocation_group_root, node, spec));
+    if (attr_value != nullptr) {
+      if (attr_value->has_list()) {
+        for (const string& class_spec : attr_value->list().s()) {
+          StringPiece spec(class_spec);
+          if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
+            found_spec = true;
+            TF_RETURN_IF_ERROR(
+                ColocateNodeToGroup(&colocation_group_root, node, spec));
+          }
         }
+      } else if (!attr_value->s().empty()) {
+        LOG(ERROR) << "The value for colocation attribute '_class' must be a "
+                      "list of strings, not a single string: "
+                   << node->DebugString();
       }
     }
 
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index cdf9248..6d36011 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -1545,6 +1545,7 @@
   absl::optional<string> OutputNodeDevice(int output_index) const override {
     return absl::nullopt;
   }
+  bool ColocateOutputIdentity() const override { return false; }
   absl::optional<string> ControlNodeDevice() const override {
     return absl::nullopt;
   }
@@ -1568,6 +1569,7 @@
   absl::optional<string> OutputNodeDevice(int output_index) const override {
     return caller_device_;
   }
+  bool ColocateOutputIdentity() const override { return false; }
   absl::optional<string> ControlNodeDevice() const override {
     return caller_device_;
   }
@@ -1598,6 +1600,7 @@
   absl::optional<string> OutputNodeDevice(int output_index) const override {
     return absl::nullopt;
   }
+  bool ColocateOutputIdentity() const override { return true; }
   absl::optional<string> ControlNodeDevice() const override {
     return caller_device_;
   }
@@ -1914,6 +1917,12 @@
     Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
     const absl::optional<string> device = placer->OutputNodeDevice(index);
     if (device.has_value()) node->set_requested_device(*device);
+    bool colocate_identity = placer->ColocateOutputIdentity();
+    if (colocate_identity) {
+      node->AddAttr(kColocationAttrName,
+                    std::vector<string>{absl::StrCat(kColocationGroupPrefix,
+                                                     input.node->name())});
+    }
     return node;
   };
 
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
index 9ef70d2..790ffae 100644
--- a/tensorflow/core/common_runtime/function.h
+++ b/tensorflow/core/common_runtime/function.h
@@ -173,6 +173,9 @@
 
   virtual absl::optional<string> InputNodeDevice(int input_index) const = 0;
   virtual absl::optional<string> OutputNodeDevice(int output_index) const = 0;
+  // Returns true if the added output identity node should be colocated with the
+  // corresponding output from the function body.
+  virtual bool ColocateOutputIdentity() const = 0;
   virtual absl::optional<string> ControlNodeDevice() const = 0;
   virtual absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const = 0;
 
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index 625b8d6..7cbbe0c 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -35,6 +35,66 @@
 
 namespace {
 
+struct NameCounts {
+  mutex counts_mutex;
+  std::unordered_map<string, int> counts;
+};
+
+string MakeUniqueFilename(string name) {
+  static NameCounts& instance = *new NameCounts;
+
+  // Remove illegal characters from `name`.
+  for (int i = 0; i < name.size(); ++i) {
+    char ch = name[i];
+    if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') {
+      name[i] = '_';
+    }
+  }
+
+  int count;
+  {
+    mutex_lock lock(instance.counts_mutex);
+    count = instance.counts[name]++;
+  }
+
+  string filename = name;
+  if (count > 0) {
+    absl::StrAppend(&filename, "_", count);
+  }
+  absl::StrAppend(&filename, ".txt");
+  return filename;
+}
+
+Status GetFileName(string base_name, string* fname) {
+  const char* dir = nullptr;
+  dir = getenv("TF_DUMP_GRAPH_PREFIX");
+  if (!dir) {
+    return errors::Internal("Failed to get the directory for ", base_name,
+                            " because dump location is not specified through "
+                            "TF_DUMP_GRAPH_PREFIX environment variable");
+  }
+  base_name = MakeUniqueFilename(base_name);
+  *fname = absl::StrCat(dir, "/", base_name);
+  return Status::OK();
+}
+
+void DumpColocationGraph(const string& base_name,
+                         const ColocationGraph& colocation_graph) {
+  string fname;
+  Status status = GetFileName(base_name, &fname);
+  if (status.ok()) {
+    status = WriteStringToFile(Env::Default(), fname,
+                               colocation_graph.DebugString());
+    if (status.ok()) {
+      LOG(INFO) << "Wrote ColocationGraph to " << fname;
+    }
+  }
+  if (!status.ok()) {
+    LOG(ERROR) << "Failed to write final colocation graph to file " << fname
+               << " with " << status.ToString();
+  }
+}
+
 // Returns true if the node has no inputs and produces outputs
 // that are consumed by a single node.
 //
@@ -229,6 +289,7 @@
 
   if (VLOG_IS_ON(3)) {
     DumpGraphToFile("placer_output", *graph_, nullptr);
+    DumpColocationGraph("colocation_graph", colocation_graph);
   }
   return Status::OK();
 }
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
index 2fbdba5..799599b 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
@@ -782,7 +782,7 @@
        NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, cpu0),
        NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, cpu1),
        NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, cpu1),
-       NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, cpu0),
+       NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, cpu1),
 
        NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, cpu0)},
       // Function library.
diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py
index a735081..e64d5fe 100644
--- a/tensorflow/python/framework/graph_util_impl.py
+++ b/tensorflow/python/framework/graph_util_impl.py
@@ -122,6 +122,14 @@
     return n.split(":")[0]
 
 
+def _get_colocated_node_name(colocated_node_name):
+  """Decodes colocated node name and returns it without loc:@ preprended."""
+  colocated_node_decoded = colocated_node_name.decode("utf-8")
+  if colocated_node_decoded.startswith("loc:@"):
+    return colocated_node_decoded[5:]
+  return colocated_node_decoded
+
+
 def _extract_graph_summary(graph_def):
   """Extracts useful information from the graph and returns them."""
   name_to_input_name = {}  # Keyed by the dest node name.
@@ -138,9 +146,8 @@
     # Prevent colocated nodes from being lost.
     if "_class" in node.attr:
       for colocated_node_name in node.attr["_class"].list.s:
-        colocated_node_decoded = colocated_node_name.decode("utf-8")
-        if colocated_node_decoded.startswith("loc:@"):
-          name_to_input_name[n].append(colocated_node_decoded[5:])
+        name_to_input_name[n].append(
+            _get_colocated_node_name(colocated_node_name))
     name_to_seq_num[n] = seq
     seq += 1
   return name_to_input_name, name_to_node, name_to_seq_num
@@ -306,20 +313,24 @@
       while (source_op_names and map_name_to_node[source_op_names[0]].op in
              _CONTROL_FLOW_OP_NAMES_OR_IDENTITY):
         source_op_name = source_op_names.pop()
+        current_node = map_name_to_node[source_op_name]
 
         if source_op_name not in resource_op_types:
           resource_op_types[source_op_name] = node.attr["dtype"]
-          source_op_names.append(
-              get_input_name(map_name_to_node[source_op_name]))
+          source_op_names.append(get_input_name(current_node))
 
-        if map_name_to_node[source_op_name].op == "Merge":
-          merge_resource_name = get_input_name(
-              map_name_to_node[source_op_name], index=1)
+        if current_node == "Merge":
+          merge_resource_name = get_input_name(current_node, index=1)
           if merge_resource_name not in resource_op_types:
             resource_op_types[merge_resource_name] = node.attr["dtype"]
             source_op_names.append(
                 get_input_name(map_name_to_node[merge_resource_name]))
 
+        if "_class" in current_node.attr:
+          for colocated_node_name in current_node.attr["_class"].list.s:
+            source_op_names.append(
+                _get_colocated_node_name(colocated_node_name))
+
       for source_node in source_op_names:
         if map_name_to_node[source_node].op != "VarHandleOp":
           raise ValueError("Cannot find the variable that is an input "