Register the TPU platform when the relevant TPU libraries are dynamically loaded

PiperOrigin-RevId: 318420259
Change-Id: Ic7807130ab4717031e9176b466ed9a51da8b27a7
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 589af63..dd7435d 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -139,6 +139,7 @@
         "//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
         "//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
         "//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
+        "//tensorflow/stream_executor/tpu:tpu_executor_base",
         "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
         "//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
     ],
diff --git a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
index 495e6a2..450f7aa 100644
--- a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
+++ b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
@@ -21,12 +21,14 @@
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/tpu/tpu_api.h"
 #include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
 
-#define TFTPU_SET_FN(Struct, FnName)                                       \
-  Struct->FnName##Fn =                                                     \
-      reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName)); \
-  if (!(Struct->FnName##Fn)) {                                             \
-    LOG(ERROR) << #FnName " not available in this library.";               \
+#define TFTPU_SET_FN(Struct, FnName)                                         \
+  Struct->FnName##Fn =                                                       \
+      reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName));   \
+  if (!(Struct->FnName##Fn)) {                                               \
+    LOG(ERROR) << #FnName " not available in this library.";                 \
+    return errors::Unimplemented(#FnName " not available in this library."); \
   }
 
 // Reminder: Update tpu_library_loader_windows.cc if you are adding new publicly
@@ -44,19 +46,20 @@
     shared_object_loaded = false;
   }
 
-  TF_RETURN_IF_ERROR(InitializeTpuStructFns(library_handle));
+  Status s = InitializeTpuStructFns(library_handle);
 
-  if (shared_object_loaded) {
+  // TPU platform registration must only be performed after the library is
+  // loaded. We do not want to register a TPU platform in XLA without the
+  // supporting library providing the necessary APIs.
+  if (shared_object_loaded && s.ok()) {
     // TODO(frankchn): Make initialization actually work
     // Initialize TPU platform when the platform code is loaded from a library.
     // InitializeApiFn()->TfTpu_InitializeFn();
 
-    // We should only register the TPU platform when the library is loaded.
-    // TODO(frankchn): Resolve the circular dependency and register the platform
-    // RegisterTpuPlatform();
+    RegisterTpuPlatform();
   }
 
-  return Status::OK();
+  return s;
 }
 
 }  // namespace tpu
diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD
index 66b0013..add4db1 100644
--- a/tensorflow/stream_executor/tpu/BUILD
+++ b/tensorflow/stream_executor/tpu/BUILD
@@ -35,6 +35,7 @@
     deps = [
         ":tpu_executor_c_api_hdrs",
         "//tensorflow/core/platform:status",
+        "//tensorflow/core/tpu:tpu_api",
         "//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
     ],
 )
