Skip to content

Commit

Permalink
fix(custom_llm.py): pass input params to custom llm
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishdholakia committed Jul 26, 2024
1 parent bd7af04 commit 41abd51
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 10 deletions.
80 changes: 76 additions & 4 deletions litellm/llms/custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,88 @@ class CustomLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()

def completion(self, *args, **kwargs) -> ModelResponse:
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")

def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> Iterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")

async def acompletion(self, *args, **kwargs) -> ModelResponse:
async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")

async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
async def astreaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> AsyncIterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")


Expand Down
21 changes: 20 additions & 1 deletion litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2711,8 +2711,27 @@ def completion(
async_fn=acompletion, stream=stream, custom_llm=custom_handler
)

headers = headers or litellm.headers

## CALL FUNCTION
response = handler_fn()
response = handler_fn(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
encoding=encoding,
)
if stream is True:
return CustomStreamWrapper(
completion_stream=response,
Expand Down
91 changes: 86 additions & 5 deletions litellm/tests/test_custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Iterator, Union
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Coroutine,
Iterator,
Optional,
Union,
)
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
Expand Down Expand Up @@ -94,21 +103,75 @@ async def __anext__(self) -> GenericStreamingChunk:


class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.HTTPHandler] = None,
) -> ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore

async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.AsyncHTTPHandler] = None,
) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore

def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.HTTPHandler] = None,
) -> Iterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
Expand All @@ -126,7 +189,25 @@ def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
)
return custom_iterator

async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
async def astreaming( # type: ignore
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.AsyncHTTPHandler] = None,
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
Expand Down

0 comments on commit 41abd51

Please sign in to comment.