Implement OSS TPU ordinal selector.

PiperOrigin-RevId: 360295595
Change-Id: I70168d92e735e5f42c46b0b7742195649b094433
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
index 08afc82..96a4c2c 100644
--- a/tensorflow/core/tpu/kernels/BUILD
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -956,6 +956,7 @@
     name = "tpu_ordinal_selector_op",
     srcs = ["tpu_ordinal_selector_op.cc"],
     deps = [
+        ":tpu_ordinal_selector",
         "//tensorflow/core:framework",
     ],
     alwayslink = 1,
@@ -968,3 +969,14 @@
         "//tensorflow/core:framework",
     ],
 )
+
+cc_library(
+    name = "tpu_ordinal_selector",
+    hdrs = ["tpu_ordinal_selector.h"],
+    deps = [
+        ":tpu_ordinal_selector_interface",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/tpu:tpu_api",
+        "//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
+    ],
+)
diff --git a/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h b/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h
new file mode 100644
index 0000000..faf78f9
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h
@@ -0,0 +1,58 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_
+
+#include "tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h"
+#include "tensorflow/core/tpu/tpu_api.h"
+#include "tensorflow/core/tpu/tpu_ops_c_api.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// A reserved ID for deferred core selection. Intentionally set at a number
+// that is more than the number of cores available in a future system.
+constexpr int32 kDeferredCoreSelectionReserved = -8193;
+
+class TPUOrdinalSelector : TPUOrdinalSelectorInterface {
+ public:
+  explicit TPUOrdinalSelector(int num_cores_per_replica = 1) {
+    OpsApiFn()->TfTpuOrdinalSelector_CreateFn(&ordinal_selector_,
+                                              num_cores_per_replica);
+  }
+  ~TPUOrdinalSelector() override {
+    OpsApiFn()->TfTpuOrdinalSelector_DestroyFn(ordinal_selector_);
+  }
+  int64 GetOrdinal(absl::optional<uint64> key, int64_t* req_id) override {
+    int64 ordinal;
+    OpsApiFn()->TfTpuOrdinalSelector_GetOrdinalFn(ordinal_selector_, key,
+                                                  req_id, &ordinal);
+    return ordinal;
+  }
+  void DequeueFromCoreSelector(int32_t device_ordinal,
+                               int64_t req_id) override {
+    OpsApiFn()->TfTpuOrdinalSelector_DequeueFromCoreSelectorFn(
+        ordinal_selector_, device_ordinal, req_id);
+  }
+
+ private:
+  TfTpuOrdinalSelector* ordinal_selector_;
+};
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc b/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc
index 13a624b..c6da029 100644
--- a/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc
+++ b/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc
@@ -19,14 +19,11 @@
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/tpu/kernels/tpu_ordinal_selector.h"
 
 namespace tensorflow {
 namespace {
 
-// A reserved ID for deferred core selection. Intentionally set at a number
-// that is more than the number of cores available in a future system.
-constexpr int32 kDeferredCoreSelectionReserved = -8193;
-
 // TPUOrdinalSelectorOp is a no-op for backward compatibility. The core
 // selection algorithm happens inside TPUPartitionedCall.
 class TPUOrdinalSelectorOp : public OpKernel {
@@ -37,7 +34,7 @@
 
   void Compute(OpKernelContext* ctx) override {
     Tensor output(DT_INT32, TensorShape({}));
-    output.flat<int>().setValues({kDeferredCoreSelectionReserved});
+    output.flat<int>().setValues({tpu::kDeferredCoreSelectionReserved});
     ctx->set_output(0, output);
     ctx->SetStatus(Status::OK());
   }