Skip to content

Commit

Permalink
Add support for command-line arguments.
Browse files Browse the repository at this point in the history
  • Loading branch information
matth3wmajf committed Nov 3, 2024
1 parent a9d9ee7 commit 6535d66
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,19 @@
from tqdm import tqdm
from einops import rearrange
from torch import autocast
import argparse
assert torch.cuda.is_available()
device = "cuda:0"

parse = argparse.ArgumentParser(description='Generate AI-hallucinated Minecraft videos!')

parse.add_argument('--count', type=int, help='How many frames should be generated?', default=32)
parse.add_argument('--fps', type=int, help='What framerate should be used?', default=20)
parse.add_argument('--file', type=str, help='What should the video\'s file name be?', default="video.mp4")
parse.add_argument('--steps', type=int, help='How many steps?', default=50)

argv = parse.parse_args()

# load DiT checkpoint
ckpt = torch.load("oasis500m.pt")
model = DiT_models["DiT-S/2"]()
Expand All @@ -27,9 +37,9 @@

# sampling params
B = 1
total_frames = 32
total_frames = argv.count
max_noise_level = 1000
ddim_noise_steps = 100
ddim_noise_steps = argv.steps
noise_range = torch.linspace(-1, max_noise_level - 1, ddim_noise_steps + 1)
noise_abs_max = 20
ctx_max_noise_idx = ddim_noise_steps // 10 * 3
Expand Down Expand Up @@ -114,6 +124,5 @@
# save video
x = torch.clamp(x, 0, 1)
x = (x * 255).byte()
write_video("video.mp4", x[0].cpu(), fps=20)
print("generation saved to video.mp4.")

write_video(argv.file, x[0].cpu(), fps=argv.fps)
print("generation saved to video.mp4.")

0 comments on commit 6535d66

Please sign in to comment.