Skip to content

Commit

Permalink
fix bugs and update dpo example
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 14, 2024
1 parent 08e39f8 commit 30a5d8c
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 21 deletions.
92 changes: 92 additions & 0 deletions llm_tricks/DPO_example/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
from torch.utils.data import Dataset


class RlhfDataset(Dataset):
def __init__(self, data, tokenizer):
self.data = data

self.encoded_texts = []
for entry in data:
prompt = entry["prompt"]
rejected_response = entry["rejected"]
chosen_response = entry["chosen"]

prompt_tokens = tokenizer.encode(prompt)
chosen_full_text = f"{prompt}\n\n### Response:\n{chosen_response}"
rejected_full_text = f"{prompt}\n\n### Response:\n{rejected_response}"
chosen_full_tokens = tokenizer.encode(chosen_full_text)
rejected_full_tokens = tokenizer.encode(rejected_full_text)

self.encoded_texts.append({
"prompt": prompt_tokens,
"chosen": chosen_full_tokens,
"rejected": rejected_full_tokens,
})

def __getitem__(self, index):
return self.encoded_texts[index]

def __len__(self):
return len(self.data)


def data_collate(
batch,
pad_token_id=50256,
allowed_max_length=None,
mask_prompt_tokens=True,
device="cpu"
):
# Initialize lists to hold batch data
batch_data = {
"prompt": [],
"chosen": [],
"rejected": [],
"rejected_mask": [],
"chosen_mask": []

}

# Determine the longest sequence to set a common padding length
max_length_common = 0
if batch:
for key in ["chosen", "rejected"]:
current_max = max(len(item[key]) + 1 for item in batch)
max_length_common = max(max_length_common, current_max)

# Process each item in the batch
for item in batch:
prompt = torch.tensor(item["prompt"])
batch_data["prompt"].append(prompt)

for key in ["chosen", "rejected"]:
# Adjust padding according to the common maximum length
sequence = item[key]
padded = sequence + [pad_token_id] * (max_length_common - len(sequence))
mask = torch.ones(len(padded)).bool()

# Set mask for all padding tokens to False
mask[len(sequence):] = False

# Set mask for all input tokens to False
# +2 sets the 2 newline ("\n") tokens before "### Response" to False
if mask_prompt_tokens:
mask[:prompt.shape[0] + 2] = False

batch_data[key].append(torch.tensor(padded))
batch_data[f"{key}_mask"].append(mask)

# Final processing
for key in ["chosen", "rejected", "chosen_mask", "rejected_mask"]:
# Stack all sequences into a tensor for the given key
tensor_stack = torch.stack(batch_data[key])

# Optionally truncate to maximum sequence length
if allowed_max_length is not None:
tensor_stack = tensor_stack[:, :allowed_max_length]

# Move to the specified device
batch_data[key] = tensor_stack.to(device)

