diff --git a/README.md b/README.md
index c2ca68e..b94764c 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-
+from litellm.proxy.client.cli.commands.models import models
@@ -160,6 +160,22 @@ if __name__ == "__main__":
asyncio.run(main())
```
+## LLM Customization
+If you’d like to use a custom LLM, you can do so by providing an `apiKey` and `baseUrl` in the `model_client_options` parameter of the `StagehandConfig`.
+Most LLMs are OpenAI-compatible, and thus can be used with Stagehand as long as they support structured outputs.
+Only supported for 'LOCAL' environment.
+
+```python
+config = StagehandConfig(
+ env="LOCAL",
+ model_name="llama3.3",
+ model_client_options={
+ "apiKey": "llama3.3",
+ "baseUrl": "http://localhost:11434/v1"
+ }
+)
+```
+
## Documentation
See our full documentation [here](https://docs.stagehand.dev/).
diff --git a/stagehand/config.py b/stagehand/config.py
index a577230..89e0e33 100644
--- a/stagehand/config.py
+++ b/stagehand/config.py
@@ -19,7 +19,7 @@ class StagehandConfig(BaseModel):
browserbase_session_create_params (Optional[BrowserbaseSessionCreateParams]): Browserbase session create params.
browserbase_session_id (Optional[str]): Session ID for resuming Browserbase sessions.
model_name (Optional[str]): Name of the model to use.
- model_api_key (Optional[str]): Model API key.
+ model_client_options (Optional[dict[str, Any]]): Options for the model client.
logger (Optional[Callable[[Any], None]]): Custom logging function.
verbose (Optional[int]): Verbosity level for logs (1=minimal, 2=medium, 3=detailed).
use_rich_logging (bool): Whether to use Rich for colorized logging.
@@ -47,8 +47,8 @@ class StagehandConfig(BaseModel):
alias="apiUrl",
description="Stagehand API URL",
)
- model_api_key: Optional[str] = Field(
- None, alias="modelApiKey", description="Model API key"
+ model_client_options: Optional[dict[str, Any]] = Field(
+ None, alias="modelClientOptions", description="Configuration options for the language model client (i.e. apiKey, baseURL)",
)
verbose: Optional[int] = Field(
1,
diff --git a/stagehand/llm/client.py b/stagehand/llm/client.py
index 855b0fe..9d681d5 100644
--- a/stagehand/llm/client.py
+++ b/stagehand/llm/client.py
@@ -54,7 +54,7 @@ def __init__(
setattr(litellm, key, value)
self.logger.debug(f"Set global litellm.{key}", category="llm")
# Handle common aliases or expected config names if necessary
- elif key == "api_base": # Example: map api_base if needed
+ elif key == "api_base" or key == "baseURL": # Example: map api_base if needed
litellm.api_base = value
self.logger.debug(
f"Set global litellm.api_base to {value}", category="llm"
diff --git a/stagehand/main.py b/stagehand/main.py
index 0de682e..f826947 100644
--- a/stagehand/main.py
+++ b/stagehand/main.py
@@ -68,7 +68,10 @@ def __init__(
# Handle non-config parameters
self.api_url = self.config.api_url
- self.model_api_key = self.config.model_api_key or os.getenv("MODEL_API_KEY")
+
+ # Handle model-related settings
+ self.model_client_options = self.config.model_client_options or {}
+ self.model_api_key = self.model_client_options.get("apiKey") or os.getenv("MODEL_API_KEY")
self.model_name = self.config.model_name
# Extract frequently used values from config for convenience
@@ -89,11 +92,6 @@ def __init__(
self.config.local_browser_launch_options or {}
)
- # Handle model-related settings
- self.model_client_options = {}
- if self.model_api_key and "apiKey" not in self.model_client_options:
- self.model_client_options["apiKey"] = self.model_api_key
-
# Handle browserbase session create params
self.browserbase_session_create_params = make_serializable(
self.config.browserbase_session_create_params
diff --git a/tests/unit/llm/test_llm_integration.py b/tests/unit/llm/test_llm_integration.py
index a01e7a7..00d09c4 100644
--- a/tests/unit/llm/test_llm_integration.py
+++ b/tests/unit/llm/test_llm_integration.py
@@ -40,6 +40,7 @@ def test_llm_client_with_custom_options(self):
api_key="test-key",
default_model="gpt-4o-mini",
stagehand_logger=StagehandLogger(),
+ api_base="https://test-api-base.com",
)
assert client.default_model == "gpt-4o-mini"
diff --git a/tests/unit/test_client_api.py b/tests/unit/test_client_api.py
index f6cb20b..e76e30d 100644
--- a/tests/unit/test_client_api.py
+++ b/tests/unit/test_client_api.py
@@ -19,7 +19,7 @@ async def mock_client(self):
browserbase_session_id="test-session-123",
api_key="test-api-key",
project_id="test-project-id",
- model_api_key="test-model-api-key",
+ model_client_options={"apiKey": "test-model-api-key"}
)
return client
diff --git a/tests/unit/test_client_initialization.py b/tests/unit/test_client_initialization.py
index cd748ac..ff22039 100644
--- a/tests/unit/test_client_initialization.py
+++ b/tests/unit/test_client_initialization.py
@@ -23,7 +23,7 @@ def test_init_with_direct_params(self):
browserbase_session_id="test-session",
api_key="test-api-key",
project_id="test-project-id",
- model_api_key="test-model-api-key",
+ model_client_options={"apiKey": "test-model-api-key"},
verbose=2,
)
@@ -203,3 +203,32 @@ async def mock_create_session():
# Call _create_session and expect error
with pytest.raises(RuntimeError, match="Invalid response format"):
await client._create_session()
+
+ @mock.patch.dict(os.environ, {"MODEL_API_KEY": "test-model-api-key"}, clear=True)
+ def test_init_with_model_api_key_in_env(self):
+ config = StagehandConfig(env="LOCAL")
+ client = Stagehand(config=config)
+ assert client.model_api_key == "test-model-api-key"
+
+ def test_init_with_custom_llm(self):
+ config = StagehandConfig(
+ env="LOCAL",
+ model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"}
+ )
+ client = Stagehand(config=config)
+ assert client.model_api_key == "custom-llm-key"
+ assert client.model_client_options["apiKey"] == "custom-llm-key"
+ assert client.model_client_options["baseURL"] == "https://custom-llm.com"
+
+ def test_init_with_custom_llm_override(self):
+ config = StagehandConfig(
+ env="LOCAL",
+ model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"}
+ )
+ client = Stagehand(
+ config=config,
+ model_client_options={"apiKey": "override-llm-key", "baseURL": "https://override-llm.com"}
+ )
+ assert client.model_api_key == "override-llm-key"
+ assert client.model_client_options["apiKey"] == "override-llm-key"
+ assert client.model_client_options["baseURL"] == "https://override-llm.com"
\ No newline at end of file