Skip to content

Commit

Permalink
[Fix]: Resolve config bug and seed (hao-ai-lab#51)
Browse files Browse the repository at this point in the history
Co-authored-by: Peiyuan Zhang <[email protected]>
  • Loading branch information
jzhang38 committed Nov 16, 2024
1 parent 7106ead commit 45e4adc
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
6 changes: 3 additions & 3 deletions fastvideo/model/modeling_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from diffusers.models.embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
from fastvideo.model.norm import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm


from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info
Expand Down Expand Up @@ -199,9 +199,9 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin):

attn_mask = encoder_attention_mask[:, :].bool()
attn_mask = F.pad(attn_mask, (sequence_length, 0), value=True)
# hidden_states = flash_attn_no_pad(qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None)
hidden_states = flash_attn_no_pad(qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None)

hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask = None, dropout_p=0.0, is_causal=False)
# hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask = None, dropout_p=0.0, is_causal=False)

# valid_lengths = encoder_attention_mask.sum(dim=1) + sequence_length
# def no_padding_mask(score, b, h, q_idx, kv_idx):
Expand Down
39 changes: 28 additions & 11 deletions fastvideo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import shutil
from pathlib import Path
from diffusers.training_utils import cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
from fastvideo.utils.parallel_states import initialize_sequence_parallel_state, \
destroy_sequence_parallel_group, get_sequence_parallel_state, nccl_info
from fastvideo.utils.communications import sp_parallel_dataloader_wrapper, broadcast
Expand All @@ -32,19 +31,13 @@
from diffusers.optimization import get_scheduler
from fastvideo.model.modeling_mochi import MochiTransformer3DModel
from diffusers.utils import check_min_version
from fastvideo.utils.ema import EMAModel
from fastvideo.dataset.latent_datasets import LatentDataset, latent_collate_function
import torch.distributed as dist
from safetensors.torch import save_file, load_file
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict, get_peft_model, inject_adapter_in_model
from peft import LoraConfig, inject_adapter_in_model
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
)
from torch.distributed.checkpoint.state_dict import get_state_dict
from typing import Optional
import copy
from typing import Dict, Type
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0")
import time
Expand Down Expand Up @@ -90,6 +83,27 @@ def save_checkpoint(transformer: MochiTransformer3DModel, rank, output_dir, step
json.dump(config_dict, f, indent=4)
main_print(f"--> checkpoint saved at step {step}")

def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, generator, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu", generator=generator)
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu", generator=generator)
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu", generator=generator)
return u

def get_sigmas(noise_scheduler, device, timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
Expand All @@ -102,7 +116,7 @@ def get_sigmas(noise_scheduler, device, timesteps, n_dim=4, dtype=torch.float32)
return sigma


def train_one_step_mochi(transformer, optimizer, lr_scheduler, loader,noise_scheduler, gradient_accumulation_steps, sp_size, precondition_outputs, max_grad_norm, weighting_scheme, logit_mean, logit_std, mode_scale):
def train_one_step_mochi(transformer, optimizer, lr_scheduler,loader, noise_scheduler, noise_random_generator, gradient_accumulation_steps, sp_size, precondition_outputs, max_grad_norm, weighting_scheme, logit_mean, logit_std, mode_scale):
total_loss = 0.0
optimizer.zero_grad()
for _ in range(gradient_accumulation_steps):
Expand All @@ -114,6 +128,7 @@ def train_one_step_mochi(transformer, optimizer, lr_scheduler, loader,noise_sche
u = compute_density_for_timestep_sampling(
weighting_scheme=weighting_scheme,
batch_size=batch_size,
generator=noise_random_generator,
logit_mean=logit_mean,
logit_std=logit_std,
mode_scale=mode_scale,
Expand All @@ -126,7 +141,6 @@ def train_one_step_mochi(transformer, optimizer, lr_scheduler, loader,noise_sche

sigmas = get_sigmas(noise_scheduler, latents.device, timesteps, n_dim=latents.ndim, dtype=latents.dtype)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise

model_pred = transformer(
noisy_model_input,
encoder_hidden_states,
Expand Down Expand Up @@ -252,6 +266,8 @@ def main(args):
if args.seed is not None:
# TODO: t within the same seq parallel group should be the same. Noise should be different.
set_seed(args.seed + rank)
# We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.
noise_random_generator = None

# Handle the repository creation
if rank <=0 and args.output_dir is not None:
Expand All @@ -263,6 +279,7 @@ def main(args):
# Create model:

main_print(f"--> loading model from {args.pretrained_model_name_or_path}")
# keep the master weight to float32
load_dtype = torch.float32
transformer = MochiTransformer3DModel.from_pretrained(
args.pretrained_model_name_or_path,
Expand Down Expand Up @@ -395,7 +412,7 @@ def main(args):
next(loader)
for step in range(init_steps + 1, args.max_train_steps+1):
start_time = time.time()
loss, grad_norm= train_one_step_mochi(transformer, optimizer, lr_scheduler, loader, noise_scheduler, args.gradient_accumulation_steps, args.sp_size, args.precondition_outputs, args.max_grad_norm, args.weighting_scheme, args.logit_mean, args.logit_std, args.mode_scale)
loss, grad_norm= train_one_step_mochi(transformer, optimizer, lr_scheduler, loader, noise_scheduler, noise_random_generator, args.gradient_accumulation_steps, args.sp_size, args.precondition_outputs, args.max_grad_norm, args.weighting_scheme, args.logit_mean, args.logit_std, args.mode_scale)

step_time = time.time() - start_time
step_times.append(step_time)
Expand Down

0 comments on commit 45e4adc

Please sign in to comment.