Skip to content

Commit

Permalink
fix rlhf bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Nov 6, 2024
1 parent 41a63ed commit 4d0bc19
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 18 deletions.
10 changes: 8 additions & 2 deletions rlhf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@

两个参数配置文件,第一个为```common_args.py```, 其余不同方法的配置在```rlhf_args```文件夹内

建议使用deepspeed启动,启动脚本在```script/rlhf_run.sh```
建议使用deepspeed启动,启动脚本在```rlhf_run.sh```
```bash
bash rlhf_run.sh
```

- rlhf_type: [PPO,RLOO,CPO,DPO,SimPO,CPOSimPO,Reward]
- train_mode: [lora, qlora, full]
Expand Down Expand Up @@ -130,7 +133,10 @@ res_length为64
具体介绍可参见文章:[知识蒸馏](https://zhuanlan.zhihu.com/p/1064724364)

### Quick Star
进入script目录下bash运行```gkd_run.sh```即可,修改对应参数运行。同样支持Deepspeed,参数介绍可看上述文章。
进入script目录下bash运行```gkd_run.sh```即可,修改对应参数运行。同样支持Deepspeed.
```bash
bash gkd_run.sh
```

**参数介绍**
- lmbda:0时为Supervised KD,1时为GKD。可在[0,1]范围内选择,这样就会混合比例
Expand Down
2 changes: 1 addition & 1 deletion rlhf/script/gkd_run.sh → rlhf/gkd_run.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# 使用显卡数量需在yaml文件中修改num_processes参数

# Lora模式, 如需QLora或者全参略微修改参数即可
CUDA_VISIBLE_DEVICES=2,3 accelerate launch --config_file ./ds_zero3.yaml ../train_gkd.py \
CUDA_VISIBLE_DEVICES=2,3 accelerate launch --config_file ./ds_config/ds_zero3.yaml ./train_gkd.py \
--model_name_or_path deepseek-coder-6.7b-instruct \
--teacher_model_name_or_path deepseek-coder-33b-instruct\
--dataset_name ../data_example/gkd_data.jsonl \
Expand Down
4 changes: 2 additions & 2 deletions rlhf/rlhf_args/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class BaseConfig(TrainingArguments):
"""
训练参数
"""
output_dir: str = field(default='', metadata={"help": "模型训练完成后的保存路径"})
output_dir: str = field(default='./output', metadata={"help": "模型训练完成后的保存路径"})
num_train_epochs: int = 1,

per_device_train_batch_size: int = 2
Expand All @@ -24,5 +24,5 @@ class BaseConfig(TrainingArguments):
optim: str = 'adamw_torch'
report_to: str = 'tensorboard'
remove_unused_columns: bool = False
bf16: bool = True
bf16: bool = False
fp16: bool = False
2 changes: 1 addition & 1 deletion rlhf/rlhf_args/cpo_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from base_config import BaseConfig
from rlhf_args.base_config import BaseConfig
from trl import CPOConfig as TrlCPOConfig


Expand Down
2 changes: 1 addition & 1 deletion rlhf/rlhf_args/dpo_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from base_config import BaseConfig
from rlhf_args.base_config import BaseConfig
from typing import Literal
from trl import DPOConfig as TrlDPOConfig

Expand Down
2 changes: 1 addition & 1 deletion rlhf/rlhf_args/ppo_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from trl import PPOConfig as TrlPPOConfig
from base_config import BaseConfig
from rlhf_args.base_config import BaseConfig


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion rlhf/rlhf_args/reward_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from base_config import BaseConfig
from rlhf_args.base_config import BaseConfig
from trl import RewardConfig as TrlRewardConfig


Expand Down
2 changes: 1 addition & 1 deletion rlhf/rlhf_args/rloo_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from typing import Optional
from base_config import BaseConfig
from rlhf_args.base_config import BaseConfig
from trl import RLOOConfig as TrlPLOOConfig


Expand Down
6 changes: 3 additions & 3 deletions rlhf/script/rlhf_run.sh → rlhf/rlhf_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ TRAIN_DATA='./'
MODEL_PATH='./'
OUTPUT_PATH='./'

CUDA_VISIBLE_DEVICES=2,3 accelerate launch --config_file ../ds_config/ds_zero2.yaml ../train_rlhf.py \
CUDA_VISIBLE_DEVICES=2,3 accelerate launch --config_file ./ds_config/ds_zero2.yaml ./train_rlhf.py \
--model_name_or_path "$MODEL_PATH" \
--train_data_path "$TRAIN_DATA" \
--output_dir "$OUTPUT_PATH" \
Expand All @@ -19,13 +19,13 @@ CUDA_VISIBLE_DEVICES=2,3 accelerate launch --config_file ../ds_config/ds_zero2.y
--gradient_accumulation_steps 8 \
--logging_steps 2 \
--num_train_epochs 1 \
--fb16 \
--bf16 \
--save_strategy "steps" \
--report_to "wandb" \
--save_steps 180 \
--save_total_limit 5 \
--warmup_steps 10 \
--no_remove_unused_columns \
--remove_unused_columns False\
--lr_scheduler_type "cosine"

# [CPO,DPO,SimPO,CPOSimPO,Reward] 可直接使用上述运行
Expand Down
24 changes: 19 additions & 5 deletions rlhf/train_rlhf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import os
from peft import LoraConfig, TaskType
from datasets import load_dataset
from transformers import (
Expand All @@ -16,6 +17,8 @@
from loguru import logger
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template

os.environ["TOKENIZERS_PARALLELISM"] = "false"

WITH_REWARD_MODEL = ['RLOO', 'PPO']
USE_REF_MODEL = ['DPO']

Expand Down Expand Up @@ -49,8 +52,8 @@ def load_config(args, remaining_args):
class_name = args.rlhf_type + "Config"
# 使用getattr获取模块中的类
argument = getattr(module, class_name)
train_argument = argument()
parser_b = HfArgumentParser((train_argument,))

parser_b = HfArgumentParser((argument,))
train_args, = parser_b.parse_args_into_dataclasses(args=remaining_args)
return train_args

Expand Down Expand Up @@ -179,15 +182,21 @@ def main():
eval_dataset=train_dataset['test'] if config.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=lora_config,
),
)
if args.rlhf_type == 'DPO'
else dict()
,
'CPO': dict(
model=policy,
args=config,
train_dataset=train_dataset['train'],
eval_dataset=train_dataset['test'] if config.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=lora_config,
),
)
if args.rlhf_type == 'CPO'
else dict()
,
"PPO": dict(
),
"RLOO": dict(
Expand All @@ -198,7 +207,10 @@ def main():
reward_model=reward_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
),
)
if args.rlhf_type == 'RLOO'
else dict()
,
'Reward': dict(
model=policy,
processing_class=tokenizer,
Expand All @@ -207,6 +219,8 @@ def main():
eval_dataset=train_dataset['test'] if config.eval_strategy != "no" else None,
peft_config=lora_config,
)
if args.rlhf_type == 'Reward'
else dict()
}
trainer_kwargs_map['SimPO'] = trainer_kwargs_map['CPO'].copy()
trainer_kwargs_map['CPOSimPO'] = trainer_kwargs_map['CPO'].copy()
Expand Down

0 comments on commit 4d0bc19

Please sign in to comment.