Skip to content

Commit

Permalink
[feat]: Add vae encoder embedded generator to main (hao-ai-lab#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
rlsu9 committed Nov 6, 2024
1 parent 52ba538 commit 035ba5f
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 26 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ __pycache__
*.pth
UCF-101/
results/
vae
build/
fastvideo.egg-info/
wandb/
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ conda create -n fastvideo python=3.10.12
conda activate fastvideo
pip3 install torch==2.5.0 torchvision --index-url https://download.pytorch.org/whl/cu121
pip3 install -U xformers==0.0.28.post2 --index-url https://download.pytorch.org/whl/cu121
pip3 install ray
cd .. && git clone https://github.com/huggingface/diffusers
cd diffusers && git checkout mochi && pip install -e . && cd ../FastVideo-OSP
```
Expand Down
8 changes: 4 additions & 4 deletions fastvideo/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from transformers import AutoTokenizer

from torchvision import transforms

from torchvision.transforms import Lambda
from fastvideo.dataset.t2v_datasets import T2V_dataset
from fastvideo.dataset.latent_datasets import LatentDataset
from fastvideo.dataset.transform import Normalize255, TemporalRandomCrop,CenterCropResizeVideo

def getdataset(args):
temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x
norm_fun = ae_norm[args.ae]
norm_fun = Lambda(lambda x: 2. * x - 1.)
resize_topcrop = [CenterCropResizeVideo((args.max_height, args.max_width), top_crop=True), ]
resize = [CenterCropResizeVideo((args.max_height, args.max_width)), ]
transform = transforms.Compose([
Normalize255(),
# Normalize255(),
*resize,
# RandomHorizontalFlipVideo(p=0.5), # in case their caption have position decription
norm_fun
# norm_fun
])
transform_topcrop = transforms.Compose([
Normalize255(),
Expand Down
55 changes: 35 additions & 20 deletions fastvideo/dataset/t2v_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ def filter_resolution(h, w, max_h_div_w_ratio=17/16, min_h_div_w_ratio=8 / 16):

class T2V_dataset(Dataset):
def __init__(self, args, transform, temporal_sample, tokenizer, transform_topcrop):
self.data = args.data
self.data = args.data_merge_path
self.num_frames = args.num_frames
self.target_length = args.target_length
self.train_fps = args.train_fps
self.use_image_num = args.use_image_num
self.transform = transform
Expand All @@ -130,7 +131,7 @@ def __init__(self, args, transform, temporal_sample, tokenizer, transform_topcro
self.drop_short_ratio = args.drop_short_ratio
assert self.speed_factor >= 1
self.v_decoder = DecordInit()
self.video_length_tolerance_range = args.video_length_tolerance_range
# self.video_length_tolerance_range = args.video_length_tolerance_range
self.support_Chinese = True
if not ('mt5' in args.text_encoder_name):
self.support_Chinese = False
Expand Down Expand Up @@ -168,23 +169,36 @@ def get_data(self, idx):
if path.endswith('.mp4'):
return self.get_video(idx)
else:
assert False
return self.get_image(idx)

def get_video(self, idx):

video_path = dataset_prog.cap_list[idx]['path']
assert os.path.exists(video_path), f"file {video_path} do not exist!"
frame_indice = dataset_prog.cap_list[idx]['sample_frame_index']
video = self.decord_read(video_path, predefine_num_frames=len(frame_indice))
video, _, metadata = torchvision.io.read_video(video_path, output_format='TCHW')
video = self.transform(video)
video = rearrange(video, 't c h w -> c t h w')
video = video.unsqueeze(0)
video = video.to(torch.uint8)
assert video.dtype == torch.uint8
target_length = self.target_length
current_length = video.shape[2] # This is t (92 in your example)

if current_length < target_length:
# Calculate indices for frames spaced across the target length
indices = np.linspace(0, current_length - 1, target_length).astype(int)
# Select frames based on indices
video = video[:, :, indices, :, :]
elif current_length > target_length:
# Slice to reduce the time dimension to the target length
video = video[:, :, :target_length, :, :]

h, w = video.shape[-2:]
assert h / w <= 17 / 16 and h / w >= 8 / 16, f'Only videos with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But video ({video_path}) found ratio is {round(h / w, 2)} with the shape of {video.shape}'
t = video.shape[0]
video = self.transform(video) # T C H W -> T C H W

# video = torch.rand(221, 3, 480, 640)

video = video.transpose(0, 1) # T C H W -> C T H W

video = video.float() / 127.5 - 1.0

text = dataset_prog.cap_list[idx]['cap']
if not isinstance(text, list):
text = [text]
Expand All @@ -202,7 +216,7 @@ def get_video(self, idx):
)
input_ids = text_tokens_and_mask['input_ids']
cond_mask = text_tokens_and_mask['attention_mask']
return dict(pixel_values=video, input_ids=input_ids, cond_mask=cond_mask)
return dict(pixel_values=video, text=text, input_ids=input_ids, cond_mask=cond_mask, path=video_path)

def get_image(self, idx):
image_data = dataset_prog.cap_list[idx] # [{'path': path, 'cap': cap}, ...]
Expand Down Expand Up @@ -277,10 +291,10 @@ def define_frame_index(self, cap_list):
hw_aspect_thr = 1.5
is_pick = filter_resolution(height, width, max_h_div_w_ratio=hw_aspect_thr*aspect,
min_h_div_w_ratio=1/hw_aspect_thr*aspect)
if not is_pick:
print("resolution mismatch")
cnt_resolution_mismatch += 1
continue
# if not is_pick:
# print("resolution mismatch")
# cnt_resolution_mismatch += 1
# continue

# # ignore image resolution mismatch
# if self.max_height > resolution['height'] or self.max_width > resolution['width']:
Expand All @@ -290,9 +304,9 @@ def define_frame_index(self, cap_list):
# import ipdb;ipdb.set_trace()
i['num_frames'] = int(fps * duration)
# max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration.
if i['num_frames'] / fps > self.video_length_tolerance_range * (self.num_frames / self.train_fps * self.speed_factor): # too long video is not suitable for this training stage (self.num_frames)
cnt_too_long += 1
continue
# if i['num_frames'] / fps > self.video_length_tolerance_range * (self.num_frames / self.train_fps * self.speed_factor): # too long video is not suitable for this training stage (self.num_frames)
# cnt_too_long += 1
# continue

# resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24)
frame_interval = fps / self.train_fps
Expand Down Expand Up @@ -368,8 +382,8 @@ def decord_read(self, path, predefine_num_frames):
frame_indices = frame_indices[:end_frame_idx]
if predefine_num_frames != len(frame_indices):
raise ValueError(f'predefine_num_frames ({predefine_num_frames}) is not equal with frame_indices ({len(frame_indices)})')
if len(frame_indices) < self.num_frames and self.drop_short_ratio >= 1:
raise IndexError(f'video ({path}) has {total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})')
# if len(frame_indices) < self.num_frames and self.drop_short_ratio >= 1:
# raise IndexError(f'video ({path}) has {total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})')
video_data = decord_vr.get_batch(frame_indices).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
Expand All @@ -379,6 +393,7 @@ def read_jsons(self, data):
cap_lists = []
with open(data, 'r') as f:
folder_anno = [i.strip().split(',') for i in f.readlines() if len(i.strip()) > 0]
print(folder_anno)
for folder, anno in folder_anno:
with open(anno, 'r') as f:
sub_list = json.load(f)
Expand Down
2 changes: 1 addition & 1 deletion fastvideo/sample/generate_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def generate_video_and_latent(pipe, prompt, height, width, num_frames, num_infer
os.makedirs(os.path.join(args.dataset_output_dir, "prompt_embed"), exist_ok=True)
os.makedirs(os.path.join(args.dataset_output_dir, "prompt_attention_mask"), exist_ok=True)
data = []
for i, prompt in enumerate(text_prompt[:10]):
for i, prompt in enumerate(text_prompt):
if i % world_size != local_rank:
continue
noise, video, latent, prompt_embed, prompt_attention_mask = generate_video_and_latent(pipe, prompt, args.height, args.width, args.num_frames, args.num_inference_steps, args.guidance_scale)
Expand Down
85 changes: 85 additions & 0 deletions fastvideo/utils/data_preprocess/finetune_data_T5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import argparse
import torch
from accelerate.logging import get_logger
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
import json
import os
import torch.distributed as dist
logger = get_logger(__name__)

def main(args):
local_rank = int(os.getenv('RANK', 0))
world_size = int(os.getenv('WORLD_SIZE', 1))
print('world_size', world_size, 'local rank', local_rank)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(local_rank)
if not dist.is_initialized():
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=local_rank)

pipe = MochiPipeline.from_pretrained(args.model_path, torch_dtype=torch.bfloat16).to(device)
pipe.vae.enable_tiling()
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "video"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"), exist_ok=True)

latents_json_path = os.path.join(args.output_dir, "videos2caption_temp.json")
with open(latents_json_path, "r") as f:
train_dataset = json.load(f)
train_dataset = sorted(train_dataset, key=lambda x: x['latent_path'])

json_data = []
for _, data in enumerate(train_dataset):
video_name =data['latent_path'].split(".")[0]
if int(video_name) % world_size != local_rank:
continue
try:
with torch.inference_mode():
with torch.autocast("cuda", dtype=torch.bfloat16):
latent = torch.load(os.path.join(args.output_dir, 'latent', data['latent_path']))
prompt_embeds, prompt_attention_mask, _, _ = pipe.encode_prompt(
prompt=data['caption'],
)
prompt_embed_path = os.path.join(args.output_dir, "prompt_embed", video_name + ".pt")
video_path = os.path.join(args.output_dir, "video", video_name + ".mp4")
prompt_attention_mask_path = os.path.join(args.output_dir, "prompt_attention_mask", video_name + ".pt")
# save latent
torch.save(prompt_embeds[0], prompt_embed_path)
torch.save(prompt_attention_mask[0], prompt_attention_mask_path)
print(f"sample {video_name} saved")
video = pipe.vae.decode(latent.unsqueeze(0).to(device), return_dict=False)[0]
video = pipe.video_processor.postprocess_video(video)
export_to_video(video[0], video_path, fps=30)
item = {}
item["latent_path"] = video_name + ".pt"
item["prompt_embed_path"] = video_name + ".pt"
item["prompt_attention_mask"] = video_name + ".pt"
item["caption"] = data['caption']
json_data.append(item)
except:
print("video out of memory")
continue
dist.barrier()
local_data = json_data
gathered_data = [None] * world_size
dist.all_gather_object(gathered_data, local_data)
if local_rank == 0:
# os.remove(latents_json_path)
all_json_data = [item for sublist in gathered_data for item in sublist]
with open(os.path.join(args.output_dir, "videos2caption.json"), 'w') as f:
json.dump(all_json_data, f, indent=4)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
# dataset & dataloader
parser.add_argument("--model_path", type=str, default="data/mochi")
# text encoder & vae & diffusion model
parser.add_argument("--text_encoder_name", type=str, default='google/t5-v1_1-xxl')
parser.add_argument("--cache_dir", type=str, default='./cache_dir')
parser.add_argument("--output_dir", type=str, default=None, help="The output directory where the model predictions and checkpoints will be written.")

args = parser.parse_args()
main(args)
104 changes: 104 additions & 0 deletions fastvideo/utils/data_preprocess/finetune_data_VAE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from fastvideo.dataset import getdataset
from torch.utils.data import DataLoader
from fastvideo.utils.dataset_utils import Collate
import argparse
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
import json
import os
from diffusers import AutoencoderKLMochi
import torch.distributed as dist

logger = get_logger(__name__)

def main(args):
local_rank = int(os.getenv('RANK', 0))
world_size = int(os.getenv('WORLD_SIZE', 1))
print('world_size', world_size, 'local rank', local_rank)
args.ae_stride_t, args.ae_stride_h, args.ae_stride_w = 4, 8, 8
args.ae_stride = args.ae_stride_h
patch_size_t, patch_size_h, patch_size_w = 1, 2, 2
args.patch_size = patch_size_h
args.patch_size_t, args.patch_size_h, args.patch_size_w = patch_size_t, patch_size_h, patch_size_w
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=args.logging_dir)
accelerator = Accelerator(
project_config=accelerator_project_config,
)
train_dataset = getdataset(args)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)


encoder_device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(local_rank)
if not dist.is_initialized():
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=local_rank)
vae = AutoencoderKLMochi.from_pretrained(args.model_path, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
vae.enable_tiling()
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True)

json_data = []
for _, data in enumerate(train_dataloader):
video_name = os.path.basename(data['path'][0]).split(".")[0]
if int(video_name) % world_size != local_rank:
continue
with torch.inference_mode():
with torch.autocast("cuda", dtype=torch.bfloat16):
latents = vae.encode(data['pixel_values'][0].to(encoder_device))['latent_dist'].sample()
latent_path = os.path.join(args.output_dir, "latent", video_name + ".pt")
torch.save(latents[0].to(torch.bfloat16), latent_path)
item = {}
item["latent_path"] = video_name + ".pt"
item["caption"] = data['text']
json_data.append(item)
print(f"{video_name} processed")
dist.barrier()
local_data = json_data
gathered_data = [None] * world_size
dist.all_gather_object(gathered_data, local_data)
if local_rank == 0:
all_json_data = [item for sublist in gathered_data for item in sublist]
with open(os.path.join(args.output_dir, "videos2caption_temp.json"), 'w') as f:
json.dump(all_json_data, f, indent=4)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
# dataset & dataloader
parser.add_argument("--model_path", type=str, default="data/mochi")
parser.add_argument("--data_merge_path", type=str, required=True)
parser.add_argument("--num_frames", type=int, default=65)
parser.add_argument("--target_length", type=int, default=65)
parser.add_argument("--dataloader_num_workers", type=int, default=1, help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.")
parser.add_argument("--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader.")
parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.")
parser.add_argument("--max_height", type=int, default=480)
parser.add_argument("--max_width", type=int, default=848)
parser.add_argument("--group_frame", action="store_true") # TODO
parser.add_argument("--group_resolution", action="store_true") # TODO
parser.add_argument("--dataset", default='t2v')
parser.add_argument("--train_fps", type=int, default=24)
parser.add_argument("--use_image_num", type=int, default=0)
parser.add_argument("--text_max_length", type=int, default=256)
parser.add_argument("--speed_factor", type=float, default=1.0)
parser.add_argument("--drop_short_ratio", type=float, default=1.0)
# text encoder & vae & diffusion model
parser.add_argument("--text_encoder_name", type=str, default='google/t5-v1_1-xxl')
parser.add_argument("--cache_dir", type=str, default='./cache_dir')
parser.add_argument('--cfg', type=float, default=0.1)
parser.add_argument("--output_dir", type=str, default=None, help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--logging_dir", type=str, default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)

args = parser.parse_args()
main(args)
22 changes: 22 additions & 0 deletions scripts/data_preprocess/finetune_data_gen.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# export WANDB_MODE="offline"
GPU_NUM=8
MODEL_PATH="/ephemeral/hao.zhang/outputfolder/ckptfolder/mochi_diffuser"
MOCHI_DIR="/ephemeral/hao.zhang/resourcefolder/mochi/mochi-1-preview"
DATA_MERGE_PATH="/ephemeral/hao.zhang/resourcefolder/Mochi-Synthetic-Data-BW-Finetune/merge.txt"
OUTPUT_DIR="./data/BW-Finetune-Synthetic-Data_test"

torchrun --nproc_per_node=$GPU_NUM \
./fastvideo/utils/data_preprocess/finetune_data_VAE.py \
--model_path $MODEL_PATH \
--data_merge_path $DATA_MERGE_PATH \
--train_batch_size=1 \
--max_height=480 \
--max_width=848 \
--target_length=163 \
--dataloader_num_workers 1 \
--output_dir=$OUTPUT_DIR

torchrun --nproc_per_node=$GPU_NUM \
./fastvideo/utils/data_preprocess/finetune_data_T5.py \
--model_path $MODEL_PATH \
--output_dir=$OUTPUT_DIR

0 comments on commit 035ba5f

Please sign in to comment.