blob: 79e3ad253602d89abb5416b3ca868ab589f0a0c7 [file] [log] [blame]
import os
import glob
from itertools import chain
from .env import check_env_flag
from .cuda import WITH_CUDA, CUDA_HOME
def gather_paths(env_vars):
return list(chain(*(os.getenv(v, '').split(':') for v in env_vars)))
WITH_CUDNN = False
CUDNN_LIB_DIR = None
CUDNN_INCLUDE_DIR = None
if WITH_CUDA and not check_env_flag('NO_CUDNN'):
lib_paths = list(filter(bool, [
os.getenv('CUDNN_LIB_DIR'),
os.path.join(CUDA_HOME, 'lib'),
os.path.join(CUDA_HOME, 'lib64'),
'/usr/lib/x86_64-linux-gnu/',
'/usr/lib/powerpc64le-linux-gnu/',
] + gather_paths([
'LIBRARY_PATH',
])))
include_paths = list(filter(bool, [
os.getenv('CUDNN_INCLUDE_DIR'),
os.path.join(CUDA_HOME, 'include'),
'/usr/include/',
] + gather_paths([
'CPATH',
'C_INCLUDE_PATH',
'CPLUS_INCLUDE_PATH',
])))
for path in lib_paths:
if path is None or not os.path.exists(path):
continue
if glob.glob(os.path.join(path, 'libcudnn*')):
CUDNN_LIB_DIR = path
break
for path in include_paths:
if path is None or not os.path.exists(path):
continue
if os.path.exists((os.path.join(path, 'cudnn.h'))):
CUDNN_INCLUDE_DIR = path
break
if not CUDNN_LIB_DIR or not CUDNN_INCLUDE_DIR:
CUDNN_LIB_DIR = CUDNN_INCLUDE_DIR = None
else:
WITH_CUDNN = True