Execute tpu embedding load function exclusively.
PiperOrigin-RevId: 448329928
diff --git a/tensorflow/dtensor/cc/constants.h b/tensorflow/dtensor/cc/constants.h
index 94fa4a5..94a6895 100644
--- a/tensorflow/dtensor/cc/constants.h
+++ b/tensorflow/dtensor/cc/constants.h
@@ -127,6 +127,9 @@
// Name of dtensor load embedding function.
static constexpr char kLoadEmbeddingFn[] = "load_embedding_fn";
+
+// Name of dtensor retrieve embedding function.
+static constexpr char kRetrieveEmbeddingFn[] = "retrieve_embedding_fn";
} // namespace dtensor
} // namespace tensorflow
diff --git a/tensorflow/dtensor/cc/dtensor_device.cc b/tensorflow/dtensor/cc/dtensor_device.cc
index e922382..0afd917 100644
--- a/tensorflow/dtensor/cc/dtensor_device.cc
+++ b/tensorflow/dtensor/cc/dtensor_device.cc
@@ -1441,6 +1441,11 @@
}
VLOG(4) << tensorflow::DumpGraphToFile("after_mlir_spmd_lowering", *graph,
flib_def);
+ if (flib_def->Contains(kLoadEmbeddingFn)) {
+ Status s = InsertFunctionForTPUEmbeddingCheckpoint(
+ status, graph.get(), inputs, kLoadEmbeddingFn);
+ RETURN_C_STATUS_IF_NOT_OK(s, status);
+ }
// After MLIR transformations, exactly one StatefulPartitionedCall op is
// returned for mesh cluster in computation. Identity all functions to execute
@@ -1548,7 +1553,7 @@
std::map<std::string, const MeshWithParallelDevice*>
function_name_and_mesh_mapping;
absl::flat_hash_set<std::string> excluded_fn_names;
- std::unique_ptr<const TranslatedFunction> epu_fn_ptr;
+ std::unique_ptr<const TranslatedFunction> epu_fn_ptr, load_embedding_ptr;
for (const TranslatedFunction& function :
execution_functions->function_list) {
StatusOr<Mesh> maybe_converted_mesh = function.function_mesh;
@@ -1584,6 +1589,14 @@
epu_fn_ptr = std::make_unique<const TranslatedFunction>(function);
excluded_fn_names.insert(function.translated_function_name);
}
+ if (absl::StartsWith(function.translated_function_name, kLoadEmbeddingFn)) {
+ if (load_embedding_ptr != nullptr) {
+ RETURN_STATUS(status, TF_INTERNAL,
+ "There are more than one function defined on EPU mesh.");
+ }
+ load_embedding_ptr = std::make_unique<const TranslatedFunction>(function);
+ excluded_fn_names.insert(function.translated_function_name);
+ }
}
// Compute the step_id based on the function_mesh_fingerprint and the
@@ -1612,6 +1625,23 @@
/*status=*/status);
}
+ if (load_embedding_ptr != nullptr) {
+ StatusOr<std::vector<parallel_device::ParallelTensor*>> parallel_inputs =
+ PrepareEmbeddingInputs(inputs);
+ if (!parallel_inputs.ok()) {
+ RETURN_STATUS(status, TF_INTERNAL,
+ parallel_inputs.status().error_message().c_str());
+ }
+ ExecuteFunctionAndWait(
+ context,
+ /*function_ptr=*/load_embedding_ptr.get(),
+ /*parallel_device_mesh=*/
+ function_name_and_mesh_mapping[load_embedding_ptr
+ ->translated_function_name],
+ /*parallel_inputs=*/*parallel_inputs, /*step_id=*/step_id,
+ /*attributes=*/attributes, /*status=*/status);
+ }
+
// Execute all functions in parallel.
for (const TranslatedFunction& function :
execution_functions->function_list) {
@@ -1686,6 +1716,10 @@
TF_NewStatus(), TF_DeleteStatus);
for (const TranslatedFunction& function :
execution_functions->function_list) {
+ // Skip execution for a function when it's excluded.
+ if (excluded_fn_names.contains(function.translated_function_name)) {
+ continue;
+ }
const Mesh& mesh = function.function_mesh;
// TODO(b/168730933): Lookup is slow as it takes all the devices in the Mesh
// object. Ideally we'd just use a fingerprinted int64_t as a unique
diff --git a/tensorflow/dtensor/cc/dtensor_device_util.cc b/tensorflow/dtensor/cc/dtensor_device_util.cc
index 911fd5a..ecb4b59 100644
--- a/tensorflow/dtensor/cc/dtensor_device_util.cc
+++ b/tensorflow/dtensor/cc/dtensor_device_util.cc
@@ -27,12 +27,17 @@
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/dtensor/cc/constants.h"
+#include "tensorflow/dtensor/cc/dstatus.h"
#include "tensorflow/dtensor/cc/small_constant_optimization.h"
namespace tensorflow {
@@ -788,5 +793,115 @@
return parallel_inputs;
}
+StatusOr<std::map<int64_t, Node*>> GetTPUEmbeddingInputNodes(
+ TF_Status* s, const Graph& graph,
+ const std::vector<TensorWithLayout*>& inputs) {
+ std::map<int64_t, Node*> table_id_node_map;
+ for (Node* node : graph.nodes()) {
+ if (!node->IsArg()) continue;
+
+ const int64_t& arg_id = node->attrs().Find("index")->i();
+ const AttrValue* embedding_attr =
+ node->attrs().Find("_tpu_embedding_table_id");
+
+ if (embedding_attr == nullptr) continue;
+
+ // Offset due to device id.
+ const int64_t table_id = embedding_attr->i();
+ EmbeddingResourceAttrs embedding_attrs;
+ embedding_attrs.table_id = table_id;
+ inputs[arg_id - 1]->UpdateAttrs(embedding_attrs, s);
+ if (!s->status.ok()) {
+ return errors::Internal(
+ "Failed to set embedding resource attrs. \n Got error: ",
+ s->status.error_message());
+ }
+ table_id_node_map.insert({table_id, node});
+ }
+ return table_id_node_map;
+}
+
+StatusOr<std::string> ValidateResourceMeshConsistency(
+ const std::vector<TensorWithLayout*>& inputs) {
+ std::string mesh_str;
+ for (TensorWithLayout* inp : inputs) {
+ if (inp->tensor_type() != kResource) continue;
+
+ auto* resource = dynamic_cast<ResourceHandleWithLayout*>(inp);
+ if (!resource || !resource->attrs().has_value()) continue;
+ const std::string& input_mesh_str = inp->layout().mesh().ToString();
+ if (mesh_str.empty()) {
+ mesh_str = input_mesh_str;
+ } else if (mesh_str != input_mesh_str) {
+ return errors::Internal(absl::StrCat(
+ "All inputs of embedding resource must be on same mesh. but get : ",
+ mesh_str, " != ", input_mesh_str));
+ }
+ }
+ VLOG(1) << "Resource input mesh is : " << mesh_str;
+ return mesh_str;
+}
+
+Status InsertFunctionForTPUEmbeddingCheckpoint(
+ TF_Status* status, Graph* graph,
+ const std::vector<TensorWithLayout*>& inputs,
+ const std::string& checkpoint_fn_name) {
+ if (checkpoint_fn_name != kLoadEmbeddingFn &&
+ checkpoint_fn_name != kRetrieveEmbeddingFn) {
+ return errors::InvalidArgument(absl::StrCat(
+ "Found wrong function name: ", checkpoint_fn_name,
+ " \n expects : ", kLoadEmbeddingFn, " or ", kRetrieveEmbeddingFn));
+ }
+
+ StatusOr<std::map<int64_t, Node*>> table_id_node_map =
+ GetTPUEmbeddingInputNodes(status, *graph, inputs);
+ if (!table_id_node_map.ok()) {
+ return errors::Internal(table_id_node_map.status().error_message());
+ }
+
+ StatusOr<std::string> mesh_str = ValidateResourceMeshConsistency(inputs);
+
+ const int64_t& num_tables = table_id_node_map->size();
+ NodeDef func_node_def;
+ std::vector<NodeDefBuilder::NodeOut> func_inputs;
+ std::vector<DataType> input_types, output_types;
+
+ func_inputs.reserve(num_tables);
+ input_types.reserve(num_tables);
+
+ for (int i = 0; i < num_tables; ++i) {
+ auto node_ptr = table_id_node_map->find(i);
+ if (node_ptr == table_id_node_map->end()) {
+ return errors::Internal(
+ absl::StrCat("Embedding table id ", i, " is not found."));
+ }
+ const std::string& node_name = node_ptr->second->name();
+ func_inputs.push_back({node_name, i, DT_RESOURCE});
+ input_types.push_back(DT_RESOURCE);
+ }
+
+ AttrValue mesh_attr;
+ *mesh_attr.mutable_s() = *mesh_str;
+ NameAttrList func_attr;
+ func_attr.set_name(checkpoint_fn_name);
+ TF_RETURN_IF_ERROR(
+ NodeDefBuilder(checkpoint_fn_name, "StatefulPartitionedCall")
+ .Attr("Tin", input_types)
+ .Attr("Tout", output_types)
+ .Attr("f", func_attr)
+ .Attr(kMeshAttr, mesh_attr)
+ .Attr("config", mesh_attr)
+ .Input(func_inputs)
+ .Finalize(&func_node_def, true));
+
+ TF_ASSIGN_OR_RETURN(Node * func_node, graph->AddNode(func_node_def));
+ for (int i = 0; i < num_tables; ++i) {
+ Node* node = table_id_node_map->find(i)->second;
+ graph->AddEdge(node, 0, func_node, i);
+ }
+
+ return Status::OK();
+}
+
} // namespace dtensor
} // namespace tensorflow
diff --git a/tensorflow/dtensor/cc/dtensor_device_util.h b/tensorflow/dtensor/cc/dtensor_device_util.h
index b36017d..7063425 100644
--- a/tensorflow/dtensor/cc/dtensor_device_util.h
+++ b/tensorflow/dtensor/cc/dtensor_device_util.h
@@ -548,6 +548,11 @@
StatusOr<std::vector<parallel_device::ParallelTensor*>> PrepareEmbeddingInputs(
const std::vector<TensorWithLayout*>& inputs);
+Status InsertFunctionForTPUEmbeddingCheckpoint(
+ TF_Status* status, Graph* graph,
+ const std::vector<TensorWithLayout*>& inputs,
+ const std::string& checkpoint_fn_name);
+
} // namespace dtensor
} // namespace tensorflow