Skip to content

Commit

Permalink
[feat]: Add Image-Video Mixture training to main repo (hao-ai-lab#50)
Browse files Browse the repository at this point in the history
Co-authored-by: runlong <[email protected]>
  • Loading branch information
rlsu9 and rlsu9 committed Nov 24, 2024
1 parent 45e4adc commit 5a5d0ef
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 100 deletions.
6 changes: 3 additions & 3 deletions fastvideo/dataset/latent_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ def __init__(
self.prompt_attention_mask_dir = os.path.join(self.datase_dir_path, "prompt_attention_mask")
with open(self.json_path, 'r') as f:
self.data_anno = json.load(f)
self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
# json.load(f) already keeps the order
# self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
self.num_latent_t = num_latent_t
self.uncond_prompt_embed = torch.load(os.path.join(uncond_prompt_embed_mask_dir, "embed.pt"), map_location="cpu", weights_only=True)
self.uncond_prompt_mask = torch.load(os.path.join(uncond_prompt_embed_mask_dir, "mask.pt"), map_location="cpu", weights_only=True)
self.lengths = [torch.load(os.path.join(self.latent_dir, data_item["latent_path"]), map_location="cpu").shape[1] for data_item in self.data_anno]
def __getitem__(self, idx):
latent_file = self.data_anno[idx]["latent_path"]
prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
prompt_attention_mask_file = self.data_anno[idx]["prompt_attention_mask"]
# load
latent = torch.load(os.path.join(self.latent_dir, latent_file), map_location="cpu", weights_only=True)
# TODO: Hack
latent = latent.squeeze()[:, -self.num_latent_t:]
if random.random() < self.cfg_rate:
prompt_embed = self.uncond_prompt_embed
prompt_attention_mask = self.uncond_prompt_mask
Expand Down
16 changes: 9 additions & 7 deletions fastvideo/dataset/t2v_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
from tqdm import tqdm
from PIL import Image
from accelerate.logging import get_logger

from fastvideo.utils.dataset_utils import DecordInit
from fastvideo.utils.utils import text_preprocessing
import torchvision
logger = get_logger(__name__)



class SingletonMeta(type):
"""
这是一个元类,用于创建单例类。
Expand Down Expand Up @@ -132,14 +131,14 @@ def get_data(self, idx):
if path.endswith('.mp4'):
return self.get_video(idx)
else:
assert False
return self.get_image(idx)

def get_video(self, idx):
video_path = dataset_prog.cap_list[idx]['path']
assert os.path.exists(video_path), f"file {video_path} do not exist!"
frame_indices = dataset_prog.cap_list[idx]['sample_frame_index']
video = self.decord_read(video_path, frame_indices=frame_indices)
torchvision_video, _, metadata = torchvision.io.read_video(video_path, output_format="TCHW")
video = torchvision_video[frame_indices]
video = self.transform(video)
video = rearrange(video, 't c h w -> c t h w')
video = video.unsqueeze(0)
Expand All @@ -156,7 +155,7 @@ def get_video(self, idx):
text = [text]
text = [random.choice(text)]

text = text_preprocessing(text, support_Chinese=self.support_Chinese) if random.random() > self.cfg else ""
text = text[0] if random.random() > self.cfg else ""
text_tokens_and_mask = self.tokenizer(
text,
max_length=self.text_max_length,
Expand All @@ -182,7 +181,10 @@ def get_image(self, idx):

image = self.transform_topcrop(image) if 'human_images' in image_data['path'] else self.transform(image) # [1 C H W] -> num_img [1 C H W]
image = image.transpose(0, 1) # [1 C H W] -> [C 1 H W]

image = image.unsqueeze(0)

image = image.float() / 127.5 - 1.0

caps = image_data['cap'] if isinstance(image_data['cap'], list) else [image_data['cap']]
caps = [random.choice(caps)]
text = text_preprocessing(caps, support_Chinese=self.support_Chinese)
Expand All @@ -199,7 +201,7 @@ def get_image(self, idx):
)
input_ids = text_tokens_and_mask['input_ids'] # 1, l
cond_mask = text_tokens_and_mask['attention_mask'] # 1, l
return dict(pixel_values=image, input_ids=input_ids, cond_mask=cond_mask)
return dict(pixel_values=image, text=text, input_ids=input_ids, cond_mask=cond_mask, path=image_data['path'])

def define_frame_index(self, cap_list):

Expand Down
2 changes: 1 addition & 1 deletion fastvideo/model/pipeline_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

# duplicate text embeddings for each generation per prompt, using mps friendly method
Expand Down
11 changes: 10 additions & 1 deletion fastvideo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
import json
from torch.utils.data.distributed import DistributedSampler
from fastvideo.utils.dataset_utils import LengthGroupedSampler
import wandb
from accelerate.utils import set_seed
from tqdm.auto import tqdm
Expand Down Expand Up @@ -354,7 +355,15 @@ def main(args):
)

train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg)
sampler = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
sampler = LengthGroupedSampler(
args.train_batch_size,
rank=rank,
world_size=world_size,
lengths=train_dataset.lengths,
group_frame=args.group_frame,
group_resolution=args.group_resolution,
) if (args.group_frame or args.group_resolution) else DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=False)

train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
Expand Down
22 changes: 13 additions & 9 deletions fastvideo/utils/communications.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,16 @@ def sp_parallel_dataloader_wrapper(dataloader, device, train_batch_size, sp_size
cond = cond.to(device)
attn_mask = attn_mask.to(device)
cond_mask = cond_mask.to(device)
latents, cond, attn_mask, cond_mask = prepare_sequence_parallel_data(latents, cond, attn_mask, cond_mask)

for iter in range(train_batch_size * sp_size // train_sp_batch_size):
st_idx = iter * train_sp_batch_size
ed_idx = (iter + 1) * train_sp_batch_size
encoder_hidden_states=cond[st_idx: ed_idx]
attention_mask=attn_mask[st_idx: ed_idx]
encoder_attention_mask=cond_mask[st_idx: ed_idx]
yield latents[st_idx: ed_idx], encoder_hidden_states, attention_mask, encoder_attention_mask
frame = latents.shape[2]
if frame == 1:
yield latents, cond, attn_mask, cond_mask
else:
latents, cond, attn_mask, cond_mask = prepare_sequence_parallel_data(latents, cond, attn_mask, cond_mask)

for iter in range(train_batch_size * sp_size // train_sp_batch_size):
st_idx = iter * train_sp_batch_size
ed_idx = (iter + 1) * train_sp_batch_size
encoder_hidden_states=cond[st_idx: ed_idx]
attention_mask=attn_mask[st_idx: ed_idx]
encoder_attention_mask=cond_mask[st_idx: ed_idx]
yield latents[st_idx: ed_idx], encoder_hidden_states, attention_mask, encoder_attention_mask
2 changes: 1 addition & 1 deletion fastvideo/utils/data_preprocess/finetune_data_VAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def main(args):
# text encoder & vae & diffusion model
parser.add_argument("--text_encoder_name", type=str, default='google/t5-v1_1-xxl')
parser.add_argument("--cache_dir", type=str, default='./cache_dir')
parser.add_argument('--cfg', type=float, default=0.1)
parser.add_argument('--cfg', type=float, default=0.0)
parser.add_argument("--output_dir", type=str, default=None, help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--logging_dir", type=str, default="logs",
help=(
Expand Down
120 changes: 42 additions & 78 deletions fastvideo/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def process(self, batch_tubes, input_ids, cond_mask, t_ds_stride, ds_stride, max
count_dict = Counter(len_each_batch)
if len(count_dict) != 1:
sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1])
# import ipdb;ipdb.set_trace()
# print(batch, idx_length_dict, count_dict, sorted_by_value)
pick_length = sorted_by_value[-1][0] # the highest frequency
candidate_batch = [idx for idx, length in idx_length_dict.items() if length == pick_length]
random_select_batch = [random.choice(candidate_batch) for _ in range(len(len_each_batch) - len(candidate_batch))]
Expand Down Expand Up @@ -198,100 +196,55 @@ def group_frame_fun(indices, lengths):
indices.sort(key=lambda i: lengths[i], reverse=True)
return indices

def group_resolution_fun(indices):
raise NotImplementedError
return indices

def group_frame_and_resolution_fun(indices):
raise NotImplementedError
return indices

def last_group_frame_fun(shuffled_megabatches, lengths):
re_shuffled_megabatches = []
# print('shuffled_megabatches', len(shuffled_megabatches))
for i_megabatch, megabatch in enumerate(shuffled_megabatches):
re_megabatch = []
for i_batch, batch in enumerate(megabatch):
assert len(batch) != 0
len_each_batch = [lengths[i] for i in batch]
idx_length_dict = dict([*zip(batch, len_each_batch)])
count_dict = Counter(len_each_batch)
if len(count_dict) != 1:
sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1])
# print(batch, idx_length_dict, count_dict, sorted_by_value)
pick_length = sorted_by_value[-1][0] # the highest frequency
candidate_batch = [idx for idx, length in idx_length_dict.items() if length == pick_length]
random_select_batch = [random.choice(candidate_batch) for i in range(len(len_each_batch) - len(candidate_batch))]
# print(batch, idx_length_dict, count_dict, sorted_by_value, pick_length, candidate_batch, random_select_batch)
batch = candidate_batch + random_select_batch
# print(batch)
re_megabatch.append(batch)
re_shuffled_megabatches.append(re_megabatch)

def megabatch_frame_alignment(megabatches, lengths):
aligned_magabatches = []
for _, megabatch in enumerate(megabatches):
assert len(megabatch) != 0
len_each_megabatch = [lengths[i] for i in megabatch]
idx_length_dict = dict([*zip(megabatch, len_each_megabatch)])
count_dict = Counter(len_each_megabatch)

# mixed frame length, align megabatch inside
if len(count_dict) != 1:
sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1])
pick_length = sorted_by_value[-1][0] # the highest frequency
candidate_batch = [idx for idx, length in idx_length_dict.items() if length == pick_length]
random_select_batch = [random.choice(candidate_batch) for i in range(len(idx_length_dict) - len(candidate_batch))]
aligned_magabatch = candidate_batch + random_select_batch
aligned_magabatches.append(aligned_magabatch)
# already aligned megabatches
else:
aligned_magabatches.append(megabatch)

# for megabatch, re_megabatch in zip(shuffled_megabatches, re_shuffled_megabatches):
# for batch, re_batch in zip(megabatch, re_megabatch):
# for i, re_i in zip(batch, re_batch):
# if i != re_i:
# print(i, re_i)
return re_shuffled_megabatches
return aligned_magabatches


def last_group_resolution_fun(indices):
raise NotImplementedError
return indices

def last_group_frame_and_resolution_fun(indices):
raise NotImplementedError
return indices

def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, group_frame=False, group_resolution=False, seed=42):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
if generator is None:
generator = torch.Generator().manual_seed(seed) # every rank will generate a fixed order but random index
# print('lengths', lengths)

indices = torch.randperm(len(lengths), generator=generator).tolist()
# print('indices', len(indices))

if group_frame and not group_resolution:
indices = group_frame_fun(indices, lengths)
elif not group_frame and group_resolution:
indices = group_resolution_fun(indices)
elif group_frame and group_resolution:
indices = group_frame_and_resolution_fun(indices)
# print('sort indices', len(indices))
# print('sort indices', indices)
# print('sort lengths', [lengths[i] for i in indices])

# sort dataset according to frame
indices = group_frame_fun(indices, lengths)

# chunk dataset to megabatches
megabatch_size = world_size * batch_size
megabatches = [indices[i: i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
# print('megabatches', len(megabatches))
# print('\nmegabatches', megabatches)
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
# print('sort megabatches', len(megabatches))
# megabatches_len = [[lengths[i] for i in megabatch] for megabatch in megabatches]
# print('\nsorted megabatches', megabatches)
# print('\nsorted megabatches_len', megabatches_len)

# make sure the length in each magabatch is align with each other
megabatches = megabatch_frame_alignment(megabatches, lengths)

# aplit aligned megabatch into batches
megabatches = [split_to_even_chunks(megabatch, lengths, world_size, batch_size) for megabatch in megabatches]
# print('nsplit_to_even_chunks megabatches', len(megabatches))
# print('\nsplit_to_even_chunks megabatches', megabatches)
# print('\nsplit_to_even_chunks len', [lengths[i] for megabatch in megabatches for batch in megabatch for i in batch])
# return [i for megabatch in megabatches for batch in megabatch for i in batch]

# random megabatches to do video-image mix training
indices = torch.randperm(len(megabatches), generator=generator).tolist()
shuffled_megabatches = [megabatches[i] for i in indices]
# print('shuffled_megabatches', len(shuffled_megabatches))
if group_frame and not group_resolution:
shuffled_megabatches = last_group_frame_fun(shuffled_megabatches, lengths)
elif not group_frame and group_resolution:
shuffled_megabatches = last_group_resolution_fun(shuffled_megabatches, indices)
elif group_frame and group_resolution:
shuffled_megabatches = last_group_frame_and_resolution_fun(shuffled_megabatches, indices)
# print('\nshuffled_megabatches', shuffled_megabatches)
# import ipdb;ipdb.set_trace()
# print('\nshuffled_megabatches len', [lengths[i] for megabatch in shuffled_megabatches for batch in megabatch for i in batch])

# expand indices and return
return [i for megabatch in shuffled_megabatches for batch in megabatch for i in batch]


Expand All @@ -304,6 +257,7 @@ class LengthGroupedSampler(Sampler):
def __init__(
self,
batch_size: int,
rank: int,
world_size: int,
lengths: Optional[List[int]] = None,
group_frame=False,
Expand All @@ -314,6 +268,7 @@ def __init__(
raise ValueError("Lengths must be provided.")

self.batch_size = batch_size
self.rank = rank
self.world_size = world_size
self.lengths = lengths
self.group_frame = group_frame
Expand All @@ -326,4 +281,13 @@ def __len__(self):
def __iter__(self):
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, group_frame=self.group_frame,
group_resolution=self.group_resolution, generator=self.generator)
def distributed_sampler(lst, rank, batch_size, world_size):
result = []
index = rank * batch_size
while index < len(lst):
result.extend(lst[index:index + batch_size])
index += batch_size * world_size
return result

indices = distributed_sampler(indices, self.rank, self.batch_size, self.world_size)
return iter(indices)

0 comments on commit 5a5d0ef

Please sign in to comment.