-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfinetune.py
565 lines (501 loc) · 22.5 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
import os
import argparse
import accelerate
import diffusers
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
import logging
import datasets
import math
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
from transformers.utils import ContextManagers
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from typing import List, Optional, Tuple, Union
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXTransformer3DModel,
)
from accelerate.utils import ProjectConfiguration, set_seed
from cogvideox_interpolation.datasets import ImageVideoDataset
logger = get_logger(__name__, log_level="INFO")
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--data_path",
type=str,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--output_dir",
type=str,
default="results",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=8,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--gradient_checkpointing",
type=bool,
default=True,
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=3e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--use_8bit_adam",
type=bool,
default=True,
help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--use_came",
type=bool,
default=False,
help="whether to use came",
)
parser.add_argument(
"--allow_tf32",
type=bool,
default=True,
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--train_batch_size", type=int, default=2, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=50,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def compute_prompt_embeds(
text_encoder,
text_input_ids,
device=None,
dtype=None,
num_videos_per_prompt=1,
):
batch_size = text_input_ids.size(0)
with torch.no_grad():
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def prepare_rotary_positional_embeddings(
height: int,
width: int,
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
base_width: int = 720,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * patch_size)
grid_width = width // (vae_scale_factor_spatial * patch_size)
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
def main():
args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
set_seed(args.seed)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
args.mixed_precision = accelerator.mixed_precision
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
args.mixed_precision = accelerator.mixed_precision
# Prepare models and scheduler
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
# CogVideoX-2b weights are stored in float16
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
transformer = CogVideoXTransformer3DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=load_dtype,
revision=args.revision,
variant=args.variant,
)
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
if deepspeed_plugin is None:
return []
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
text_encoder = T5EncoderModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKLCogVideoX.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
# Freeze vae and text_encoder and set transformer3d to trainable
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
transformer.requires_grad_(False)
# Move text_encode and vae to gpu and cast to weight_dtype
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device)
transformer.train()
# https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/transformer/diffusion_pytorch_model.safetensors.index.json
trainable_modules = ['ff.net', 'to_q', 'to_v', 'proj_out',]
trainable_modules_low_learning_rate = []
for name, param in transformer.named_parameters():
for trainable_module_name in trainable_modules + trainable_modules_low_learning_rate:
if trainable_module_name in name:
param.requires_grad = True
break
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# Initialize the optimizer
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
elif args.use_came:
try:
from came_pytorch import CAME
except:
raise ImportError(
"Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`"
)
optimizer_cls = CAME
else:
optimizer_cls = torch.optim.AdamW
trainable_params = list(filter(lambda p: p.requires_grad, transformer.parameters()))
trainable_params_optim = [
{'params': [], 'lr': args.learning_rate},
{'params': [], 'lr': args.learning_rate / 2},
]
in_already = []
for name, param in transformer.named_parameters():
high_lr_flag = False
if name in in_already:
continue
for trainable_module_name in trainable_modules:
if trainable_module_name in name:
in_already.append(name)
high_lr_flag = True
trainable_params_optim[0]['params'].append(param)
if accelerator.is_main_process:
print(f"Set {name} to lr : {args.learning_rate}")
break
if high_lr_flag:
continue
for trainable_module_name in trainable_modules_low_learning_rate:
if trainable_module_name in name:
in_already.append(name)
trainable_params_optim[1]['params'].append(param)
if accelerator.is_main_process:
print(f"Set {name} to lr : {args.learning_rate / 2}")
break
if args.use_came:
optimizer = optimizer_cls(
trainable_params_optim,
lr=args.learning_rate,
# weight_decay=args.adam_weight_decay,
betas=(0.9, 0.999, 0.9999),
eps=(1e-30, 1e-16)
)
else:
optimizer = optimizer_cls(
trainable_params_optim,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Dataset and Dataloader
train_dataset = ImageVideoDataset(data_root=args.data_path, tokenizer=tokenizer)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
num_workers=args.dataloader_num_workers,
)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Scheduler and math around the number of training steps.
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
# Prepare everything with our `accelerator`.
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
initial_global_step = 0
first_epoch = 0
global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
# For DeepSpeed training
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
video = batch[0].to(accelerator.device, dtype=vae.dtype)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
first_image = video[:, :, :1].clone()
last_image = video[:, :, -1:].clone()
with torch.no_grad():
video_latents = vae.encode(video).latent_dist.sample() * vae.config.scaling_factor
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=first_image.device)
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=first_image.dtype)
first_noisy_image = first_image + torch.randn_like(first_image) * image_noise_sigma[:, None, None, None, None]
first_image_latents = vae.encode(first_noisy_image).latent_dist.sample() * vae.config.scaling_factor
last_noisy_image = last_image + torch.randn_like(last_image) * image_noise_sigma[:, None, None, None, None]
last_image_latents = vae.encode(last_noisy_image).latent_dist.sample() * vae.config.scaling_factor
video_latents = video_latents.permute(0, 2, 1, 3, 4)
first_image_latents = first_image_latents.permute(0, 2, 1, 3, 4)
last_image_latents = last_image_latents.permute(0, 2, 1, 3, 4)
padding_shape = (video_latents.shape[0], video_latents.shape[1] - 2, *video_latents.shape[2:])
latent_padding = first_image_latents.new_zeros(padding_shape)
image_latents = torch.cat([first_image_latents, latent_padding, last_image_latents], dim=1)
input_text_ids = batch[1]
prompt_embeds = compute_prompt_embeds(text_encoder, input_text_ids, accelerator.device, weight_dtype,)
models_to_accumulate = [transformer]
with accelerator.accumulate(models_to_accumulate):
video_latents = video_latents.to(dtype=weight_dtype) # [B, F, C, H, W]
image_latents = image_latents.to(dtype=weight_dtype) # [B, F, C, H, W]
batch_size, num_frames, num_channels, height, width = video_latents.shape
# Sample a random timestep for each image
timesteps = torch.randint(
0, scheduler.config.num_train_timesteps, (batch_size,), device=video_latents.device
)
timesteps = timesteps.long()
# Sample noise that will be added to the latents
noise = torch.randn_like(video_latents)
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps)
noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)
# Prepare rotary embeds
image_rotary_emb = (
prepare_rotary_positional_embeddings(
height=train_dataset.height,
width=train_dataset.width,
num_frames=num_frames,
vae_scale_factor_spatial=vae_scale_factor_spatial,
patch_size=model_config.patch_size,
attention_head_dim=model_config.attention_head_dim,
device=accelerator.device,
)
if model_config.use_rotary_positional_embeddings
else None
)
# Predict the noise residual
model_output = transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timesteps,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps)
alphas_cumprod = scheduler.alphas_cumprod[timesteps]
weights = 1 / (1 - alphas_cumprod)
while len(weights.shape) < len(model_pred.shape):
weights = weights.unsqueeze(-1)
target = video_latents
loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
loss = loss.mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = transformer.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if accelerator.sync_gradients:
if global_step % args.checkpointing_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if __name__ == "__main__":
main()