Colocate output identities with outputs when inlining multi-device functions
PiperOrigin-RevId: 280441133
Change-Id: Ie539eaa0d8eeef790a994a1abf8845b362d0a68c
diff --git a/tensorflow/core/common_runtime/colocation_graph.cc b/tensorflow/core/common_runtime/colocation_graph.cc
index 8a8aaac..c337766 100644
--- a/tensorflow/core/common_runtime/colocation_graph.cc
+++ b/tensorflow/core/common_runtime/colocation_graph.cc
@@ -644,14 +644,20 @@
bool found_spec = false;
const AttrValue* attr_value =
node->attrs().Find(kColocationAttrNameStringPiece);
- if (attr_value != nullptr && attr_value->has_list()) {
- for (const string& class_spec : attr_value->list().s()) {
- StringPiece spec(class_spec);
- if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
- found_spec = true;
- TF_RETURN_IF_ERROR(
- ColocateNodeToGroup(&colocation_group_root, node, spec));
+ if (attr_value != nullptr) {
+ if (attr_value->has_list()) {
+ for (const string& class_spec : attr_value->list().s()) {
+ StringPiece spec(class_spec);
+ if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
+ found_spec = true;
+ TF_RETURN_IF_ERROR(
+ ColocateNodeToGroup(&colocation_group_root, node, spec));
+ }
}
+ } else if (!attr_value->s().empty()) {
+ LOG(ERROR) << "The value for colocation attribute '_class' must be a "
+ "list of strings, not a single string: "
+ << node->DebugString();
}
}
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index cdf9248..6d36011 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -1545,6 +1545,7 @@
absl::optional<string> OutputNodeDevice(int output_index) const override {
return absl::nullopt;
}
+ bool ColocateOutputIdentity() const override { return false; }
absl::optional<string> ControlNodeDevice() const override {
return absl::nullopt;
}
@@ -1568,6 +1569,7 @@
absl::optional<string> OutputNodeDevice(int output_index) const override {
return caller_device_;
}
+ bool ColocateOutputIdentity() const override { return false; }
absl::optional<string> ControlNodeDevice() const override {
return caller_device_;
}
@@ -1598,6 +1600,7 @@
absl::optional<string> OutputNodeDevice(int output_index) const override {
return absl::nullopt;
}
+ bool ColocateOutputIdentity() const override { return true; }
absl::optional<string> ControlNodeDevice() const override {
return caller_device_;
}
@@ -1914,6 +1917,12 @@
Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
const absl::optional<string> device = placer->OutputNodeDevice(index);
if (device.has_value()) node->set_requested_device(*device);
+ bool colocate_identity = placer->ColocateOutputIdentity();
+ if (colocate_identity) {
+ node->AddAttr(kColocationAttrName,
+ std::vector<string>{absl::StrCat(kColocationGroupPrefix,
+ input.node->name())});
+ }
return node;
};
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
index 9ef70d2..790ffae 100644
--- a/tensorflow/core/common_runtime/function.h
+++ b/tensorflow/core/common_runtime/function.h
@@ -173,6 +173,9 @@
virtual absl::optional<string> InputNodeDevice(int input_index) const = 0;
virtual absl::optional<string> OutputNodeDevice(int output_index) const = 0;
+ // Returns true if the added output identity node should be colocated with the
+ // corresponding output from the function body.
+ virtual bool ColocateOutputIdentity() const = 0;
virtual absl::optional<string> ControlNodeDevice() const = 0;
virtual absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const = 0;
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index 625b8d6..7cbbe0c 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -35,6 +35,66 @@
namespace {
+struct NameCounts {
+ mutex counts_mutex;
+ std::unordered_map<string, int> counts;
+};
+
+string MakeUniqueFilename(string name) {
+ static NameCounts& instance = *new NameCounts;
+
+ // Remove illegal characters from `name`.
+ for (int i = 0; i < name.size(); ++i) {
+ char ch = name[i];
+ if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') {
+ name[i] = '_';
+ }
+ }
+
+ int count;
+ {
+ mutex_lock lock(instance.counts_mutex);
+ count = instance.counts[name]++;
+ }
+
+ string filename = name;
+ if (count > 0) {
+ absl::StrAppend(&filename, "_", count);
+ }
+ absl::StrAppend(&filename, ".txt");
+ return filename;
+}
+
+Status GetFileName(string base_name, string* fname) {
+ const char* dir = nullptr;
+ dir = getenv("TF_DUMP_GRAPH_PREFIX");
+ if (!dir) {
+ return errors::Internal("Failed to get the directory for ", base_name,
+ " because dump location is not specified through "
+ "TF_DUMP_GRAPH_PREFIX environment variable");
+ }
+ base_name = MakeUniqueFilename(base_name);
+ *fname = absl::StrCat(dir, "/", base_name);
+ return Status::OK();
+}
+
+void DumpColocationGraph(const string& base_name,
+ const ColocationGraph& colocation_graph) {
+ string fname;
+ Status status = GetFileName(base_name, &fname);
+ if (status.ok()) {
+ status = WriteStringToFile(Env::Default(), fname,
+ colocation_graph.DebugString());
+ if (status.ok()) {
+ LOG(INFO) << "Wrote ColocationGraph to " << fname;
+ }
+ }
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to write final colocation graph to file " << fname
+ << " with " << status.ToString();
+ }
+}
+
// Returns true if the node has no inputs and produces outputs
// that are consumed by a single node.
//
@@ -229,6 +289,7 @@
if (VLOG_IS_ON(3)) {
DumpGraphToFile("placer_output", *graph_, nullptr);
+ DumpColocationGraph("colocation_graph", colocation_graph);
}
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
index 2fbdba5..799599b 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
@@ -782,7 +782,7 @@
NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, cpu0),
NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, cpu1),
NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, cpu1),
- NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, cpu0),
+ NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, cpu1),
NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, cpu0)},
// Function library.
diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py
index a735081..e64d5fe 100644
--- a/tensorflow/python/framework/graph_util_impl.py
+++ b/tensorflow/python/framework/graph_util_impl.py
@@ -122,6 +122,14 @@
return n.split(":")[0]
+def _get_colocated_node_name(colocated_node_name):
+ """Decodes colocated node name and returns it without loc:@ preprended."""
+ colocated_node_decoded = colocated_node_name.decode("utf-8")
+ if colocated_node_decoded.startswith("loc:@"):
+ return colocated_node_decoded[5:]
+ return colocated_node_decoded
+
+
def _extract_graph_summary(graph_def):
"""Extracts useful information from the graph and returns them."""
name_to_input_name = {} # Keyed by the dest node name.
@@ -138,9 +146,8 @@
# Prevent colocated nodes from being lost.
if "_class" in node.attr:
for colocated_node_name in node.attr["_class"].list.s:
- colocated_node_decoded = colocated_node_name.decode("utf-8")
- if colocated_node_decoded.startswith("loc:@"):
- name_to_input_name[n].append(colocated_node_decoded[5:])
+ name_to_input_name[n].append(
+ _get_colocated_node_name(colocated_node_name))
name_to_seq_num[n] = seq
seq += 1
return name_to_input_name, name_to_node, name_to_seq_num
@@ -306,20 +313,24 @@
while (source_op_names and map_name_to_node[source_op_names[0]].op in
_CONTROL_FLOW_OP_NAMES_OR_IDENTITY):
source_op_name = source_op_names.pop()
+ current_node = map_name_to_node[source_op_name]
if source_op_name not in resource_op_types:
resource_op_types[source_op_name] = node.attr["dtype"]
- source_op_names.append(
- get_input_name(map_name_to_node[source_op_name]))
+ source_op_names.append(get_input_name(current_node))
- if map_name_to_node[source_op_name].op == "Merge":
- merge_resource_name = get_input_name(
- map_name_to_node[source_op_name], index=1)
+ if current_node == "Merge":
+ merge_resource_name = get_input_name(current_node, index=1)
if merge_resource_name not in resource_op_types:
resource_op_types[merge_resource_name] = node.attr["dtype"]
source_op_names.append(
get_input_name(map_name_to_node[merge_resource_name]))
+ if "_class" in current_node.attr:
+ for colocated_node_name in current_node.attr["_class"].list.s:
+ source_op_names.append(
+ _get_colocated_node_name(colocated_node_name))
+
for source_node in source_op_names:
if map_name_to_node[source_node].op != "VarHandleOp":
raise ValueError("Cannot find the variable that is an input "