Skip to content

Commit

Permalink
Editing setup.py to locate nvcc and detect cuda major version more st…
Browse files Browse the repository at this point in the history
…rictly
  • Loading branch information
definitelynotmcarilli committed May 2, 2018
1 parent b0d7d60 commit 5dfa4c3
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,18 @@ def findcuda():
raise RuntimeError("Error: Could not find cuda on this system."+
" Please set your CUDA_HOME enviornment variable to the CUDA base directory.")

NVCC = find(CUDA_HOME, re.compile('nvcc').search)
CUDA_LIB = find(CUDA_HOME, re.compile('libcudart.so.*.*.*').search)
NVCC = find(CUDA_HOME+os.sep+"bin",
re.compile('nvcc$').search)
print("Found NVCC = ", NVCC)

# Parse output of nvcc to get cuda major version
nvcc_output = subprocess.check_output([NVCC, '--version']).decode("utf-8")
CUDA_LIB = re.compile(', V[0-9]+\.[0-9]+\.[0-9]+').search(nvcc_output).group(0).split('V')[1]
print("Found CUDA_LIB = ", CUDA_LIB)

if CUDA_LIB:
try:
CUDA_VERSION = int(CUDA_LIB.split('.')[2])
CUDA_VERSION = int(CUDA_LIB.split('.')[0])
except (ValueError, TypeError):
CUDA_VERSION = 9
else:
Expand Down

0 comments on commit 5dfa4c3

Please sign in to comment.