Skip to content

Commit

Permalink
Feat: Allow checking multiple conversations running at the same time (A…
Browse files Browse the repository at this point in the history
  • Loading branch information
tofarr authored Dec 26, 2024
1 parent 69a9080 commit 5005986
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 41 deletions.
96 changes: 68 additions & 28 deletions openhands/server/session/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import time
from dataclasses import dataclass, field
from uuid import uuid4

import socketio

Expand All @@ -27,6 +28,14 @@ class ConversationDoesNotExistError(Exception):
pass


@dataclass
class _SessionIsRunningCheck:
request_id: str
request_sids: list[str]
running_sids: set[str] = field(default_factory=set)
flag: asyncio.Event = field(default_factory=asyncio.Event)


@dataclass
class SessionManager:
sio: socketio.AsyncServer
Expand All @@ -36,7 +45,9 @@ class SessionManager:
local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
_last_alive_timestamps: dict[str, float] = field(default_factory=dict)
_redis_listen_task: asyncio.Task | None = None
_session_is_running_flags: dict[str, asyncio.Event] = field(default_factory=dict)
_session_is_running_checks: dict[str, _SessionIsRunningCheck] = field(
default_factory=dict
)
_active_conversations: dict[str, tuple[Conversation, int]] = field(
default_factory=dict
)
Expand Down Expand Up @@ -97,27 +108,41 @@ async def _redis_subscribe(self):
async def _process_message(self, message: dict):
data = json.loads(message['data'])
logger.debug(f'got_published_message:{message}')
sid = data['sid']
message_type = data['message_type']
if message_type == 'event':
sid = data['sid']
session = self._local_agent_loops_by_sid.get(sid)
if session:
await session.dispatch(data['data'])
elif message_type == 'is_session_running':
# Another node in the cluster is asking if the current node is running the session given.
session = self._local_agent_loops_by_sid.get(sid)
if session:
request_id = data['request_id']
sids = [
sid for sid in data['sids'] if sid in self._local_agent_loops_by_sid
]
if sids:
await self._get_redis_client().publish(
'oh_event',
json.dumps({'sid': sid, 'message_type': 'session_is_running'}),
json.dumps(
{
'request_id': request_id,
'sids': sids,
'message_type': 'session_is_running',
}
),
)
elif message_type == 'session_is_running':
self._last_alive_timestamps[sid] = time.time()
flag = self._session_is_running_flags.get(sid)
if flag:
flag.set()
request_id = data['request_id']
for sid in data['sids']:
self._last_alive_timestamps[sid] = time.time()
check = self._session_is_running_checks.get(request_id)
if check:
check.running_sids.update(data['sids'])
if len(check.request_sids) == len(check.running_sids):
check.flag.set()
elif message_type == 'has_remote_connections_query':
# Another node in the cluster is asking if the current node is connected to a session
sid = data['sid']
required = sid in self.local_connection_id_to_session_id.values()
if required:
await self._get_redis_client().publish(
Expand All @@ -127,12 +152,14 @@ async def _process_message(self, message: dict):
),
)
elif message_type == 'has_remote_connections_response':
sid = data['sid']
flag = self._has_remote_connections_flags.get(sid)
if flag:
flag.set()
elif message_type == 'session_closing':
# Session closing event - We only get this in the event of graceful shutdown,
# which can't be guaranteed - nodes can simply vanish unexpectedly!
sid = data['sid']
logger.debug(f'session_closing:{sid}')
for (
connection_id,
Expand Down Expand Up @@ -234,47 +261,60 @@ async def _cleanup_detached_conversations(self):
logger.warning('error_cleaning_detached_conversations', exc_info=True)
await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)

async def _is_agent_loop_running(self, sid: str) -> bool:
if await self._is_agent_loop_running_locally(sid):
async def get_agent_loop_running(self, sids: set[str]) -> set[str]:
running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid)
check_cluster_sids = [sid for sid in sids if sid not in running_sids]
running_cluster_sids = await self.get_agent_loop_running_in_cluster(
check_cluster_sids
)
running_sids.union(running_cluster_sids)
return running_sids

async def is_agent_loop_running(self, sid: str) -> bool:
if await self.is_agent_loop_running_locally(sid):
return True
if await self._is_agent_loop_running_in_cluster(sid):
if await self.is_agent_loop_running_in_cluster(sid):
return True
return False

async def _is_agent_loop_running_locally(self, sid: str) -> bool:
if self._local_agent_loops_by_sid.get(sid, None):
return True
return False
async def is_agent_loop_running_locally(self, sid: str) -> bool:
return sid in self._local_agent_loops_by_sid

async def is_agent_loop_running_in_cluster(self, sid: str) -> bool:
running_sids = await self.get_agent_loop_running_in_cluster([sid])
return bool(running_sids)

