Skip to content

Commit

Permalink
Fix bugs (#2838)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Jan 2, 2025
1 parent f61aa4e commit a5dca04
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 32 deletions.
17 changes: 9 additions & 8 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
- 🔥model: 模型id或模型本地路径。如果是自定义模型请配合`model_type``template`使用,具体可以参考[自定义模型](../Customization/自定义模型.md)
- model_type: 模型类型。相同的模型架构、template、模型加载过程被定义为一个model_type
- model_revision: 模型版本
- task_type: 默认为'causal_lm'. 可选为'causal_lm', 'seq_cls'. 例子可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/seq_cls).
- task_type: 默认为'causal_lm'. 可选为'causal_lm', 'seq_cls'. seq_cls的例子可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/seq_cls).
- 🔥torch_dtype: 模型权重的数据类型,支持`float16`,`bfloat16`,`float32`,默认从config文件中读取
- attn_impl: attention类型,支持`flash_attn`, `sdpa`, `eager`,默认使用sdpa
- num_labels: 分类模型需要指定。代表标签数量,默认为None
Expand All @@ -45,7 +45,7 @@
### 模板参数
- 🔥template: 对话模板类型,默认使用model对应的template类型。`swift pt`会将对话模版转为生成模板使用
- 🔥system: 自定义system字段,默认为None,使用template的默认system
- 🔥max_length: 单样本的tokens最大长度默认为None,不做限制
- 🔥max_length: 单样本的tokens最大长度默认为None,设置为模型支持的tokens最大长度(max_model_len)
- truncation_strategy: 如果超长如何处理,支持`delete`, `left``right`,代表删除、左侧裁剪和右侧裁剪,默认为'delete'
- 🔥max_pixels: 多模态模型图片前处理的最大像素数(H\*W),默认不缩放。
- tools_prompt: 智能体训练时的工具列表转为system的格式,请参考[智能体训练](./智能体的支持.md),默认为'react_en'
Expand Down Expand Up @@ -96,7 +96,7 @@
- lr_scheduler_type: lr_scheduler类型,默认为cosine
- lr_scheduler_kwargs: lr_scheduler其他参数
- 🔥gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数. 例如设置为`--gradient_checkpointing_kwargs '{"use_reentrant": false}'`
- report_to: 默认值为`tensorboard`
- report_to: 默认值为`tensorboard`。你也可以指定`--report_to tensorboard wandb`, `--report_to all`
- remove_unused_columns: 默认值False
- logging_first_step: 是否记录第一个step的打印,默认值True
- logging_steps: 日志打印间隔,默认值5
Expand Down Expand Up @@ -139,7 +139,7 @@
#### 全参
- freeze_parameters: 被冻结参数的前缀, 默认为`[]`
- freeze_parameters_ratio: 从下往上冻结的参数比例, 默认为0. 可设置为1将所有参数冻结, 结合`trainable_parameters`设置可训练参数.
- trainable_parameters: 可训练参数的前缀, 默认为`[]`
- trainable_parameters: 可训练参数的前缀, 默认为`[]`. `trainable_parameters`的优先级高于`freeze_parameters``freeze_parameters_ratio`

#### LoRA
- 🔥lora_rank: 默认为`8`
Expand Down Expand Up @@ -306,7 +306,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数.
### RLHF参数
RLHF参数继承于[训练参数](#训练参数)

- 🔥rlhf_type: 对齐算法类型,支持`dpo`, `orpo`, `simpo`, `kto`, `cpo`
- 🔥rlhf_type: 对齐算法类型,支持`dpo`, `orpo`, `simpo`, `kto`, `cpo`, `rm`
- ref_model: DPO等算法中的原始对比模型
- ref_model_type: 同model_type
- ref_model_revision: 同model_revision
Expand Down Expand Up @@ -403,22 +403,23 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数)

