Skip to content

Commit

Permalink
fix bugs and update dpo example
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 15, 2024
1 parent 7195dcf commit 2be5189
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 86 deletions.
5 changes: 3 additions & 2 deletions main_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ def load_dpo_dataset(args, tokenizer):
train_dataset = load_dataset(data_files=args.train_data_path, path='json')

def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
row["prompt"] = tokenizer.apply_chat_template(row["chosen"][:-1], tokenize=False)
row["chosen"] = tokenizer.apply_chat_template(row["chosen"][-1], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"][-1], tokenize=False)
return row

train_dataset = train_dataset.map(process)
Expand Down
14 changes: 0 additions & 14 deletions rlhf/rlhf_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from common_args import CommonArgs


def load_config(args):
# 根据config_option加载相应的配置
module_path = args.train_args_path.replace("/", ".").rstrip(".py")
Expand Down Expand Up @@ -151,19 +150,6 @@ def main():
policy = prepare_model_for_kbit_training(policy, use_gradient_checkpointing=config.gradient_checkpointing)
policy = get_peft_model(policy, lora_config)

################
# Dataset
################
# raw_datasets = pd.read_json(config.train_data_path, lines=True)
# for i in range(len(raw_datasets)):
# pro = raw_datasets['prompt'][i]
# res = tokenizer.apply_chat_template(pro, tokenize=False)
# raw_datasets.loc[i, 'prompt'] = res
# raw_datasets = Dataset.from_pandas(raw_datasets, preserve_index=False)
# eval_samples = config.eval_samples
# train_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples))
# eval_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples, len(raw_datasets)))

################
# Training
################
Expand Down
4 changes: 3 additions & 1 deletion train_args/dpo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ DPO训练方式均支持框架中的deepspeed或者python启动模式,相应

后者是自己从零构建的数据组织形式,理论上按照DPOTrainer相同形式,只实现了单轮。这样的**目的是为了更好地理解DPO的过程以及方便一些魔改操作**,权当学习使用。

对于DPO数据,可见```data/dpo_multi_data.jsonl```示例数据。
🤓**注意:** 对于DPO数据,可见```data/dpo_multi_data.jsonl```示例数据。数据是huggingface的hh-rlhf-helpful-base-trl-style格式数据,其中prompt是一句话,而chosen和
rejected则是包含prompt的完整对话。故如构建自己的数据集时,无论多轮和单轮,都应在chosen和rejected中加入prompt,单轮相当于取第一句当prompt,
多轮相当于取最后一句之前的所有当prompt(其实还可以取每一轮的user当prompt,后面有时间可能会实现)。

对于自己构建的single_dpo数据格式,示例为:
```json lines
Expand Down
75 changes: 6 additions & 69 deletions utils/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,72 +3,6 @@
from loguru import logger


# class CommonSingleRoundDataProcess(Dataset):
# """
# 默认使用的单轮对话常规dataset
# """
#
# def __init__(self, file, tokenizer, max_length, template):
# self.tokenizer = tokenizer
# self.template_name = template.template_name
# self.system_format = template.system_format
# self.user_format = template.user_format
# self.assistant_format = template.assistant_format
# self.system = template.system
#
# self.max_length = max_length
#
# logger.info(f'Loading data: {file}')
# with open(file, 'r', encoding='utf8') as f:
# data_list = f.readlines()
# logger.info(f'Use template "{self.template_name}" for training')
# logger.info(f"There are {len(data_list)} data in dataset")
# self.data_list = data_list
#
# def __len__(self):
# return len(self.data_list)
#
# def __getitem__(self, item):
# # 开始拼接每条数据
# data = self.data_list[item]
# data = json.loads(data)
# input_ids, target_mask = [], []
#
# if self.system_format is not None:
# system = self.system
# if system is not None:
# system_text = self.system_format.format(content=system)
# input_ids = self.tokenizer.encode(system_text, add_special_tokens=False)
# target_mask = [0] * len(input_ids)
# instruction = data['instruction']
# output = data['output']
#
# instruction_text = self.user_format.format(content=instruction, stop_token=self.tokenizer.eos_token)
# output_text = self.assistant_format.format(content=output, stop_token=self.tokenizer.eos_token)
#
# input_tokens = self.tokenizer.encode(instruction_text, add_special_tokens=False)
# output_tokens = self.tokenizer.encode(output_text, add_special_tokens=False)
#
# input_ids += input_tokens + output_tokens
# target_mask += [0] * len(input_tokens) + [1] * len(output_tokens)
#
# # 判断一下输入和掩码长度是否相等
# assert len(input_ids) == len(target_mask)
#
# # 对长度进行截断
# input_ids = input_ids[:self.max_length]
# target_mask = target_mask[:self.max_length]
# attention_mask = [1] * len(input_ids)
# assert len(input_ids) == len(target_mask) == len(attention_mask)
# inputs = {
# "input_ids": input_ids,
# "attention_mask": attention_mask,
# "target_mask": target_mask
# }
#
# return inputs
#

class MultiRoundDataProcess(Dataset):
def __init__(self, file, tokenizer, max_length):
self.tokenizer = tokenizer
Expand All @@ -90,20 +24,23 @@ def __getitem__(self, item):

if data['message'][0]['role'] == 'system':
system_text = [data['message'][0]]
system_text = self.tokenizer.apply_chat_template(system_text, tokenize=False, add_generation_prompt=False, return_tensors="pt")
system_text = self.tokenizer.apply_chat_template(system_text, tokenize=False, add_generation_prompt=False,
return_tensors="pt")
input_ids = self.tokenizer.encode(system_text, add_special_tokens=False)
target_mask = [0] * len(input_ids)

message = data['message']
# 拼接多轮对话
for i, conv in enumerate(message):
if conv['role'] == 'user':
user_text = self.tokenizer.apply_chat_template([conv], tokenize=False, add_generation_prompt=False, return_tensors="pt")
user_text = self.tokenizer.apply_chat_template([conv], tokenize=False, add_generation_prompt=False,
return_tensors="pt")
input_tokens = self.tokenizer.encode(user_text, add_special_tokens=False)
input_ids += input_tokens
target_mask += [0] * len(input_tokens)
elif conv['role'] == 'assistant':
assistant_text = self.tokenizer.apply_chat_template([conv], tokenize=False, add_generation_prompt=False,return_tensors="pt")
assistant_text = self.tokenizer.apply_chat_template([conv], tokenize=False, add_generation_prompt=False,
return_tensors="pt")
output_tokens = self.tokenizer.encode(assistant_text, add_special_tokens=False)
input_ids += output_tokens
target_mask += [1] * len(output_tokens)
Expand Down

0 comments on commit 2be5189

Please sign in to comment.