Skip to content

Commit

Permalink
update sd-scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Akegarasu committed Nov 24, 2024
1 parent 15019c7 commit b06fb56
Show file tree
Hide file tree
Showing 25 changed files with 3,707 additions and 2,794 deletions.
4 changes: 4 additions & 0 deletions scripts/dev/README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ Stable Diffusionの学習、画像生成、その他のスクリプトを入れ

[README in English](./README.md) ←更新情報はこちらにあります

開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。

FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。

GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。

以下のスクリプトがあります。
Expand Down
388 changes: 213 additions & 175 deletions scripts/dev/README.md

Large diffs are not rendered by default.

218 changes: 35 additions & 183 deletions scripts/dev/flux_train.py

Large diffs are not rendered by default.

142 changes: 76 additions & 66 deletions scripts/dev/flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False

def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
Expand All @@ -51,10 +52,23 @@ def assert_extra_args(self, args, train_dataset_group):
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")

assert not args.split_mode or not args.cpu_offload_checkpointing, (
"split_mode and cpu_offload_checkpointing cannot be used together"
" / split_modeとcpu_offload_checkpointingは同時に使用できません"
)
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"

# deprecated split_mode option
if args.split_mode:
if args.blocks_to_swap is not None:
logger.warning(
"split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
" / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
)
else:
logger.warning(
"split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
" / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
)
args.blocks_to_swap = 18 # 18 is safe for most cases

train_dataset_group.verify_bucket_reso_steps(32) # TODO check this

Expand All @@ -74,9 +88,21 @@ def load_target_model(self, args, weight_dtype, accelerator):
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
elif model.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 FLUX model")
else:
logger.info(
"Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
" / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
)
model.to(torch.float8_e4m3fn)

if args.split_mode:
model = self.prepare_split_model(model, weight_dtype, accelerator)
# if args.split_mode:
# model = self.prepare_split_model(model, weight_dtype, accelerator)

self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)

clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
clip_l.eval()
Expand All @@ -101,43 +127,6 @@ def load_target_model(self, args, weight_dtype, accelerator):

return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model

def prepare_split_model(self, model, weight_dtype, accelerator):
from accelerate import init_empty_weights

logger.info("prepare split model")
with init_empty_weights():
flux_upper = flux_models.FluxUpper(model.params)
flux_lower = flux_models.FluxLower(model.params)
sd = model.state_dict()

# lower (trainable)
logger.info("load state dict for lower")
flux_lower.load_state_dict(sd, strict=False, assign=True)
flux_lower.to(dtype=weight_dtype)

# upper (frozen)
logger.info("load state dict for upper")
flux_upper.load_state_dict(sd, strict=False, assign=True)

logger.info("prepare upper model")
target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype
flux_upper.to(accelerator.device, dtype=target_dtype)
flux_upper.eval()

if args.fp8_base:
# this is required to run on fp8
flux_upper = accelerator.prepare(flux_upper)

flux_upper.to("cpu")

self.flux_upper = flux_upper
del model # we don't need model anymore
clean_memory_on_device(accelerator.device)

logger.info("split model prepared")

return flux_lower

def get_tokenize_strategy(self, args):
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)

Expand Down Expand Up @@ -231,7 +220,7 @@ def cache_text_encoder_outputs_if_needed(
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()

prompts = sd3_train_utils.load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
Expand Down Expand Up @@ -284,12 +273,12 @@ def sample_images(self, accelerator, args, epoch, global_step, device, ae, token
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)

if not args.split_mode:
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
return
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
# return

"""
class FluxUpperLowerWrapper(torch.nn.Module):
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
super().__init__()
Expand All @@ -316,6 +305,7 @@ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_a
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
)
clean_memory_on_device(accelerator.device)
"""

def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
Expand Down Expand Up @@ -363,7 +353,7 @@ def get_noise_pred_and_target(
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t.dtype.is_floating_point:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
img_ids.requires_grad_(True)
guidance_vec.requires_grad_(True)
Expand All @@ -374,20 +364,21 @@ def get_noise_pred_and_target(
t5_attn_mask = None

def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
if not args.split_mode:
# normal forward
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# if not args.split_mode:
# normal forward
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
"""
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
Expand Down Expand Up @@ -421,6 +412,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
vec.requires_grad_(True)
pe.requires_grad_(True)
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""

return model_pred

Expand Down Expand Up @@ -453,6 +445,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t

if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
unet.prepare_block_swap_before_forward()
with torch.no_grad():
model_pred_prior = call_dit(
img=packed_noisy_model_input[diff_output_pr_indices],
Expand Down Expand Up @@ -539,16 +532,33 @@ def forward(hidden_states):
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)

def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)

# if we doesn't swap blocks, we can move the model to device
flux: flux_models.Flux = unet
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()

return flux


def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)

parser.add_argument(
"--split_mode",
action="store_true",
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
# help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
# + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
)
return parser

Expand Down
2 changes: 1 addition & 1 deletion scripts/dev/library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
Expand Down
Loading

0 comments on commit b06fb56

Please sign in to comment.