Do not enable TensorRT if we cannot load the TensorRT shared libraries.
PiperOrigin-RevId: 263017759
diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD
index 3de09b2..ac32500 100644
--- a/tensorflow/compiler/tf2tensorrt/BUILD
+++ b/tensorflow/compiler/tf2tensorrt/BUILD
@@ -503,7 +503,10 @@
srcs = ["utils/py_utils.cc"],
hdrs = ["utils/py_utils.h"],
copts = tf_copts(),
- deps = if_tensorrt([":tensorrt_lib"]),
+ deps = if_tensorrt([
+ ":tensorrt_lib",
+ "//tensorflow/stream_executor/platform:dso_loader",
+ ]),
)
tf_py_wrap_cc(
diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc
index 008cabb..885f58c 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc
+++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc
@@ -16,6 +16,7 @@
#include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h"
#if GOOGLE_CUDA && GOOGLE_TENSORRT
+#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "third_party/tensorrt/NvInfer.h"
#endif
@@ -23,13 +24,16 @@
namespace tensorrt {
bool IsGoogleTensorRTEnabled() {
- // TODO(laigd): consider also checking if tensorrt shared libraries are
- // accessible. We can then direct users to this function to make sure they can
- // safely write code that uses tensorrt conditionally. E.g. if it does not
- // check for for tensorrt, and user mistakenly uses tensorrt, they will just
- // crash and burn.
#if GOOGLE_CUDA && GOOGLE_TENSORRT
- return true;
+ auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries();
+ if (!handle_or.ok()) {
+ LOG(WARNING) << "Cannot dlopen some TensorRT libraries. If you would like "
+ "to use Nvidia GPU with TensorRT, please make sure the "
+ "missing libraries mentioned above are installed properly.";
+ return false;
+ } else {
+ return true;
+ }
#else
return false;
#endif
diff --git a/tensorflow/stream_executor/platform/default/dlopen_checker.cc b/tensorflow/stream_executor/platform/default/dlopen_checker.cc
index 750c1f2..b55c9f5 100644
--- a/tensorflow/stream_executor/platform/default/dlopen_checker.cc
+++ b/tensorflow/stream_executor/platform/default/dlopen_checker.cc
@@ -20,7 +20,7 @@
namespace internal {
namespace DsoLoader {
-port::Status MaybeTryDlopenCUDALibraries() {
+port::Status TryDlopenCUDALibraries() {
auto cudart_status = GetCudaRuntimeDsoHandle();
auto cublas_status = GetCublasDsoHandle();
auto cufft_status = GetCufftDsoHandle();
@@ -39,7 +39,7 @@
}
}
-port::Status MaybeTryDlopenROCmLibraries() {
+port::Status TryDlopenROCmLibraries() {
auto rocblas_status = GetRocblasDsoHandle();
auto miopen_status = GetMiopenDsoHandle();
auto rocfft_status = GetRocfftDsoHandle();
@@ -55,14 +55,26 @@
port::Status MaybeTryDlopenGPULibraries() {
#if GOOGLE_CUDA
- return MaybeTryDlopenCUDALibraries();
+ return TryDlopenCUDALibraries();
#elif TENSORFLOW_USE_ROCM
- return MaybeTryDlopenROCmLibraries();
+ return TryDlopenROCmLibraries();
#else
LOG(INFO) << "Not built with GPU enabled. Skip GPU library dlopen check.";
return port::Status::OK();
#endif
}
+
+port::Status TryDlopenTensorRTLibraries() {
+ auto nvinfer_status = GetNvInferDsoHandle();
+ auto nvinferplugin_status = GetNvInferPluginDsoHandle();
+ if (!nvinfer_status.status().ok() || !nvinferplugin_status.status().ok()) {
+ return port::Status(port::error::INTERNAL,
+ absl::StrCat("Cannot dlopen all TensorRT libraries."));
+ } else {
+ return port::Status::OK();
+ }
+}
+
} // namespace DsoLoader
} // namespace internal
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/platform/default/dso_loader.h b/tensorflow/stream_executor/platform/default/dso_loader.h
index e9927a7..7eee2e6 100644
--- a/tensorflow/stream_executor/platform/default/dso_loader.h
+++ b/tensorflow/stream_executor/platform/default/dso_loader.h
@@ -57,6 +57,11 @@
// dynamically loaded. Error status is returned when any of the libraries cannot
// be dlopened.
port::Status MaybeTryDlopenGPULibraries();
+
+// The following method tries to dlopen all necessary TensorRT libraries when
+// these libraries should be dynamically loaded. Error status is returned when
+// any of the libraries cannot be dlopened.
+port::Status TryDlopenTensorRTLibraries();
} // namespace DsoLoader
// Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs