Skip to content

Commit dc988a3

Browse files
committed
Normalize and denormalize llamacpp streaming reply
Originally, I wanted to add the normalizers to convert the `im_start`/`im_end` tags, but we worked around that by setting llamacpp to use the OpenAI format. We'll still need a normalizer for the vllm provider though. At the moment we really need the denormalizer so that the blocking pipeline can return a stream of `ModelResponse`s and the denormalizer would convert them to the CreateChatCompletionStreamResponse structure that is then serialized to the client. This avoids any guessing or special casing that would otherwise be needed in the `llamacpp_stream_generator` which currently expected `Iterator[CreateChatCompletionStreamResponse]`. Another change that simplifies the logic is that the `llamacpp_stream_generator` now accepts an `AsyncIterator` instead of just `Iterator` that the llamacpp completion hander was returning. Again, this is to simplify the logic and pass the iterator from the blocking pipeline. On the completion side we have a simple sync-to-async wrapper. Fixes: stacklok#94
1 parent 9954610 commit dc988a3

File tree

8 files changed

+309
-58
lines changed

8 files changed

+309
-58
lines changed

src/codegate/pipeline/codegate_system_prompt/codegate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ class CodegateSystemPrompt(PipelineStep):
1717

1818
def __init__(self, system_prompt_message: Optional[str] = None):
1919
self._system_message = ChatCompletionSystemMessage(
20-
content=system_prompt_message,
21-
role="system"
20+
content=system_prompt_message, role="system"
2221
)
2322

2423
@property
@@ -29,7 +28,7 @@ def name(self) -> str:
2928
return "codegate-system-prompt"
3029

