blob: 5b7432e885867686916a95b1b07dffa565ba5bd8 [file] [log] [blame]
From fd6e7a61a16a17fa155cbd717de0c79001af71e6 Mon Sep 17 00:00:00 2001
From: Artem Belevich <tra@google.com>
Date: Mon, 23 Sep 2019 11:18:56 -0700
Subject: [PATCH] Fix CUDA version detection in CUB
This fixes the problem with CUB using deprecated shfl/vote instructions when CUB
is compiled with clang (e.g. some TensorFlow builds).
---
cub/util_arch.cuh | 3 ++-
cub/util_type.cuh | 4 ++--
2 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/cub/util_arch.cuh b/cub/util_arch.cuh
index 87c5ea2fb..9ad9d1cbb 100644
--- a/cub/util_arch.cuh
+++ b/cub/util_arch.cuh
@@ -44,7 +44,8 @@ namespace cub {
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
-#if (__CUDACC_VER_MAJOR__ >= 9) && !defined(CUB_USE_COOPERATIVE_GROUPS)
+#if !defined(CUB_USE_COOPERATIVE_GROUPS) && \
+ (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
#define CUB_USE_COOPERATIVE_GROUPS
#endif
diff --git a/cub/util_type.cuh b/cub/util_type.cuh
index 0ba41e1ed..b2433d735 100644
--- a/cub/util_type.cuh
+++ b/cub/util_type.cuh
@@ -37,7 +37,7 @@
#include <limits>
#include <cfloat>
-#if (__CUDACC_VER_MAJOR__ >= 9)
+#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
#include <cuda_fp16.h>
#endif
@@ -1063,7 +1063,7 @@ struct FpLimits<double>
};
-#if (__CUDACC_VER_MAJOR__ >= 9)
+#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
template <>
struct FpLimits<__half>
{