Skip to content

Commit

Permalink
OAuth2 to use current request header (home-assistant#43668)
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob authored Nov 27, 2020
1 parent 69c2818 commit f9fa249
Show file tree
Hide file tree
Showing 18 changed files with 258 additions and 90 deletions.
4 changes: 2 additions & 2 deletions homeassistant/components/toon/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ async def async_resolve_external_data(self, external_data: Any) -> dict:
"""Initialize local Toon auth implementation."""
data = {
"grant_type": "authorization_code",
"code": external_data,
"redirect_uri": self.redirect_uri,
"code": external_data["code"],
"redirect_uri": external_data["state"]["redirect_uri"],
"tenant_id": self.tenant_id,
}

Expand Down
35 changes: 25 additions & 10 deletions homeassistant/helpers/config_entry_oauth2_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from yarl import URL

from homeassistant import config_entries
from homeassistant.components.http import HomeAssistantView
from homeassistant.components import http
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.network import NoURLAvailableError, get_url
from homeassistant.helpers.network import NoURLAvailableError

from .aiohttp_client import async_get_clientsession

Expand All @@ -32,6 +32,7 @@
DATA_IMPLEMENTATIONS = "oauth2_impl"
DATA_PROVIDERS = "oauth2_providers"
AUTH_CALLBACK_PATH = "/auth/external/callback"
HEADER_FRONTEND_BASE = "HA-Frontend-Base"

CLOCK_OUT_OF_SYNC_MAX_SEC = 20

Expand Down Expand Up @@ -64,7 +65,7 @@ async def async_generate_authorize_url(self, flow_id: str) -> str:
Pass external data in with:
await hass.config_entries.flow.async_configure(
flow_id=flow_id, user_input=external_data
flow_id=flow_id, user_input={'code': 'abcd', 'state': { … }
)
"""
Expand Down Expand Up @@ -124,7 +125,17 @@ def domain(self) -> str:
@property
def redirect_uri(self) -> str:
"""Return the redirect uri."""
return f"{get_url(self.hass, require_current_request=True)}{AUTH_CALLBACK_PATH}"
req = http.current_request.get()

if req is None:
raise RuntimeError("No current request in context")

ha_host = req.headers.get(HEADER_FRONTEND_BASE)

if ha_host is None:
raise RuntimeError("No header in request")

return f"{ha_host}{AUTH_CALLBACK_PATH}"

@property
def extra_authorize_data(self) -> dict:
Expand All @@ -133,14 +144,17 @@ def extra_authorize_data(self) -> dict:

async def async_generate_authorize_url(self, flow_id: str) -> str:
"""Generate a url for the user to authorize."""
redirect_uri = self.redirect_uri
return str(
URL(self.authorize_url)
.with_query(
{
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"state": _encode_jwt(self.hass, {"flow_id": flow_id}),
"redirect_uri": redirect_uri,
"state": _encode_jwt(
self.hass, {"flow_id": flow_id, "redirect_uri": redirect_uri}
),
}
)
.update_query(self.extra_authorize_data)
Expand All @@ -151,8 +165,8 @@ async def async_resolve_external_data(self, external_data: Any) -> dict:
return await self._token_request(
{
"grant_type": "authorization_code",
"code": external_data,
"redirect_uri": self.redirect_uri,
"code": external_data["code"],
"redirect_uri": external_data["state"]["redirect_uri"],
}
)

Expand Down Expand Up @@ -384,7 +398,7 @@ def async_add_implementation_provider(
] = async_provide_implementation


class OAuth2AuthorizeCallbackView(HomeAssistantView):
class OAuth2AuthorizeCallbackView(http.HomeAssistantView):
"""OAuth2 Authorization Callback View."""

requires_auth = False
Expand All @@ -406,7 +420,8 @@ async def get(self, request: web.Request) -> web.Response:
return web.Response(text="Invalid state")

await hass.config_entries.flow.async_configure(
flow_id=state["flow_id"], user_input=request.query["code"]
flow_id=state["flow_id"],
user_input={"state": state, "code": request.query["code"]},
)

return web.Response(
Expand Down
6 changes: 3 additions & 3 deletions homeassistant/helpers/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import yarl

from homeassistant.components.http import current_request
from homeassistant.components import http
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import bind_hass
Expand Down Expand Up @@ -49,7 +49,7 @@ def get_url(
prefer_cloud: bool = False,
) -> str:
"""Get a URL to this instance."""
if require_current_request and current_request.get() is None:
if require_current_request and http.current_request.get() is None:
raise NoURLAvailableError

order = [TYPE_URL_INTERNAL, TYPE_URL_EXTERNAL]
Expand Down Expand Up @@ -125,7 +125,7 @@ def get_url(

def _get_request_host() -> Optional[str]:
"""Get the host address of the current request."""
request = current_request.get()
request = http.current_request.get()
if request is None:
raise NoURLAvailableError
return yarl.URL(request.url).host
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
CLIENT_SECRET = "5678"


async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
Expand All @@ -27,7 +29,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
"NEW_DOMAIN", context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)

assert result["url"] == (
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
Expand Down
12 changes: 10 additions & 2 deletions tests/components/almond/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ async def test_abort_if_existing_entry(hass):
assert result["reason"] == "single_instance_allowed"


async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
Expand All @@ -109,7 +111,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)

assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == (
Expand Down
12 changes: 10 additions & 2 deletions tests/components/home_connect/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
CLIENT_SECRET = "5678"


async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
Expand All @@ -31,7 +33,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
"home_connect", context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)

assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == (
Expand Down
12 changes: 10 additions & 2 deletions tests/components/nest/test_config_flow_sdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
SUBSCRIBER_ID = "subscriber-id-9876"


async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
Expand All @@ -31,7 +33,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)

oauth_authorize = OAUTH2_AUTHORIZE.format(project_id=PROJECT_ID)
assert result["url"] == (
Expand Down
12 changes: 10 additions & 2 deletions tests/components/netatmo/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ async def test_abort_if_existing_entry(hass):
assert result["reason"] == "already_configured"


async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
Expand All @@ -56,7 +58,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
"netatmo", context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)

scope = "+".join(
[
Expand Down
12 changes: 10 additions & 2 deletions tests/components/smappee/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,9 @@ async def test_abort_cloud_flow_if_local_device_exists(hass):
assert len(hass.config_entries.async_entries(DOMAIN)) == 1


async def test_full_user_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_user_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
Expand All @@ -351,7 +353,13 @@ async def test_full_user_flow(hass, aiohttp_client, aioclient_mock, current_requ
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"environment": ENV_CLOUD}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)

client = await aiohttp_client(hass.http.app)
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
Expand Down
12 changes: 10 additions & 2 deletions tests/components/somfy/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ async def test_abort_if_existing_entry(hass):
assert result["reason"] == "single_instance_allowed"


async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
async def test_full_flow(
hass, aiohttp_client, aioclient_mock, current_request_with_host
):
"""Check full flow."""
assert await setup.async_setup_component(
hass,
Expand All @@ -69,7 +71,13 @@ async def test_full_flow(hass, aiohttp_client, aioclient_mock, current_request):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)

assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == (
Expand Down
Loading

0 comments on commit f9fa249

Please sign in to comment.