Do not create duplicate variables on the TPU

Provide an option to only create one TPU variable for each CPU
variable that must be mirrored to the TPU. By default, a TPU
variable is created for each variable arg to the
TPUPartitionedCall. If there are multiple args that are all backed
by the same variable, the same variable is mirrored to the TPU
multiple times.

PiperOrigin-RevId: 378272842
Change-Id: I0fb80cb7afbb7a4dc841bd10530b212692d4c937
diff --git a/tensorflow/core/tpu/kernels/tpu_functional_ops.cc b/tensorflow/core/tpu/kernels/tpu_functional_ops.cc
index 668320e..1c562db 100644
--- a/tensorflow/core/tpu/kernels/tpu_functional_ops.cc
+++ b/tensorflow/core/tpu/kernels/tpu_functional_ops.cc
@@ -1580,6 +1580,18 @@
 Status TPUPartitionedCallOp::ReplaceResourceArgsWithVarHandleOps(
     Graph* graph, OpKernelContext* ctx, int device_ordinal,
     int num_cores_per_replica, bool enable_spmd_xla_partitioning) {
+  // Currently variable deduplication is not supported for XLA SPMD
+  // partitioning. It is possible that it could be supported in the future.
+  const bool enable_variable_deduplication =
+      runtime_params_.enable_variable_deduplication;
+  if (enable_spmd_xla_partitioning && enable_variable_deduplication) {
+    // If enable_spmd_xla_partitioning is true, the user set the
+    // enable_auto_xla_input_sharding flag. Warn them that only one of the flags
+    // can be set safely.
+    return errors::InvalidArgument(
+        "The following flags are incompatible: enable_auto_xla_input_sharding "
+        "and enable_variable_deduplication. Only enable one of the flags.");
+  }
   std::vector<Node*> tpu_resource_args;
   std::vector<int> arg_indices;
   absl::flat_hash_map<const Node*, xla::OpSharding> variable_to_xla_sharding;
@@ -1601,6 +1613,11 @@
   }
 
   VLOG(3) << "tpu_resource_args.size(): " << tpu_resource_args.size();
