4 from .env
import IS_WINDOWS, IS_CONDA, CONDA_DIR, check_negative_env_flag, gather_paths, lib_paths_from_base
5 from .cuda
import USE_CUDA, CUDA_HOME
10 CUDNN_INCLUDE_DIR =
None 12 WITH_STATIC_CUDNN = os.getenv(
"USE_STATIC_CUDNN")
14 if USE_CUDA
and not check_negative_env_flag(
'USE_CUDNN'):
15 lib_paths = list(filter(bool, [
16 os.getenv(
'CUDNN_LIB_DIR')
17 ] + lib_paths_from_base(CUDA_HOME) + [
18 '/usr/lib/x86_64-linux-gnu/',
19 '/usr/lib/powerpc64le-linux-gnu/',
20 '/usr/lib/aarch64-linux-gnu/',
26 include_paths = list(filter(bool, [
27 os.getenv(
'CUDNN_INCLUDE_DIR'),
28 os.path.join(CUDA_HOME,
'include'),
37 lib_paths.append(os.path.join(CONDA_DIR,
'lib'))
38 include_paths.append(os.path.join(CONDA_DIR,
'include'))
39 for path
in include_paths:
40 if path
is None or not os.path.exists(path):
42 include_file_path = os.path.join(path,
'cudnn.h')
43 CUDNN_INCLUDE_VERSION =
None 44 if os.path.exists(include_file_path):
45 CUDNN_INCLUDE_DIR = path
46 with open(include_file_path)
as f:
48 if "#define CUDNN_MAJOR" in line:
49 CUDNN_INCLUDE_VERSION = int(line.split()[-1])
51 if CUDNN_INCLUDE_VERSION
is None:
52 raise AssertionError(
"Could not find #define CUDNN_MAJOR in " + include_file_path)
55 if CUDNN_INCLUDE_VERSION
is None:
59 if CUDNN_INCLUDE_DIR
is not None:
60 cudnn_path = os.path.join(os.path.dirname(CUDNN_INCLUDE_DIR))
61 cudnn_lib_paths = lib_paths_from_base(cudnn_path)
62 lib_paths.extend(cudnn_lib_paths)
64 for path
in lib_paths:
65 if path
is None or not os.path.exists(path):
68 library = os.path.join(path,
'cudnn.lib')
69 if os.path.exists(library):
70 CUDNN_LIBRARY = library
74 if WITH_STATIC_CUDNN
is not None:
75 search_name =
'libcudnn_static.a' 77 search_name =
'libcudnn*' + str(CUDNN_INCLUDE_VERSION) +
"*" 78 libraries = sorted(glob.glob(os.path.join(path, search_name)))
80 CUDNN_LIBRARY = libraries[0]
84 library = os.getenv(
'CUDNN_LIBRARY')
85 if library
is not None and os.path.exists(library):
86 CUDNN_LIBRARY = library
87 CUDNN_LIB_DIR = os.path.dirname(CUDNN_LIBRARY)
89 if not all([CUDNN_LIBRARY, CUDNN_LIB_DIR, CUDNN_INCLUDE_DIR]):
90 CUDNN_LIBRARY = CUDNN_LIB_DIR = CUDNN_INCLUDE_DIR =
None 92 real_cudnn_library = os.path.realpath(CUDNN_LIBRARY)
93 real_cudnn_lib_dir = os.path.realpath(CUDNN_LIB_DIR)
94 assert os.path.dirname(real_cudnn_library) == real_cudnn_lib_dir, (
95 'cudnn library and lib_dir must agree')