diff --git a/install.bash b/install.bash index 8a241cb..374ae66 100644 --- a/install.bash +++ b/install.bash @@ -24,16 +24,17 @@ fi echo "Installing torch & xformers..." -cuda_version_line=$(nvcc --version | grep 'release') -cuda_version=$(echo $cuda_version_line | sed -n -e 's/^.*release \([0-9]\+\.[0-9]\+\),.*$/\1/p') +cuda_version=$(nvcc --version | grep 'release' | sed -n -e 's/^.*release \([0-9]\+\.[0-9]\+\),.*$/\1/p') +cuda_major_version=$(echo "$cuda_version" | awk -F'.' '{print $1}') +cuda_minor_version=$(echo "$cuda_version" | awk -F'.' '{print $2}') echo "Cuda Version:$cuda_version" -if [[ $cuda_version == "11.8" ]]; then +if (( cuda_major_version >= 12 )) || (( cuda_major_version == 11 && cuda_minor_version >= 8 )); then echo "install torch 2.0.0+cu118" pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 pip install xformers==0.0.19 -elif [[ $cuda_version == "11.6" ]]; then +elif (( cuda_major_version == 11 && cuda_minor_version == 6 )); then echo "install torch 1.12.1+cu116" pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 pip install --upgrade git+https://github.com/facebookresearch/xformers.git@0bad001ddd56c080524d37c84ff58d9cd030ebfd