forked from AbanteAI/archive-old-cli-mentat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconftest.py
334 lines (254 loc) · 10.1 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
import gc
import os
import shutil
import stat
import subprocess
import tempfile
import time
from datetime import datetime
from pathlib import Path
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as AsyncChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from mentat import config
from mentat.agent_handler import AgentHandler
from mentat.auto_completer import AutoCompleter
from mentat.code_context import CodeContext
from mentat.code_file_manager import CodeFileManager
from mentat.config import Config, config_file_name
from mentat.conversation import Conversation
from mentat.cost_tracker import CostTracker
from mentat.git_handler import get_git_root_for_path
from mentat.llm_api_handler import LlmApiHandler
from mentat.sampler.sampler import Sampler
from mentat.session_context import SESSION_CONTEXT, SessionContext
from mentat.session_stream import SessionStream, StreamMessage, StreamMessageSource
from mentat.streaming_printer import StreamingPrinter
from mentat.vision.vision_manager import VisionManager
pytest_plugins = ("pytest_reportlog",)
def filter_mark(items, mark, exists):
new_items = []
for item in items:
marker = item.get_closest_marker(mark)
if bool(marker) == bool(exists):
new_items.append(item)
return new_items
def pytest_addoption(parser):
parser.addoption("--benchmark", action="store_true")
parser.addoption("--uitest", action="store_true")
@pytest.fixture
def benchmarks(request):
benchmarks = request.config.getoption("--benchmarks")
if len(benchmarks) == 1:
return benchmarks[0]
return benchmarks
def pytest_configure(config):
config.addinivalue_line("markers", "benchmark: run benchmarks that call openai")
config.addinivalue_line(
"markers", "uitest: run ui-tests that get evaluated by humans"
)
config.addinivalue_line(
"markers", "clear_testbed: create a testbed without any existing files"
)
config.addinivalue_line("markers", "no_git_testbed: create a testbed without git")
def pytest_collection_modifyitems(config, items):
benchmark = config.getoption("--benchmark")
uitest = config.getoption("--uitest")
items[:] = filter_mark(items, "benchmark", benchmark)
items[:] = filter_mark(items, "uitest", uitest)
@pytest.fixture
def get_marks(request):
return [mark.name for mark in request.node.iter_markers()]
@pytest.fixture
def mock_collect_user_input(mocker):
async_mock = AsyncMock()
mocker.patch("mentat.session_input._get_input_request", side_effect=async_mock)
def set_stream_messages(values):
async_mock.side_effect = [
StreamMessage(
id=uuid4(),
channel="default",
source=StreamMessageSource.CLIENT,
data=value,
extra={},
created_at=datetime.utcnow(),
)
for value in values
]
async_mock.set_stream_messages = set_stream_messages
return async_mock
@pytest.fixture(scope="function")
def mock_call_llm_api(mocker):
completion_mock = mocker.patch.object(LlmApiHandler, "call_llm_api")
def wrap_unstreamed_string(value):
timestamp = int(time.time())
return ChatCompletion(
id="test-id",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content=value,
role="assistant",
),
)
],
created=timestamp,
model="test-model",
object="chat.completion",
)
def wrap_streamed_strings(values):
async def _async_generator():
timestamp = int(time.time())
for value in values:
yield ChatCompletionChunk(
id="test-id",
choices=[
AsyncChoice(
delta=ChoiceDelta(content=value, role="assistant"),
finish_reason=None,
index=0,
)
],
created=timestamp,
model="test-model",
object="chat.completion.chunk",
)
return _async_generator()
def set_streamed_values(values):
completion_mock.return_value = wrap_streamed_strings(values)
completion_mock.set_streamed_values = set_streamed_values
def set_unstreamed_values(value):
completion_mock.return_value = wrap_unstreamed_string(value)
completion_mock.set_unstreamed_values = set_unstreamed_values
def set_return_values(values):
async def call_llm_api_mock(messages, model, stream, response_format="unused"):
value = call_llm_api_mock.values.pop()
if stream:
return wrap_streamed_strings([value])
else:
return wrap_unstreamed_string(value)
call_llm_api_mock.values = values[::-1]
completion_mock.side_effect = call_llm_api_mock
completion_mock.set_return_values = set_return_values
return completion_mock
@pytest.fixture(scope="function")
def mock_call_embedding_api(mocker):
embedding_mock = mocker.patch.object(LlmApiHandler, "call_embedding_api")
def set_embedding_values(value):
embedding_mock.return_value = value
embedding_mock.set_embedding_values = set_embedding_values
return embedding_mock
### Auto-used fixtures
@pytest.fixture(autouse=True, scope="function")
def mock_initialize_client(mocker, request):
if not request.config.getoption("--benchmark"):
mocker.patch.object(LlmApiHandler, "initialize_client")
# ContextVars need to be set in a synchronous fixture due to pytest not propagating
# async fixture contexts to test contexts.
# https://github.com/pytest-dev/pytest-asyncio/issues/127
@pytest.fixture(autouse=True)
def mock_session_context(temp_testbed):
"""
This is autoused to make it easier to write tests without having to worry about whether
or not SessionContext is set; however, this SessionContext will be overwritten by the SessionContext
set by a Session if the test creates a Session.
If you create a Session or Client in your test, do NOT use this SessionContext!
"""
git_root = get_git_root_for_path(temp_testbed, raise_error=False)
stream = SessionStream()
cost_tracker = CostTracker()
config = Config()
llm_api_handler = LlmApiHandler()
code_context = CodeContext(stream, git_root)
code_file_manager = CodeFileManager()
conversation = Conversation()
vision_manager = VisionManager()
agent_handler = AgentHandler()
auto_completer = AutoCompleter()
sampler = Sampler()
session_context = SessionContext(
Path.cwd(),
stream,
llm_api_handler,
cost_tracker,
config,
code_context,
code_file_manager,
conversation,
vision_manager,
agent_handler,
auto_completer,
sampler,
)
token = SESSION_CONTEXT.set(session_context)
yield session_context
SESSION_CONTEXT.reset(token)
@pytest.fixture
def mock_code_context(temp_testbed, mock_session_context):
return mock_session_context.code_context
### Auto-used fixtures
def run_git_command(cwd, *args):
"""Helper function to run a git command."""
subprocess.run(
["git"] + list(args),
cwd=cwd,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
def add_permissions(func, path, exc_info):
"""
Error handler for ``shutil.rmtree``.
If the error is due to an access error (read only file)
it attempts to add write permission and then retries.
If the error is for another reason it re-raises the error.
"""
gc.collect() # Force garbage collection
# Is the error an access error?
if not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise
@pytest.fixture(autouse=True)
def temp_testbed(monkeypatch, get_marks):
# Allow us to run tests from any directory
base_dir = Path(__file__).parent.parent
# create temporary copy of testbed, complete with git repo
# realpath() resolves symlinks, required for paths to match on macOS
temp_dir = os.path.realpath(tempfile.mkdtemp())
temp_testbed = os.path.join(temp_dir, "testbed")
os.mkdir(temp_testbed)
if "no_git_testbed" not in get_marks:
# Initialize git repo
run_git_command(temp_testbed, "init")
# Set local config for user.name and user.email. Set automatically on
# MacOS, but not Windows/Ubuntu, which prevents commits from taking.
run_git_command(temp_testbed, "config", "user.email", "[email protected]")
run_git_command(temp_testbed, "config", "user.name", "Test User")
if "clear_testbed" not in get_marks:
# Copy testbed
shutil.copytree(base_dir / "testbed", temp_testbed, dirs_exist_ok=True)
shutil.copy(base_dir / ".gitignore", temp_testbed)
if "no_git_testbed" not in get_marks:
# Add all files and commit
run_git_command(temp_testbed, "add", ".")
run_git_command(temp_testbed, "commit", "-m", "add testbed")
# necessary to undo chdir before calling rmtree, or it fails on windows
with monkeypatch.context() as m:
m.chdir(temp_testbed)
yield Path(temp_testbed)
shutil.rmtree(temp_dir, onerror=add_permissions)
# Always set the user config to just be a config in the temp_testbed; that way,
# it will be unset unless a specific test wants to make a config in the testbed
@pytest.fixture(autouse=True)
def mock_user_config(mocker):
config.user_config_path = Path(config_file_name)
@pytest.fixture(autouse=True)
def mock_sleep_time(mocker):
mocker.patch.object(StreamingPrinter, "sleep_time", new=lambda self: 0)