Fix data race in tpu_platform_interface

PiperOrigin-RevId: 322498158
Change-Id: I56f5e8dbd0aedb7c37d6ef7eb943ec86fef03120
diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD
index a8557aa..931cfde 100644
--- a/tensorflow/stream_executor/tpu/BUILD
+++ b/tensorflow/stream_executor/tpu/BUILD
@@ -275,6 +275,7 @@
     hdrs = ["tpu_platform_interface.h"],
     visibility = ["//visibility:public"],
     deps = [
+        "//tensorflow/core/platform:mutex",
         "//tensorflow/core/platform:types",
         "//tensorflow/stream_executor:multi_platform_manager",
         "//tensorflow/stream_executor:stream_executor_headers",
diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
index fa9062c..2843039 100644
--- a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
+++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
@@ -17,6 +17,7 @@
 
 #include <atomic>
 
+#include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/stream_executor/multi_platform_manager.h"
 
 namespace tensorflow {
@@ -72,16 +73,19 @@
 /* static */
 TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform(
     bool initialize_platform) {
+  static auto* mu = new mutex;
   static bool requested_initialize_platform = initialize_platform;
   static TpuPlatformInterface* tpu_registered_platform =
       GetRegisteredPlatformStatic(initialize_platform);
 
+  mutex_lock lock(*mu);
   if (!requested_initialize_platform && initialize_platform) {
     // If the first time this function is called, we did not request
     // initializing the platform, but the next caller wants the platform
     // initialized, we will call GetRegisteredPlatformStatic again to initialize
     // the platform.
     tpu_registered_platform = GetRegisteredPlatformStatic(initialize_platform);
+    requested_initialize_platform = true;
   }
 
   return tpu_registered_platform;