Skip to content

Commit

Permalink
update dpo
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 13, 2024
1 parent 84d2024 commit 08e39f8
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 1 deletion.
32 changes: 31 additions & 1 deletion llm_tricks/DPO_example/dpo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@
"source": [
"def compute_batch_loss(batch, policy_model, reference_model, beta):\n",
" \"\"\"Compute the DPO loss on an input batch\"\"\"\n",
" loss_fn = DPOLoss(beta)\n",
" \n",
" policy_chosen_logps = compute_logprobs(\n",
" logits=policy_model(batch[\"chosen\"]),\n",
" labels=batch[\"chosen\"],\n",
Expand All @@ -319,12 +321,40 @@
" logits=policy_model(batch[\"rejected\"]),\n",
" labels=batch[\"rejected\"],\n",
" mask=batch[\"rejected_mask\"]\n",
" )"
" )\n",
" reference_chosen_logps = compute_logprobs(\n",
" logits=reference_model(batch['chosen']),\n",
" labels=batch['chosen'],\n",
" mask=batch[\"chosen_mask\"]\n",
" )\n",
" reference_rejected_logps = compute_logprobs(\n",
" logits=reference_model(batch['rejected']),\n",
" labels=batch['rejected'],\n",
" mask=batch[\"rejected_mask\"]\n",
" )\n",
" loss, chosen_rewards, rejected_rewards = loss_fn(\n",
" policy_chosen_logps=policy_chosen_logps,\n",
" policy_rejected_logps=policy_rejected_logps,\n",
" reference_chosen_logps=reference_chosen_logps,\n",
" reference_rejected_logps=reference_rejected_logps,\n",
" beta=beta\n",
" )\n",
" return loss, chosen_rewards, rejected_rewards"
],
"metadata": {
"collapsed": false
},
"id": "3211a04a645fb478"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "89e5c1930fb1041"
}
],
"metadata": {
Expand Down
51 changes: 51 additions & 0 deletions llm_tricks/DPO_example/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def compute_logprobs(logits, labels, mask=None):

def compute_batch_loss(batch, policy_model, reference_model, beta):
"""Compute the DPO loss on an input batch"""
loss_fn = DPOLoss(beta)

policy_chosen_logps = compute_logprobs(
logits=policy_model(batch["chosen"]),
labels=batch["chosen"],
Expand All @@ -46,6 +48,55 @@ def compute_batch_loss(batch, policy_model, reference_model, beta):
labels=batch["rejected"],
mask=batch["rejected_mask"]
)
reference_chosen_logps = compute_logprobs(
logits=reference_model(batch['chosen']),
labels=batch['chosen'],
mask=batch["chosen_mask"]
)
reference_rejected_logps = compute_logprobs(
logits=reference_model(batch['rejected']),
labels=batch['rejected'],
mask=batch["rejected_mask"]
)
loss, chosen_rewards, rejected_rewards = loss_fn(
policy_chosen_logps=policy_chosen_logps,
policy_rejected_logps=policy_rejected_logps,
reference_chosen_logps=reference_chosen_logps,
reference_rejected_logps=reference_rejected_logps,
beta=beta
)
return loss, chosen_rewards, rejected_rewards


def compute_dataloader_loss(data_loader, policy_model, reference_model, beta, num_batch=None):
total_loss, total_chosen_rewards, total_rejected_rewards = 0., 0., 0.
if len(data_loader) == 0:
return float("nan")
elif num_batch is None:
num_batches = len(data_loader)
else:
# Reduce the number of batches to match the total number of batches in the data loader
# if num_batches exceeds the number of batches in the data loader
num_batches = min(num_batch, len(data_loader))

for i, batch in enumerate(data_loader):
if i < num_batches:
loss, chosen_rewards, rejected_rewards = compute_batch_loss(
batch=batch,
policy_model=policy_model,
reference_model=reference_model,
beta=beta
)
total_loss += loss.item()
total_chosen_rewards += chosen_rewards.item()
total_rejected_rewards += rejected_rewards.item()

else:
break
total_loss /= num_batches
total_chosen_rewards /= num_batches
total_rejected_rewards /= num_batches
return total_loss, total_chosen_rewards, total_rejected_rewards


class DPOLoss(nn.Module):
Expand Down
2 changes: 2 additions & 0 deletions run_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ MODEL_PATH=""

# deepspeed 启动
deepspeed --include localhost:0,1 main_train.py\
--train_args_path "sft_args" \
--train_data_path "$DATA_PATH" \
--model_name_or_path "$MODEL_PATH" \
--max_len 1024 \
Expand All @@ -29,6 +30,7 @@ deepspeed --include localhost:0,1 main_train.py\

# task_type:[pretrain, sft, dpo_multi, dpo_single]
# train_mode:[qlora, lora, full]
# train_args_path: [sft_args,dpo_args]



Expand Down

0 comments on commit 08e39f8

Please sign in to comment.