- IMAGE_FACTOR: 默认为28
- MIN_PIXELS: 默认为`4 * 28 * 28`
- MAX_PIXELS: 默认为`16384 * 28 * 28`,参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/ocr.sh#L3)
- 🔥MAX_PIXELS: 默认为`16384 * 28 * 28`,参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/ocr.sh#L3)
- MAX_RATIO: 默认为200
- VIDEO_MIN_PIXELS: 默认为`128 * 28 * 28`
- VIDEO_MAX_PIXELS: 默认为`768 * 28 * 28`,参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/video.sh#L7)
- 🔥VIDEO_MAX_PIXELS: 默认为`768 * 28 * 28`,参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/video.sh#L7)
- VIDEO_TOTAL_PIXELS: 默认为`24576 * 28 * 28`
- FRAME_FACTOR: 默认为2
- FPS: 默认为2.0
- FPS_MIN_FRAMES: 默认为4
- FPS_MAX_FRAMES: 默认为768,参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/video.sh#L8)
- 🔥FPS_MAX_FRAMES: 默认为768,参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/video.sh#L8)

### internvl, internvl_phi3
参数含义可以查看[这里](https://modelscope.cn/models/OpenGVLab/Mini-InternVL-Chat-2B-V1-5)
- MAX_NUM: 默认为12
- INPUT_SIZE: 默认为448

### internvl2, internvl2_phi3, internvl2_5
参数含义可以查看[这里](https://modelscope.cn/models/OpenGVLab/InternVL2_5-2B)
- MAX_NUM: 默认为12
- INPUT_SIZE: 默认为448
- VIDEO_MAX_NUM: 默认为1。视频的MAX_NUM
Expand Down
17 changes: 9 additions & 8 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The introduction to command line parameters will cover base arguments, atomic ar
- model_type: Model type. The same model architecture, template, and loading process define a model_type.
- model_revision: Model version.
- 🔥torch_dtype: Data type for model weights, supports `float16`, `bfloat16`, `float32`, default is read from the config file.
- task_type: Defaults to 'causal_lm'. Options include 'causal_lm' and 'seq_cls'. You can view examples [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/seq_cls).
- task_type: Defaults to 'causal_lm'. Options include 'causal_lm' and 'seq_cls'. You can view examples of seq_cls [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/seq_cls).
- attn_impl: Attention type, supports `flash_attn`, `sdpa`, `eager`, default is sdpa.
- num_labels: To be specified for classification models, representing the number of labels, default is None.
- rope_scaling: Rope type, supports `linear` and `dynamic`, to be used with `max_length`.
Expand All @@ -45,7 +45,7 @@ The introduction to command line parameters will cover base arguments, atomic ar
### Template Arguments
- 🔥template: Type of dialogue template, which defaults to the template type corresponding to the model. `swift pt` will convert the dialogue template into a generation template for use.
- 🔥system: Custom system field, default is None, uses the default system of the template.
- 🔥max_length: Maximum length of tokens for a single sample, default is None (no limit).
- 🔥max_length: The maximum length of tokens for a single sample. Defaults to None, set to the maximum length of tokens supported by the model (max_model_len).
- truncation_strategy: How to handle overly long tokens, supports `delete`, `left`, `right`, representing deletion, left trimming, and right trimming, default is 'delete'.
- 🔥max_pixels: Maximum pixel count for pre-processing images in multimodal models (H*W), default is no scaling.
- tools_prompt: The list of tools for agent training converted to system format, refer to [Agent Training](./Agent-support.md), default is 'react_en'.
Expand Down Expand Up @@ -97,7 +97,7 @@ This parameter list inherits from transformers `Seq2SeqTrainingArguments`, with
- lr_scheduler_type: LR scheduler type, default is cosine.
- lr_scheduler_kwargs: Other parameters for the LR scheduler.
- 🔥gradient_checkpointing_kwargs: Parameters passed to `torch.utils.checkpoint`. For example, set to `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`.
- report_to: Default is `tensorboard`.
- report_to: Default is `tensorboard`. You can also specify `--report_to tensorboard wandb`, `--report_to all`.
- remove_unused_columns: Default is False.
- logging_first_step: Whether to log the first step print, default is True.
- logging_steps: Interval for logging prints, default is 5.
Expand Down Expand Up @@ -141,7 +141,7 @@ Other important parameters:

- freeze_parameters: Prefix of parameters to be frozen, default is `[]`.
- freeze_parameters_ratio: Ratio of parameters to freeze from the bottom up, default is 0. Setting it to 1 will freeze all parameters. Combine with `trainable_parameters` to set trainable parameters.
- trainable_parameters: Prefix of trainable parameters, default is `[]`.
- trainable_parameters: Prefix of trainable parameters, default is `[]`. The priority of `trainable_parameters` is higher than that of `freeze_parameters` and `freeze_parameters_ratio`.

#### LoRA

Expand Down Expand Up @@ -310,7 +310,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine

RLHF arguments inherit from the [training arguments](#training-arguments).

- 🔥rlhf_type: Alignment algorithm type, supports `dpo`, `orpo`, `simpo`, `kto`, `cpo`.
- 🔥rlhf_type: Alignment algorithm type, supports `dpo`, `orpo`, `simpo`, `kto`, `cpo`, `rm`.
- ref_model: Original comparison model in algorithms like DPO.
- ref_model_type: Same as model_type.
- ref_model_revision: Same as model_revision.
Expand Down Expand Up @@ -403,22 +403,23 @@ For the meaning of the arguments, please refer to [here](https://github.com/Qwen

- IMAGE_FACTOR: Default is 28
- MIN_PIXELS: Default is `4 * 28 * 28`
- MAX_PIXELS: Default is `16384 * 28 * 28`, refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/ocr.sh#L3)
- 🔥MAX_PIXELS: Default is `16384 * 28 * 28`, refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/ocr.sh#L3)
- MAX_RATIO: Default is 200
- VIDEO_MIN_PIXELS: Default is `128 * 28 * 28`
- VIDEO_MAX_PIXELS: Default is `768 * 28 * 28`, refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/video.sh#L7)
- 🔥VIDEO_MAX_PIXELS: Default is `768 * 28 * 28`, refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/video.sh#L7)
- VIDEO_TOTAL_PIXELS: Default is `24576 * 28 * 28`
- FRAME_FACTOR: Default is 2
- FPS: Default is 2.0
- FPS_MIN_FRAMES: Default is 4
- FPS_MAX_FRAMES: Default is 768, refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/video.sh#L8)
- 🔥FPS_MAX_FRAMES: Default is 768, refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/video.sh#L8)

### internvl, internvl_phi3
For the meaning of the arguments, please refer to [here](https://modelscope.cn/models/OpenGVLab/Mini-InternVL-Chat-2B-V1-5)
- MAX_NUM: Default is 12
- INPUT_SIZE: Default is 448

### internvl2, internvl2_phi3, internvl2_5
For the meaning of the arguments, please refer to [here](https://modelscope.cn/models/OpenGVLab/InternVL2_5-2B)
- MAX_NUM: Default is 12
- INPUT_SIZE: Default is 448
- VIDEO_MAX_NUM: Default is 1, which is the MAX_NUM for videos
Expand Down
2 changes: 1 addition & 1 deletion examples/train/multimodal/caption.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ CUDA_VISIBLE_DEVICES=0 \
MAX_PIXELS=1003520 \
swift sft \
--model Qwen/Qwen2-VL-7B-Instruct \
--dataset 'modelscope/coco_2014_caption#20000' \
--dataset 'modelscope/coco_2014_caption:validation#20000' \
--train_type lora \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
Expand Down
1 change: 1 addition & 0 deletions examples/train/multimodal/infer.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Perform inference using the validation set from the training phase.
CUDA_VISIBLE_DEVICES=0 \
MAX_PIXELS=1003520 \
swift infer \
--adapters output/vx-xxx/checkpoint-xxx \
--stream true \
Expand Down
29 changes: 19 additions & 10 deletions swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,13 @@ def _infer_stream(self,
raise ValueError(error_msg)
streamer = TokensIteratorStreamer()
generate_kwargs = {
'adapter_names': self._get_adapter_names(adapter_request),
'generation_config': generation_config,
'streamer': streamer,
**inputs,
}
adapter_names = self._get_adapter_names(adapter_request)
if adapter_names is not None:
generate_kwargs['adapter_names'] = adapter_names
num_prompt_tokens = self._get_num_tokens(inputs)

logits_streamer = None
Expand Down Expand Up @@ -272,12 +274,20 @@ def _infer_seq_cls(self,
inputs: Dict[str, Any],
adapter_request: Optional[AdapterRequest] = None,
**kwargs):
call_kwargs = {'adapter_names': self._get_adapter_names(adapter_request)}
call_kwargs = {}
adapter_names = self._get_adapter_names(adapter_request)
if adapter_names is not None:
call_kwargs['adapter_names'] = adapter_names
num_prompt_tokens = self._get_num_tokens(inputs)
inputs.pop('labels')
logits = self.model(**inputs, **call_kwargs).logits
logprobs = torch.log_softmax(logits, -1)
preds = torch.argmax(logits, dim=-1).tolist()
if logits.shape[-1] > 1:
preds = torch.argmax(logits, dim=-1).tolist()
logprobs = torch.log_softmax(logits, -1)
logprobs = [self._get_seq_cls_logprobs(logprobs[i]) for i in range(len(preds))]
else:
preds = logits.squeeze(dim=-1).tolist()
logprobs = [None] * len(preds)
res = []
for i, pred in enumerate(preds):
usage_info = self._get_usage_info(num_prompt_tokens, 1)
Expand All @@ -286,7 +296,7 @@ def _infer_seq_cls(self,
index=0,
message=ChatMessage(role='assistant', content=str(pred), tool_calls=None),
finish_reason='stop',
logprobs=self._get_seq_cls_logprobs(logprobs[i]))
logprobs=logprobs[i])
]
res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info))
return res
Expand All @@ -299,11 +309,10 @@ def _infer_full(self,
adapter_request: Optional[AdapterRequest] = None,
template_inputs=None) -> Union[List[ChatCompletionResponse]]:
# bos_token TODO: encoder-decoder
generate_kwargs = {
'adapter_names': self._get_adapter_names(adapter_request),
'generation_config': generation_config,
**inputs
}
generate_kwargs = {'generation_config': generation_config, **inputs}
adapter_names = self._get_adapter_names(adapter_request)
if adapter_names is not None:
generate_kwargs['adapter_names'] = adapter_names
num_prompt_tokens = self._get_num_tokens(inputs)

generate_kwargs = template.prepare_generate_kwargs(generate_kwargs, model=self.model)
Expand Down
6 changes: 4 additions & 2 deletions swift/llm/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ class LLMModelType:
mamba = 'mamba'
polylm = 'polylm'
aya = 'aya'
# bert


class BertModelType:
modern_bert = 'modern_bert'
bert = 'bert'

Expand Down Expand Up @@ -174,7 +176,7 @@ class MLLMModelType:
megrez_omni = 'megrez_omni'


class ModelType(LLMModelType, MLLMModelType):
class ModelType(LLMModelType, MLLMModelType, BertModelType):

@classmethod
def get_model_name_list(cls) -> List[str]:
Expand Down
6 changes: 3 additions & 3 deletions swift/llm/model/model/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from transformers import AutoConfig

from swift.utils import get_logger
from ..constant import LLMModelType
from ..constant import BertModelType
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_from_local, register_model

logger = get_logger()
Expand All @@ -17,7 +17,7 @@ def get_model_tokenizer_modern_bert(model_dir, *args, **kwargs):

register_model(
ModelMeta(
LLMModelType.modern_bert, [
BertModelType.modern_bert, [
ModelGroup([
Model('answerdotai/ModernBERT-base', 'answerdotai/ModernBERT-base'),
Model('answerdotai/ModernBERT-large', 'answerdotai/ModernBERT-large'),
Expand All @@ -30,7 +30,7 @@ def get_model_tokenizer_modern_bert(model_dir, *args, **kwargs):

register_model(
ModelMeta(
LLMModelType.bert, [ModelGroup([
BertModelType.bert, [ModelGroup([
Model('iic/nlp_structbert_backbone_base_std'),
])],
None,
Expand Down

0 comments on commit a5dca04

Please sign in to comment.