Skip to content

Commit

Permalink
Major Updates to app following latest
Browse files Browse the repository at this point in the history
Now verify if input file name by the user as spaces and fixes that
Update to the inference pipeline 
Now uses paths same as inference code ./outputs
fully support both jpg and video
double checks formats to codec outputs
  • Loading branch information
tpc2233 authored Oct 19, 2024
1 parent 53c939c commit 2cb25ef
Showing 1 changed file with 99 additions and 100 deletions.
199 changes: 99 additions & 100 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
import logging
import os
import random
import tempfile
import time

import shutil
import torch
from easydict import EasyDict
import numpy as np
import torch
from dav.pipelines import DAVPipeline
from dav.models import UNetSpatioTemporalRopeConditionModel
from diffusers import AutoencoderKLTemporalDecoder, FlowMatchEulerDiscreteScheduler
Expand All @@ -16,18 +14,14 @@

def seed_all(seed: int = 0):
"""
Set random seeds for reproducibility.
Set random seeds of all components.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


# Initialize logging
logging.basicConfig(level=logging.INFO)


# Load models once to avoid reloading on every inference
def load_models(model_base, device):
vae = AutoencoderKLTemporalDecoder.from_pretrained(model_base, subfolder="vae")
Expand Down Expand Up @@ -59,7 +53,7 @@ def load_models(model_base, device):


def depth_any_video(
file,
file_path,
denoise_steps=3,
num_frames=32,
decode_chunk_size=16,
Expand All @@ -69,99 +63,104 @@ def depth_any_video(
):
"""
Perform depth estimation on the uploaded video/image.
Save the result in the output directory and return the path for display.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
# Save the uploaded file
input_path = os.path.join(tmp_dir, file.name)
with open(input_path, "wb") as f:
f.write(file.read())

# Set up output directory
output_dir = os.path.join(tmp_dir, "output")
os.makedirs(output_dir, exist_ok=True)

# Prepare configuration
cfg = EasyDict(
{
"model_base": MODEL_BASE,
"data_path": input_path,
"output_dir": output_dir,
"denoise_steps": denoise_steps,
"num_frames": num_frames,
"decode_chunk_size": decode_chunk_size,
"num_interp_frames": num_interp_frames,
"num_overlap_frames": num_overlap_frames,
"max_resolution": max_resolution,
"seed": 666,
}
)
output_dir = "./outputs"
os.makedirs(output_dir, exist_ok=True)

# Replace spaces with underscores in the filename
sanitized_file_name = os.path.basename(file_path).replace(" ", "_")
local_input_path = os.path.join(output_dir, sanitized_file_name)
shutil.copy(file_path, local_input_path)

# Prepare configuration
cfg = EasyDict(
{
"model_base": MODEL_BASE,
"data_path": local_input_path,
"output_dir": output_dir,
"denoise_steps": denoise_steps,
"num_frames": num_frames,
"decode_chunk_size": decode_chunk_size,
"num_interp_frames": num_interp_frames,
"num_overlap_frames": num_overlap_frames,
"max_resolution": max_resolution,
"seed": random.randint(0, 10000),
}
)

seed_all(cfg.seed)

file_name = os.path.splitext(os.path.basename(cfg.data_path))[0]
is_video = cfg.data_path.lower().endswith((".mp4", ".avi", ".mov", ".mkv"))

if is_video:
num_interp_frames = cfg.num_interp_frames
num_overlap_frames = cfg.num_overlap_frames
num_frames = cfg.num_frames
assert num_frames % 2 == 0, "num_frames should be even."
assert (
2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2
), "Invalid frame overlap."
max_frames = (num_interp_frames + 2 - num_overlap_frames) * (
num_frames // 2
)
image, fps = img_utils.read_video(cfg.data_path, max_frames=max_frames)
else:
image = img_utils.read_image(cfg.data_path)

