Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Aug 20, 2024
1 parent 2be5189 commit ddf540a
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 25 deletions.
2 changes: 1 addition & 1 deletion download_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from modelscope import snapshot_download
model_dir = snapshot_download('qwen/Qwen-1_8B-Chat', cache_dir='../download_llm')
model_dir = snapshot_download('qwen/Qwen1.5-0.5B', cache_dir='../download_llm')
print("模型下载完成")
46 changes: 25 additions & 21 deletions llm_tricks/DPO_example/dataset.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,38 @@
import torch
from torch.utils.data import Dataset
import json


class RlhfDataset(Dataset):
def __init__(self, data, tokenizer):
self.data = data

self.encoded_texts = []
for entry in data:
prompt = entry["prompt"]
rejected_response = entry["rejected"]
chosen_response = entry["chosen"]

prompt_tokens = tokenizer.encode(prompt)
chosen_full_text = f"{prompt}\n\n### Response:\n{chosen_response}"
rejected_full_text = f"{prompt}\n\n### Response:\n{rejected_response}"
chosen_full_tokens = tokenizer.encode(chosen_full_text)
rejected_full_tokens = tokenizer.encode(rejected_full_text)

self.encoded_texts.append({
def __init__(self, file_path, tokenizer):
with open(file_path, "r", encoding="utf-8") as file:
data_list = file.readlines()
self.data_list = data_list
self.tokenizer = tokenizer

def __getitem__(self, item):
data = self.data_list[item]
data = json.loads(data)
prompt = data['prompt']
chosen = data['chosen']
rejected = data['rejected']

chosen_full_text = f"{prompt}\n\n### Response:\n{chosen}"
rejected_full_text = f"{prompt}\n\n### Response:\n{rejected}"

prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
chosen_full_tokens = self.tokenizer.encode(chosen_full_text, add_special_tokens=False)
rejected_full_tokens = self.tokenizer.encode(rejected_full_text, add_special_tokens=False)

input = {
"prompt": prompt_tokens,
"chosen": chosen_full_tokens,
"rejected": rejected_full_tokens,
})

def __getitem__(self, index):
return self.encoded_texts[index]
}
return input

def __len__(self):
return len(self.data)
return len(self.data_list)


def data_collate(
Expand Down
Loading

0 comments on commit ddf540a

Please sign in to comment.