Skip to content

Commit 70eeb9b

Browse files
author
Luke Hinds
authored
Merge pull request stacklok#117 from jhrozek/codegate_system_prompt
If codegate is detected in the message, change the system prompt
2 parents b8a6e80 + 1543fb6 commit 70eeb9b

File tree

4 files changed

+175
-0
lines changed

4 files changed

+175
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt
2+
3+
__all__ = ["CodegateSystemPrompt"]
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from typing import Optional
2+
3+
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage
4+
5+
from codegate.pipeline.base import (
6+
PipelineContext,
7+
PipelineResult,
8+
PipelineStep,
9+
)
10+
11+
12+
class CodegateSystemPrompt(PipelineStep):
13+
"""
14+
Pipeline step that adds a system prompt to the completion request when it detects
15+
the word "codegate" in the user message.
16+
"""
17+
18+
def __init__(self, system_prompt_message: Optional[str] = None):
19+
self._system_message = ChatCompletionSystemMessage(
20+
content=system_prompt_message,
21+
role="system"
22+
)
23+
24+
@property
25+
def name(self) -> str:
26+
"""
27+
Returns the name of this pipeline step.
28+
"""
29+
return "codegate-system-prompt"
30+
31+
async def process(
32+
self, request: ChatCompletionRequest, context: PipelineContext
33+
) -> PipelineResult:
34+
"""
35+
Process the completion request and add a system prompt if the user message contains
36+
the word "codegate".
37+
"""
38+
# no prompt configured
39+
if not self._system_message["content"]:
40+
return PipelineResult(request=request)
41+
42+
last_user_message = self.get_last_user_message(request)
43+
44+
if last_user_message is not None:
45+
last_user_message_str, last_user_idx = last_user_message
46+
if "codegate" in last_user_message_str.lower():
47+
# Add a system prompt to the completion request
48+
new_request = request.copy()
49+
new_request["messages"].insert(last_user_idx, self._system_message)
50+
return PipelineResult(
51+
request=new_request,
52+
)
53+
54+
# Fall through
55+
return PipelineResult(request=request)

src/codegate/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from fastapi import APIRouter, FastAPI
44

55
from codegate import __description__, __version__
6+
from codegate.config import Config
67
from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor
8+
from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt
79
from codegate.pipeline.secrets.secrets import CodegateSecrets
810
from codegate.pipeline.secrets.signatures import CodegateSignatures
911
from codegate.pipeline.version.version import CodegateVersion
@@ -23,6 +25,7 @@ def init_app() -> FastAPI:
2325

2426
steps: List[PipelineStep] = [
2527
CodegateVersion(),
28+
CodegateSystemPrompt(Config.get_config().prompts.codegate_chat),
2629
# CodegateSecrets(),
2730
]
2831
# Leaving the pipeline empty for now
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from unittest.mock import Mock
2+
3+
import pytest
4+
from litellm.types.llms.openai import ChatCompletionRequest
5+
6+
from codegate.pipeline.base import PipelineContext
7+
from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt
8+
9+
10+
@pytest.mark.asyncio
11+
class TestCodegateSystemPrompt:
12+
def test_init_no_system_message(self):
13+
"""
14+
Test initialization with no system message
15+
"""
16+
step = CodegateSystemPrompt()
17+
assert step._system_message["content"] is None
18+
19+
def test_init_with_system_message(self):
20+
"""
21+
Test initialization with a system message
22+
"""
23+
test_message = "Test system prompt"
24+
step = CodegateSystemPrompt(system_prompt_message=test_message)
25+
assert step._system_message["content"] == test_message
26+
27+
@pytest.mark.parametrize("user_message,expected_modification", [
28+
# Test cases with different scenarios
29+
("Hello CodeGate", True),
30+
("CODEGATE in uppercase", True),
31+
("No matching message", False),
32+
("codegate with lowercase", True)
33+
])
34+
async def test_process_system_prompt_insertion(
35+
self,
36+
user_message,
37+
expected_modification
38+
):
39+
"""
40+
Test system prompt insertion based on message content
41+
"""
42+
# Prepare mock request with user message
43+
mock_request = {
44+
"messages": [
45+
{"role": "user", "content": user_message}
46+
]
47+
}
48+
mock_context = Mock(spec=PipelineContext)
49+
50+
# Create system prompt step
51+
system_prompt = "Security analysis system prompt"
52+
step = CodegateSystemPrompt(system_prompt_message=system_prompt)
53+
54+
# Mock the get_last_user_message method
55+
step.get_last_user_message = Mock(
56+
return_value=(user_message, 0)
57+
)
58+
59+
# Process the request
60+
result = await step.process(ChatCompletionRequest(**mock_request), mock_context)
61+
62+
if expected_modification:
63+
# Check that system message was inserted
64+
assert len(result.request['messages']) == 2
65+
assert result.request['messages'][0]['role'] == 'system'
66+
assert result.request['messages'][0]['content'] == system_prompt
67+
assert result.request['messages'][1]['role'] == 'user'
68+
assert result.request['messages'][1]['content'] == user_message
69+
else:
70+
# Ensure no modification occurred
71+
assert len(result.request['messages']) == 1
72+
73+
async def test_no_system_message_configured(self):
74+
"""
75+
Test behavior when no system message is configured
76+
"""
77+
mock_request = {
78+
"messages": [
79+
{"role": "user", "content": "CodeGate test"}
80+
]
81+
}
82+
mock_context = Mock(spec=PipelineContext)
83+
84+
# Create step without system message
85+
step = CodegateSystemPrompt()
86+
87+
# Process the request
88+
result = await step.process(ChatCompletionRequest(**mock_request), mock_context)
89+
90+
# Verify request remains unchanged
91+
assert result.request == mock_request
92+
93+
@pytest.mark.parametrize("edge_case", [
94+
None, # No messages
95+
[], # Empty messages list
96+
])
97+
async def test_edge_cases(self, edge_case):
98+
"""
99+
Test edge cases with None or empty message list
100+
"""
101+
mock_request = {"messages": edge_case} if edge_case is not None else {}
102+
mock_context = Mock(spec=PipelineContext)
103+
104+
system_prompt = "Security edge case prompt"
105+
step = CodegateSystemPrompt(system_prompt_message=system_prompt)
106+
107+
# Mock get_last_user_message to return None
108+
step.get_last_user_message = Mock(return_value=None)
109+
110+
# Process the request
111+
result = await step.process(ChatCompletionRequest(**mock_request), mock_context)
112+
113+
# Verify request remains unchanged
114+
assert result.request == mock_request

0 commit comments

Comments
 (0)