Register the TPU platform when the relevant TPU libraries are dynamically loaded
PiperOrigin-RevId: 318420259
Change-Id: Ic7807130ab4717031e9176b466ed9a51da8b27a7
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 589af63..dd7435d 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -139,6 +139,7 @@
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
+ "//tensorflow/stream_executor/tpu:tpu_executor_base",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
],
diff --git a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
index 495e6a2..450f7aa 100644
--- a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
+++ b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
@@ -21,12 +21,14 @@
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
-#define TFTPU_SET_FN(Struct, FnName) \
- Struct->FnName##Fn = \
- reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName)); \
- if (!(Struct->FnName##Fn)) { \
- LOG(ERROR) << #FnName " not available in this library."; \
+#define TFTPU_SET_FN(Struct, FnName) \
+ Struct->FnName##Fn = \
+ reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName)); \
+ if (!(Struct->FnName##Fn)) { \
+ LOG(ERROR) << #FnName " not available in this library."; \
+ return errors::Unimplemented(#FnName " not available in this library."); \
}
// Reminder: Update tpu_library_loader_windows.cc if you are adding new publicly
@@ -44,19 +46,20 @@
shared_object_loaded = false;
}
- TF_RETURN_IF_ERROR(InitializeTpuStructFns(library_handle));
+ Status s = InitializeTpuStructFns(library_handle);
- if (shared_object_loaded) {
+ // TPU platform registration must only be performed after the library is
+ // loaded. We do not want to register a TPU platform in XLA without the
+ // supporting library providing the necessary APIs.
+ if (shared_object_loaded && s.ok()) {
// TODO(frankchn): Make initialization actually work
// Initialize TPU platform when the platform code is loaded from a library.
// InitializeApiFn()->TfTpu_InitializeFn();
- // We should only register the TPU platform when the library is loaded.
- // TODO(frankchn): Resolve the circular dependency and register the platform
- // RegisterTpuPlatform();
+ RegisterTpuPlatform();
}
- return Status::OK();
+ return s;
}
} // namespace tpu
diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD
index 66b0013..add4db1 100644
--- a/tensorflow/stream_executor/tpu/BUILD
+++ b/tensorflow/stream_executor/tpu/BUILD
@@ -35,6 +35,7 @@
deps = [
":tpu_executor_c_api_hdrs",
"//tensorflow/core/platform:status",
+ "//tensorflow/core/tpu:tpu_api",
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
],
)
@@ -61,7 +62,16 @@
)
cc_library(
- name = "tpu_executor",
+ name = "device_memory_base_helper",
+ hdrs = ["device_memory_base_helper.h"],
+ deps = [
+ ":tpu_executor_c_api_hdrs",
+ "//tensorflow/stream_executor:device_memory",
+ ],
+)
+
+cc_library(
+ name = "tpu_executor_base",
srcs = [
"tpu_executor.cc",
"tpu_platform.cc",
@@ -73,7 +83,7 @@
"tpu_timer.h",
],
deps = [
- ":c_api_conversions",
+ ":device_memory_base_helper",
":status_helper",
":tpu_executor_c_api_hdrs",
":tpu_executor_interface",
@@ -91,6 +101,16 @@
)
cc_library(
+ name = "tpu_executor",
+ srcs = ["tpu_platform_registration.cc"],
+ deps = [
+ ":tpu_executor_base",
+ "//tensorflow/stream_executor/platform",
+ ],
+ alwayslink = True,
+)
+
+cc_library(
name = "tpu_node_context",
srcs = ["tpu_node_context.cc"],
hdrs = ["tpu_node_context.h"],
@@ -118,6 +138,7 @@
srcs = ["tpu_transfer_manager_registration.cc"],
deps = [
":tpu_executor",
+ ":tpu_executor_base",
":tpu_transfer_manager_base",
"//tensorflow/compiler/xla/service:transfer_manager",
],
@@ -131,7 +152,7 @@
":c_api_conversions",
":proto_helper",
":status_helper",
- ":tpu_executor",
+ ":tpu_executor_base",
":tpu_executor_c_api_hdrs",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/stream_executor/tpu/device_memory_base_helper.h b/tensorflow/stream_executor/tpu/device_memory_base_helper.h
new file mode 100644
index 0000000..9937dc2
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/device_memory_base_helper.h
@@ -0,0 +1,41 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_DEVICE_MEMORY_BASE_HELPER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_DEVICE_MEMORY_BASE_HELPER_H_
+
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+class DeviceMemoryBaseHelper {
+ public:
+ static stream_executor::DeviceMemoryBase
+ SE_DeviceMemoryBaseToDeviceMemoryBase(SE_DeviceMemoryBase se_base) {
+ stream_executor::DeviceMemoryBase base(se_base.opaque, se_base.size);
+ base.SetPayload(se_base.payload);
+ return base;
+ }
+
+ static SE_DeviceMemoryBase DeviceMemoryBaseToSE_DeviceMemoryBase(
+ const stream_executor::DeviceMemoryBase& base) {
+ SE_DeviceMemoryBase se_base;
+ se_base.opaque = const_cast<void*>(base.opaque());
+ se_base.payload = base.payload();
+ se_base.size = base.size();
+ return se_base;
+ }
+};
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_DEVICE_MEMORY_BASE_HELPER_H_
diff --git a/tensorflow/stream_executor/tpu/status_helper.h b/tensorflow/stream_executor/tpu/status_helper.h
index bc8820f..0129abb 100644
--- a/tensorflow/stream_executor/tpu/status_helper.h
+++ b/tensorflow/stream_executor/tpu/status_helper.h
@@ -18,22 +18,34 @@
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
+#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
-struct StatusHelper {
- StatusHelper() : c_status(TpuStatus_New()) {}
- ~StatusHelper() { TpuStatus_Free(c_status); }
- bool ok() { return TpuStatus_Code(c_status) == 0; }
- tensorflow::Status status() {
+class StatusHelper {
+ public:
+ StatusHelper()
+ : c_status(tensorflow::tpu::ExecutorApiFn()->TpuStatus_NewFn()) {}
+
+ ~StatusHelper() {
+ tensorflow::tpu::ExecutorApiFn()->TpuStatus_FreeFn(c_status);
+ }
+
+ bool ok() const {
+ return tensorflow::tpu::ExecutorApiFn()->TpuStatus_CodeFn(c_status) == 0;
+ }
+
+ tensorflow::Status status() const {
if (!ok()) {
return tensorflow::Status(
- tensorflow::error::Code(TpuStatus_Code(c_status)),
- TpuStatus_Message(c_status));
+ tensorflow::error::Code(
+ tensorflow::tpu::ExecutorApiFn()->TpuStatus_CodeFn(c_status)),
+ tensorflow::tpu::ExecutorApiFn()->TpuStatus_MessageFn(c_status));
} else {
return tensorflow::Status::OK();
}
}
- SE_Status* c_status;
+
+ SE_Status* c_status; // NOLINT
};
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_executor.cc b/tensorflow/stream_executor/tpu/tpu_executor.cc
index 95c3271..60d6d22b 100644
--- a/tensorflow/stream_executor/tpu/tpu_executor.cc
+++ b/tensorflow/stream_executor/tpu/tpu_executor.cc
@@ -20,7 +20,7 @@
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/status.h"
-#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
+#include "tensorflow/stream_executor/tpu/device_memory_base_helper.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_stream.h"
@@ -123,25 +123,26 @@
stream_executor::Event::Status TpuExecutor::PollForEventStatus(
stream_executor::Event* event) {
- return stream_executor::Event::Status(TpuExecutor_PollForEventStatus(
- executor_, event_map().at(event->implementation())));
+ return stream_executor::Event::Status(
+ tpu::ExecutorApiFn()->TpuExecutor_PollForEventStatusFn(
+ executor_, event_map().at(event->implementation())));
}
Status TpuExecutor::RecordEvent(Stream* stream,
::stream_executor::Event* event) {
StatusHelper status;
- TpuExecutor_RecordEvent(executor_, stream_map().at(stream->implementation()),
- event_map().at(event->implementation()),
- status.c_status);
+ tpu::ExecutorApiFn()->TpuExecutor_RecordEventFn(
+ executor_, stream_map().at(stream->implementation()),
+ event_map().at(event->implementation()), status.c_status);
return status.status();
}
Status TpuExecutor::WaitForEvent(Stream* stream,
::stream_executor::Event* event) {
StatusHelper status;
- TpuExecutor_WaitForEvent(executor_, stream_map().at(stream->implementation()),
- event_map().at(event->implementation()),
- status.c_status);
+ tpu::ExecutorApiFn()->TpuExecutor_WaitForEventFn(
+ executor_, stream_map().at(stream->implementation()),
+ event_map().at(event->implementation()), status.c_status);
return status.status();
}
@@ -181,18 +182,18 @@
DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) {
SE_DeviceMemoryBase se_base = tpu::ExecutorApiFn()->TpuExecutor_AllocateFn(
executor_, size, memory_space);
- return TpuConversions::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base);
+ return DeviceMemoryBaseHelper::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base);
}
void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
SE_DeviceMemoryBase se_base =
- TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(memory);
+ DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(memory);
tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
}
void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
SE_DeviceMemoryBase se_base =
- TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory);
+ DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory);
tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
}
@@ -268,7 +269,7 @@
const ::stream_executor::DeviceMemoryBase& device_src,
uint64 size) {
SE_DeviceMemoryBase se_base =
- TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
+ DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
return tpu::ExecutorApiFn()->TpuExecutor_MemcpyToHostFn(
executor_, stream_map().at(stream->implementation()), host_dst, &se_base,
size);
@@ -278,7 +279,8 @@
::stream_executor::DeviceMemoryBase* device_dst,
const void* host_src, uint64 size) {
SE_DeviceMemoryBase se_base =
- TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
+ DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(
+ *device_dst);
return tpu::ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn(
executor_, stream_map().at(stream->implementation()), &se_base, host_src,
size);
@@ -289,7 +291,8 @@
uint64 size) {
StatusHelper status;
SE_DeviceMemoryBase se_base =
- TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
+ DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(
+ *device_dst);
tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyFromHostFn(
executor_, &se_base, host_src, size, status.c_status);
return status.status();
@@ -300,7 +303,7 @@
uint64 size) {
StatusHelper status;
SE_DeviceMemoryBase se_base =
- TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
+ DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyToHostFn(
executor_, host_dst, &se_base, size, status.c_status);
return status.status();
diff --git a/tensorflow/stream_executor/tpu/tpu_platform.cc b/tensorflow/stream_executor/tpu/tpu_platform.cc
index db6324e..24767a8 100644
--- a/tensorflow/stream_executor/tpu/tpu_platform.cc
+++ b/tensorflow/stream_executor/tpu/tpu_platform.cc
@@ -109,7 +109,7 @@
}
const std::string& TpuPlatform::Name() const {
- static std::string* name = new std::string(kName);
+ static std::string* name = new std::string("TPU");
return *name;
}
@@ -122,7 +122,7 @@
->TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopyFn(platform_);
}
-void RegisterTpuPlatform() {
+bool RegisterTpuPlatform() {
static bool tpu_platform_registered = false;
if (!tpu_platform_registered) {
tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
@@ -132,14 +132,7 @@
std::move(platform)));
tpu_platform_registered = true;
}
+ return true;
}
-REGISTER_MODULE_INITIALIZER(tpu_platform, RegisterTpuPlatform());
-
-// Note that module initialization sequencing is not supported in the
-// open-source project, so this will be a no-op there.
-REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager);
-REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
- tpu_platform);
-
} // namespace tensorflow
diff --git a/tensorflow/stream_executor/tpu/tpu_platform.h b/tensorflow/stream_executor/tpu/tpu_platform.h
index c2673ab..a3852b0 100644
--- a/tensorflow/stream_executor/tpu/tpu_platform.h
+++ b/tensorflow/stream_executor/tpu/tpu_platform.h
@@ -38,7 +38,6 @@
SE_Event*>;
static const ::stream_executor::Platform::Id kId;
- static constexpr char kName[] = "TPU";
using Status = ::stream_executor::port::Status;
template <typename T>
@@ -122,7 +121,7 @@
EventMap event_map_;
};
-void RegisterTpuPlatform();
+bool RegisterTpuPlatform();
} // namespace tensorflow
diff --git a/tensorflow/stream_executor/tpu/tpu_platform_registration.cc b/tensorflow/stream_executor/tpu/tpu_platform_registration.cc
new file mode 100644
index 0000000..6f054f5
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_platform_registration.cc
@@ -0,0 +1,28 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/platform/initialize.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+
+REGISTER_MODULE_INITIALIZER(tpu_platform, tensorflow::RegisterTpuPlatform());
+
+DECLARE_MODULE_INITIALIZER(multi_platform_manager);
+DECLARE_MODULE_INITIALIZER(multi_platform_manager_listener);
+
+// Note that module initialization sequencing is not supported in the
+// open-source project, so this will be a no-op there.
+REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager);
+REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
+ tpu_platform);
diff --git a/tensorflow/stream_executor/tpu/tpu_stream.h b/tensorflow/stream_executor/tpu/tpu_stream.h
index 209a624..09b496b 100644
--- a/tensorflow/stream_executor/tpu/tpu_stream.h
+++ b/tensorflow/stream_executor/tpu/tpu_stream.h
@@ -18,7 +18,7 @@
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
-#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
+#include "tensorflow/stream_executor/tpu/device_memory_base_helper.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
@@ -46,8 +46,10 @@
tensorflow::tpu::ExecutorApiFn()
->TpuStream_TpuEnqueueOnDeviceSendRecvLocalFn(
stream_,
- TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(send_buffer),
- TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(recv_buffer),
+ DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(
+ send_buffer),
+ DeviceMemoryBaseHelper::DeviceMemoryBaseToSE_DeviceMemoryBase(
+ recv_buffer),
status.c_status);
return status.status();
}