Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/mst272/LLM-Dojo
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 14, 2024
2 parents 0368766 + cb6f57a commit 7195dcf
Show file tree
Hide file tree
Showing 10 changed files with 406 additions and 22 deletions.
92 changes: 92 additions & 0 deletions llm_tricks/DPO_example/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
from torch.utils.data import Dataset


class RlhfDataset(Dataset):
def __init__(self, data, tokenizer):
self.data = data

self.encoded_texts = []
for entry in data:
prompt = entry["prompt"]
rejected_response = entry["rejected"]
chosen_response = entry["chosen"]

prompt_tokens = tokenizer.encode(prompt)
chosen_full_text = f"{prompt}\n\n### Response:\n{chosen_response}"
rejected_full_text = f"{prompt}\n\n### Response:\n{rejected_response}"
chosen_full_tokens = tokenizer.encode(chosen_full_text)
rejected_full_tokens = tokenizer.encode(rejected_full_text)

self.encoded_texts.append({
"prompt": prompt_tokens,
"chosen": chosen_full_tokens,
"rejected": rejected_full_tokens,
})

def __getitem__(self, index):
return self.encoded_texts[index]

def __len__(self):
return len(self.data)


def data_collate(
batch,
pad_token_id=50256,
allowed_max_length=None,
mask_prompt_tokens=True,
device="cpu"
):
# Initialize lists to hold batch data
batch_data = {
"prompt": [],
"chosen": [],
"rejected": [],
"rejected_mask": [],
"chosen_mask": []

}

# Determine the longest sequence to set a common padding length
max_length_common = 0
if batch:
for key in ["chosen", "rejected"]:
current_max = max(len(item[key]) + 1 for item in batch)
max_length_common = max(max_length_common, current_max)

# Process each item in the batch
for item in batch:
prompt = torch.tensor(item["prompt"])
batch_data["prompt"].append(prompt)

for key in ["chosen", "rejected"]:
# Adjust padding according to the common maximum length
sequence = item[key]
padded = sequence + [pad_token_id] * (max_length_common - len(sequence))
mask = torch.ones(len(padded)).bool()

# Set mask for all padding tokens to False
mask[len(sequence):] = False

# Set mask for all input tokens to False
# +2 sets the 2 newline ("\n") tokens before "### Response" to False
if mask_prompt_tokens:
mask[:prompt.shape[0] + 2] = False

batch_data[key].append(torch.tensor(padded))
batch_data[f"{key}_mask"].append(mask)

# Final processing
for key in ["chosen", "rejected", "chosen_mask", "rejected_mask"]:
# Stack all sequences into a tensor for the given key
tensor_stack = torch.stack(batch_data[key])

# Optionally truncate to maximum sequence length
if allowed_max_length is not None:
tensor_stack = tensor_stack[:, :allowed_max_length]

# Move to the specified device
batch_data[key] = tensor_stack.to(device)

