Execute tpu embedding load function exclusively.

PiperOrigin-RevId: 448329928
diff --git a/tensorflow/dtensor/cc/constants.h b/tensorflow/dtensor/cc/constants.h
index 94fa4a5..94a6895 100644
--- a/tensorflow/dtensor/cc/constants.h
+++ b/tensorflow/dtensor/cc/constants.h
@@ -127,6 +127,9 @@
 
 // Name of dtensor load embedding function.
 static constexpr char kLoadEmbeddingFn[] = "load_embedding_fn";
+
+// Name of dtensor retrieve embedding function.
+static constexpr char kRetrieveEmbeddingFn[] = "retrieve_embedding_fn";
 }  // namespace dtensor
 }  // namespace tensorflow
 
diff --git a/tensorflow/dtensor/cc/dtensor_device.cc b/tensorflow/dtensor/cc/dtensor_device.cc
index e922382..0afd917 100644
--- a/tensorflow/dtensor/cc/dtensor_device.cc
+++ b/tensorflow/dtensor/cc/dtensor_device.cc
@@ -1441,6 +1441,11 @@
   }
   VLOG(4) << tensorflow::DumpGraphToFile("after_mlir_spmd_lowering", *graph,
                                          flib_def);
+  if (flib_def->Contains(kLoadEmbeddingFn)) {
+    Status s = InsertFunctionForTPUEmbeddingCheckpoint(
+        status, graph.get(), inputs, kLoadEmbeddingFn);
+    RETURN_C_STATUS_IF_NOT_OK(s, status);
+  }
 
   // After MLIR transformations, exactly one StatefulPartitionedCall op is
   // returned for mesh cluster in computation. Identity all functions to execute
@@ -1548,7 +1553,7 @@
   std::map<std::string, const MeshWithParallelDevice*>
       function_name_and_mesh_mapping;
   absl::flat_hash_set<std::string> excluded_fn_names;
-  std::unique_ptr<const TranslatedFunction> epu_fn_ptr;
+  std::unique_ptr<const TranslatedFunction> epu_fn_ptr, load_embedding_ptr;
   for (const TranslatedFunction& function :
        execution_functions->function_list) {
     StatusOr<Mesh> maybe_converted_mesh = function.function_mesh;
@@ -1584,6 +1589,14 @@
       epu_fn_ptr = std::make_unique<const TranslatedFunction>(function);
       excluded_fn_names.insert(function.translated_function_name);
     }
+    if (absl::StartsWith(function.translated_function_name, kLoadEmbeddingFn)) {
+      if (load_embedding_ptr != nullptr) {
+        RETURN_STATUS(status, TF_INTERNAL,
+                      "There are more than one function defined on EPU mesh.");
+      }
+      load_embedding_ptr = std::make_unique<const TranslatedFunction>(function);
+      excluded_fn_names.insert(function.translated_function_name);
+    }
   }
 
   // Compute the step_id based on the function_mesh_fingerprint and the
@@ -1612,6 +1625,23 @@
         /*status=*/status);
   }
 
