[tf.data] Several optimizations for the graph hashing code.
1. Avoid copying the `GraphDef` each time a `GraphHasher` is created. The graph always outlives the hasher, so an unowned pointer is acceptable here. Should save O(#nodes) copies.
2. Use the same `FunctionLibraryDefinition` for all hashing. Previously we were converting it to and from a submessage of `GraphDef`, which led to a lot of copies, dynamic allocations, etc. Instead, we either build it once for the root node, or (ideally) the user passes in an already-constructed library, then we use that for all nodes. Since the function library typically has O(1) functions per node, this saves O(#nodes^2) copies.
PiperOrigin-RevId: 301307984
Change-Id: I6e28ffd1df908840e946e43d3be3dc2f5106eb55
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index 1502ff1..ee135b3 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -153,10 +153,11 @@
// https://stackoverflow.com/questions/11338746/directed-graphs-with-a-given-root-node-match-another-directed-graph-for-equali
class GraphHasher {
public:
- explicit GraphHasher(const GraphDef& graph_def, const NodeDef* root_node)
- : graph_def_(graph_def),
- root_node_(root_node),
- flib_def_(OpRegistry::Global(), graph_def.library()) {}
+ // `GraphHasher` does not take ownership of `graph_def`, `root_node`, or
+ // `flib_def`.
+ explicit GraphHasher(const GraphDef* graph_def, const NodeDef* root_node,
+ const FunctionLibraryDefinition* flib_def)
+ : graph_def_(graph_def), root_node_(root_node), flib_def_(flib_def) {}
Status ComputeHash(uint64* hash) {
TF_RETURN_IF_ERROR(Init());
@@ -189,7 +190,7 @@
TF_RETURN_IF_ERROR(ParseInputNodeName(node->input(i), &node_name,
&suffix, &is_control_input));
const NodeDef* input_node;
- TF_RETURN_IF_ERROR(FindNode(graph_def_, node_name, &input_node));
+ TF_RETURN_IF_ERROR(FindNode(*graph_def_, node_name, &input_node));
// If we've already seen this node before, skip it and don't add it to
// the queue.
@@ -308,20 +309,19 @@
}
Status HashFunction(const NameAttrList& func, uint64* hash) {
- const FunctionDef* fdef = flib_def_.Find(func.name());
+ const FunctionDef* fdef = flib_def_->Find(func.name());
// Convert to a GraphDef.
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
- &flib_def_, &fbody));
+ flib_def_, &fbody));
GraphDef graph_def = fbody->graph->ToGraphDefDebug();
- graph_def.mutable_library()->MergeFrom(flib_def_.ToProto());
// For each return node, we create a new GraphHasher to compute a hash.
// We then combine these hashes to produce the hash ordered.
uint64 ret_nodes_hash = 0;
for (const auto& ret_node : fbody->ret_nodes) {
- GraphHasher ret_node_hasher(graph_def, &ret_node->def());
+ GraphHasher ret_node_hasher(&graph_def, &ret_node->def(), flib_def_);
uint64 ret_node_hash = 0;
TF_RETURN_IF_ERROR(ret_node_hasher.ComputeHash(&ret_node_hash));
ret_nodes_hash = Hash64Combine(ret_nodes_hash, ret_node_hash);
@@ -359,9 +359,9 @@
}
};
- const GraphDef graph_def_;
- const NodeDef* root_node_;
- const FunctionLibraryDefinition flib_def_;
+ const GraphDef* const graph_def_; // Not owned.
+ const NodeDef* const root_node_; // Not owned.
+ const FunctionLibraryDefinition* const flib_def_; // Not owned.
// Edges that need to be pruned as their presence will cause cycles.
absl::flat_hash_set<uint64> cycle_forming_edges_;
absl::flat_hash_map<const NodeDef*, NodeRep> nodes_;
@@ -397,7 +397,14 @@
}
Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash) {
- GraphHasher graph_hasher(graph, &node);
+ const FunctionLibraryDefinition flib_def(OpRegistry::Global(),
+ graph.library());
+ return HashNode(graph, node, flib_def, hash);
+}
+
+Status HashNode(const GraphDef& graph, const NodeDef& node,
+ const FunctionLibraryDefinition& flib_def, uint64* hash) {
+ GraphHasher graph_hasher(&graph, &node, &flib_def);
return graph_hasher.ComputeHash(hash);
}
@@ -414,7 +421,9 @@
return errors::Internal("Cannot find sink node for dataset graph.");
}
- GraphHasher graph_hasher(graph_def, sink);
+ const FunctionLibraryDefinition flib_def(OpRegistry::Global(),
+ graph_def.library());
+ GraphHasher graph_hasher(&graph_def, sink, &flib_def);
TF_RETURN_IF_ERROR(graph_hasher.ComputeHash(hash));
return Status::OK();
}
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 7c0857a..bedd5fa 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -135,6 +135,8 @@
// NOTE: There is currently no guarantee that the hash of a subgraph will stay
// the same between TensorFlow builds.
Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash);
+Status HashNode(const GraphDef& graph, const NodeDef& node,
+ const FunctionLibraryDefinition& flib_def, uint64* hash);
// Returns a stable hash of the given tensor.
//
diff --git a/tensorflow/core/kernels/data/rewrite_utils.cc b/tensorflow/core/kernels/data/rewrite_utils.cc
index 3717016..609c402 100644
--- a/tensorflow/core/kernels/data/rewrite_utils.cc
+++ b/tensorflow/core/kernels/data/rewrite_utils.cc
@@ -195,8 +195,10 @@
if (record_fingerprint) {
(*ctx->runner())([graph_def = std::move(graph_def),
+ lib_def = lib_def.release(),
input_list = std::move(input_list),
output_node = std::move(output_node)]() {
+ std::unique_ptr<FunctionLibraryDefinition> lib_def_owner(lib_def);
const NodeDef* node_def = nullptr;
for (const auto& node : graph_def.node()) {
if (node.name() == output_node) {
@@ -209,7 +211,7 @@
return;
}
uint64 hash = 0;
- Status s = HashNode(graph_def, *node_def, &hash);
+ Status s = HashNode(graph_def, *node_def, *lib_def, &hash);
if (!s.ok()) {
VLOG(3) << "Failed to hash graph: " << s.ToString();
return;