Skip to content

Commit

Permalink
add safetensors format
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-q committed Nov 4, 2024
1 parent fa312b3 commit 02d9301
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from tqdm import tqdm
from einops import rearrange
from torch import autocast
from safetensors.torch import load_model
import argparse
assert torch.cuda.is_available()
device = "cuda:0"

parse = argparse.ArgumentParser()

parse.add_argument('--oasis-ckpt', type=str, help='Path to Oasis DiT checkpoint.', default="oasis500m.pt")
parse.add_argument('--vae-ckpt', type=str, help='Path to Oasis ViT-VAE checkpoint.', default="vit-l-20.pt")
parse.add_argument('--oasis-ckpt', type=str, help='Path to Oasis DiT checkpoint.', default="oasis500m.safetensors")
parse.add_argument('--vae-ckpt', type=str, help='Path to Oasis ViT-VAE checkpoint.', default="vit-l-20.safetensors")
parse.add_argument('--num-frames', type=int, help='How many frames should be generated?', default=32)
parse.add_argument('--output-path', type=str, help='Path where generated video should be saved.', default="video.mp4")
parse.add_argument('--fps', type=int, help='What framerate should be used to save the output?', default=20)
Expand All @@ -26,15 +27,21 @@
args = parse.parse_args()

# load DiT checkpoint
ckpt = torch.load(args.oasis_ckpt)
model = DiT_models["DiT-S/2"]()
model.load_state_dict(ckpt, strict=False)
if args.oasis_ckpt.endswith(".pt"):
ckpt = torch.load(args.oasis_ckpt, weights_only=True)
model.load_state_dict(ckpt, strict=False)
elif args.oasis_ckpt.endswith(".safetensors"):
load_model(model, args.oasis_ckpt)
model = model.to(device).eval()

# load VAE checkpoint
vae_ckpt = torch.load(args.vae_ckpt)
vae = VAE_models["vit-l-20-shallow-encoder"]()
vae.load_state_dict(vae_ckpt)
if args.vae_ckpt.endswith(".pt"):
vae_ckpt = torch.load(args.vae_ckpt, weights_only=True)
vae.load_state_dict(vae_ckpt)
elif args.vae_ckpt.endswith(".safetensors"):
load_model(vae, args.vae_ckpt)
vae = vae.to(device).eval()

# sampling params
Expand Down Expand Up @@ -126,5 +133,6 @@
# save video
x = torch.clamp(x, 0, 1)
x = (x * 255).byte()
write_video(args.output_path, x[0].cpu(), fps=argv.fps)
write_video(args.output_path, x[0].cpu(), fps=args.fps)
print(f"generation saved to {args.output_path}.")

0 comments on commit 02d9301

Please sign in to comment.