Skip to content

Commit

Permalink
[Easy] fix: save_model_card utility of the DreamBooth SDXL LoRA script (
Browse files Browse the repository at this point in the history
huggingface#7258)

* fix: save_model_card utility.

* fix a little more to make it more lenient.

* remove lower()
  • Loading branch information
sayakpaul authored Mar 8, 2024
1 parent d9a3b69 commit 9d97440
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def save_model_card(
)

model_description = f"""
# {'SDXL' if 'playgroundai' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
<Gallery />
Expand All @@ -139,7 +139,7 @@ def save_model_card(
[Download]({repo_id}/tree/main) them in the Files & versions tab.
"""
if "playgroundai" in args.pretrained_model_name_or_path:
if "playground" in base_model:
model_description += """\n
## License
Expand All @@ -148,7 +148,7 @@ def save_model_card(
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++" if "playgroundai" not in base_model else "playground-v2dot5-community",
license="openrail++" if "playground" not in base_model else "playground-v2dot5-community",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
Expand All @@ -162,7 +162,7 @@ def save_model_card(
"lora" if not use_dora else "dora",
"template:sd-lora",
]
if "playgroundai" in base_model:
if "playground" in base_model:
tags.extend(["playground", "playground-diffusers"])
else:
tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
Expand Down Expand Up @@ -206,7 +206,7 @@ def log_validation(
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
inference_ctx = (
contextlib.nullcontext() if "playgroundai" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
)

with inference_ctx:
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
if accelerator.is_main_process:
tracker_name = (
"dreambooth-lora-sd-xl"
if "playgroundai" not in args.pretrained_model_name_or_path
if "playground" not in args.pretrained_model_name_or_path
else "dreambooth-lora-playground"
)
accelerator.init_trackers(tracker_name, config=vars(args))
Expand Down

0 comments on commit 9d97440

Please sign in to comment.