Skip to content

Commit

Permalink
Merge pull request mst272#13 from aJupyter/main
Browse files Browse the repository at this point in the history
update readme and run_example.sh
  • Loading branch information
mst272 authored Aug 7, 2024
2 parents 7620834 + 050ea09 commit b56c449
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 4 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ RLHF训练框架,支持并持续更新Reward训练、PPO、DPO、RLOO、SimPO
- [x] [deepseek-coder](https://github.com/deepseek-ai/DeepSeek-Coder)
- [x] [哔哩哔哩 Index-1.9B](https://github.com/bilibili/Index-1.9B)
- [x] [baichuan系列](https://github.com/baichuan-inc/Baichuan2)
- 待更新GLM
- [x] [GLM系列](https://github.com/THUDM/GLM-4)
- 待更新Mistral系列
### 已更新tricks讲解
所有相关的trciks及讲解都在llm_tricks文件夹下
- [Dora代码讲解(llm_tricks/dora/READEME.md)](./llm_tricks/dora/READEME.md)
Expand Down
3 changes: 1 addition & 2 deletions main_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, Trainer, \
BitsAndBytesConfig, HfArgumentParser, set_seed
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, cast_mixed_precision_params
from train_args import dpo_TrainArgument, sft_TrainArgument
import bitsandbytes as bnb
from utils.template import template_dict
from utils.data_process import MultiRoundDataProcess, DpoDataset
Expand All @@ -31,8 +32,6 @@ def load_config(train_args_path):


def initial_args():
# parser = HfArgumentParser((CommonArgs, TrainArgument))
# reward_args, train_args = parser.parse_args_into_dataclasses()
parser = HfArgumentParser((CommonArgs,))
args = parser.parse_args_into_dataclasses()[0]
# 根据CommonArgs中的config_option动态加载配置
Expand Down
10 changes: 10 additions & 0 deletions run_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
deepspeed --include localhost:0,1 main_train.py\
--train_data_path 数据集路径\
--model_name_or_path 模型路径\
--task_type sft\
--train_mode qlora\
--output_dir 输出路径

# task_type:[pretrain, sft, dpo_multi, dpo_single]

# python main_train.py --train_data_path 数据集路径 --model_name_or_path 模型路径
7 changes: 7 additions & 0 deletions train_args/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .dpo.dpo_config import TrainArgument as dpo_TrainArgument
from .sft.lora_qlora.base import TrainArgument as sft_TrainArgument

__all__ = [
"dpo_TrainArgument",
"sft_TrainArgument",
]
2 changes: 1 addition & 1 deletion utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class CommonArgs:
一些常用的自定义参数
"""
# Deepspeed相关参数
local_rank: int = field(default=1, metadata={"help": "deepspeed所需参数,单机无需修改"})
local_rank: int = field(default=1, metadata={"help": "deepspeed所需参数,单机无需修改,如出现报错可注释掉"})

train_args_path: TrainArgPath = field(default=TrainArgPath.SFT_LORA_QLORA_BASE.value,
metadata={"help": "当前模式的训练参数,分为sft和dpo参数"})
Expand Down

0 comments on commit b56c449

Please sign in to comment.