Add TpuEmbedding Enqueue API's.

PiperOrigin-RevId: 421880160
Change-Id: Ibbca63cc351537a3654739f20211ef3456cbb341
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 343e4bf..7e94551 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -336,10 +336,13 @@
 cc_library(
     name = "tpu_ops_c_api_hdrs",
     srcs = [],
-    hdrs = ["tpu_ops_c_api.h"],
+    hdrs = [
+        "tpu_ops_c_api.h",
+    ],
     visibility = ["//visibility:public"],
     deps = [
         ":libtftpu_header",
+        "//tensorflow/c:tf_tstring",
         "//tensorflow/stream_executor/tpu:c_api_decl",
         "//tensorflow/stream_executor/tpu:proto_helper",
         "@com_google_absl//absl/types:optional",
diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc
index 89609d4..b4d48bb 100644
--- a/tensorflow/core/tpu/tpu_library_init_fns.inc
+++ b/tensorflow/core/tpu/tpu_library_init_fns.inc
@@ -103,6 +103,9 @@
   TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_IsInitialized);
   TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_WriteParameters);
   TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_ReadParameters);
+  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingTensorBatchFixedState_Create);
+  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingTensorBatchFixedState_Destroy);
+  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_EnqueueTensorBatch);
 
   return tensorflow::Status::OK();
 }
diff --git a/tensorflow/core/tpu/tpu_ops_c_api.h b/tensorflow/core/tpu/tpu_ops_c_api.h
index 486f274..53511ac 100644
--- a/tensorflow/core/tpu/tpu_ops_c_api.h
+++ b/tensorflow/core/tpu/tpu_ops_c_api.h
@@ -20,6 +20,7 @@
 #include <cstdint>
 
 #include "absl/types/optional.h"
+#include "tensorflow/c/tf_tensor.h"
 #include "tensorflow/core/tpu/libtftpu.h"
 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
 #include "tensorflow/stream_executor/tpu/proto_helper.h"
@@ -64,6 +65,9 @@
 
 typedef struct XLA_TpuEmbeddingEngineState XLA_TpuEmbeddingEngineState;
 
+typedef struct TpuEmbedding_TensorBatchFixedState
+    TpuEmbedding_TensorBatchFixedState;
+
 typedef struct TpuProfiler TpuProfiler;
 
 typedef struct XLA_DeviceAssignment {
@@ -601,6 +605,42 @@
 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ReadParameters(
     TpuEmbeddingEngineParameters* params, TF_Status* status);
 
+typedef struct TpuEmbeddingEngine_EnqueueTensorBatch_Params {
+  int32_t struct_size;
+  void* priv;
+
+  int32_t local_device_ordinal;
+  TpuEmbedding_TensorBatchFixedState* fixed_state;
+
+  TF_Tensor** sample_indices_tensors;
+  size_t sample_indices_tensors_size;
+  TF_Tensor** embedding_indices_tensors;
+  size_t embedding_indices_tensors_size;
+  TF_Tensor** aggregation_weights_tensors;
+  size_t aggregation_weights_tensors_size;
+  TF_Status* status;
+} TpuEmbeddingEngine_EnqueueTensorBatch_Params;
+
+TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_EnqueueTensorBatch(
+    TpuEmbeddingEngine_EnqueueTensorBatch_Params* params);
+
+typedef struct TpuEmbedding_TensorBatchFixedState_Create_Params {
+  int32_t struct_size;
+  void* priv;
+
+  size_t combiners_size;
+  char** combiners;
+
+  // out
+  TF_Status* status;
+} TpuEmbedding_TensorBatchFixedState_Create_Params;
+
+TFTPU_CAPI_EXPORT TpuEmbedding_TensorBatchFixedState*
+TpuEmbeddingTensorBatchFixedState_Create(
+    TpuEmbedding_TensorBatchFixedState_Create_Params* params);
+TFTPU_CAPI_EXPORT void TpuEmbeddingTensorBatchFixedState_Destroy(
+    TpuEmbedding_TensorBatchFixedState* fixed_state);
+
 struct TfTpu_OpsApiFn {
   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild);
   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_XrtCompileAndBuild);
@@ -692,6 +732,9 @@
   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_IsInitialized);
   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_WriteParameters);
   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ReadParameters);
+  TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingTensorBatchFixedState_Create);
+  TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingTensorBatchFixedState_Destroy);
+  TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_EnqueueTensorBatch);
 };
 
 }  // extern "C"