Skip to content

Commit

Permalink
update dpo example
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 21, 2024
1 parent ddf540a commit edf2670
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 79 deletions.
63 changes: 1 addition & 62 deletions llm_tricks/DPO_example/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,65 +32,4 @@ def __getitem__(self, item):
return input

def __len__(self):
return len(self.data_list)


def data_collate(
batch,
pad_token_id=50256,
allowed_max_length=None,
mask_prompt_tokens=True,
device="cpu"
):
# Initialize lists to hold batch data
batch_data = {
"prompt": [],
"chosen": [],
"rejected": [],
"rejected_mask": [],
"chosen_mask": []

}

# Determine the longest sequence to set a common padding length
max_length_common = 0
if batch:
for key in ["chosen", "rejected"]:
current_max = max(len(item[key]) + 1 for item in batch)
max_length_common = max(max_length_common, current_max)

# Process each item in the batch
for item in batch:
prompt = torch.tensor(item["prompt"])
batch_data["prompt"].append(prompt)

for key in ["chosen", "rejected"]:
# Adjust padding according to the common maximum length
sequence = item[key]
padded = sequence + [pad_token_id] * (max_length_common - len(sequence))
mask = torch.ones(len(padded)).bool()

# Set mask for all padding tokens to False
mask[len(sequence):] = False

# Set mask for all input tokens to False
# +2 sets the 2 newline ("\n") tokens before "### Response" to False
if mask_prompt_tokens:
mask[:prompt.shape[0] + 2] = False

batch_data[key].append(torch.tensor(padded))
batch_data[f"{key}_mask"].append(mask)

# Final processing
for key in ["chosen", "rejected", "chosen_mask", "rejected_mask"]:
# Stack all sequences into a tensor for the given key
tensor_stack = torch.stack(batch_data[key])

# Optionally truncate to maximum sequence length
if allowed_max_length is not None:
tensor_stack = tensor_stack[:, :allowed_max_length]

# Move to the specified device
batch_data[key] = tensor_stack.to(device)

return batch_data
return len(self.data_list)
83 changes: 66 additions & 17 deletions llm_tricks/DPO_example/dpo_example.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,79 @@
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from dataset import RlhfDataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from loss import DPOLoss

# 1、加载模型与tokenizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = r'D:\GithubProject\LLM\download_llm\qwen\Qwen1___5-0___5B'
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16)
ref_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16)
model.to(device)
ref_model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)

# 2、处理数据
















# 加载数据
data_file = './unsloth_dpo.jsonl'
# Dataset详细逻辑可看进入RlhfDataset实现
dataset = RlhfDataset(data_file, tokenizer)
# 划分训练集验证集
train_size = int(len(dataset) * 0.85) # 85% for training
val_size = len(dataset) - train_size # Remaining for validation
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# 编写batch批次的padding及mask处理函数
IGNORE_INDEX = False


def data_collector(batch, pad_token, device, max_length=None, if_mask_prompt=True):
batch_data = {
"prompt": [],
"chosen": [],
"rejected": [],
"rejected_mask": [],
"chosen_mask": []
}

# 判断长度及padding
max_length_common = 0
for key in ["chosen", "rejected"]:
current_max = max(len(item[key]) for item in batch)
max_length_common = max(max_length_common, current_max)

# 转为torch tensor并padding,决定是否对prompt进行mask
for item in batch:
prompt = torch.tensor(item['prompt'])
batch_data['prompt'].append(prompt)

for key in ["chosen", "rejected"]:
out = item[key]
out_padding = out + [pad_token] * (max_length_common - len(out))
mask = torch.ones(len(out_padding)).bool()

# padding部分的mask设置为 IGNORE_INDEX
mask[len(out):] = IGNORE_INDEX

if if_mask_prompt:
mask[:prompt.shape[0] + 2] = IGNORE_INDEX
batch_data[key].append(torch.tensor(out_padding))
batch_data[f"{key}_mask"].append(mask)

# 进行最大长度截断
for key in ["chosen", "rejected", "chosen_mask", "rejected_mask"]:
tensor_stack = torch.stack(batch_data[key])
if max_length is not None:
tensor_stack = tensor_stack[:, :max_length]
# 将tensor移到对应的device
batch_data[key] = tensor_stack.to(device)
return batch_data


# 3、开始计算DPO(或其他)的损失函数
loss_fn = DPOLoss()


if __name__ == "__main__":
Expand Down
72 changes: 72 additions & 0 deletions llm_tricks/DPO_example/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch.nn.functional as F
import torch.nn as nn
import torch


# 计算DPO loss的公式
class DPOLoss(nn.Module):
"""
DPO Loss
"""

def __init__(self, beta: float = 0.1) -> None:
super().__init__()
self.beta = beta

def forward(
self,
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
):
"""
policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,)
policy_rejected_logps: Shape: (batch_size,)
reference_chosen_logps: Shape: (batch_size,)
reference_rejected_logps: Shape: (batch_size,)
"""
policy_logps = policy_chosen_logps - policy_rejected_logps
reference_logps = reference_chosen_logps - reference_rejected_logps
logits = policy_logps - reference_logps

loss = -F.logsigmoid(self.beta * logits)

# 下面两个用于追踪训练的进度
chosen_rewards = (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = (policy_rejected_logps - reference_rejected_logps).detach()

# 对每个batch进行平均(期望)
return loss.mean(), chosen_rewards.mean(), rejected_rewards.mean()


# 计算每个模型的Log probabilities
def compute_logprobs(logits, labels, mask=None):
"""
logits: shape (batch_size, sequence_len, vocab_size),即将label输入给模型后输出的结果
labels: shape (batch_size, sequence_len)
"""

# 需要先进行位移操作
# 去掉标签的第一个
labels = labels[:, 1:].clone()
# 去掉模型输出的最后一个
logits = logits[:, :-1, :]

logps = F.log_softmax(logits, dim=-1)

select_logprobs = torch.gather(
input=logps,
dim=1,
index=labels.unsqueeze(1)
).squeeze(1)

if mask is not None:
mask = mask[:, 1:].clone()
# 进行掩码padding部分
select_logprobs = select_logprobs * mask
# 计算每一句的平均
average_logprobs = select_logprobs.sum(-1) / mask.sum(-1)
return average_logprobs
else:
return select_logprobs.mean(-1)

0 comments on commit edf2670

Please sign in to comment.