Fix data race in tpu_platform_interface
PiperOrigin-RevId: 322498158
Change-Id: I56f5e8dbd0aedb7c37d6ef7eb943ec86fef03120
diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD
index a8557aa..931cfde 100644
--- a/tensorflow/stream_executor/tpu/BUILD
+++ b/tensorflow/stream_executor/tpu/BUILD
@@ -275,6 +275,7 @@
hdrs = ["tpu_platform_interface.h"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core/platform:mutex",
"//tensorflow/core/platform:types",
"//tensorflow/stream_executor:multi_platform_manager",
"//tensorflow/stream_executor:stream_executor_headers",
diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
index fa9062c..2843039 100644
--- a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
+++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
@@ -17,6 +17,7 @@
#include <atomic>
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
namespace tensorflow {
@@ -72,16 +73,19 @@
/* static */
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform(
bool initialize_platform) {
+ static auto* mu = new mutex;
static bool requested_initialize_platform = initialize_platform;
static TpuPlatformInterface* tpu_registered_platform =
GetRegisteredPlatformStatic(initialize_platform);
+ mutex_lock lock(*mu);
if (!requested_initialize_platform && initialize_platform) {
// If the first time this function is called, we did not request
// initializing the platform, but the next caller wants the platform
// initialized, we will call GetRegisteredPlatformStatic again to initialize
// the platform.
tpu_registered_platform = GetRegisteredPlatformStatic(initialize_platform);
+ requested_initialize_platform = true;
}
return tpu_registered_platform;