diff --git a/llm_tricks/DPO_example/dpo.ipynb b/llm_tricks/DPO_example/dpo.ipynb index ee49521..6fc8cf5 100644 --- a/llm_tricks/DPO_example/dpo.ipynb +++ b/llm_tricks/DPO_example/dpo.ipynb @@ -160,7 +160,7 @@ "execution_count": null, "outputs": [], "source": [ - "def compute_logprobs(logits, labels):\n", + "def compute_logprobs(logits, labels, mask=None):\n", " \"\"\"\n", " logits: shape (batch_size, sequence_len, vocab_size)\n", " labels: shape (batch_size, sequence_len)\n", @@ -178,16 +178,73 @@ " input=logps,\n", " dim=1,\n", " index=labels.unsqueeze(1)\n", - " ).squeeze(1)" + " ).squeeze(1)\n", + " \n", + " if mask is not None:\n", + " mask = mask[:,1:].clone()\n", + " # 进行掩码padding部分\n", + " select_logprobs = select_logprobs * mask\n", + " # 计算平均\n", + " average_logprobs = select_logprobs.sum(-1) / mask.sum(-1)\n", + " return average_logprobs\n", + " else:\n", + " return select_logprobs.mean(-1)" ], "metadata": { "collapsed": false }, "id": "ca63797467a68c1e" }, + { + "cell_type": "markdown", + "source": [ + "clone 示例" + ], + "metadata": { + "collapsed": false + }, + "id": "5636ff98ad0814b3" + }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 13, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([2, 3, 4])\n" + ] + } + ], + "source": [ + "mask = torch.tensor([1,2,3])\n", + "mask1 = mask\n", + "mask1 += 1\n", + "print(mask)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-12T16:24:01.170731600Z", + "start_time": "2024-08-12T16:24:01.158727300Z" + } + }, + "id": "72b9adb35c2cf553" + }, + { + "cell_type": "markdown", + "source": [ + "tensor shape示例" + ], + "metadata": { + "collapsed": false + }, + "id": "dbac1fe29cc58257" + }, + { + "cell_type": "code", + "execution_count": 1, "outputs": [ { "name": "stdout", @@ -230,11 +287,44 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-08-12T10:05:21.666926700Z", - "start_time": "2024-08-12T10:05:21.637344Z" + "end_time": "2024-08-12T16:01:46.797611600Z", + "start_time": "2024-08-12T16:01:44.279995600Z" } }, "id": "4b1f1f33b9e7f613" + }, + { + "cell_type": "markdown", + "source": [ + "进行batch的dpo loss计算" + ], + "metadata": { + "collapsed": false + }, + "id": "4043bc53c7bcd4d0" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "def compute_batch_loss(batch, policy_model, reference_model, beta):\n", + " \"\"\"Compute the DPO loss on an input batch\"\"\"\n", + " policy_chosen_logps = compute_logprobs(\n", + " logits=policy_model(batch[\"chosen\"]),\n", + " labels=batch[\"chosen\"],\n", + " mask=batch[\"chosen_mask\"]\n", + " )\n", + " policy_rejected_logps = compute_logprobs(\n", + " logits=policy_model(batch[\"rejected\"]),\n", + " labels=batch[\"rejected\"],\n", + " mask=batch[\"rejected_mask\"]\n", + " )" + ], + "metadata": { + "collapsed": false + }, + "id": "3211a04a645fb478" } ], "metadata": { diff --git a/llm_tricks/DPO_example/dpo.py b/llm_tricks/DPO_example/dpo.py index e69de29..22e9ee6 100644 --- a/llm_tricks/DPO_example/dpo.py +++ b/llm_tricks/DPO_example/dpo.py @@ -0,0 +1,85 @@ +import torch.nn.functional as F +import torch.nn as nn +import torch + + +def compute_logprobs(logits, labels, mask=None): + """ + logits: shape (batch_size, sequence_len, vocab_size) + labels: shape (batch_size, sequence_len) + """ + + # 需要先进行位移操作 + # 去掉标签的第一个 + labels = labels[:, 1:].clone() + # 去掉模型输出的最后一个 + logits = logits[:, :-1, :] + + logps = F.log_softmax(logits, dim=-1) + + select_logprobs = torch.gather( + input=logps, + dim=1, + index=labels.unsqueeze(1) + ).squeeze(1) + + if mask is not None: + mask = mask[:, 1:].clone() + # 进行掩码padding部分 + select_logprobs = select_logprobs * mask + # 计算平均 + average_logprobs = select_logprobs.sum(-1) / mask.sum(-1) + return average_logprobs + else: + return select_logprobs.mean(-1) + + +def compute_batch_loss(batch, policy_model, reference_model, beta): + """Compute the DPO loss on an input batch""" + policy_chosen_logps = compute_logprobs( + logits=policy_model(batch["chosen"]), + labels=batch["chosen"], + mask=batch["chosen_mask"] + ) + policy_rejected_logps = compute_logprobs( + logits=policy_model(batch["rejected"]), + labels=batch["rejected"], + mask=batch["rejected_mask"] + ) + + +class DPOLoss(nn.Module): + """ + DPO Loss + """ + + def __init__(self, beta: float = 0.1) -> None: + super().__init__() + self.beta = beta + + def forward( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + ): + """ + policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,) + policy_rejected_logps: Shape: (batch_size,) + reference_chosen_logps: Shape: (batch_size,) + reference_rejected_logps: Shape: (batch_size,) + + """ + policy_logps = policy_chosen_logps - policy_rejected_logps + reference_logps = reference_chosen_logps - reference_rejected_logps + logits = policy_logps - reference_logps + + loss = -F.logsigmoid(self.beta * logits) + + # 下面两个用于追踪训练的进度 + chosen_rewards = (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = (policy_rejected_logps - reference_rejected_logps).detach() + + # 对每个batch进行平均 + return loss.mean(), chosen_rewards.mean(), rejected_rewards.mean()