blob: 4723c9ea54e91fed7ca2a2131b4e14590b82a596 [file] [log] [blame]
import platform
import glob
import os
import sys
from itertools import chain
from .env import check_env_flag
def gather_paths(env_vars):
return list(chain(*(os.getenv(v, '').split(':') for v in env_vars)))
IS_LINUX = platform.system() == 'Linux'
IS_CONDA = 'conda' in sys.version or 'Continuum' in sys.version
CONDA_DIR = os.path.join(os.path.dirname(sys.executable), '..')
MKLDNN_HOME = os.getenv('MKLDNN_HOME', '/usr/local/mkl-dnn')
WITH_MKLDNN = False
MKLDNN_LIB_DIR = None
MKLDNN_INCLUDE_DIR = None
MKLDNN_LIBRARY = None
if IS_LINUX and not check_env_flag('NO_MKLDNN'):
lib_paths = list(filter(bool, [
os.getenv('MKLDNN_LIB_DIR'),
os.path.join(MKLDNN_HOME, 'lib'),
os.path.join(MKLDNN_HOME, 'lib64'),
'/usr/lib/',
'/usr/lib64/',
] + gather_paths([
'LIBRARY_PATH',
]) + gather_paths([
'LD_LIBRARY_PATH',
])))
include_paths = list(filter(bool, [
os.getenv('MKLDNN_INCLUDE_DIR'),
os.path.join(MKLDNN_HOME, 'include'),
'/usr/include/',
] + gather_paths([
'CPATH',
'C_INCLUDE_PATH',
'CPLUS_INCLUDE_PATH',
])))
if IS_CONDA:
lib_paths.append(os.path.join(CONDA_DIR, 'lib'))
include_paths.append(os.path.join(CONDA_DIR, 'include'))
for path in lib_paths:
if path is None or not os.path.exists(path):
continue
else:
libraries = sorted(glob.glob(os.path.join(path, 'libmkldnn*')))
if libraries:
if not glob.glob(os.path.join(path, 'libmklml_intel*')):
print("WARNING: MKL-DNN is not compiled with Intel MKL small library")
print("Convolution performance might be suboptimal")
print("Refer https://github.com/01org/mkl-dnn for detail info")
MKLDNN_LIBRARY = libraries[0]
MKLDNN_LIB_DIR = path
break
for path in include_paths:
if path is None or not os.path.exists(path):
continue
else:
if os.path.exists(os.path.join(path, 'mkldnn.hpp')):
MKLDNN_INCLUDE_DIR = path
break
# Specifying the library directly will overwrite the lib directory
library = os.getenv('MKLDNN_LIBRARY')
if library is not None and os.path.exists(library):
MKLDNN_LIBRARY = library
MKLDNN_LIB_DIR = os.path.dirname(MKLDNN_LIBRARY)
if not all([MKLDNN_LIBRARY, MKLDNN_LIB_DIR, MKLDNN_INCLUDE_DIR]):
MKLDNN_LIBRARY = MKLDNN_LIB_DIR = MKLDNN_INCLUDE_DIR = None
else:
real_mkldnn_library = os.path.realpath(MKLDNN_LIBRARY)
real_mkldnn_lib_dir = os.path.realpath(MKLDNN_LIB_DIR)
assert os.path.dirname(real_mkldnn_library) == real_mkldnn_lib_dir, (
'mkldnn library and lib_dir must agree')
WITH_MKLDNN = True