Skip to content

Commit

Permalink
fix glm4v-merge_lora (modelscope#1104)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Jun 8, 2024
1 parent 9ecfda7 commit a5454a3
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 5 deletions.
3 changes: 2 additions & 1 deletion swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import datetime as dt
import inspect
import math
import os
Expand Down Expand Up @@ -1260,7 +1261,7 @@ def __post_init__(self):
@dataclass
class EvalArguments(InferArguments):

name: Optional[str] = None
name: Optional[str] = field(default_factory=lambda: dt.datetime.now().strftime('%Y%m%d-%H%M%S'))

eval_url: Optional[str] = None

Expand Down
2 changes: 2 additions & 0 deletions swift/llm/utils/client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def inference_client(

if is_chat_request is None:
is_chat_request = _is_chat
assert is_chat_request is not None, (
'Please set the `is_chat_request` parameter to indicate whether the model is a chat model.')
data = {k: v for k, v in request_config.__dict__.items() if not k.startswith('__')}
if is_chat_request:
messages = history_to_messages(history, query, system)
Expand Down
1 change: 1 addition & 0 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,7 @@ def get_model_tokenizer_chatglm(model_dir: str,
remove_property(tokenizer_cls, tokenizer_config)
kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True)
model, tokenizer = get_model_tokenizer_from_repo(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)
tokenizer.init_kwargs['image_size'] = 1120
if model is not None:
from torch.nn import CrossEntropyLoss
__old_forward = CrossEntropyLoss.forward
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ def __call__(self, dataset: HfDataset) -> HfDataset:
kwargs = {}
if has_system:
kwargs['system'] = system
if has_history:
kwargs['history'] = history
kwargs.update({
'query': query,
'response': response,
})
if has_history:
kwargs['history'] = history
dataset = HfDataset.from_dict(kwargs)
return dataset

Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def random_uuid() -> str:
@dataclass
class Model:
id: str # model_type
is_chat: bool # chat model or generation model
is_chat: Optional[bool] = None # chat model or generation model
is_multimodal: bool = False

object: str = 'model'
Expand Down
4 changes: 3 additions & 1 deletion tests/llm/test_run2.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def test_glm4v_9b_chat(self):
lazy_tokenize=False))
best_model_checkpoint = output['best_model_checkpoint']
torch.cuda.empty_cache()
infer_main(InferArguments(ckpt_dir=best_model_checkpoint, load_dataset_config=True, val_dataset_sample=2))
infer_main(
InferArguments(
ckpt_dir=best_model_checkpoint, load_dataset_config=True, val_dataset_sample=2, merge_lora=True))

def test_baichuan2_chat_int4(self):
if not __name__ == '__main__':
Expand Down

0 comments on commit a5454a3

Please sign in to comment.