Skip to content

Commit 29d9af0

Browse files
committed
update for data format
1 parent a8a742b commit 29d9af0

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

spin/alignment/data.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def _strip_prefix(s, pattern):
3131
rejected_messages = example["generated"][1:]
3232
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
3333
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
34-
# example["text_prompt"] = tokenizer.apply_chat_template(
35-
# prompt_messages, tokenize=False, add_generation_prompt=True
36-
# )
34+
example["text_prompt"] = tokenizer.apply_chat_template(
35+
prompt_messages, tokenize=False, add_generation_prompt=True
36+
)
3737
example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix)
3838
example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix)
3939
else:

spin/run_spin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def main():
8989
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
9090
for split in ["train", "test"]:
9191
raw_datasets[split] = raw_datasets[split].rename_columns(
92-
{"text_real": "real", "text_generated": "generated"}
92+
{"text_prompt": "prompt", "text_real": "real", "text_generated": "generated"}
9393
)
9494

9595
torch_dtype = (

0 commit comments

Comments
 (0)