blob: ced16e0e989394010382e32a0f288b7038efc432 [file] [log] [blame]
import os
import platform
import ctypes.util
from subprocess import Popen, PIPE
from .env import check_env_flag
def find_nvcc():
proc = Popen(['which', 'nvcc'], stdout=PIPE, stderr=PIPE)
out, err = proc.communicate()
out = out.decode().strip()
if len(out) > 0:
return os.path.dirname(out)
else:
return None
if check_env_flag('NO_CUDA'):
WITH_CUDA = False
CUDA_HOME = None
else:
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
if not os.path.exists(CUDA_HOME):
# We use nvcc path on Linux and cudart path on macOS
osname = platform.system()
if osname == 'Linux':
cuda_path = find_nvcc()
else:
cudart_path = ctypes.util.find_library('cudart')
if cudart_path is not None:
cuda_path = os.path.dirname(cudart_path)
else:
cuda_path = None
if cuda_path is not None:
CUDA_HOME = os.path.dirname(cuda_path)
else:
CUDA_HOME = None
WITH_CUDA = CUDA_HOME is not None