+  // Create a mapping from ResourceHandle to variable node. When a
+  // ResourceHandle backs several variable nodes, the variable nodes refer to
+  // the same underlying resource. In that case, only one variable node needs
+  // to be mirrored to the TPU for that resource.
+  absl::flat_hash_map<uint64, Node*> tpu_variables;
   for (int i = 0; i < tpu_resource_args.size(); i++) {
     Node* node = tpu_resource_args[i];
     ResourceHandle handle = HandleFromInput(ctx, arg_indices[i]);
@@ -1616,69 +1633,91 @@
     // Only respect graph's placement when model parallelism enabled.
     if (num_cores_per_replica > 1) device_ordinal = var_info.device_ordinal;
 
-    uint64 fp =
-        Fingerprint64(strings::StrCat(handle.container(), handle.name(), i));
-    NodeDef ndef;
-    ndef.set_name(strings::StrCat(handle.name(), fp));
-    ndef.set_op(kVarHandleOp);
-    if (num_cores_per_replica > 1) {
-      ndef.set_device(strings::StrCat(kTPUDeviceNamePrefix, device_ordinal));
+    const uint64 handle_fp =
+        Fingerprint64(strings::StrCat(handle.container(), handle.name()));
+    if (enable_variable_deduplication && tpu_variables.contains(handle_fp) &&
+        num_cores_per_replica == 1) {
+      Node* tpu_variable = tpu_variables.at(handle_fp);
+      std::vector<Node*> dst_nodes;
+      std::vector<int> src_indices;
+      std::vector<int> dst_indices;
+      for (const Edge* edge : node->out_edges()) {
+        dst_nodes.push_back(edge->dst());
+        src_indices.push_back(edge->src_output());
+        dst_indices.push_back(edge->dst_input());
+      }
+      graph->RemoveNode(node);
+      for (int i = 0; i < dst_nodes.size(); i++) {
+        graph->AddEdge(tpu_variable, src_indices[i], dst_nodes[i],
+                       dst_indices[i]);
+      }
     } else {
-      // Assign this new VarHandleOp to TPU:0 so the partitioner only partiitons
-      // the graph into two subgraphs, one on CPU and one on TPU. The actual
-      // device ordinal on which this VarHandleOp runs is assigned after
-      // partitioning (in SetDeviceOrdinal).
-      ndef.set_device(
-          strings::StrCat(kTPUDeviceNamePrefix, kTPUDefaultDeviceOrdinal));
-    }
+      uint64 fp =
+          Fingerprint64(strings::StrCat(handle.container(), handle.name(), i));
+      NodeDef ndef;
+      ndef.set_name(strings::StrCat(handle.name(), fp));
+      ndef.set_op(kVarHandleOp);
+      if (num_cores_per_replica > 1) {
+        ndef.set_device(strings::StrCat(kTPUDeviceNamePrefix, device_ordinal));
+      } else {
+        // Assign this new VarHandleOp to TPU:0 so the partitioner only
+        // partiitons the graph into two subgraphs, one on CPU and one on TPU.
+        // The actual device ordinal on which this VarHandleOp runs is assigned
+        // after partitioning (in SetDeviceOrdinal).
+        ndef.set_device(
+            strings::StrCat(kTPUDeviceNamePrefix, kTPUDefaultDeviceOrdinal));
+      }
 
-    // Replace each _Arg node of type DT_RESOURCE that goes into a TPU node
-    // by a VarHandleOp on TPU with shared_name "v_tpu_x" where "v" is the
-    // shared_name of the variable on CPU and "x" is the rewritten device
-    // ordinal.
-    const string sname =
-        strings::StrCat(handle.name(), "_tpu_", device_ordinal);
-    AddNodeAttr("shared_name", sname, &ndef);
-    const string cname = ctx->resource_manager()->default_container();
-    AddNodeAttr("container", cname, &ndef);
-    core::RefCountPtr<Var> var;
-    TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &var));
-    AddNodeAttr("dtype", var->tensor()->dtype(), &ndef);
-    TensorShapeProto proto;
-    var->tensor()->shape().AsProto(&proto);
-    AddNodeAttr("shape", proto, &ndef);
-    Status status;
-    Node* new_node = graph->AddNode(ndef, &status);
-    TF_RETURN_IF_ERROR(status);
-    std::vector<const Edge*> in_edges(node->in_edges().begin(),
-                                      node->in_edges().end());
-    for (const Edge* edge : in_edges) {
-      graph->AddEdge(edge->src(), edge->src_output(), new_node,
-                     edge->dst_input());
-    }
-    std::vector<Node*> dst_nodes;
-    std::vector<int> src_indices;
-    std::vector<int> dst_indices;
-    for (const Edge* edge : node->out_edges()) {
-      dst_nodes.push_back(edge->dst());
-      src_indices.push_back(edge->src_output());
-      dst_indices.push_back(edge->dst_input());
-    }
-    graph->RemoveNode(node);
-    for (int i = 0; i < dst_nodes.size(); i++) {
-      graph->AddEdge(new_node, src_indices[i], dst_nodes[i], dst_indices[i]);
-    }
-    // Don't initialize variables on TPU if it is done for the ordinal already.
-    if (seen_ordinals_.contains(device_ordinal)) continue;
+      // Replace each _Arg node of type DT_RESOURCE that goes into a TPU node
+      // by a VarHandleOp on TPU with shared_name "v_tpu_x" where "v" is the
+      // shared_name of the variable on CPU and "x" is the rewritten device
+      // ordinal.
+      const string sname =
+          strings::StrCat(handle.name(), "_tpu_", device_ordinal);
+      AddNodeAttr("shared_name", sname, &ndef);
+      const string cname = ctx->resource_manager()->default_container();
+      AddNodeAttr("container", cname, &ndef);
+      core::RefCountPtr<Var> var;
+      TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &var));
+      AddNodeAttr("dtype", var->tensor()->dtype(), &ndef);
+      TensorShapeProto proto;
+      var->tensor()->shape().AsProto(&proto);
+      AddNodeAttr("shape", proto, &ndef);
+      Status status;
+      Node* new_node = graph->AddNode(ndef, &status);
+      TF_RETURN_IF_ERROR(status);
+      std::vector<const Edge*> in_edges(node->in_edges().begin(),
+                                        node->in_edges().end());
+      for (const Edge* edge : in_edges) {
+        graph->AddEdge(edge->src(), edge->src_output(), new_node,
+                       edge->dst_input());
+      }
+      std::vector<Node*> dst_nodes;
+      std::vector<int> src_indices;
+      std::vector<int> dst_indices;
+      for (const Edge* edge : node->out_edges()) {
+        dst_nodes.push_back(edge->dst());
+        src_indices.push_back(edge->src_output());
+        dst_indices.push_back(edge->dst_input());
+      }
+      graph->RemoveNode(node);
+      for (int i = 0; i < dst_nodes.size(); i++) {
+        graph->AddEdge(new_node, src_indices[i], dst_nodes[i], dst_indices[i]);
+      }
+      // Don't initialize variables on TPU if it is done for the ordinal
+      // already.
+      if (seen_ordinals_.contains(device_ordinal)) continue;
 
-    Device* d;
-    TF_RETURN_IF_ERROR(library_runtime_->device_mgr()->LookupDevice(
-        strings::StrCat(kTPUDeviceNamePrefix, device_ordinal), &d));
-    Var* tpu_var;
-    status = d->resource_manager()->Lookup(cname, sname, &tpu_var);
-    if (!status.ok()) {
-      TF_RETURN_IF_ERROR(InitializeVarOnTPU(ctx, var, &ndef, device_ordinal,
-                                            var_info.fast_mem));
+      Device* d;
+      TF_RETURN_IF_ERROR(library_runtime_->device_mgr()->LookupDevice(
+          strings::StrCat(kTPUDeviceNamePrefix, device_ordinal), &d));
+      Var* tpu_var;
+      status = d->resource_manager()->Lookup(cname, sname, &tpu_var);
+      if (!status.ok()) {
+        TF_RETURN_IF_ERROR(InitializeVarOnTPU(ctx, var, &ndef, device_ordinal,
+                                              var_info.fast_mem));
+      }
+      tpu_variables[handle_fp] = new_node;
     }
   }
 
diff --git a/tensorflow/core/tpu/tpu_ops_c_api.h b/tensorflow/core/tpu/tpu_ops_c_api.h
index d643464..1884286 100644
--- a/tensorflow/core/tpu/tpu_ops_c_api.h
+++ b/tensorflow/core/tpu/tpu_ops_c_api.h
@@ -106,6 +106,9 @@
   // enable_auto_xla_input_sharding is set to true. Negative numbers are
   // allowed and refers to dimensions starting from the end.
   int32_t auto_xla_input_sharding_dim;
+
+  // If true, only create one variable on the TPU for each variable on the CPU.
+  bool enable_variable_deduplication;
 };
 
 // Compiles Mlir or TF function computation by lowering into HLO IR and returns