Skip to content

Commit

Permalink
fix lmdeploy qwen_vl (modelscope#2009)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Sep 11, 2024
1 parent 37acf8e commit e9b126b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions swift/llm/utils/client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def decode_base64(*,
prompt: Optional[str] = None,
images: Optional[List[str]] = None,
tmp_dir: str = 'tmp') -> Dict[str, Any]:
# base64 -> local_path
os.makedirs(tmp_dir, exist_ok=True)
res = {}
if messages is not None:
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.models.auto.tokenization_auto import get_tokenizer_config
from transformers.utils import strtobool
from transformers.utils import is_torch_bf16_gpu_available, strtobool
from transformers.utils.versions import require_version

from swift import get_logger
Expand Down Expand Up @@ -6554,7 +6554,7 @@ def get_torch_dtype(model_dir: str) -> Dtype:
if isinstance(torch_dtype, str):
torch_dtype = eval(f'torch.{torch_dtype}')
if torch_dtype in {torch.float32, None}:
torch_dtype = torch.float16
torch_dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
return torch_dtype


Expand Down
6 changes: 3 additions & 3 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,14 +474,14 @@ def _preprocess_media(self, example):
images = example.get('images') or []
if images:
if example.get('objects') or self.load_medias or self._is_lmdeploy or self._is_vllm:
images = load_batch(images, load_image)
if not self.load_medias:
images = decode_base64(images=images)['images']
images = load_batch(images, load_image) # base64/local_path -> PIL.Image
if example.get('objects'):
# Normalize grounding bboxes
self.normalize_bbox(example['objects'], images, to_type=self.grounding_type)
if self.load_medias and self.grounding_type != 'real':
images = [rescale_image(img, self.rescale_image) for img in images]
if not self.load_medias and not self._is_lmdeploy and not self._is_vllm: # fix pt & qwen-vl
images = decode_base64(images=images)['images'] # PIL.Image/base64 -> local_path
example['images'] = images

# Check the example that whether matching the very template's rules
Expand Down

0 comments on commit e9b126b

Please sign in to comment.