| import os |
| import glob |
| import ctypes.util |
| |
| from . import which |
| from .env import IS_WINDOWS, IS_LINUX, IS_DARWIN, check_negative_env_flag |
| |
| LINUX_HOME = '/usr/local/cuda' |
| WINDOWS_HOME = glob.glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') |
| |
| |
| def find_nvcc(): |
| nvcc = which('nvcc') |
| if nvcc is not None: |
| return os.path.dirname(nvcc) |
| else: |
| return None |
| |
| |
| if check_negative_env_flag('USE_CUDA'): |
| USE_CUDA = False |
| CUDA_HOME = None |
| else: |
| if IS_LINUX or IS_DARWIN: |
| CUDA_HOME = os.getenv('CUDA_HOME', LINUX_HOME) |
| else: |
| CUDA_HOME = os.getenv('CUDA_PATH', '').replace('\\', '/') |
| if CUDA_HOME == '' and len(WINDOWS_HOME) > 0: |
| CUDA_HOME = WINDOWS_HOME[0].replace('\\', '/') |
| if not os.path.exists(CUDA_HOME): |
| # We use nvcc path on Linux and cudart path on macOS |
| if IS_LINUX or IS_WINDOWS: |
| cuda_path = find_nvcc() |
| else: |
| cudart_path = ctypes.util.find_library('cudart') |
| if cudart_path is not None: |
| cuda_path = os.path.dirname(cudart_path) |
| else: |
| cuda_path = None |
| if cuda_path is not None: |
| CUDA_HOME = os.path.dirname(cuda_path) |
| else: |
| CUDA_HOME = None |
| USE_CUDA = CUDA_HOME is not None |