Skip to content

Commit

Permalink
persistent_workersを有効にした際にキャプションが変化しなくなるバグ修正
Browse files Browse the repository at this point in the history
  • Loading branch information
u-haru committed Mar 23, 2023
1 parent 447c56b commit dbadc40
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 4 deletions.
3 changes: 2 additions & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def train(args):
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
config_util.blueprint_args_conflict(args,blueprint)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)

if args.debug_dataset:
Expand Down Expand Up @@ -259,13 +260,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
train_dataset_group.set_current_step(global_step)

for m in training_models:
m.train()

loss_total = 0
for step, batch in enumerate(train_dataloader):
train_dataset_group.set_current_step(global_step)
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
Expand Down
8 changes: 8 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,14 @@ def load_user_config(file: str) -> dict:

return config

def blueprint_args_conflict(args,blueprint:Blueprint):
# train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする
for b in blueprint.dataset_group.datasets:
for t in b.subsets:
if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0):
print("Warning: %s: caption_dropout_every_n_epochs and token_warmup_step is ignored because --persistent_data_loader_workers option is used / --persistent_data_loader_workersオプションが使われているため、caption_dropout_every_n_epochs及びtoken_warmup_stepは無視されます。"%(t.params.image_dir))
t.params.caption_dropout_every_n_epochs = 0
t.params.token_warmup_step = 0

# for config test
if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def train(args):
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
config_util.blueprint_args_conflict(args,blueprint)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)

if args.no_token_padding:
Expand Down Expand Up @@ -233,6 +234,7 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
train_dataset_group.set_current_step(global_step)

# 指定したステップ数までText Encoderを学習する:epoch最初の状態
unet.train()
Expand All @@ -241,7 +243,6 @@ def train(args):
text_encoder.train()

for step, batch in enumerate(train_dataloader):
train_dataset_group.set_current_step(global_step)
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}")
Expand Down
3 changes: 2 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def train(args):
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
config_util.blueprint_args_conflict(args,blueprint)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)

if args.debug_dataset:
Expand Down Expand Up @@ -501,13 +502,13 @@ def train(args):
if is_main_process:
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
train_dataset_group.set_current_step(global_step)

metadata["ss_epoch"] = str(epoch + 1)

network.on_epoch_start(text_encoder, unet)

for step, batch in enumerate(train_dataloader):
train_dataset_group.set_current_step(global_step)
with accelerator.accumulate(network):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
Expand Down
3 changes: 2 additions & 1 deletion train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def train(args):
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
config_util.blueprint_args_conflict(args,blueprint)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)

# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
Expand Down Expand Up @@ -335,12 +336,12 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
train_dataset_group.set_current_step(global_step)

text_encoder.train()

loss_total = 0
for step, batch in enumerate(train_dataloader):
train_dataset_group.set_current_step(global_step)
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
Expand Down

0 comments on commit dbadc40

Please sign in to comment.