Skip to content

Commit

Permalink
Add support for now SDXL parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Jul 3, 2023
1 parent fb0c3df commit fae048c
Show file tree
Hide file tree
Showing 8 changed files with 306 additions and 60 deletions.
49 changes: 26 additions & 23 deletions finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
update_my_data,
check_if_model_exist,
output_message,
SDXLParameters
)
from library.tensorboard_gui import (
gradio_tensorboard,
Expand Down Expand Up @@ -130,6 +131,8 @@ def save_configuration(
scale_v_pred_loss_like_noise_pred,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
sdxl_min_timestep,
sdxl_max_timestep,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -256,6 +259,8 @@ def open_configuration(
scale_v_pred_loss_like_noise_pred,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
sdxl_min_timestep,
sdxl_max_timestep,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -370,6 +375,8 @@ def train_model(
scale_v_pred_loss_like_noise_pred,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
sdxl_min_timestep,
sdxl_max_timestep,
):
print_only_bool = True if print_only.get('label') == 'True' else False
log.info(f'Start Finetuning...')
Expand All @@ -391,13 +398,13 @@ def train_model(
# )
# return

if optimizer == 'Adafactor' and lr_warmup != '0':
output_message(
msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
title='Warning',
headless=headless_bool,
)
lr_warmup = '0'
# if optimizer == 'Adafactor' and lr_warmup != '0':
# output_message(
# msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
# title='Warning',
# headless=headless_bool,
# )
# lr_warmup = '0'

# create caption json file
if generate_caption_database:
Expand Down Expand Up @@ -533,6 +540,12 @@ def train_model(

if sdxl_no_half_vae:
run_cmd += f' --no_half_vae'

if sdxl_min_timestep > 0:
run_cmd += f' --min_timestep={sdxl_min_timestep}'

if not sdxl_max_timestep == 1000:
run_cmd += f' --max_timestep={sdxl_max_timestep}'

run_cmd += run_cmd_training(
learning_rate=learning_rate,
Expand Down Expand Up @@ -794,20 +807,8 @@ def finetune_tab(headless=False):
optimizer_args,
) = gradio_training(learning_rate_value='1e-5')

# SDXL parameters
with gr.Row(visible=False) as sdxl_row:
sdxl_cache_text_encoder_outputs = gr.Checkbox(
label='(SDXL) Cache text encoder outputs',
info='Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.',
value=False
)
sdxl_no_half_vae = gr.Checkbox(
label='(SDXL) No half VAE',
info='Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.',
value=False
)

sdxl_checkbox.change(lambda sdxl_checkbox: gr.Row.update(visible=sdxl_checkbox), inputs=[sdxl_checkbox], outputs=[sdxl_row])
# Add SDXL Parameters
sdxl_params = SDXLParameters(sdxl_checkbox)

with gr.Row():
dataset_repeats = gr.Textbox(label='Dataset repeats', value=40)
Expand Down Expand Up @@ -966,8 +967,10 @@ def finetune_tab(headless=False):
use_wandb,
wandb_api_key,
scale_v_pred_loss_like_noise_pred,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
sdxl_params.sdxl_cache_text_encoder_outputs,
sdxl_params.sdxl_no_half_vae,
sdxl_params.sdxl_min_timestep,
sdxl_params.sdxl_max_timestep,
]

button_run.click(
Expand Down
44 changes: 44 additions & 0 deletions library/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,3 +1413,47 @@ def verify_image_folder_pattern(folder_path):

log.info(f'Valid image folder names found in: {folder_path}')
return true_response

### SDXL Parameters class
class SDXLParameters:
def __init__(self, sdxl_checkbox):
self.sdxl_checkbox = sdxl_checkbox

with gr.Accordion(visible=False, open=True, label='SDXL Specific Parameters') as self.sdxl_row:
with gr.Row():
self.sdxl_cache_text_encoder_outputs = gr.Checkbox(
label='Cache text encoder outputs',
info='Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.',
value=False,
)
self.sdxl_no_half_vae = gr.Checkbox(
label='No half VAE',
info='Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.',
value=False
)
self.sdxl_min_timestep = gr.Slider(
label='Min Timestep',
value=0,
step=1,
minimum=0,
maximum=1000,
info='Train U-Net with different timesteps'
)
self.sdxl_max_timestep = gr.Slider(
label='Max Timestep',
value=1000,
step=1,
minimum=0,
maximum=1000,
info='Train U-Net with different timesteps',
)

# def timestep_minimum(value):
# if value < 0:
# value = 0
# return gr.Number.update(value=value)

# self.sdxl_min_timestep.blur(timestep_minimum, inputs=[self.sdxl_min_timestep], outputs=[self.sdxl_min_timestep])
# self.sdxl_max_timestep.blur(timestep_minimum, inputs=[self.sdxl_max_timestep], outputs=[self.sdxl_max_timestep])

self.sdxl_checkbox.change(lambda sdxl_checkbox: gr.Accordion.update(visible=sdxl_checkbox), inputs=[self.sdxl_checkbox], outputs=[self.sdxl_row])
48 changes: 26 additions & 22 deletions lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
check_if_model_exist,
output_message,
verify_image_folder_pattern,
SDXLParameters
)
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
Expand Down Expand Up @@ -170,6 +171,8 @@ def save_configuration(
module_dropout,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
sdxl_min_timestep,
sdxl_max_timestep,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -319,6 +322,8 @@ def open_configuration(
module_dropout,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
sdxl_min_timestep,
sdxl_max_timestep,
training_preset,
):
# Get list of function parameters and values
Expand Down Expand Up @@ -485,6 +490,8 @@ def train_model(
module_dropout,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
sdxl_min_timestep,
sdxl_max_timestep,
):
print_only_bool = True if print_only.get('label') == 'True' else False
log.info(f'Start training LoRA {LoRA_type} ...')
Expand Down Expand Up @@ -570,13 +577,13 @@ def train_model(
):
return

if optimizer == 'Adafactor' and lr_warmup != '0':
output_message(
msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
title='Warning',
headless=headless_bool,
)
lr_warmup = '0'
# if optimizer == 'Adafactor' and lr_warmup != '0':
# output_message(
# msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
# title='Warning',
# headless=headless_bool,
# )
# lr_warmup = '0'

# If string is empty set string to 0.
if text_encoder_lr == '':
Expand Down Expand Up @@ -887,6 +894,12 @@ def train_model(

if sdxl_no_half_vae:
run_cmd += f' --no_half_vae'

if sdxl_min_timestep > 0:
run_cmd += f' --min_timestep={sdxl_min_timestep}'

if not sdxl_max_timestep == 1000:
run_cmd += f' --max_timestep={sdxl_max_timestep}'

run_cmd += run_cmd_training(
learning_rate=learning_rate,
Expand Down Expand Up @@ -1197,20 +1210,9 @@ def list_presets(path):
value='0.0001',
info='Optional',
)
# SDXL parameters
with gr.Row(visible=False) as sdxl_row:
sdxl_cache_text_encoder_outputs = gr.Checkbox(
label='(SDXL) Cache text encoder outputs',
info='Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.',
value=False
)
sdxl_no_half_vae = gr.Checkbox(
label='(SDXL) No half VAE',
info='Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.',
value=False
)

sdxl_checkbox.change(lambda sdxl_checkbox: gr.Row.update(visible=sdxl_checkbox), inputs=[sdxl_checkbox], outputs=[sdxl_row])
# Add SDXL Parameters
sdxl_params = SDXLParameters(sdxl_checkbox)

with gr.Row():
factor = gr.Slider(
Expand Down Expand Up @@ -1717,8 +1719,10 @@ def update_LoRA_settings(LoRA_type):
network_dropout,
rank_dropout,
module_dropout,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
sdxl_params.sdxl_cache_text_encoder_outputs,
sdxl_params.sdxl_no_half_vae,
sdxl_params.sdxl_min_timestep,
sdxl_params.sdxl_max_timestep,
]

button_open_config.click(
Expand Down
5 changes: 0 additions & 5 deletions presets/lora/ia3-sd15.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"gradient_checkpointing": false,
"keep_tokens": "0",
"learning_rate": 1.0,
"logging_dir": "",
"lora_network_weights": "",
"lr_scheduler": "cosine",
"lr_scheduler_num_cycles": "",
Expand All @@ -57,14 +56,11 @@
"num_cpu_threads_per_process": 2,
"optimizer": "Prodigy",
"optimizer_args": "",
"output_dir": "",
"output_name": "",
"persistent_data_loader_workers": false,
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
"prior_loss_weight": 1.0,
"random_crop": false,
"rank_dropout": 0,
"reg_data_dir": "",
"resume": "",
"sample_every_n_epochs": 0,
"sample_every_n_steps": 0,
Expand All @@ -84,7 +80,6 @@
"stop_text_encoder_training": 0,
"text_encoder_lr": 1.0,
"train_batch_size": 1,
"train_data_dir": "",
"train_on_input": true,
"training_comment": "",
"unet_lr": 1.0,
Expand Down
5 changes: 0 additions & 5 deletions presets/lora/locon-dadaptation-sdxl.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"gradient_checkpointing": false,
"keep_tokens": "0",
"learning_rate": 4e-07,
"logging_dir": "",
"lora_network_weights": "",
"lr_scheduler": "constant_with_warmup",
"lr_scheduler_num_cycles": "",
Expand All @@ -57,14 +56,11 @@
"num_cpu_threads_per_process": 2,
"optimizer": "Adafactor",
"optimizer_args": "scale_parameter=False relative_step=False warmup_init=False",
"output_dir": "",
"output_name": "",
"persistent_data_loader_workers": false,
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
"prior_loss_weight": 1.0,
"random_crop": false,
"rank_dropout": 0,
"reg_data_dir": "",
"resume": "",
"sample_every_n_epochs": 0,
"sample_every_n_steps": 0,
Expand All @@ -85,7 +81,6 @@
"stop_text_encoder_training": 0,
"text_encoder_lr": 0.0,
"train_batch_size": 1,
"train_data_dir": "",
"train_on_input": true,
"training_comment": "",
"unet_lr": 4e-07,
Expand Down
5 changes: 0 additions & 5 deletions presets/lora/loha-sd15.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"gradient_checkpointing": false,
"keep_tokens": "0",
"learning_rate": 0.0001,
"logging_dir": "",
"lora_network_weights": "",
"lr_scheduler": "cosine",
"lr_scheduler_num_cycles": "",
Expand All @@ -57,14 +56,11 @@
"num_cpu_threads_per_process": 2,
"optimizer": "AdamW",
"optimizer_args": "",
"output_dir": "",
"output_name": "",
"persistent_data_loader_workers": false,
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
"prior_loss_weight": 1.0,
"random_crop": false,
"rank_dropout": 0,
"reg_data_dir": "",
"resume": "",
"sample_every_n_epochs": 0,
"sample_every_n_steps": 0,
Expand All @@ -84,7 +80,6 @@
"stop_text_encoder_training": 0,
"text_encoder_lr": 0.0001,
"train_batch_size": 1,
"train_data_dir": "",
"train_on_input": true,
"training_comment": "",
"unet_lr": 0.0001,
Expand Down
Loading

0 comments on commit fae048c

Please sign in to comment.