Skip to content

Commit

Permalink
update dpo example
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 23, 2024
1 parent cadd1f7 commit 35bd4fc
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions llm_tricks/DPO_example/dpo_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
import torch
from dataset import RlhfDataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from loss import DPOLoss, compute_batch_loss
from loss import compute_batch_loss
from evaluate import evaluate_loss_dataloader
import time
from functools import partial

# 1、加载模型与tokenizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = 'qwen\Qwen1___5-0___5B'
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16)
ref_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16)
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model_path = '/IndexTeam/Index-1___9B-Chat'
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
ref_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
ref_model.eval()
model.to(device)
ref_model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
Expand Down Expand Up @@ -100,7 +101,6 @@ def data_collector(batch, pad_token_id, device, max_length=None, if_mask_prompt=
# 3、开始计算DPO(或其他)的损失函数



# 4、编写训练函数
def train_model(
policy_model, reference_model, train_loader, val_loader,
Expand All @@ -115,7 +115,6 @@ def train_model(
"val_rejected_rewards": [],
"tokens_seen": []
}

tokens_seen, global_step = 0, -1

# 训练
Expand All @@ -132,7 +131,6 @@ def train_model(
reference_model=reference_model,
beta=beta
)

loss.backward()
optimizer.step()

Expand Down Expand Up @@ -175,7 +173,7 @@ def main():
start_time = time.time()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

num_epochs = 2
num_epochs = 3
tracking = train_model(
policy_model=model,
reference_model=ref_model,
Expand Down

0 comments on commit 35bd4fc

Please sign in to comment.