Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] handle prompt length for multi-GPU #87

Merged
merged 10 commits into from
Dec 9, 2024

Conversation

hppRC
Copy link
Collaborator

@hppRC hppRC commented Dec 9, 2024

関連する Issue / PR

埋め込み作成時にpromptを利用するモデルについて、DPSentenceTransformerを用いてmulti GPUでencode関数を用いるとエラーが発生する。

PR をマージした後の挙動の変化

encode関数では内部でprompt_lengthと呼ばれる値を作成している。
埋め込み作成時にこの値がdict型のfeature変数に格納される。この時、prompt_lengthはint型である。
encode関数はDPでの推論時にfeatureの中身をgatherするが、このgathertorch.Tensor型以外を受け取るとエラーになる。
現状はprompt_lengthfeature変数を介してgather関数に渡っているため、promptとともにencode関数を使用するとエラーが発生する。
この問題を解決する。

挙動の変更を達成するために行ったこと

対策として、feature中にprompt_lengthがkeyとして含まれる場合、その値を入力文数と同じ数・同じ値を持つtorch.Tensor型に変換する。

また、そもそもprompt_lengthinclude_prompt=Falseなモデルのためのパラメータであるが、include_prompt=Trueの場合はprompt_length自体が不要なので、その場合は事前にkeyを消しておく。

動作確認として、埋め込み作成時にprefixを用いる例をテストに追加した。

動作確認

  • テストが通ることを確認した
  • マージ先がdevブランチであることを確認した

@hppRC hppRC requested a review from akiFQC December 9, 2024 08:41
@hppRC hppRC requested a review from lsz05 December 9, 2024 08:59
Copy link
Collaborator

@akiFQC akiFQC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@lsz05 lsz05 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

全体的にLGTMです!importの件だけ,書き方をご検討いただければ

@@ -7,7 +7,7 @@
import torch
from accelerate.utils import find_executable_batch_size
from loguru import logger
from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer, models
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: modelsの中のclassは一個しか使われなかったので,from sentence_transformers.models import Poolingと書くのがどうでしょう(self.modelと混同するのを防ぐため)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

確かにそちらの方が良さそうですね、ありがとうございます、変更します!(modelsはちょっと一般的すぎるモジュール名で嫌ですよね)

@lsz05 lsz05 changed the title Fix/handle prompt length for multi gpus [Fix] handle prompt length for multi-GPU Dec 9, 2024
@hppRC hppRC merged commit 06dbef6 into dev Dec 9, 2024
3 checks passed
@lsz05 lsz05 mentioned this pull request Dec 11, 2024
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants