Skip to content

Commit

Permalink
Add cudann install to menu
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Jun 5, 2023
1 parent f9a7d53 commit 1179780
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 97 deletions.
2 changes: 1 addition & 1 deletion setup.bat
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ IF NOT EXIST venv (
)

:: Create the directory if it doesn't exist
mkdir ".\logs\status" > nul 2>&1
mkdir ".\logs\setup" > nul 2>&1

:: Deactivate the virtual environment
call .\venv\Scripts\deactivate.bat
Expand Down
78 changes: 0 additions & 78 deletions tools/cudann_1.8_install.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,9 @@
import filecmp
import importlib.util
import os
import shutil
import sys
import sysconfig
import subprocess
from pathlib import Path
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata

req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../requirements.txt")

def run(command, desc=None, errdesc=None, custom_env=None):
if desc is not None:
print(desc)

result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)

if result.returncode != 0:

message = f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
"""
raise RuntimeError(message)

return result.stdout.decode(encoding="utf8", errors="ignore")

def check_versions():
global req_file
reqs = open(req_file, 'r')
lines = reqs.readlines()
reqs_dict = {}
for line in lines:
splits = line.split("==")
if len(splits) == 2:
key = splits[0]
if "torch" not in key:
if "diffusers" in key:
key = "diffusers"
reqs_dict[key] = splits[1].replace("\n", "").strip()
if os.name == "nt":
reqs_dict["torch"] = "1.12.1+cu116"
reqs_dict["torchvision"] = "0.13.1+cu116"

checks = ["xformers","bitsandbytes", "diffusers", "transformers", "torch", "torchvision"]
for check in checks:
check_ver = "N/A"
status = "[ ]"
try:
check_available = importlib.util.find_spec(check) is not None
if check_available:
check_ver = importlib_metadata.version(check)
if check in reqs_dict:
req_version = reqs_dict[check]
if str(check_ver) == str(req_version):
status = "[+]"
else:
status = "[!]"
except importlib_metadata.PackageNotFoundError:
check_available = False
if not check_available:
status = "[!]"
print(f"{status} {check} NOT installed.")
if check == 'xformers':
x_cmd = "-U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl"
print(f"Installing xformers with: pip install {x_cmd}")
run(f"pip install {x_cmd}", desc="Installing xformers")

else:
print(f"{status} {check} version {check_ver} installed.")

base_dir = os.path.dirname(os.path.realpath(__file__))
#repo = git.Repo(base_dir)
#revision = repo.rev_parse("HEAD")
#print(f"Dreambooth revision is {revision}")
check_versions()
# Check for "different" B&B Files and copy only if necessary
if os.name == "nt":
python = sys.executable
Expand All @@ -100,5 +24,3 @@ def check_versions():
print("Copied CUDNN 8.6 files to destination")
else:
print(f"Installation Failed: \"{cudnn_src}\" could not be found. ")


55 changes: 37 additions & 18 deletions tools/setup_windows.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import subprocess
import os
import sys
import filecmp
import logging
import shutil
import sysconfig
Expand Down Expand Up @@ -134,6 +135,24 @@ def check_torch():
log.error(f'Could not load torch: {e}')
sys.exit(1)

def cudann_install():
cudnn_src = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..\cudnn_windows")
cudnn_dest = os.path.join(sysconfig.get_paths()["purelib"], "torch", "lib")

log.info(f"Checking for CUDNN files in {cudnn_dest}...")
if os.path.exists(cudnn_src):
if os.path.exists(cudnn_dest):
# check for different files
filecmp.clear_cache()
for file in os.listdir(cudnn_src):
src_file = os.path.join(cudnn_src, file)
dest_file = os.path.join(cudnn_dest, file)
#if dest file exists, check if it's different
if os.path.exists(dest_file):
shutil.copy2(src_file, cudnn_dest)
log.info("Copied CUDNN 8.6 files to destination")
else:
log.error(f"Installation Failed: \"{cudnn_src}\" could not be found. ")

def pip(
arg: str,
Expand Down Expand Up @@ -442,11 +461,7 @@ def install_kohya_ss_torch1():
reinstall=reinstall,
)
install_requirements('requirements_windows_torch1.txt')
delete_file('./logs/status/torch_version')
write_to_file('./logs/status/torch_version', '1')

sync_bits_and_bytes_files()

run_cmd(f'accelerate config')


Expand All @@ -471,11 +486,7 @@ def install_kohya_ss_torch2():
)
install_requirements('requirements_windows_torch2.txt')
# install('https://huggingface.co/r4ziel/xformers_pre_built/resolve/main/triton-2.0.0-cp310-cp310-win_amd64.whl', 'triton', reinstall=reinstall)
delete_file('./logs/status/torch_version')
write_to_file('./logs/status/torch_version', '2')

sync_bits_and_bytes_files()

run_cmd(f'accelerate config')


Expand All @@ -492,8 +503,8 @@ def main_menu():
while True:
print('\nKohya_ss GUI setup menu:\n')
print('0. Cleanup the venv')
print('1. Install kohya_ss gui [torch 1]')
print('2. Install kohya_ss gui [torch 2]')
print('1. Install kohya_ss gui')
print('2. Install cudann files')
print('3. Start Kohya_ss GUI in browser')
print('4. Quit')

Expand All @@ -509,15 +520,23 @@ def main_menu():
else:
print('Cleanup canceled.')
elif choice == '1':
print(
f'{YELLOW}Be patient, this can take quite some time to complete...\033[0m\n'
)
install_kohya_ss_torch1()
while True:
print('1. Torch 1')
print('2. Torch 2')
print('3. Cancel')
choice_torch = input('\nEnter your choice: ')
print('')

if choice_torch == 1:
install_kohya_ss_torch1()
elif choice_torch == '2':
install_kohya_ss_torch2()
elif choice_torch == '3':
break
else:
print('Invalid choice. Please enter a number between 1-3.')
elif choice == '2':
print(
f'{YELLOW}Be patient, this can take quite some time to complete...\033[0m\n'
)
install_kohya_ss_torch2()
cudann_install()
elif choice == '3':
subprocess.Popen('start cmd /c .\gui.bat --inbrowser', shell=True)
elif choice == '4':
Expand Down

0 comments on commit 1179780

Please sign in to comment.