Skip to content

Commit

Permalink
update command line run (ing)
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 7, 2024
1 parent b56c449 commit 72e4d44
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
32 changes: 14 additions & 18 deletions main_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,26 @@
from utils.data_process import MultiRoundDataProcess, DpoDataset
from utils.data_collator import SftDataCollator
from utils.args import CommonArgs
import importlib
from datasets import load_dataset
from trl import DPOTrainer


def load_config(train_args_path):
# 根据config_option加载相应的配置
module_path = train_args_path.replace("/", ".").rstrip(".py")
# 动态导入模块
module = importlib.import_module(module_path)
# 每个模块导入的类名均为TrainArgument
class_name = "TrainArgument"

# 使用getattr获取模块中的类
TrainArgument = getattr(module, class_name)
train_argument = TrainArgument()
return train_argument


def initial_args():
parser = HfArgumentParser((CommonArgs,))
args = parser.parse_args_into_dataclasses()[0]
# args = parser.parse_args_into_dataclasses()[0]
args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
# 根据CommonArgs中的config_option动态加载配置
train_args = load_config(args.train_args_path)
# train_args = load_config(args.train_args_path)
if args.train_args_path == "sft_args":
parser_b = HfArgumentParser((sft_TrainArgument,))
train_args, = parser_b.parse_args_into_dataclasses(args=remaining_args)
print("Loaded instance sft_args")
elif args.train_args_path == "dpo_args":
parser_c = HfArgumentParser((dpo_TrainArgument,))
train_args, = parser_c.parse_args_into_dataclasses(args=remaining_args)
print(f"Loaded instance dpo_args")
else:
raise ValueError("Invalid train_args_path choice")

if not os.path.exists(train_args.output_dir):
os.mkdir(train_args.output_dir)
Expand Down Expand Up @@ -79,7 +75,7 @@ def create_tokenizer(args):
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.bos_token_id = tokenizer.eod_id
tokenizer.eos_token_id = tokenizer.eod_id
if tokenizer.bos_token is None: # qwen没有bos_token,要设置一下,不然dpo train时会报错。
if tokenizer.bos_token is None: # qwen没有bos_token,要设置一下,不然dpo train时会报错。
tokenizer.add_special_tokens({"bos_token": tokenizer.eos_token})
tokenizer.bos_token_id = tokenizer.eos_token_id

Expand Down
6 changes: 3 additions & 3 deletions utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class TrainMode(Enum):


class TrainArgPath(Enum):
SFT_LORA_QLORA_BASE = 'train_args/sft/lora_qlora/base.py'
DPO_LORA_QLORA_BASE = 'train_args/dpo/dpo_config.py'
SFT_LORA_QLORA_BASE = 'sft_args'
DPO_LORA_QLORA_BASE = 'dpo_args'


@dataclass
Expand All @@ -22,7 +22,7 @@ class CommonArgs:
# Deepspeed相关参数
local_rank: int = field(default=1, metadata={"help": "deepspeed所需参数,单机无需修改,如出现报错可注释掉"})

train_args_path: TrainArgPath = field(default=TrainArgPath.SFT_LORA_QLORA_BASE.value,
train_args_path: TrainArgPath = field(default='sft_args',
metadata={"help": "当前模式的训练参数,分为sft和dpo参数"})
max_len: int = field(default=1024, metadata={"help": "最大输入长度,dpo时该参数在dpo_config中设置"})
max_prompt_length: int = field(default=512, metadata={
Expand Down

0 comments on commit 72e4d44

Please sign in to comment.