diff --git a/generate.py b/generate.py index ecfff31..a805877 100644 --- a/generate.py +++ b/generate.py @@ -10,14 +10,15 @@ from tqdm import tqdm from einops import rearrange from torch import autocast +from safetensors.torch import load_model import argparse assert torch.cuda.is_available() device = "cuda:0" parse = argparse.ArgumentParser() -parse.add_argument('--oasis-ckpt', type=str, help='Path to Oasis DiT checkpoint.', default="oasis500m.pt") -parse.add_argument('--vae-ckpt', type=str, help='Path to Oasis ViT-VAE checkpoint.', default="vit-l-20.pt") +parse.add_argument('--oasis-ckpt', type=str, help='Path to Oasis DiT checkpoint.', default="oasis500m.safetensors") +parse.add_argument('--vae-ckpt', type=str, help='Path to Oasis ViT-VAE checkpoint.', default="vit-l-20.safetensors") parse.add_argument('--num-frames', type=int, help='How many frames should be generated?', default=32) parse.add_argument('--output-path', type=str, help='Path where generated video should be saved.', default="video.mp4") parse.add_argument('--fps', type=int, help='What framerate should be used to save the output?', default=20) @@ -26,15 +27,21 @@ args = parse.parse_args() # load DiT checkpoint -ckpt = torch.load(args.oasis_ckpt) model = DiT_models["DiT-S/2"]() -model.load_state_dict(ckpt, strict=False) +if args.oasis_ckpt.endswith(".pt"): + ckpt = torch.load(args.oasis_ckpt, weights_only=True) + model.load_state_dict(ckpt, strict=False) +elif args.oasis_ckpt.endswith(".safetensors"): + load_model(model, args.oasis_ckpt) model = model.to(device).eval() # load VAE checkpoint -vae_ckpt = torch.load(args.vae_ckpt) vae = VAE_models["vit-l-20-shallow-encoder"]() -vae.load_state_dict(vae_ckpt) +if args.vae_ckpt.endswith(".pt"): + vae_ckpt = torch.load(args.vae_ckpt, weights_only=True) + vae.load_state_dict(vae_ckpt) +elif args.vae_ckpt.endswith(".safetensors"): + load_model(vae, args.vae_ckpt) vae = vae.to(device).eval() # sampling params @@ -126,5 +133,6 @@ # save video x = torch.clamp(x, 0, 1) x = (x * 255).byte() -write_video(args.output_path, x[0].cpu(), fps=argv.fps) +write_video(args.output_path, x[0].cpu(), fps=args.fps) print(f"generation saved to {args.output_path}.") +