forked from mst272/LLM-Dojo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
110 additions
and
458 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# 从零实现强化学习DPO(SimPO)训练代码 | ||
|
||
## Quick start | ||
```python | ||
python dpo_train.py | ||
``` | ||
|
||
## 说明 | ||
本文档下的从零实现只是一个学习的demo,用以理解原理所用,并没有增加分布式等。所以尽管使用2B的小模型,显存占用也高达30+GB。 | ||
|
||
```dpo_train.py```为训练主路径, 相关loss计算在```loss.py```. | ||
|
||
如果想要使用DPO或者Simpo、CPO等强化学习方法真正训练的话, | ||
可以使用本项目中的rlhf构建的强化学习框架:[RLHF](../../rlhf/README.md) | ||
|
||
支持deepspeed的单机多卡Lora、Dora、Qlora、全量参数训练,并自动适配模型的chat template。 |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import List, Dict | ||
import copy | ||
|
||
|
||
def is_right_apply_chat(tokenizer, prompt: List[Dict[str, str]], assistant_content: List[Dict[str, str]]) -> bool: | ||
""" | ||
Checks if the assistant's content is correctly applied to the prompt in a chat template. | ||
Args: | ||
tokenizer: The tokenizer. | ||
prompt: The initial prompt message. | ||
assistant_content: The content provided by the assistant. | ||
Returns: | ||
bool: True if the assistant's content is correctly applied, False otherwise. | ||
""" | ||
try: | ||
test_assistant = tokenizer.apply_chat_template(assistant_content, tokenize=False) | ||
test_prompt = tokenizer.apply_chat_template(prompt, tokenize=False) | ||
conversation = copy.deepcopy(prompt) | ||
conversation.append(assistant_content[0]) | ||
if tokenizer.apply_chat_template(conversation) == test_prompt + test_assistant: | ||
return True | ||
else: | ||
return False | ||
except Exception as e: | ||
return False | ||
|
||
|
||
def fix_chat_template_if_needed(tokenizer, prompt: List[Dict[str, str]], chosen: List[Dict[str, str]], | ||
rejected: List[Dict[str, str]]): | ||
""" | ||
Fixes the chat template if needed. | ||
Args: | ||
tokenizer: The tokenizer. | ||
prompt: The initial prompt message. | ||
chosen: The chosen response, a list containing a single dictionary representing the chosen message. | ||
rejected: The rejected response, a list containing a single dictionary representing the rejected message. | ||
Returns: | ||
- tuple: A tuple containing the fixed prompt, fixed chosen response, and fixed rejected response. | ||
""" | ||
conversation_chosen = copy.deepcopy(prompt) | ||
conversation_rejected = copy.deepcopy(prompt) | ||
conversation_chosen.append(chosen[0]) | ||
conversation_rejected.append(rejected[0]) | ||
conversation_chosen = tokenizer.apply_chat_template(conversation_chosen, tokenize=False) | ||
conversation_rejected = tokenizer.apply_chat_template(conversation_rejected, tokenize=False) | ||
# find position | ||
start_position = conversation_chosen.find(chosen[0]['content'][0]) | ||
# The following is right | ||
fixed_prompt = conversation_chosen[:start_position] | ||
fixed_chosen = conversation_chosen[start_position:] | ||
fixed_rejected = conversation_rejected[start_position:] | ||
return fixed_prompt, fixed_chosen, fixed_rejected |