forked from hao-ai-lab/FastVideo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess_validation_text_embeddings.py
81 lines (74 loc) · 3.08 KB
/
preprocess_validation_text_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
import torch
from accelerate.logging import get_logger
from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
from diffusers.utils import export_to_video
import json
import os
import torch.distributed as dist
logger = get_logger(__name__)
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from fastvideo.utils.load import load_text_encoder, load_vae
from diffusers.video_processor import VideoProcessor
from tqdm import tqdm
def main(args):
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size, "local rank", local_rank)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(local_rank)
if not dist.is_initialized():
dist.init_process_group(
backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
)
text_encoder = load_text_encoder(args.model_type, args.model_path, device=device)
autocast_type = torch.float16 if args.model_type == "hunyuan" else torch.bfloat16
# output_dir/validation/prompt_attention_mask
# output_dir/validation/prompt_embed
os.makedirs(os.path.join(args.output_dir, "validation"), exist_ok=True)
os.makedirs(
os.path.join(args.output_dir, "validation", "prompt_attention_mask"),
exist_ok=True,
)
os.makedirs(
os.path.join(args.output_dir, "validation", "prompt_embed"), exist_ok=True
)
json_data = []
with open(args.validation_prompt_txt, "r", encoding="utf-8") as file:
lines = file.readlines()
prompts = [line.strip() for line in lines]
for prompt in prompts:
with torch.inference_mode():
with torch.autocast("cuda", dtype=autocast_type):
prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt(
prompt
)
file_name = prompt.split(".")[0]
prompt_embed_path = os.path.join(
args.output_dir, "validation", "prompt_embed", f"{file_name}.pt"
)
prompt_attention_mask_path = os.path.join(
args.output_dir,
"validation",
"prompt_attention_mask",
f"{file_name}.pt",
)
torch.save(prompt_embeds[0], prompt_embed_path)
torch.save(prompt_attention_mask[0], prompt_attention_mask_path)
print(f"sample {file_name} saved")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# dataset & dataloader
parser.add_argument("--model_path", type=str, default="data/mochi")
parser.add_argument("--model_type", type=str, default="mochi")
parser.add_argument("--validation_prompt_txt", type=str)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
args = parser.parse_args()
main(args)