Skip to content

Commit

Permalink
fix sft bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 19, 2024
1 parent 1553604 commit de8378d
Showing 1 changed file with 30 additions and 43 deletions.
73 changes: 30 additions & 43 deletions utils/data_process.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import copy
import json
from typing import List, Dict

from torch.utils.data import Dataset
from loguru import logger

Expand All @@ -19,54 +16,24 @@ def __init__(self, file, tokenizer, max_length):
def __len__(self):
return len(self.data_list)

def __getitem__1(self, item):
def __getitem__(self, item):
# 开始自动判断并适配chat template
data = self.data_list[item]
data = json.loads(data)
input_ids, target_mask = [], []
message = data['message']

text = self.tokenizer.apply_chat_template(message, tokenize=False)
input_ids += self.tokenizer.encode(text, add_special_tokens=False)
target_mask += [0] * len(input_ids)
start_position = 0
for i, conv in enumerate(message):
if conv['role'] == 'assistant':
position = text[start_position:].find(conv('content'))
start_position += position
fixed_prompt = text[start_position:position]
fixed_output = text[position:position + len(conv('content'))]
input_tokens = self.tokenizer.encode(fixed_prompt, add_special_tokens=False)
output_tokens = self.tokenizer.encode(fixed_output, add_special_tokens=False)
input_ids += input_tokens
target_mask += [0] * len(input_tokens)
input_ids += output_tokens
target_mask += [1] * len(output_tokens)

def __getitem__(self, item):
# 开始拼接每条数据
data = self.data_list[item]
data = json.loads(data)
input_ids, target_mask = [], []

if data['message'][0]['role'] == 'system':
system_text = [data['message'][0]]
system_text = self.tokenizer.apply_chat_template(system_text, tokenize=False, add_generation_prompt=False,
return_tensors="pt")
input_ids = self.tokenizer.encode(system_text, add_special_tokens=False)
target_mask = [0] * len(input_ids)

message = data['message']
# 拼接多轮对话 todo:优化apply,有小问题,同rlhf问题
for i, conv in enumerate(message):
if conv['role'] == 'user':
user_text = self.tokenizer.apply_chat_template([conv], tokenize=False, add_generation_prompt=False,
return_tensors="pt")
input_tokens = self.tokenizer.encode(user_text, add_special_tokens=False)
input_ids += input_tokens
target_mask += [0] * len(input_tokens)
elif conv['role'] == 'assistant':
assistant_text = self.tokenizer.apply_chat_template([conv], tokenize=False, add_generation_prompt=False,
return_tensors="pt")
output_tokens = self.tokenizer.encode(assistant_text, add_special_tokens=False)
input_ids += output_tokens
target_mask += [1] * len(output_tokens)
assistant_ids = self.tokenizer.encode(conv['content'], add_special_tokens=False)
position = find_sublist_start(input_ids[start_position:], assistant_ids)
assistant_len = len(assistant_ids)
target_mask[start_position + position:start_position + position + assistant_len] = [1] * assistant_len
start_position += position + assistant_len

# 判断一下输入和掩码长度是否相等
assert len(input_ids) == len(target_mask)
Expand Down Expand Up @@ -167,3 +134,23 @@ def __getitem__(self, item):
# 适配DPOTrainer的接口
def map(self, func, **kwargs):
return self


def find_sublist_start(main_list, sub_list):
"""
find_sublist_start
Args:
main_list (list)
sub_list (list)
"""
sub_len = len(sub_list)
main_len = len(main_list)

for i in range(main_len - sub_len + 1):
if main_list[i:i + sub_len] == sub_list:
return i
# 因为会有开头decode不一样的情况出现
elif main_list[i + 1:i + sub_len] == sub_list[1:]:
return i
return -1

0 comments on commit de8378d

Please sign in to comment.