image = img_utils.imresize_max(image, cfg.max_resolution)
image = img_utils.imcrop_multi(image)
image_tensor = np.ascontiguousarray(
[_img.transpose(2, 0, 1) / 255.0 for _img in image]
seed_all(cfg.seed)

file_name = os.path.splitext(sanitized_file_name)[0]
is_video = cfg.data_path.lower().endswith((".mp4", ".avi", ".mov", ".mkv"))

if is_video:
num_interp_frames = cfg.num_interp_frames
num_overlap_frames = cfg.num_overlap_frames
num_frames = cfg.num_frames
assert num_frames % 2 == 0, "num_frames should be even."
assert (
2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2
), "Invalid frame overlap."
max_frames = (num_interp_frames + 2 - num_overlap_frames) * (
num_frames // 2
)
image_tensor = torch.from_numpy(image_tensor).to(DEVICE)

with torch.no_grad(), torch.autocast(
device_type=DEVICE_TYPE, dtype=torch.float16
):
pipe_out = pipe(
image_tensor,
num_frames=cfg.num_frames,
num_overlap_frames=cfg.num_overlap_frames,
num_interp_frames=cfg.num_interp_frames,
decode_chunk_size=cfg.decode_chunk_size,
num_inference_steps=cfg.denoise_steps,
)

disparity = pipe_out.disparity
disparity_colored = pipe_out.disparity_colored
image = pipe_out.image
# (N, H, 2 * W, 3)
merged = np.concatenate(
[
image,
disparity_colored,
],
axis=2,
image, fps = img_utils.read_video(cfg.data_path, max_frames=max_frames)

if image is None or len(image) == 0:
raise ValueError("No frames extracted from the video. Please check the input file.")
else:
image = img_utils.read_image(cfg.data_path)

if image is None or len(image) == 0:
raise ValueError("Failed to read the image. Please check the input file.")

image = img_utils.imresize_max(image, cfg.max_resolution)
image = img_utils.imcrop_multi(image)
image_tensor = np.ascontiguousarray(
[_img.transpose(2, 0, 1) / 255.0 for _img in image]
)
image_tensor = torch.from_numpy(image_tensor).to(DEVICE)

with torch.no_grad(), torch.autocast(
device_type=DEVICE_TYPE, dtype=torch.float16
):
pipe_out = pipe(
image_tensor,
num_frames=cfg.num_frames,
num_overlap_frames=cfg.num_overlap_frames,
num_interp_frames=cfg.num_interp_frames,
decode_chunk_size=cfg.decode_chunk_size,
num_inference_steps=cfg.denoise_steps,
)

if is_video:
output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.mp4")
img_utils.write_video(
output_path,
merged,
fps,
)
return output_path
else:
output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.png")
img_utils.write_image(
output_path,
merged[0],
)
return output_path
disparity = pipe_out.disparity
disparity_colored = pipe_out.disparity_colored
image = pipe_out.image
# (N, H, 2 * W, 3)
merged = np.concatenate(
[
image,
disparity_colored,
],
axis=2,
)

if is_video:
output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.mp4") # Ensure .mp4 extension
img_utils.write_video(
output_path,
merged,
fps
)
return output_path
else:
output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.png")
img_utils.write_image(
output_path,
merged[0],
)
return output_path


# Define Gradio interface
Expand All @@ -174,7 +173,7 @@ def depth_any_video(
iface = gr.Interface(
fn=depth_any_video,
inputs=[
gr.File(label="Upload Video/Image"),
gr.File(label="Upload Video/Image", type="filepath"), # Correct type usage
gr.Slider(1, 10, step=1, value=3, label="Denoise Steps"),
gr.Slider(16, 64, step=1, value=32, label="Number of Frames"),
gr.Slider(8, 32, step=1, value=16, label="Decode Chunk Size"),
Expand All @@ -191,4 +190,4 @@ def depth_any_video(
)

if __name__ == "__main__":
iface.launch()
iface.launch(share=True)

0 comments on commit 2cb25ef

Please sign in to comment.