+  if (load_embedding_ptr != nullptr) {
+    StatusOr<std::vector<parallel_device::ParallelTensor*>> parallel_inputs =
+        PrepareEmbeddingInputs(inputs);
+    if (!parallel_inputs.ok()) {
+      RETURN_STATUS(status, TF_INTERNAL,
+                    parallel_inputs.status().error_message().c_str());
+    }
+    ExecuteFunctionAndWait(
+        context,
+        /*function_ptr=*/load_embedding_ptr.get(),
+        /*parallel_device_mesh=*/
+        function_name_and_mesh_mapping[load_embedding_ptr
+                                           ->translated_function_name],
+        /*parallel_inputs=*/*parallel_inputs, /*step_id=*/step_id,
+        /*attributes=*/attributes, /*status=*/status);
+  }
+
   // Execute all functions in parallel.
   for (const TranslatedFunction& function :
        execution_functions->function_list) {
@@ -1686,6 +1716,10 @@
       TF_NewStatus(), TF_DeleteStatus);
   for (const TranslatedFunction& function :
        execution_functions->function_list) {
+    // Skip execution for a function when it's excluded.
+    if (excluded_fn_names.contains(function.translated_function_name)) {
+      continue;
+    }
     const Mesh& mesh = function.function_mesh;
     // TODO(b/168730933): Lookup is slow as it takes all the devices in the Mesh
     // object. Ideally we'd just use a fingerprinted int64_t as a unique
diff --git a/tensorflow/dtensor/cc/dtensor_device_util.cc b/tensorflow/dtensor/cc/dtensor_device_util.cc
index 911fd5a..ecb4b59 100644
--- a/tensorflow/dtensor/cc/dtensor_device_util.cc
+++ b/tensorflow/dtensor/cc/dtensor_device_util.cc
@@ -27,12 +27,17 @@
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/common_runtime/graph_constructor.h"
 #include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/lib/strings/proto_serialization.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/public/version.h"
 #include "tensorflow/dtensor/cc/constants.h"
+#include "tensorflow/dtensor/cc/dstatus.h"
 #include "tensorflow/dtensor/cc/small_constant_optimization.h"
 
 namespace tensorflow {
@@ -788,5 +793,115 @@
   return parallel_inputs;
 }
 
+StatusOr<std::map<int64_t, Node*>> GetTPUEmbeddingInputNodes(
+    TF_Status* s, const Graph& graph,
+    const std::vector<TensorWithLayout*>& inputs) {
+  std::map<int64_t, Node*> table_id_node_map;
+  for (Node* node : graph.nodes()) {
+    if (!node->IsArg()) continue;
+
+    const int64_t& arg_id = node->attrs().Find("index")->i();
+    const AttrValue* embedding_attr =
+        node->attrs().Find("_tpu_embedding_table_id");
+
+    if (embedding_attr == nullptr) continue;
+
+    // Offset due to device id.
+    const int64_t table_id = embedding_attr->i();
+    EmbeddingResourceAttrs embedding_attrs;
+    embedding_attrs.table_id = table_id;
+    inputs[arg_id - 1]->UpdateAttrs(embedding_attrs, s);
+    if (!s->status.ok()) {
+      return errors::Internal(
+          "Failed to set embedding resource attrs. \n Got error: ",
+          s->status.error_message());
+    }
+    table_id_node_map.insert({table_id, node});
+  }
+  return table_id_node_map;
+}
+
+StatusOr<std::string> ValidateResourceMeshConsistency(
+    const std::vector<TensorWithLayout*>& inputs) {
+  std::string mesh_str;
+  for (TensorWithLayout* inp : inputs) {
+    if (inp->tensor_type() != kResource) continue;
+
+    auto* resource = dynamic_cast<ResourceHandleWithLayout*>(inp);
+    if (!resource || !resource->attrs().has_value()) continue;
+    const std::string& input_mesh_str = inp->layout().mesh().ToString();
+    if (mesh_str.empty()) {
+      mesh_str = input_mesh_str;
+    } else if (mesh_str != input_mesh_str) {
+      return errors::Internal(absl::StrCat(
+          "All inputs of embedding resource must be on same mesh. but get : ",
+          mesh_str, " != ", input_mesh_str));
+    }
+  }
+  VLOG(1) << "Resource input mesh is : " << mesh_str;
+  return mesh_str;
+}
+
+Status InsertFunctionForTPUEmbeddingCheckpoint(
+    TF_Status* status, Graph* graph,
+    const std::vector<TensorWithLayout*>& inputs,
+    const std::string& checkpoint_fn_name) {
+  if (checkpoint_fn_name != kLoadEmbeddingFn &&
+      checkpoint_fn_name != kRetrieveEmbeddingFn) {
+    return errors::InvalidArgument(absl::StrCat(
+        "Found wrong function name: ", checkpoint_fn_name,
+        " \n expects : ", kLoadEmbeddingFn, " or ", kRetrieveEmbeddingFn));
+  }
+
+  StatusOr<std::map<int64_t, Node*>> table_id_node_map =
+      GetTPUEmbeddingInputNodes(status, *graph, inputs);
+  if (!table_id_node_map.ok()) {
+    return errors::Internal(table_id_node_map.status().error_message());
+  }
+
+  StatusOr<std::string> mesh_str = ValidateResourceMeshConsistency(inputs);
+
+  const int64_t& num_tables = table_id_node_map->size();
+  NodeDef func_node_def;
+  std::vector<NodeDefBuilder::NodeOut> func_inputs;
+  std::vector<DataType> input_types, output_types;
+
+  func_inputs.reserve(num_tables);
+  input_types.reserve(num_tables);
+
+  for (int i = 0; i < num_tables; ++i) {
+    auto node_ptr = table_id_node_map->find(i);
+    if (node_ptr == table_id_node_map->end()) {
+      return errors::Internal(
+          absl::StrCat("Embedding table id ", i, " is not found."));
+    }
+    const std::string& node_name = node_ptr->second->name();
+    func_inputs.push_back({node_name, i, DT_RESOURCE});
+    input_types.push_back(DT_RESOURCE);
+  }
+
+  AttrValue mesh_attr;
+  *mesh_attr.mutable_s() = *mesh_str;
+  NameAttrList func_attr;
+  func_attr.set_name(checkpoint_fn_name);
+  TF_RETURN_IF_ERROR(
+      NodeDefBuilder(checkpoint_fn_name, "StatefulPartitionedCall")
+          .Attr("Tin", input_types)
+          .Attr("Tout", output_types)
+          .Attr("f", func_attr)
+          .Attr(kMeshAttr, mesh_attr)
+          .Attr("config", mesh_attr)
+          .Input(func_inputs)
+          .Finalize(&func_node_def, true));
+
+  TF_ASSIGN_OR_RETURN(Node * func_node, graph->AddNode(func_node_def));
+  for (int i = 0; i < num_tables; ++i) {
+    Node* node = table_id_node_map->find(i)->second;
+    graph->AddEdge(node, 0, func_node, i);
+  }
+
+  return Status::OK();
+}
+
 }  // namespace dtensor
 }  // namespace tensorflow
diff --git a/tensorflow/dtensor/cc/dtensor_device_util.h b/tensorflow/dtensor/cc/dtensor_device_util.h
index b36017d..7063425 100644
--- a/tensorflow/dtensor/cc/dtensor_device_util.h
+++ b/tensorflow/dtensor/cc/dtensor_device_util.h
@@ -548,6 +548,11 @@
 StatusOr<std::vector<parallel_device::ParallelTensor*>> PrepareEmbeddingInputs(
     const std::vector<TensorWithLayout*>& inputs);
 
+Status InsertFunctionForTPUEmbeddingCheckpoint(
+    TF_Status* status, Graph* graph,
+    const std::vector<TensorWithLayout*>& inputs,
+    const std::string& checkpoint_fn_name);
+
 }  // namespace dtensor
 }  // namespace tensorflow