Disable the driver linking API for specific CUDA 11 and driver versions.
PiperOrigin-RevId: 376826989
Change-Id: If56eac421657a1eaef6f35701093736e77edc3be
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 8508b39..527c0b0 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -873,11 +873,17 @@
submodule_compile_results.push_back(result.second);
}
- TF_ASSIGN_OR_RETURN(
- std::vector<uint8> backend_result,
- this->LinkModules(stream_exec, std::move(submodule_compile_results)));
+ auto maybe_backend_result =
+ this->LinkModules(stream_exec, std::move(submodule_compile_results));
+ if (!maybe_backend_result.ok()) {
+ LOG(ERROR) << "The CUDA linking API did not work. Please use "
+ "XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 to "
+ "bypass it, but expect to get longer compilation time due to "
+ "the lack of multi-threading.";
+ return maybe_backend_result.status();
+ }
- return std::make_pair(ptx_snippets, backend_result);
+ return std::make_pair(ptx_snippets, std::move(*maybe_backend_result));
}
StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
diff --git a/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc b/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc
index f92d3c4..dc62326 100644
--- a/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc
+++ b/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
+#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
namespace stream_executor {
@@ -34,14 +35,35 @@
port::StatusOr<std::vector<uint8>> LinkGpuAsm(
gpu::GpuContext* context, std::vector<CubinOrPTXImage> images) {
+ const bool linking_supported = [] {
+ if (CUDA_VERSION < 11300) {
+ return true;
+ }
+ auto version_or_status = gpu::Diagnostician::FindKernelDriverVersion();
+ if (!version_or_status.ok()) {
+ LOG(WARNING) << "Couldn't read CUDA driver version.";
+ return false;
+ }
+ return std::get<0>(*version_or_status) >= 465;
+ }();
+
+ if (!linking_supported) {
+ return tensorflow::errors::Unimplemented("Linking is unsupported");
+ }
+
gpu::ScopedActivateContext activation(context);
CUlinkState link_state;
RETURN_IF_CUDA_ERROR(cuLinkCreate(0, nullptr, nullptr, &link_state));
for (auto& image : images) {
- RETURN_IF_CUDA_ERROR(cuLinkAddData(
- link_state, CU_JIT_INPUT_CUBIN, static_cast<void*>(image.bytes.data()),
- image.bytes.size(), "", 0, nullptr, nullptr));
+ auto status = cuLinkAddData(link_state, CU_JIT_INPUT_CUBIN,
+ static_cast<void*>(image.bytes.data()),
+ image.bytes.size(), "", 0, nullptr, nullptr);
+ if (status != CUDA_SUCCESS) {
+ LOG(ERROR) << "cuLinkAddData fails. This is usually caused by stale "
+ "driver version.";
+ }
+ RETURN_IF_CUDA_ERROR(status);
}
void* cubin_out;
size_t cubin_size;