3130
async def process(
32-
self, request: ChatCompletionRequest, context: PipelineContext
31+
self, request: ChatCompletionRequest, context: PipelineContext
3332
) -> PipelineResult:
3433
"""
3534
Process the completion request and add a system prompt if the user message contains

src/codegate/providers/base.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,20 @@ def _setup_routes(self) -> None:
4949
def provider_route_name(self) -> str:
5050
pass
5151

52+
async def _run_output_stream_pipeline(
53+
self,
54+
normalized_stream: AsyncIterator[ModelResponse],
55+
) -> AsyncIterator[ModelResponse]:
56+
# we don't have a pipeline for output stream yet
57+
return normalized_stream
58+
59+
def _run_output_pipeline(
60+
self,
61+
normalized_response: ModelResponse,
62+
) -> ModelResponse:
63+
# we don't have a pipeline for output yet
64+
return normalized_response
65+
5266
async def _run_input_pipeline(
5367
self, normalized_request: ChatCompletionRequest, is_fim_request: bool
5468
) -> PipelineResult:
@@ -149,8 +163,13 @@ async def complete(
149163
provider_request, api_key=api_key, stream=streaming
150164
)
151165
if not streaming:
152-
return self._output_normalizer.denormalize(model_response)
153-
return self._output_normalizer.denormalize_streaming(model_response)
166+
normalized_response = self._output_normalizer.normalize(model_response)
167+
pipeline_output = self._run_output_pipeline(normalized_response)
168+
return self._output_normalizer.denormalize(pipeline_output)
169+
170+
normalized_stream = self._output_normalizer.normalize_streaming(model_response)
171+
pipeline_output_stream = await self._run_output_stream_pipeline(normalized_stream)
172+
return self._output_normalizer.denormalize_streaming(pipeline_output_stream)
154173

155174
def get_routes(self) -> APIRouter:
156175
return self.router

src/codegate/providers/litellmshim/generators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import json
32
from typing import Any, AsyncIterator
43

src/codegate/providers/llamacpp/completion_handler.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,24 @@
44

55
from fastapi.responses import StreamingResponse
66
from litellm import ChatCompletionRequest, ModelResponse
7+
from llama_cpp.llama_types import (
8+
CreateChatCompletionStreamResponse,
9+
)
710

811
from codegate.config import Config
912
from codegate.inference.inference_engine import LlamaCppInferenceEngine
1013
from codegate.providers.base import BaseCompletionHandler
1114

1215

13-
async def llamacpp_stream_generator(stream: Iterator[Any]) -> AsyncIterator[str]:
16+
async def llamacpp_stream_generator(
17+
stream: AsyncIterator[CreateChatCompletionStreamResponse],
18+
) -> AsyncIterator[str]:
1419
"""OpenAI-style SSE format"""
1520
try:
16-
for chunk in stream:
17-
if hasattr(chunk, "model_dump_json"):
18-
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
21+
async for chunk in stream:
22+
chunk = json.dumps(chunk)
1923
try:
20-
yield f"data:{json.dumps(chunk)}\n\n"
21-
await asyncio.sleep(0)
24+
yield f"data:{chunk}\n\n"
2225
except Exception as e:
2326
yield f"data:{str(e)}\n\n"
2427
except Exception as e:
@@ -27,6 +30,18 @@ async def llamacpp_stream_generator(stream: Iterator[Any]) -> AsyncIterator[str]
2730
yield "data: [DONE]\n\n"
2831

2932

33+
async def convert_to_async_iterator(
34+
sync_iterator: Iterator[CreateChatCompletionStreamResponse],
35+
) -> AsyncIterator[CreateChatCompletionStreamResponse]:
36+
"""
37+
Convert a synchronous iterator to an asynchronous iterator. This makes the logic easier
38+
because both the pipeline and the completion handler can use async iterators.
39+
"""
40+
for item in sync_iterator:
41+
yield item
42+
await asyncio.sleep(0)
43+
44+
3045
class LlamaCppCompletionHandler(BaseCompletionHandler):
3146
def __init__(self):
3247
self.inference_engine = LlamaCppInferenceEngine()
@@ -53,9 +68,10 @@ async def execute_completion(
5368
Config.get_config().chat_model_n_gpu_layers,
5469
**request,
5570
)
56-
return response
5771

58-
def create_streaming_response(self, stream: Iterator[Any]) -> StreamingResponse:
72+
return convert_to_async_iterator(response) if stream else response
73+
74+
def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
5975
"""
6076
Create a streaming response from a stream generator. The StreamingResponse
6177
is the format that FastAPI expects for streaming responses.

src/codegate/providers/llamacpp/normalizer.py

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
from typing import Any, AsyncIterable, AsyncIterator, Dict, Iterable, Iterator, Union
1+
from typing import Any, AsyncIterable, AsyncIterator, Dict, Union
22

33
from litellm import ChatCompletionRequest, ModelResponse
4+
from litellm.types.utils import Delta, StreamingChoices
5+
from llama_cpp.llama_types import (
6+
ChatCompletionStreamResponseChoice,
7+
ChatCompletionStreamResponseDelta,
8+
ChatCompletionStreamResponseDeltaEmpty,
9+
CreateChatCompletionStreamResponse,
10+
)
411

512
from codegate.providers.normalizer import ModelInputNormalizer, ModelOutputNormalizer
613

@@ -32,16 +39,97 @@ def denormalize(self, data: ChatCompletionRequest) -> Dict:
3239
return data
3340

3441