async def _is_agent_loop_running_in_cluster(self, sid: str) -> bool:
async def get_agent_loop_running_in_cluster(self, sids: list[str]) -> set[str]:
"""As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
redis_client = self._get_redis_client()
if not redis_client:
return False
return set()

flag = asyncio.Event()
self._session_is_running_flags[sid] = flag
request_id = str(uuid4())
check = _SessionIsRunningCheck(request_id=request_id, request_sids=sids)
self._session_is_running_checks[request_id] = check
try:
logger.debug(f'publish:is_session_running:{sid}')
logger.debug(f'publish:is_session_running:{sids}')
await redis_client.publish(
'oh_event',
json.dumps(
{
'sid': sid,
'request_id': request_id,
'sids': sids,
'message_type': 'is_session_running',
}
),
)
async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
await flag.wait()

result = flag.is_set()
return result
return check.running_sids
except TimeoutError:
# Nobody replied in time
return False
return check.running_sids
finally:
self._session_is_running_flags.pop(sid, None)
self._session_is_running_checks.pop(request_id, None)

async def _has_remote_connections(self, sid: str) -> bool:
"""As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
Expand Down Expand Up @@ -307,7 +347,7 @@ async def maybe_start_agent_loop(
) -> EventStream:
logger.info(f'maybe_start_agent_loop:{sid}')
session: Session | None = None
if not await self._is_agent_loop_running(sid):
if not await self.is_agent_loop_running(sid):
logger.info(f'start_agent_loop:{sid}')
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
Expand All @@ -328,7 +368,7 @@ async def _get_event_stream(self, sid: str) -> EventStream | None:
logger.info(f'found_local_agent_loop:{sid}')
return session.agent_session.event_stream

if await self._is_agent_loop_running_in_cluster(sid):
if await self.is_agent_loop_running_in_cluster(sid):
logger.info(f'found_remote_agent_loop:{sid}')
return EventStream(sid, self.file_store)

Expand All @@ -352,7 +392,7 @@ async def send_to_event_stream(self, connection_id: str, data: dict):
next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
if (
next_alive_check > time.time()
or await self._is_agent_loop_running_in_cluster(sid)
or await self.is_agent_loop_running_in_cluster(sid)
):
# Send the event to the other pod
await redis_client.publish(
Expand Down
4 changes: 2 additions & 2 deletions openhands/storage/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
class InMemoryFileStore(FileStore):
files: dict[str, str]

def __init__(self):
self.files = IN_MEMORY_FILES
def __init__(self, files: dict[str, str] = IN_MEMORY_FILES):
self.files = files

def write(self, path: str, contents: str) -> None:
self.files[path] = contents
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from openhands.llm.metrics import Metrics
from openhands.runtime.base import Runtime
from openhands.storage import get_file_store
from openhands.storage.memory import InMemoryFileStore


@pytest.fixture
Expand Down Expand Up @@ -168,7 +169,7 @@ async def on_event(event: Event):
@pytest.mark.asyncio
async def test_run_controller_stop_with_stuck():
config = AppConfig()
file_store = get_file_store(config.file_store, config.file_store_path)
file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store)

agent = MagicMock(spec=Agent)
Expand Down
33 changes: 23 additions & 10 deletions tests/unit/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4

import pytest

Expand Down Expand Up @@ -35,44 +36,56 @@ def get_mock_sio(get_message: GetMessageMock | None = None):
@pytest.mark.asyncio
async def test_session_not_running_in_cluster():
sio = get_mock_sio()
id = uuid4()
with (
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
patch('openhands.server.session.manager.uuid4', MagicMock(return_value=id)),
):
async with SessionManager(
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
result = await session_manager._is_agent_loop_running_in_cluster(
result = await session_manager.is_agent_loop_running_in_cluster(
'non-existant-session'
)
assert result is False
assert sio.manager.redis.publish.await_count == 1
sio.manager.redis.publish.assert_called_once_with(
'oh_event',
'{"sid": "non-existant-session", "message_type": "is_session_running"}',
'{"request_id": "'
+ str(id)
+ '", "sids": ["non-existant-session"], "message_type": "is_session_running"}',
)


@pytest.mark.asyncio
async def test_session_is_running_in_cluster():
id = uuid4()
sio = get_mock_sio(
GetMessageMock(
{'sid': 'existing-session', 'message_type': 'session_is_running'}
{
'request_id': str(id),
'sids': ['existing-session'],
'message_type': 'session_is_running',
}
)
)
with (
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
patch('openhands.server.session.manager.uuid4', MagicMock(return_value=id)),
):
async with SessionManager(
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
result = await session_manager._is_agent_loop_running_in_cluster(
result = await session_manager.is_agent_loop_running_in_cluster(
'existing-session'
)
assert result is True
assert sio.manager.redis.publish.await_count == 1
sio.manager.redis.publish.assert_called_once_with(
'oh_event',
'{"sid": "existing-session", "message_type": "is_session_running"}',
'{"request_id": "'
+ str(id)
+ '", "sids": ["existing-session"], "message_type": "is_session_running"}',
)


Expand All @@ -93,7 +106,7 @@ async def test_init_new_local_session():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
),
):
Expand Down Expand Up @@ -125,7 +138,7 @@ async def test_join_local_session():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
),
):
Expand Down Expand Up @@ -158,7 +171,7 @@ async def test_join_cluster_session():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
),
):
Expand Down Expand Up @@ -187,7 +200,7 @@ async def test_add_to_local_event_stream():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
),
):
Expand Down Expand Up @@ -221,7 +234,7 @@ async def test_add_to_cluster_event_stream():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
),
):
Expand Down

0 comments on commit 5005986

Please sign in to comment.