Skip to content

Commit 6184edd

Browse files
ptelangjhrozek
authored andcommitted
Modify storageclient to singleton pattern
1 parent dd8f8d7 commit 6184edd

File tree

6 files changed

+68
-64
lines changed

6 files changed

+68
-64
lines changed

config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ log_level: "INFO" # One of: ERROR, WARNING, INFO, DEBUG
1919
##
2020

2121
# Model to use for chatting
22-
chat_model_path: "./models"
22+
model_base_path: "./models"
2323

2424
# Context length of the model
2525
chat_model_n_ctx: 32768
2626

2727
# Number of layers to offload to GPU. If -1, all layers are offloaded.
2828
chat_model_n_gpu_layers: -1
2929

30+
# Embedding model
31+
embedding_model: "all-minilm-L6-v2-q5_k_m.gguf"

src/codegate/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class Config:
4040
model_base_path: str = "./models"
4141
chat_model_n_ctx: int = 32768
4242
chat_model_n_gpu_layers: int = -1
43+
embedding_model: str = "all-minilm-L6-v2-q5_k_m.gguf"
4344

4445
# Provider URLs with defaults
4546
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())
@@ -117,11 +118,12 @@ def from_file(cls, config_path: Union[str, Path]) -> "Config":
117118
host=config_data.get("host", cls.host),
118119
log_level=config_data.get("log_level", cls.log_level.value),
119120
log_format=config_data.get("log_format", cls.log_format.value),
120-
model_base_path=config_data.get("chat_model_path", cls.model_base_path),
121+
model_base_path=config_data.get("model_base_path", cls.model_base_path),
121122
chat_model_n_ctx=config_data.get("chat_model_n_ctx", cls.chat_model_n_ctx),
122123
chat_model_n_gpu_layers=config_data.get(
123124
"chat_model_n_gpu_layers", cls.chat_model_n_gpu_layers
124125
),
126+
embedding_model=config_data.get("embedding_model", cls.embedding_model),
125127
prompts=prompts_config,
126128
provider_urls=provider_urls,
127129
)

src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ class CodegateContextRetriever(PipelineStep):
2020
the word "codegate" in the user message.
2121
"""
2222

23-
def __init__(self):
24-
self.storage_engine = StorageEngine()
25-
2623
@property
2724
def name(self) -> str:
2825
"""
@@ -33,7 +30,8 @@ def name(self) -> str:
3330
async def get_objects_from_search(
3431
self, search: str, packages: list[str] = None
3532
) -> list[object]:
36-
objects = await self.storage_engine.search(search, distance=0.8, packages=packages)
33+
storage_engine = StorageEngine()
34+
objects = await storage_engine.search(search, distance=0.8, packages=packages)
3735
return objects
3836

3937
def generate_context_str(self, objects: list[object]) -> str:

src/codegate/pipeline/extract_snippets/output.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
class CodeCommentStep(OutputPipelineStep):
1717
"""Pipeline step that adds comments after code blocks"""
1818

19-
def __init__(self):
20-
self._storage_engine = StorageEngine()
21-
2219
@property
2320
def name(self) -> str:
2421
return "code-comment"
@@ -52,7 +49,8 @@ async def _snippet_comment(self, snippet: CodeSnippet, secrets: PipelineSensitiv
5249
base_url=secrets.api_base,
5350
)
5451

55-
libobjects = await self._storage_engine.search_by_property("name", snippet.libraries)
52+
storage_engine = StorageEngine()
53+
libobjects = await storage_engine.search_by_property("name", snippet.libraries)
5654
logger.info(f"Found {len(libobjects)} libraries in the storage engine")
5755

5856
libraries_text = ""

src/codegate/storage/storage_engine.py

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,60 @@
2626

2727

2828
class StorageEngine:
29-
def get_client(self, data_path):
29+
__storage_engine = None
30+
31+
def __new__(cls, *args, **kwargs):
32+
if cls.__storage_engine is None:
33+
cls.__storage_engine = super().__new__(cls)
34+
return cls.__storage_engine
35+
36+
# This function is needed only for the unit testing for the
37+
# mocks to work.
38+
@classmethod
39+
def recreate_instance(cls, *args, **kwargs):
40+
cls.__storage_engine = None
41+
return cls(*args, **kwargs)
42+
43+
def __init__(self, data_path="./weaviate_data"):
44+
if hasattr(self, "initialized"):
45+
return
46+
47+
self.initialized = True
48+
self.data_path = data_path
49+
self.inference_engine = LlamaCppInferenceEngine()
50+
self.model_path = (
51+
f"{Config.get_config().model_base_path}/{Config.get_config().embedding_model}"
52+
)
53+
self.schema_config = schema_config
54+
55+
# setup schema for weaviate
56+
self.weaviate_client = self.get_client(self.data_path)
57+
if self.weaviate_client is not None:
58+
try:
59+
self.weaviate_client.connect()
60+
self.setup_schema(self.weaviate_client)
61+
except Exception as e:
62+
logger.error(f"Failed to connect or setup schema: {str(e)}")
63+
64+
def __del__(self):
3065
try:
31-
# Get current config
32-
config = Config.get_config()
66+
self.weaviate_client.close()
67+
except Exception as e:
68+
logger.error(f"Failed to close client: {str(e)}")
3369