@@ -61,7 +62,16 @@
 )
 
 cc_library(
-    name = "tpu_executor",
+    name = "device_memory_base_helper",
+    hdrs = ["device_memory_base_helper.h"],
+    deps = [
+        ":tpu_executor_c_api_hdrs",
+        "//tensorflow/stream_executor:device_memory",
+    ],
+)
+
+cc_library(
+    name = "tpu_executor_base",
     srcs = [
         "tpu_executor.cc",
         "tpu_platform.cc",
@@ -73,7 +83,7 @@
         "tpu_timer.h",
     ],
     deps = [
-        ":c_api_conversions",
+        ":device_memory_base_helper",
         ":status_helper",
         ":tpu_executor_c_api_hdrs",
         ":tpu_executor_interface",
@@ -91,6 +101,16 @@
 )
 
 cc_library(
+    name = "tpu_executor",
+    srcs = ["tpu_platform_registration.cc"],
+    deps = [
+        ":tpu_executor_base",
+        "//tensorflow/stream_executor/platform",
+    ],
+    alwayslink = True,
+)
+
+cc_library(
     name = "tpu_node_context",
     srcs = ["tpu_node_context.cc"],
     hdrs = ["tpu_node_context.h"],
@@ -118,6 +138,7 @@
     srcs = ["tpu_transfer_manager_registration.cc"],
     deps = [
         ":tpu_executor",
+        ":tpu_executor_base",
         ":tpu_transfer_manager_base",
         "//tensorflow/compiler/xla/service:transfer_manager",
     ],
@@ -131,7 +152,7 @@
         ":c_api_conversions",
         ":proto_helper",
         ":status_helper",
-        ":tpu_executor",
+        ":tpu_executor_base",
         ":tpu_executor_c_api_hdrs",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/stream_executor/tpu/device_memory_base_helper.h b/tensorflow/stream_executor/tpu/device_memory_base_helper.h
new file mode 100644
index 0000000..9937dc2
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/device_memory_base_helper.h
@@ -0,0 +1,41 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_DEVICE_MEMORY_BASE_HELPER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_DEVICE_MEMORY_BASE_HELPER_H_
+
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+class DeviceMemoryBaseHelper {
+ public:
+  static stream_executor::DeviceMemoryBase
+  SE_DeviceMemoryBaseToDeviceMemoryBase(SE_DeviceMemoryBase se_base) {
+    stream_executor::DeviceMemoryBase base(se_base.opaque, se_base.size);
+    base.SetPayload(se_base.payload);
+    return base;
+  }
+
+  static SE_DeviceMemoryBase DeviceMemoryBaseToSE_DeviceMemoryBase(
+      const stream_executor::DeviceMemoryBase& base) {
+    SE_DeviceMemoryBase se_base;
+    se_base.opaque = const_cast<void*>(base.opaque());
+    se_base.payload = base.payload();
+    se_base.size = base.size();
+    return se_base;
+  }
+};
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_DEVICE_MEMORY_BASE_HELPER_H_
diff --git a/tensorflow/stream_executor/tpu/status_helper.h b/tensorflow/stream_executor/tpu/status_helper.h
index bc8820f..0129abb 100644
--- a/tensorflow/stream_executor/tpu/status_helper.h
+++ b/tensorflow/stream_executor/tpu/status_helper.h
@@ -18,22 +18,34 @@
 
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
+#include "tensorflow/core/tpu/tpu_api.h"
 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
 
-struct StatusHelper {
-  StatusHelper() : c_status(TpuStatus_New()) {}
-  ~StatusHelper() { TpuStatus_Free(c_status); }
-  bool ok() { return TpuStatus_Code(c_status) == 0; }
-  tensorflow::Status status() {
+class StatusHelper {
+ public:
+  StatusHelper()
+      : c_status(tensorflow::tpu::ExecutorApiFn()->TpuStatus_NewFn()) {}
+
+  ~StatusHelper() {
+    tensorflow::tpu::ExecutorApiFn()->TpuStatus_FreeFn(c_status);
+  }
+
+  bool ok() const {
+    return tensorflow::tpu::ExecutorApiFn()->TpuStatus_CodeFn(c_status) == 0;
+  }
+
+  tensorflow::Status status() const {
     if (!ok()) {
       return tensorflow::Status(
-          tensorflow::error::Code(TpuStatus_Code(c_status)),
-          TpuStatus_Message(c_status));
+          tensorflow::error::Code(
+              tensorflow::tpu::ExecutorApiFn()->TpuStatus_CodeFn(c_status)),
+          tensorflow::tpu::ExecutorApiFn()->TpuStatus_MessageFn(c_status));
     } else {
       return tensorflow::Status::OK();
     }
   }
-  SE_Status* c_status;
+
+  SE_Status* c_status;  // NOLINT
 };
 
 #endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_executor.cc b/tensorflow/stream_executor/tpu/tpu_executor.cc
index 95c3271..60d6d22b 100644
--- a/tensorflow/stream_executor/tpu/tpu_executor.cc
+++ b/tensorflow/stream_executor/tpu/tpu_executor.cc
@@ -20,7 +20,7 @@
 #include "tensorflow/core/tpu/tpu_api.h"
 #include "tensorflow/stream_executor/device_memory.h"
 #include "tensorflow/stream_executor/lib/status.h"
-#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
+#include "tensorflow/stream_executor/tpu/device_memory_base_helper.h"
 #include "tensorflow/stream_executor/tpu/status_helper.h"
 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
 #include "tensorflow/stream_executor/tpu/tpu_stream.h"
@@ -123,25 +123,26 @@
 
 stream_executor::Event::Status TpuExecutor::PollForEventStatus(
     stream_executor::Event* event) {
-  return stream_executor::Event::Status(TpuExecutor_PollForEventStatus(
-      executor_, event_map().at(event->implementation())));
+  return stream_executor::Event::Status(
+      tpu::ExecutorApiFn()->TpuExecutor_PollForEventStatusFn(
+          executor_, event_map().at(event->implementation())));
 }
 
 Status TpuExecutor::RecordEvent(Stream* stream,
                                 ::stream_executor::Event* event) {
   StatusHelper status;
-  TpuExecutor_RecordEvent(executor_, stream_map().at(stream->implementation()),
-                          event_map().at(event->implementation()),
-                          status.c_status);
+  tpu::ExecutorApiFn()->TpuExecutor_RecordEventFn(
+      executor_, stream_map().at(stream->implementation()),
+      event_map().at(event->implementation()), status.c_status);
   return status.status();
 }
 
 Status TpuExecutor::WaitForEvent(Stream* stream,
                                  ::stream_executor::Event* event) {
   StatusHelper status;
-  TpuExecutor_WaitForEvent(executor_, stream_map().at(stream->implementation()),
-                           event_map().at(event->implementation()),
-                           status.c_status);
+  tpu::ExecutorApiFn()->TpuExecutor_WaitForEventFn(
+      executor_, stream_map().at(stream->implementation()),
+      event_map().at(event->implementation()), status.c_status);
   return status.status();
 }
 
@@ -181,18 +182,18 @@
 DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) {
   SE_DeviceMemoryBase se_base = tpu::ExecutorApiFn()->TpuExecutor_AllocateFn(
       executor_, size, memory_space);
-  return TpuConversions::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base);
+  return DeviceMemoryBaseHelper::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base);
 }
 
 void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
   SE_DeviceMemoryBase se_base =
-      TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(memory);
+      DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(memory);
   tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
 }
 
 void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
   SE_DeviceMemoryBase se_base =
