Skip to content

Commit

Permalink
优化faster_whisper模型加载逻辑,大幅提升处理速度
Browse files Browse the repository at this point in the history
  • Loading branch information
Ikaros-521 committed Apr 13, 2024
1 parent 2ccfc4a commit 5f0e5db
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
30 changes: 20 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
# 点火起飞
def start_server():
global config, common, my_handle, last_username_list, config_path, last_liveroom_data
global do_listen_and_comment_thread, stop_do_listen_and_comment_thread_event
global do_listen_and_comment_thread, stop_do_listen_and_comment_thread_event, faster_whisper_model


# 按键监听相关
Expand Down Expand Up @@ -274,13 +274,27 @@ def audio_listen(volume_threshold=800.0, silence_threshold=15):

# 执行录音、识别&提交
def do_listen_and_comment(status=True):
global stop_do_listen_and_comment_thread_event
global stop_do_listen_and_comment_thread_event, faster_whisper_model

config = Config(config_path)

# 是否启用按键监听,不启用的话就不用执行了
if False == config.get("talk", "key_listener_enable"):
return

# 针对faster_whisper情况,模型加载一次共用,减少开销
if "faster_whisper" == config.get("talk", "type") :
from faster_whisper import WhisperModel

if faster_whisper_model is None:
logging.info("faster_whisper 模型加载中,请稍后...")
# Run on GPU with FP16
faster_whisper_model = WhisperModel(model_size_or_path=config.get("talk", "faster_whisper", "model_size"), \
device=config.get("talk", "faster_whisper", "device"), \
compute_type=config.get("talk", "faster_whisper", "compute_type"), \
download_root=config.get("talk", "faster_whisper", "download_root"))
logging.info("faster_whisper 模型加载完毕,可以开始说话了喵~")


while True:
try:
Expand Down Expand Up @@ -374,8 +388,6 @@ def do_listen_and_comment(status=True):
except sr.RequestError as e:
logging.error("请求出错:" + str(e))
elif "faster_whisper" == config.get("talk", "type"):
from faster_whisper import WhisperModel

# 设置音频参数
FORMAT = pyaudio.paInt16
CHANNELS = config.get("talk", "CHANNELS")
Expand All @@ -399,13 +411,9 @@ def do_listen_and_comment(status=True):
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))

# Run on GPU with FP16
model = WhisperModel(model_size_or_path=config.get("talk", "faster_whisper", "model_size"), \
device=config.get("talk", "faster_whisper", "device"), \
compute_type=config.get("talk", "faster_whisper", "compute_type"), \
download_root=config.get("talk", "faster_whisper", "download_root"))
logging.debug("faster_whisper模型加载中...")

segments, info = model.transcribe(WAVE_OUTPUT_FILENAME, beam_size=config.get("talk", "faster_whisper", "beam_size"))
segments, info = faster_whisper_model.transcribe(WAVE_OUTPUT_FILENAME, beam_size=config.get("talk", "faster_whisper", "beam_size"))

logging.debug("识别语言为:'%s',概率:%f" % (info.language, info.language_probability))

Expand Down Expand Up @@ -2566,6 +2574,8 @@ def exit_handler(signum, frame):
# 按键监听相关
do_listen_and_comment_thread = None
stop_do_listen_and_comment_thread_event = None
# 存储加载的模型对象
faster_whisper_model = None

# 信号特殊处理
signal.signal(signal.SIGINT, exit_handler)
Expand Down
2 changes: 1 addition & 1 deletion utils/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def get_priority_level(audio_json):
insert_position = i + 1
break

logging.info(f"insert_position={insert_position}")
logging.debug(f"insert_position={insert_position}")

# 数据队列数据量超长判断,插入位置索引大于最大数,则说明优先级低与队列中已存在数据,丢弃数据
if insert_position >= int(self.config.get("filter", "message_queue_max_len")):
Expand Down

0 comments on commit 5f0e5db

Please sign in to comment.