Skip to content

Commit

Permalink
update dpo example
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 26, 2024
1 parent 407f57e commit ced93fd
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 458 deletions.
16 changes: 16 additions & 0 deletions llm_tricks/DPO_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# 从零实现强化学习DPO(SimPO)训练代码

## Quick start
```python
python dpo_train.py
```

## 说明
本文档下的从零实现只是一个学习的demo,用以理解原理所用,并没有增加分布式等。所以尽管使用2B的小模型,显存占用也高达30+GB。

```dpo_train.py```为训练主路径, 相关loss计算在```loss.py```.

如果想要使用DPO或者Simpo、CPO等强化学习方法真正训练的话,
可以使用本项目中的rlhf构建的强化学习框架:[RLHF](../../rlhf/README.md)

支持deepspeed的单机多卡Lora、Dora、Qlora、全量参数训练,并自动适配模型的chat template。
454 changes: 0 additions & 454 deletions llm_tricks/DPO_example/dpo.ipynb

This file was deleted.

7 changes: 4 additions & 3 deletions llm_tricks/DPO_example/dpo_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
IGNORE_INDEX = False


def data_collector(batch, pad_token_id, device, max_length=None, if_mask_prompt=True):
def data_collate(batch, pad_token_id, device, max_length=None, if_mask_prompt=True):
batch_data = {
"prompt": [],
"chosen": [],
Expand Down Expand Up @@ -75,7 +75,7 @@ def data_collector(batch, pad_token_id, device, max_length=None, if_mask_prompt=


customized_collate_fn = partial(
data_collector,
data_collate,
pad_token_id=tokenizer.pad_token_id,
device=device,
if_mask_prompt=True,
Expand All @@ -98,8 +98,9 @@ def data_collector(batch, pad_token_id, device, max_length=None, if_mask_prompt=
drop_last=False
)

# 3、开始计算DPO(或其他)的损失函数

# 3、开始计算DPO(或其他)的损失函数
# 相关代码可以再loss里查看,就不写在主函数里了。

# 4、编写训练函数
def train_model(
Expand Down
37 changes: 36 additions & 1 deletion llm_tricks/DPO_example/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,33 @@ def forward(
return loss.mean(), chosen_rewards.mean(), rejected_rewards.mean()


class SimPo(nn.Module):
"""
SimPO Loss
"""

def __init__(self, beta: float = 0.1, gamma: float = 0.5) -> None:
super().__init__()
self.beta = beta
self.gamma = gamma

def forward(
self,
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
):
"""
policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,)
policy_rejected_logps: Shape: (batch_size,)
"""
logits = policy_chosen_logps - policy_rejected_logps
logits = logits - self.gamma
loss = -F.logsigmoid(self.beta * logits)

# 对每个batch进行平均(期望)
return loss.mean()


# 计算每个模型的Log probabilities
def compute_logprobs(logits, labels, mask=None):
"""
Expand Down Expand Up @@ -101,7 +128,8 @@ def compute_logprobs_f_cross(logits, labels, mask=None):

def compute_batch_loss(batch, policy_model, reference_model, beta):
# 决定使用哪个loss
loss_fn = DPOLoss(beta)
# loss_fn = SimPo(beta, 0.5) SimPO loss
loss_fn = DPOLoss(beta) # DPO loss

policy_chosen_logps = compute_logprobs(
logits=policy_model(batch["chosen"]).logits,
Expand Down Expand Up @@ -129,6 +157,12 @@ def compute_batch_loss(batch, policy_model, reference_model, beta):
reference_chosen_logps=reference_chosen_logps,
reference_rejected_logps=reference_rejected_logps,
)
# SimPO使用如下
# loss = loss_fn(
# policy_chosen_logps=policy_chosen_logps,
# policy_rejected_logps=policy_rejected_logps,
# )
# return loss
return loss, chosen_rewards, rejected_rewards


Expand Down Expand Up @@ -157,6 +191,7 @@ def compute_loss_dataloader(data_loader, policy_model, reference_model, beta, nu


if __name__ == "__main__":
# 测试compute_logprobs_f_cross 与 compute_logprobs
logits = torch.tensor(
[[2.0, 1.0, 0.1, 0.4],
[0.5, 2.5, 0.3, 0.5],
Expand Down
1 change: 1 addition & 0 deletions main_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from utils.args import CommonArgs
from datasets import load_dataset
from trl import DPOTrainer
from rlhf.utils.utils import is_right_apply_chat, fix_chat_template_if_needed


def initial_args():
Expand Down
1 change: 1 addition & 0 deletions rlhf/rlhf_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from common_args import CommonArgs
from loguru import logger
from utils.utils import is_right_apply_chat, fix_chat_template_if_needed


def load_config(args):
Expand Down
52 changes: 52 additions & 0 deletions rlhf/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import List, Dict
import copy


def is_right_apply_chat(tokenizer, prompt: List[Dict[str, str]], assistant_content: List[Dict[str, str]]) -> bool:
"""
Checks if the assistant's content is correctly applied to the prompt in a chat template.
Args:
tokenizer: The tokenizer.
prompt: The initial prompt message.
assistant_content: The content provided by the assistant.
Returns:
bool: True if the assistant's content is correctly applied, False otherwise.
"""
try:
test_assistant = tokenizer.apply_chat_template(assistant_content, tokenize=False)
test_prompt = tokenizer.apply_chat_template(prompt, tokenize=False)
conversation = copy.deepcopy(prompt)
conversation.append(assistant_content[0])
if tokenizer.apply_chat_template(conversation) == test_prompt + test_assistant:
return True
else:
return False
except Exception as e:
return False


def fix_chat_template_if_needed(tokenizer, prompt: List[Dict[str, str]], chosen: List[Dict[str, str]],
rejected: List[Dict[str, str]]):
"""
Fixes the chat template if needed.
Args:
tokenizer: The tokenizer.
prompt: The initial prompt message.
chosen: The chosen response, a list containing a single dictionary representing the chosen message.
rejected: The rejected response, a list containing a single dictionary representing the rejected message.
Returns:
- tuple: A tuple containing the fixed prompt, fixed chosen response, and fixed rejected response.
"""
conversation_chosen = copy.deepcopy(prompt)
conversation_rejected = copy.deepcopy(prompt)
conversation_chosen.append(chosen[0])
conversation_rejected.append(rejected[0])
conversation_chosen = tokenizer.apply_chat_template(conversation_chosen, tokenize=False)
conversation_rejected = tokenizer.apply_chat_template(conversation_rejected, tokenize=False)
# find position
start_position = conversation_chosen.find(chosen[0]['content'][0])
# The following is right
fixed_prompt = conversation_chosen[:start_position]
fixed_chosen = conversation_chosen[start_position:]
fixed_rejected = conversation_rejected[start_position:]
return fixed_prompt, fixed_chosen, fixed_rejected

0 comments on commit ced93fd

Please sign in to comment.