forked from mst272/LLM-Dojo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/mst272/LLM-Dojo
- Loading branch information
Showing
10 changed files
with
406 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class RlhfDataset(Dataset): | ||
def __init__(self, data, tokenizer): | ||
self.data = data | ||
|
||
self.encoded_texts = [] | ||
for entry in data: | ||
prompt = entry["prompt"] | ||
rejected_response = entry["rejected"] | ||
chosen_response = entry["chosen"] | ||
|
||
prompt_tokens = tokenizer.encode(prompt) | ||
chosen_full_text = f"{prompt}\n\n### Response:\n{chosen_response}" | ||
rejected_full_text = f"{prompt}\n\n### Response:\n{rejected_response}" | ||
chosen_full_tokens = tokenizer.encode(chosen_full_text) | ||
rejected_full_tokens = tokenizer.encode(rejected_full_text) | ||
|
||
self.encoded_texts.append({ | ||
"prompt": prompt_tokens, | ||
"chosen": chosen_full_tokens, | ||
"rejected": rejected_full_tokens, | ||
}) | ||
|
||
def __getitem__(self, index): | ||
return self.encoded_texts[index] | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
|
||
def data_collate( | ||
batch, | ||
pad_token_id=50256, | ||
allowed_max_length=None, | ||
mask_prompt_tokens=True, | ||
device="cpu" | ||
): | ||
# Initialize lists to hold batch data | ||
batch_data = { | ||
"prompt": [], | ||
"chosen": [], | ||
"rejected": [], | ||
"rejected_mask": [], | ||
"chosen_mask": [] | ||
|
||
} | ||
|
||
# Determine the longest sequence to set a common padding length | ||
max_length_common = 0 | ||
if batch: | ||
for key in ["chosen", "rejected"]: | ||
current_max = max(len(item[key]) + 1 for item in batch) | ||
max_length_common = max(max_length_common, current_max) | ||
|
||
# Process each item in the batch | ||
for item in batch: | ||
prompt = torch.tensor(item["prompt"]) | ||
batch_data["prompt"].append(prompt) | ||
|
||
for key in ["chosen", "rejected"]: | ||
# Adjust padding according to the common maximum length | ||
sequence = item[key] | ||
padded = sequence + [pad_token_id] * (max_length_common - len(sequence)) | ||
mask = torch.ones(len(padded)).bool() | ||
|
||
# Set mask for all padding tokens to False | ||
mask[len(sequence):] = False | ||
|
||
# Set mask for all input tokens to False | ||
# +2 sets the 2 newline ("\n") tokens before "### Response" to False | ||
if mask_prompt_tokens: | ||
mask[:prompt.shape[0] + 2] = False | ||
|
||
batch_data[key].append(torch.tensor(padded)) | ||
batch_data[f"{key}_mask"].append(mask) | ||
|
||
# Final processing | ||
for key in ["chosen", "rejected", "chosen_mask", "rejected_mask"]: | ||
# Stack all sequences into a tensor for the given key | ||
tensor_stack = torch.stack(batch_data[key]) | ||
|
||
# Optionally truncate to maximum sequence length | ||
if allowed_max_length is not None: | ||
tensor_stack = tensor_stack[:, :allowed_max_length] | ||
|
||
# Move to the specified device | ||
batch_data[key] = tensor_stack.to(device) | ||
|
||
return batch_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.