Skip to content

Commit

Permalink
update dpo
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 12, 2024
1 parent f418585 commit 3769aad
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 0 deletions.
261 changes: 261 additions & 0 deletions llm_tricks/DPO_example/dpo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# 从0实现一个DPO"
],
"metadata": {
"collapsed": false
},
"id": "fb69f628966bb719"
},
{
"cell_type": "markdown",
"source": [
"## 1.准备数据"
],
"metadata": {
"collapsed": false
},
"id": "8bbb0f857b2a051"
},
{
"cell_type": "markdown",
"source": [
"DPO所需要的数据主要三个字段:\n",
"- instruction:指令问题\n",
"- chosen:选择的偏好回答\n",
"- rejected: 不好的回答"
],
"metadata": {
"collapsed": false
},
"id": "76726cb0c67792af"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "691f55cb1faa5a5f"
},
{
"cell_type": "markdown",
"source": [
"# 2、数据集处理"
],
"metadata": {
"collapsed": false
},
"id": "f6eaa11e7c529604"
},
{
"cell_type": "markdown",
"source": [
"了解DPO训练流程的可以知道,一般的DPO实现是需要将prompt(即instruction)分别和chsoen、rejected拼接在一起的。"
],
"metadata": {
"collapsed": false
},
"id": "c4726f451d1b5a4"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "4ce6798d1afd79ed"
},
{
"cell_type": "markdown",
"source": [
"# LOSS "
],
"metadata": {
"collapsed": false
},
"id": "9777171077888583"
},
{
"cell_type": "markdown",
"source": [
"DPO主要是两个模型,policy model(即我们主要要调优的模型) 和 reference model(用来约束的模型)"
],
"metadata": {
"collapsed": false
},
"id": "c97e1b5434c5aa01"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"import torch.nn as nn\n",
"import torch\n",
"\n",
"class DPOLoss(nn.Module):\n",
" \"\"\"\n",
" DPO Loss\n",
" \"\"\"\n",
"\n",
" def __init__(self, beta: float=0.1) -> None:\n",
" super().__init__()\n",
" self.beta = beta\n",
"\n",
" def forward(\n",
" self,\n",
" policy_chosen_logps: torch.Tensor,\n",
" policy_rejected_logps: torch.Tensor,\n",
" reference_chosen_logps: torch.Tensor,\n",
" reference_rejected_logps: torch.Tensor,\n",
" ) :\n",
" \"\"\"\n",
" policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,)\n",
" policy_rejected_logps: Shape: (batch_size,)\n",
" reference_chosen_logps: Shape: (batch_size,)\n",
" reference_rejected_logps: Shape: (batch_size,)\n",
" \n",
" \"\"\"\n",
" policy_logps = policy_chosen_logps - policy_rejected_logps\n",
" reference_logps = reference_chosen_logps - reference_rejected_logps\n",
" logits = policy_logps - reference_logps\n",
" \n",
" loss = -F.logsigmoid(self.beta * logits)\n",
" \n",
" # 下面两个用于追踪训练的进度\n",
" chosen_rewards = (policy_chosen_logps - reference_chosen_logps).detach()\n",
" rejected_rewards = (policy_rejected_logps - reference_rejected_logps).detach()\n",
" \n",
" # 对每个batch进行平均\n",
" return loss.mean(), chosen_rewards.mean(), rejected_rewards.mean()\n",
"\n",
" "
],
"metadata": {
"collapsed": false
},
"id": "c245fc84671838dd"
},
{
"cell_type": "markdown",
"source": [
"计算log probs ,也就是 $\\pi_\\theta (y_w \\mid x)$,"
],
"metadata": {
"collapsed": false
},
"id": "2b1be59f347b82bd"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"def compute_logprobs(logits, labels):\n",
" \"\"\"\n",
" logits: shape (batch_size, sequence_len, vocab_size)\n",
" labels: shape (batch_size, sequence_len)\n",
" \"\"\"\n",
" \n",
" # 需要先进行位移操作\n",
" # 去掉标签的第一个\n",
" labels = labels[:, 1:].clone()\n",
" # 去掉模型输出的最后一个\n",
" logits = logits[:,:-1,:]\n",
" \n",
" logps = F.log_softmax(logits, dim=-1)\n",
" \n",
" select_logprobs = torch.gather(\n",
" input=logps,\n",
" dim=1,\n",
" index=labels.unsqueeze(1)\n",
" ).squeeze(1)"
],
"metadata": {
"collapsed": false
},
"id": "ca63797467a68c1e"
},
{
"cell_type": "code",
"execution_count": 19,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[-0.4170],\n",
" [-2.4200]]) torch.Size([2, 1])\n",
"tensor([-0.4170, -2.4200]) torch.Size([2])\n",
"tensor(1.4185) tensor(1.4185)\n"
]
}
],
"source": [
"import torch.nn.functional as F\n",
"import torch\n",
"logits = torch.tensor(\n",
" [[2.0, 1.0, 0.1],\n",
" [0.5, 2.5, 0.3]]) # Shape: (2, 3)\n",
"targets = torch.tensor([0, 2]) # Shape: (2,)\n",
"# print(targets.unsqueeze(-1).shape)\n",
"\n",
"# Manual loss using torch.gather\n",
"log_softmax_logits = F.log_softmax(logits, dim=1) # Shape: (2, 3)\n",
"# print(log_softmax_logits)\n",
"selected_log_probs = torch.gather(\n",
" input=log_softmax_logits,\n",
" dim=1,\n",
" index=targets.unsqueeze(1), # Shape 2, 1\n",
") # Shape: (2,)\n",
"print(selected_log_probs,selected_log_probs.shape)\n",
"print(selected_log_probs.squeeze(1),selected_log_probs.squeeze(1).shape)\n",
"manual_loss = -selected_log_probs.mean() # Averaging over the batch\n",
"\n",
"\n",
"# PyTorch loss\n",
"cross_entropy_loss = F.cross_entropy(logits, targets)\n",
"\n",
"print(manual_loss, cross_entropy_loss)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-08-12T10:05:21.666926700Z",
"start_time": "2024-08-12T10:05:21.637344Z"
}
},
"id": "4b1f1f33b9e7f613"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Empty file added llm_tricks/DPO_example/dpo.py
Empty file.

0 comments on commit 3769aad

Please sign in to comment.