-      TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory);
+      DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory);
   tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
 }
 
@@ -268,7 +269,7 @@
                          const ::stream_executor::DeviceMemoryBase& device_src,
                          uint64 size) {
   SE_DeviceMemoryBase se_base =
-      TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
+      DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
   return tpu::ExecutorApiFn()->TpuExecutor_MemcpyToHostFn(
       executor_, stream_map().at(stream->implementation()), host_dst, &se_base,
       size);
@@ -278,7 +279,8 @@
                          ::stream_executor::DeviceMemoryBase* device_dst,
                          const void* host_src, uint64 size) {
   SE_DeviceMemoryBase se_base =
-      TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
+      DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(
+          *device_dst);
   return tpu::ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn(
       executor_, stream_map().at(stream->implementation()), &se_base, host_src,
       size);
@@ -289,7 +291,8 @@
     uint64 size) {
   StatusHelper status;
   SE_DeviceMemoryBase se_base =
-      TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
+      DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(
+          *device_dst);
   tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyFromHostFn(
       executor_, &se_base, host_src, size, status.c_status);
   return status.status();
@@ -300,7 +303,7 @@
     uint64 size) {
   StatusHelper status;
   SE_DeviceMemoryBase se_base =
-      TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
+      DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
   tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyToHostFn(
       executor_, host_dst, &se_base, size, status.c_status);
   return status.status();
diff --git a/tensorflow/stream_executor/tpu/tpu_platform.cc b/tensorflow/stream_executor/tpu/tpu_platform.cc
index db6324e..24767a8 100644
--- a/tensorflow/stream_executor/tpu/tpu_platform.cc
+++ b/tensorflow/stream_executor/tpu/tpu_platform.cc
@@ -109,7 +109,7 @@
 }
 
 const std::string& TpuPlatform::Name() const {
-  static std::string* name = new std::string(kName);
+  static std::string* name = new std::string("TPU");
   return *name;
 }
 
@@ -122,7 +122,7 @@
       ->TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopyFn(platform_);
 }
 
-void RegisterTpuPlatform() {
+bool RegisterTpuPlatform() {
   static bool tpu_platform_registered = false;
   if (!tpu_platform_registered) {
     tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
@@ -132,14 +132,7 @@
         std::move(platform)));
     tpu_platform_registered = true;
   }
+  return true;
 }
 
-REGISTER_MODULE_INITIALIZER(tpu_platform, RegisterTpuPlatform());
-
-// Note that module initialization sequencing is not supported in the
-// open-source project, so this will be a no-op there.
-REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager);
-REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
-                                     tpu_platform);
-
 }  // namespace tensorflow
diff --git a/tensorflow/stream_executor/tpu/tpu_platform.h b/tensorflow/stream_executor/tpu/tpu_platform.h
index c2673ab..a3852b0 100644
--- a/tensorflow/stream_executor/tpu/tpu_platform.h
+++ b/tensorflow/stream_executor/tpu/tpu_platform.h
@@ -38,7 +38,6 @@
                           SE_Event*>;
 
   static const ::stream_executor::Platform::Id kId;
-  static constexpr char kName[] = "TPU";
 
   using Status = ::stream_executor::port::Status;
   template <typename T>
@@ -122,7 +121,7 @@
   EventMap event_map_;
 };
 
-void RegisterTpuPlatform();
+bool RegisterTpuPlatform();
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/stream_executor/tpu/tpu_platform_registration.cc b/tensorflow/stream_executor/tpu/tpu_platform_registration.cc
new file mode 100644
index 0000000..6f054f5
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_platform_registration.cc
@@ -0,0 +1,28 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/platform/initialize.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+
+REGISTER_MODULE_INITIALIZER(tpu_platform, tensorflow::RegisterTpuPlatform());
+
+DECLARE_MODULE_INITIALIZER(multi_platform_manager);
+DECLARE_MODULE_INITIALIZER(multi_platform_manager_listener);
+
+// Note that module initialization sequencing is not supported in the
+// open-source project, so this will be a no-op there.
+REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager);
+REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
+                                     tpu_platform);
diff --git a/tensorflow/stream_executor/tpu/tpu_stream.h b/tensorflow/stream_executor/tpu/tpu_stream.h
index 209a624..09b496b 100644
--- a/tensorflow/stream_executor/tpu/tpu_stream.h
+++ b/tensorflow/stream_executor/tpu/tpu_stream.h
@@ -18,7 +18,7 @@
 
 #include "tensorflow/core/tpu/tpu_api.h"
 #include "tensorflow/stream_executor/stream_executor_internal.h"
-#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
+#include "tensorflow/stream_executor/tpu/device_memory_base_helper.h"
 #include "tensorflow/stream_executor/tpu/status_helper.h"
 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
 #include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
@@ -46,8 +46,10 @@
     tensorflow::tpu::ExecutorApiFn()
         ->TpuStream_TpuEnqueueOnDeviceSendRecvLocalFn(
             stream_,
-            TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(send_buffer),
-            TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(recv_buffer),
+            DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(
+                send_buffer),
+            DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(
+                recv_buffer),
             status.c_status);
     return status.status();
   }