forked from huggingface/smollm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
156 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import torch | ||
from transformers import AutoProcessor, Idefics3ForConditionalGeneration | ||
from PIL import Image | ||
import cv2 | ||
import numpy as np | ||
from typing import List | ||
import logging | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
class VideoFrameExtractor: | ||
def __init__(self, max_frames: int = 50): | ||
self.max_frames = max_frames | ||
|
||
def resize_and_center_crop(self, image: Image.Image, target_size: int) -> Image.Image: | ||
# Get current dimensions | ||
width, height = image.size | ||
|
||
# Calculate new dimensions keeping aspect ratio | ||
if width < height: | ||
new_width = target_size | ||
new_height = int(height * (target_size / width)) | ||
else: | ||
new_height = target_size | ||
new_width = int(width * (target_size / height)) | ||
|
||
# Resize | ||
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | ||
|
||
# Center crop | ||
left = (new_width - target_size) // 2 | ||
top = (new_height - target_size) // 2 | ||
right = left + target_size | ||
bottom = top + target_size | ||
|
||
return image.crop((left, top, right, bottom)) | ||
|
||
def extract_frames(self, video_path: str) -> List[Image.Image]: | ||
cap = cv2.VideoCapture(video_path) | ||
if not cap.isOpened(): | ||
raise ValueError(f"Could not open video: {video_path}") | ||
|
||
# Get video properties | ||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | ||
fps = int(cap.get(cv2.CAP_PROP_FPS)) | ||
|
||
# Calculate frame indices to extract (1fps) | ||
frame_indices = list(range(0, total_frames, fps)) | ||
|
||
# If we have more frames than max_frames, sample evenly | ||
if len(frame_indices) > self.max_frames: | ||
indices = np.linspace(0, len(frame_indices) - 1, self.max_frames, dtype=int) | ||
frame_indices = [frame_indices[i] for i in indices] | ||
|
||
frames = [] | ||
for frame_idx in frame_indices: | ||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | ||
ret, frame = cap.read() | ||
if ret: | ||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
pil_image = Image.fromarray(frame) | ||
pil_image = self.resize_and_center_crop(pil_image, 384) | ||
frames.append(pil_image) | ||
|
||
cap.release() | ||
return frames | ||
|
||
def load_model(checkpoint_path: str, base_model_id: str = "HuggingFaceTB/SmolVLM-Instruct", device: str = "cuda"): | ||
# Load processor from original model | ||
processor = AutoProcessor.from_pretrained(base_model_id) | ||
if checkpoint_path: | ||
# Load fine-tuned model from checkpoint | ||
model = Idefics3ForConditionalGeneration.from_pretrained( | ||
checkpoint_path, | ||
torch_dtype=torch.bfloat16, | ||
device_map=device | ||
) | ||
else: | ||
model = Idefics3ForConditionalGeneration.from_pretrained( | ||
base_model_id, | ||
torch_dtype=torch.bfloat16, | ||
device_map=device | ||
) | ||
|
||
# Configure processor for video frames | ||
processor.image_processor.size = (384, 384) | ||
processor.image_processor.do_resize = False | ||
processor.image_processor.do_image_splitting = False | ||
|
||
return model, processor | ||
|
||
def generate_response(model, processor, video_path: str, question: str, max_frames: int = 50): | ||
# Extract frames | ||
frame_extractor = VideoFrameExtractor(max_frames) | ||
frames = frame_extractor.extract_frames(video_path) | ||
logger.info(f"Extracted {len(frames)} frames from video") | ||
|
||
# Create prompt with frames | ||
image_tokens = [{"type": "image"} for _ in range(len(frames))] | ||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": "Answer briefly."}, | ||
*image_tokens, | ||
{"type": "text", "text": question} | ||
] | ||
} | ||
] | ||
|
||
# Process inputs | ||
inputs = processor( | ||
text=processor.apply_chat_template(messages, add_generation_prompt=True), | ||
images=[img for img in frames], | ||
return_tensors="pt" | ||
).to(model.device) | ||
|
||
# Generate response | ||
outputs = model.generate( | ||
**inputs, | ||
max_new_tokens=100, | ||
num_beams=5, | ||
temperature=0.7, | ||
do_sample=True, | ||
use_cache=True | ||
) | ||
|
||
# Decode response | ||
response = processor.decode(outputs[0], skip_special_tokens=True) | ||
return response | ||
|
||
def main(): | ||
# Configuration | ||
#checkpoint_path = "/path/to/your/checkpoint" | ||
checkpoint_path = None | ||
base_model_id = "HuggingFaceTB/SmolVLM-Instruct" | ||
video_path = "/path/to/video.mp4" | ||
question = "Describe the video" | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
# Load model | ||
logger.info("Loading model...") | ||
model, processor = load_model(checkpoint_path, base_model_id, device) | ||
|
||
# Generate response | ||
logger.info("Generating response...") | ||
response = generate_response(model, processor, video_path, question) | ||
|
||
# Print results | ||
print("Question:", question) | ||
print("Response:", response) | ||
|
||
if __name__ == "__main__": | ||
main() |