[tf.data] Several optimizations for the graph hashing code.

1. Avoid copying the `GraphDef` each time a `GraphHasher` is created. The graph always outlives the hasher, so an unowned pointer is acceptable here. Should save O(#nodes) copies.
2. Use the same `FunctionLibraryDefinition` for all hashing. Previously we were converting it to and from a submessage of `GraphDef`, which led to a lot of copies, dynamic allocations, etc. Instead, we either build it once for the root node, or (ideally) the user passes in an already-constructed library, then we use that for all nodes. Since the function library typically has O(1) functions per node, this saves O(#nodes^2) copies.

PiperOrigin-RevId: 301307984
Change-Id: I6e28ffd1df908840e946e43d3be3dc2f5106eb55
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index 1502ff1..ee135b3 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -153,10 +153,11 @@
 // https://stackoverflow.com/questions/11338746/directed-graphs-with-a-given-root-node-match-another-directed-graph-for-equali
 class GraphHasher {
  public:
-  explicit GraphHasher(const GraphDef& graph_def, const NodeDef* root_node)
-      : graph_def_(graph_def),
-        root_node_(root_node),
-        flib_def_(OpRegistry::Global(), graph_def.library()) {}
+  // `GraphHasher` does not take ownership of `graph_def`, `root_node`, or
+  // `flib_def`.
+  explicit GraphHasher(const GraphDef* graph_def, const NodeDef* root_node,
+                       const FunctionLibraryDefinition* flib_def)
+      : graph_def_(graph_def), root_node_(root_node), flib_def_(flib_def) {}
 
   Status ComputeHash(uint64* hash) {
     TF_RETURN_IF_ERROR(Init());
@@ -189,7 +190,7 @@
         TF_RETURN_IF_ERROR(ParseInputNodeName(node->input(i), &node_name,
                                               &suffix, &is_control_input));
         const NodeDef* input_node;
-        TF_RETURN_IF_ERROR(FindNode(graph_def_, node_name, &input_node));
+        TF_RETURN_IF_ERROR(FindNode(*graph_def_, node_name, &input_node));
 
         // If we've already seen this node before, skip it and don't add it to
         // the queue.
@@ -308,20 +309,19 @@
   }
 
   Status HashFunction(const NameAttrList& func, uint64* hash) {
-    const FunctionDef* fdef = flib_def_.Find(func.name());
+    const FunctionDef* fdef = flib_def_->Find(func.name());
 
     // Convert to a GraphDef.
     std::unique_ptr<FunctionBody> fbody;
     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
-                                               &flib_def_, &fbody));
+                                               flib_def_, &fbody));
     GraphDef graph_def = fbody->graph->ToGraphDefDebug();
-    graph_def.mutable_library()->MergeFrom(flib_def_.ToProto());
 
     // For each return node, we create a new GraphHasher to compute a hash.
     // We then combine these hashes to produce the hash ordered.
     uint64 ret_nodes_hash = 0;
     for (const auto& ret_node : fbody->ret_nodes) {
-      GraphHasher ret_node_hasher(graph_def, &ret_node->def());
+      GraphHasher ret_node_hasher(&graph_def, &ret_node->def(), flib_def_);
       uint64 ret_node_hash = 0;
       TF_RETURN_IF_ERROR(ret_node_hasher.ComputeHash(&ret_node_hash));
       ret_nodes_hash = Hash64Combine(ret_nodes_hash, ret_node_hash);
@@ -359,9 +359,9 @@
     }
   };
 
-  const GraphDef graph_def_;
-  const NodeDef* root_node_;
-  const FunctionLibraryDefinition flib_def_;
+  const GraphDef* const graph_def_;                  // Not owned.
+  const NodeDef* const root_node_;                   // Not owned.
+  const FunctionLibraryDefinition* const flib_def_;  // Not owned.
   // Edges that need to be pruned as their presence will cause cycles.
   absl::flat_hash_set<uint64> cycle_forming_edges_;
   absl::flat_hash_map<const NodeDef*, NodeRep> nodes_;
@@ -397,7 +397,14 @@
 }
 
 Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash) {
-  GraphHasher graph_hasher(graph, &node);
+  const FunctionLibraryDefinition flib_def(OpRegistry::Global(),
+                                           graph.library());
+  return HashNode(graph, node, flib_def, hash);
+}
+
+Status HashNode(const GraphDef& graph, const NodeDef& node,
+                const FunctionLibraryDefinition& flib_def, uint64* hash) {
+  GraphHasher graph_hasher(&graph, &node, &flib_def);
   return graph_hasher.ComputeHash(hash);
 }
 
@@ -414,7 +421,9 @@
     return errors::Internal("Cannot find sink node for dataset graph.");
   }
 
-  GraphHasher graph_hasher(graph_def, sink);
+  const FunctionLibraryDefinition flib_def(OpRegistry::Global(),
+                                           graph_def.library());
+  GraphHasher graph_hasher(&graph_def, sink, &flib_def);
   TF_RETURN_IF_ERROR(graph_hasher.ComputeHash(hash));
   return Status::OK();
 }
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 7c0857a..bedd5fa 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -135,6 +135,8 @@
 // NOTE: There is currently no guarantee that the hash of a subgraph will stay
 // the same between TensorFlow builds.
 Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash);
+Status HashNode(const GraphDef& graph, const NodeDef& node,
+                const FunctionLibraryDefinition& flib_def, uint64* hash);
 
 // Returns a stable hash of the given tensor.
 //
diff --git a/tensorflow/core/kernels/data/rewrite_utils.cc b/tensorflow/core/kernels/data/rewrite_utils.cc
index 3717016..609c402 100644
--- a/tensorflow/core/kernels/data/rewrite_utils.cc
+++ b/tensorflow/core/kernels/data/rewrite_utils.cc
@@ -195,8 +195,10 @@
 
   if (record_fingerprint) {
     (*ctx->runner())([graph_def = std::move(graph_def),
+                      lib_def = lib_def.release(),
                       input_list = std::move(input_list),
                       output_node = std::move(output_node)]() {
+      std::unique_ptr<FunctionLibraryDefinition> lib_def_owner(lib_def);
       const NodeDef* node_def = nullptr;
       for (const auto& node : graph_def.node()) {
         if (node.name() == output_node) {
@@ -209,7 +211,7 @@
         return;
       }
       uint64 hash = 0;
-      Status s = HashNode(graph_def, *node_def, &hash);
+      Status s = HashNode(graph_def, *node_def, *lib_def, &hash);
       if (!s.ok()) {
         VLOG(3) << "Failed to hash graph: " << s.ToString();
         return;