Skip to content

Commit

Permalink
Merge pull request #1190 from pipecat-ai/mb/languages-hosted-whisper
Browse files Browse the repository at this point in the history
Add language support to OpenAI and Groq hosted Whisper
  • Loading branch information
markbackman authored Feb 10, 2025
2 parents 2dc585a + cd52d73 commit 0d2e90c
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 5 deletions.
83 changes: 83 additions & 0 deletions src/pipecat/services/base_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
from pipecat.services.ai_services import SegmentedSTTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601

try:
Expand All @@ -23,6 +24,82 @@
raise Exception(f"Missing module: {e}")


def language_to_whisper_language(language: Language) -> Optional[str]:
"""Language support for Whisper API.
Docs: https://platform.openai.com/docs/guides/speech-to-text#supported-languages
"""
BASE_LANGUAGES = {
Language.AF: "af",
Language.AR: "ar",
Language.HY: "hy",
Language.AZ: "az",
Language.BE: "be",
Language.BS: "bs",
Language.BG: "bg",
Language.CA: "ca",
Language.ZH: "zh",
Language.HR: "hr",
Language.CS: "cs",
Language.DA: "da",
Language.NL: "nl",
Language.EN: "en",
Language.ET: "et",
Language.FI: "fi",
Language.FR: "fr",
Language.GL: "gl",
Language.DE: "de",
Language.EL: "el",
Language.HE: "he",
Language.HI: "hi",
Language.HU: "hu",
Language.IS: "is",
Language.ID: "id",
Language.IT: "it",
Language.JA: "ja",
Language.KN: "kn",
Language.KK: "kk",
Language.KO: "ko",
Language.LV: "lv",
Language.LT: "lt",
Language.MK: "mk",
Language.MS: "ms",
Language.MR: "mr",
Language.MI: "mi",
Language.NE: "ne",
Language.NO: "no",
Language.FA: "fa",
Language.PL: "pl",
Language.PT: "pt",
Language.RO: "ro",
Language.RU: "ru",
Language.SR: "sr",
Language.SK: "sk",
Language.SL: "sl",
Language.ES: "es",
Language.SW: "sw",
Language.SV: "sv",
Language.TL: "tl",
Language.TA: "ta",
Language.TH: "th",
Language.TR: "tr",
Language.UK: "uk",
Language.UR: "ur",
Language.VI: "vi",
Language.CY: "cy",
}

result = BASE_LANGUAGES.get(language)

# If not found in base languages, try to find the base language from a variant
if not result:
lang_str = str(language.value)
base_code = lang_str.split("-")[0].lower()
result = base_code if base_code in BASE_LANGUAGES.values() else None

return result


class BaseWhisperSTTService(SegmentedSTTService):
"""Base class for Whisper-based speech-to-text services.
Expand All @@ -33,6 +110,7 @@ class BaseWhisperSTTService(SegmentedSTTService):
model: Name of the Whisper model to use.
api_key: Service API key. Defaults to None.
base_url: Service API base URL. Defaults to None.
language: Language of the audio input. Defaults to English.
**kwargs: Additional arguments passed to SegmentedSTTService.
"""

Expand All @@ -42,11 +120,13 @@ def __init__(
model: str,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
language: Optional[Language] = Language.EN,
**kwargs,
):
super().__init__(**kwargs)
self.set_model_name(model)
self._client = self._create_client(api_key, base_url)
self._language = self.language_to_service_language(language or Language.EN)

def _create_client(self, api_key: Optional[str], base_url: Optional[str]):
return AsyncOpenAI(api_key=api_key, base_url=base_url)
Expand All @@ -57,6 +137,9 @@ async def set_model(self, model: str):
def can_generate_metrics(self) -> bool:
return True

def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_whisper_language(language)

async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
try:
await self.start_processing_metrics()
Expand Down
14 changes: 11 additions & 3 deletions src/pipecat/services/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription
from pipecat.services.openai import OpenAILLMService
from pipecat.transcriptions.language import Language


class GroqLLMService(OpenAILLMService):
Expand Down Expand Up @@ -52,8 +53,8 @@ class GroqSTTService(BaseWhisperSTTService):
model: Whisper model to use. Defaults to "whisper-large-v3-turbo".
api_key: Groq API key. Defaults to None.
base_url: API base URL. Defaults to "https://api.groq.com/openai/v1".
language: Language of the audio input. Defaults to English.
**kwargs: Additional arguments passed to BaseWhisperSTTService.
"""

def __init__(
Expand All @@ -62,11 +63,18 @@ def __init__(
model: str = "whisper-large-v3-turbo",
api_key: Optional[str] = None,
base_url: str = "https://api.groq.com/openai/v1",
language: Optional[Language] = Language.EN,
**kwargs,
):
super().__init__(model=model, api_key=api_key, base_url=base_url, **kwargs)
super().__init__(
model=model, api_key=api_key, base_url=base_url, language=language, **kwargs
)

async def _transcribe(self, audio: bytes) -> Transcription:
assert self._language is not None # Assigned in the BaseWhisperSTTService class
return await self._client.audio.transcriptions.create(
file=("audio.wav", audio, "audio/wav"), model=self.model_name, response_format="json"
file=("audio.wav", audio, "audio/wav"),
model=self.model_name,
response_format="json",
language=self._language,
)
10 changes: 8 additions & 2 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
TTSService,
)
from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601

try:
Expand Down Expand Up @@ -406,6 +407,7 @@ class OpenAISTTService(BaseWhisperSTTService):
model: Whisper model to use. Defaults to "whisper-1".
api_key: OpenAI API key. Defaults to None.
base_url: API base URL. Defaults to None.
language: Language of the audio input. Defaults to English.
**kwargs: Additional arguments passed to BaseWhisperSTTService.
"""

Expand All @@ -415,13 +417,17 @@ def __init__(
model: str = "whisper-1",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
language: Optional[Language] = Language.EN,
**kwargs,
):
super().__init__(model=model, api_key=api_key, base_url=base_url, **kwargs)
super().__init__(
model=model, api_key=api_key, base_url=base_url, language=language, **kwargs
)

async def _transcribe(self, audio: bytes) -> Transcription:
assert self._language is not None # Assigned in the BaseWhisperSTTService class
return await self._client.audio.transcriptions.create(
file=("audio.wav", audio, "audio/wav"), model=self.model_name
file=("audio.wav", audio, "audio/wav"), model=self.model_name, language=self._language
)


Expand Down
6 changes: 6 additions & 0 deletions src/pipecat/transcriptions/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class Language(StrEnum):
AZ = "az"
AZ_AZ = "az-AZ"

# Belarusian
BE = "be"

# Bulgarian
BG = "bg"
BG_BG = "bg-BG"
Expand Down Expand Up @@ -264,6 +267,9 @@ class Language(StrEnum):
MN = "mn"
MN_MN = "mn-MN"

# Maori
MI = "mi"

# Marathi
MR = "mr"
MR_IN = "mr-IN"
Expand Down

0 comments on commit 0d2e90c

Please sign in to comment.