Skip to content

Commit

Permalink
fix vllm video (#2864)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Jan 6, 2025
1 parent 0b45cce commit f9fa53b
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 28 deletions.
1 change: 0 additions & 1 deletion swift/llm/argument/export_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class ExportArguments(MergeArguments, BaseArguments):
to_peft_format: bool = False

def _init_output_dir(self):
suffix = None
if self.output_dir is None:
ckpt_dir = self.ckpt_dir or f'./{self.model_suffix}'
ckpt_dir, ckpt_name = os.path.split(ckpt_dir)
Expand Down
1 change: 1 addition & 0 deletions swift/llm/infer/infer_engine/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

try:
# After setting the environment variables, import vllm. This way of writing allows lint to pass.
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '3600'
import vllm
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
Expand Down
26 changes: 20 additions & 6 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import deepcopy
from dataclasses import asdict
from functools import wraps
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import json
import torch
Expand Down Expand Up @@ -454,20 +454,19 @@ def _pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float
for context, loss_scale in zip(context_list, loss_scale_list):
for k in ['image', 'video', 'audio']:
if context == f'<{k}>':
idx = getattr(inputs, f'{k}_idx')
c_list = self.replace_tag(k, idx, inputs)
setattr(inputs, f'{k}_idx', idx + 1)
c_list = self.replace_tag(k, getattr(inputs, f'{k}_idx'), inputs)
setattr(inputs, f'{k}_idx', getattr(inputs, f'{k}_idx') + 1)
loss_scale = 0.
break
else:
if context == '<ref-object>':
idx = inputs.object_idx
c_list = self.replace_object(inputs.objects[idx], idx, inputs)
inputs.object_idx = idx + 1
inputs.object_idx += 1
elif context == '<bbox>':
idx = inputs.box_idx
c_list = self.replace_box(inputs.objects[idx], idx, inputs)
inputs.box_idx = idx + 1
inputs.box_idx += 1
else:
c_list = [context]
res += c_list
Expand Down Expand Up @@ -689,6 +688,21 @@ def debug_logger(self, inputs):
for v in val:
self.print_inputs({k: v.tolist()})

def replace_video2image(self, load_video_func, inputs, replace_tag: Callable) -> List[Context]:
context_list = []
if self.mode == 'pt':
video = inputs.videos[inputs.video_idx]
else:
video = inputs.videos.pop(inputs.video_idx)
inputs.video_idx -= 1
images = inputs.images
new_images = load_video_func(video)
inputs.images = images[:inputs.image_idx] + new_images + images[inputs.image_idx:]
for i in range(len(new_images)):
context_list += replace_tag(i)
inputs.image_idx += len(new_images)
return context_list

def get_generate_ids(self, generate_ids: Union[torch.Tensor, List[int]],
num_prompt_tokens: int) -> Union[torch.Tensor, List[int]]:
if self.skip_prompt:
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..register import register_template
from ..template_inputs import StdTemplateInputs
from ..utils import Context, findall
from ..vision_utils import load_video_internvl, replace_video2image, transform_image
from ..vision_utils import load_video_internvl, transform_image
from .microsoft import Phi3TemplateMeta
from .utils import ChatmlTemplateMeta

Expand Down Expand Up @@ -98,7 +98,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int
elif media_type == 'video':
video_segments = get_env_args('video_segments', int, self.video_segments)
load_video = partial(load_video_internvl, num_segments=video_segments)
return replace_video2image(load_video, inputs, lambda i: [f'Frame{i + 1}: '] + image_context)
return self.replace_video2image(load_video, inputs, lambda i: [f'Frame{i + 1}: '] + image_context)

def replace_object(self, object_: Dict[str, Any], index: int, inputs: StdTemplateInputs) -> List[Context]:
objects = inputs.objects
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..register import TemplateMeta, register_template
from ..template_inputs import StdTemplateInputs
from ..utils import Context, Prompt, findall
from ..vision_utils import load_video_minicpmv_mplug_owl3, replace_video2image
from ..vision_utils import load_video_minicpmv_mplug_owl3
from .llama import Llama3TemplateMeta
from .qwen import QwenTemplateMeta

Expand Down Expand Up @@ -166,7 +166,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index,
if media_type == 'image':
return image_context
elif media_type == 'video':
return replace_video2image(load_video, inputs, lambda i: image_context)
return self.replace_video2image(load_video, inputs, lambda i: image_context)

def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded = Template._encode(self, inputs)
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/mplug.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..register import TemplateMeta, register_template
from ..template_inputs import StdTemplateInputs
from ..utils import Context, Prompt, findall
from ..vision_utils import load_video_minicpmv_mplug_owl3, replace_video2image
from ..vision_utils import load_video_minicpmv_mplug_owl3
from .qwen import QwenTemplateMeta


Expand Down Expand Up @@ -82,7 +82,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int
if media_type == 'image':
return [[-100], '\n']
elif media_type == 'video':
return replace_video2image(load_video, inputs, lambda i: [[-100]]) + ['\n']
return self.replace_video2image(load_video, inputs, lambda i: [[-100]]) + ['\n']

def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded = super()._encode(inputs)
Expand Down
15 changes: 0 additions & 15 deletions swift/llm/template/vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from PIL import Image, ImageDraw

from swift.utils import get_env_args
from .utils import Context

# >>> internvl
IMAGENET_MEAN = (0.485, 0.456, 0.406)
Expand Down Expand Up @@ -325,20 +324,6 @@ def normalize_bbox(objects: List[Dict[str, Any]], images: List[Image.Image], to_
object_['bbox_type'] = to_type


def replace_video2image(load_video_func, inputs, replace_tag: Callable) -> List[Context]:
context_list = []
video_idx = inputs.video_idx
video = inputs.videos[video_idx]
images = inputs.images
image_idx = inputs.image_idx
new_images = load_video_func(video)
inputs.images = images[:image_idx] + new_images + images[image_idx:]
for i in range(len(new_images)):
context_list += replace_tag(i)
inputs.image_idx += len(new_images)
return context_list


if __name__ == '__main__':
# TODO:remove
# A test main to draw bbox
Expand Down

0 comments on commit f9fa53b

Please sign in to comment.