Skip to content

Commit

Permalink
fix add_generation_prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
zRzRzRzRzRzRzR committed Dec 9, 2024
1 parent 821910f commit 8268d5f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def process_batch(
batched_input_ids = []
batched_labels = []
for conv in batched_conv:
new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False, add_generation_prompt=False)
input_ids = new_input_ids
loss_masks = [False] * len(input_ids)
last_assistant_index = len(input_ids) - input_ids[::-1].index(59254) - 1 # <|assistant|>
Expand Down
4 changes: 2 additions & 2 deletions finetune/finetune_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def process_batch(
loss_mask_val = False if message["role"] in ("system", "user") else True
new_input_ids_all = tokenizer.apply_chat_template(
[message],
add_generation_prompt=True,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
Expand All @@ -288,7 +288,7 @@ def process_batch(
input_ids.append(59253) # EOS
attention_mask.append(1)
position_ids.append(len(position_ids))
loss_masks.append(False)
loss_masks.append(True)

padding_length = max(0, max_length - len(input_ids))

Expand Down

0 comments on commit 8268d5f

Please sign in to comment.