Skip to content

Commit

Permalink
[feat] add hunyuan adv (hao-ai-lab#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhang38 authored Dec 13, 2024
1 parent 6ab2263 commit 85639d1
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 88 deletions.
2 changes: 1 addition & 1 deletion fastvideo/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from diffusers import (
FlowMatchEulerDiscreteScheduler,
)
from fastvideo.utils.load import get_no_split_modules, load_transformer
from fastvideo.utils.load import load_transformer
from fastvideo.distill.solver import EulerSolver, extract_into_tensor
from copy import deepcopy
from diffusers.optimization import get_scheduler
Expand Down
9 changes: 5 additions & 4 deletions fastvideo/distill/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def __init__(
stride=8,
num_h_per_head=1,
adapter_channel_dims=[3072],
total_layers = 48,
):
super().__init__()
adapter_channel_dims = adapter_channel_dims * (48 // stride)
adapter_channel_dims = adapter_channel_dims * (total_layers // stride)
self.stride = stride
self.num_h_per_head = num_h_per_head
self.head_num = len(adapter_channel_dims)
Expand All @@ -89,9 +90,9 @@ def custom_forward(*inputs):

return custom_forward

assert len(features) // self.stride == len(self.heads)
for i in range(0, len(features), self.stride):
for h in self.heads[i // self.stride]:
assert len(features) == len(self.heads)
for i in range(0, len(features)):
for h in self.heads[i]:
# out = torch.utils.checkpoint.checkpoint(
# create_custom_forward(h),
# features[i],
Expand Down
134 changes: 73 additions & 61 deletions fastvideo/distill_adv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
StateDictType,
FullStateDictConfig,
)
from fastvideo.utils.load import load_transformer

from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule
import json
Expand Down Expand Up @@ -76,6 +77,7 @@ def gan_d_loss(
encoder_hidden_states,
encoder_attention_mask,
weight,
discriminator_head_stride
):
loss = 0.0
# collate sample_fake and sample_real
Expand All @@ -85,15 +87,17 @@ def gan_d_loss(
encoder_hidden_states,
timestep,
encoder_attention_mask,
output_attn=True,
output_features=True,
output_features_stride=discriminator_head_stride,
return_dict=False,
)[1]
real_features = teacher_transformer(
sample_real,
encoder_hidden_states,
timestep,
encoder_attention_mask,
output_attn=True,
output_features=True,
output_features_stride=discriminator_head_stride,
return_dict=False,
)[1]

Expand All @@ -115,14 +119,16 @@ def gan_g_loss(
encoder_hidden_states,
encoder_attention_mask,
weight,
discriminator_head_stride
):
loss = 0.0
features = teacher_transformer(
sample_fake,
encoder_hidden_states,
timestep,
encoder_attention_mask,
output_attn=True,
output_features=True,
output_features_stride=discriminator_head_stride,
return_dict=False,
)[1]
fake_outputs = discriminator(
Expand All @@ -135,20 +141,19 @@ def gan_g_loss(
return loss


def train_one_step_mochi(
def distill_one_step_adv(
transformer,
model_type,
teacher_transformer,
optimizer,
discriminator,
discriminator_optimizer,
global_step,
lr_scheduler,
loader,
noise_scheduler,
solver,
noise_random_generator,
sp_size,
precondition_outputs,
max_grad_norm,
uncond_prompt_embed,
uncond_prompt_mask,
Expand All @@ -157,6 +162,7 @@ def train_one_step_mochi(
not_apply_cfg_solver,
distill_cfg,
adv_weight,
discriminator_head_stride
):
optimizer.zero_grad()
discriminator_optimizer.zero_grad()
Expand All @@ -167,7 +173,7 @@ def train_one_step_mochi(
latents_attention_mask,
encoder_attention_mask,
) = next(loader)
model_input = normalize_mochi_dit_input(latents)
model_input = normalize_dit_input(model_type, latents)
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
index = torch.randint(
Expand Down Expand Up @@ -284,6 +290,7 @@ def train_one_step_mochi(
encoder_hidden_states.float(),
encoder_attention_mask,
1.0,
discriminator_head_stride
)
g_loss += g_gan_loss
g_loss.backward()
Expand All @@ -308,6 +315,7 @@ def train_one_step_mochi(
encoder_hidden_states,
encoder_attention_mask,
1.0,
discriminator_head_stride,
)

d_loss.backward()
Expand Down Expand Up @@ -347,21 +355,14 @@ def main(args):

main_print(f"--> loading model from {args.pretrained_model_name_or_path}")
# keep the master weight to float32
if args.dit_model_name_or_path:
transformer = transformer = MochiTransformer3DModel.from_pretrained(
args.dit_model_name_or_path,
torch_dtype=torch.float32,
# torch_dtype=torch.bfloat16 if args.use_lora else torch.float32,
)
else:
transformer = MochiTransformer3DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=torch.float32,
# torch_dtype=torch.bfloat16 if args.use_lora else torch.float32,
)
transformer = load_transformer(
args.model_type,
args.dit_model_name_or_path,
args.pretrained_model_name_or_path,
torch.float32 if args.master_weight_type == "fp32" else torch.bfloat16,
)
teacher_transformer = deepcopy(transformer)
discriminator = Discriminator(args.discriminator_head_stride)
discriminator = Discriminator(args.discriminator_head_stride, total_layers = 48 if args.model_type =="mochi" else 40)

if args.use_lora:
transformer.requires_grad_(False)
Expand All @@ -383,15 +384,20 @@ def main(args):
main_print(
f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}"
)
fsdp_kwargs = get_dit_fsdp_kwargs(
args.fsdp_sharding_startegy, args.use_lora, args.use_cpu_offload
fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs(
transformer,
args.fsdp_sharding_startegy,
args.use_lora,
args.use_cpu_offload,
args.master_weight_type,
)
discriminator_fsdp_kwargs = get_discriminator_fsdp_kwargs(args.master_weight_type)
if args.use_lora:
assert args.model_type == "mochi", "LoRA is only supported for Mochi model."
transformer.config.lora_rank = args.lora_rank
transformer.config.lora_alpha = args.lora_alpha
transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
transformer._no_split_modules = ["MochiTransformerBlock"]
transformer._no_split_modules = no_split_modules
fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer)

transformer = FSDP(
Expand All @@ -409,8 +415,12 @@ def main(args):
main_print(f"--> model loaded")

if args.gradient_checkpointing:
apply_fsdp_checkpointing(transformer, args.selective_checkpointing)
apply_fsdp_checkpointing(teacher_transformer, args.selective_checkpointing)
apply_fsdp_checkpointing(
transformer, no_split_modules, args.selective_checkpointing
)
apply_fsdp_checkpointing(
teacher_transformer, no_split_modules, args.selective_checkpointing
)
# Set model as trainable.
transformer.train()
teacher_transformer.requires_grad_(False)
Expand Down Expand Up @@ -562,39 +572,49 @@ def main(args):

step_times = deque(maxlen=100)
# log_validation(args, transformer, device,
# torch.bfloat16, init_steps, scheduler_type=args.scheduler_type, shift=args.shift, num_euler_timesteps=args.num_euler_timesteps, linear_quadratic_threshold=args.linear_quadratic_threshold, ema=False)

# torch.bfloat16, 0, scheduler_type=args.scheduler_type, shift=args.shift, num_euler_timesteps=args.num_euler_timesteps, linear_quadratic_threshold=args.linear_quadratic_threshold,ema=False)
def get_num_phases(multi_phased_distill_schedule, step):
# step-phase,step-phase
multi_phases = multi_phased_distill_schedule.split(",")
phase = multi_phases[-1].split("-")[-1]
for step_phases in multi_phases:
phase_step, phase = step_phases.split("-")
if step <= int(phase_step):
return int(phase)
return phase
for i in range(init_steps):
_ = next(loader)
for step in range(init_steps + 1, args.max_train_steps + 1):
assert args.multi_phased_distill_schedule is not None
num_phases = get_num_phases(args.multi_phased_distill_schedule, step)
start_time = time.time()
(
generator_loss,
generator_grad_norm,
discriminator_loss,
discriminator_grad_norm,
) = train_one_step_mochi(
) = distill_one_step_adv(
transformer,
args.model_type,
teacher_transformer,
optimizer,
discriminator,
discriminator_optimizer,
step,
lr_scheduler,
loader,
noise_scheduler,
solver,
noise_random_generator,
args.sp_size,
args.precondition_outputs,
args.max_grad_norm,
uncond_prompt_embed,
uncond_prompt_mask,
args.num_euler_timesteps,
args.validation_sampling_steps,
num_phases,
args.not_apply_cfg_solver,
args.distill_cfg,
args.adv_weight,
args.discriminator_head_stride
)

step_time = time.time() - start_time
Expand Down Expand Up @@ -633,15 +653,17 @@ def main(args):
)
else:
# Your existing checkpoint saving code
save_checkpoint_generator_discriminator(
transformer,
optimizer,
discriminator,
discriminator_optimizer,
rank,
args.output_dir,
step,
)
# TODO
# save_checkpoint_generator_discriminator(
# transformer,
# optimizer,
# discriminator,
# discriminator_optimizer,
# rank,
# args.output_dir,
# step,
# )
save_checkpoint(transformer, rank, args.output_dir, args.max_train_steps)
main_print(f"--> checkpoint saved at step {step}")
dist.barrier()
if args.log_validation and step % args.validation_steps == 0:
Expand All @@ -655,25 +677,17 @@ def main(args):
shift=args.shift,
num_euler_timesteps=args.num_euler_timesteps,
linear_quadratic_threshold=args.linear_quadratic_threshold,
linear_range=args.linear_range,
ema=False,
)


if args.use_lora:
save_lora_checkpoint(
transformer, optimizer, rank, args.output_dir, args.max_train_steps
)
else:
save_checkpoint(
transformer, optimizer, rank, args.output_dir, args.max_train_steps
)
save_checkpoint(
discriminator,
discriminator_optimizer,
rank,
args.output_dir,
step,
discriminator=True,
)
save_checkpoint(transformer, rank, args.output_dir, args.max_train_steps)

if get_sequence_parallel_state():
destroy_sequence_parallel_group()
Expand All @@ -682,6 +696,9 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_type", type=str, default="mochi", help="The type of model to train."
)
# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_frames", type=int, default=163)
Expand Down Expand Up @@ -712,16 +729,9 @@ def main(args):
parser.add_argument("--ema_decay", type=float, default=0.999)
parser.add_argument("--ema_start_step", type=int, default=0)
parser.add_argument("--cfg", type=float, default=0.1)
parser.add_argument(
"--precondition_outputs",
action="store_true",
help="Whether to precondition the outputs of the model.",
)

# validation & logs
parser.add_argument("--validation_prompt_dir", type=str)
parser.add_argument("--validation_sampling_steps", type=int, default=64)
parser.add_argument("--validation_guidance_scale", type=float, default=4.5)
parser.add_argument("--validation_sampling_steps", type=str, default="64")
parser.add_argument("--validation_guidance_scale", type=str, default="4.5")
parser.add_argument("--validation_steps", type=float, default=64)
parser.add_argument("--log_validation", action="store_true")
parser.add_argument("--tracker_project_name", type=str, default=None)
Expand Down Expand Up @@ -750,6 +760,7 @@ def main(args):
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument("--validation_prompt_dir", type=str)
parser.add_argument("--shift", type=float, default=1.0)
parser.add_argument(
"--resume_from_checkpoint",
Expand Down Expand Up @@ -866,6 +877,7 @@ def main(args):
"--lora_rank", type=int, default=128, help="LoRA rank parameter. "
)
parser.add_argument("--fsdp_sharding_startegy", default="full")
parser.add_argument("--multi_phased_distill_schedule", type=str, default=None)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
Expand Down
Loading

0 comments on commit 85639d1

Please sign in to comment.