Commit 29d9af0 1 parent a8a742b commit 29d9af0 Copy full SHA for 29d9af0
File tree 2 files changed +4
-4
lines changed
2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -31,9 +31,9 @@ def _strip_prefix(s, pattern):
31
31
rejected_messages = example ["generated" ][1 :]
32
32
example ["text_chosen" ] = tokenizer .apply_chat_template (chosen_messages , tokenize = False )
33
33
example ["text_rejected" ] = tokenizer .apply_chat_template (rejected_messages , tokenize = False )
34
- # example["text_prompt"] = tokenizer.apply_chat_template(
35
- # prompt_messages, tokenize=False, add_generation_prompt=True
36
- # )
34
+ example ["text_prompt" ] = tokenizer .apply_chat_template (
35
+ prompt_messages , tokenize = False , add_generation_prompt = True
36
+ )
37
37
example ["text_chosen" ] = _strip_prefix (example ["text_chosen" ], assistant_prefix )
38
38
example ["text_rejected" ] = _strip_prefix (example ["text_rejected" ], assistant_prefix )
39
39
else :
Original file line number Diff line number Diff line change @@ -89,7 +89,7 @@ def main():
89
89
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
90
90
for split in ["train" , "test" ]:
91
91
raw_datasets [split ] = raw_datasets [split ].rename_columns (
92
- {"text_real" : "real" , "text_generated" : "generated" }
92
+ {"text_prompt" : "prompt" , " text_real" : "real" , "text_generated" : "generated" }
93
93
)
94
94
95
95
torch_dtype = (
You can’t perform that action at this time.
0 commit comments