Skip to content

Commit

Permalink
dreambooth upscaling fix added latents (huggingface#3659)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamberman authored Jun 5, 2023
1 parent 523a50a commit 0fc2fb7
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 30 deletions.
20 changes: 14 additions & 6 deletions docs/source/en/training/dreambooth.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,13 @@ upscaler to remove the new token from the instance prompt. I.e. if your stage I
For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than
LoRA finetuning stage II.

For finegrained detail like faces, we find that lower learning rates work best.
For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best.

For stage II, we find that lower learning rates are also needed.

We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler
used in the training scripts.

### Stage II additional validation images

The stage II validation requires images to upscale, we can download a downsized version of the training set:
Expand Down Expand Up @@ -631,7 +634,8 @@ with a T5 loaded from the original model.

`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.

`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade.
`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is
likely the learning rate can be increased with larger batch sizes.

Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.

Expand All @@ -656,18 +660,22 @@ accelerate launch train_dreambooth.py \
--text_encoder_use_attention_mask \
--tokenizer_max_length 77 \
--pre_compute_text_embeddings \
--use_8bit_adam \ #
--use_8bit_adam \
--set_grads_to_none \
--skip_save_text_encoder \
--push_to_hub
```

### IF Stage II Full Dreambooth

`--learning_rate=1e-8`: Even lower learning rate.
`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as
1e-8.

`--resolution=256`: The upscaler expects higher resolution inputs

`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with
faces required large effective batch sizes.

```sh
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
export INSTANCE_DIR="dog"
Expand All @@ -682,8 +690,8 @@ accelerate launch train_dreambooth.py \
--instance_prompt="a sks dog" \
--resolution=256 \
--train_batch_size=2 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-8 \
--gradient_accumulation_steps=6 \
--learning_rate=5e-6 \
--max_train_steps=2000 \
--validation_prompt="a sks dog" \
--validation_steps=150 \
Expand Down
20 changes: 14 additions & 6 deletions examples/dreambooth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,13 @@ upscaler to remove the new token from the instance prompt. I.e. if your stage I
For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than
LoRA finetuning stage II.

For finegrained detail like faces, we find that lower learning rates work best.
For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best.

For stage II, we find that lower learning rates are also needed.

We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler
used in the training scripts.

### Stage II additional validation images

The stage II validation requires images to upscale, we can download a downsized version of the training set:
Expand Down Expand Up @@ -665,7 +668,8 @@ with a T5 loaded from the original model.

`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.

`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade.
`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is
likely the learning rate can be increased with larger batch sizes.

Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.

Expand All @@ -690,18 +694,22 @@ accelerate launch train_dreambooth.py \
--text_encoder_use_attention_mask \
--tokenizer_max_length 77 \
--pre_compute_text_embeddings \
--use_8bit_adam \ #
--use_8bit_adam \
--set_grads_to_none \
--skip_save_text_encoder \
--push_to_hub
```

### IF Stage II Full Dreambooth

`--learning_rate=1e-8`: Even lower learning rate.
`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as
1e-8.

`--resolution=256`: The upscaler expects higher resolution inputs

`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with
faces required large effective batch sizes.

```sh
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
export INSTANCE_DIR="dog"
Expand All @@ -716,8 +724,8 @@ accelerate launch train_dreambooth.py \
--instance_prompt="a sks dog" \
--resolution=256 \
--train_batch_size=2 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-8 \
--gradient_accumulation_steps=6 \
--learning_rate=5e-6 \
--max_train_steps=2000 \
--validation_prompt="a sks dog" \
--validation_steps=150 \
Expand Down
11 changes: 2 additions & 9 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor


if is_wandb_available():
Expand Down Expand Up @@ -1212,14 +1211,8 @@ def compute_text_embeddings(prompt):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)

if unet.config.in_channels > channels:
needed_additional_channels = unet.config.in_channels - channels
additional_latents = randn_tensor(
(bsz, needed_additional_channels, height, width),
device=noisy_model_input.device,
dtype=noisy_model_input.dtype,
)
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
if unet.config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)

if args.class_labels_conditioning == "timesteps":
class_labels = timesteps
Expand Down
11 changes: 2 additions & 9 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -1157,14 +1156,8 @@ def compute_text_embeddings(prompt):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)

if unet.config.in_channels > channels:
needed_additional_channels = unet.config.in_channels - channels
additional_latents = randn_tensor(
(bsz, needed_additional_channels, height, width),
device=noisy_model_input.device,
dtype=noisy_model_input.dtype,
)
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
if unet.config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)

if args.class_labels_conditioning == "timesteps":
class_labels = timesteps
Expand Down

0 comments on commit 0fc2fb7

Please sign in to comment.