42+
class ModelToLlamaCpp(AsyncIterator[CreateChatCompletionStreamResponse]):
43+
def __init__(self, normalized_reply: AsyncIterable[ModelResponse]):
44+
self.normalized_reply = normalized_reply
45+
self._aiter = normalized_reply.__aiter__()
46+
47+
def __aiter__(self):
48+
return self
49+
50+
@staticmethod
51+
def _create_delta(
52+
choice_delta: Delta,
53+
) -> Union[ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty]:
54+
if not choice_delta:
55+
return ChatCompletionStreamResponseDeltaEmpty()
56+
return ChatCompletionStreamResponseDelta(
57+
content=choice_delta.content,
58+
role=choice_delta.role,
59+
)
60+
61+
async def __anext__(self) -> CreateChatCompletionStreamResponse:
62+
try:
63+
chunk = await self._aiter.__anext__()
64+
return CreateChatCompletionStreamResponse(
65+
id=chunk["id"],
66+
model=chunk["model"],
67+
object="chat.completion.chunk",
68+
created=chunk["created"],
69+
choices=[
70+
ChatCompletionStreamResponseChoice(
71+
index=choice.index,
72+
delta=self._create_delta(choice.delta),
73+
finish_reason=choice.finish_reason,
74+
logprobs=None,
75+
)
76+
for choice in chunk["choices"]
77+
],
78+
)
79+
except StopAsyncIteration:
80+
raise StopAsyncIteration
81+
82+
83+
class LlamaCppToModel(AsyncIterator[ModelResponse]):
84+
def __init__(self, normalized_reply: AsyncIterable[CreateChatCompletionStreamResponse]):
85+
self.normalized_reply = normalized_reply
86+
self._aiter = normalized_reply.__aiter__()
87+
88+
def __aiter__(self):
89+
return self
90+
91+
@staticmethod
92+
def _create_delta(
93+
choice_delta: Union[
94+
ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty
95+
]
96+
) -> Delta:
97+
if not choice_delta: # Handles empty dict case
98+
return Delta(content=None, role=None)
99+
return Delta(content=choice_delta.get("content"), role=choice_delta.get("role"))
100+
101+
async def __anext__(self) -> ModelResponse:
102+
try:
103+
chunk = await self._aiter.__anext__()
104+
return ModelResponse(
105+
id=chunk["id"],
106+
choices=[
107+
StreamingChoices(
108+
finish_reason=choice.get("finish_reason", None),
109+
index=choice["index"],
110+
delta=self._create_delta(choice.get("delta")),
111+
logprobs=None,
112+
)
113+
for choice in chunk["choices"]
114+
],
115+
created=chunk["created"],
116+
model=chunk["model"],
117+
object=chunk["object"],
118+
)
119+
except StopAsyncIteration:
120+
raise StopAsyncIteration
121+
122+
35123
class LLamaCppOutputNormalizer(ModelOutputNormalizer):
36124
def normalize_streaming(
37125
self,
38-
model_reply: Union[AsyncIterable[Any], Iterable[Any]],
39-
) -> Union[AsyncIterator[ModelResponse], Iterator[ModelResponse]]:
126+
llamacpp_stream: AsyncIterable[CreateChatCompletionStreamResponse],
127+
) -> AsyncIterator[ModelResponse]:
40128
"""
41129
Normalize the output stream. This is a pass-through for liteLLM output normalizer
42130
as the liteLLM output is already in the normalized format.
43131
"""
44-
return model_reply
132+
return LlamaCppToModel(llamacpp_stream)
45133

