Skip to content

Commit

Permalink
update dpo example
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 12, 2024
1 parent 3769aad commit 84d2024
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 5 deletions.
100 changes: 95 additions & 5 deletions llm_tricks/DPO_example/dpo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@
"execution_count": null,
"outputs": [],
"source": [
"def compute_logprobs(logits, labels):\n",
"def compute_logprobs(logits, labels, mask=None):\n",
" \"\"\"\n",
" logits: shape (batch_size, sequence_len, vocab_size)\n",
" labels: shape (batch_size, sequence_len)\n",
Expand All @@ -178,16 +178,73 @@
" input=logps,\n",
" dim=1,\n",
" index=labels.unsqueeze(1)\n",
" ).squeeze(1)"
" ).squeeze(1)\n",
" \n",
" if mask is not None:\n",
" mask = mask[:,1:].clone()\n",
" # 进行掩码padding部分\n",
" select_logprobs = select_logprobs * mask\n",
" # 计算平均\n",
" average_logprobs = select_logprobs.sum(-1) / mask.sum(-1)\n",
" return average_logprobs\n",
" else:\n",
" return select_logprobs.mean(-1)"
],
"metadata": {
"collapsed": false
},
"id": "ca63797467a68c1e"
},
{
"cell_type": "markdown",
"source": [
"clone 示例"
],
"metadata": {
"collapsed": false
},
"id": "5636ff98ad0814b3"
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 13,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([2, 3, 4])\n"
]
}
],
"source": [
"mask = torch.tensor([1,2,3])\n",
"mask1 = mask\n",
"mask1 += 1\n",
"print(mask)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-08-12T16:24:01.170731600Z",
"start_time": "2024-08-12T16:24:01.158727300Z"
}
},
"id": "72b9adb35c2cf553"
},
{
"cell_type": "markdown",
"source": [
"tensor shape示例"
],
"metadata": {
"collapsed": false
},
"id": "dbac1fe29cc58257"
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -230,11 +287,44 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-08-12T10:05:21.666926700Z",
"start_time": "2024-08-12T10:05:21.637344Z"
"end_time": "2024-08-12T16:01:46.797611600Z",
"start_time": "2024-08-12T16:01:44.279995600Z"
}
},
"id": "4b1f1f33b9e7f613"
},
{
"cell_type": "markdown",
"source": [
"进行batch的dpo loss计算"
],
"metadata": {
"collapsed": false
},
"id": "4043bc53c7bcd4d0"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"def compute_batch_loss(batch, policy_model, reference_model, beta):\n",
" \"\"\"Compute the DPO loss on an input batch\"\"\"\n",
" policy_chosen_logps = compute_logprobs(\n",
" logits=policy_model(batch[\"chosen\"]),\n",
" labels=batch[\"chosen\"],\n",
" mask=batch[\"chosen_mask\"]\n",
" )\n",
" policy_rejected_logps = compute_logprobs(\n",
" logits=policy_model(batch[\"rejected\"]),\n",
" labels=batch[\"rejected\"],\n",
" mask=batch[\"rejected_mask\"]\n",
" )"
],
"metadata": {
"collapsed": false
},
"id": "3211a04a645fb478"
}
],
"metadata": {
Expand Down
85 changes: 85 additions & 0 deletions llm_tricks/DPO_example/dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch.nn.functional as F
import torch.nn as nn
import torch


def compute_logprobs(logits, labels, mask=None):
"""
logits: shape (batch_size, sequence_len, vocab_size)
labels: shape (batch_size, sequence_len)
"""

# 需要先进行位移操作
# 去掉标签的第一个
labels = labels[:, 1:].clone()
# 去掉模型输出的最后一个
logits = logits[:, :-1, :]

logps = F.log_softmax(logits, dim=-1)

select_logprobs = torch.gather(
input=logps,
dim=1,
index=labels.unsqueeze(1)
).squeeze(1)

if mask is not None:
mask = mask[:, 1:].clone()
# 进行掩码padding部分
select_logprobs = select_logprobs * mask
# 计算平均
average_logprobs = select_logprobs.sum(-1) / mask.sum(-1)
return average_logprobs
else:
return select_logprobs.mean(-1)


def compute_batch_loss(batch, policy_model, reference_model, beta):
"""Compute the DPO loss on an input batch"""
policy_chosen_logps = compute_logprobs(
logits=policy_model(batch["chosen"]),
labels=batch["chosen"],
mask=batch["chosen_mask"]
)
policy_rejected_logps = compute_logprobs(
logits=policy_model(batch["rejected"]),
labels=batch["rejected"],
mask=batch["rejected_mask"]
)


class DPOLoss(nn.Module):
"""
DPO Loss
"""

def __init__(self, beta: float = 0.1) -> None:
super().__init__()
self.beta = beta

def forward(
self,
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
):
"""
policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,)
policy_rejected_logps: Shape: (batch_size,)
reference_chosen_logps: Shape: (batch_size,)
reference_rejected_logps: Shape: (batch_size,)
"""
policy_logps = policy_chosen_logps - policy_rejected_logps
reference_logps = reference_chosen_logps - reference_rejected_logps
logits = policy_logps - reference_logps

loss = -F.logsigmoid(self.beta * logits)

# 下面两个用于追踪训练的进度
chosen_rewards = (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = (policy_rejected_logps - reference_rejected_logps).detach()

# 对每个batch进行平均
return loss.mean(), chosen_rewards.mean(), rejected_rewards.mean()

0 comments on commit 84d2024

Please sign in to comment.