Skip to content

Commit

Permalink
fix chat template bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 15, 2024
1 parent 2783673 commit 8b709b0
Showing 1 changed file with 34 additions and 16 deletions.
50 changes: 34 additions & 16 deletions rlhf/rlhf_train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib
import multiprocessing

import copy
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from datasets import Dataset
from transformers import (
Expand All @@ -15,6 +15,7 @@
import torch.nn as nn
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from common_args import CommonArgs
from loguru import logger


def load_config(args):
Expand Down Expand Up @@ -80,26 +81,43 @@ def tokenize(element):
return train_dataset, eval_dataset


def judge_chat_template():
pass


def test_tokenizer_chat_template(tokenizer, data):
bos = tokenizer.bos_token
eos = tokenizer.eos_token
test_prompt = tokenizer.apply_chat_template(data['prompt'][0], tokenize=False)
pass
def test_tokenizer_chat_template(tokenizer, prompt, chosen):
test_prompt = tokenizer.apply_chat_template(prompt, tokenize=False)
test_chosen = tokenizer.apply_chat_template(chosen, tokenize=False)
one_conversation = copy.deepcopy(prompt)
one_conversation.append(chosen[-1])
one_conversation = tokenizer.apply_chat_template(one_conversation, tokenize=False)
if one_conversation == test_prompt + test_chosen:
return True
else:
logger.warning("Chat template is not right, Start automatic repair!")
return False


def load_data_all(tokenizer, train_data_path, eval_samples):
raw_datasets = pd.read_json(train_data_path, lines=True)
for i in range(len(raw_datasets)):
raw_datasets.loc[i, 'prompt'] = tokenizer.apply_chat_template(raw_datasets['prompt'][i], tokenize=False)
raw_datasets.loc[i, 'chosen'] = raw_datasets.loc[i, 'prompt'] + tokenizer.apply_chat_template(
raw_datasets['chosen'][i], tokenize=False)
raw_datasets.loc[i, 'rejected'] = raw_datasets.loc[i, 'prompt'] + tokenizer.apply_chat_template(
raw_datasets['rejected'][i], tokenize=False)
prompt, chosen = raw_datasets['prompt'][0], raw_datasets['chosen'][0]
if not test_tokenizer_chat_template(tokenizer, prompt, chosen):
for i in range(len(raw_datasets)):
conversation_chosen = raw_datasets['prompt'][i][:]
conversation_rejected = raw_datasets['prompt'][i][:]
conversation_chosen.append(raw_datasets['chosen'][i][-1])
conversation_rejected.append(raw_datasets['rejected'][i][-1])
conversation_chosen = tokenizer.apply_chat_template(conversation_chosen, tokenize=False)
conversation_rejected = tokenizer.apply_chat_template(conversation_rejected, tokenize=False)
start_position = conversation_chosen.find(raw_datasets['chosen'][i][-1]['content'])
raw_datasets.loc[i, 'prompt'] = conversation_chosen[:start_position]
raw_datasets.loc[i, 'chosen'] = conversation_chosen[start_position:]
raw_datasets.loc[i, 'rejected'] = conversation_rejected[start_position:]
else:
for i in range(len(raw_datasets)):
raw_datasets.loc[i, 'prompt'] = tokenizer.apply_chat_template(raw_datasets['prompt'][i], tokenize=False)
raw_datasets.loc[i, 'chosen'] = raw_datasets.loc[i, 'prompt'] + tokenizer.apply_chat_template(
raw_datasets['chosen'][i], tokenize=False)
raw_datasets.loc[i, 'rejected'] = raw_datasets.loc[i, 'prompt'] + tokenizer.apply_chat_template(
raw_datasets['rejected'][i], tokenize=False)
raw_datasets = Dataset.from_pandas(raw_datasets, preserve_index=False)
logger.warning("Now, the chat template is not right!")
train_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples))
eval_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples, len(raw_datasets)))
return train_dataset, eval_dataset
Expand Down

0 comments on commit 8b709b0

Please sign in to comment.