diff --git a/train/README.md b/train/README.md index 911dfee3..f5f6fc23 100644 --- a/train/README.md +++ b/train/README.md @@ -96,7 +96,7 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r 训练的启动脚本写在scripts/run.sh,你需要按照实际需求修改run.sh中的参数 ```bash -bash scripts/run.sh +bash scripts/run_sft.sh ``` - model_name_or_path 代表预训练模型(如果是LLaMA模型,需事先转为hf格式才能通过from_pretrained读取) @@ -116,7 +116,7 @@ run.sh中包含了全量参数微调和LoRA两种训练方式的启动命令, 下面的命令是单机多卡进行全量参数微调,同时采用deepspeed,基础模型是LLaMA ```bash -torchrun --nproc_per_node 8 train.py \ +torchrun --nproc_per_node 8 src/entrypoint/sft_train.py \ --model_name_or_path ${model_name_or_path} \ --llama \ --deepspeed configs/deepspeed_config.json \ @@ -180,7 +180,7 @@ trainer_state.json记录了loss、learning_rate的变化 #### 2.2.2 LoRA ```bash -torchrun --nproc_per_node 8 train.py \ +torchrun --nproc_per_node 8 src/entry_point/sft_train.py \ --model_name_or_path ${model_name_or_path} \ --llama \ --use_lora True \ @@ -284,7 +284,7 @@ torchrun --nproc_per_node 8 --nnodes 2 --master_addr ${master_addr} --master_por 如果您看到了这里,说明您已经完成了训练。现在我们加载训练好的模型,验证模型生成文本的效果。 ```bash -CUDA_VISIBLE_DEVICES=0 python src/inference.py \ +CUDA_VISIBLE_DEVICES=0 python src/entry_point/inference.py \ --model_name_or_path model_name_or_path \ --ckpt_path ckpt_path \ --llama \ @@ -307,7 +307,7 @@ CUDA_VISIBLE_DEVICES=0 python src/inference.py \ 我们也提供了一个简洁的基于gradio的交互式web界面,启动服务: ```bash -CUDA_VISIBLE_DEVICES=0 python src/interface.py \ +CUDA_VISIBLE_DEVICES=0 python src/entry_point/interface.py \ --model_name_or_path model_name_or_path \ --ckpt_path ckpt_path \ --llama \ @@ -334,7 +334,7 @@ bash scripts/run_multi_backend.sh 首先,您需要从[facebookresearch/llama](https://github.com/facebookresearch/llama)获取LLaMA模型的访问权限,下载官方检查点 ```bash -python training_scripts/convert_llama_weights_to_hf.py --input_dir download_official_llama_path --model_size 7B --output_dir xx/llama-7b-hf +python scripts/convert_llama_weights_to_hf.py --input_dir download_official_llama_path --model_size 7B --output_dir xx/llama-7b-hf ``` 运行训练脚本时将model_name_or_path改为xx/llama-7b-hf即可 diff --git a/train/scripts/multinode_run.sh b/train/scripts/multinode_run.sh index f9319c11..86e61ad9 100644 --- a/train/scripts/multinode_run.sh +++ b/train/scripts/multinode_run.sh @@ -17,7 +17,7 @@ cutoff_len=1024 master_addr="10.111.112.223" # #Multi-node -torchrun --nproc_per_node 8 --nnodes 2 --master_addr ${master_addr} --master_port 14545 --node_rank ${node_rank} src/entry_point/train.py \ +torchrun --nproc_per_node 8 --nnodes 2 --master_addr ${master_addr} --master_port 14545 --node_rank ${node_rank} src/entry_point/sft_train.py \ --model_name_or_path ${model_name_or_path} \ --llama \ --deepspeed configs/deepspeed_config.json \ diff --git a/train/scripts/run.sh b/train/scripts/run_pt.sh similarity index 94% rename from train/scripts/run.sh rename to train/scripts/run_pt.sh index 8de8c14d..da95adea 100644 --- a/train/scripts/run.sh +++ b/train/scripts/run_pt.sh @@ -17,7 +17,7 @@ mkdir -p ${cache_dir} cutoff_len=1024 #FT -# torchrun --nproc_per_node 8 src/entry_point/train.py \ +# torchrun --nproc_per_node 8 src/entry_point/pt_train.py \ # --ddp_timeout 36000 \ # --model_name_or_path ${model_name_or_path} \ # --llama \ @@ -46,7 +46,7 @@ cutoff_len=1024 #LoRA with 8bit -# torchrun --nproc_per_node 8 src/entry_point/train.py \ +# torchrun --nproc_per_node 8 src/entry_point/pt_train.py \ # --ddp_timeout 36000 \ # --model_name_or_path ${model_name_or_path} \ # --llama \ @@ -76,7 +76,7 @@ cutoff_len=1024 # # --resume_from_checkpoint ... # LoRA without 8bit -torchrun --nproc_per_node 8 src/entry_point/train.py \ +torchrun --nproc_per_node 8 src/entry_point/pt_train.py \ --ddp_timeout 36000 \ --model_name_or_path ${model_name_or_path} \ --llama \ diff --git a/train/scripts/run_sft.sh b/train/scripts/run_sft.sh new file mode 100644 index 00000000..eee53431 --- /dev/null +++ b/train/scripts/run_sft.sh @@ -0,0 +1,106 @@ +#! /bin/bash +export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' +export WANDB_PROJECT=... +export WANDB_RUN_ID=... +export WANDB_RESUME=allow +export ABS_PATH=... +export PYTHONPATH="$ABS_PATH/BELLE/train" +model_name_or_path=/path_to_llm/hf_llama_7b/ # or bloomz-7b1-mt + +train_file=belleMath.json +validation_file=belleMath-dev1K.json +output_dir="$ABS_PATH/saved_models/${WANDB_PROJECT}_${WANDB_RUN_ID}" +mkdir -p ${output_dir} + +cache_dir=hf_cache_dir +mkdir -p ${cache_dir} +cutoff_len=1024 + +#FT +# torchrun --nproc_per_node 8 src/entry_point/sft_train.py \ +# --ddp_timeout 36000 \ +# --model_name_or_path ${model_name_or_path} \ +# --llama \ +# --deepspeed configs/deepspeed_config.json \ +# --train_file ${train_file} \ +# --validation_file ${validation_file} \ +# --per_device_train_batch_size 2 \ +# --per_device_eval_batch_size 2 \ +# --gradient_accumulation_steps 4 \ +# --num_train_epochs 2 \ +# --model_max_length ${cutoff_len} \ +# --save_strategy "steps" \ +# --save_total_limit 3 \ +# --learning_rate 8e-6 \ +# --weight_decay 0.00001 \ +# --warmup_ratio 0.05 \ +# --lr_scheduler_type "cosine" \ +# --logging_steps 10 \ +# --evaluation_strategy "steps" \ +# --fp16 \ +# --seed 1234 \ +# --gradient_checkpointing \ +# --cache_dir ${cache_dir} \ +# --output_dir ${output_dir} \ +# # --resume_from_checkpoint ... + + +#LoRA with 8bit +# torchrun --nproc_per_node 8 src/entry_point/sft_train.py \ +# --ddp_timeout 36000 \ +# --model_name_or_path ${model_name_or_path} \ +# --llama \ +# --use_lora \ +# --use_int8_training \ +# --lora_config configs/lora_config_llama.json \ +# --train_file ${train_file} \ +# --validation_file ${validation_file} \ +# --per_device_train_batch_size 1 \ +# --per_device_eval_batch_size 1 \ +# --gradient_accumulation_steps 8 \ +# --num_train_epochs 2 \ +# --model_max_length ${cutoff_len} \ +# --save_strategy "steps" \ +# --save_total_limit 3 \ +# --learning_rate 8e-6 \ +# --weight_decay 0.00001 \ +# --warmup_ratio 0.05 \ +# --lr_scheduler_type "cosine" \ +# --logging_steps 10 \ +# --evaluation_strategy "steps" \ +# --fp16 \ +# --seed 1234 \ +# --gradient_checkpointing \ +# --cache_dir ${cache_dir} \ +# --output_dir ${output_dir} \ +# # --resume_from_checkpoint ... + +# LoRA without 8bit +torchrun --nproc_per_node 8 src/entry_point/sft_train.py \ + --ddp_timeout 36000 \ + --model_name_or_path ${model_name_or_path} \ + --llama \ + --use_lora \ + --deepspeed configs/deepspeed_config_stage3.json \ + --lora_config configs/lora_config_llama.json \ + --train_file ${train_file} \ + --validation_file ${validation_file} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --num_train_epochs 10 \ + --model_max_length ${cutoff_len} \ + --save_strategy "steps" \ + --save_total_limit 3 \ + --learning_rate 3e-4 \ + --weight_decay 0.00001 \ + --warmup_ratio 0.01 \ + --lr_scheduler_type "cosine" \ + --logging_steps 10 \ + --evaluation_strategy "steps" \ + --fp16 \ + --seed 1234 \ + --gradient_checkpointing \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + # --resume_from_checkpoint ... diff --git a/train/src/entry_point/pt_train.py b/train/src/entry_point/pt_train.py new file mode 100644 index 00000000..ebff9a61 --- /dev/null +++ b/train/src/entry_point/pt_train.py @@ -0,0 +1,512 @@ + +from transformers.utils import add_start_docstrings +from transformers.trainer_utils import get_last_checkpoint +from transformers.trainer_pt_utils import torch_distributed_zero_first +from transformers import (AutoModelForCausalLM, AutoTokenizer, + HfArgumentParser, LlamaTokenizer, TrainingArguments, + set_seed) +from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training +from datasets import load_dataset +import transformers +import torch + +from typing import Optional +from functools import partial +from dataclasses import dataclass, field +import os +import math +import logging +import json +import sys + +from src.utils import get_model_param_count +from src.trainer import MyTrainer as Trainer +from src.sample_generator import batch_grouped_pretrain_generate + + + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." + ) + }, + ) + config_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where do you want to store the pretrained models downloaded from huggingface.co" + }, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + }, + ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + llama: bool = field(default=False, metadata={"help": "Llama model"}) + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, + metadata={ + "help": "The name of the dataset to use (via the datasets library)."}, + ) + dataset_config_name: Optional[str] = field( + default=None, + metadata={ + "help": "The configuration name of the dataset to use (via the datasets library)." + }, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a text file)."} + ) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)." + }, + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + + +@dataclass +@add_start_docstrings(TrainingArguments.__doc__) +class TrainingArguments(TrainingArguments): + model_max_length: int = field( + default=512, + metadata={"help": "Maximum sequence length."}, + ) + use_lora: bool = field( + default=False, + metadata={"help": "Whether to use LoRA."} + ) + use_int8_training: bool = field( + default=False, metadata={"help": "Whether to use int8 training."} + ) + lora_config: Optional[str] = field( + default=None, + metadata={"help": "LoRA config file."}, + ) + ddp_find_unused_parameters: bool = field( + default=False, metadata={"help": "ddp_find_unused_parameters"} + ) + gradient_checkpointing: bool = field( + default=False, metadata={"help": "gradient_checkpointing"} + ) + # https://discuss.huggingface.co/t/wandb-does-not-display-train-eval-loss-except-for-last-one/9170 + evaluation_strategy: str = field( + default="steps", metadata={"help": "wandb bug fix"} + ) + save_total_limit: Optional[int] = field( + default=3, + metadata={ + "help": "keep saved model less than save_total_limit, delete old checkpoints when save new model"} + ) + report_to: str = field( + default="wandb", + metadata={"help": "places where report the results"} + ) + + +def print_rank_0(msg, log_file, rank=0): + if rank <= 0: + with open(log_file, "a") as f: + print(msg) + f.write(msg + "\n") + + +def main(): + parser = HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + world_size = int(os.environ.get("WORLD_SIZE", 1)) + global_rank = torch.distributed.get_rank() + log_file = os.path.join(training_args.output_dir, "print_log.txt") + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif ( + last_checkpoint is not None and training_args.resume_from_checkpoint is None + ): + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + training_args.data_seed = training_args.seed + + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + # int8 is not compatible with DeepSpeed (require not to pass device_map) + if training_args.use_int8_training: + print_rank_0( + "int8 is not compatible with DeepSpeed. ", + log_file, + global_rank + ) + device_map = ( + {"": int(os.environ.get("LOCAL_RANK") or 0)} + if world_size != 1 else "auto" + ) + # device_map = "auto" + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + load_in_8bit=True, # xxx: int8 load in + device_map=device_map, # xxx: int8 requires passing device_map + torch_dtype=torch_dtype, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + torch_dtype=torch_dtype, + ) + + if model_args.llama: + tokenizer = LlamaTokenizer.from_pretrained( + model_args.model_name_or_path + ) + print_rank_0( + "Set the eos_token_id and bos_token_id of LLama model tokenizer", + log_file, + global_rank, + ) + tokenizer.eos_token_id = 2 + tokenizer.bos_token_id = 1 + else: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path + ) + + tokenizer.pad_token_id = 0 + tokenizer.padding_side = "left" # Allow batched inference + + print_rank_0( + "tokenizer.eos_token_id = {}".format(tokenizer.eos_token_id), + log_file, + global_rank, + ) + print_rank_0( + "tokenizer.pad_token_id = {}".format(tokenizer.pad_token_id), + log_file, + global_rank, + ) + print_rank_0( + "tokenizer.bos_token_id = {}".format(tokenizer.bos_token_id), + log_file, + global_rank, + ) + + # peft model + if training_args.use_lora: + print_rank_0( + "Loading lora config from {}".format(training_args.lora_config), + log_file, + global_rank, + ) + lora_config = json.load(open(training_args.lora_config)) + print_rank_0( + "Lora config: {}".format(lora_config), + log_file, + global_rank + ) + if training_args.use_int8_training: + print_rank_0( + "training_args.use_int8_training!!! (int8 is not compatible with DeepSpeed)", + log_file, + global_rank, + ) + model = prepare_model_for_int8_training(model) + config = LoraConfig( + r=lora_config["lora_r"], + lora_alpha=lora_config["lora_alpha"], + target_modules=lora_config["lora_target_modules"], + lora_dropout=lora_config["lora_dropout"], + bias="none", + task_type="CAUSAL_LM", + ) + + # "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn" + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model = get_peft_model(model, config) + model.print_trainable_parameters() + + if training_args.gradient_checkpointing: + model.gradient_checkpointing_enable() + + # model.is_parallelizable = True + # model.model_parallel = True + + assert os.path.exists(data_args.train_file), "{} file not exists".format( + data_args.train_file + ) + + with torch_distributed_zero_first(global_rank): + train_data = load_dataset( + "json", + data_files=data_args.train_file, + cache_dir=model_args.cache_dir + ) + + val_data = load_dataset( + "json", + data_files=data_args.validation_file, + cache_dir=model_args.cache_dir + ) + + train_data = train_data["train"].shuffle().map( + partial( + batch_grouped_pretrain_generate, + training_args.model_max_length, + tokenizer + ), + batched=True, + desc=f"Grouping texts in chunks of {training_args.model_max_length}", + remove_columns='text' + ) + + val_data = val_data["train"].shuffle().map( + partial( + batch_grouped_pretrain_generate, + training_args.model_max_length, + tokenizer + ), + batched=True, + desc=f"Grouping texts in chunks of {training_args.model_max_length}", + remove_columns='text' + ) + + + for i in range(2): + print_rank_0( + "Eval tokenized example: {}".format(val_data[i]), + log_file, + global_rank + ) + for i in range(2): + print_rank_0( + "Train tokenized example: {}".format(train_data[i]), + log_file, + global_rank + ) + + training_nums = len(train_data) + num_gpus = torch.cuda.device_count() + + batch_size = ( + training_args.per_device_train_batch_size + * training_args.world_size + * training_args.gradient_accumulation_steps + ) + # train steps + t_total = math.ceil(training_nums / batch_size) * \ + training_args.num_train_epochs + # eval steps + training_args.eval_steps = max(t_total // 5, 5) + # save steps + training_args.save_steps = training_args.eval_steps + training_args.warmup_steps = ( + int(t_total * training_args.warmup_ratio) + if training_args.warmup_ratio > 0.0 + else training_args.warmup_steps + ) + print_rank_0( + "num_gpus = {}, training_nums = {}, t_total = {}, warmup_steps = {}, eval_steps = {}, save_steps = {}".format( + num_gpus, + training_nums, + t_total, + training_args.warmup_steps, + training_args.eval_steps, + training_args.save_steps, + ), + log_file, + global_rank, + ) + print_rank_0( + "val data nums = {}, training_nums = {}, batch_size = {}".format( + len(val_data), training_nums, batch_size + ), + log_file, + global_rank, + ) + + # Trainer + # https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py + # https://github.com/huggingface/transformers/blob/main/src/transformers/data/data_collator.py + # https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py + # https://www.deepspeed.ai/docs/config-json/ + # https://huggingface.co/docs/accelerate/usage_guides/deepspeed + # https://huggingface.co/transformers/v4.10.1/main_classes/deepspeed.html + # https://github.com/tatsu-lab/stanford_alpaca/issues/176 + trainer = Trainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=train_data, + eval_dataset=val_data, + data_collator=transformers.DataCollatorForSeq2Seq( + tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True + ), + ) + + print_rank_0( + f"Using {training_args.half_precision_backend} half precision backend", + log_file, + global_rank, + ) + # Train! + len_dataloader = len(trainer.get_train_dataloader()) + num_update_steps_per_epoch = ( + len_dataloader // training_args.gradient_accumulation_steps + ) + + total_train_batch_size = ( + training_args.train_batch_size + * training_args.gradient_accumulation_steps + * training_args.world_size + ) + num_examples = trainer.num_examples(trainer.get_train_dataloader()) + num_train_samples = num_examples * training_args.num_train_epochs + max_steps = math.ceil(training_args.num_train_epochs * \ + num_update_steps_per_epoch) + print_rank_0("***** Running training *****", log_file, global_rank) + print_rank_0(f" Num examples = {num_examples}", log_file, global_rank) + print_rank_0( + f" Num train samples = {num_train_samples}", + log_file, + global_rank + ) + print_rank_0(f" world_size = {world_size}", log_file, global_rank) + print_rank_0( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}", + log_file, + global_rank, + ) + print_rank_0( + f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}", + log_file, + global_rank, + ) + print_rank_0( + f" Total optimization steps = {max_steps}", + log_file, + global_rank + ) + + print_rank_0( + f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True)}", + log_file, + global_rank, + ) + + # https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958/3 + model.config.use_cache = False + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + trainer.train(resume_from_checkpoint=checkpoint) + + print_rank_0( + "\n Training completed!!! If there's a warning about missing keys above, please disregard :)", + log_file, + global_rank, + ) + + +if __name__ == "__main__": + main() diff --git a/train/src/entry_point/train.py b/train/src/entry_point/sft_train.py similarity index 99% rename from train/src/entry_point/train.py rename to train/src/entry_point/sft_train.py index b36c4285..0b338f8f 100644 --- a/train/src/entry_point/train.py +++ b/train/src/entry_point/sft_train.py @@ -9,6 +9,7 @@ from datasets import load_dataset import transformers import torch + from typing import Optional from functools import partial from dataclasses import dataclass, field @@ -23,6 +24,7 @@ from src.sample_generator import generate_and_tokenize_prompt + logger = logging.getLogger(__name__) @@ -351,7 +353,7 @@ def make_inputs_require_grad(module, input, output): tokenizer ) ) - + val_data = val_data["train"].shuffle().map( partial( generate_and_tokenize_prompt, diff --git a/train/src/sample_generator.py b/train/src/sample_generator.py index a872f3e7..7c2c0128 100644 --- a/train/src/sample_generator.py +++ b/train/src/sample_generator.py @@ -1,11 +1,20 @@ +from itertools import chain +from typing import Any, Dict, List import pudb import copy from transformers import PreTrainedTokenizer import json + IGNORE_INDEX = -100 -def generate_and_tokenize_prompt(model_max_length: int, tokenizer: PreTrainedTokenizer, data_point): +def generate_and_tokenize_prompt( + model_max_length: int, + tokenizer: PreTrainedTokenizer, + data_point: Dict[str, Any], + fix_length=False, + padding_side="left", +): input_ids = [] labels = [] source = data_point["conversations"] @@ -29,19 +38,40 @@ def generate_and_tokenize_prompt(model_max_length: int, tokenizer: PreTrainedTok labels += label # add eos at every end of assistant sentence if sentence_from != "human": - input_ids += [ - tokenizer.eos_token_id - ] # make sure eos_token_id is correct + input_ids += [tokenizer.eos_token_id] # make sure eos_token_id is correct labels += [tokenizer.eos_token_id] - input_ids = input_ids[: model_max_length - 1] - labels = labels[: model_max_length - 1] + input_ids = input_ids[:model_max_length] + labels = labels[:model_max_length] + if all(x == IGNORE_INDEX for x in labels): labels[18:24] = input_ids[ 18:24 ] # labels can not have all values being -100. 18 and 24 are just random numbers - attention_mask = [1] * len(input_ids) + + if fix_length: + if padding_side == "left": + input_ids = [tokenizer.pad_token_id] * ( + model_max_length - len(input_ids) + ) + input_ids + labels = [tokenizer.pad_token_id] * ( + model_max_length - len(labels) + ) + labels + attention_mask = [0] * ( + model_max_length - len(attention_mask) + ) + attention_mask + else: + input_ids = input_ids + [tokenizer.pad_token_id] * ( + model_max_length - len(input_ids) + ) + labels = labels + [tokenizer.pad_token_id] * ( + model_max_length - len(labels) + ) + attention_mask = attention_mask + [0] * ( + model_max_length - len(attention_mask) + ) + tokenized_full_prompt = { "input_ids": input_ids, "attention_mask": attention_mask, @@ -50,39 +80,42 @@ def generate_and_tokenize_prompt(model_max_length: int, tokenizer: PreTrainedTok return tokenized_full_prompt -def pretrain_generate(model_max_length: int, tokenizer: PreTrainedTokenizer, data_point): - input_ids = tokenizer.encode(data_point['text']) - labels = copy.deepcopy(input_ids) - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] - input_ids = input_ids[: model_max_length] - labels = labels[: model_max_length] - return { - "input_ids": input_ids, - "attention_mask": [1] * len(input_ids), - "labels": labels, - } +def batch_grouped_pretrain_generate( + model_max_length: int, + tokenizer: PreTrainedTokenizer, + examples: Dict[str, List[str]], +) -> Dict[str, List[List[int]]]: + # build grouped texts with format `X1 X2 X3 ... X1 X2 X3 ... []` + token_ids_list: List[List[int]] = tokenizer( + examples["text"], add_special_tokens=False + )["input_ids"] + token_ids_list = [ + token_ids + [tokenizer.eos_token_id] for token_ids in token_ids_list + ] + concatenated_ids = list(chain(*token_ids_list)) + # we drop the small remainder, and if the total_length < block_size, we exclude this batch + total_length = (len(concatenated_ids) // model_max_length) * model_max_length + result = [ + concatenated_ids[i : i + model_max_length] + for i in range(0, total_length, model_max_length) + ] + return {"input_ids": result, "labels": result.copy()} def exam_generate(model_max_length: int, tokenizer: PreTrainedTokenizer, data_point): - template = 'Human: \n{human}\n\nAssistant: \n' + template = "Human: \n{human}\n\nAssistant: \n" # pudb.set_trace() input_str = template.format( human=f'回答下面的{data_point["type"]}题,用json返回答案,包括原因和答案,如{{"reason":..., "answer":...}}\n{data_point["question"]}\n选项:{" ".join(data_point["candidates"])}' ) - input_ids = tokenizer.encode( - input_str, - add_special_tokens=False - ) + input_ids = tokenizer.encode(input_str, add_special_tokens=False) labels = [IGNORE_INDEX] * len(input_ids) bot_ids = tokenizer.encode( json.dumps( - { - 'reason': data_point['reason'], - 'answer': data_point['answer'] - }, ensure_ascii=False + {"reason": data_point["reason"], "answer": data_point["answer"]}, + ensure_ascii=False, ), - add_special_tokens=False + add_special_tokens=False, ) input_ids += bot_ids labels += bot_ids