Skip to content

Commit

Permalink
extend qwen2-vl and bugfix (#407)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostkevin authored Dec 27, 2024
1 parent 24fa3e9 commit d7cab4e
Show file tree
Hide file tree
Showing 24 changed files with 994 additions and 637 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ English | [简体中文](./README_zh-CN.md)
Pai-Megatron-Patch (https://github.com/alibaba/Pai-Megatron-Patch) is a deep learning training toolkit built for developers to train and predict LLMs & VLMs by using Megatron framework easily. With the continuous development of LLMs, the model structure and scale are rapidly evolving. Although these models can be conveniently manufactured using Transformers or DeepSpeed training framework, the training efficiency is comparably low. This phenomenon becomes even severer when the model scale exceeds 10 billion. The primary objective of Pai-Megatron-Patch is to effectively utilize the computational power of GPUs for LLM. This tool allows convenient training of commonly used LLM with all the accelerating techniques provided by Megatron-LM.

What's New:
- **Upgrade Qwen2-VL models to support MG2HF ckpts conversion and training with multi-turn complex multimodal samples.** [🔥🔥 2024.12.27]
- **Support training Qwen2-VL models by using Megatron-Core.** [🔥🔥 2024.11.27]
- **Support training LLaVA models by using Megatron-Core.** [🔥🔥 2024.11.20]
- **Add llm auto configurator and apply per seq sft loss for qwen2/2.5 models.** [🔥🔥 2024.10.30]
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Pai-Megatron-Patch是各类开源大模型和Megatron训练加速引擎之间的
- [阿里云PAI获得FewCLUE基于大模型的小样本学习双料冠军](https://developer.aliyun.com/article/788081?spm=a2c6h.12873639.article-detail.17.11c5383cHpFZks&tlog=yuekan_8)

新功能:
- **拓展Qwen2-VL模型权重转换及多轮复杂多模态数据的训练支持** [🔥🔥 2024.12.27]
- **支持用Megatron-Core框架训练Qwen2-VL模型** [🔥🔥 2024.11.27]
- **支持用Megatron-Core框架训练LLaVA模型** [🔥🔥 2024.11.20]
- **添加大模型训练最优吞吐参数自动配置以及针对qwen2/2.5系列模型优化微调per seq sft loss.** [🔥🔥 2024.10.30]
Expand Down
9 changes: 6 additions & 3 deletions examples/deepseek_v2/pretrain_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_batch(data_iterator):

# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None, None
return None, None, None, None, None, None, None

args = get_args()

Expand Down Expand Up @@ -130,7 +130,10 @@ def get_batch(data_iterator):
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)

return tuple([*batch.values(), packed_seq_params])
if args.train_mode == "pretrain":
return tuple([*batch.values(), None, packed_seq_params])
else:
return tuple([*batch.values(), packed_seq_params])
else:
raise ValueError("please set correct --dataset ")

Expand Down Expand Up @@ -181,7 +184,7 @@ def forward_step(data_iterator, model: GPTModel):

# Get the batch.
timers("batch-generator", log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch(data_iterator)
tokens, labels, loss_mask, attention_mask, position_ids, _, packed_seq_params = get_batch(data_iterator)
timers("batch-generator").stop()
output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params)

Expand Down
13 changes: 10 additions & 3 deletions examples/llama2/pretrain_mcore_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_batch(data_iterator):

# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
return None, None, None, None, None, None

args = get_args()

Expand All @@ -101,7 +101,14 @@ def get_batch(data_iterator):
batch = get_batch_on_this_tp_rank(data_iterator)
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)

return (
batch['tokens'],
batch['labels'],
batch['loss_mask'],
batch['attention_mask'],
batch['position_ids'],
None
)
else:
raise ValueError("please set correct --dataset ")

Expand Down Expand Up @@ -146,7 +153,7 @@ def forward_step(data_iterator, model):

# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
tokens, labels, loss_mask, attention_mask, position_ids, _ = get_batch(
data_iterator)
timers('batch-generator').stop()

Expand Down
13 changes: 10 additions & 3 deletions examples/llama3/pretrain_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_batch(data_iterator):

# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
return None, None, None, None, None, None

args = get_args()

Expand All @@ -105,7 +105,14 @@ def get_batch(data_iterator):
batch = get_batch_on_this_tp_rank(data_iterator)
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)

