From 3fc9edf4ca15c3009841372cf69e2fddb7b2a3c3 Mon Sep 17 00:00:00 2001 From: Johan Stenberg Date: Fri, 11 Jul 2025 16:06:21 -0700 Subject: [PATCH 1/7] Add refresh auth headers (sync and async) as alternate approach to allow bearer tokens to be updated Allow api_key to be a callable to enable refresh of keys/tokens. --- src/openai/_client.py | 55 +++++++++++++------ src/openai/lib/azure.py | 2 +- .../resources/beta/realtime/realtime.py | 2 + 3 files changed, 41 insertions(+), 18 deletions(-) diff --git a/src/openai/_client.py b/src/openai/_client.py index b99db786a7..b461368161 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -3,11 +3,13 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Union, Mapping +from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable from typing_extensions import Self, override import httpx +from openai._models import FinalRequestOptions + from . import _exceptions from ._qs import Querystring from ._types import ( @@ -95,6 +97,7 @@ def __init__( self, *, api_key: str | None = None, + bearer_token_provider: Callable[[], str] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -128,11 +131,12 @@ def __init__( """ if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") - if api_key is None: + if api_key is None and bearer_token_provider is None: raise OpenAIError( "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" ) - self.api_key = api_key + self.bearer_token_provider = bearer_token_provider + self.api_key = api_key or '' if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -165,6 +169,7 @@ def __init__( ) self._default_stream_cls = Stream + self._auth_headers: dict[str, str] = {} @cached_property def completions(self) -> Completions: @@ -281,21 +286,26 @@ def with_raw_response(self) -> OpenAIWithRawResponse: @cached_property def with_streaming_response(self) -> OpenAIWithStreamedResponse: return OpenAIWithStreamedResponse(self) - @property @override def qs(self) -> Querystring: return Querystring(array_format="brackets") + def refresh_auth_headers(self): + bearer_token = self.bearer_token_provider() if self.bearer_token_provider else self.api_key + self._auth_headers = {"Authorization": f"Bearer {bearer_token}"} + + + @override + def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + self.refresh_auth_headers() + return super()._prepare_options(options) + @property @override def auth_headers(self) -> dict[str, str]: - api_key = self.api_key - if not api_key: - # if the api key is an empty string, encoding the header will fail - return {} - return {"Authorization": f"Bearer {api_key}"} - + return self._auth_headers + @property @override def default_headers(self) -> dict[str, str | Omit]: @@ -420,6 +430,7 @@ def __init__( self, *, api_key: str | None = None, + bearer_token_provider: Callable[[], Awaitable[str]] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -453,11 +464,12 @@ def __init__( """ if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") - if api_key is None: + if api_key is None and bearer_token_provider is None: raise OpenAIError( "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" ) - self.api_key = api_key + self.bearer_token_provider = bearer_token_provider + self.api_key = api_key or '' if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -490,6 +502,7 @@ def __init__( ) self._default_stream_cls = AsyncStream + self._auth_headers: dict[str, str] = {} @cached_property def completions(self) -> AsyncCompletions: @@ -612,14 +625,22 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") + async def refresh_auth_headers(self): + if self.bearer_token_provider: + bearer_token = await self.bearer_token_provider() + else: + bearer_token = self.api_key + self._auth_headers = {"Authorization": f"Bearer {bearer_token}"} + + @override + async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + await self.refresh_auth_headers() + return await super()._prepare_options(options) + @property @override def auth_headers(self) -> dict[str, str]: - api_key = self.api_key - if not api_key: - # if the api key is an empty string, encoding the header will fail - return {} - return {"Authorization": f"Bearer {api_key}"} + return self._auth_headers @property @override diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index a994e4256c..a714982e0c 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -628,7 +628,7 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt "api-version": self._api_version, "deployment": self._azure_deployment or model, } - if self.api_key != "": + if self.api_key and self.api_key != "": auth_headers = {"api-key": self.api_key} else: token = await self._get_azure_ad_token() diff --git a/src/openai/resources/beta/realtime/realtime.py b/src/openai/resources/beta/realtime/realtime.py index 7b99c7f6c4..90f6324d43 100644 --- a/src/openai/resources/beta/realtime/realtime.py +++ b/src/openai/resources/beta/realtime/realtime.py @@ -358,6 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query + await self.__client.refresh_auth_headers() auth_headers = self.__client.auth_headers if is_async_azure_client(self.__client): url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query) @@ -540,6 +541,7 @@ def __enter__(self) -> RealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query + self.__client.refresh_auth_headers() auth_headers = self.__client.auth_headers if is_azure_client(self.__client): url, auth_headers = self.__client._configure_realtime(self.__model, extra_query) From 6ff9ed038ac8c376df2aae167ef52d8500140e23 Mon Sep 17 00:00:00 2001 From: Johan Stenberg Date: Mon, 14 Jul 2025 15:54:11 -0700 Subject: [PATCH 2/7] Validate only one of api_key and bearer_token_provider are passed in. Propagate bearer_token_provider in the `copy` method. --- src/openai/_client.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/openai/_client.py b/src/openai/_client.py index b461368161..5989f69b96 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -129,6 +129,8 @@ def __init__( - `project` from `OPENAI_PROJECT_ID` - `webhook_secret` from `OPENAI_WEBHOOK_SECRET` """ + if api_key and bearer_token_provider: + raise ValueError("The `api_key` and `bearer_token_provider` arguments are mutually exclusive") if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") if api_key is None and bearer_token_provider is None: @@ -321,6 +323,7 @@ def copy( self, *, api_key: str | None = None, + bearer_token_provider: Callable[[], str] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -359,6 +362,7 @@ def copy( http_client = http_client or self._client return self.__class__( api_key=api_key or self.api_key, + bearer_token_provider = bearer_token_provider or self.bearer_token_provider, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, @@ -462,6 +466,8 @@ def __init__( - `project` from `OPENAI_PROJECT_ID` - `webhook_secret` from `OPENAI_WEBHOOK_SECRET` """ + if api_key and bearer_token_provider: + raise ValueError("The `api_key` and `bearer_token_provider` arguments are mutually exclusive") if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") if api_key is None and bearer_token_provider is None: @@ -657,6 +663,7 @@ def copy( self, *, api_key: str | None = None, + bearer_token_provider: Callable[[], Awaitable[str]] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -695,6 +702,7 @@ def copy( http_client = http_client or self._client return self.__class__( api_key=api_key or self.api_key, + bearer_token_provider = bearer_token_provider or self.bearer_token_provider, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, From 7873247a44182c580378929c766eb9eb0d88a1b8 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Wed, 23 Jul 2025 13:27:19 -0700 Subject: [PATCH 3/7] add tests, fix copy, add token provider to module client (#18) * add tests, fix copy, add token provider to module client * fix lint * ignore for azure copy * revert change --- src/openai/__init__.py | 14 ++++ src/openai/_client.py | 54 ++++++++----- src/openai/lib/azure.py | 8 +- tests/test_client.py | 148 +++++++++++++++++++++++++++++++++++- tests/test_module_client.py | 33 ++++++++ 5 files changed, 231 insertions(+), 26 deletions(-) diff --git a/src/openai/__init__.py b/src/openai/__init__.py index b944fbed5e..b8d0edeeaa 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -119,6 +119,8 @@ api_key: str | None = None +bearer_token_provider: _t.Callable[[], str] | None = None + organization: str | None = None project: str | None = None @@ -165,6 +167,17 @@ def api_key(self, value: str | None) -> None: # type: ignore api_key = value + @property # type: ignore + @override + def bearer_token_provider(self) -> _t.Callable[[], str] | None: + return bearer_token_provider + + @bearer_token_provider.setter # type: ignore + def bearer_token_provider(self, value: _t.Callable[[], str] | None) -> None: # type: ignore + global bearer_token_provider + + bearer_token_provider = value + @property # type: ignore @override def organization(self) -> str | None: @@ -348,6 +361,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction] _client = _ModuleClient( api_key=api_key, + bearer_token_provider=bearer_token_provider, organization=organization, project=project, webhook_secret=webhook_secret, diff --git a/src/openai/_client.py b/src/openai/_client.py index 5989f69b96..ee5d740af3 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -135,10 +135,10 @@ def __init__( api_key = os.environ.get("OPENAI_API_KEY") if api_key is None and bearer_token_provider is None: raise OpenAIError( - "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" + "The api_key or bearer_token_provider client option must be set either by passing api_key or bearer_token_provider to the client or by setting the OPENAI_API_KEY environment variable" ) self.bearer_token_provider = bearer_token_provider - self.api_key = api_key or '' + self.api_key = api_key or "" if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -288,26 +288,32 @@ def with_raw_response(self) -> OpenAIWithRawResponse: @cached_property def with_streaming_response(self) -> OpenAIWithStreamedResponse: return OpenAIWithStreamedResponse(self) + @property @override def qs(self) -> Querystring: return Querystring(array_format="brackets") - def refresh_auth_headers(self): - bearer_token = self.bearer_token_provider() if self.bearer_token_provider else self.api_key - self._auth_headers = {"Authorization": f"Bearer {bearer_token}"} - + def refresh_auth_headers(self) -> None: + secret = self.bearer_token_provider() if self.bearer_token_provider else self.api_key + if not secret: + # if the api key is an empty string, encoding the header will fail + # so we set it to an empty dict + # this is to avoid sending an invalid Authorization header + self._auth_headers = {} + else: + self._auth_headers = {"Authorization": f"Bearer {secret}"} @override def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: self.refresh_auth_headers() return super()._prepare_options(options) - + @property @override def auth_headers(self) -> dict[str, str]: return self._auth_headers - + @property @override def default_headers(self) -> dict[str, str | Omit]: @@ -359,10 +365,13 @@ def copy( elif set_default_query is not None: params = set_default_query + bearer_token_provider = bearer_token_provider or self.bearer_token_provider + if bearer_token_provider is not None: + _extra_kwargs = {**_extra_kwargs, "bearer_token_provider": bearer_token_provider} + http_client = http_client or self._client return self.__class__( api_key=api_key or self.api_key, - bearer_token_provider = bearer_token_provider or self.bearer_token_provider, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, @@ -472,10 +481,10 @@ def __init__( api_key = os.environ.get("OPENAI_API_KEY") if api_key is None and bearer_token_provider is None: raise OpenAIError( - "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" + "The api_key or bearer_token_provider client option must be set either by passing api_key or bearer_token_provider to the client or by setting the OPENAI_API_KEY environment variable" ) self.bearer_token_provider = bearer_token_provider - self.api_key = api_key or '' + self.api_key = api_key or "" if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -631,18 +640,24 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") - async def refresh_auth_headers(self): + async def refresh_auth_headers(self) -> None: if self.bearer_token_provider: - bearer_token = await self.bearer_token_provider() + secret = await self.bearer_token_provider() + else: + secret = self.api_key + if not secret: + # if the api key is an empty string, encoding the header will fail + # so we set it to an empty dict + # this is to avoid sending an invalid Authorization header + self._auth_headers = {} else: - bearer_token = self.api_key - self._auth_headers = {"Authorization": f"Bearer {bearer_token}"} - + self._auth_headers = {"Authorization": f"Bearer {secret}"} + @override async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: await self.refresh_auth_headers() return await super()._prepare_options(options) - + @property @override def auth_headers(self) -> dict[str, str]: @@ -699,10 +714,13 @@ def copy( elif set_default_query is not None: params = set_default_query + bearer_token_provider = bearer_token_provider or self.bearer_token_provider + if bearer_token_provider is not None: + _extra_kwargs = {**_extra_kwargs, "bearer_token_provider": bearer_token_provider} + http_client = http_client or self._client return self.__class__( api_key=api_key or self.api_key, - bearer_token_provider = bearer_token_provider or self.bearer_token_provider, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index a714982e0c..c7b2a19fa4 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -255,7 +255,7 @@ def __init__( self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None @override - def copy( + def copy( # type: ignore self, *, api_key: str | None = None, @@ -301,7 +301,7 @@ def copy( }, ) - with_options = copy + with_options = copy # type: ignore def _get_azure_ad_token(self) -> str | None: if self._azure_ad_token is not None: @@ -536,7 +536,7 @@ def __init__( self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None @override - def copy( + def copy( # type: ignore self, *, api_key: str | None = None, @@ -582,7 +582,7 @@ def copy( }, ) - with_options = copy + with_options = copy # type: ignore async def _get_azure_ad_token(self) -> str | None: if self._azure_ad_token is not None: diff --git a/tests/test_client.py b/tests/test_client.py index ccda50a7f0..805b95f51b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,7 +11,7 @@ import inspect import subprocess import tracemalloc -from typing import Any, Union, cast +from typing import Any, Union, Protocol, cast from textwrap import dedent from unittest import mock from typing_extensions import Literal @@ -41,6 +41,10 @@ api_key = "My API Key" +class MockRequestCall(Protocol): + request: httpx.Request + + def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]: request = client._build_request(FinalRequestOptions(method="get", url="/foo")) url = httpx.URL(request.url) @@ -337,7 +341,9 @@ def test_default_headers_option(self) -> None: def test_validate_headers(self) -> None: client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + options = client._prepare_options(FinalRequestOptions(method="get", url="/foo")) + request = client._build_request(options) + assert request.headers.get("Authorization") == f"Bearer {api_key}" with pytest.raises(OpenAIError): @@ -939,6 +945,63 @@ def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + def test_refresh_auth_headers_token(self) -> None: + client = OpenAI(base_url=base_url, bearer_token_provider=lambda: "test_bearer_token") + client.refresh_auth_headers() + assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" + + def test_refresh_auth_headers_key(self) -> None: + client = OpenAI(base_url=base_url, api_key="test_api_key") + client.refresh_auth_headers() + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + + @pytest.mark.respx() + def test_bearer_token_refresh(self, respx_mock: MockRouter) -> None: + respx_mock.post(base_url + "/chat/completions").mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + def token_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = OpenAI(base_url=base_url, bearer_token_provider=token_provider) + client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 2 + + assert calls[0].request.headers.get("Authorization") == "Bearer first" + assert calls[1].request.headers.get("Authorization") == "Bearer second" + + def test_auth_mutually_exclusive(self) -> None: + with pytest.raises(ValueError) as exc_info: + OpenAI(base_url=base_url, api_key=api_key, bearer_token_provider=lambda: "test_bearer_token") + assert str(exc_info.value) == "The `api_key` and `bearer_token_provider` arguments are mutually exclusive" + + def test_copy_auth(self) -> None: + client = OpenAI(base_url=base_url, bearer_token_provider=lambda: "test_bearer_token_1").copy( + bearer_token_provider=lambda: "test_bearer_token_2" + ) + client.refresh_auth_headers() + assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} + + def test_copy_auth_mutually_exclusive(self) -> None: + with pytest.raises(ValueError) as exc_info: + OpenAI(base_url=base_url, api_key=api_key).copy(bearer_token_provider=lambda: "test_bearer_token") + assert str(exc_info.value) == "The `api_key` and `bearer_token_provider` arguments are mutually exclusive" + class TestAsyncOpenAI: client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -1220,9 +1283,10 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" - def test_validate_headers(self) -> None: + async def test_validate_headers(self) -> None: client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo")) + request = client._build_request(options) assert request.headers.get("Authorization") == f"Bearer {api_key}" with pytest.raises(OpenAIError): @@ -1887,3 +1951,79 @@ async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + + @pytest.mark.asyncio + async def test_refresh_auth_headers_token_async(self) -> None: + async def token_provider() -> str: + return "test_bearer_token" + + client = AsyncOpenAI(base_url=base_url, bearer_token_provider=token_provider) + await client.refresh_auth_headers() + assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" + + @pytest.mark.asyncio + async def test_refresh_auth_headers_key_async(self) -> None: + client = AsyncOpenAI(base_url=base_url, api_key="test_api_key") + await client.refresh_auth_headers() + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + + @pytest.mark.asyncio + @pytest.mark.respx() + async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None: + respx_mock.post(base_url + "/chat/completions").mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + async def token_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = AsyncOpenAI(base_url=base_url, bearer_token_provider=token_provider) + await client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 2 + + assert calls[0].request.headers.get("Authorization") == "Bearer first" + assert calls[1].request.headers.get("Authorization") == "Bearer second" + + def test_auth_mutually_exclusive_async(self) -> None: + async def token_provider() -> str: + return "test_bearer_token" + + with pytest.raises(ValueError) as exc_info: + AsyncOpenAI(base_url=base_url, api_key=api_key, bearer_token_provider=token_provider) + assert str(exc_info.value) == "The `api_key` and `bearer_token_provider` arguments are mutually exclusive" + + @pytest.mark.asyncio + async def test_copy_auth(self) -> None: + async def token_provider_1() -> str: + return "test_bearer_token_1" + + async def token_provider_2() -> str: + return "test_bearer_token_2" + + client = AsyncOpenAI(base_url=base_url, bearer_token_provider=token_provider_1).copy( + bearer_token_provider=token_provider_2 + ) + await client.refresh_auth_headers() + assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} + + def test_copy_auth_mutually_exclusive_async(self) -> None: + async def token_provider() -> str: + return "test_bearer_token" + + with pytest.raises(ValueError) as exc_info: + AsyncOpenAI(base_url=base_url, api_key=api_key).copy(bearer_token_provider=token_provider) + assert str(exc_info.value) == "The `api_key` and `bearer_token_provider` arguments are mutually exclusive" diff --git a/tests/test_module_client.py b/tests/test_module_client.py index 9c9a1addab..1cc29ae69e 100644 --- a/tests/test_module_client.py +++ b/tests/test_module_client.py @@ -15,6 +15,7 @@ def reset_state() -> None: openai._reset_client() openai.api_key = None or "My API Key" + openai.bearer_token_provider = None openai.organization = None openai.project = None openai.webhook_secret = None @@ -97,6 +98,17 @@ def test_http_client_option() -> None: assert openai.completions._client._client is new_client +def test_bearer_token_provider_option() -> None: + assert openai.bearer_token_provider is None + assert openai.completions._client.bearer_token_provider is None + + openai.bearer_token_provider = lambda: "foo" + + assert openai.bearer_token_provider() == "foo" + assert openai.completions._client.bearer_token_provider + assert openai.completions._client.bearer_token_provider() == "foo" + + import contextlib from typing import Iterator @@ -123,6 +135,27 @@ def test_only_api_key_results_in_openai_api() -> None: assert type(openai.completions._client).__name__ == "_ModuleClient" +def test_only_bearer_token_provider_in_openai_api() -> None: + with fresh_env(): + openai.api_type = None + openai.api_key = None + openai.bearer_token_provider = lambda: "example bearer token" + + assert type(openai.completions._client).__name__ == "_ModuleClient" + + +def test_both_api_key_and_bearer_token_provider_in_openai_api() -> None: + with fresh_env(): + openai.api_key = "example API key" + openai.bearer_token_provider = lambda: "example bearer token" + + with pytest.raises( + ValueError, + match=r"The `api_key` and `bearer_token_provider` arguments are mutually exclusive", + ): + openai.completions._client # noqa: B018 + + def test_azure_api_key_env_without_api_version() -> None: with fresh_env(): openai.api_type = None From a6a6a2b82ee214099ceb08b3d3dab38db47548e6 Mon Sep 17 00:00:00 2001 From: Johan Stenberg Date: Tue, 26 Aug 2025 12:26:56 -0700 Subject: [PATCH 4/7] Make api_key callable to enable token refresh for openai client --- src/openai/__init__.py | 24 ++++------------- src/openai/_client.py | 53 ++++++++++++++++--------------------- tests/test_client.py | 41 ++++++---------------------- tests/test_module_client.py | 20 +++++--------- 4 files changed, 42 insertions(+), 96 deletions(-) diff --git a/src/openai/__init__.py b/src/openai/__init__.py index b8d0edeeaa..9c822b0da6 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -117,9 +117,7 @@ from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES -api_key: str | None = None - -bearer_token_provider: _t.Callable[[], str] | None = None +api_key: str | _t.Callable[[], str] | None = None organization: str | None = None @@ -158,25 +156,14 @@ class _ModuleClient(OpenAI): @property # type: ignore @override - def api_key(self) -> str | None: + def api_key(self) -> str | _t.Callable[[], str] | None: return api_key @api_key.setter # type: ignore - def api_key(self, value: str | None) -> None: # type: ignore + def api_key(self, value: str | _t.Callable[[], str] | None) -> None: # type: ignore global api_key - api_key = value - @property # type: ignore - @override - def bearer_token_provider(self) -> _t.Callable[[], str] | None: - return bearer_token_provider - - @bearer_token_provider.setter # type: ignore - def bearer_token_provider(self, value: _t.Callable[[], str] | None) -> None: # type: ignore - global bearer_token_provider - - bearer_token_provider = value @property # type: ignore @override @@ -346,7 +333,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction] _client = _AzureModuleClient( # type: ignore api_version=api_version, azure_endpoint=azure_endpoint, - api_key=api_key, + api_key=bearer_token_provider or api_key, azure_ad_token=azure_ad_token, azure_ad_token_provider=azure_ad_token_provider, organization=organization, @@ -360,8 +347,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction] return _client _client = _ModuleClient( - api_key=api_key, - bearer_token_provider=bearer_token_provider, + api_key=api_key or bearer_token_provider, organization=organization, project=project, webhook_secret=webhook_secret, diff --git a/src/openai/_client.py b/src/openai/_client.py index ee5d740af3..7780d5df2b 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -81,6 +81,7 @@ class OpenAI(SyncAPIClient): # client options api_key: str + bearer_token_provider: Callable[[], str] | None = None organization: str | None project: str | None webhook_secret: str | None @@ -96,8 +97,7 @@ class OpenAI(SyncAPIClient): def __init__( self, *, - api_key: str | None = None, - bearer_token_provider: Callable[[], str] | None = None, + api_key: str | None | Callable[[], str] = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -129,16 +129,17 @@ def __init__( - `project` from `OPENAI_PROJECT_ID` - `webhook_secret` from `OPENAI_WEBHOOK_SECRET` """ - if api_key and bearer_token_provider: - raise ValueError("The `api_key` and `bearer_token_provider` arguments are mutually exclusive") if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") - if api_key is None and bearer_token_provider is None: + if api_key is None: raise OpenAIError( - "The api_key or bearer_token_provider client option must be set either by passing api_key or bearer_token_provider to the client or by setting the OPENAI_API_KEY environment variable" + "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" ) - self.bearer_token_provider = bearer_token_provider - self.api_key = api_key or "" + if callable(api_key): + self.bearer_token_provider = api_key + self.api_key = "" + else: + self.api_key = api_key or "" if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -328,8 +329,7 @@ def default_headers(self) -> dict[str, str | Omit]: def copy( self, *, - api_key: str | None = None, - bearer_token_provider: Callable[[], str] | None = None, + api_key: str | Callable[[], str] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -365,13 +365,9 @@ def copy( elif set_default_query is not None: params = set_default_query - bearer_token_provider = bearer_token_provider or self.bearer_token_provider - if bearer_token_provider is not None: - _extra_kwargs = {**_extra_kwargs, "bearer_token_provider": bearer_token_provider} - http_client = http_client or self._client return self.__class__( - api_key=api_key or self.api_key, + api_key=api_key or self.api_key or self.bearer_token_provider, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, @@ -427,6 +423,7 @@ def _make_status_error( class AsyncOpenAI(AsyncAPIClient): # client options api_key: str + bearer_token_provider: Callable[[], Awaitable[str]] | None = None organization: str | None project: str | None webhook_secret: str | None @@ -442,8 +439,7 @@ class AsyncOpenAI(AsyncAPIClient): def __init__( self, *, - api_key: str | None = None, - bearer_token_provider: Callable[[], Awaitable[str]] | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -475,16 +471,18 @@ def __init__( - `project` from `OPENAI_PROJECT_ID` - `webhook_secret` from `OPENAI_WEBHOOK_SECRET` """ - if api_key and bearer_token_provider: - raise ValueError("The `api_key` and `bearer_token_provider` arguments are mutually exclusive") if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") - if api_key is None and bearer_token_provider is None: + if api_key is None: raise OpenAIError( - "The api_key or bearer_token_provider client option must be set either by passing api_key or bearer_token_provider to the client or by setting the OPENAI_API_KEY environment variable" + "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" ) - self.bearer_token_provider = bearer_token_provider - self.api_key = api_key or "" + if callable(api_key): + self.bearer_token_provider = api_key + self.api_key = "" + else: + self.bearer_token_provider = None + self.api_key = api_key or "" if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -677,8 +675,7 @@ def default_headers(self) -> dict[str, str | Omit]: def copy( self, *, - api_key: str | None = None, - bearer_token_provider: Callable[[], Awaitable[str]] | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -714,13 +711,9 @@ def copy( elif set_default_query is not None: params = set_default_query - bearer_token_provider = bearer_token_provider or self.bearer_token_provider - if bearer_token_provider is not None: - _extra_kwargs = {**_extra_kwargs, "bearer_token_provider": bearer_token_provider} - http_client = http_client or self._client return self.__class__( - api_key=api_key or self.api_key, + api_key=api_key or self.api_key or self.bearer_token_provider, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, diff --git a/tests/test_client.py b/tests/test_client.py index 805b95f51b..c50b382d22 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -946,7 +946,7 @@ def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" def test_refresh_auth_headers_token(self) -> None: - client = OpenAI(base_url=base_url, bearer_token_provider=lambda: "test_bearer_token") + client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token") client.refresh_auth_headers() assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" @@ -976,7 +976,7 @@ def token_provider() -> str: return "second" - client = OpenAI(base_url=base_url, bearer_token_provider=token_provider) + client = OpenAI(base_url=base_url, api_key=token_provider) client.chat.completions.create(messages=[], model="gpt-4") calls = cast("list[MockRequestCall]", respx_mock.calls) @@ -985,23 +985,14 @@ def token_provider() -> str: assert calls[0].request.headers.get("Authorization") == "Bearer first" assert calls[1].request.headers.get("Authorization") == "Bearer second" - def test_auth_mutually_exclusive(self) -> None: - with pytest.raises(ValueError) as exc_info: - OpenAI(base_url=base_url, api_key=api_key, bearer_token_provider=lambda: "test_bearer_token") - assert str(exc_info.value) == "The `api_key` and `bearer_token_provider` arguments are mutually exclusive" def test_copy_auth(self) -> None: - client = OpenAI(base_url=base_url, bearer_token_provider=lambda: "test_bearer_token_1").copy( - bearer_token_provider=lambda: "test_bearer_token_2" + client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy( + api_key=lambda: "test_bearer_token_2" ) client.refresh_auth_headers() assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} - def test_copy_auth_mutually_exclusive(self) -> None: - with pytest.raises(ValueError) as exc_info: - OpenAI(base_url=base_url, api_key=api_key).copy(bearer_token_provider=lambda: "test_bearer_token") - assert str(exc_info.value) == "The `api_key` and `bearer_token_provider` arguments are mutually exclusive" - class TestAsyncOpenAI: client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -1957,7 +1948,7 @@ async def test_refresh_auth_headers_token_async(self) -> None: async def token_provider() -> str: return "test_bearer_token" - client = AsyncOpenAI(base_url=base_url, bearer_token_provider=token_provider) + client = AsyncOpenAI(base_url=base_url, api_key=token_provider) await client.refresh_auth_headers() assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" @@ -1989,7 +1980,7 @@ async def token_provider() -> str: return "second" - client = AsyncOpenAI(base_url=base_url, bearer_token_provider=token_provider) + client = AsyncOpenAI(base_url=base_url, api_key=token_provider) await client.chat.completions.create(messages=[], model="gpt-4") calls = cast("list[MockRequestCall]", respx_mock.calls) @@ -1998,14 +1989,6 @@ async def token_provider() -> str: assert calls[0].request.headers.get("Authorization") == "Bearer first" assert calls[1].request.headers.get("Authorization") == "Bearer second" - def test_auth_mutually_exclusive_async(self) -> None: - async def token_provider() -> str: - return "test_bearer_token" - - with pytest.raises(ValueError) as exc_info: - AsyncOpenAI(base_url=base_url, api_key=api_key, bearer_token_provider=token_provider) - assert str(exc_info.value) == "The `api_key` and `bearer_token_provider` arguments are mutually exclusive" - @pytest.mark.asyncio async def test_copy_auth(self) -> None: async def token_provider_1() -> str: @@ -2014,16 +1997,8 @@ async def token_provider_1() -> str: async def token_provider_2() -> str: return "test_bearer_token_2" - client = AsyncOpenAI(base_url=base_url, bearer_token_provider=token_provider_1).copy( - bearer_token_provider=token_provider_2 + client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy( + api_key=token_provider_2 ) await client.refresh_auth_headers() assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} - - def test_copy_auth_mutually_exclusive_async(self) -> None: - async def token_provider() -> str: - return "test_bearer_token" - - with pytest.raises(ValueError) as exc_info: - AsyncOpenAI(base_url=base_url, api_key=api_key).copy(bearer_token_provider=token_provider) - assert str(exc_info.value) == "The `api_key` and `bearer_token_provider` arguments are mutually exclusive" diff --git a/tests/test_module_client.py b/tests/test_module_client.py index 1cc29ae69e..85777acd76 100644 --- a/tests/test_module_client.py +++ b/tests/test_module_client.py @@ -99,12 +99,8 @@ def test_http_client_option() -> None: def test_bearer_token_provider_option() -> None: - assert openai.bearer_token_provider is None - assert openai.completions._client.bearer_token_provider is None + openai.api_key = lambda: "foo" - openai.bearer_token_provider = lambda: "foo" - - assert openai.bearer_token_provider() == "foo" assert openai.completions._client.bearer_token_provider assert openai.completions._client.bearer_token_provider() == "foo" @@ -138,23 +134,19 @@ def test_only_api_key_results_in_openai_api() -> None: def test_only_bearer_token_provider_in_openai_api() -> None: with fresh_env(): openai.api_type = None - openai.api_key = None - openai.bearer_token_provider = lambda: "example bearer token" + openai.api_key = lambda: "example bearer token" assert type(openai.completions._client).__name__ == "_ModuleClient" def test_both_api_key_and_bearer_token_provider_in_openai_api() -> None: with fresh_env(): - openai.api_key = "example API key" - openai.bearer_token_provider = lambda: "example bearer token" + openai.api_key = lambda: "example bearer token" - with pytest.raises( - ValueError, - match=r"The `api_key` and `bearer_token_provider` arguments are mutually exclusive", - ): - openai.completions._client # noqa: B018 + assert(openai.api_key() == "example bearer token") + openai.api_key = "example API key" + assert(openai.api_key == "example API key") def test_azure_api_key_env_without_api_version() -> None: with fresh_env(): From 771af2c9d6b5312c7d83ba27bae1333b6d4ac643 Mon Sep 17 00:00:00 2001 From: Johan Stenberg Date: Tue, 26 Aug 2025 14:38:42 -0700 Subject: [PATCH 5/7] Fix up missed tests and consistency tweaks --- src/openai/__init__.py | 4 ++-- src/openai/_client.py | 16 ++++++++++------ tests/test_module_client.py | 1 - 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/openai/__init__.py b/src/openai/__init__.py index 9c822b0da6..119b3d7ff6 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -333,7 +333,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction] _client = _AzureModuleClient( # type: ignore api_version=api_version, azure_endpoint=azure_endpoint, - api_key=bearer_token_provider or api_key, + api_key=api_key, azure_ad_token=azure_ad_token, azure_ad_token_provider=azure_ad_token_provider, organization=organization, @@ -347,7 +347,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction] return _client _client = _ModuleClient( - api_key=api_key or bearer_token_provider, + api_key=api_key, organization=organization, project=project, webhook_secret=webhook_secret, diff --git a/src/openai/_client.py b/src/openai/_client.py index 7780d5df2b..20ce0aaacf 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -136,10 +136,11 @@ def __init__( "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" ) if callable(api_key): - self.bearer_token_provider = api_key self.api_key = "" + self.bearer_token_provider = api_key else: self.api_key = api_key or "" + self.bearer_token_provider = None if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -296,7 +297,10 @@ def qs(self) -> Querystring: return Querystring(array_format="brackets") def refresh_auth_headers(self) -> None: - secret = self.bearer_token_provider() if self.bearer_token_provider else self.api_key + if self.bearer_token_provider: + secret = self.bearer_token_provider() + else: + secret = self.api_key if not secret: # if the api key is an empty string, encoding the header will fail # so we set it to an empty dict @@ -367,7 +371,7 @@ def copy( http_client = http_client or self._client return self.__class__( - api_key=api_key or self.api_key or self.bearer_token_provider, + api_key=api_key or self.bearer_token_provider or self.api_key, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, @@ -478,11 +482,11 @@ def __init__( "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" ) if callable(api_key): - self.bearer_token_provider = api_key self.api_key = "" + self.bearer_token_provider = api_key else: - self.bearer_token_provider = None self.api_key = api_key or "" + self.bearer_token_provider = None if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -713,7 +717,7 @@ def copy( http_client = http_client or self._client return self.__class__( - api_key=api_key or self.api_key or self.bearer_token_provider, + api_key=api_key or self.bearer_token_provider or self.api_key, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, diff --git a/tests/test_module_client.py b/tests/test_module_client.py index 85777acd76..4b1174bb8e 100644 --- a/tests/test_module_client.py +++ b/tests/test_module_client.py @@ -15,7 +15,6 @@ def reset_state() -> None: openai._reset_client() openai.api_key = None or "My API Key" - openai.bearer_token_provider = None openai.organization = None openai.project = None openai.webhook_secret = None From 1ffd959e8c7cf3355060965be493cbdf4ced9bb9 Mon Sep 17 00:00:00 2001 From: "Johan Stenberg (MSFT)" Date: Wed, 27 Aug 2025 14:54:30 -0700 Subject: [PATCH 6/7] Update tests/test_client.py Co-authored-by: Robert Craigie --- tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_client.py b/tests/test_client.py index c50b382d22..780a0cd83c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -956,7 +956,7 @@ def test_refresh_auth_headers_key(self) -> None: assert client.auth_headers.get("Authorization") == "Bearer test_api_key" @pytest.mark.respx() - def test_bearer_token_refresh(self, respx_mock: MockRouter) -> None: + def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None: respx_mock.post(base_url + "/chat/completions").mock( side_effect=[ httpx.Response(500, json={"error": "server error"}), From 60dba27ea4840770f65260e2f8869051b0819aba Mon Sep 17 00:00:00 2001 From: Johan Stenberg Date: Wed, 27 Aug 2025 19:27:32 -0700 Subject: [PATCH 7/7] Review feedback + make sure you can swap between callable and string api key values for module level client. --- src/openai/__init__.py | 15 ++++- src/openai/_client.py | 62 +++++++------------ src/openai/lib/azure.py | 14 ++--- .../resources/beta/realtime/realtime.py | 4 +- tests/test_client.py | 49 ++++++++++----- tests/test_module_client.py | 27 +++++--- 6 files changed, 98 insertions(+), 73 deletions(-) diff --git a/src/openai/__init__.py b/src/openai/__init__.py index 119b3d7ff6..3391e8fa8c 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -157,15 +157,26 @@ class _ModuleClient(OpenAI): @property # type: ignore @override def api_key(self) -> str | _t.Callable[[], str] | None: - return api_key + return api_key() if callable(api_key) else api_key @api_key.setter # type: ignore def api_key(self, value: str | _t.Callable[[], str] | None) -> None: # type: ignore global api_key api_key = value + @property + def _api_key_provider(self) -> _t.Callable[[], str] | None: # type: ignore + return None - @property # type: ignore + @_api_key_provider.setter + def _api_key_provider(self, value: _t.Callable[[], str] | None) -> None: # type: ignore + global api_key + # Yes, setting the api_key is intentional. The module level client accepts callables + # for the module level api_key and will call it to retrieve the value + # if it is a callable. + api_key = value + + @property @override def organization(self) -> str | None: return organization diff --git a/src/openai/_client.py b/src/openai/_client.py index 20ce0aaacf..50b0c99181 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -81,7 +81,6 @@ class OpenAI(SyncAPIClient): # client options api_key: str - bearer_token_provider: Callable[[], str] | None = None organization: str | None project: str | None webhook_secret: str | None @@ -137,10 +136,10 @@ def __init__( ) if callable(api_key): self.api_key = "" - self.bearer_token_provider = api_key + self._api_key_provider: Callable[[], str] | None = api_key else: self.api_key = api_key or "" - self.bearer_token_provider = None + self._api_key_provider = None if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -173,7 +172,6 @@ def __init__( ) self._default_stream_cls = Stream - self._auth_headers: dict[str, str] = {} @cached_property def completions(self) -> Completions: @@ -296,28 +294,23 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") - def refresh_auth_headers(self) -> None: - if self.bearer_token_provider: - secret = self.bearer_token_provider() - else: - secret = self.api_key - if not secret: - # if the api key is an empty string, encoding the header will fail - # so we set it to an empty dict - # this is to avoid sending an invalid Authorization header - self._auth_headers = {} - else: - self._auth_headers = {"Authorization": f"Bearer {secret}"} + def _refresh_api_key(self) -> None: + if self._api_key_provider: + self.api_key = self._api_key_provider() @override def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - self.refresh_auth_headers() + self._refresh_api_key() return super()._prepare_options(options) @property @override def auth_headers(self) -> dict[str, str]: - return self._auth_headers + api_key = self.api_key + if not api_key: + # if the api key is an empty string, encoding the header will fail + return {} + return {"Authorization": f"Bearer {api_key}"} @property @override @@ -371,7 +364,7 @@ def copy( http_client = http_client or self._client return self.__class__( - api_key=api_key or self.bearer_token_provider or self.api_key, + api_key=api_key or self._api_key_provider or self.api_key, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, @@ -427,7 +420,6 @@ def _make_status_error( class AsyncOpenAI(AsyncAPIClient): # client options api_key: str - bearer_token_provider: Callable[[], Awaitable[str]] | None = None organization: str | None project: str | None webhook_secret: str | None @@ -483,10 +475,10 @@ def __init__( ) if callable(api_key): self.api_key = "" - self.bearer_token_provider = api_key + self._api_key_provider: Callable[[], Awaitable[str]] | None = api_key else: self.api_key = api_key or "" - self.bearer_token_provider = None + self._api_key_provider = None if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -519,7 +511,6 @@ def __init__( ) self._default_stream_cls = AsyncStream - self._auth_headers: dict[str, str] = {} @cached_property def completions(self) -> AsyncCompletions: @@ -642,28 +633,23 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") - async def refresh_auth_headers(self) -> None: - if self.bearer_token_provider: - secret = await self.bearer_token_provider() - else: - secret = self.api_key - if not secret: - # if the api key is an empty string, encoding the header will fail - # so we set it to an empty dict - # this is to avoid sending an invalid Authorization header - self._auth_headers = {} - else: - self._auth_headers = {"Authorization": f"Bearer {secret}"} + async def _refresh_api_key(self) -> None: + if self._api_key_provider: + self.api_key = await self._api_key_provider() @override async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - await self.refresh_auth_headers() + await self._refresh_api_key() return await super()._prepare_options(options) @property @override def auth_headers(self) -> dict[str, str]: - return self._auth_headers + api_key = self.api_key + if not api_key: + # if the api key is an empty string, encoding the header will fail + return {} + return {"Authorization": f"Bearer {api_key}"} @property @override @@ -717,7 +703,7 @@ def copy( http_client = http_client or self._client return self.__class__( - api_key=api_key or self.bearer_token_provider or self.api_key, + api_key=api_key or self._api_key_provider or self.api_key, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index c7b2a19fa4..d6143f916f 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -255,10 +255,10 @@ def __init__( self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None @override - def copy( # type: ignore + def copy( self, *, - api_key: str | None = None, + api_key: str | Callable[[], str] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -301,7 +301,7 @@ def copy( # type: ignore }, ) - with_options = copy # type: ignore + with_options = copy def _get_azure_ad_token(self) -> str | None: if self._azure_ad_token is not None: @@ -435,7 +435,7 @@ def __init__( azure_endpoint: str | None = None, azure_deployment: str | None = None, api_version: str | None = None, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, azure_ad_token: str | None = None, azure_ad_token_provider: AsyncAzureADTokenProvider | None = None, organization: str | None = None, @@ -536,10 +536,10 @@ def __init__( self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None @override - def copy( # type: ignore + def copy( self, *, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -582,7 +582,7 @@ def copy( # type: ignore }, ) - with_options = copy # type: ignore + with_options = copy async def _get_azure_ad_token(self) -> str | None: if self._azure_ad_token is not None: diff --git a/src/openai/resources/beta/realtime/realtime.py b/src/openai/resources/beta/realtime/realtime.py index 90f6324d43..4fa35963b6 100644 --- a/src/openai/resources/beta/realtime/realtime.py +++ b/src/openai/resources/beta/realtime/realtime.py @@ -358,7 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query - await self.__client.refresh_auth_headers() + await self.__client._refresh_api_key() auth_headers = self.__client.auth_headers if is_async_azure_client(self.__client): url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query) @@ -541,7 +541,7 @@ def __enter__(self) -> RealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query - self.__client.refresh_auth_headers() + self.__client._refresh_api_key() auth_headers = self.__client.auth_headers if is_azure_client(self.__client): url, auth_headers = self.__client._configure_realtime(self.__model, extra_query) diff --git a/tests/test_client.py b/tests/test_client.py index 780a0cd83c..c6435b70c6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -945,14 +945,24 @@ def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" - def test_refresh_auth_headers_token(self) -> None: + def test_api_key_before_after_refresh_provider(self) -> None: client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token") - client.refresh_auth_headers() + + assert client.api_key == "" + assert 'Authorization' not in client.auth_headers + + client._refresh_api_key() + + assert client.api_key == "test_bearer_token" assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" - def test_refresh_auth_headers_key(self) -> None: + + def test_api_key_before_after_refresh_str(self) -> None: client = OpenAI(base_url=base_url, api_key="test_api_key") - client.refresh_auth_headers() + + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + client._refresh_api_key() + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" @pytest.mark.respx() @@ -985,12 +995,11 @@ def token_provider() -> str: assert calls[0].request.headers.get("Authorization") == "Bearer first" assert calls[1].request.headers.get("Authorization") == "Bearer second" - def test_copy_auth(self) -> None: client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy( api_key=lambda: "test_bearer_token_2" ) - client.refresh_auth_headers() + client._refresh_api_key() assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} @@ -1944,18 +1953,28 @@ async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" @pytest.mark.asyncio - async def test_refresh_auth_headers_token_async(self) -> None: - async def token_provider() -> str: + async def test_api_key_before_after_refresh_provider(self) -> None: + async def mock_api_key_provider(): return "test_bearer_token" + + client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider) - client = AsyncOpenAI(base_url=base_url, api_key=token_provider) - await client.refresh_auth_headers() + assert client.api_key == "" + assert 'Authorization' not in client.auth_headers + + await client._refresh_api_key() + + assert client.api_key == "test_bearer_token" assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" + @pytest.mark.asyncio - async def test_refresh_auth_headers_key_async(self) -> None: + async def test_api_key_before_after_refresh_str(self) -> None: client = AsyncOpenAI(base_url=base_url, api_key="test_api_key") - await client.refresh_auth_headers() + + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + await client._refresh_api_key() + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" @pytest.mark.asyncio @@ -1997,8 +2016,6 @@ async def token_provider_1() -> str: async def token_provider_2() -> str: return "test_bearer_token_2" - client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy( - api_key=token_provider_2 - ) - await client.refresh_auth_headers() + client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2) + await client._refresh_api_key() assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} diff --git a/tests/test_module_client.py b/tests/test_module_client.py index 4b1174bb8e..643ea85c2e 100644 --- a/tests/test_module_client.py +++ b/tests/test_module_client.py @@ -97,12 +97,22 @@ def test_http_client_option() -> None: assert openai.completions._client._client is new_client -def test_bearer_token_provider_option() -> None: - openai.api_key = lambda: "foo" +def test_api_key_callable() -> None: + openai.api_key = lambda: "1" + assert openai.completions._client.api_key == "1" - assert openai.completions._client.bearer_token_provider - assert openai.completions._client.bearer_token_provider() == "foo" +def test_api_key_overridable() -> None: + openai.api_key = lambda: "1" + assert openai.completions._client.api_key == "1" + assert openai.completions._client._api_key_provider is None + openai.api_key = "2" + assert openai.completions._client.api_key == "2" + assert openai.completions._client._api_key_provider is None + + openai.api_key = lambda: "3" + assert openai.completions._client.api_key == "3" + assert openai.completions._client._api_key_provider is None import contextlib from typing import Iterator @@ -130,7 +140,7 @@ def test_only_api_key_results_in_openai_api() -> None: assert type(openai.completions._client).__name__ == "_ModuleClient" -def test_only_bearer_token_provider_in_openai_api() -> None: +def test_only_api_key_in_openai_api() -> None: with fresh_env(): openai.api_type = None openai.api_key = lambda: "example bearer token" @@ -138,14 +148,15 @@ def test_only_bearer_token_provider_in_openai_api() -> None: assert type(openai.completions._client).__name__ == "_ModuleClient" -def test_both_api_key_and_bearer_token_provider_in_openai_api() -> None: +def test_both_api_key_and_api_key_provider_in_openai_api() -> None: with fresh_env(): openai.api_key = lambda: "example bearer token" - assert(openai.api_key() == "example bearer token") + assert openai.api_key() == "example bearer token" openai.api_key = "example API key" - assert(openai.api_key == "example API key") + assert openai.api_key == "example API key" + def test_azure_api_key_env_without_api_version() -> None: with fresh_env():