5 from subprocess
import Popen, PIPE
7 from .env
import IS_WINDOWS, IS_LINUX, IS_DARWIN, check_env_flag, check_negative_env_flag
9 LINUX_HOME =
'/usr/local/cuda' 10 WINDOWS_HOME = glob.glob(
'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
15 proc = Popen([
'where',
'nvcc.exe'], stdout=PIPE, stderr=PIPE)
17 proc = Popen([
'which',
'nvcc'], stdout=PIPE, stderr=PIPE)
18 out, err = proc.communicate()
19 out = out.decode().strip()
22 if out.find(
'\r\n') != -1:
23 out = out.split(
'\r\n')[0]
24 out = os.path.abspath(os.path.join(os.path.dirname(out),
".."))
25 out = out.replace(
'\\',
'/')
27 return os.path.dirname(out)
32 def find_cuda_version(cuda_home):
36 candidate_names = [os.path.basename(cuda_home)]
39 cuda_lib_dirs = [
'lib64',
'lib']
40 for lib_dir
in cuda_lib_dirs:
41 cuda_lib_path = os.path.join(cuda_home, lib_dir)
42 if os.path.exists(cuda_lib_path):
46 candidate_names = list(glob.glob(os.path.join(cuda_lib_path,
'*cudart*')))
47 candidate_names = [os.path.basename(c)
for c
in candidate_names]
50 version_regex = re.compile(
r'[0-9]+\.[0-9]+\.[0-9]+')
51 candidates = [c.group()
for c
in map(version_regex.search, candidate_names)
if c]
52 if len(candidates) > 0:
56 version_regex = re.compile(
r'[0-9]+\.[0-9]+')
57 candidates = [c.group()
for c
in map(version_regex.search, candidate_names)
if c]
58 if len(candidates) > 0:
61 if check_negative_env_flag(
'USE_CUDA')
or check_env_flag(
'USE_ROCM'):
66 if IS_LINUX
or IS_DARWIN:
67 CUDA_HOME = os.getenv(
'CUDA_HOME', LINUX_HOME)
69 CUDA_HOME = os.getenv(
'CUDA_PATH',
'').replace(
'\\',
'/')
70 if CUDA_HOME ==
'' and len(WINDOWS_HOME) > 0:
71 CUDA_HOME = WINDOWS_HOME[0].replace(
'\\',
'/')
72 if not os.path.exists(CUDA_HOME):
74 if IS_LINUX
or IS_WINDOWS:
75 cuda_path = find_nvcc()
77 cudart_path = ctypes.util.find_library(
'cudart')
78 if cudart_path
is not None:
79 cuda_path = os.path.dirname(cudart_path)
82 if cuda_path
is not None:
83 CUDA_HOME = os.path.dirname(cuda_path)
86 CUDA_VERSION = find_cuda_version(CUDA_HOME)
87 USE_CUDA = CUDA_HOME
is not None