Create a lockfile when loading libtpu.so to prevent attempts at double loading and initialization
PiperOrigin-RevId: 357513705
Change-Id: Iafc6a83f0a3bdfa580c98b286a4113508852c3b8
diff --git a/tensorflow/core/profiler/internal/tpu/BUILD b/tensorflow/core/profiler/internal/tpu/BUILD
index 67c5c34..0817946 100644
--- a/tensorflow/core/profiler/internal/tpu/BUILD
+++ b/tensorflow/core/profiler/internal/tpu/BUILD
@@ -23,6 +23,7 @@
"//tensorflow/core/profiler/utils:xplane_utils",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/core/tpu:tpu_api_dlsym_initializer",
+ "//tensorflow/core/tpu:tpu_initializer_helper",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"//tensorflow/stream_executor/tpu:status_helper",
"@com_google_absl//absl/strings",
diff --git a/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc b/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
index 528432f..b69b5b7 100644
--- a/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
+++ b/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
@@ -29,6 +29,7 @@
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/tpu/tpu_api.h"
+#include "tensorflow/core/tpu/tpu_initializer_helper.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
@@ -134,7 +135,9 @@
}
auto register_tpu_tracer_factory = [] {
- RegisterProfilerFactory(&CreateTpuTracer);
+ if (tensorflow::tpu::TryAcquireTpuLock()) {
+ RegisterProfilerFactory(&CreateTpuTracer);
+ }
return 0;
}();
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 8753f9e..5d15e27 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -11,6 +11,7 @@
"//tensorflow/compiler/mlir/tensorflow:__subpackages__",
"//tensorflow/compiler/tf2xla/kernels:__subpackages__",
"//tensorflow/compiler/xrt:__subpackages__",
+ "//tensorflow/core/profiler/internal/tpu:__subpackages__",
"//tensorflow/core/tpu:__subpackages__",
"//tensorflow/stream_executor/tpu:__subpackages__",
],
@@ -105,7 +106,11 @@
name = "tpu_initializer_helper",
srcs = ["tpu_initializer_helper.cc"],
hdrs = ["tpu_initializer_helper.h"],
- deps = ["@com_google_absl//absl/strings"],
+ deps = [
+ "//tensorflow/core/platform:logging",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ ],
)
cc_library(
diff --git a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
index 4c67d59..eb32962 100644
--- a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
+++ b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
@@ -65,6 +65,8 @@
}
bool FindAndLoadTpuLibrary() {
+ if (!TryAcquireTpuLock()) return false;
+
void* library = dlopen("libtpu.so", RTLD_NOW);
if (library) {
InitializeTpuLibrary(library);
diff --git a/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
index 8c2ae85..d3a70da 100644
--- a/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
+++ b/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
@@ -62,6 +62,8 @@
}
bool FindAndLoadTpuLibrary() {
+ if (!TryAcquireTpuLock()) return false;
+
void* library = dlopen("libtpu.so", RTLD_NOW);
if (library) {
InitializeTpuLibrary(library);
diff --git a/tensorflow/core/tpu/tpu_initializer_helper.cc b/tensorflow/core/tpu/tpu_initializer_helper.cc
index c97a09b..07be9f8 100644
--- a/tensorflow/core/tpu/tpu_initializer_helper.cc
+++ b/tensorflow/core/tpu/tpu_initializer_helper.cc
@@ -15,13 +15,56 @@
#include "tensorflow/core/tpu/tpu_initializer_helper.h"
+#if defined(LIBTPU_ON_GCE)
+#include <fcntl.h>
#include <stdlib.h>
+#include <unistd.h>
+#endif // LIBTPU_ON_GCE
#include "absl/strings/str_split.h"
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace tpu {
+bool TryAcquireTpuLock() {
+#if defined(LIBTPU_ON_GCE)
+ static absl::Mutex* mu = new absl::Mutex();
+ absl::MutexLock l(mu);
+
+ static bool attempted_file_open = false;
+ static bool should_load_library = false;
+
+ if (!attempted_file_open) {
+ should_load_library = true;
+
+ // if the TPU_HOST_BOUNDS env var is set, that means we are loading each
+ // chip in a different process and thus multiple libtpu loads are OK.
+ if (getenv("TPU_HOST_BOUNDS") == nullptr) {
+ int fd = open("/tmp/libtpu_lockfile", O_CREAT | O_RDWR);
+
+ // This lock is held until the process exits intentionally. The underlying
+ // TPU device will be held on until it quits.
+ if (lockf(fd, F_TLOCK, 0) != 0) {
+ LOG(WARNING) << "libtpu.so already in used by another process. Not "
+ "attempting to load libtpu.so in this process.";
+ should_load_library = false;
+ } else {
+ should_load_library = true;
+ }
+ } else {
+ LOG(INFO) << "TPU_HOST_BOUNDS is set, allowing multiple libtpu.so loads.";
+ should_load_library = true;
+ }
+ }
+
+ return should_load_library;
+#else // LIBTPU_ON_GCE
+ return false;
+#endif
+}
+
std::pair<std::vector<std::string>, std::vector<const char*>>
GetLibTpuInitArguments() {
// We make copies of the arguments returned by getenv because the memory
diff --git a/tensorflow/core/tpu/tpu_initializer_helper.h b/tensorflow/core/tpu/tpu_initializer_helper.h
index cd9b419..3abad8b 100644
--- a/tensorflow/core/tpu/tpu_initializer_helper.h
+++ b/tensorflow/core/tpu/tpu_initializer_helper.h
@@ -22,6 +22,11 @@
namespace tensorflow {
namespace tpu {
+// This will acquire a system-wide lock on behalf of the whole process. Follow
+// up calls to this function will return true if the lock has been acquired and
+// false if we failed to acquire the lock.
+bool TryAcquireTpuLock();
+
// Returns arguments (e.g. flags) set in the LIBTPU_INIT_ARGS environment
// variable. The first return value is the arguments, the second return value is
// pointers to the arguments suitable for passing into the C API.