Skip to content

Commit

Permalink
fix: handle CUDA arch major >= 10 (facebookresearch#766)
Browse files Browse the repository at this point in the history
* fix: handle CUDA arch major >= 10

* Fix: handle +<foo> in version string
  • Loading branch information
BMDan authored Jun 22, 2023
1 parent 6e88cf8 commit 6f0602f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
for arch in archs_list.replace(" ", ";").split(";"):
assert len(arch) >= 3, f"Invalid sm version: {arch}"

num = 10 * int(arch[0]) + int(arch[2])
arch_arr = arch.split('.')
num = 10 * int(arch_arr[0]) + int(arch_arr[1].partition("+")[0])
# Need at least 7.5
if num < 75:
continue
Expand Down

0 comments on commit 6f0602f

Please sign in to comment.