46134
def normalize(self, model_reply: Any) -> ModelResponse:
47135
"""
@@ -59,10 +147,10 @@ def denormalize(self, normalized_reply: ModelResponse) -> Any:
59147

60148
def denormalize_streaming(
61149
self,
62-
normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]],
63-
) -> Union[AsyncIterator[Any], Iterator[Any]]:
150+
model_stream: AsyncIterable[ModelResponse],
151+
) -> AsyncIterator[CreateChatCompletionStreamResponse]:
64152
"""
65153
Denormalize the output stream from the completion function to the format
66154
expected by the client
67155
"""
68-
return normalized_reply
156+
return ModelToLlamaCpp(model_stream)

src/codegate/server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from codegate.config import Config
77
from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor
88
from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt
9-
from codegate.pipeline.secrets.secrets import CodegateSecrets
10-
from codegate.pipeline.secrets.signatures import CodegateSignatures
119
from codegate.pipeline.version.version import CodegateVersion
1210
from codegate.providers.anthropic.provider import AnthropicProvider
1311
from codegate.providers.llamacpp.provider import LlamaCppProvider

tests/pipeline/codegate_system_prompt/test_codegate_system_prompt.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,61 +24,50 @@ def test_init_with_system_message(self):
2424
step = CodegateSystemPrompt(system_prompt_message=test_message)
2525
assert step._system_message["content"] == test_message
2626

27-
@pytest.mark.parametrize("user_message,expected_modification", [
28-
# Test cases with different scenarios
29-
("Hello CodeGate", True),
30-
("CODEGATE in uppercase", True),
31-
("No matching message", False),
32-
("codegate with lowercase", True)
33-
])
34-
async def test_process_system_prompt_insertion(
35-
self,
36-
user_message,
37-
expected_modification
38-
):
27+
@pytest.mark.parametrize(
28+
"user_message,expected_modification",
29+
[
30+
# Test cases with different scenarios
31+
("Hello CodeGate", True),
32+
("CODEGATE in uppercase", True),
33+
("No matching message", False),
34+
("codegate with lowercase", True),
35+
],
36+
)
37+
async def test_process_system_prompt_insertion(self, user_message, expected_modification):
3938
"""
4039
Test system prompt insertion based on message content
4140
"""
4241
# Prepare mock request with user message
43-
mock_request = {
44-
"messages": [
45-
{"role": "user", "content": user_message}
46-
]
47-
}
42+
mock_request = {"messages": [{"role": "user", "content": user_message}]}
4843
mock_context = Mock(spec=PipelineContext)
4944

5045
# Create system prompt step
5146
system_prompt = "Security analysis system prompt"
5247
step = CodegateSystemPrompt(system_prompt_message=system_prompt)
5348

5449
# Mock the get_last_user_message method
55-
step.get_last_user_message = Mock(
56-
return_value=(user_message, 0)
57-
)
50+
step.get_last_user_message = Mock(return_value=(user_message, 0))
5851

5952
# Process the request
6053
result = await step.process(ChatCompletionRequest(**mock_request), mock_context)
6154

6255
if expected_modification:
6356
# Check that system message was inserted
64-
assert len(result.request['messages']) == 2
65-
assert result.request['messages'][0]['role'] == 'system'
66-
assert result.request['messages'][0]['content'] == system_prompt
67-
assert result.request['messages'][1]['role'] == 'user'
68-
assert result.request['messages'][1]['content'] == user_message
57+
assert len(result.request["messages"]) == 2
58+
assert result.request["messages"][0]["role"] == "system"
59+
assert result.request["messages"][0]["content"] == system_prompt
60+
assert result.request["messages"][1]["role"] == "user"
61+
assert result.request["messages"][1]["content"] == user_message
6962
else:
7063
# Ensure no modification occurred
71-
assert len(result.request['messages']) == 1
64+
assert len(result.request["messages"]) == 1
7265

7366
async def test_no_system_message_configured(self):
7467
"""
7568
Test behavior when no system message is configured
7669
"""
77-
mock_request = {
78-
"messages": [
79-
{"role": "user", "content": "CodeGate test"}
80-
]
81-
}
70+
mock_request = {"messages": [{"role": "user", "content": "CodeGate test"}]}
8271
mock_context = Mock(spec=PipelineContext)
8372

8473
# Create step without system message
@@ -90,10 +79,13 @@ async def test_no_system_message_configured(self):
9079
# Verify request remains unchanged
9180
assert result.request == mock_request
9281

93-
@pytest.mark.parametrize("edge_case", [
94-
None, # No messages
95-
[], # Empty messages list
96-
])
82+
@pytest.mark.parametrize(
83+
"edge_case",
84+
[
85+
None, # No messages
86+
[], # Empty messages list
87+
],
88+
)
9789
async def test_edge_cases(self, edge_case):
9890
"""
9991
Test edge cases with None or empty message list

0 commit comments

Comments
 (0)