Skip to content

enable session affinity for cache optimization #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 71 additions & 78 deletions stagehand/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
from typing import Any

import httpx

from .utils import convert_dict_keys_to_camel_case

__all__ = ["_create_session", "_execute"]
Expand Down Expand Up @@ -73,21 +71,20 @@ async def _create_session(self):
"x-language": "python",
}

client = httpx.AsyncClient(timeout=self.timeout_settings)
async with client:
resp = await client.post(
f"{self.api_url}/sessions/start",
json=payload,
headers=headers,
)
if resp.status_code != 200:
raise RuntimeError(f"Failed to create session: {resp.text}")
data = resp.json()
self.logger.debug(f"Session created: {data}")
if not data.get("success") or "sessionId" not in data.get("data", {}):
raise RuntimeError(f"Invalid response format: {resp.text}")
# async with self._client:
resp = await self._client.post(
f"{self.api_url}/sessions/start",
json=payload,
headers=headers,
)
if resp.status_code != 200:
raise RuntimeError(f"Failed to create session: {resp.text}")
data = resp.json()
self.logger.debug(f"Session created: {data}")
if not data.get("success") or "sessionId" not in data.get("data", {}):
raise RuntimeError(f"Invalid response format: {resp.text}")

self.session_id = data["data"]["sessionId"]
self.session_id = data["data"]["sessionId"]


async def _execute(self, method: str, payload: dict[str, Any]) -> Any:
Expand All @@ -109,65 +106,61 @@ async def _execute(self, method: str, payload: dict[str, Any]) -> Any:
# Convert snake_case keys to camelCase for the API
modified_payload = convert_dict_keys_to_camel_case(payload)

client = httpx.AsyncClient(timeout=self.timeout_settings)

async with client:
try:
# Always use streaming for consistent log handling
async with client.stream(
"POST",
f"{self.api_url}/sessions/{self.session_id}/{method}",
json=modified_payload,
headers=headers,
) as response:
if response.status_code != 200:
error_text = await response.aread()
error_message = error_text.decode("utf-8")
self.logger.error(
f"[HTTP ERROR] Status {response.status_code}: {error_message}"
)
raise RuntimeError(
f"Request failed with status {response.status_code}: {error_message}"
)
result = None

async for line in response.aiter_lines():
# Skip empty lines
if not line.strip():
continue

try:
# Handle SSE-style messages that start with "data: "
if line.startswith("data: "):
line = line[len("data: ") :]

message = json.loads(line)
# Handle different message types
msg_type = message.get("type")

if msg_type == "system":
status = message.get("data", {}).get("status")
if status == "error":
error_msg = message.get("data", {}).get(
"error", "Unknown error"
)
self.logger.error(f"[ERROR] {error_msg}")
raise RuntimeError(
f"Server returned error: {error_msg}"
)
elif status == "finished":
result = message.get("data", {}).get("result")
elif msg_type == "log":
# Process log message using _handle_log
await self._handle_log(message)
else:
# Log any other message types
self.logger.debug(f"[UNKNOWN] Message type: {msg_type}")
except json.JSONDecodeError:
self.logger.warning(f"Could not parse line as JSON: {line}")

# Return the final result
return result
except Exception as e:
self.logger.error(f"[EXCEPTION] {str(e)}")
raise
# async with self._client:
try:
# Always use streaming for consistent log handling
async with self._client.stream(
"POST",
f"{self.api_url}/sessions/{self.session_id}/{method}",
json=modified_payload,
headers=headers,
) as response:
if response.status_code != 200:
error_text = await response.aread()
error_message = error_text.decode("utf-8")
self.logger.error(
f"[HTTP ERROR] Status {response.status_code}: {error_message}"
)
raise RuntimeError(
f"Request failed with status {response.status_code}: {error_message}"
)
result = None

async for line in response.aiter_lines():
# Skip empty lines
if not line.strip():
continue

try:
# Handle SSE-style messages that start with "data: "
if line.startswith("data: "):
line = line[len("data: ") :]

message = json.loads(line)
# Handle different message types
msg_type = message.get("type")

if msg_type == "system":
status = message.get("data", {}).get("status")
if status == "error":
error_msg = message.get("data", {}).get(
"error", "Unknown error"
)
self.logger.error(f"[ERROR] {error_msg}")
raise RuntimeError(f"Server returned error: {error_msg}")
elif status == "finished":
result = message.get("data", {}).get("result")
elif msg_type == "log":
# Process log message using _handle_log
await self._handle_log(message)
else:
# Log any other message types
self.logger.debug(f"[UNKNOWN] Message type: {msg_type}")
except json.JSONDecodeError:
self.logger.error(f"Could not parse line as JSON: {line}")

# Return the final result
return result
except Exception as e:
self.logger.error(f"[EXCEPTION] {str(e)}")
raise
9 changes: 2 additions & 7 deletions stagehand/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(
)
if not self.model_api_key:
# Model API key needed if Stagehand server creates the session
self.logger.warning(
self.logger.info(
"model_api_key is recommended when creating a new BROWSERBASE session to configure the Stagehand server's LLM."
)
elif self.session_id:
Expand All @@ -161,9 +161,7 @@ def __init__(
# Register signal handlers for graceful shutdown
self._register_signal_handlers()

self._client: Optional[httpx.AsyncClient] = (
None # Used for server communication in BROWSERBASE
)
self._client = httpx.AsyncClient(timeout=self.timeout_settings)

self._playwright: Optional[Playwright] = None
self._browser = None
Expand Down Expand Up @@ -388,9 +386,6 @@ async def init(self):
self._playwright = await async_playwright().start()

if self.env == "BROWSERBASE":
if not self._client:
self._client = httpx.AsyncClient(timeout=self.timeout_settings)

# Create session if we don't have one
if not self.session_id:
await self._create_session() # Uses self._client and api_url
Expand Down
10 changes: 7 additions & 3 deletions tests/unit/llm/test_llm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from stagehand.llm.client import LLMClient
from tests.mocks.mock_llm import MockLLMClient, MockLLMResponse
from stagehand.logging import StagehandLogger


class TestLLMClientInitialization:
Expand All @@ -15,7 +16,8 @@ def test_llm_client_creation_with_openai(self):
"""Test LLM client creation with OpenAI provider"""
client = LLMClient(
api_key="test-openai-key",
default_model="gpt-4o"
default_model="gpt-4o",
stagehand_logger=StagehandLogger(),
)

assert client.default_model == "gpt-4o"
Expand All @@ -25,7 +27,8 @@ def test_llm_client_creation_with_anthropic(self):
"""Test LLM client creation with Anthropic provider"""
client = LLMClient(
api_key="test-anthropic-key",
default_model="claude-3-sonnet"
default_model="claude-3-sonnet",
stagehand_logger=StagehandLogger(),
)

assert client.default_model == "claude-3-sonnet"
Expand All @@ -35,7 +38,8 @@ def test_llm_client_with_custom_options(self):
"""Test LLM client with custom configuration options"""
client = LLMClient(
api_key="test-key",
default_model="gpt-4o-mini"
default_model="gpt-4o-mini",
stagehand_logger=StagehandLogger(),
)

assert client.default_model == "gpt-4o-mini"
Expand Down