return batch_data
62 changes: 61 additions & 1 deletion llm_tricks/DPO_example/dpo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,72 @@
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"source": [
"import json\n",
"\n",
"file_path = \"./unsloth_dpo.jsonl\"\n",
"\n",
"with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
" data_list = file.readlines()\n",
"# data = json.loads(data_list)\n",
"data = [json.loads(data) for data in data_list]\n",
"data"
],
"metadata": {
"collapsed": false
},
"id": "691f55cb1faa5a5f"
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [
{
"data": {
"text/plain": "{'prompt': 'What is one benefit of using Unsloth for LLM fine-tuning?',\n 'chosen': 'One benefit of using Unsloth for LLM fine-tuning is that it offers a 0% accuracy degradation compared to normal QLoRA, as no approximations are made in the optimized code.',\n 'rejected': 'Using Unsloth for LLM fine-tuning increases accuracy degradation compared to normal QLoRA.'}"
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data[1]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-08-14T09:17:46.958167500Z",
"start_time": "2024-08-14T09:17:46.942167Z"
}
},
"id": "9aa37c970611f41d"
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [
{
"data": {
"text/plain": "[{'prompt': 'How can Unsloth accelerate LLM fine-tuning?',\n 'chosen': 'Unsloth accelerates LLM fine-tuning by overwriting some parts of the modeling code with optimized operations and rewriting all Pytorch modules into Triton kernels, resulting in a 2x speedup and a 40% reduction in memory usage.',\n 'rejected': 'Unsloth slows down LLM fine-tuning.'},\n {'prompt': 'What is one benefit of using Unsloth for LLM fine-tuning?',\n 'chosen': 'One benefit of using Unsloth for LLM fine-tuning is that it offers a 0% accuracy degradation compared to normal QLoRA, as no approximations are made in the optimized code.',\n 'rejected': 'Using Unsloth for LLM fine-tuning increases accuracy degradation compared to normal QLoRA.'},\n {'prompt': 'Who developed the Unsloth library for LLM fine-tuning?',\n 'chosen': 'The Unsloth library for LLM fine-tuning was developed by the Unsloth team, including Daniel and Michael, in collaboration with the open-source community.',\n 'rejected': 'The Unsloth library for LLM fine-tuning was developed by Google and Microsoft.'},\n {'prompt': 'What GPU types are compatible with Unsloth for LLM fine-tuning?',\n 'chosen': 'Unsloth supports most NVIDIA GPUs, ranging from GTX 1070 to H100s, making it compatible with a wide range of hardware configurations.',\n 'rejected': 'Unsloth only supports AMD GPUs for LLM fine-tuning.'},\n {'prompt': 'How does Unsloth contribute to reducing memory usage during LLM fine-tuning?',\n 'chosen': 'Unsloth contributes to reducing memory usage during LLM fine-tuning by rewriting all Pytorch modules into Triton kernels, resulting in a 40% decrease in memory usage.',\n 'rejected': 'Unsloth increases memory usage during LLM fine-tuning.'},\n {'prompt': 'What is the primary goal of Unsloth for LLM fine-tuning?',\n 'chosen': 'The primary goal of Unsloth for LLM fine-tuning is to accelerate the process, achieving a 2x speedup while maintaining 0% accuracy degradation compared to normal QLoRA.',\n 'rejected': 'The primary goal of Unsloth for LLM fine-tuning is to slow down the process and increase memory usage.'},\n {'prompt': 'How does Unsloth improve LLM fine-tuning performance?',\n 'chosen': 'Unsloth improves LLM fine-tuning performance by manually deriving backpropagation steps and rewriting Pytorch modules into Triton kernels, resulting in a 2x speed increase and a 40% reduction in memory usage.',\n 'rejected': 'Unsloth degrades LLM fine-tuning performance compared to traditional methods.'},\n {'prompt': 'What makes Unsloth different from other tools for LLM fine-tuning?',\n 'chosen': 'What makes Unsloth different from other tools for LLM fine-tuning is its ability to achieve a 2x speed increase and a 40% reduction in memory usage while maintaining 0% accuracy degradation, thanks to optimized operations and Triton kernel rewriting.',\n 'rejected': 'Unsloth is not different from other tools for LLM fine-tuning.'},\n {'prompt': 'Why is Unsloth considered a lightweight library for LLM fine-tuning?',\n 'chosen': 'Unsloth is considered a lightweight library for LLM fine-tuning because it offers significant performance improvements, including a 2x speed increase and a 40% reduction in memory usage, without compromising accuracy.',\n 'rejected': 'Unsloth is considered a heavyweight library for LLM fine-tuning.'},\n {'prompt': 'What are the supported architectures for LLM fine-tuning by Unsloth?',\n 'chosen': 'Unsloth supports the Llama and Mistral architectures for LLM fine-tuning, ensuring compatibility with a range of model structures.',\n 'rejected': 'Unsloth does not support any specific architectures for LLM fine-tuning.'},\n {'prompt': 'How does Unsloth contribute to making LLM fine-tuning 2x faster?',\n 'chosen': 'Unsloth accelerates LLM fine-tuning by overwriting certain parts of the modeling code with optimized operations and manually deriving backpropagation steps, resulting in a 2x speedup without sacrificing accuracy.',\n 'rejected': 'Unsloth slows down LLM fine-tuning by introducing unnecessary complexities and approximations, leading to decreased efficiency.'},\n {'prompt': 'What benefits does Unsloth offer for LLM fine-tuning?',\n 'chosen': 'Unsloth provides a 2x acceleration in LLM fine-tuning speed, reduces memory usage by 40%, and ensures 0% accuracy degradation compared to regular QLoRA fine-tuning methods.',\n 'rejected': 'Unsloth has no impact on LLM fine-tuning efficiency and does not address memory usage concerns, leading to potential accuracy loss.'},\n {'prompt': 'Who developed the Unsloth library for LLM fine-tuning acceleration?',\n 'chosen': 'The Unsloth library was developed by Daniel, Michael, and the open-source community to address the need for faster fine-tuning of large language models.',\n 'rejected': 'Unsloth was solely developed by the Hugging Face team without any community contributions, making it less reliable and effective.'},\n {'prompt': 'Which GPUs are supported by Unsloth for LLM fine-tuning acceleration?',\n 'chosen': 'Unsloth supports most NVIDIA GPUs, ranging from GTX 1070 to H100s, ensuring compatibility with a wide range of hardware configurations for fine-tuning large language models.',\n 'rejected': 'Unsloth only supports outdated NVIDIA GPUs and is incompatible with modern hardware, limiting its usefulness for fine-tuning large language models.'},\n {'prompt': 'How does Unsloth ensure accuracy preservation during LLM fine-tuning?',\n 'chosen': 'Unsloth achieves 0% accuracy degradation by avoiding approximations in the optimized code and ensuring that all backpropagation steps are manually derived, maintaining fidelity to regular QLoRA fine-tuning methods.',\n 'rejected': 'Unsloth compromises accuracy during LLM fine-tuning by making approximations in the optimized code, resulting in a loss of model fidelity and decreased performance.'},\n {'prompt': 'What is the compatibility of Unsloth with the Hugging Face ecosystem?',\n 'chosen': 'Unsloth is fully compatible with the Hugging Face ecosystem, including Hub, transformers, PEFT, and TRL libraries, providing seamless integration into existing workflows for fine-tuning large language models.',\n 'rejected': 'Unsloth lacks compatibility with the Hugging Face ecosystem, requiring extensive modifications to existing workflows and libraries for integration, which can lead to inefficiencies and errors.'}]"
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"example_data = data[:16]\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-08-14T09:48:33.915596900Z",
"start_time": "2024-08-14T09:48:33.887571800Z"
}
},
"id": "c4947ba3cc9fb205"
},
{
"cell_type": "markdown",
"source": [
Expand Down
107 changes: 107 additions & 0 deletions llm_tricks/DPO_example/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,110 @@ def forward(

# 对每个batch进行平均
return loss.mean(), chosen_rewards.mean(), rejected_rewards.mean()


def evaluate_dpo_loss_loader(policy_model, reference_model, train_loader, val_loader, beta, eval_iter):
"""Compute the DPO loss for the training and validation dataset"""

policy_model.eval()
with torch.no_grad():
train_loss, train_chosen_rewards, train_rejected_rewards = compute_dataloader_loss(
data_loader=train_loader,
policy_model=policy_model,
reference_model=reference_model,
beta=beta,
num_batches=eval_iter
)

val_loss, val_chosen_rewards, val_rejected_rewards = compute_dataloader_loss(
data_loader=val_loader,
policy_model=policy_model,
reference_model=reference_model,
beta=beta,
num_batches=eval_iter
)

res = {
"train_loss": train_loss,
"train_chosen_reward": train_chosen_rewards,
"train_rejected_reward": train_rejected_rewards,
"val_loss": val_loss,
"val_chosen_reward": val_chosen_rewards,
"val_rejected_reward": val_rejected_rewards
}

policy_model.train()
return res


# 开始训练模型
def train_model(
policy_model, reference_model, train_loader, val_loader,
optimizer, num_epochs, beta,
eval_freq, eval_iter, start_context, tokenizer
):
tracking = {
"train_losses": [],
"train_chosen_rewards": [],
"train_rejected_rewards": [],
"val_losses": [],
"val_chosen_rewards": [],
"val_rejected_rewards": [],
"tokens_seen": []
}

tokens_seen, global_step = 0, -1

# 训练
for epoch in range(num_epochs):
# policy 模型需要训练
policy_model.train()

for idx, batch in enumerate(train_loader):
optimizer.zero_grad()

loss, chosen_rewards, rejected_rewards = compute_batch_loss(
batch=batch,
policy_model=policy_model,
reference_model=reference_model,
beta=beta
)

loss.backward()
optimizer.step()

global_step += 1
tokens_seen += batch["chosen"].numel()

# 验证
if global_step % eval_freq == 0:
res = evaluate_dpo_loss_loader(
policy_model=policy_model,
reference_model=reference_model,
train_loader=train_loader,
val_loader=val_loader,
beta=beta,
eval_iter=eval_iter
)
tracking["train_losses"].append(res["train_loss"])
tracking["train_chosen_rewards"].append(res["train_chosen_reward"])
tracking["train_rejected_rewards"].append(res["train_rejected_reward"])
tracking["val_losses"].append(res["val_loss"])
tracking["val_chosen_rewards"].append(res["val_chosen_reward"])
tracking["val_rejected_rewards"].append(res["val_rejected_reward"])
tracking["tokens_seen"].append(tokens_seen)
train_reward_margin = res["train_chosen_reward"] - res["train_rejected_reward"]
val_reward_margin = res["val_chosen_reward"] - res["val_rejected_reward"]

print(
f"Ep {epoch + 1} (Step {global_step:06d}): "
f"Train loss {res['train_loss']:.3f}, Val loss {res['val_loss']:.3f}, "
f"Train reward margins {train_reward_margin:.3f}, "
f"Val reward margins {val_reward_margin:.3f}"
)

return tracking


def main():
pass
2 changes: 2 additions & 0 deletions rlhf/rlhf_args/cpo-simpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ class CPOSimPOConfig(CPOConfig):
cpo_alpha: float = 0.5
"""combined use of CPO and SimPO, which enables more stable training and improved performance.A non-zero
cpo_alpha"""
eval_samples: int = 30
"""eval sample的数量,注意不能少于batchsize*gradient_accumulation_steps"""

2 changes: 2 additions & 0 deletions rlhf/rlhf_args/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@ class CPOConfig(CPOConfig):
"help": "Remove columns not required by the model when using an nlp.Dataset."})
bf16: bool = field(default=True, metadata={"help": "是否使用bf16精度"})
fp16: bool = field(default=False, metadata={"help": "是否使用bf16精度"})
eval_samples: int = 30
"""eval sample的数量,注意不能少于batchsize*gradient_accumulation_steps"""
17 changes: 3 additions & 14 deletions rlhf/rlhf_args/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,9 @@
from transformers.training_args import OptimizerNames
from trl.trainer.ppov2_trainer import PPOv2Config


@dataclass
class PPOConfig(PPOv2Config):
# common config
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
run_name: Optional[str] = None
"""a unique name of this run"""
sanity_check: bool = False
"""wether to run in debug mode"""

# batch size related config
num_mini_batches: int = 1
"""Number of minibatches to split a batch into"""
Expand Down Expand Up @@ -87,9 +80,5 @@ class PPOConfig(PPOv2Config):
remove_unused_columns: Optional[bool] = field(default=False, metadata={
"help": "Remove columns not required by the model when using an nlp.Dataset."})
bf16: bool = field(default=True, metadata={"help": "是否使用bf16精度"})

# Deepspeed训练相关参数,不使用时设置为default=None
deepspeed: Optional[str] = field(default=None, metadata={"help": "启用Deepspeed时需要的config文件"})

world_size: Optional[int] = 1
"""The number of processes (GPUs) to use"""
eval_samples: int = 30
"""eval sample的数量,注意不能少于batchsize*gradient_accumulation_steps"""
8 changes: 2 additions & 6 deletions rlhf/rlhf_args/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,6 @@
# 支持直接通过total_episodes确定训练步数,也支持通过在TrainingArguments中配置num_train_epochs确定训练步数。
@dataclass
class RLOOConfig(RLOOConfig):
# common config
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
run_name: Optional[str] = None
"""a unique name of this run"""


# batch size related config
num_mini_batches: int = 1
Expand Down Expand Up @@ -87,3 +81,5 @@ class RLOOConfig(RLOOConfig):
"help": "Remove columns not required by the model when using an nlp.Dataset."})
bf16: bool = field(default=True, metadata={"help": "是否使用bf16精度"})
fp16: bool = field(default=False, metadata={"help": "是否使用bf16精度"})
eval_samples: int = 30
"""eval sample的数量,注意不能少于batchsize*gradient_accumulation_steps"""
2 changes: 2 additions & 0 deletions rlhf/rlhf_args/simpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ class SimPOConfig(CPOConfig):
"""A hyperparameter that controls the strength of the BC regularizer in CPO training."""
simpo_gamma: float = 0.5
"""A target reward margin for the SimPO loss, used only when the "simpo" option is enabled."""
eval_samples: int = 30
"""eval sample的数量,注意不能少于batchsize*gradient_accumulation_steps"""

0 comments on commit 30a5d8c

Please sign in to comment.