Skip to content

Commit

Permalink
update rlhf
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 29, 2024
1 parent 89faa35 commit ac7fe10
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 2 additions & 0 deletions rlhf/rlhf_args/cpo-simpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ class CPOSimPOConfig(CPOConfig):
cpo_alpha"""
eval_samples: int = 30
"""eval sample的数量,注意不能少于batchsize*gradient_accumulation_steps"""
sft_model_path: str = "./"
"""the path to the sft model"""

8 changes: 4 additions & 4 deletions rlhf/rlhf_args/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ class CPOConfig(CPOConfig):
"""Whether to sample and log generations during evaluation step."""
is_encoder_decoder: Optional[bool] = None
"""If no model is provided, we need to know if the model_init returns an encoder-decoder."""
model_init_kwargs: Optional[Dict] = None
"""Dict of Optional kwargs to pass when instantiating the model from a string"""
sft_model_path: str = "./"
"""the path to the sft model"""



dataset_num_proc: Optional[int] = None
"""The number of workers to use to tokenize the data. Defaults to None."""

# TrainingArguments的相关参数
train_data_path: Optional[str] = field(default='./', metadata={"help": "训练集路径"})
Expand Down
2 changes: 2 additions & 0 deletions rlhf/rlhf_args/simpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ class SimPOConfig(CPOConfig):
"""A target reward margin for the SimPO loss, used only when the "simpo" option is enabled."""
eval_samples: int = 30
"""eval sample的数量,注意不能少于batchsize*gradient_accumulation_steps"""
sft_model_path: str = "./"
"""the path to the sft model"""

0 comments on commit ac7fe10

Please sign in to comment.