Add TpuEmbedding Enqueue API's.
PiperOrigin-RevId: 421880160
Change-Id: Ibbca63cc351537a3654739f20211ef3456cbb341
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 343e4bf..7e94551 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -336,10 +336,13 @@
cc_library(
name = "tpu_ops_c_api_hdrs",
srcs = [],
- hdrs = ["tpu_ops_c_api.h"],
+ hdrs = [
+ "tpu_ops_c_api.h",
+ ],
visibility = ["//visibility:public"],
deps = [
":libtftpu_header",
+ "//tensorflow/c:tf_tstring",
"//tensorflow/stream_executor/tpu:c_api_decl",
"//tensorflow/stream_executor/tpu:proto_helper",
"@com_google_absl//absl/types:optional",
diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc
index 89609d4..b4d48bb 100644
--- a/tensorflow/core/tpu/tpu_library_init_fns.inc
+++ b/tensorflow/core/tpu/tpu_library_init_fns.inc
@@ -103,6 +103,9 @@
TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_IsInitialized);
TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_WriteParameters);
TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_ReadParameters);
+ TFTPU_SET_FN(ops_api_fn, TpuEmbeddingTensorBatchFixedState_Create);
+ TFTPU_SET_FN(ops_api_fn, TpuEmbeddingTensorBatchFixedState_Destroy);
+ TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_EnqueueTensorBatch);
return tensorflow::Status::OK();
}
diff --git a/tensorflow/core/tpu/tpu_ops_c_api.h b/tensorflow/core/tpu/tpu_ops_c_api.h
index 486f274..53511ac 100644
--- a/tensorflow/core/tpu/tpu_ops_c_api.h
+++ b/tensorflow/core/tpu/tpu_ops_c_api.h
@@ -20,6 +20,7 @@
#include <cstdint>
#include "absl/types/optional.h"
+#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
@@ -64,6 +65,9 @@
typedef struct XLA_TpuEmbeddingEngineState XLA_TpuEmbeddingEngineState;
+typedef struct TpuEmbedding_TensorBatchFixedState
+ TpuEmbedding_TensorBatchFixedState;
+
typedef struct TpuProfiler TpuProfiler;
typedef struct XLA_DeviceAssignment {
@@ -601,6 +605,42 @@
TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ReadParameters(
TpuEmbeddingEngineParameters* params, TF_Status* status);
+typedef struct TpuEmbeddingEngine_EnqueueTensorBatch_Params {
+ int32_t struct_size;
+ void* priv;
+
+ int32_t local_device_ordinal;
+ TpuEmbedding_TensorBatchFixedState* fixed_state;
+
+ TF_Tensor** sample_indices_tensors;
+ size_t sample_indices_tensors_size;
+ TF_Tensor** embedding_indices_tensors;
+ size_t embedding_indices_tensors_size;
+ TF_Tensor** aggregation_weights_tensors;
+ size_t aggregation_weights_tensors_size;
+ TF_Status* status;
+} TpuEmbeddingEngine_EnqueueTensorBatch_Params;
+
+TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_EnqueueTensorBatch(
+ TpuEmbeddingEngine_EnqueueTensorBatch_Params* params);
+
+typedef struct TpuEmbedding_TensorBatchFixedState_Create_Params {
+ int32_t struct_size;
+ void* priv;
+
+ size_t combiners_size;
+ char** combiners;
+
+ // out
+ TF_Status* status;
+} TpuEmbedding_TensorBatchFixedState_Create_Params;
+
+TFTPU_CAPI_EXPORT TpuEmbedding_TensorBatchFixedState*
+TpuEmbeddingTensorBatchFixedState_Create(
+ TpuEmbedding_TensorBatchFixedState_Create_Params* params);
+TFTPU_CAPI_EXPORT void TpuEmbeddingTensorBatchFixedState_Destroy(
+ TpuEmbedding_TensorBatchFixedState* fixed_state);
+
struct TfTpu_OpsApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_XrtCompileAndBuild);
@@ -692,6 +732,9 @@
TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_IsInitialized);
TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_WriteParameters);
TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ReadParameters);
+ TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingTensorBatchFixedState_Create);
+ TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingTensorBatchFixedState_Destroy);
+ TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_EnqueueTensorBatch);
};
} // extern "C"