Skip to content

Commit de91231

Browse files
Merge pull request stacklok#132 from stacklok/add-fim-to-vlllm
Add FIM functionalty for VLLM provider
2 parents db2ff1c + bd8219a commit de91231

File tree

6 files changed

+29
-7
lines changed

6 files changed

+29
-7
lines changed

src/codegate/providers/anthropic/completion_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ async def execute_completion(
1515
request: ChatCompletionRequest,
1616
api_key: Optional[str],
1717
stream: bool = False,
18+
is_fim_request: bool = False,
1819
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
1920
"""
2021
Ensures the model name is prefixed with 'anthropic/' to explicitly route to Anthropic's API.
@@ -30,4 +31,4 @@ async def execute_completion(
3031
model_in_request = request["model"]
3132
if not model_in_request.startswith("anthropic/"):
3233
request["model"] = f"anthropic/{model_in_request}"
33-
return await super().execute_completion(request, api_key, stream)
34+
return await super().execute_completion(request, api_key, stream, is_fim_request)

src/codegate/providers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ async def complete(
161161
# This gives us either a single response or a stream of responses
162162
# based on the streaming flag
163163
model_response = await self._completion_handler.execute_completion(
164-
provider_request, api_key=api_key, stream=streaming
164+
provider_request, api_key=api_key, stream=streaming, is_fim_request=is_fim_request
165165
)
166166
if not streaming:
167167
normalized_response = self._output_normalizer.normalize(model_response)

src/codegate/providers/completion/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ async def execute_completion(
1717
request: ChatCompletionRequest,
1818
api_key: Optional[str],
1919
stream: bool = False, # TODO: remove this param?
20+
is_fim_request: bool = False,
2021
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
2122
"""Execute the completion request"""
2223
pass

src/codegate/providers/litellmshim/litellmshim.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, AsyncIterator, Optional, Union
1+
from typing import Any, AsyncIterator, Callable, Optional, Union
22

33
from fastapi.responses import StreamingResponse
44
from litellm import ChatCompletionRequest, ModelResponse, acompletion
@@ -13,20 +13,33 @@ class LiteLLmShim(BaseCompletionHandler):
1313
LiteLLM API.
1414
"""
1515

16-
def __init__(self, stream_generator: StreamGenerator, completion_func=acompletion):
16+
def __init__(
17+
self,
18+
stream_generator: StreamGenerator,
19+
completion_func: Callable = acompletion,
20+
fim_completion_func: Optional[Callable] = None,
21+
):
1722
self._stream_generator = stream_generator
1823
self._completion_func = completion_func
24+
# Use the same function for FIM completion if one is not specified
25+
if fim_completion_func is None:
26+
self._fim_completion_func = completion_func
27+
else:
28+
self._fim_completion_func = fim_completion_func
1929

2030
async def execute_completion(
2131
self,
2232
request: ChatCompletionRequest,
2333
api_key: Optional[str],
2434
stream: bool = False,
35+
is_fim_request: bool = False,
2536
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
2637
"""
2738
Execute the completion request with LiteLLM's API
2839
"""
2940
request["api_key"] = api_key
41+
if is_fim_request:
42+
return await self._fim_completion_func(**request)
3043
return await self._completion_func(**request)
3144

3245
def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:

src/codegate/providers/llamacpp/completion_handler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,18 @@ def __init__(self):
4747
self.inference_engine = LlamaCppInferenceEngine()
4848

4949
async def execute_completion(
50-
self, request: ChatCompletionRequest, api_key: Optional[str], stream: bool = False
50+
self,
51+
request: ChatCompletionRequest,
52+
api_key: Optional[str],
53+
stream: bool = False,
54+
is_fim_request: bool = False,
5155
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
5256
"""
5357
Execute the completion request with inference engine API
5458
"""
5559
model_path = f"{Config.get_config().model_base_path}/{request['model']}.gguf"
5660

57-
if "prompt" in request:
61+
if is_fim_request:
5862
response = await self.inference_engine.complete(
5963
model_path,
6064
Config.get_config().chat_model_n_ctx,

src/codegate/providers/vllm/provider.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional
33

44
from fastapi import Header, HTTPException, Request
5+
from litellm import atext_completion
56

67
from codegate.config import Config
78
from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
@@ -15,7 +16,9 @@ def __init__(
1516
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
1617
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
1718
):
18-
completion_handler = LiteLLmShim(stream_generator=sse_stream_generator)
19+
completion_handler = LiteLLmShim(
20+
stream_generator=sse_stream_generator, fim_completion_func=atext_completion
21+
)
1922
super().__init__(
2023
VLLMInputNormalizer(),
2124
VLLMOutputNormalizer(),

0 commit comments

Comments
 (0)