Skip to content

Commit

Permalink
fix returned dtypes for LLaVA (oobabooga#1547)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wojtab authored Apr 26, 2023
1 parent 9b272bc commit 65beb51
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion extensions/llava/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ def tokenizer_modifier(state, prompt, input_ids, input_embeds):

prompt, input_ids, input_embeds, total_embedded = llava_embedder.forward(prompt, images, state)
print(f'LLaVA - Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
return prompt, input_ids.unsqueeze(0).to(shared.model.device), input_embeds.unsqueeze(0).to(shared.model.device)
return (prompt,
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))


def ui():
Expand Down

0 comments on commit 65beb51

Please sign in to comment.