return batch_data
94 changes: 92 additions & 2 deletions llm_tricks/DPO_example/dpo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,72 @@
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"source": [
"import json\n",
"\n",
"file_path = \"./unsloth_dpo.jsonl\"\n",
"\n",
"with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
" data_list = file.readlines()\n",
"# data = json.loads(data_list)\n",
"data = [json.loads(data) for data in data_list]\n",
"data"
],
"metadata": {
"collapsed": false
},
"id": "691f55cb1faa5a5f"
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [
{
"data": {
"text/plain": "{'prompt': 'What is one benefit of using Unsloth for LLM fine-tuning?',\n 'chosen': 'One benefit of using Unsloth for LLM fine-tuning is that it offers a 0% accuracy degradation compared to normal QLoRA, as no approximations are made in the optimized code.',\n 'rejected': 'Using Unsloth for LLM fine-tuning increases accuracy degradation compared to normal QLoRA.'}"
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data[1]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-08-14T09:17:46.958167500Z",
"start_time": "2024-08-14T09:17:46.942167Z"
}
},
"id": "9aa37c970611f41d"
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [
{
"data": {
"text/plain": "[{'prompt': 'How can Unsloth accelerate LLM fine-tuning?',\n 'chosen': 'Unsloth accelerates LLM fine-tuning by overwriting some parts of the modeling code with optimized operations and rewriting all Pytorch modules into Triton kernels, resulting in a 2x speedup and a 40% reduction in memory usage.',\n 'rejected': 'Unsloth slows down LLM fine-tuning.'},\n {'prompt': 'What is one benefit of using Unsloth for LLM fine-tuning?',\n 'chosen': 'One benefit of using Unsloth for LLM fine-tuning is that it offers a 0% accuracy degradation compared to normal QLoRA, as no approximations are made in the optimized code.',\n 'rejected': 'Using Unsloth for LLM fine-tuning increases accuracy degradation compared to normal QLoRA.'},\n {'prompt': 'Who developed the Unsloth library for LLM fine-tuning?',\n 'chosen': 'The Unsloth library for LLM fine-tuning was developed by the Unsloth team, including Daniel and Michael, in collaboration with the open-source community.',\n 'rejected': 'The Unsloth library for LLM fine-tuning was developed by Google and Microsoft.'},\n {'prompt': 'What GPU types are compatible with Unsloth for LLM fine-tuning?',\n 'chosen': 'Unsloth supports most NVIDIA GPUs, ranging from GTX 1070 to H100s, making it compatible with a wide range of hardware configurations.',\n 'rejected': 'Unsloth only supports AMD GPUs for LLM fine-tuning.'},\n {'prompt': 'How does Unsloth contribute to reducing memory usage during LLM fine-tuning?',\n 'chosen': 'Unsloth contributes to reducing memory usage during LLM fine-tuning by rewriting all Pytorch modules into Triton kernels, resulting in a 40% decrease in memory usage.',\n 'rejected': 'Unsloth increases memory usage during LLM fine-tuning.'},\n {'prompt': 'What is the primary goal of Unsloth for LLM fine-tuning?',\n 'chosen': 'The primary goal of Unsloth for LLM fine-tuning is to accelerate the process, achieving a 2x speedup while maintaining 0% accuracy degradation compared to normal QLoRA.',\n 'rejected': 'The primary goal of Unsloth for LLM fine-tuning is to slow down the process and increase memory usage.'},\n {'prompt': 'How does Unsloth improve LLM fine-tuning performance?',\n 'chosen': 'Unsloth improves LLM fine-tuning performance by manually deriving backpropagation steps and rewriting Pytorch modules into Triton kernels, resulting in a 2x speed increase and a 40% reduction in memory usage.',\n 'rejected': 'Unsloth degrades LLM fine-tuning performance compared to traditional methods.'},\n {'prompt': 'What makes Unsloth different from other tools for LLM fine-tuning?',\n 'chosen': 'What makes Unsloth different from other tools for LLM fine-tuning is its ability to achieve a 2x speed increase and a 40% reduction in memory usage while maintaining 0% accuracy degradation, thanks to optimized operations and Triton kernel rewriting.',\n 'rejected': 'Unsloth is not different from other tools for LLM fine-tuning.'},\n {'prompt': 'Why is Unsloth considered a lightweight library for LLM fine-tuning?',\n 'chosen': 'Unsloth is considered a lightweight library for LLM fine-tuning because it offers significant performance improvements, including a 2x speed increase and a 40% reduction in memory usage, without compromising accuracy.',\n 'rejected': 'Unsloth is considered a heavyweight library for LLM fine-tuning.'},\n {'prompt': 'What are the supported architectures for LLM fine-tuning by Unsloth?',\n 'chosen': 'Unsloth supports the Llama and Mistral architectures for LLM fine-tuning, ensuring compatibility with a range of model structures.',\n 'rejected': 'Unsloth does not support any specific architectures for LLM fine-tuning.'},\n {'prompt': 'How does Unsloth contribute to making LLM fine-tuning 2x faster?',\n 'chosen': 'Unsloth accelerates LLM fine-tuning by overwriting certain parts of the modeling code with optimized operations and manually deriving backpropagation steps, resulting in a 2x speedup without sacrificing accuracy.',\n 'rejected': 'Unsloth slows down LLM fine-tuning by introducing unnecessary complexities and approximations, leading to decreased efficiency.'},\n {'prompt': 'What benefits does Unsloth offer for LLM fine-tuning?',\n 'chosen': 'Unsloth provides a 2x acceleration in LLM fine-tuning speed, reduces memory usage by 40%, and ensures 0% accuracy degradation compared to regular QLoRA fine-tuning methods.',\n 'rejected': 'Unsloth has no impact on LLM fine-tuning efficiency and does not address memory usage concerns, leading to potential accuracy loss.'},\n {'prompt': 'Who developed the Unsloth library for LLM fine-tuning acceleration?',\n 'chosen': 'The Unsloth library was developed by Daniel, Michael, and the open-source community to address the need for faster fine-tuning of large language models.',\n 'rejected': 'Unsloth was solely developed by the Hugging Face team without any community contributions, making it less reliable and effective.'},\n {'prompt': 'Which GPUs are supported by Unsloth for LLM fine-tuning acceleration?',\n 'chosen': 'Unsloth supports most NVIDIA GPUs, ranging from GTX 1070 to H100s, ensuring compatibility with a wide range of hardware configurations for fine-tuning large language models.',\n 'rejected': 'Unsloth only supports outdated NVIDIA GPUs and is incompatible with modern hardware, limiting its usefulness for fine-tuning large language models.'},\n {'prompt': 'How does Unsloth ensure accuracy preservation during LLM fine-tuning?',\n 'chosen': 'Unsloth achieves 0% accuracy degradation by avoiding approximations in the optimized code and ensuring that all backpropagation steps are manually derived, maintaining fidelity to regular QLoRA fine-tuning methods.',\n 'rejected': 'Unsloth compromises accuracy during LLM fine-tuning by making approximations in the optimized code, resulting in a loss of model fidelity and decreased performance.'},\n {'prompt': 'What is the compatibility of Unsloth with the Hugging Face ecosystem?',\n 'chosen': 'Unsloth is fully compatible with the Hugging Face ecosystem, including Hub, transformers, PEFT, and TRL libraries, providing seamless integration into existing workflows for fine-tuning large language models.',\n 'rejected': 'Unsloth lacks compatibility with the Hugging Face ecosystem, requiring extensive modifications to existing workflows and libraries for integration, which can lead to inefficiencies and errors.'}]"
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"example_data = data[:16]\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-08-14T09:48:33.915596900Z",
"start_time": "2024-08-14T09:48:33.887571800Z"
}
},
"id": "c4947ba3cc9fb205"
},
{
"cell_type": "markdown",
"source": [
Expand Down Expand Up @@ -310,6 +370,8 @@
"source": [
"def compute_batch_loss(batch, policy_model, reference_model, beta):\n",
" \"\"\"Compute the DPO loss on an input batch\"\"\"\n",
" loss_fn = DPOLoss(beta)\n",
" \n",
" policy_chosen_logps = compute_logprobs(\n",
" logits=policy_model(batch[\"chosen\"]),\n",
" labels=batch[\"chosen\"],\n",
Expand All @@ -319,12 +381,40 @@
" logits=policy_model(batch[\"rejected\"]),\n",
" labels=batch[\"rejected\"],\n",
" mask=batch[\"rejected_mask\"]\n",
" )"
" )\n",
" reference_chosen_logps = compute_logprobs(\n",
" logits=reference_model(batch['chosen']),\n",
" labels=batch['chosen'],\n",
" mask=batch[\"chosen_mask\"]\n",
" )\n",
" reference_rejected_logps = compute_logprobs(\n",
" logits=reference_model(batch['rejected']),\n",
" labels=batch['rejected'],\n",
" mask=batch[\"rejected_mask\"]\n",
" )\n",
" loss, chosen_rewards, rejected_rewards = loss_fn(\n",
" policy_chosen_logps=policy_chosen_logps,\n",
" policy_rejected_logps=policy_rejected_logps,\n",
" reference_chosen_logps=reference_chosen_logps,\n",
" reference_rejected_logps=reference_rejected_logps,\n",
" beta=beta\n",
" )\n",
" return loss, chosen_rewards, rejected_rewards"
],
"metadata": {
"collapsed": false
},
"id": "3211a04a645fb478"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "89e5c1930fb1041"
}
],
"metadata": {
Expand Down
158 changes: 158 additions & 0 deletions llm_tricks/DPO_example/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def compute_logprobs(logits, labels, mask=None):

def compute_batch_loss(batch, policy_model, reference_model, beta):
"""Compute the DPO loss on an input batch"""
loss_fn = DPOLoss(beta)

policy_chosen_logps = compute_logprobs(
logits=policy_model(batch["chosen"]),
labels=batch["chosen"],
Expand All @@ -46,6 +48,55 @@ def compute_batch_loss(batch, policy_model, reference_model, beta):
labels=batch["rejected"],
mask=batch["rejected_mask"]
)
reference_chosen_logps = compute_logprobs(
logits=reference_model(batch['chosen']),
labels=batch['chosen'],
mask=batch["chosen_mask"]
)
reference_rejected_logps = compute_logprobs(
logits=reference_model(batch['rejected']),
labels=batch['rejected'],
mask=batch["rejected_mask"]
)
loss, chosen_rewards, rejected_rewards = loss_fn(
policy_chosen_logps=policy_chosen_logps,
policy_rejected_logps=policy_rejected_logps,
reference_chosen_logps=reference_chosen_logps,
reference_rejected_logps=reference_rejected_logps,
beta=beta
)
return loss, chosen_rewards, rejected_rewards


def compute_dataloader_loss(data_loader, policy_model, reference_model, beta, num_batch=None):
total_loss, total_chosen_rewards, total_rejected_rewards = 0., 0., 0.
if len(data_loader) == 0:
return float("nan")
elif num_batch is None:
num_batches = len(data_loader)
else:
# Reduce the number of batches to match the total number of batches in the data loader
# if num_batches exceeds the number of batches in the data loader
num_batches = min(num_batch, len(data_loader))

for i, batch in enumerate(data_loader):
if i < num_batches:
loss, chosen_rewards, rejected_rewards = compute_batch_loss(
batch=batch,
policy_model=policy_model,
reference_model=reference_model,
beta=beta
)
total_loss += loss.item()
total_chosen_rewards += chosen_rewards.item()
total_rejected_rewards += rejected_rewards.item()

else:
break
total_loss /= num_batches
total_chosen_rewards /= num_batches
total_rejected_rewards /= num_batches
return total_loss, total_chosen_rewards, total_rejected_rewards


class DPOLoss(nn.Module):
Expand Down Expand Up @@ -83,3 +134,110 @@ def forward(

# 对每个batch进行平均
return loss.mean(), chosen_rewards.mean(), rejected_rewards.mean()


def evaluate_dpo_loss_loader(policy_model, reference_model, train_loader, val_loader, beta, eval_iter):
"""Compute the DPO loss for the training and validation dataset"""

policy_model.eval()
with torch.no_grad():
train_loss, train_chosen_rewards, train_rejected_rewards = compute_dataloader_loss(
data_loader=train_loader,
policy_model=policy_model,
reference_model=reference_model,
beta=beta,
num_batches=eval_iter
)

val_loss, val_chosen_rewards, val_rejected_rewards = compute_dataloader_loss(
data_loader=val_loader,
policy_model=policy_model,
reference_model=reference_model,
beta=beta,
num_batches=eval_iter
)

res = {
"train_loss": train_loss,
"train_chosen_reward": train_chosen_rewards,
"train_rejected_reward": train_rejected_rewards,
"val_loss": val_loss,
"val_chosen_reward": val_chosen_rewards,
"val_rejected_reward": val_rejected_rewards
}

policy_model.train()
return res


# 开始训练模型
def train_model(
policy_model, reference_model, train_loader, val_loader,
optimizer, num_epochs, beta,
eval_freq, eval_iter, start_context, tokenizer
):
tracking = {
"train_losses": [],
"train_chosen_rewards": [],
"train_rejected_rewards": [],
"val_losses": [],
"val_chosen_rewards": [],
"val_rejected_rewards": [],
"tokens_seen": []
}

tokens_seen, global_step = 0, -1

# 训练
for epoch in range(num_epochs):
# policy 模型需要训练
policy_model.train()

for idx, batch in enumerate(train_loader):
optimizer.zero_grad()

loss, chosen_rewards, rejected_rewards = compute_batch_loss(
batch=batch,
policy_model=policy_model,
reference_model=reference_model,
beta=beta
)

loss.backward()
optimizer.step()

global_step += 1
tokens_seen += batch["chosen"].numel()

# 验证
if global_step % eval_freq == 0:
res = evaluate_dpo_loss_loader(
policy_model=policy_model,
reference_model=reference_model,
train_loader=train_loader,
val_loader=val_loader,
beta=beta,
eval_iter=eval_iter
)
tracking["train_losses"].append(res["train_loss"])
tracking["train_chosen_rewards"].append(res["train_chosen_reward"])
tracking["train_rejected_rewards"].append(res["train_rejected_reward"])
tracking["val_losses"].append(res["val_loss"])
tracking["val_chosen_rewards"].append(res["val_chosen_reward"])
tracking["val_rejected_rewards"].append(res["val_rejected_reward"])
tracking["tokens_seen"].append(tokens_seen)
train_reward_margin = res["train_chosen_reward"] - res["train_rejected_reward"]
val_reward_margin = res["val_chosen_reward"] - res["val_rejected_reward"]

print(
f"Ep {epoch + 1} (Step {global_step:06d}): "
f"Train loss {res['train_loss']:.3f}, Val loss {res['val_loss']:.3f}, "
f"Train reward margins {train_reward_margin:.3f}, "
f"Val reward margins {val_reward_margin:.3f}"
)

return tracking


def main():
pass
Loading

0 comments on commit 7195dcf

Please sign in to comment.