-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] handle prompt length for multi-GPU #87
Conversation
[dev to main] v1.3.1
[dev to main] v1.3.1 hot fixes
[dev to main] v1.3.2
…ngth_for_multi_gpus
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
全体的にLGTMです!importの件だけ,書き方をご検討いただければ
@@ -7,7 +7,7 @@ | |||
import torch | |||
from accelerate.utils import find_executable_batch_size | |||
from loguru import logger | |||
from sentence_transformers import SentenceTransformer | |||
from sentence_transformers import SentenceTransformer, models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: models
の中のclassは一個しか使われなかったので,from sentence_transformers.models import Pooling
と書くのがどうでしょう(self.model
と混同するのを防ぐため)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
確かにそちらの方が良さそうですね、ありがとうございます、変更します!(modelsはちょっと一般的すぎるモジュール名で嫌ですよね)
関連する Issue / PR
埋め込み作成時にpromptを利用するモデルについて、DPSentenceTransformerを用いてmulti GPUで
encode
関数を用いるとエラーが発生する。PR をマージした後の挙動の変化
encode
関数では内部でprompt_length
と呼ばれる値を作成している。埋め込み作成時にこの値がdict型の
feature
変数に格納される。この時、prompt_length
はint型である。encode
関数はDPでの推論時にfeature
の中身をgather
するが、このgather
はtorch.Tensor
型以外を受け取るとエラーになる。現状は
prompt_length
がfeature
変数を介してgather
関数に渡っているため、promptとともにencode
関数を使用するとエラーが発生する。この問題を解決する。
挙動の変更を達成するために行ったこと
対策として、
feature
中にprompt_length
がkeyとして含まれる場合、その値を入力文数と同じ数・同じ値を持つtorch.Tensor
型に変換する。また、そもそも
prompt_length
はinclude_prompt=False
なモデルのためのパラメータであるが、include_prompt=True
の場合はprompt_length
自体が不要なので、その場合は事前にkeyを消しておく。動作確認として、埋め込み作成時にprefixを用いる例をテストに追加した。
動作確認