diff --git a/README.md b/README.md index 8f2f61a..624d5f0 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ English | [简体中文](./README_zh-CN.md) Pai-Megatron-Patch (https://github.com/alibaba/Pai-Megatron-Patch) is a deep learning training toolkit built for developers to train and predict LLMs & VLMs by using Megatron framework easily. With the continuous development of LLMs, the model structure and scale are rapidly evolving. Although these models can be conveniently manufactured using Transformers or DeepSpeed training framework, the training efficiency is comparably low. This phenomenon becomes even severer when the model scale exceeds 10 billion. The primary objective of Pai-Megatron-Patch is to effectively utilize the computational power of GPUs for LLM. This tool allows convenient training of commonly used LLM with all the accelerating techniques provided by Megatron-LM. What's New: +- **Upgrade Qwen2-VL models to support MG2HF ckpts conversion and training with multi-turn complex multimodal samples.** [🔥🔥 2024.12.27] - **Support training Qwen2-VL models by using Megatron-Core.** [🔥🔥 2024.11.27] - **Support training LLaVA models by using Megatron-Core.** [🔥🔥 2024.11.20] - **Add llm auto configurator and apply per seq sft loss for qwen2/2.5 models.** [🔥🔥 2024.10.30] diff --git a/README_zh-CN.md b/README_zh-CN.md index b897010..08d7b2c 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -41,6 +41,7 @@ Pai-Megatron-Patch是各类开源大模型和Megatron训练加速引擎之间的 - [阿里云PAI获得FewCLUE基于大模型的小样本学习双料冠军](https://developer.aliyun.com/article/788081?spm=a2c6h.12873639.article-detail.17.11c5383cHpFZks&tlog=yuekan_8) 新功能: +- **拓展Qwen2-VL模型权重转换及多轮复杂多模态数据的训练支持** [🔥🔥 2024.12.27] - **支持用Megatron-Core框架训练Qwen2-VL模型** [🔥🔥 2024.11.27] - **支持用Megatron-Core框架训练LLaVA模型** [🔥🔥 2024.11.20] - **添加大模型训练最优吞吐参数自动配置以及针对qwen2/2.5系列模型优化微调per seq sft loss.** [🔥🔥 2024.10.30] diff --git a/examples/deepseek_v2/pretrain_deepseek.py b/examples/deepseek_v2/pretrain_deepseek.py index 2f1d4ad..3a04b5c 100644 --- a/examples/deepseek_v2/pretrain_deepseek.py +++ b/examples/deepseek_v2/pretrain_deepseek.py @@ -86,7 +86,7 @@ def get_batch(data_iterator): # TODO: this is pretty hacky, find a better way if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): - return None, None, None, None, None, None + return None, None, None, None, None, None, None args = get_args() @@ -130,7 +130,10 @@ def get_batch(data_iterator): # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) - return tuple([*batch.values(), packed_seq_params]) + if args.train_mode == "pretrain": + return tuple([*batch.values(), None, packed_seq_params]) + else: + return tuple([*batch.values(), packed_seq_params]) else: raise ValueError("please set correct --dataset ") @@ -181,7 +184,7 @@ def forward_step(data_iterator, model: GPTModel): # Get the batch. timers("batch-generator", log_level=2).start() - tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch(data_iterator) + tokens, labels, loss_mask, attention_mask, position_ids, _, packed_seq_params = get_batch(data_iterator) timers("batch-generator").stop() output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params) diff --git a/examples/llama2/pretrain_mcore_llama.py b/examples/llama2/pretrain_mcore_llama.py index 035d1a8..8596f2c 100644 --- a/examples/llama2/pretrain_mcore_llama.py +++ b/examples/llama2/pretrain_mcore_llama.py @@ -86,7 +86,7 @@ def get_batch(data_iterator): # TODO: this is pretty hacky, find a better way if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): - return None, None, None, None, None + return None, None, None, None, None, None args = get_args() @@ -101,7 +101,14 @@ def get_batch(data_iterator): batch = get_batch_on_this_tp_rank(data_iterator) # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) - + return ( + batch['tokens'], + batch['labels'], + batch['loss_mask'], + batch['attention_mask'], + batch['position_ids'], + None + ) else: raise ValueError("please set correct --dataset ") @@ -146,7 +153,7 @@ def forward_step(data_iterator, model): # Get the batch. timers('batch-generator', log_level=2).start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + tokens, labels, loss_mask, attention_mask, position_ids, _ = get_batch( data_iterator) timers('batch-generator').stop() diff --git a/examples/llama3/pretrain_llama.py b/examples/llama3/pretrain_llama.py index dc101d5..1e1f5f7 100644 --- a/examples/llama3/pretrain_llama.py +++ b/examples/llama3/pretrain_llama.py @@ -90,7 +90,7 @@ def get_batch(data_iterator): # TODO: this is pretty hacky, find a better way if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): - return None, None, None, None, None + return None, None, None, None, None, None args = get_args() @@ -105,7 +105,14 @@ def get_batch(data_iterator): batch = get_batch_on_this_tp_rank(data_iterator) # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) - + return ( + batch['tokens'], + batch['labels'], + batch['loss_mask'], + batch['attention_mask'], + batch['position_ids'], + None + ) else: raise ValueError("please set correct --dataset ") @@ -154,7 +161,7 @@ def forward_step(data_iterator, model: GPTModel): # Get the batch. timers('batch-generator', log_level=2).start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + tokens, labels, loss_mask, attention_mask, position_ids, _ = get_batch( data_iterator) timers('batch-generator').stop() diff --git a/examples/llama3/pretrain_llama_mcore070.py b/examples/llama3/pretrain_llama_mcore070.py index 3836fe9..9a75f52 100644 --- a/examples/llama3/pretrain_llama_mcore070.py +++ b/examples/llama3/pretrain_llama_mcore070.py @@ -72,7 +72,7 @@ def get_batch(data_iterator): # TODO: this is pretty hacky, find a better way if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): - return None, None, None, None, None + return None, None, None, None, None, None args = get_args() @@ -87,7 +87,14 @@ def get_batch(data_iterator): batch = get_batch_on_this_tp_rank(data_iterator) # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) - + return ( + batch['tokens'], + batch['labels'], + batch['loss_mask'], + batch['attention_mask'], + batch['position_ids'], + None + ) else: raise ValueError("please set correct --dataset ") @@ -137,7 +144,7 @@ def forward_step(data_iterator, model: GPTModel): # Get the batch. timers('batch-generator', log_level=2).start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + tokens, labels, loss_mask, attention_mask, position_ids, _ = get_batch( data_iterator) timers('batch-generator').stop() output_tensor = model(tokens, position_ids, attention_mask, diff --git a/examples/llama3_1/pretrain_llama.py b/examples/llama3_1/pretrain_llama.py index c5abf4e..bf4c8ae 100644 --- a/examples/llama3_1/pretrain_llama.py +++ b/examples/llama3_1/pretrain_llama.py @@ -90,7 +90,7 @@ def get_batch(data_iterator): # TODO: this is pretty hacky, find a better way if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): - return None, None, None, None, None, None + return None, None, None, None, None, None, None args = get_args() @@ -133,7 +133,11 @@ def get_batch(data_iterator): # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) - return tuple([*batch.values(), packed_seq_params]) + + if args.train_mode == "pretrain": + return tuple([*batch.values(), None, packed_seq_params]) + else: + return tuple([*batch.values(), packed_seq_params]) else: raise ValueError("please set correct --dataset ") @@ -184,7 +188,7 @@ def forward_step(data_iterator, model: GPTModel): # Get the batch. timers("batch-generator", log_level=2).start() - tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch(data_iterator) + tokens, labels, loss_mask, attention_mask, position_ids, _, packed_seq_params = get_batch(data_iterator) timers("batch-generator").stop() output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params) diff --git a/examples/mistral/pretrain_mcore_mistral.py b/examples/mistral/pretrain_mcore_mistral.py index ff0b129..6cf0f2a 100644 --- a/examples/mistral/pretrain_mcore_mistral.py +++ b/examples/mistral/pretrain_mcore_mistral.py @@ -40,12 +40,12 @@ from megatron_patch.data import build_pretrain_dataset_from_original from megatron_patch.data.utils import get_batch_on_this_tp_rank_original, get_batch_on_this_tp_rank_idxmap_sft -from megatron_patch.model.mixtral.layer_specs import ( +from megatron_patch.model.mixtral_bak.layer_specs import ( get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, ) -from megatron_patch.model.mixtral.model import GPTModel -from megatron_patch.model.mixtral.transformer_config import TransformerConfig +from megatron_patch.model.mixtral_bak.model import GPTModel +from megatron_patch.model.mixtral_bak.transformer_config import TransformerConfig from megatron_patch.tokenizer import build_tokenizer, get_tokenizer from megatron.core.packed_seq_params import PackedSeqParams diff --git a/examples/mistral/pretrain_mcore_mistral_bak.py b/examples/mistral/pretrain_mcore_mistral_bak.py index ba2efec..a9c1113 100644 --- a/examples/mistral/pretrain_mcore_mistral_bak.py +++ b/examples/mistral/pretrain_mcore_mistral_bak.py @@ -69,7 +69,7 @@ def get_batch(data_iterator): # TODO: this is pretty hacky, find a better way if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): - return None, None, None, None, None + return None, None, None, None, None, None args = get_args() @@ -84,7 +84,14 @@ def get_batch(data_iterator): batch = get_batch_on_this_tp_rank(data_iterator) # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) - + return ( + batch['tokens'], + batch['labels'], + batch['loss_mask'], + batch['attention_mask'], + batch['position_ids'], + None + ) else: raise ValueError("please set correct --dataset ") @@ -128,7 +135,7 @@ def forward_step(data_iterator, model): # Get the batch. timers('batch-generator', log_level=2).start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + tokens, labels, loss_mask, attention_mask, position_ids, _ = get_batch( data_iterator) timers('batch-generator').stop() diff --git a/examples/qwen1_5/pretrain_mcore_qwen.py b/examples/qwen1_5/pretrain_mcore_qwen.py index 6661a0e..ebdc017 100644 --- a/examples/qwen1_5/pretrain_mcore_qwen.py +++ b/examples/qwen1_5/pretrain_mcore_qwen.py @@ -91,7 +91,7 @@ def get_batch(data_iterator): # TODO: this is pretty hacky, find a better way if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): - return None, None, None, None, None + return None, None, None, None, None, None args = get_args() @@ -107,6 +107,14 @@ def get_batch(data_iterator): # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) + return ( + batch['tokens'], + batch['labels'], + batch['loss_mask'], + batch['attention_mask'], + batch['position_ids'], + None + ) else: raise ValueError("please set correct --dataset ") @@ -155,7 +163,7 @@ def forward_step(data_iterator, model: GPTModel): # Get the batch. timers('batch-generator', log_level=2).start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + tokens, labels, loss_mask, attention_mask, position_ids, _ = get_batch( data_iterator) timers('batch-generator').stop() diff --git a/examples/qwen2_vl/README.md b/examples/qwen2_vl/README.md index 88d5dae..4f0e010 100755 --- a/examples/qwen2_vl/README.md +++ b/examples/qwen2_vl/README.md @@ -57,6 +57,9 @@ wget https://atp-modelzoo-wlcb-pai.oss-cn-wulanchabu.aliyuncs.com/release/models tar -zxf wds.tgz ``` +对于视频多模态、单样本中包含多张图片、多轮对话等复杂数据集,您需要将其转换为sharegpt格式数据后再使用Megatron-Patch训练。对于sharegpt格式的数据处理,参见[链接](./dataset_preparation.md)。 + + ## Megatron-Core模型训练流程 ### Megatron-Core模型格式转换 运行`hf2mcore_qwen2_vl_convertor.sh`脚本,需要传入的参数列表如下 @@ -84,6 +87,21 @@ false \ bf16 ``` +当您需要将训练好的checkpoint转换回huggingface格式用于推理时,执行 + +```bash +cd /workspace/Pai-Megatron-Patch/toolkits/model_checkpoints_convertor/qwen +bash hf2mcore_qwen2_vl_convertor.sh \ +7B \ +/mnt/qwen2-vl-ckpts/Qwen2-VL-7B-Instruct-tp2pp2 \ +/mnt/qwen2-vl-ckpts/Qwen2-VL-7B-Instruct-tp2pp2-back \ +2 \ +2 \ +true \ +bf16 \ +/mnt/qwen2-vl-ckpts/Qwen2-VL-7B-Instruct +``` + ### Megatron-Core预训练 #### 预训练命令描述 @@ -102,7 +120,7 @@ TP=${10} # 模型并行度 PP=${11} # 流水并行度 CP=${12} # 上下文并行度 DO=${13} # 是否使用Megatron版Zero-1降显存优化器: true, false -FL=${14} # 是否优先使用Flash Attention: true, false +FL=${14} # 是否优先使用Flash Attention: false AC=${15} # 激活检查点模式: sel, full, offload, false OPTIMIZER_OFFLOAD=${16} # 是否启用Offload optimizer: false, static, auto SAVE_INTERVAL=${17} # 保存ckpt的间隔 @@ -123,17 +141,47 @@ sh run_mcore_qwen.sh \ dsw \ 7B \ 1 \ -256 \ -0.00015 \ +32 \ +1e-5 \ +1e-6 \ +2048 \ +2048 \ +bf16 \ +2 \ +2 \ +1 \ +true \ +false \ +true \ +false \ +100000 \ +/mnt/llava-datasets/LLaVA-Pretrain/wds \ +/mnt/llava-datasets/LLaVA-Pretrain/wds \ +/mnt/qwen2-vl-ckpts/Qwen2-VL-7B-Instruct-tp2pp2 \ +20000 \ +200 \ +/workspace/output_mcore_qwen2vl_pretrain +``` + +由于PP切分时,PP Rank 0额外的ViT会导致其负载略高于其他PP Rank,为了达到最佳性能,您可能需要调整`MP_PP0_LAYERS`变量降低PP Rank 0的LLM层数。 + +```bash +cd /workspace/Pai-Megatron-Patch/examples/qwen2_vl +MP_PP0_LAYERS=12 sh run_mcore_qwen.sh \ +dsw \ +7B \ +1 \ +32 \ 1e-5 \ -1024 \ -1024 \ +1e-6 \ +2048 \ +2048 \ bf16 \ 2 \ 2 \ 1 \ true \ -true \ +false \ true \ false \ 100000 \ diff --git a/examples/qwen2_vl/dataset_helpers.py b/examples/qwen2_vl/dataset_helpers.py index 7041363..474cefe 100644 --- a/examples/qwen2_vl/dataset_helpers.py +++ b/examples/qwen2_vl/dataset_helpers.py @@ -1,4 +1,16 @@ -# TODO: Add a License +# Copyright (c) 2024 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import dataclasses import re import sys @@ -11,6 +23,7 @@ import numpy as np import torch from torchvision import transforms as T +import json from megatron.energon import ( Batch, @@ -18,6 +31,8 @@ VQASample, ) +from megatron_patch.data.energon.chatml import ChatMLSample + from megatron.training import get_args from megatron_patch.tokenizer import get_tokenizer @@ -26,14 +41,16 @@ class ImageTaskSample: __key__: str __subflavors__: Dict - # (c, h, w) - imgs: List[np.ndarray] - image_thw_grids: List[Tuple[int]] - video_thw_grids: List[Tuple[int]] + + imgs: List[np.ndarray] # (c, h, w) + videos: List[np.ndarray] # (c, h, w) + + image_thw_grids: np.ndarray + video_thw_grids: np.ndarray image_input_mask: np.ndarray video_input_mask: np.ndarray text: np.ndarray - target: torch.Tensor = None + target: np.ndarray # Typing for the resulting batch data after encode_batch() @dataclass @@ -42,6 +59,7 @@ class VQATaskBatch(Batch): __subflavors__: List[Dict] # (num_tiles, c, h, w) imgs: torch.Tensor + videos: torch.Tensor image_thw_grids: torch.Tensor video_thw_grids: torch.Tensor image_input_mask: torch.Tensor @@ -90,7 +108,7 @@ def convert_to_qwen2vl_content( return contents -class TaskEncoder(DefaultTaskEncoder[VQASample, ImageTaskSample, VQATaskBatch, dict]): +class TaskEncoder(DefaultTaskEncoder[Union[VQASample, ChatMLSample], ImageTaskSample, VQATaskBatch, dict]): """A simple task encoder for captioning.""" def __init__( @@ -110,43 +128,29 @@ def __init__( self.seq_len = self.args.max_padding_length - def encode_sample(self, sample: VQASample): + def encode_sample(self, sample: Union[VQASample, ChatMLSample]): if isinstance(sample, VQASample): is_llava_training = sample.__subflavors__['is_llava_training'] if 'is_llava_training' in sample.__subflavors__ else False if is_llava_training: raise NotImplementedError('Sample format not supported') else: yield self.encode_vqa(sample) + elif isinstance(sample, ChatMLSample): + yield self.encode_chatml(sample) else: raise NotImplementedError('Sample format not supported') - def encode_vqa(self, sample: VQASample): - augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False - has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False - - if has_video: - # Grab the selected frames of the video as a tensor with shape - # fhwc: (num_frames, height, width, num_channels). - # video_fhwc = sample.image.permute(0, 2, 3, 1) - # selected_frames = torch.linspace( - # 0, video_fhwc.shape[0] - 1, self.args.num_frames).long() - # video_frame_fhwc = video_fhwc[selected_frames] - # imgs = [] - # for video_frame_hwc in video_frame_fhwc: - # imgs += get_visual_transform( - # video_frame_hwc, self.img_h, self.img_w, - # self.args.use_tiling, self.args.max_num_tiles, - # self.args.use_thumbnail, augment=False) - raise NotImplementedError() - else: - # TODO: add args - imgs = get_visual_transform( - sample.image - ) - resized_height, resized_width = imgs[0].shape[-2:] - # shape: c x img_h x img_w - # split single image into tiles for dynamic resolution - patches = np.tile(np.array(imgs[0]), (self.temporal_patch_size, 1, 1, 1)) + def _flatten_visual_inputs(self, visuals, is_image: bool = True): + flattened = [] + thw_grids = [] + for visual in visuals: + if is_image: + resized_height, resized_width = visual.shape[-2:] + patches = np.tile(np.array(visual), (self.temporal_patch_size, 1, 1, 1)) + else: + assert len(visual) % self.temporal_patch_size == 0 + patches = np.array(visual) + resized_height, resized_width = patches.shape[-2:] channel = patches.shape[1] grid_t = patches.shape[0] // self.temporal_patch_size @@ -162,18 +166,171 @@ def encode_vqa(self, sample: VQASample): self.merge_size, self.patch_size, ) - patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) + patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) flatten_patches = patches.reshape( grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size - ) + ) + flattened.append(flatten_patches) + thw_grids.append((grid_t, grid_h, grid_w)) + return flattened, np.array(thw_grids) + + def encode_chatml(self, sample: ChatMLSample): + # TODO: modify get_visual_transform to add more augmentations + imgs = [get_visual_transform(img)[0] for img in sample.imgs] + videos = [[get_visual_transform(frame)[0] for frame in video] for video in sample.videos] + + # NOTE: flatten all images + flattened_imgs, image_thw_grids = self._flatten_visual_inputs(imgs, is_image=True) + flattened_videos, video_thw_grids = self._flatten_visual_inputs(videos, is_image=False) + + # NOTE: generate qwen2vl conversations + conversation = json.loads(sample.conversation) if isinstance(sample.conversation, (str, bytes)) else sample.conversation + + role_key = 'from' if 'from' in conversation[0] else 'role' + content_key = 'value' if 'from' in conversation[0] else 'content' + + # NOTE: assume the conversation format is: [System]? (User Assistant)+ + converted_conversation = [] + if len(conversation) % 2 == 0: + # Default Prompt + converted_conversation.append({ + 'role': 'system', + 'content': 'You are a helpful assistant.' + }) + else: + converted_conversation.append({ + 'role': 'system', + 'content': conversation[0][content_key] + }) + conversation = conversation[1:] + + EXPECTED_ROLE = ['human', 'gpt'] + for turn_idx, turn in enumerate(conversation): + role = turn[role_key] + if role != EXPECTED_ROLE[turn_idx % len(EXPECTED_ROLE)]: + raise InternalWarning(f"Expect conversation organized in order: [sys] human gpt human gpt..., but got role '{role}' in turn {turn_idx}") + content = turn[content_key] + + if role == 'human': + role = 'user' + content = convert_to_qwen2vl_content(content) + elif role == 'gpt': + role = 'assistant' + + converted_conversation.append({ + 'role': role, + 'content': content + }) + conversation = converted_conversation + + # NOTE: we need to mask all system/user input tokens and assistant generation prefix tokens + input_ids = self.tokenizer.apply_chat_template(conversation, tokenize=True, return_tensors="np")[0] + target = input_ids.copy() + + system_prompt_prefix = len(self.tokenizer.apply_chat_template([conversation[0]], tokenize=True)) + assistant_generation_prefix = 3 + pad_token_id = self.tokenizer.pad_token_id + + target[:system_prompt_prefix] = pad_token_id + offset = system_prompt_prefix + for turn_idx, turn in enumerate(conversation[1:]): + turn_tokens = self.tokenizer.apply_chat_template([turn], tokenize=True, return_tensors="np")[0] + turn_content = turn_tokens[system_prompt_prefix:] + n_tokens = len(turn_content) + if (target[offset: offset + n_tokens] != turn_content).any(): + raise InternalWarning("Encode Error") + + if turn['role'] == 'user': + target[offset: offset + n_tokens] = pad_token_id + elif turn['role'] == 'assistant': + target[offset: offset + assistant_generation_prefix] = pad_token_id + offset += n_tokens + + # NOTE: expand image_pad & video_pad + merge_length = self.merge_size**2 + image_token_id, video_token_id = self.tokenizer.encode(['<|image_pad|>', '<|video_pad|>']) + + image_token_indices = np.where(input_ids == image_token_id)[0] + assert len(image_token_indices) == len(image_thw_grids), f"With {len(image_thw_grids)} images in the sample, but {len(image_token_indices)} image placeholders!" + video_token_indices = np.where(input_ids == video_token_id)[0] + assert len(video_token_indices) == len(video_thw_grids), f"With {len(video_thw_grids)} images in the sample, but {len(video_token_indices)} video placeholders!" + image_thw_grids, video_thw_grids = np.array(image_thw_grids, dtype=np.int64), np.array(video_thw_grids, dtype=np.int64) + + target_length = ( + input_ids.shape[0] + - image_thw_grids.shape[0] + image_thw_grids.prod(axis=-1).sum() // merge_length + - video_thw_grids.shape[0] + video_thw_grids.prod(axis=-1).sum() // merge_length + ) + if target_length > self.seq_len: + raise InternalWarning(f"Long sequence with length {target_length} found, dropped...") + final_input_ids = np.zeros(target_length, dtype=input_ids.dtype) + final_input_masks = final_input_ids.copy() + + image_idx, video_idx = 0, 0 + indices = np.sort(np.concatenate([image_token_indices, video_token_indices])) + + cur_x, cur_y = 0, 0 + for idx in indices: + token_id = input_ids[idx] + if token_id == image_token_id: + size = image_thw_grids[image_idx].prod() // merge_length + image_idx += 1 + elif token_id == video_token_id: + size = video_thw_grids[video_idx].prod() // merge_length + video_idx += 1 + # NOTE: + # input_ids[cur_x:idx] -> final_input_ids[cur_y:cur_y + idx - cur_x] + # input_ids[idx] -> final_input_ids[cur_y + idx - cur_x: cur_y + idx - cur_x + size] + final_input_ids[cur_y: cur_y + idx - cur_x] = input_ids[cur_x:idx] + final_input_masks[cur_y: cur_y + idx - cur_x] = target[cur_x:idx] + cur_y += idx - cur_x + final_input_ids[cur_y: cur_y + size] = token_id + final_input_masks[cur_y: cur_y + size] = pad_token_id + cur_y += size + cur_x = idx + 1 + + if cur_x < len(input_ids): + final_input_ids[cur_y:] = input_ids[cur_x:] + final_input_masks[cur_y:] = target[cur_x:] - # flatten_patches, (grid_t, grid_h, grid_w) - thw_grids = [(grid_t, grid_h, grid_w)] + target = np.roll(final_input_masks, shift=-1) + target[-1] = pad_token_id - assert "" in sample.context # ? + if (target == pad_token_id).all(): + raise InternalWarning("Sample with all masked label, dropped.") - # NOTE: we expect a context is a string with conetnt + image_input_mask = final_input_ids == self.tokenizer.image_token_id + video_input_mask = final_input_ids == self.tokenizer.video_token_id + # collect data + return ImageTaskSample( + __key__=sample.__key__, + __subflavors__=sample.__subflavors__, + imgs=flattened_imgs, + videos=flattened_videos, + + image_thw_grids=image_thw_grids, + video_thw_grids=video_thw_grids, + + image_input_mask=image_input_mask, + video_input_mask=video_input_mask, + + text=final_input_ids, + target=target, + ) + + def encode_vqa(self, sample: VQASample): + augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False + has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False + if has_video: + raise NotImplementedError("You should use sharegpt dataset to train with videos.") + else: + # TODO: add args + imgs = get_visual_transform(sample.image) + flatten_patches, thw_grids = self._flatten_visual_inputs(imgs, is_image=True) + + assert "" in sample.context # ? + # NOTE: we expect a context is a string with conetnt if isinstance(sample.answers, list): answer_list = sample.answers weight_list = np.array(sample.answer_weights).astype(np.float32) @@ -231,30 +388,49 @@ def encode_vqa(self, sample: VQASample): return ImageTaskSample( __key__=sample.__key__, __subflavors__=sample.__subflavors__, + imgs=flatten_patches, + videos=list(), + image_thw_grids=thw_grids, - video_thw_grids=None, + video_thw_grids=torch.empty([0, 3], dtype=torch.long), + image_input_mask=image_input_mask, video_input_mask=None, + text=input_ids, target=target, ) def batch(self, samples: List[ImageTaskSample]) -> VQATaskBatch: # Stack images to [num_tiles, c, h, w]. If there are no images (text-only), then use a dummy image. - imgs = [s.imgs for s in samples] + imgs = [img for s in samples for img in s.imgs] if len(imgs) > 0: imgs = torch.cat([torch.from_numpy(img) for img in imgs]) else: imgs = torch.empty([0, 3 * self.temporal_patch_size * self.patch_size * self.patch_size], dtype=torch.float32) - thw_grids = [thw_grids for s in samples for thw_grids in s.image_thw_grids] - if len(thw_grids) > 0: - thw_grids = torch.from_numpy(np.array(thw_grids)).long() - assert thw_grids.prod(dim=-1).sum() == imgs.shape[0] + image_thw_grids = [thw_grids for s in samples for thw_grids in s.image_thw_grids] + if len(image_thw_grids) > 0: + image_thw_grids = torch.from_numpy(np.array(image_thw_grids)).long() + assert image_thw_grids.prod(dim=-1).sum() == imgs.shape[0] + else: + image_thw_grids = torch.empty([0, 3], dtype=torch.long) + + # Stack videos to [num_tiles, c, h, w]. If there are no videos (text-only), then use a dummy video. + videos = [video for s in samples for video in s.videos] + if len(videos) > 0: + videos = torch.cat([torch.from_numpy(video) for video in videos]) else: - thw_grids = torch.empty([0, 3], dtype=torch.long) + videos = torch.empty([0, 3 * self.temporal_patch_size * self.patch_size * self.patch_size], dtype=torch.float32) + video_thw_grids = [thw_grids for s in samples for thw_grids in s.video_thw_grids] + if len(video_thw_grids) > 0: + video_thw_grids = torch.from_numpy(np.array(video_thw_grids)).long() + assert video_thw_grids.prod(dim=-1).sum() == videos.shape[0] + else: + video_thw_grids = torch.empty([0, 3], dtype=torch.long) + # If the user hasn't defined a target sequence length, then use the max along the sample lengths. max_seq_len = self.seq_len if not max_seq_len: @@ -283,8 +459,9 @@ def batch(self, samples: List[ImageTaskSample]) -> VQATaskBatch: __keys__=[s.__key__ for s in samples], __subflavors__=[s.__subflavors__ for s in samples], imgs=imgs, - image_thw_grids=thw_grids, - video_thw_grids=None, + videos=videos, + image_thw_grids=image_thw_grids, + video_thw_grids=video_thw_grids, image_input_mask=torch.from_numpy(image_input_masks), video_input_mask=torch.from_numpy(video_input_masks), text=torch.from_numpy(text_mat), diff --git a/examples/qwen2_vl/dataset_preparation.md b/examples/qwen2_vl/dataset_preparation.md new file mode 100644 index 0000000..ad5dd58 --- /dev/null +++ b/examples/qwen2_vl/dataset_preparation.md @@ -0,0 +1,72 @@ +# 准备Qwen2-VL多模态数据集 + +当前Qwen2-VL支持特定格式的复杂多模态样本的训练,您可按照下述流程将您的数据集转换为Qwen2-VL的支持格式。 + +## 原始数据 + +在转换前,你可能需要自行将数据集转换为**sharegpt格式**,示例如下: +```json +[ + { + "conversations": [ + { + "from": "human", + "value": "human instruction" + }, + { + "from": "gpt", + "value": "model response" + }, + { + "from": "human", + "value": "