diff --git a/stagehand/api.py b/stagehand/api.py index 3935dea..849ac8d 100644 --- a/stagehand/api.py +++ b/stagehand/api.py @@ -1,8 +1,6 @@ import json from typing import Any -import httpx - from .utils import convert_dict_keys_to_camel_case __all__ = ["_create_session", "_execute"] @@ -73,21 +71,20 @@ async def _create_session(self): "x-language": "python", } - client = httpx.AsyncClient(timeout=self.timeout_settings) - async with client: - resp = await client.post( - f"{self.api_url}/sessions/start", - json=payload, - headers=headers, - ) - if resp.status_code != 200: - raise RuntimeError(f"Failed to create session: {resp.text}") - data = resp.json() - self.logger.debug(f"Session created: {data}") - if not data.get("success") or "sessionId" not in data.get("data", {}): - raise RuntimeError(f"Invalid response format: {resp.text}") + # async with self._client: + resp = await self._client.post( + f"{self.api_url}/sessions/start", + json=payload, + headers=headers, + ) + if resp.status_code != 200: + raise RuntimeError(f"Failed to create session: {resp.text}") + data = resp.json() + self.logger.debug(f"Session created: {data}") + if not data.get("success") or "sessionId" not in data.get("data", {}): + raise RuntimeError(f"Invalid response format: {resp.text}") - self.session_id = data["data"]["sessionId"] + self.session_id = data["data"]["sessionId"] async def _execute(self, method: str, payload: dict[str, Any]) -> Any: @@ -109,65 +106,61 @@ async def _execute(self, method: str, payload: dict[str, Any]) -> Any: # Convert snake_case keys to camelCase for the API modified_payload = convert_dict_keys_to_camel_case(payload) - client = httpx.AsyncClient(timeout=self.timeout_settings) - - async with client: - try: - # Always use streaming for consistent log handling - async with client.stream( - "POST", - f"{self.api_url}/sessions/{self.session_id}/{method}", - json=modified_payload, - headers=headers, - ) as response: - if response.status_code != 200: - error_text = await response.aread() - error_message = error_text.decode("utf-8") - self.logger.error( - f"[HTTP ERROR] Status {response.status_code}: {error_message}" - ) - raise RuntimeError( - f"Request failed with status {response.status_code}: {error_message}" - ) - result = None - - async for line in response.aiter_lines(): - # Skip empty lines - if not line.strip(): - continue - - try: - # Handle SSE-style messages that start with "data: " - if line.startswith("data: "): - line = line[len("data: ") :] - - message = json.loads(line) - # Handle different message types - msg_type = message.get("type") - - if msg_type == "system": - status = message.get("data", {}).get("status") - if status == "error": - error_msg = message.get("data", {}).get( - "error", "Unknown error" - ) - self.logger.error(f"[ERROR] {error_msg}") - raise RuntimeError( - f"Server returned error: {error_msg}" - ) - elif status == "finished": - result = message.get("data", {}).get("result") - elif msg_type == "log": - # Process log message using _handle_log - await self._handle_log(message) - else: - # Log any other message types - self.logger.debug(f"[UNKNOWN] Message type: {msg_type}") - except json.JSONDecodeError: - self.logger.warning(f"Could not parse line as JSON: {line}") - - # Return the final result - return result - except Exception as e: - self.logger.error(f"[EXCEPTION] {str(e)}") - raise + # async with self._client: + try: + # Always use streaming for consistent log handling + async with self._client.stream( + "POST", + f"{self.api_url}/sessions/{self.session_id}/{method}", + json=modified_payload, + headers=headers, + ) as response: + if response.status_code != 200: + error_text = await response.aread() + error_message = error_text.decode("utf-8") + self.logger.error( + f"[HTTP ERROR] Status {response.status_code}: {error_message}" + ) + raise RuntimeError( + f"Request failed with status {response.status_code}: {error_message}" + ) + result = None + + async for line in response.aiter_lines(): + # Skip empty lines + if not line.strip(): + continue + + try: + # Handle SSE-style messages that start with "data: " + if line.startswith("data: "): + line = line[len("data: ") :] + + message = json.loads(line) + # Handle different message types + msg_type = message.get("type") + + if msg_type == "system": + status = message.get("data", {}).get("status") + if status == "error": + error_msg = message.get("data", {}).get( + "error", "Unknown error" + ) + self.logger.error(f"[ERROR] {error_msg}") + raise RuntimeError(f"Server returned error: {error_msg}") + elif status == "finished": + result = message.get("data", {}).get("result") + elif msg_type == "log": + # Process log message using _handle_log + await self._handle_log(message) + else: + # Log any other message types + self.logger.debug(f"[UNKNOWN] Message type: {msg_type}") + except json.JSONDecodeError: + self.logger.error(f"Could not parse line as JSON: {line}") + + # Return the final result + return result + except Exception as e: + self.logger.error(f"[EXCEPTION] {str(e)}") + raise diff --git a/stagehand/main.py b/stagehand/main.py index 96a28ef..e8662c2 100644 --- a/stagehand/main.py +++ b/stagehand/main.py @@ -144,7 +144,7 @@ def __init__( ) if not self.model_api_key: # Model API key needed if Stagehand server creates the session - self.logger.warning( + self.logger.info( "model_api_key is recommended when creating a new BROWSERBASE session to configure the Stagehand server's LLM." ) elif self.session_id: @@ -161,9 +161,7 @@ def __init__( # Register signal handlers for graceful shutdown self._register_signal_handlers() - self._client: Optional[httpx.AsyncClient] = ( - None # Used for server communication in BROWSERBASE - ) + self._client = httpx.AsyncClient(timeout=self.timeout_settings) self._playwright: Optional[Playwright] = None self._browser = None @@ -388,9 +386,6 @@ async def init(self): self._playwright = await async_playwright().start() if self.env == "BROWSERBASE": - if not self._client: - self._client = httpx.AsyncClient(timeout=self.timeout_settings) - # Create session if we don't have one if not self.session_id: await self._create_session() # Uses self._client and api_url diff --git a/tests/unit/llm/test_llm_integration.py b/tests/unit/llm/test_llm_integration.py index cb1120b..a01e7a7 100644 --- a/tests/unit/llm/test_llm_integration.py +++ b/tests/unit/llm/test_llm_integration.py @@ -6,6 +6,7 @@ from stagehand.llm.client import LLMClient from tests.mocks.mock_llm import MockLLMClient, MockLLMResponse +from stagehand.logging import StagehandLogger class TestLLMClientInitialization: @@ -15,7 +16,8 @@ def test_llm_client_creation_with_openai(self): """Test LLM client creation with OpenAI provider""" client = LLMClient( api_key="test-openai-key", - default_model="gpt-4o" + default_model="gpt-4o", + stagehand_logger=StagehandLogger(), ) assert client.default_model == "gpt-4o" @@ -25,7 +27,8 @@ def test_llm_client_creation_with_anthropic(self): """Test LLM client creation with Anthropic provider""" client = LLMClient( api_key="test-anthropic-key", - default_model="claude-3-sonnet" + default_model="claude-3-sonnet", + stagehand_logger=StagehandLogger(), ) assert client.default_model == "claude-3-sonnet" @@ -35,7 +38,8 @@ def test_llm_client_with_custom_options(self): """Test LLM client with custom configuration options""" client = LLMClient( api_key="test-key", - default_model="gpt-4o-mini" + default_model="gpt-4o-mini", + stagehand_logger=StagehandLogger(), ) assert client.default_model == "gpt-4o-mini"