diff --git a/.env.example b/.env.example index 2b228c0..26d589e 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,4 @@ MODEL_API_KEY = "your-favorite-llm-api-key" BROWSERBASE_API_KEY = "browserbase-api-key" BROWSERBASE_PROJECT_ID = "browserbase-project-id" -STAGEHAND_API_URL = "api_url" STAGEHAND_ENV= "LOCAL or BROWSERBASE" diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index b6d0d11..f99cd17 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -23,19 +23,19 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' - name: Install dependencies run: | python -m pip install --upgrade pip - pip install build twine wheel setuptools ruff black tomllib + pip install build twine wheel setuptools ruff black pip install -r requirements.txt if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi diff --git a/stagehand/agent/agent.py b/stagehand/agent/agent.py index 09fe0c2..c3f0e97 100644 --- a/stagehand/agent/agent.py +++ b/stagehand/agent/agent.py @@ -36,7 +36,7 @@ def __init__(self, stagehand_client, **kwargs): self.stagehand = stagehand_client self.config = AgentConfig(**kwargs) if kwargs else AgentConfig() self.logger = self.stagehand.logger - if self.stagehand.env == "BROWSERBASE": + if self.stagehand.use_api: if self.config.model in MODEL_TO_PROVIDER_MAP: self.provider = MODEL_TO_PROVIDER_MAP[self.config.model] else: @@ -120,7 +120,7 @@ async def execute( instruction = options.instruction - if self.stagehand.env == "LOCAL": + if not self.stagehand.use_api: self.logger.info( f"Agent starting execution for instruction: '{instruction}'", category="agent", diff --git a/stagehand/api.py b/stagehand/api.py index 849ac8d..ff8601c 100644 --- a/stagehand/api.py +++ b/stagehand/api.py @@ -41,7 +41,6 @@ async def _create_session(self): }, } ), - "proxies": True, } # Add the new parameters if they have values diff --git a/stagehand/browser.py b/stagehand/browser.py index 88c60b7..832900c 100644 --- a/stagehand/browser.py +++ b/stagehand/browser.py @@ -6,6 +6,7 @@ from typing import Any, Optional from browserbase import Browserbase +from browserbase.types import SessionCreateParams as BrowserbaseSessionCreateParams from playwright.async_api import ( Browser, BrowserContext, @@ -40,11 +41,30 @@ async def connect_browserbase_browser( # Connect to remote browser via Browserbase SDK and CDP bb = Browserbase(api_key=browserbase_api_key) try: - session = bb.sessions.retrieve(session_id) - if session.status != "RUNNING": - raise RuntimeError( - f"Browserbase session {session_id} is not running (status: {session.status})" + if session_id: + session = bb.sessions.retrieve(session_id) + if session.status != "RUNNING": + raise RuntimeError( + f"Browserbase session {session_id} is not running (status: {session.status})" + ) + else: + browserbase_session_create_params = ( + BrowserbaseSessionCreateParams( + project_id=stagehand_instance.browserbase_project_id, + browser_settings={ + "viewport": { + "width": 1024, + "height": 768, + }, + }, + ) + if not stagehand_instance.browserbase_session_create_params + else stagehand_instance.browserbase_session_create_params ) + session = bb.sessions.create(**browserbase_session_create_params) + if not session.id: + raise Exception("Could not create Browserbase session") + stagehand_instance.session_id = session.id connect_url = session.connectUrl except Exception as e: logger.error(f"Error retrieving or validating Browserbase session: {str(e)}") diff --git a/stagehand/config.py b/stagehand/config.py index 107a296..4dc1a0e 100644 --- a/stagehand/config.py +++ b/stagehand/config.py @@ -96,6 +96,16 @@ class StagehandConfig(BaseModel): alias="localBrowserLaunchOptions", description="Local browser launch options", ) + use_api: Optional[bool] = Field( + True, + alias=None, + description="Whether to use the Stagehand API", + ) + experimental: Optional[bool] = Field( + False, + alias=None, + description="Whether to use experimental features", + ) model_config = ConfigDict(populate_by_name=True) diff --git a/stagehand/main.py b/stagehand/main.py index eb453fa..3412284 100644 --- a/stagehand/main.py +++ b/stagehand/main.py @@ -168,13 +168,23 @@ def __init__( self._playwright_page: Optional[PlaywrightPage] = None self.page: Optional[StagehandPage] = None self.context: Optional[StagehandContext] = None + self.use_api = self.config.use_api + self.experimental = self.config.experimental + if self.experimental: + self.use_api = False + if ( + self.browserbase_session_create_params + and self.browserbase_session_create_params.get("region") + and self.browserbase_session_create_params.get("region") != "us-west-2" + ): + self.use_api = False self._initialized = False # Flag to track if init() has run self._closed = False # Flag to track if resources have been closed # Setup LLM client if LOCAL mode self.llm = None - if self.env == "LOCAL": + if not self.use_api: self.llm = LLMClient( stagehand_logger=self.logger, api_key=self.model_api_key, @@ -385,15 +395,16 @@ async def init(self): if self.env == "BROWSERBASE": # Create session if we don't have one - if not self.session_id: - await self._create_session() # Uses self._client and api_url - self.logger.debug( - f"Created new Browserbase session via Stagehand server: {self.session_id}" - ) - else: - self.logger.debug( - f"Using existing Browserbase session: {self.session_id}" - ) + if self.use_api: + if not self.session_id: + await self._create_session() # Uses self._client and api_url + self.logger.debug( + f"Created new Browserbase session via Stagehand server: {self.session_id}" + ) + else: + self.logger.debug( + f"Using existing Browserbase session: {self.session_id}" + ) # Connect to remote browser try: @@ -470,8 +481,8 @@ async def close(self): self.logger.debug("Closing resources...") - if self.env == "BROWSERBASE": - # --- BROWSERBASE Cleanup --- + if self.use_api: + # --- BROWSERBASE Cleanup (API) --- # End the session on the server if we have a session ID if self.session_id and self._client: # Check if client was initialized try: diff --git a/stagehand/page.py b/stagehand/page.py index ecaf112..300d7b2 100644 --- a/stagehand/page.py +++ b/stagehand/page.py @@ -80,7 +80,7 @@ async def goto( Returns: The result from the Stagehand server's navigation execution. """ - if self._stagehand.env == "LOCAL": + if not self._stagehand.use_api: await self._page.goto( url, referer=referer, timeout=timeout, wait_until=wait_until ) @@ -142,7 +142,7 @@ async def act( ) # TODO: Temporary until we move api based logic to client - if self._stagehand.env == "LOCAL": + if not self._stagehand.use_api: # TODO: revisit passing user_provided_instructions if not hasattr(self, "_observe_handler"): # TODO: revisit handlers initialization on page creation @@ -207,7 +207,7 @@ async def observe( payload = options_obj.model_dump(exclude_none=True, by_alias=True) # If in LOCAL mode, use local implementation - if self._stagehand.env == "LOCAL": + if not self._stagehand.use_api: self._stagehand.logger.debug( "observe", category="observe", auxiliary=payload ) @@ -324,8 +324,7 @@ async def extract( else: schema_to_validate_with = DefaultExtractSchema - # If in LOCAL mode, use local implementation - if self._stagehand.env == "LOCAL": + if not self._stagehand.use_api: # If we don't have an extract handler yet, create one if not hasattr(self, "_extract_handler"): self._extract_handler = ExtractHandler( @@ -391,18 +390,16 @@ async def screenshot(self, options: Optional[dict] = None) -> str: Returns: str: Base64-encoded screenshot data. """ - if self._stagehand.env == "LOCAL": - self._stagehand.logger.info( - "Local execution of screenshot is not implemented" - ) - return None payload = options or {} - lock = self._stagehand._get_lock_for_session() - async with lock: - result = await self._stagehand._execute("screenshot", payload) + if self._stagehand.use_api: + lock = self._stagehand._get_lock_for_session() + async with lock: + result = await self._stagehand._execute("screenshot", payload) - return result + return result + else: + return await self._page.screenshot(options) # Method to get or initialize the persistent CDP client async def get_cdp_client(self) -> CDPSession: