Skip to content

Commit

Permalink
Add GUI support for kohya latest options
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Jun 4, 2023
1 parent aeefc0d commit fd47104
Showing 1 changed file with 63 additions and 1 deletion.
64 changes: 63 additions & 1 deletion lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ def save_configuration(
save_last_n_steps_state,
use_wandb,
wandb_api_key,
scale_weight_norms,
network_dropout,
rank_dropout,
module_dropout,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -290,6 +294,10 @@ def open_configuration(
save_last_n_steps_state,
use_wandb,
wandb_api_key,
scale_weight_norms,
network_dropout,
rank_dropout,
module_dropout,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -419,6 +427,10 @@ def train_model(
save_last_n_steps_state,
use_wandb,
wandb_api_key,
scale_weight_norms,
network_dropout,
rank_dropout,
module_dropout,
):
print_only_bool = True if print_only.get('label') == 'True' else False
log.info(f'Start training LoRA {LoRA_type} ...')
Expand Down Expand Up @@ -628,6 +640,7 @@ def train_model(
run_cmd += f' --save_model_as={save_model_as}'
if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}'

if LoRA_type == 'LoCon' or LoRA_type == 'LyCORIS/LoCon':
try:
import lycoris
Expand All @@ -638,6 +651,7 @@ def train_model(
return
run_cmd += f' --network_module=lycoris.kohya'
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=lora"'

if LoRA_type == 'LyCORIS/LoHa':
try:
import lycoris
Expand All @@ -659,6 +673,8 @@ def train_model(
'block_alphas',
'conv_dims',
'conv_alphas',
'rank_dropout',
'module_dropout',
]

run_cmd += f' --network_module=networks.lora'
Expand All @@ -670,7 +686,7 @@ def train_model(

network_args = ''
if LoRA_type == 'Kohya LoCon':
network_args += f' "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"'
network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"'

for key, value in kohya_lora_vars.items():
if value:
Expand All @@ -691,6 +707,8 @@ def train_model(
'block_alphas',
'conv_dims',
'conv_alphas',
'rank_dropout',
'module_dropout',
'unit',
]

Expand Down Expand Up @@ -746,6 +764,12 @@ def train_model(
run_cmd += f' --lr_scheduler_num_cycles="{epoch}"'
if not lr_scheduler_power == '':
run_cmd += f' --lr_scheduler_power="{lr_scheduler_power}"'

if scale_weight_norms > 0.0:
run_cmd += f' --scale_weight_norms="{scale_weight_norms}"'

if network_dropout > 0.0:
run_cmd += f' --network_dropout="{network_dropout}"'

run_cmd += run_cmd_training(
learning_rate=learning_rate,
Expand Down Expand Up @@ -1205,6 +1229,15 @@ def update_LoRA_settings(LoRA_type):
value=False,
info='Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.',
)
scale_weight_norms = gr.Slider(
label="Scale weight norms",
value=0,
minimum=0.0,
maximum=1.0,
step=0.01,
info='Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR for details.',
interactive=True,
)
with gr.Row():
prior_loss_weight = gr.Number(
label='Prior loss weight', value=1.0
Expand All @@ -1218,6 +1251,31 @@ def update_LoRA_settings(LoRA_type):
label='LR power',
placeholder='(Optional) For Cosine with restart and polynomial only',
)
with gr.Row():
network_dropout = gr.Slider(
label='Network dropout',
value=0.0,
minimum=0.0,
maximum=1.0,
step=0.01,
info='Is a normal probability dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Recommended range 0.1 to 0.5'
)
rank_dropout = gr.Slider(
label='Rank dropout',
value=0.0,
minimum=0.0,
maximum=1.0,
step=0.01,
info='can specify `rank_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3'
)
module_dropout = gr.Slider(
label='Module dropout',
value=0.0,
minimum=0.0,
maximum=1.0,
step=0.01,
info='can specify `module_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3'
)
(
# use_8bit_adam,
xformers,
Expand Down Expand Up @@ -1395,6 +1453,10 @@ def update_LoRA_settings(LoRA_type):
save_last_n_steps_state,
use_wandb,
wandb_api_key,
scale_weight_norms,
network_dropout,
rank_dropout,
module_dropout,
]

button_open_config.click(
Expand Down

0 comments on commit fd47104

Please sign in to comment.