Implement OSS TPU ordinal selector.
PiperOrigin-RevId: 360295595
Change-Id: I70168d92e735e5f42c46b0b7742195649b094433
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
index 08afc82..96a4c2c 100644
--- a/tensorflow/core/tpu/kernels/BUILD
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -956,6 +956,7 @@
name = "tpu_ordinal_selector_op",
srcs = ["tpu_ordinal_selector_op.cc"],
deps = [
+ ":tpu_ordinal_selector",
"//tensorflow/core:framework",
],
alwayslink = 1,
@@ -968,3 +969,14 @@
"//tensorflow/core:framework",
],
)
+
+cc_library(
+ name = "tpu_ordinal_selector",
+ hdrs = ["tpu_ordinal_selector.h"],
+ deps = [
+ ":tpu_ordinal_selector_interface",
+ "//tensorflow/core:framework",
+ "//tensorflow/core/tpu:tpu_api",
+ "//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
+ ],
+)
diff --git a/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h b/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h
new file mode 100644
index 0000000..faf78f9
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h
@@ -0,0 +1,58 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_
+
+#include "tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h"
+#include "tensorflow/core/tpu/tpu_api.h"
+#include "tensorflow/core/tpu/tpu_ops_c_api.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// A reserved ID for deferred core selection. Intentionally set at a number
+// that is more than the number of cores available in a future system.
+constexpr int32 kDeferredCoreSelectionReserved = -8193;
+
+class TPUOrdinalSelector : TPUOrdinalSelectorInterface {
+ public:
+ explicit TPUOrdinalSelector(int num_cores_per_replica = 1) {
+ OpsApiFn()->TfTpuOrdinalSelector_CreateFn(&ordinal_selector_,
+ num_cores_per_replica);
+ }
+ ~TPUOrdinalSelector() override {
+ OpsApiFn()->TfTpuOrdinalSelector_DestroyFn(ordinal_selector_);
+ }
+ int64 GetOrdinal(absl::optional<uint64> key, int64_t* req_id) override {
+ int64 ordinal;
+ OpsApiFn()->TfTpuOrdinalSelector_GetOrdinalFn(ordinal_selector_, key,
+ req_id, &ordinal);
+ return ordinal;
+ }
+ void DequeueFromCoreSelector(int32_t device_ordinal,
+ int64_t req_id) override {
+ OpsApiFn()->TfTpuOrdinalSelector_DequeueFromCoreSelectorFn(
+ ordinal_selector_, device_ordinal, req_id);
+ }
+
+ private:
+ TfTpuOrdinalSelector* ordinal_selector_;
+};
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc b/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc
index 13a624b..c6da029 100644
--- a/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc
+++ b/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc
@@ -19,14 +19,11 @@
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/tpu/kernels/tpu_ordinal_selector.h"
namespace tensorflow {
namespace {
-// A reserved ID for deferred core selection. Intentionally set at a number
-// that is more than the number of cores available in a future system.
-constexpr int32 kDeferredCoreSelectionReserved = -8193;
-
// TPUOrdinalSelectorOp is a no-op for backward compatibility. The core
// selection algorithm happens inside TPUPartitionedCall.
class TPUOrdinalSelectorOp : public OpKernel {
@@ -37,7 +34,7 @@
void Compute(OpKernelContext* ctx) override {
Tensor output(DT_INT32, TensorShape({}));
- output.flat<int>().setValues({kDeferredCoreSelectionReserved});
+ output.flat<int>().setValues({tpu::kDeferredCoreSelectionReserved});
ctx->set_output(0, output);
ctx->SetStatus(Status::OK());
}