Skip to content

Commit

Permalink
Merge pull request BerriAI#4449 from BerriAI/litellm_azure_tts
Browse files Browse the repository at this point in the history
feat(azure.py): azure tts support
  • Loading branch information
krrishdholakia authored Jun 28, 2024
2 parents efee284 + c14cc35 commit 1223b2b
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 12 deletions.
134 changes: 134 additions & 0 deletions litellm/llms/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
AsyncAssistantEventHandler,
AsyncAssistantStreamManager,
AsyncCursorPage,
HttpxBinaryResponseContent,
MessageData,
OpenAICreateThreadParamsMessage,
OpenAIMessage,
Expand Down Expand Up @@ -414,6 +415,49 @@ def validate_environment(self, api_key, azure_ad_token):
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers

def _get_sync_azure_client(
self,
api_version: Optional[str],
api_base: Optional[str],
api_key: Optional[str],
azure_ad_token: Optional[str],
model: str,
max_retries: int,
timeout: Union[float, httpx.Timeout],
client: Optional[Any],
client_type: Literal["sync", "async"],
):
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
if client_type == "sync":
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
elif client_type == "async":
azure_client = AsyncAzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
if api_version is not None and isinstance(azure_client._custom_query, dict):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)

return azure_client

def completion(
self,
model: str,
Expand Down Expand Up @@ -1256,6 +1300,96 @@ async def async_audio_transcriptions(
)
raise e

def audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
organization: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
azure_ad_token: Optional[str] = None,
aspeech: Optional[bool] = None,
client=None,
) -> HttpxBinaryResponseContent:

max_retries = optional_params.pop("max_retries", 2)

if aspeech is not None and aspeech is True:
return self.async_audio_speech(
model=model,
input=input,
voice=voice,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
max_retries=max_retries,
timeout=timeout,
client=client,
) # type: ignore

azure_client: AzureOpenAI = self._get_sync_azure_client(
api_base=api_base,
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
model=model,
max_retries=max_retries,
timeout=timeout,
client=client,
client_type="sync",
) # type: ignore

response = azure_client.audio.speech.create(
model=model,
voice=voice, # type: ignore
input=input,
**optional_params,
)
return response

async def async_audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
client=None,
) -> HttpxBinaryResponseContent:

azure_client: AsyncAzureOpenAI = self._get_sync_azure_client(
api_base=api_base,
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
model=model,
max_retries=max_retries,
timeout=timeout,
client=client,
client_type="async",
) # type: ignore

response = await azure_client.audio.speech.create(
model=model,
voice=voice, # type: ignore
input=input,
**optional_params,
)

return response

def get_headers(
self,
model: Optional[str],
Expand Down
40 changes: 40 additions & 0 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4410,6 +4410,7 @@ def speech(
voice: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
max_retries: Optional[int] = None,
Expand Down Expand Up @@ -4483,6 +4484,45 @@ def speech(
client=client, # pass AsyncOpenAI, OpenAI client
aspeech=aspeech,
)
elif custom_llm_provider == "azure":
# azure configs
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore

api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
) # type: ignore

api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore

azure_ad_token: Optional[str] = optional_params.get("extra_body", {}).pop( # type: ignore
"azure_ad_token", None
) or get_secret(
"AZURE_AD_TOKEN"
)

headers = headers or litellm.headers

response = azure_chat_completions.audio_speech(
model=model,
input=input,
voice=voice,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
organization=organization,
max_retries=max_retries,
timeout=timeout,
client=client, # pass AsyncOpenAI, OpenAI client
aspeech=aspeech,
)

if response is None:
raise Exception(
Expand Down
1 change: 0 additions & 1 deletion litellm/proxy/_super_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ model_list:
api_key: os.environ/PREDIBASE_API_KEY
tenant_id: os.environ/PREDIBASE_TENANT_ID
max_new_tokens: 256

# - litellm_params:
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
# api_key: os.environ/AZURE_EUROPE_API_KEY
Expand Down
45 changes: 34 additions & 11 deletions litellm/tests/test_audio_speech.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# What is this?
## unit tests for openai tts endpoint

import sys, os, asyncio, time, random, uuid
import asyncio
import os
import random
import sys
import time
import traceback
import uuid

from dotenv import load_dotenv

load_dotenv()
Expand All @@ -11,23 +17,40 @@
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm, openai
from pathlib import Path

import openai
import pytest

@pytest.mark.parametrize("sync_mode", [True, False])
import litellm


@pytest.mark.parametrize(
"sync_mode",
[True, False],
)
@pytest.mark.parametrize(
"model, api_key, api_base",
[
(
"azure/azure-tts",
os.getenv("AZURE_SWEDEN_API_KEY"),
os.getenv("AZURE_SWEDEN_API_BASE"),
),
("openai/tts-1", os.getenv("OPENAI_API_KEY"), None),
],
) # ,
@pytest.mark.asyncio
async def test_audio_speech_litellm(sync_mode):
async def test_audio_speech_litellm(sync_mode, model, api_base, api_key):
speech_file_path = Path(__file__).parent / "speech.mp3"

if sync_mode:
response = litellm.speech(
model="openai/tts-1",
model=model,
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
api_base=api_base,
api_key=api_key,
organization=None,
project=None,
max_retries=1,
Expand All @@ -41,11 +64,11 @@ async def test_audio_speech_litellm(sync_mode):
assert isinstance(response, HttpxBinaryResponseContent)
else:
response = await litellm.aspeech(
model="openai/tts-1",
model=model,
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
api_base=api_base,
api_key=api_key,
organization=None,
project=None,
max_retries=1,
Expand Down

0 comments on commit 1223b2b

Please sign in to comment.