From dbadc40ec2eb2de92b21fd3b5aa82994899705cc Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Thu, 23 Mar 2023 12:33:03 +0900 Subject: [PATCH] =?UTF-8?q?persistent=5Fworkers=E3=82=92=E6=9C=89=E5=8A=B9?= =?UTF-8?q?=E3=81=AB=E3=81=97=E3=81=9F=E9=9A=9B=E3=81=AB=E3=82=AD=E3=83=A3?= =?UTF-8?q?=E3=83=97=E3=82=B7=E3=83=A7=E3=83=B3=E3=81=8C=E5=A4=89=E5=8C=96?= =?UTF-8?q?=E3=81=97=E3=81=AA=E3=81=8F=E3=81=AA=E3=82=8B=E3=83=90=E3=82=B0?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fine_tune.py | 3 ++- library/config_util.py | 8 ++++++++ train_db.py | 3 ++- train_network.py | 3 ++- train_textual_inversion.py | 3 ++- 5 files changed, 16 insertions(+), 4 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index def942fac..ff5804350 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -62,6 +62,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.debug_dataset: @@ -259,13 +260,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) + train_dataset_group.set_current_step(global_step) for m in training_models: m.train() loss_total = 0 for step, batch in enumerate(train_dataloader): - train_dataset_group.set_current_step(global_step) with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: diff --git a/library/config_util.py b/library/config_util.py index 84bbf3086..efeb8016d 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -497,6 +497,14 @@ def load_user_config(file: str) -> dict: return config +def blueprint_args_conflict(args,blueprint:Blueprint): + # train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする + for b in blueprint.dataset_group.datasets: + for t in b.subsets: + if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): + print("Warning: %s: caption_dropout_every_n_epochs and token_warmup_step is ignored because --persistent_data_loader_workers option is used / --persistent_data_loader_workersオプションが使われているため、caption_dropout_every_n_epochs及びtoken_warmup_stepは無視されます。"%(t.params.image_dir)) + t.params.caption_dropout_every_n_epochs = 0 + t.params.token_warmup_step = 0 # for config test if __name__ == "__main__": diff --git a/train_db.py b/train_db.py index e17a8b795..87fe771b4 100644 --- a/train_db.py +++ b/train_db.py @@ -57,6 +57,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.no_token_padding: @@ -233,6 +234,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) + train_dataset_group.set_current_step(global_step) # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -241,7 +243,6 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): - train_dataset_group.set_current_step(global_step) # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") diff --git a/train_network.py b/train_network.py index 6d23ab07b..02a2d9255 100644 --- a/train_network.py +++ b/train_network.py @@ -98,6 +98,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.debug_dataset: @@ -501,13 +502,13 @@ def train(args): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) + train_dataset_group.set_current_step(global_step) metadata["ss_epoch"] = str(epoch + 1) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): - train_dataset_group.set_current_step(global_step) with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 427461696..63b634267 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -183,6 +183,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 @@ -335,12 +336,12 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) + train_dataset_group.set_current_step(global_step) text_encoder.train() loss_total = 0 for step, batch in enumerate(train_dataloader): - train_dataset_group.set_current_step(global_step) with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: