Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 18, 2024
1 parent d05ecaa commit 1553604
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions utils/data_process.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import copy
import json
from typing import List, Dict

from torch.utils.data import Dataset
from loguru import logger

Expand All @@ -16,6 +19,26 @@ def __init__(self, file, tokenizer, max_length):
def __len__(self):
return len(self.data_list)

def __getitem__1(self, item):
data = self.data_list[item]
data = json.loads(data)
input_ids, target_mask = [], []
message = data['message']
text = self.tokenizer.apply_chat_template(message, tokenize=False)
start_position = 0
for i, conv in enumerate(message):
if conv['role'] == 'assistant':
position = text[start_position:].find(conv('content'))
start_position += position
fixed_prompt = text[start_position:position]
fixed_output = text[position:position + len(conv('content'))]
input_tokens = self.tokenizer.encode(fixed_prompt, add_special_tokens=False)
output_tokens = self.tokenizer.encode(fixed_output, add_special_tokens=False)
input_ids += input_tokens
target_mask += [0] * len(input_tokens)
input_ids += output_tokens
target_mask += [1] * len(output_tokens)

def __getitem__(self, item):
# 开始拼接每条数据
data = self.data_list[item]
Expand Down

0 comments on commit 1553604

Please sign in to comment.