Skip to content

Commit

Permalink
Add initial user msg to /new_conversation route (All-Hands-AI#6314)
Browse files Browse the repository at this point in the history
  • Loading branch information
malhotra5 authored Jan 17, 2025
1 parent 2edb233 commit 000055b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
8 changes: 6 additions & 2 deletions openhands/server/routes/manage_conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
class InitSessionRequest(BaseModel):
github_token: str | None = None
selected_repository: str | None = None
initial_user_msg: str | None = None


async def _create_new_conversation(
user_id: str | None,
token: str | None,
selected_repository: str | None,
initial_user_msg: str | None,
):
logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
Expand Down Expand Up @@ -89,7 +91,7 @@ async def _create_new_conversation(

logger.info(f'Starting agent loop for conversation {conversation_id}')
event_stream = await session_manager.maybe_start_agent_loop(
conversation_id, conversation_init_data, user_id
conversation_id, conversation_init_data, user_id, initial_user_msg
)
try:
event_stream.subscribe(
Expand All @@ -114,10 +116,11 @@ async def new_conversation(request: Request, data: InitSessionRequest):
user_id = get_user_id(request)
github_token = getattr(request.state, 'github_token', '') or data.github_token
selected_repository = data.selected_repository
initial_user_msg = data.initial_user_msg

try:
conversation_id = await _create_new_conversation(
user_id, github_token, selected_repository
user_id, github_token, selected_repository, initial_user_msg
)

return JSONResponse(
Expand All @@ -140,6 +143,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
'message': str(e),
'msg_id': 'STATUS$ERROR_LLM_AUTHENTICATION',
},
status_code=400,
)


Expand Down
8 changes: 8 additions & 0 deletions openhands/server/session/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.action import ChangeAgentStateAction
from openhands.events.action.message import MessageAction
from openhands.events.event import EventSource
from openhands.events.stream import EventStream
from openhands.microagent import BaseMicroAgent
Expand Down Expand Up @@ -71,6 +72,7 @@ async def start(
agent_configs: dict[str, AgentConfig] | None = None,
github_token: str | None = None,
selected_repository: str | None = None,
initial_user_msg: str | None = None,
):
"""Starts the Agent session
Parameters:
Expand Down Expand Up @@ -112,6 +114,12 @@ async def start(
self.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
)

if initial_user_msg:
self.event_stream.add_event(
MessageAction(content=initial_user_msg), EventSource.USER
)

self._starting = False

async def close(self):
Expand Down
8 changes: 6 additions & 2 deletions openhands/server/session/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,11 @@ async def _get_connections_remotely(
self._connection_queries.pop(query_id, None)

async def maybe_start_agent_loop(
self, sid: str, settings: Settings, user_id: str | None
self,
sid: str,
settings: Settings,
user_id: str | None,
initial_user_msg: str | None = None,
) -> EventStream:
logger.info(f'maybe_start_agent_loop:{sid}')
session: Session | None = None
Expand All @@ -462,7 +466,7 @@ async def maybe_start_agent_loop(
user_id=user_id,
)
self._local_agent_loops_by_sid[sid] = session
asyncio.create_task(session.initialize_agent(settings))
asyncio.create_task(session.initialize_agent(settings, initial_user_msg))

event_stream = await self._get_event_stream(sid)
if not event_stream:
Expand Down
6 changes: 2 additions & 4 deletions openhands/server/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ async def close(self):
self.is_alive = False
await self.agent_session.close()

async def initialize_agent(
self,
settings: Settings,
):
async def initialize_agent(self, settings: Settings, initial_user_msg: str | None):
self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING),
EventSource.ENVIRONMENT,
Expand Down Expand Up @@ -122,6 +119,7 @@ async def initialize_agent(
agent_configs=self.config.get_agent_configs(),
github_token=github_token,
selected_repository=selected_repository,
initial_user_msg=initial_user_msg,
)
except Exception as e:
logger.exception(f'Error creating agent_session: {e}')
Expand Down

0 comments on commit 000055b

Please sign in to comment.