return (
batch['tokens'],
batch['labels'],
batch['loss_mask'],
batch['attention_mask'],
batch['position_ids'],
None
)
else:
raise ValueError("please set correct --dataset ")

Expand Down Expand Up @@ -154,7 +161,7 @@ def forward_step(data_iterator, model: GPTModel):

# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
tokens, labels, loss_mask, attention_mask, position_ids, _ = get_batch(
data_iterator)
timers('batch-generator').stop()

Expand Down
13 changes: 10 additions & 3 deletions examples/llama3/pretrain_llama_mcore070.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_batch(data_iterator):

# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
return None, None, None, None, None, None

args = get_args()

Expand All @@ -87,7 +87,14 @@ def get_batch(data_iterator):
batch = get_batch_on_this_tp_rank(data_iterator)
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)

return (
batch['tokens'],
batch['labels'],
batch['loss_mask'],
batch['attention_mask'],
batch['position_ids'],
None
)
else:
raise ValueError("please set correct --dataset ")

Expand Down Expand Up @@ -137,7 +144,7 @@ def forward_step(data_iterator, model: GPTModel):

# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
tokens, labels, loss_mask, attention_mask, position_ids, _ = get_batch(
data_iterator)
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
Expand Down
10 changes: 7 additions & 3 deletions examples/llama3_1/pretrain_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_batch(data_iterator):

# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None, None
return None, None, None, None, None, None, None

args = get_args()

Expand Down Expand Up @@ -133,7 +133,11 @@ def get_batch(data_iterator):
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)

return tuple([*batch.values(), packed_seq_params])

if args.train_mode == "pretrain":
return tuple([*batch.values(), None, packed_seq_params])
else:
return tuple([*batch.values(), packed_seq_params])
else:
raise ValueError("please set correct --dataset ")

Expand Down Expand Up @@ -184,7 +188,7 @@ def forward_step(data_iterator, model: GPTModel):

# Get the batch.
timers("batch-generator", log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch(data_iterator)
tokens, labels, loss_mask, attention_mask, position_ids, _, packed_seq_params = get_batch(data_iterator)
timers("batch-generator").stop()
output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params)

Expand Down
6 changes: 3 additions & 3 deletions examples/mistral/pretrain_mcore_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@
from megatron_patch.data import build_pretrain_dataset_from_original

from megatron_patch.data.utils import get_batch_on_this_tp_rank_original, get_batch_on_this_tp_rank_idxmap_sft
from megatron_patch.model.mixtral.layer_specs import (
from megatron_patch.model.mixtral_bak.layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron_patch.model.mixtral.model import GPTModel
from megatron_patch.model.mixtral.transformer_config import TransformerConfig
from megatron_patch.model.mixtral_bak.model import GPTModel
from megatron_patch.model.mixtral_bak.transformer_config import TransformerConfig
from megatron_patch.tokenizer import build_tokenizer, get_tokenizer
from megatron.core.packed_seq_params import PackedSeqParams

Expand Down
13 changes: 10 additions & 3 deletions examples/mistral/pretrain_mcore_mistral_bak.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_batch(data_iterator):

# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
return None, None, None, None, None, None

args = get_args()

Expand All @@ -84,7 +84,14 @@ def get_batch(data_iterator):
batch = get_batch_on_this_tp_rank(data_iterator)
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)

return (
batch['tokens'],
batch['labels'],
batch['loss_mask'],
batch['attention_mask'],
batch['position_ids'],
None
)
else:
raise ValueError("please set correct --dataset ")

Expand Down Expand Up @@ -128,7 +135,7 @@ def forward_step(data_iterator, model):

# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
tokens, labels, loss_mask, attention_mask, position_ids, _ = get_batch(
data_iterator)
timers('batch-generator').stop()

