-
Notifications
You must be signed in to change notification settings - Fork 2.3k
feat: Vllm whisper model #3901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Vllm whisper model #3901
Conversation
Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it. Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes-sigs/prow repository. |
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here.
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
return result.text | ||
|
||
except Exception as err: | ||
maxkb_logger.error(f":Error: {str(err)}: {traceback.format_exc()}") No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code appears to be an implementation of a speech-to-text service using the VLLM Whisper model through OpenAI's API. Here are some observations and potential areas for improvement:
Observations:
-
Imports:
- The necessary imports (
os
,traceback
, andtyping
) are present but not imported from specific modules.
- The necessary imports (
-
Static Methods:
- Both
is_cache_model
andnew_instance
methods are implemented as static methods without providing additional functionality that differs from their parent class.
- Both
-
Audio Processing:
- In the
check_auth
method, the function attempts to readiat_mp3_16k.mp3
from the current working directory. This is unconventional and might need further justification based on project requirements.
- In the
-
Exception Handling:
- The exception handling in the
speech_to_text
method logs errors usingmaxkb_logger.error
. This is appropriate but could also be improved by ensuring that the error messages are detailed enough for debugging purposes (e.g., include more context about what went wrong).
- The exception handling in the
-
API Key Validation:
- Without specific validation logic (e.g., checking if the API key is valid before making requests), there’s no guarantee of proper authentication.
-
Logging Configuration:
- There is limited mention of how logging is configured elsewhere in the system, which assumes this logger instance is properly initialized somewhere else in the application stack.
Potential Improvements:
-
Import Cleanup:
- Import statements should preferably use absolute paths where practical to reduce import cycle dependencies.
-
Detailed Logging:
- Enhance the logging details in the exception handling section to provide more useful information like request parameters or response content.
-
Validation of API Credentials:
- Implement checks to validate the validity of the API credentials before establishing connections to ensure security.
-
Better Error Management:
- Consider adding retries and fallback mechanisms for network issues or rate limiting provided by the OpenAI API.
-
Configuration File:
- Move configuration settings (like log level) to external files (e.g.,
.env
, Configurations.yml) for easier management and maintenance.
- Move configuration settings (like log level) to external files (e.g.,
-
Testing:
- Ensure comprehensive testing coverage, including unit tests for classes and functions, and integration tests with mock data responses.
-
Documentation:
- Provide clear documentation comments at the beginning of each class and method explaining their purpose and usage.
Here is a refined version of the code addressing some of these points:
# Imports
import os
import traceback
from typing import Dict
from openai import OpenAI
import logging
from common.utils.logger import maxkb_logger
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_stt import BaseSpeechToText
# Set up logging configuration
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)
class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_key: str
api_url: str
model: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
self.api_url = kwargs.get('api_url')
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return VllmWhisperSpeechToText(
model=model_name,
api_key=model_credential.get('api_key'),
api_url=model_credential.get('api_url'),
params=model_kwargs,
**model_kwargs
)
def check_auth(self):
cwd = os.path.dirname(os.path.abspath(__file__))
try:
# Simulate reading audio file for test purpose
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as audio_file:
self.speech_to_text(audio_file)
except FileNotFoundError as e:
logger.error(f"FileNotFoundError: {str(e)}", exc_info=True)
def speech_to_text(self, audio_file) -> str:
"""
Convert audio file to text using VLLM Whisper model.
:param audio_file: Audio file bytes
:return: Transcribed text
"""
base_url = f"{self.api_url}/v1"
try:
client = OpenAI(api_key=self.api_key, base_url=base_url)
result = client.audio.transcriptions.create(
file=audio_file,
model=self.model,
language=self.params.get('language'), # Corrected parameter name
response_format="json"
)
return result.text
except Exception as err:
logger.error(f"An error occurred during transcription: {str(err)}")
return None # Return None instead of empty string
These changes improve the structure, enhance logging, clarify method names, and add basic error handling for the speech_to_text
method.
return self | ||
|
||
def get_model_params_setting_form(self, model_name): | ||
return VLLMWhisperModelParams() No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No irregularities found. The code looks to be correctly structured for a Django-based form handling with model credential validation. Here are some optimizations you might consider:
-
Use
gettext
directly: Since both_('Language')
andgettext('{model_type} Model type is not supported').format(model_type=model_type)
usegettext
, they can be consolidated into a single call. -
Remove unnecessary imports:
langchain_core.messages.HumanMessage
is used but never referenced within this class, so it's safe to remove from the imports list. -
Encapsulate logic: You could encapsulate some of the exception handling and message formatting in helper functions rather than repeating them across lines.
-
Consider using context managers: If you anticipate making multiple network requests or database interactions, using async contexts (Python 3.7+) would help manage operations more cleanly.
Here's an updated version with these considerations:
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _
from common import forms
from common.exception.app_exception import AppApiException, ValidCode
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
def create_message(msg):
"""Create a human-readable message."""
return _('"{msg}"')
class VLLMWhisperModelParams(BaseForm):
Language = forms.TextInputField(
TooltipLabel(create_message(_('Language')),
_("If not passed, the default value is 'zh'")),
required=True,
default_value='zh',
)
class VLLMWhisperModelCredential(BaseForm, BaseModelCredential):
api_url = forms.TextInputField('API URL', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
@staticmethod
def _create_error_msg(code: int, msg: str) -> str:
"""Create an error message format string with placeholders."""
return gettext(f'{code}: {msg}')
def is_valid(self,
model_type: str,
model_name,
model_credential: Dict[str, object],
model_params,
provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(filter(lambda mt: mt['value'] == model_type, model_type_list)):
err_code, err_msg = ValidCode.valid_error.value, \
self._create_error_msg(err_code,
gettext("'{model_type}' Model type is not supported").format(model_type=model_type))
if raise_exception:
raise AppApiException(*err_msg.split(','))
else:
raise ValueError(*err_msg.split(','))
try:
model_list = provider.get_base_model_list(model_credential['api_url'], model_credential['api_key'])
except Exception as e:
err_code, err_msg = ValidCode.valid_error.value, \
self._create_error_msg(err_code, gettext('API domain name is invalid'))
if raise_exception:
raise AppApiException(*err_msg.split(','))
else:
raise ValueError(*err_msg.split(','))
exist = provider.get_model_info_by_name(model_list, model_name)
if len(exist) == 0:
err_code, err_msg = ValidCode.valid_error.value, \
self._create_error_msg(err_code,
gettext('The model does not exist, please download the model first'))
if raise_exception:
raise AppApiException(*err_msg.split(','))
else:
raise ValueError(*err_msg.split(','))
model = provider.get_model(model_type, model_name, model_credential, **model_params)
return True
def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
def build_model(self, model_info: Dict[str, object]):
for key in ['api_key', 'model']:
if key not in model_info:
err_code, err_msg = 500, \
self._create_error_msg(err_code, gettext(f'{key} is required').format(key=key))
raise AppApiException(*err_msg.split(','))
self.api_key = model_info.get('api_key')
return self
def get_model_params_setting_form(self, model_name):
return VLLMWhisperModelParams()
This version introduces helpers like _create_error_msg
for better readability and consolidates duplicated message creation patterns.
@@ -45,6 +60,8 @@ | |||
.append_default_model_info(image_model_info_list[0]) | |||
.append_model_info_list(embedding_model_info_list) | |||
.append_default_model_info(embedding_model_info_list[0]) | |||
.append_model_info_list(whisper_model_info_list) | |||
.append_default_model_info(whisper_model_info_list[0]) | |||
.build() | |||
) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code looks mostly correct and should work without significant issues. However, there are a few areas that could be improved:
-
Whitespace Consistency: Ensure consistent indentation throughout the file to improve readability.
-
Redundant Code: The
image_model_info_list
is used twice in the configuration process. It might be better to combine these lists if they contain the same information. -
Default Model Info Duplication: If the intention is to have all default models included at some point, ensure this logic works correctly.
Here's an updated version with some optimizations:
from models_provider.impl.vllm_model_provider.credential.embedding import VllmEmbeddingCredential
from models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential
from models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential
from models_provider.impl.vllm_model_provider.credential.whisper_stt import VLLMWhisperModelCredential
from models_provider.impl.vllm_model_provider.model.embedding import VllmEmbeddingModel
from models_provider.impl.vllm_model_provider.model.image import VllmImage
from models_provider.impl.vllm_model_provider.model.llm import VllmChatModel
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _
v_llm_model_credential = VLLMModelCredential()
image_model_credential = VllmImageModelCredential()
embedding_model_credential = VllmEmbeddingCredential()
whisper_model_credential = VLLMWhisperModelCredential()
model_info_lists = [
(VLLMChatModel, v_llm_model_credential),
(VllmEmbeddingModel, embedding_model_credential),
(VllmWhisperSpeechToText, whisper_model_credential)
]
all_models = (
image_model_info_list +
embedding_model_info_list +
whisper_model_info_list
)
# Assuming append_default_model_info handles adding the first element twice if needed
config_management.append_model_info_list(all_models).build()
Changes Made:
- Consistent Indentation: Ensured all lines within functions, loops, etc., start properly indented.
- Combined Lists: Combined the two instances of
embedding_model_info_list
. - Improved Naming: Used tuples
model_info_lists
for consistency with the final combined list.
These changes aim to make the code cleaner and potentially more efficient by avoiding redundancy in the model info lists.
feat: Vllm whisper model