70+
def get_client(self, data_path):
71+
try:
3472
# Configure Weaviate logging
3573
additional_env_vars = {
3674
# Basic logging configuration
37-
"LOG_FORMAT": config.log_format.value.lower(),
38-
"LOG_LEVEL": config.log_level.value.lower(),
75+
"LOG_FORMAT": Config.get_config().log_format.value.lower(),
76+
"LOG_LEVEL": Config.get_config().log_level.value.lower(),
3977
# Disable colored output
4078
"LOG_FORCE_COLOR": "false",
4179
# Configure JSON format
4280
"LOG_JSON_FIELDS": "timestamp, level,message",
4381
# Configure text format
44-
"LOG_METHOD": config.log_format.value.lower(),
82+
"LOG_METHOD": Config.get_config().log_format.value.lower(),
4583
"LOG_LEVEL_IN_UPPER": "false", # Keep level lowercase like codegate format
4684
# Disable additional fields
4785
"LOG_GIT_HASH": "false",
@@ -60,28 +98,6 @@ def get_client(self, data_path):
6098
logger.error(f"Error during client creation: {str(e)}")
6199
return None
62100

63-
def __init__(self, data_path="./weaviate_data"):
64-
self.data_path = data_path
65-
self.inference_engine = LlamaCppInferenceEngine()
66-
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
67-
self.schema_config = schema_config
68-
69-
# setup schema for weaviate
70-
weaviate_client = self.get_client(self.data_path)
71-
if weaviate_client is not None:
72-
try:
73-
weaviate_client.connect()
74-
self.setup_schema(weaviate_client)
75-
except Exception as e:
76-
logger.error(f"Failed to connect or setup schema: {str(e)}")
77-
finally:
78-
try:
79-
weaviate_client.close()
80-
except Exception as e:
81-
logger.error(f"Failed to close client: {str(e)}")
82-
else:
83-
logger.error("Could not find client, skipping schema setup.")
84-
85101
def setup_schema(self, client):
86102
for class_config in self.schema_config:
87103
if not client.collections.exists(class_config["name"]):
@@ -95,18 +111,16 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj
95111
return []
96112

97113
# Perform the vector search
98-
weaviate_client = self.get_client(self.data_path)
99-
if weaviate_client is None:
114+
if self.weaviate_client is None:
100115
logger.error("Could not find client, not returning results.")
101116
return []
102117

103-
if not weaviate_client:
118+
if not self.weaviate_client:
104119
logger.error("Invalid client, cannot perform search.")
105120
return []
106121

107122
try:
108-
weaviate_client.connect()
109-
packages = weaviate_client.collections.get("Package")
123+
packages = self.weaviate_client.collections.get("Package")
110124
response = packages.query.fetch_objects(
111125
filters=Filter.by_property(name).contains_any(properties),
112126
)
@@ -117,8 +131,6 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj
117131
except Exception as e:
118132
logger.error(f"An error occurred: {str(e)}")
119133
return []
120-
finally:
121-
weaviate_client.close()
122134

123135
async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list[object]:
124136
"""
@@ -135,14 +147,8 @@ async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list
135147
query_vector = await self.inference_engine.embed(self.model_path, [query])
136148

137149
# Perform the vector search
138-
weaviate_client = self.get_client(self.data_path)
139-
if weaviate_client is None:
140-
logger.error("Could not find client, not returning results.")
141-
return []
142-
143150
try:
144-
weaviate_client.connect()
145-
collection = weaviate_client.collections.get("Package")
151+
collection = self.weaviate_client.collections.get("Package")
146152
if packages:
147153
response = collection.query.near_vector(
148154
query_vector[0],
@@ -159,16 +165,10 @@ async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list
159165
return_metadata=MetadataQuery(distance=True),
160166
)
161167

162-
weaviate_client.close()
163168
if not response:
164169
return []
165170
return response.objects
166171

167172
except Exception as e:
168173
logger.error(f"Error during search: {str(e)}")
169174
return []
170-
finally:
171-
try:
172-
weaviate_client.close()
173-
except Exception as e:
174-
logger.error(f"Failed to close client: {str(e)}")

tests/test_storage.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
from codegate.config import Config
56
from codegate.storage.storage_engine import (
67
StorageEngine,
78
) # Adjust the import based on your actual path
@@ -34,17 +35,21 @@ def mock_inference_engine():
3435

3536
@pytest.mark.asyncio
3637
async def test_search(mock_weaviate_client, mock_inference_engine):
38+
Config.load(config_path="./config.yaml")
39+
3740
# Patch the LlamaCppInferenceEngine.embed method (not the entire class)
3841
with patch(
3942
"codegate.inference.inference_engine.LlamaCppInferenceEngine.embed",
4043
mock_inference_engine.embed,
4144
):
42-
43-
# Mock the WeaviateClient as before
44-
with patch("weaviate.WeaviateClient", return_value=mock_weaviate_client):
45-
45+
# Initialize StorageEngine
46+
with patch(
47+
"codegate.storage.storage_engine.StorageEngine.get_client",
48+
return_value=mock_weaviate_client,
49+
):
4650
# Initialize StorageEngine
47-
storage_engine = StorageEngine(data_path="./weaviate_data")
51+
# Need to recreate instance to use the mock
52+
storage_engine = StorageEngine.recreate_instance(data_path="./weaviate_data")
4853

4954
# Invoke the search method
5055
results = await storage_engine.search("test query", 5, 0.3)
@@ -53,4 +58,3 @@ async def test_search(mock_weaviate_client, mock_inference_engine):
5358
assert len(results) == 1 # Assert that one result is returned
5459
assert results[0]["properties"]["name"] == "test"
5560
mock_weaviate_client.connect.assert_called()
56-
mock_weaviate_client.close.assert_called()

0 commit comments

Comments
 (0)