Skip to content

Commit

Permalink
[examples] misc fixes (huggingface#1886)
Browse files Browse the repository at this point in the history
* misc fixes

* more comments

* Update examples/textual_inversion/textual_inversion.py

Co-authored-by: Patrick von Platen <[email protected]>

* set transformers verbosity to warning

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
patil-suraj and patrickvonplaten authored Jan 2, 2023
1 parent 423c3a4 commit fa1f470
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 70 deletions.
94 changes: 64 additions & 30 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import hashlib
import itertools
import logging
import math
import os
import warnings
Expand All @@ -12,6 +13,9 @@
import torch.utils.checkpoint
from torch.utils.data import Dataset

import datasets
import diffusers
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
Expand Down Expand Up @@ -236,6 +240,24 @@ def parse_args(input_args=None):
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
Expand Down Expand Up @@ -422,7 +444,7 @@ def main(args):
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
log_with=args.report_to,
logging_dir=logging_dir,
)

Expand All @@ -435,9 +457,27 @@ def main(args):
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()

# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)

# Generate class images if prior preservation is enabled.
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
if not class_images_dir.exists():
Expand Down Expand Up @@ -502,11 +542,7 @@ def main(args):

# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name,
revision=args.revision,
use_fast=False,
)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
Expand All @@ -518,38 +554,36 @@ def main(args):
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)

# Load models and create wrapper for stable diffusion
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)

if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True

if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
Expand All @@ -568,6 +602,7 @@ def main(args):
else:
optimizer_class = torch.optim.AdamW

# Optimizer creation
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
)
Expand All @@ -579,8 +614,7 @@ def main(args):
eps=args.adam_epsilon,
)

noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")

# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
Expand Down Expand Up @@ -615,6 +649,7 @@ def main(args):
power=args.lr_power,
)

# Prepare everything with our `accelerator`.
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
Expand All @@ -623,17 +658,16 @@ def main(args):
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
accelerator.register_for_checkpointing(lr_scheduler)

# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
# Move vae and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype)
Expand Down Expand Up @@ -664,6 +698,7 @@ def main(args):
global_step = 0
first_epoch = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
Expand Down Expand Up @@ -772,9 +807,8 @@ def main(args):
if global_step >= args.max_train_steps:
break

accelerator.wait_for_everyone()

# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
Expand Down
7 changes: 6 additions & 1 deletion examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def main():
logging_dir=logging_dir,
)

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
Expand All @@ -419,7 +420,7 @@ def main():
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
Expand Down Expand Up @@ -577,6 +578,7 @@ def tokenize_captions(examples, is_train=True):
)
return inputs.input_ids

# Preprocessing the datasets.
train_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
Expand Down Expand Up @@ -605,6 +607,7 @@ def collate_fn(examples):
input_ids = torch.stack([example["input_ids"] for example in examples])
return {"pixel_values": pixel_values, "input_ids": input_ids}

# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
)
Expand All @@ -623,6 +626,7 @@ def collate_fn(examples):
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)

# Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
Expand Down Expand Up @@ -668,6 +672,7 @@ def collate_fn(examples):
global_step = 0
first_epoch = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
Expand Down
Loading

0 comments on commit fa1f470

Please sign in to comment.