Expand Down
12 changes: 10 additions & 2 deletions examples/qwen1_5/pretrain_mcore_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_batch(data_iterator):

# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
return None, None, None, None, None, None

args = get_args()

Expand All @@ -107,6 +107,14 @@ def get_batch(data_iterator):
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)

return (
batch['tokens'],
batch['labels'],
batch['loss_mask'],
batch['attention_mask'],
batch['position_ids'],
None
)
else:
raise ValueError("please set correct --dataset ")

Expand Down Expand Up @@ -155,7 +163,7 @@ def forward_step(data_iterator, model: GPTModel):

# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
tokens, labels, loss_mask, attention_mask, position_ids, _ = get_batch(
data_iterator)
timers('batch-generator').stop()

Expand Down
60 changes: 54 additions & 6 deletions examples/qwen2_vl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ wget https://atp-modelzoo-wlcb-pai.oss-cn-wulanchabu.aliyuncs.com/release/models
tar -zxf wds.tgz
```
对于视频多模态、单样本中包含多张图片、多轮对话等复杂数据集,您需要将其转换为sharegpt格式数据后再使用Megatron-Patch训练。对于sharegpt格式的数据处理,参见[链接](./dataset_preparation.md)。
## Megatron-Core模型训练流程
### Megatron-Core模型格式转换
运行`hf2mcore_qwen2_vl_convertor.sh`脚本,需要传入的参数列表如下
Expand Down Expand Up @@ -84,6 +87,21 @@ false \
bf16
```
当您需要将训练好的checkpoint转换回huggingface格式用于推理时,执行
```bash
cd /workspace/Pai-Megatron-Patch/toolkits/model_checkpoints_convertor/qwen
bash hf2mcore_qwen2_vl_convertor.sh \
7B \
/mnt/qwen2-vl-ckpts/Qwen2-VL-7B-Instruct-tp2pp2 \
/mnt/qwen2-vl-ckpts/Qwen2-VL-7B-Instruct-tp2pp2-back \
2 \
2 \
true \
bf16 \
/mnt/qwen2-vl-ckpts/Qwen2-VL-7B-Instruct
```
### Megatron-Core预训练
#### 预训练命令描述
Expand All @@ -102,7 +120,7 @@ TP=${10} # 模型并行度
PP=${11} # 流水并行度
CP=${12} # 上下文并行度
DO=${13} # 是否使用Megatron版Zero-1降显存优化器: true, false
FL=${14} # 是否优先使用Flash Attention: true, false
FL=${14} # 是否优先使用Flash Attention: false
AC=${15} # 激活检查点模式: sel, full, offload, false
OPTIMIZER_OFFLOAD=${16} # 是否启用Offload optimizer: false, static, auto
SAVE_INTERVAL=${17} # 保存ckpt的间隔
Expand All @@ -123,17 +141,47 @@ sh run_mcore_qwen.sh \
dsw \
7B \
1 \
256 \
0.00015 \
32 \
1e-5 \
1e-6 \
2048 \
2048 \
bf16 \
2 \
2 \
1 \
true \
false \
true \
false \
100000 \
/mnt/llava-datasets/LLaVA-Pretrain/wds \
/mnt/llava-datasets/LLaVA-Pretrain/wds \
/mnt/qwen2-vl-ckpts/Qwen2-VL-7B-Instruct-tp2pp2 \
20000 \
200 \
/workspace/output_mcore_qwen2vl_pretrain
```
由于PP切分时,PP Rank 0额外的ViT会导致其负载略高于其他PP Rank,为了达到最佳性能,您可能需要调整`MP_PP0_LAYERS`变量降低PP Rank 0的LLM层数。
```bash
cd /workspace/Pai-Megatron-Patch/examples/qwen2_vl
MP_PP0_LAYERS=12 sh run_mcore_qwen.sh \
dsw \
7B \
1 \
32 \
1e-5 \
1024 \
1024 \
1e-6 \
2048 \
2048 \
bf16 \
2 \
2 \
1 \
true \
true \
false \
true \
false \
100000 \
Expand Down
Loading

0 comments on commit d7cab4e

Please sign in to comment.