Skip to content

Commit

Permalink
add web demo (hao-ai-lab#73)
Browse files Browse the repository at this point in the history
Co-authored-by: BrianChenn1129 <yonqgich@umich>
  • Loading branch information
BrianChen1129 and BrianChenn1129 authored Dec 6, 2024
1 parent 8d41d50 commit 98b92be
Showing 1 changed file with 142 additions and 0 deletions.
142 changes: 142 additions & 0 deletions fastvideo/demo/gradio_web_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import gradio as gr
import torch
from fastvideo.model.pipeline_mochi import MochiPipeline
from fastvideo.model.modeling_mochi import MochiTransformer3DModel
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import export_to_video
from fastvideo.distill.solver import PCMFMScheduler
import tempfile
import os
import argparse

def init_args():
parser = argparse.ArgumentParser()
parser.add_argument("--prompts", nargs='+', default=[])
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument("--height", type=int, default=480)
parser.add_argument("--width", type=int, default=848)
parser.add_argument("--num_inference_steps", type=int, default=64)
parser.add_argument("--guidance_scale", type=float, default=4.5)
parser.add_argument("--model_path", type=str, default="data/mochi")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--transformer_path", type=str, default=None)
parser.add_argument("--scheduler_type", type=str, default="euler")
parser.add_argument("--lora_checkpoint_dir", type=str, default=None)
parser.add_argument("--shift", type=float, default=8.0)
parser.add_argument("--num_euler_timesteps", type=int, default=100)
parser.add_argument("--linear_threshold", type=float, default=0.025)
parser.add_argument("--linear_range", type=float, default=0.5)
parser.add_argument("--cpu_offload", action="store_true")
return parser.parse_args()

def load_model(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.scheduler_type == "euler":
scheduler = FlowMatchEulerDiscreteScheduler()
else:
scheduler = PCMFMScheduler(1000, args.shift, args.num_euler_timesteps, False, args.linear_threshold, args.linear_range)

if args.transformer_path:
transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path)
else:
transformer = MochiTransformer3DModel.from_pretrained(args.model_path, subfolder='transformer/')

pipe = MochiPipeline.from_pretrained(args.model_path, transformer=transformer, scheduler=scheduler)
pipe.enable_vae_tiling()
pipe.to(device)
if args.cpu_offload:
pipe.enable_model_cpu_offload()
return pipe

def generate_video(prompt, negative_prompt, use_negative_prompt, seed, guidance_scale,
num_frames, height, width, num_inference_steps, randomize_seed=False):
if randomize_seed:
seed = torch.randint(0, 1000000, (1,)).item()

pipe = load_model(args)
print("load model successfully")
generator = torch.Generator(device="cuda").manual_seed(seed)

if not use_negative_prompt:
negative_prompt = None

with torch.autocast("cuda", dtype=torch.bfloat16):
output = pipe(
prompt=[prompt],
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
).frames[0]

output_path = os.path.join(tempfile.mkdtemp(), "output.mp4")
export_to_video(output, output_path, fps=30)
return output_path, seed

examples = [
"A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand’s movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough.",
"A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds.",
"A crowded rooftop bar buzzes with energy, the city skyline twinkling like a field of stars in the background. Strings of fairy lights hang above, casting a warm, golden glow over the scene. Groups of people gather around high tables, their laughter blending with the soft rhythm of live jazz. The aroma of freshly mixed cocktails and charred appetizers wafts through the air, mingling with the cool night breeze."
]

args = init_args()

with gr.Blocks() as demo:
gr.Markdown("# Mochi Video Generation Demo")

with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Video(label="Result", show_label=False)

with gr.Accordion("Advanced options", open=False):
with gr.Group():
with gr.Row():
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=32, value=args.height)
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=32, value=args.width)

with gr.Row():
num_frames = gr.Slider(label="Number of Frames", minimum=8, maximum=256, value=args.num_frames)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, value=args.guidance_scale)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=args.num_inference_steps)

with gr.Row():
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=False
)

seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=args.seed)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
seed_output = gr.Number(label="Used Seed")

gr.Examples(examples=examples, inputs=prompt)

use_negative_prompt.change(
fn=lambda x: gr.update(visible=x),
inputs=use_negative_prompt,
outputs=negative_prompt,
)

run_button.click(
fn=generate_video,
inputs=[prompt, negative_prompt, use_negative_prompt, seed, guidance_scale,
num_frames, height, width, num_inference_steps, randomize_seed],
outputs=[result, seed_output]
)

if __name__ == "__main__":
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860)

0 comments on commit 98b92be

Please sign in to comment.