From 65dcd947a3b86294b081214ff8449e274a8d51e9 Mon Sep 17 00:00:00 2001 From: jkyle109 Date: Tue, 26 Jul 2022 16:00:09 -0400 Subject: [PATCH 1/9] feat(credentials): Add async credentials. --- firebase_admin/credentials.py | 133 ++++++++++++++++++++++++---------- 1 file changed, 96 insertions(+), 37 deletions(-) diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 5477e1cf7..4c1e68593 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -17,14 +17,22 @@ import json import pathlib -import google.auth -from google.auth.transport import requests -from google.oauth2 import credentials +from typing import Type + +import google.auth # type: ignore +from google.auth import default +from google.auth._default_async import default_async # type: ignore +from google.auth.transport import requests # type: ignore +from google.auth.transport import _aiohttp_requests as aiohttp_requests +from google.oauth2 import credentials # type: ignore +from google.oauth2 import _credentials_async as credentials_async from google.oauth2 import service_account +from google.oauth2 import _service_account_async as service_account_async -_request = requests.Request() -_scopes = [ +_request: requests.Request = requests.Request() +_request_async: aiohttp_requests.Request = aiohttp_requests.Request() +_scopes: list[str] = [ 'https://www.googleapis.com/auth/cloud-platform', 'https://www.googleapis.com/auth/datastore', 'https://www.googleapis.com/auth/devstorage.read_write', @@ -33,7 +41,7 @@ 'https://www.googleapis.com/auth/userinfo.email' ] -AccessTokenInfo = collections.namedtuple('AccessTokenInfo', ['access_token', 'expiry']) +AccessTokenInfo: Type[tuple] = collections.namedtuple('AccessTokenInfo', ['access_token', 'expiry']) """Data included in an OAuth2 access token. Contains the access token string and the expiry time. The expirty time is exposed as a @@ -44,8 +52,8 @@ class Base: """Provides OAuth2 access tokens for accessing Firebase services.""" - def get_access_token(self): - """Fetches a Google OAuth2 access token using this credential instance. + def get_access_token(self) -> tuple: + """Fetches a Google OAuth2 access token using the synchronous credential instance. Returns: AccessTokenInfo: An access token obtained using the credential. @@ -54,8 +62,22 @@ def get_access_token(self): google_cred.refresh(_request) return AccessTokenInfo(google_cred.token, google_cred.expiry) + async def get_access_token_async(self) -> tuple: + """Fetches a Google OAuth2 access token using the asynchronous credential instance. + + Returns: + AccessTokenInfo: An access token obtained using the credential. + """ + google_cred = self.get_credential_async() + await google_cred.refresh(_request_async) + return AccessTokenInfo(google_cred.token, google_cred.expiry) + def get_credential(self): - """Returns the Google credential instance used for authentication.""" + """Returns the Google synchronous credential instance used for authentication.""" + raise NotImplementedError + + def get_credential_async(self): + """Returns the Google asynchronous credential instance used for authentication.""" raise NotImplementedError @@ -64,8 +86,8 @@ class Certificate(Base): _CREDENTIAL_TYPE = 'service_account' - def __init__(self, cert): - """Initializes a credential from a Google service account certificate. + def __init__(self, cert: str) -> None: + """Initializes credentials from a Google service account certificate. Service account certificates can be downloaded as JSON files from the Firebase console. To instantiate a credential from a certificate file, either specify the file path or a @@ -95,44 +117,54 @@ def __init__(self, cert): try: self._g_credential = service_account.Credentials.from_service_account_info( json_data, scopes=_scopes) + self._g_credential_async = service_account_async.Credentials.from_service_account_info( + json_data, scopes=_scopes) except ValueError as error: raise ValueError('Failed to initialize a certificate credential. ' 'Caused by: "{0}"'.format(error)) @property - def project_id(self): + def project_id(self) -> str: return self._g_credential.project_id @property - def signer(self): + def signer(self) -> google.auth.crypt.Signer: return self._g_credential.signer @property - def service_account_email(self): + def service_account_email(self) -> str: return self._g_credential.service_account_email - def get_credential(self): - """Returns the underlying Google credential. + def get_credential(self) -> service_account.Credentials: + """Returns the underlying Google synchronous credential. Returns: - google.auth.credentials.Credentials: A Google Auth credential instance.""" + google.auth.credentials.Credentials: A Google Auth synchronous credential instance.""" return self._g_credential + def get_credential_async(self) -> service_account_async.Credentials: + """Returns the underlying Google asynchronous credential. + + Returns: + google.auth._credentials_async.Credentials: A Google Auth asynchronous credential + instance.""" + return self._g_credential_async class ApplicationDefault(Base): """A Google Application Default credential.""" - def __init__(self): + def __init__(self) -> None: """Creates an instance that will use Application Default credentials. - The credentials will be lazily initialized when get_credential() or - project_id() is called. See those methods for possible errors raised. + The credentials will be lazily initialized when get_credential(), get_credential_async() + or project_id() is called. See those methods for possible errors raised. """ super(ApplicationDefault, self).__init__() self._g_credential = None # Will be lazily-loaded via _load_credential(). + self._g_credential_async = None # Will be lazily-loaded via _load_credential_async(). - def get_credential(self): - """Returns the underlying Google credential. + def get_credential(self) -> credentials.Credentials: + """Returns the underlying Google synchronous credential. Raises: google.auth.exceptions.DefaultCredentialsError: If Application Default @@ -142,9 +174,20 @@ def get_credential(self): self._load_credential() return self._g_credential + def get_credential_async(self) -> credentials_async.Credentials: + """Returns the underlying Google asynchronous credential. + + Raises: + google.auth.exceptions.DefaultCredentialsError: If Application Default + credentials cannot be initialized in the current environment. + Returns: + google.auth._credentials_async.Credentials: A Google Auth credential instance.""" + self._load_credential_async() + return self._g_credential_async + @property - def project_id(self): - """Returns the project_id from the underlying Google credential. + def project_id(self) -> str: + """Returns the project_id from the underlying Google credentials. Raises: google.auth.exceptions.DefaultCredentialsError: If Application Default @@ -154,21 +197,25 @@ def project_id(self): self._load_credential() return self._project_id - def _load_credential(self): + def _load_credential(self) -> None: if not self._g_credential: - self._g_credential, self._project_id = google.auth.default(scopes=_scopes) + self._g_credential, self._project_id = default(scopes=_scopes) + + def _load_credential_async(self) -> None: + if not self._g_credential_async: + self._g_credential_async, self._project_id = default_async(scopes=_scopes) class RefreshToken(Base): - """A credential initialized from an existing refresh token.""" + """Credentials initialized from an existing refresh token.""" _CREDENTIAL_TYPE = 'authorized_user' - def __init__(self, refresh_token): - """Initializes a credential from a refresh token JSON file. + def __init__(self, refresh_token: str) -> None: + """Initializes credentials from a refresh token JSON file. The JSON must consist of client_id, client_secret and refresh_token fields. Refresh token files are typically created and managed by the gcloud SDK. To instantiate - a credential from a refresh token file, either specify the file path or a dict + credentials from a refresh token file, either specify the file path or a dict representing the parsed contents of the file. Args: @@ -194,28 +241,40 @@ def __init__(self, refresh_token): raise ValueError('Invalid refresh token configuration. JSON must contain a ' '"type" field set to "{0}".'.format(self._CREDENTIAL_TYPE)) self._g_credential = credentials.Credentials.from_authorized_user_info(json_data, _scopes) + self._g_credential_async = credentials_async.Credentials.from_authorized_user_info( + json_data, + _scopes + ) @property - def client_id(self): + def client_id(self) -> str: return self._g_credential.client_id @property - def client_secret(self): + def client_secret(self) -> str: return self._g_credential.client_secret @property - def refresh_token(self): + def refresh_token(self) -> str: return self._g_credential.refresh_token - def get_credential(self): - """Returns the underlying Google credential. + def get_credential(self) -> credentials.Credentials: + """Returns the underlying Google synchronous credential. Returns: - google.auth.credentials.Credentials: A Google Auth credential instance.""" + google.auth.credentials.Credentials: A Google Auth synchronous credential instance.""" return self._g_credential + def get_credential_async(self) -> credentials_async.Credentials: + """Returns the underlying Google asynchronous credential. + + Returns: + google.auth._credentials_async.Credentials: A Google Auth asynchronous credential + instance.""" + return self._g_credential_async + -def _is_file_path(path): +def _is_file_path(path) -> bool: try: pathlib.Path(path) return True From 1a1a348cd49fe2c3218a2a0ce3dec5d9ffa38e76 Mon Sep 17 00:00:00 2001 From: jkyle109 Date: Tue, 26 Jul 2022 16:54:33 -0400 Subject: [PATCH 2/9] fix: Added to required modules libraries. --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 0dd529c04..b8aa3d724 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ pytest >= 6.2.0 pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 +aiohttp == 3.8.1 cachecontrol >= 0.12.6 google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 From bb8fd246503f799a6d273e345b0db40e7fd7db1e Mon Sep 17 00:00:00 2001 From: jkyle109 Date: Tue, 26 Jul 2022 17:06:13 -0400 Subject: [PATCH 3/9] fix: Imported correct instance of typing.List. --- firebase_admin/credentials.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 4c1e68593..b25f4e377 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -17,7 +17,10 @@ import json import pathlib -from typing import Type +from typing import ( + Type, + List +) import google.auth # type: ignore from google.auth import default @@ -32,7 +35,7 @@ _request: requests.Request = requests.Request() _request_async: aiohttp_requests.Request = aiohttp_requests.Request() -_scopes: list[str] = [ +_scopes: List[str] = [ 'https://www.googleapis.com/auth/cloud-platform', 'https://www.googleapis.com/auth/datastore', 'https://www.googleapis.com/auth/devstorage.read_write', From bb6e93aab7169f3d674cdacd1125473cfaf8a266 Mon Sep 17 00:00:00 2001 From: jkyle109 Date: Tue, 26 Jul 2022 17:47:17 -0400 Subject: [PATCH 4/9] fix: Added method in subclass to override abstract method in class. --- integration/test_auth.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/integration/test_auth.py b/integration/test_auth.py index 82974732d..3af53e29f 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -898,13 +898,24 @@ class CredentialWrapper(credentials.Base): def __init__(self, token): self._delegate = google.oauth2.credentials.Credentials(token) + self._delegate_async = google.oauth2._credentials_async.Credentials(token) def get_credential(self): return self._delegate + def get_credential_async(self): + return self._delegate_async + @classmethod def from_existing_credential(cls, google_cred): if not google_cred.token: request = transport.requests.Request() google_cred.refresh(request) return CredentialWrapper(google_cred.token) + + @classmethod + async def from_existing_credential_async(cls, google_cred): + if not google_cred.token: + request = transport._aiohttp_requests.Request() + await google_cred.refresh(request) + return CredentialWrapper(google_cred.token) From c9b650374b1eb317f8c08463fec800b470dac87b Mon Sep 17 00:00:00 2001 From: jkyle109 Date: Wed, 3 Aug 2022 13:49:35 -0400 Subject: [PATCH 5/9] Draft for async http client tests --- firebase_admin/_http_client_async.py | 161 ++++++++++++++++++++++++ tests/test_http_client_async.py | 177 +++++++++++++++++++++++++++ tests/testutils.py | 88 +++++++++++++ 3 files changed, 426 insertions(+) create mode 100644 firebase_admin/_http_client_async.py create mode 100644 tests/test_http_client_async.py diff --git a/firebase_admin/_http_client_async.py b/firebase_admin/_http_client_async.py new file mode 100644 index 000000000..8ebb6ff8d --- /dev/null +++ b/firebase_admin/_http_client_async.py @@ -0,0 +1,161 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal HTTP client module. + + This module provides utilities for making HTTP calls using the requests library. + """ +import json + +import aiohttp +from google.auth.transport import _aiohttp_requests +import requests +from requests.packages.urllib3.util import retry # pylint: disable=import-error + +import urllib3 # type: ignore + + +DEFAULT_RETRY_ATTEMPTS = 4 +DEFAULT_RETRY_CODES = [500, 503] +DEFAULT_TIMEOUT_SECONDS = 120 + + + + +class HttpClientAsync: + """Base HTTP client used to make HTTP calls. + + HttpClient maintains an HTTP session, and handles request authentication and retries if + necessary. + """ + + def __init__( + self, credential=None, + session=None, + base_url='', + headers=None, + retry_attempts=DEFAULT_RETRY_ATTEMPTS, + retry_codes=DEFAULT_RETRY_CODES, + timeout=DEFAULT_TIMEOUT_SECONDS + ): + """Creates a new HttpClientAsync instance from the provided arguments. + + If a credential is provided, initializes a new HTTP session authorized with it. If neither + a credential nor a session is provided, initializes a new unauthorized session. + + Args: + credential: A Google credential that can be used to authenticate requests (optional). + session: A custom HTTP session (optional). + base_url: A URL prefix to be added to all outgoing requests (optional). + headers: A map of headers to be added to all outgoing requests (optional). + retries: A urllib retry configuration. Default settings would retry once for low-level + connection and socket read errors, and up to 4 times for HTTP 500 and 503 errors. + Pass a False value to disable retries (optional). + timeout: HTTP timeout in seconds. Defaults to 120 seconds when not specified. Set to + None to disable timeouts (optional). + """ + if credential: + # self._session = _aiohttp_requests.AuthorizedSession(credential) + self._session = _aiohttp_requests.AuthorizedSession( + credential, + refresh_status_codes=retry_codes, + max_refresh_attempts=retry_attempts, + refresh_timeout=timeout + ) + elif session: + self._session = session + else: + self._session = aiohttp.ClientSession() # pylint: disable=redefined-variable-type + + if headers: + self._session.headers.update(headers) + self._base_url = base_url + self._timeout = timeout + + @property + def session(self): + return self._session + + @property + def base_url(self): + return self._base_url + + @property + def timeout(self): + return self._timeout + + def parse_body(self, resp): + raise NotImplementedError + + async def request(self, method, url, **kwargs): + """Makes an HTTP call using the Python requests library. + + This is the sole entry point to the requests library. All other helper methods in this + class call this method to send HTTP requests out. Refer to + http://docs.python-requests.org/en/master/api/ for more information on supported options + and features. + + Args: + method: HTTP method name as a string (e.g. get, post). + url: URL of the remote endpoint. + **kwargs: An additional set of keyword arguments to be passed into the requests API + (e.g. json, params, timeout). + + Returns: + Response: An HTTP response object. + + Raises: + RequestException: Any requests exceptions encountered while making the HTTP call. + """ + if 'timeout' not in kwargs: + kwargs['timeout'] = self.timeout + resp = await self._session.request(method, self.base_url + url, **kwargs) + resp.raise_for_status() + return resp + + async def headers(self, method, url, **kwargs): + resp = await self.request(method, url, **kwargs) + return resp.headers + + async def body_and_response(self, method, url, **kwargs): + resp = await self.request(method, url, **kwargs) + return await self.parse_body(resp), resp + + async def body(self, method, url, **kwargs): + resp = await self.request(method, url, **kwargs) + return await self.parse_body(resp) + return resp + + async def headers_and_body(self, method, url, **kwargs): + resp = await self.request(method, url, **kwargs) + return await resp.headers, self.parse_body(resp) + + async def close(self): + await self._session.close() + self._session = None + + async def parse_body(self, response): + wrapped_response = _aiohttp_requests._CombinedResponse(response) + content = await wrapped_response.content() + return json.loads(content) + + +class JsonHttpClientAsync(HttpClientAsync): + """An HTTP client that parses response messages as JSON.""" + + def __init__(self, **kwargs): + HttpClientAsync.__init__(self, **kwargs) + + def parse_body(self, resp): + return resp.json() \ No newline at end of file diff --git a/tests/test_http_client_async.py b/tests/test_http_client_async.py new file mode 100644 index 000000000..a00dbaa20 --- /dev/null +++ b/tests/test_http_client_async.py @@ -0,0 +1,177 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for firebase_admin._http_client.""" +from __future__ import absolute_import +import aiohttp + +import pytest +from pytest_localserver import http +from google.auth.transport import requests + +from firebase_admin import _http_client_async +from tests import testutils + + +_TEST_URL = 'http://firebase.test.url/' + +@pytest.mark.asyncio +async def test_http_client_default_session(): + client = _http_client_async.HttpClientAsync() + assert client.session is not None + assert isinstance(client.session, aiohttp.ClientSession) + assert client.base_url == '' + await client.close() + +@pytest.mark.asyncio +async def test_http_client_custom_session(): + session, recorder = make_mock_client_session() + client = _http_client_async.HttpClientAsync(session=session) + assert client.session is session + assert client.base_url == '' + resp = await client.request('GET', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + print(recorder) + assert len(recorder) == 1 + # assert recorder[0] + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + await client.close() + +@pytest.mark.asyncio +async def test_base_url(): + session, recorder = make_mock_client_session() + client = _http_client_async.HttpClientAsync(base_url=_TEST_URL, session=session) + assert client.session is not None + assert client.base_url == _TEST_URL + resp = await client.request('GET', 'foo') + assert resp.status_code == 200 + assert resp.text == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + 'foo' + await client.close() + +@pytest.mark.asyncio +async def test_credential_async(): + credential = testutils.MockGoogleCredentialAsync() + client = _http_client_async.HttpClientAsync( + credential=credential) + assert client.session is not None + session, recorder = make_mock_authorized_session(credential) + client._session = session + resp = await client.request('GET', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert len(recorder) == 1 + print(recorder[0].extra_kwargs) + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + assert recorder[0].extra_kwargs['headers']['authorization'] == 'Bearer mock-token' + await client.close() + +@pytest.mark.asyncio +@pytest.mark.parametrize('options, timeout', [ + ({}, _http_client_async.DEFAULT_TIMEOUT_SECONDS), + ({'timeout': 7}, 7), + ({'timeout': 0}, 0), + ({'timeout': None}, None), +]) +async def test_timeout(options, timeout): + session, recorder = make_mock_client_session() + client = _http_client_async.HttpClientAsync(**options, session=session) + assert client.timeout == timeout + await client.request('get', _TEST_URL) + assert len(recorder) == 1 + if timeout is None: + assert recorder[0].extra_kwargs['timeout'] is None + else: + assert recorder[0].extra_kwargs['timeout'] == pytest.approx(timeout, 0.001) + await client.close() + +def make_mock_client_session(payload='body', status=200): + recorder = [] + session = testutils.MockClientSession(payload, status, recorder) + client = _http_client_async.HttpClientAsync(session=session) + return session, recorder + +def make_mock_authorized_session(credentials, payload='body', status=200): + recorder = [] + session = testutils.MockAuthorizedSession(payload, status, recorder, credentials) + client = _http_client_async.HttpClientAsync(session=session) + return session, recorder + + +class TestHttpRetry: + """Unit tests for the default HTTP retry configuration.""" + + ENTITY_ENCLOSING_METHODS = ['post', 'put', 'patch'] + ALL_METHODS = ENTITY_ENCLOSING_METHODS + ['get', 'delete', 'head', 'options'] + + @classmethod + def setup_class(cls): + # Start a test server instance scoped to the class. + server = http.ContentServer() + server.start() + cls.httpserver = server + + @classmethod + def teardown_class(cls): + cls.httpserver.stop() + + def setup_method(self): + # Clean up any state in the server before starting a new test case. + self.httpserver.requests = [] + + @pytest.mark.asyncio + @pytest.mark.parametrize('method', ALL_METHODS) + async def test_retry_on_503(self, method): + self.httpserver.serve_content({}, 503) + client = _http_client_async.JsonHttpClientAsync( + credential=testutils.MockGoogleCredentialAsync(), base_url=self.httpserver.url) + body = None + if method in self.ENTITY_ENCLOSING_METHODS: + body = {'key': 'value'} + with pytest.raises(aiohttp.ClientError) as excinfo: + await client.request(method, '/', json=body) + assert excinfo.value.status == 503 + assert len(self.httpserver.requests) == 5 + await client.close() + + @pytest.mark.asyncio + @pytest.mark.parametrize('method', ALL_METHODS) + async def test_retry_on_500(self, method): + self.httpserver.serve_content({}, 500) + client = _http_client_async.JsonHttpClientAsync( + credential=testutils.MockGoogleCredentialAsync(), base_url=self.httpserver.url) + body = None + if method in self.ENTITY_ENCLOSING_METHODS: + body = {'key': 'value'} + with pytest.raises(aiohttp.ClientError) as excinfo: + await client.request(method, '/', json=body) + assert excinfo.value.status == 500 + assert len(self.httpserver.requests) == 5 + await client.close() + + @pytest.mark.asyncio + async def test_no_retry_on_404(self): + self.httpserver.serve_content({}, 404) + client = _http_client_async.JsonHttpClientAsync( + credential=testutils.MockGoogleCredentialAsync(), base_url=self.httpserver.url) + with pytest.raises(aiohttp.ClientError) as excinfo: + await client.request('get', '/') + await client.close() + assert excinfo.value.status == 404 + assert len(self.httpserver.requests) == 1 diff --git a/tests/testutils.py b/tests/testutils.py index 92755107c..8dcabd353 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -17,12 +17,18 @@ import os import pytest +import urllib3 from google.auth import credentials +from google.auth import _credentials_async from google.auth import transport +from google.auth.transport._aiohttp_requests import AuthorizedSession from requests import adapters from requests import models +import aiohttp +import asyncio + import firebase_admin @@ -119,6 +125,12 @@ class MockGoogleCredential(credentials.Credentials): def refresh(self, request): self.token = 'mock-token' +class MockGoogleCredentialAsync(_credentials_async.Credentials): + """A mock Google authentication credential.""" + async def refresh(self, request): + self.token = 'mock-token' + await asyncio.sleep(1) + class MockCredential(firebase_admin.credentials.Base): """A mock Firebase credential implementation.""" @@ -129,6 +141,14 @@ def __init__(self): def get_credential(self): return self._g_credential +class MockCredentialAsync(firebase_admin.credentials.Base): + """A mock Firebase credential implementation.""" + + def __init__(self): + self._g_credential_async = MockGoogleCredentialAsync() + + def get_credential_async(self): + return self._g_credential_async class MockMultiRequestAdapter(adapters.HTTPAdapter): """A mock HTTP adapter that supports multiple responses for the Python requests module.""" @@ -171,3 +191,71 @@ def status(self): @property def data(self): return self._responses[0] + +class MockClientResponse(aiohttp.ClientResponse): + def __init__(self, responses, statuses, recorder, current_response, method, url, **kwargs): + if len(responses) != len(statuses): + raise ValueError('The lengths of responses and statuses do not match.') + + self._url = url + self.status_code = statuses[current_response] + self.content = responses[current_response] + self.raw = io.BytesIO(responses[current_response].encode()) + + @property + def url(self): + return self._url + + @property + def status(self): + return self.status_code + + @property + def data(self): + return self + + @property + def text(self): + return self.content + +class MockSession(aiohttp.ClientSession): + def __init__(self, data, status, recorder, credentials=None): + super(MockSession, self).__init__(credentials) + # self._response = MockClientResponse(data, status, recorder, method, url) + self._current_response = 0 + self._data = data + self._responses = [data] + self._status = status + self._statuses = [status] + self.recorder = recorder + + # self._extra_kwargs = None + + async def _request(self, method, url, *args, **kwargs): + + self.method = method + self.url = url + self.args = args + self.extra_kwargs = kwargs + self.recorder.append(self) + resp = MockClientResponse(self._responses, self._statuses, self.recorder, self._current_response, method, url) + self._current_response = min(self._current_response + 1, len(self._responses) - 1) + return resp + + @property + def status(self): + return self._status_code + + @property + def data(self): + return self + +class MockClientSession(MockSession): + def __init__(self, data, status, recorder): + super(MockClientSession, self).__init__(data, status, recorder) + + +class MockAuthorizedSession(MockClientSession, AuthorizedSession): + def __init__(self, data, status, recorder, credentials): + super(MockAuthorizedSession, self).__init__(data, status, recorder) + self.credentials = credentials From d331a63d7c91cf142ab95e895a1bf07e7a0d0e53 Mon Sep 17 00:00:00 2001 From: jkyle109 Date: Tue, 16 Aug 2022 14:27:59 -0400 Subject: [PATCH 6/9] functional async http client and async messaging send --- firebase_admin/_http_client_async.py | 84 +++++---- firebase_admin/_utils.py | 27 +++ firebase_admin/messaging_async.py | 259 +++++++++++++++++++++++++++ integration/conftest.py | 14 +- integration/test_messaging_async.py | 102 +++++++++++ tests/test_http_client_async.py | 186 +++++++++---------- tests/testutils.py | 84 ++++----- 7 files changed, 572 insertions(+), 184 deletions(-) create mode 100644 firebase_admin/messaging_async.py create mode 100644 integration/test_messaging_async.py diff --git a/firebase_admin/_http_client_async.py b/firebase_admin/_http_client_async.py index 8ebb6ff8d..cfdc04364 100644 --- a/firebase_admin/_http_client_async.py +++ b/firebase_admin/_http_client_async.py @@ -12,36 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Internal HTTP client module. +"""Internal async HTTP client module. - This module provides utilities for making HTTP calls using the requests library. + This module provides utilities for making async HTTP calls using the aiohttp library. """ import json import aiohttp +from aiohttp.client_exceptions import ClientResponseError from google.auth.transport import _aiohttp_requests -import requests -from requests.packages.urllib3.util import retry # pylint: disable=import-error - -import urllib3 # type: ignore +from google.auth.transport._aiohttp_requests import _CombinedResponse DEFAULT_RETRY_ATTEMPTS = 4 -DEFAULT_RETRY_CODES = [500, 503] +DEFAULT_RETRY_CODES = (500, 503) DEFAULT_TIMEOUT_SECONDS = 120 class HttpClientAsync: - """Base HTTP client used to make HTTP calls. + """Base HTTP client used to make aiohttp calls. - HttpClient maintains an HTTP session, and handles request authentication and retries if + HttpClientAsync maintains an aiohttp session, and handles request authentication and retries if necessary. """ def __init__( - self, credential=None, + self, + credential=None, session=None, base_url='', headers=None, @@ -51,22 +50,23 @@ def __init__( ): """Creates a new HttpClientAsync instance from the provided arguments. - If a credential is provided, initializes a new HTTP session authorized with it. If neither - a credential nor a session is provided, initializes a new unauthorized session. + If a credential is provided, initializes a new aiohttp client session authorized with it. + If neither a credential nor a session is provided, initializes a new unauthorized client + session. Args: credential: A Google credential that can be used to authenticate requests (optional). - session: A custom HTTP session (optional). + session: A custom aiohttp session (optional). base_url: A URL prefix to be added to all outgoing requests (optional). headers: A map of headers to be added to all outgoing requests (optional). - retries: A urllib retry configuration. Default settings would retry once for low-level - connection and socket read errors, and up to 4 times for HTTP 500 and 503 errors. - Pass a False value to disable retries (optional). - timeout: HTTP timeout in seconds. Defaults to 120 seconds when not specified. Set to + retry_attempts: The maximum number of retries that should be attempeted for a request + (optional). + retry_codes: A list of status codes for which the request retry should be attempted + (optional). + timeout: A request timeout in seconds. Defaults to 120 seconds when not specified. Set to None to disable timeouts (optional). """ if credential: - # self._session = _aiohttp_requests.AuthorizedSession(credential) self._session = _aiohttp_requests.AuthorizedSession( credential, refresh_status_codes=retry_codes, @@ -99,30 +99,44 @@ def parse_body(self, resp): raise NotImplementedError async def request(self, method, url, **kwargs): - """Makes an HTTP call using the Python requests library. + """Makes an async HTTP call using the aiohttp library. - This is the sole entry point to the requests library. All other helper methods in this - class call this method to send HTTP requests out. Refer to + This is the sole entry point to the aiohttp library. All other helper methods in this + class call this method to send async HTTP requests out. Refer to http://docs.python-requests.org/en/master/api/ for more information on supported options and features. Args: method: HTTP method name as a string (e.g. get, post). url: URL of the remote endpoint. - **kwargs: An additional set of keyword arguments to be passed into the requests API + **kwargs: An additional set of keyword arguments to be passed into the aiohttp API (e.g. json, params, timeout). Returns: - Response: An HTTP response object. + Response: A ``_CombinedResponse`` wrapped ``ClientResponse`` object. Raises: - RequestException: Any requests exceptions encountered while making the HTTP call. + ClientResponseError: Any requests exceptions encountered while making the HTTP call. """ if 'timeout' not in kwargs: kwargs['timeout'] = self.timeout resp = await self._session.request(method, self.base_url + url, **kwargs) - resp.raise_for_status() - return resp + wrapped_resp = _CombinedResponse(resp) + + try: + # Get response content from StreamReader before it is closed by error. + print(wrapped_resp.content, "idk") + resp_content = await wrapped_resp.content() + # print(wrapped_resp._response.content) + resp.raise_for_status() + + # Catch response error and re-release it with after appending response body needed to + # determine the underlying reason for the error. + except ClientResponseError as err: + err.response = wrapped_resp + err.response_content = resp_content + raise err + return wrapped_resp async def headers(self, method, url, **kwargs): resp = await self.request(method, url, **kwargs) @@ -135,27 +149,23 @@ async def body_and_response(self, method, url, **kwargs): async def body(self, method, url, **kwargs): resp = await self.request(method, url, **kwargs) return await self.parse_body(resp) - return resp async def headers_and_body(self, method, url, **kwargs): resp = await self.request(method, url, **kwargs) return await resp.headers, self.parse_body(resp) async def close(self): - await self._session.close() - self._session = None - - async def parse_body(self, response): - wrapped_response = _aiohttp_requests._CombinedResponse(response) - content = await wrapped_response.content() - return json.loads(content) + if self._session is not None: + await self._session.close() + self._session = None class JsonHttpClientAsync(HttpClientAsync): - """An HTTP client that parses response messages as JSON.""" + """An async HTTP client that parses response messages as JSON.""" def __init__(self, **kwargs): HttpClientAsync.__init__(self, **kwargs) - def parse_body(self, resp): - return resp.json() \ No newline at end of file + async def parse_body(self, resp): + content = await resp.content() + return json.loads(content) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index dcfb520d2..ca96494d4 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -125,6 +125,33 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) +async def handle_platform_error_from_aiohttp(error, handle_func=None): + """Constructs a ``FirebaseError`` from the given requests error. + + This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. + + Args: + error: An error raised by the aiohttp module while making an HTTP call to a GCP API. + handle_func: A function that can be used to handle platform errors in a custom way. When + specified, this function will be called with three arguments. It has the same + signature as ```_handle_func_requests``, but may return ``None``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if error.response is None: + return handle_requests_error(error) + + response = error.response + content = error.response_content.decode() + status_code = response.status + error_dict, message = _parse_platform_error(content, status_code) + exc = None + if handle_func: + exc = handle_func(error, message, error_dict) + + return exc if exc else _handle_func_requests(error, message, error_dict) + def handle_operation_error(error): """Constructs a ``FirebaseError`` from the given operation error. diff --git a/firebase_admin/messaging_async.py b/firebase_admin/messaging_async.py new file mode 100644 index 000000000..236f5406c --- /dev/null +++ b/firebase_admin/messaging_async.py @@ -0,0 +1,259 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase Cloud Messaging Async module.""" + +import asyncio + +import firebase_admin +from firebase_admin.messaging import TopicManagementResponse +from firebase_admin._http_client_async import ( + JsonHttpClientAsync, + ClientResponseError, + DEFAULT_TIMEOUT_SECONDS +) +from firebase_admin._messaging_encoder import ( + Message, + MessageEncoder +) +from firebase_admin._messaging_utils import ( + QuotaExceededError, + SenderIdMismatchError, + ThirdPartyAuthError, + UnregisteredError +) +from firebase_admin import _utils + + + +_MESSAGING_ATTRIBUTE = '_messaging_async' + + +__all__ = [ + 'send', + # 'send_all', + # 'send_multicast', + 'subscribe_to_topic', + 'unsubscribe_from_topic', +] + + +def _get_messaging_service(app): + return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingServiceAsync) + +async def send(message, dry_run=False, app=None): + """Sends the given message via Firebase Cloud Messaging (FCM). + + If the ``dry_run`` mode is enabled, the message will not be actually delivered to the + recipients. Instead FCM performs all the usual validations, and emulates the send operation. + + Args: + message: An instance of ``messaging.Message``. + dry_run: A boolean indicating whether to run the operation in dry run mode (optional). + app: An App instance (optional). + + Returns: + string: A message ID string that uniquely identifies the sent message. + + Raises: + FirebaseError: If an error occurs while sending the message to the FCM service. + ValueError: If the input arguments are invalid. + """ + return await _get_messaging_service(app).send(message, dry_run) + +async def subscribe_to_topic(tokens, topic, app=None): + """Subscribes a list of registration tokens to an FCM topic. + + Args: + tokens: A non-empty list of device registration tokens. List may not have more than 1000 + elements. + topic: Name of the topic to subscribe to. May contain the ``/topics/`` prefix. + app: An App instance (optional). + + Returns: + TopicManagementResponse: A ``TopicManagementResponse`` instance. + + Raises: + FirebaseError: If an error occurs while communicating with instance ID service. + ValueError: If the input arguments are invalid. + """ + return await _get_messaging_service(app).make_topic_management_request( + tokens, topic, 'iid/v1:batchAdd') + +async def unsubscribe_from_topic(tokens, topic, app=None): + """Unsubscribes a list of registration tokens from an FCM topic. + + Args: + tokens: A non-empty list of device registration tokens. List may not have more than 1000 + elements. + topic: Name of the topic to unsubscribe from. May contain the ``/topics/`` prefix. + app: An App instance (optional). + + Returns: + TopicManagementResponse: A ``TopicManagementResponse`` instance. + + Raises: + FirebaseError: If an error occurs while communicating with instance ID service. + ValueError: If the input arguments are invalid. + """ + return await _get_messaging_service(app).make_topic_management_request( + tokens, topic, 'iid/v1:batchRemove') + + +class _MessagingServiceAsync: + """Service class that implements Firebase Cloud Messaging (FCM) functionality asynchronously.""" + + FCM_URL = 'https://fcm.googleapis.com/v1/projects/{0}/messages:send' + FCM_BATCH_URL = 'https://fcm.googleapis.com/batch' + IID_URL = 'https://iid.googleapis.com' + IID_HEADERS = {'access_token_auth': 'true'} + JSON_ENCODER = MessageEncoder() + + FCM_ERROR_TYPES = { + 'APNS_AUTH_ERROR': ThirdPartyAuthError, + 'QUOTA_EXCEEDED': QuotaExceededError, + 'SENDER_ID_MISMATCH': SenderIdMismatchError, + 'THIRD_PARTY_AUTH_ERROR': ThirdPartyAuthError, + 'UNREGISTERED': UnregisteredError, + } + + def __init__(self, app): + project_id = app.project_id + if not project_id: + raise ValueError( + 'Project ID is required to access Cloud Messaging service. Either set the ' + 'projectId option, or use service account credentials. Alternatively, set the ' + 'GOOGLE_CLOUD_PROJECT environment variable.') + self._fcm_url = _MessagingServiceAsync.FCM_URL.format(project_id) + self._fcm_headers = { + 'X-GOOG-API-FORMAT-VERSION': '2', + 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + } + timeout = app.options.get('httpTimeout', DEFAULT_TIMEOUT_SECONDS) + self._credential = app.credential.get_credential_async() + self._client = JsonHttpClientAsync(credential=self._credential, timeout=timeout) + self._loop = asyncio.get_event_loop() + + def close(self): + if self._client is not None: + self._loop.run_until_complete(self._client.close()) + self._client = None + + @classmethod + def encode_message(cls, message): + if not isinstance(message, Message): + raise ValueError('Message must be an instance of messaging.Message class.') + return cls.JSON_ENCODER.default(message) + + async def send(self, message, dry_run=False): + """Sends the given message to FCM via the FCM v1 API.""" + data = self._message_data(message, dry_run) + try: + resp = await self._client.body( + 'post', + url=self._fcm_url, + headers=self._fcm_headers, + json=data + ) + except ClientResponseError as error: + raise await self._handle_fcm_error(error) + else: + return resp['name'] + + async def make_topic_management_request(self, tokens, topic, operation): + """Invokes the IID service for topic management functionality.""" + if isinstance(tokens, str): + tokens = [tokens] + if not isinstance(tokens, list) or not tokens: + raise ValueError('Tokens must be a string or a non-empty list of strings.') + invalid_str = [t for t in tokens if not isinstance(t, str) or not t] + if invalid_str: + raise ValueError('Tokens must be non-empty strings.') + + if not isinstance(topic, str) or not topic: + raise ValueError('Topic must be a non-empty string.') + if not topic.startswith('/topics/'): + topic = '/topics/{0}'.format(topic) + data = { + 'to': topic, + 'registration_tokens': tokens, + } + url = '{0}/{1}'.format(_MessagingServiceAsync.IID_URL, operation) + try: + resp = await self._client.body( + 'post', + url=url, + json=data, + headers=_MessagingServiceAsync.IID_HEADERS + ) + except ClientResponseError as error: + raise self._handle_iid_error(error) + else: + return TopicManagementResponse(resp) + + def _message_data(self, message, dry_run): + data = {'message': _MessagingServiceAsync.encode_message(message)} + if dry_run: + data['validate_only'] = True + return data + + async def _handle_fcm_error(self, error): + """Handles errors received from the FCM API.""" + return await _utils.handle_platform_error_from_aiohttp( + error, _MessagingServiceAsync._build_fcm_error_aiohttp) + + def _handle_iid_error(self, error): + """Handles errors received from the Instance ID API.""" + if error.response is None: + raise _utils.handle_requests_error(error) + + data = {} + try: + parsed_body = error.response.json() + if isinstance(parsed_body, dict): + data = parsed_body + except ValueError: + pass + + # IID error response format: {"error": "ErrorCode"} + code = data.get('error') + msg = None + if code: + msg = 'Error while calling the IID service: {0}'.format(code) + else: + msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( + error.response.status_code, error.response.content.decode()) + + return _utils.handle_requests_error(error, msg) + + @classmethod + def _build_fcm_error_aiohttp(cls, error, message, error_dict): + """Parses an aiohttp error response from the FCM API and creates a FCM-specific exception if + appropriate.""" + exc_type = cls._build_fcm_error(error_dict) + return exc_type( + message, cause=error, + http_response=error.request_info + ) if exc_type else None + + @classmethod + def _build_fcm_error(cls, error_dict): + if not error_dict: + return None + fcm_code = None + for detail in error_dict.get('details', []): + if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': + fcm_code = detail.get('errorCode') + break + return _MessagingServiceAsync.FCM_ERROR_TYPES.get(fcm_code) diff --git a/integration/conftest.py b/integration/conftest.py index 169e02d5b..050156f80 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -15,6 +15,7 @@ """pytest configuration and global fixtures for integration tests.""" import json +import asyncio import pytest import firebase_admin @@ -60,7 +61,9 @@ def default_app(request): 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), 'storageBucket' : '{0}.appspot.com'.format(project_id) } - return firebase_admin.initialize_app(cred, ops) + app = firebase_admin.initialize_app(cred, ops) + yield app + firebase_admin.delete_app(app) @pytest.fixture(scope='session') def api_key(request): @@ -70,3 +73,12 @@ def api_key(request): 'command-line option.') with open(path) as keyfile: return keyfile.read().strip() + +@pytest.fixture(scope="session") +def event_loop(): + """Create an instance of the default event loop for test session. + This avoids early eventloop closure. + """ + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() diff --git a/integration/test_messaging_async.py b/integration/test_messaging_async.py new file mode 100644 index 000000000..bb58f82bd --- /dev/null +++ b/integration/test_messaging_async.py @@ -0,0 +1,102 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.messaging module.""" + +import re +from datetime import datetime + +import pytest + +from firebase_admin import ( + exceptions, + messaging, + messaging_async, +) + + +_REGISTRATION_TOKEN = ('fGw0qy4TGgk:APA91bGtWGjuhp4WRhHXgbabIYp1jxEKI08ofj_v1bKhWAGJQ4e3arRCWzeTf' + 'HaLz83mBnDh0aPWB1AykXAVUUGl2h1wT4XI6XazWpvY7RBUSYfoxtqSWGIm2nvWh2BOP1YG50' + '1SsRoE') + +@pytest.mark.asyncio +async def test_send(): + msg = messaging.Message( + topic='foo-bar', + notification=messaging.Notification('test-title', 'test-body', + 'https://images.unsplash.com/photo-1494438639946' + '-1ebd1d20bf85?fit=crop&w=900&q=60'), + android=messaging.AndroidConfig( + restricted_package_name='com.google.firebase.demos', + notification=messaging.AndroidNotification( + title='android-title', + body='android-body', + image='https://images.unsplash.com/' + 'photo-1494438639946-1ebd1d20bf85?fit=crop&w=900&q=60', + event_timestamp=datetime.now(), + priority='high', + vibrate_timings_millis=[100, 200, 300, 400], + visibility='public', + sticky=True, + local_only=False, + default_vibrate_timings=False, + default_sound=True, + default_light_settings=False, + light_settings=messaging.LightSettings( + color='#aabbcc', + light_off_duration_millis=200, + light_on_duration_millis=300 + ), + notification_count=1 + ) + ), + apns=messaging.APNSConfig(payload=messaging.APNSPayload( + aps=messaging.Aps( + alert=messaging.ApsAlert( + title='apns-title', + body='apns-body' + ) + ) + )) + ) + msg_id = await messaging_async.send(msg, dry_run=True) + assert re.match('^projects/.*/messages/.*$', msg_id) + +@pytest.mark.asyncio +async def test_send_invalid_token(): + msg = messaging.Message( + token=_REGISTRATION_TOKEN, + notification=messaging.Notification('test-title', 'test-body') + ) + with pytest.raises(messaging.UnregisteredError): + await messaging_async.send(msg, dry_run=True) + +@pytest.mark.asyncio +async def test_send_malformed_token(): + msg = messaging.Message( + token='not-a-token', + notification=messaging.Notification('test-title', 'test-body') + ) + with pytest.raises(exceptions.InvalidArgumentError): + await messaging_async.send(msg, dry_run=True) + +@pytest.mark.asyncio +async def test_subscribe(): + resp = await messaging_async.subscribe_to_topic(_REGISTRATION_TOKEN, 'mock-topic') + assert resp.success_count + resp.failure_count == 1 + +@pytest.mark.asyncio +async def test_unsubscribe(): + resp = await messaging_async.unsubscribe_from_topic(_REGISTRATION_TOKEN, 'mock-topic') + assert resp.success_count + resp.failure_count == 1 diff --git a/tests/test_http_client_async.py b/tests/test_http_client_async.py index a00dbaa20..7c164fb7c 100644 --- a/tests/test_http_client_async.py +++ b/tests/test_http_client_async.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2022 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for firebase_admin._http_client.""" +"""Tests for firebase_admin._http_client_async.""" from __future__ import absolute_import -import aiohttp +import asyncio +import aiohttp import pytest from pytest_localserver import http -from google.auth.transport import requests from firebase_admin import _http_client_async from tests import testutils @@ -26,93 +26,91 @@ _TEST_URL = 'http://firebase.test.url/' -@pytest.mark.asyncio -async def test_http_client_default_session(): - client = _http_client_async.HttpClientAsync() - assert client.session is not None - assert isinstance(client.session, aiohttp.ClientSession) - assert client.base_url == '' - await client.close() - -@pytest.mark.asyncio -async def test_http_client_custom_session(): - session, recorder = make_mock_client_session() - client = _http_client_async.HttpClientAsync(session=session) - assert client.session is session - assert client.base_url == '' - resp = await client.request('GET', _TEST_URL) - assert resp.status_code == 200 - assert resp.text == 'body' - print(recorder) - assert len(recorder) == 1 - # assert recorder[0] - assert recorder[0].method == 'GET' - assert recorder[0].url == _TEST_URL - await client.close() - -@pytest.mark.asyncio -async def test_base_url(): - session, recorder = make_mock_client_session() - client = _http_client_async.HttpClientAsync(base_url=_TEST_URL, session=session) - assert client.session is not None - assert client.base_url == _TEST_URL - resp = await client.request('GET', 'foo') - assert resp.status_code == 200 - assert resp.text == 'body' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == _TEST_URL + 'foo' - await client.close() - -@pytest.mark.asyncio -async def test_credential_async(): - credential = testutils.MockGoogleCredentialAsync() - client = _http_client_async.HttpClientAsync( - credential=credential) - assert client.session is not None - session, recorder = make_mock_authorized_session(credential) - client._session = session - resp = await client.request('GET', _TEST_URL) - assert resp.status_code == 200 - assert resp.text == 'body' - assert len(recorder) == 1 - print(recorder[0].extra_kwargs) - assert recorder[0].method == 'GET' - assert recorder[0].url == _TEST_URL - assert recorder[0].extra_kwargs['headers']['authorization'] == 'Bearer mock-token' - await client.close() - -@pytest.mark.asyncio -@pytest.mark.parametrize('options, timeout', [ - ({}, _http_client_async.DEFAULT_TIMEOUT_SECONDS), - ({'timeout': 7}, 7), - ({'timeout': 0}, 0), - ({'timeout': None}, None), -]) -async def test_timeout(options, timeout): - session, recorder = make_mock_client_session() - client = _http_client_async.HttpClientAsync(**options, session=session) - assert client.timeout == timeout - await client.request('get', _TEST_URL) - assert len(recorder) == 1 - if timeout is None: - assert recorder[0].extra_kwargs['timeout'] is None - else: - assert recorder[0].extra_kwargs['timeout'] == pytest.approx(timeout, 0.001) - await client.close() - def make_mock_client_session(payload='body', status=200): recorder = [] session = testutils.MockClientSession(payload, status, recorder) - client = _http_client_async.HttpClientAsync(session=session) return session, recorder def make_mock_authorized_session(credentials, payload='body', status=200): recorder = [] session = testutils.MockAuthorizedSession(payload, status, recorder, credentials) - client = _http_client_async.HttpClientAsync(session=session) return session, recorder +class TestHttpClient: + def seutp_method(self): + self.client = None + + def teardown_method(self): + if self.client is not None: + asyncio.get_event_loop().run_until_complete(self.client.close()) + + @pytest.mark.asyncio + async def test_http_client_default_session(self): + self.client = _http_client_async.HttpClientAsync() + assert self.client.session is not None + assert isinstance(self.client.session, aiohttp.ClientSession) + assert self.client.base_url == '' + + @pytest.mark.asyncio + async def test_http_client_custom_session(self): + session, recorder = make_mock_client_session() + self.client = _http_client_async.HttpClientAsync(session=session) + assert self.client.session is session + assert self.client.base_url == '' + resp = await self.client.request('GET', _TEST_URL) + assert resp.status == 200 + assert await resp.content() == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + + @pytest.mark.asyncio + async def test_base_url(self): + session, recorder = make_mock_client_session() + self.client = _http_client_async.HttpClientAsync(base_url=_TEST_URL, session=session) + assert self.client.session is not None + assert self.client.base_url == _TEST_URL + resp = await self.client.request('GET', 'foo') + assert resp.status == 200 + assert await resp.content() == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + 'foo' + + @pytest.mark.asyncio + async def test_credential_async(self): + credential = testutils.MockGoogleCredentialAsync() + self.client = _http_client_async.HttpClientAsync( + credential=credential) + assert self.client.session is not None + session, recorder = make_mock_authorized_session(credential) + self.client._session = session + resp = await self.client.request('GET', _TEST_URL) + assert resp.status == 200 + assert await resp.content() == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + assert recorder[0].extra_kwargs['headers']['authorization'] == 'Bearer mock-token' + + @pytest.mark.asyncio + @pytest.mark.parametrize('options, timeout', [ + ({}, _http_client_async.DEFAULT_TIMEOUT_SECONDS), + ({'timeout': 7}, 7), + ({'timeout': 0}, 0), + ({'timeout': None}, None), + ]) + async def test_timeout(self, options, timeout): + session, recorder = make_mock_client_session() + self.client = _http_client_async.HttpClientAsync(**options, session=session) + assert self.client.timeout == timeout + await self.client.request('get', _TEST_URL) + assert len(recorder) == 1 + if timeout is None: + assert recorder[0].extra_kwargs['timeout'] is None + else: + assert recorder[0].extra_kwargs['timeout'] == pytest.approx(timeout, 0.001) + class TestHttpRetry: """Unit tests for the default HTTP retry configuration.""" @@ -132,46 +130,48 @@ def teardown_class(cls): cls.httpserver.stop() def setup_method(self): + self.client = None # Clean up any state in the server before starting a new test case. self.httpserver.requests = [] + def teardown_method(self): + if self.client is not None: + asyncio.get_event_loop().run_until_complete(self.client.close()) + @pytest.mark.asyncio @pytest.mark.parametrize('method', ALL_METHODS) async def test_retry_on_503(self, method): self.httpserver.serve_content({}, 503) - client = _http_client_async.JsonHttpClientAsync( + self.client = _http_client_async.JsonHttpClientAsync( credential=testutils.MockGoogleCredentialAsync(), base_url=self.httpserver.url) body = None if method in self.ENTITY_ENCLOSING_METHODS: body = {'key': 'value'} with pytest.raises(aiohttp.ClientError) as excinfo: - await client.request(method, '/', json=body) - assert excinfo.value.status == 503 + await self.client.request(method, '/', json=body) + assert excinfo.value.response.status == 503 assert len(self.httpserver.requests) == 5 - await client.close() @pytest.mark.asyncio @pytest.mark.parametrize('method', ALL_METHODS) async def test_retry_on_500(self, method): self.httpserver.serve_content({}, 500) - client = _http_client_async.JsonHttpClientAsync( + self.client = _http_client_async.JsonHttpClientAsync( credential=testutils.MockGoogleCredentialAsync(), base_url=self.httpserver.url) body = None if method in self.ENTITY_ENCLOSING_METHODS: body = {'key': 'value'} with pytest.raises(aiohttp.ClientError) as excinfo: - await client.request(method, '/', json=body) - assert excinfo.value.status == 500 + await self.client.request(method, '/', json=body) + assert excinfo.value.response.status == 500 assert len(self.httpserver.requests) == 5 - await client.close() @pytest.mark.asyncio async def test_no_retry_on_404(self): self.httpserver.serve_content({}, 404) - client = _http_client_async.JsonHttpClientAsync( + self.client = _http_client_async.JsonHttpClientAsync( credential=testutils.MockGoogleCredentialAsync(), base_url=self.httpserver.url) with pytest.raises(aiohttp.ClientError) as excinfo: - await client.request('get', '/') - await client.close() - assert excinfo.value.status == 404 + await self.client.request('get', '/') + assert excinfo.value.response.status == 404 assert len(self.httpserver.requests) == 1 diff --git a/tests/testutils.py b/tests/testutils.py index 8dcabd353..f35a2585a 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -16,8 +16,12 @@ import io import os + +from unittest.mock import MagicMock import pytest -import urllib3 + +import aiohttp +from aiohttp import streams from google.auth import credentials from google.auth import _credentials_async @@ -26,8 +30,6 @@ from requests import adapters from requests import models -import aiohttp -import asyncio import firebase_admin @@ -128,8 +130,9 @@ def refresh(self, request): class MockGoogleCredentialAsync(_credentials_async.Credentials): """A mock Google authentication credential.""" async def refresh(self, request): + # filename = inspect.stack() + # print("refresh async") self.token = 'mock-token' - await asyncio.sleep(1) class MockCredential(firebase_admin.credentials.Base): @@ -193,69 +196,44 @@ def data(self): return self._responses[0] class MockClientResponse(aiohttp.ClientResponse): - def __init__(self, responses, statuses, recorder, current_response, method, url, **kwargs): - if len(responses) != len(statuses): - raise ValueError('The lengths of responses and statuses do not match.') - + def __init__(self, method, url, payload, status, recorder): # pylint: disable=super-init-not-called + self._cache = {} self._url = url - self.status_code = statuses[current_response] - self.content = responses[current_response] - self.raw = io.BytesIO(responses[current_response].encode()) - @property - def url(self): - return self._url - - @property - def status(self): - return self.status_code - - @property - def data(self): - return self - - @property - def text(self): - return self.content + mock_reader = AsyncMock(spec=streams.StreamReader) + mock_reader.read.return_value = payload + self.content = mock_reader + self.status = status + self.recorder = recorder + self._headers = [] class MockSession(aiohttp.ClientSession): - def __init__(self, data, status, recorder, credentials=None): + def __init__(self, payload, status, recorder, credentials=None): super(MockSession, self).__init__(credentials) - # self._response = MockClientResponse(data, status, recorder, method, url) - self._current_response = 0 - self._data = data - self._responses = [data] - self._status = status - self._statuses = [status] + self.payload = payload + self.status = status self.recorder = recorder + self.current_response = 0 - # self._extra_kwargs = None - - async def _request(self, method, url, *args, **kwargs): - + async def _request(self, method, url, *args, **kwargs): # pylint: disable=arguments-differ self.method = method self.url = url self.args = args self.extra_kwargs = kwargs self.recorder.append(self) - resp = MockClientResponse(self._responses, self._statuses, self.recorder, self._current_response, method, url) - self._current_response = min(self._current_response + 1, len(self._responses) - 1) - return resp - - @property - def status(self): - return self._status_code - - @property - def data(self): - return self + self.current_response += 1 + return MockClientResponse(method, url, self.payload, self.status, self.recorder) class MockClientSession(MockSession): - def __init__(self, data, status, recorder): - super(MockClientSession, self).__init__(data, status, recorder) - + def __init__(self, payload, status, recorder): + super(MockClientSession, self).__init__(payload, status, recorder) class MockAuthorizedSession(MockClientSession, AuthorizedSession): - def __init__(self, data, status, recorder, credentials): - super(MockAuthorizedSession, self).__init__(data, status, recorder) + def __init__(self, payload, status, recorder, credentials): + super(MockAuthorizedSession, self).__init__(payload, status, recorder) self.credentials = credentials + +# Custom async mock class since unuttest.mock.AsyncMock is only avaible in python 3.8+ +class AsyncMock(MagicMock): + async def __call__(self, *args, **kwargs): # pylint: disable=useless-super-delegation + return super(AsyncMock, self).__call__(*args, **kwargs) From 254500e92ad5b77d36e3fc5aeea8e0d00cfffa0e Mon Sep 17 00:00:00 2001 From: jkyle109 Date: Fri, 19 Aug 2022 11:26:16 -0400 Subject: [PATCH 7/9] Added type hints --- firebase_admin/_gapic_utils.py | 4 +- firebase_admin/_http_client.py | 4 +- firebase_admin/_http_client_async.py | 34 +++++++---- firebase_admin/_utils.py | 2 +- firebase_admin/messaging.py | 2 +- firebase_admin/messaging_async.py | 90 ++++++++++++++++++---------- 6 files changed, 88 insertions(+), 48 deletions(-) diff --git a/firebase_admin/_gapic_utils.py b/firebase_admin/_gapic_utils.py index 3c975808c..5e0f42fdd 100644 --- a/firebase_admin/_gapic_utils.py +++ b/firebase_admin/_gapic_utils.py @@ -17,8 +17,8 @@ import io import socket -import googleapiclient -import httplib2 +import googleapiclient # type: ignore +import httplib2 # type: ignore import requests from firebase_admin import exceptions diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index d259faddf..e30b421cd 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -17,9 +17,9 @@ This module provides utilities for making HTTP calls using the requests library. """ -from google.auth import transport +from google.auth import transport # type: ignore import requests -from requests.packages.urllib3.util import retry # pylint: disable=import-error +from requests.packages.urllib3.util import retry # type: ignore # pylint: disable=import-error if hasattr(retry.Retry.DEFAULT, 'allowed_methods'): diff --git a/firebase_admin/_http_client_async.py b/firebase_admin/_http_client_async.py index cfdc04364..b10190b01 100644 --- a/firebase_admin/_http_client_async.py +++ b/firebase_admin/_http_client_async.py @@ -20,8 +20,8 @@ import aiohttp from aiohttp.client_exceptions import ClientResponseError -from google.auth.transport import _aiohttp_requests -from google.auth.transport._aiohttp_requests import _CombinedResponse +from google.auth.transport import _aiohttp_requests # type: ignore +from google.auth.transport._aiohttp_requests import _CombinedResponse # type: ignore DEFAULT_RETRY_ATTEMPTS = 4 @@ -69,8 +69,8 @@ def __init__( if credential: self._session = _aiohttp_requests.AuthorizedSession( credential, - refresh_status_codes=retry_codes, max_refresh_attempts=retry_attempts, + refresh_status_codes=retry_codes, refresh_timeout=timeout ) elif session: @@ -116,7 +116,8 @@ class call this method to send async HTTP requests out. Refer to Response: A ``_CombinedResponse`` wrapped ``ClientResponse`` object. Raises: - ClientResponseError: Any requests exceptions encountered while making the HTTP call. + ClientResponseWithBodyError: Any requests exceptions encountered while making the async + HTTP call. """ if 'timeout' not in kwargs: kwargs['timeout'] = self.timeout @@ -124,18 +125,19 @@ class call this method to send async HTTP requests out. Refer to wrapped_resp = _CombinedResponse(resp) try: - # Get response content from StreamReader before it is closed by error. - print(wrapped_resp.content, "idk") + # Get response content from StreamReader before it is closed by error throw. resp_content = await wrapped_resp.content() - # print(wrapped_resp._response.content) resp.raise_for_status() - # Catch response error and re-release it with after appending response body needed to + # Catch response error and re-release it after appending response body needed to # determine the underlying reason for the error. except ClientResponseError as err: - err.response = wrapped_resp - err.response_content = resp_content - raise err + raise ClientResponseWithBodyError( + err.request_info, + err.history, + wrapped_resp, + resp_content + ) from err return wrapped_resp async def headers(self, method, url, **kwargs): @@ -169,3 +171,13 @@ def __init__(self, **kwargs): async def parse_body(self, resp): content = await resp.content() return json.loads(content) + + +class ClientResponseWithBodyError(aiohttp.ClientResponseError): + """A ClientResponseError wrapper to hold the response body of the underlying falied + aiohttp request. + """ + def __init__(self, request_info, history, response, response_content): + super(ClientResponseWithBodyError, self).__init__(request_info, history) + self.response = response + self.response_content = response_content diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index ca96494d4..af2bc8ebe 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -16,7 +16,7 @@ import json -import google.auth +import google.auth # type: ignore import requests import firebase_admin diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 46dd7d410..2fc96fef8 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -16,7 +16,7 @@ import json -from googleapiclient import http +from googleapiclient import http # type: ignore from googleapiclient import _auth import requests diff --git a/firebase_admin/messaging_async.py b/firebase_admin/messaging_async.py index 236f5406c..716f15c21 100644 --- a/firebase_admin/messaging_async.py +++ b/firebase_admin/messaging_async.py @@ -16,11 +16,23 @@ import asyncio +from typing import ( + Optional, + Any, + Type, + List, + Dict +) + import firebase_admin +from firebase_admin.exceptions import FirebaseError +from firebase_admin import ( + App +) from firebase_admin.messaging import TopicManagementResponse from firebase_admin._http_client_async import ( JsonHttpClientAsync, - ClientResponseError, + ClientResponseWithBodyError, DEFAULT_TIMEOUT_SECONDS ) from firebase_admin._messaging_encoder import ( @@ -40,7 +52,7 @@ _MESSAGING_ATTRIBUTE = '_messaging_async' -__all__ = [ +__all__: List[str] = [ 'send', # 'send_all', # 'send_multicast', @@ -48,11 +60,14 @@ 'unsubscribe_from_topic', ] - -def _get_messaging_service(app): +# pylint: disable=unsubscriptable-object +# TODO:(/b)Remove false positive unsubscriptable-object lint warnings caused by type hints Optional type. +# This is fixed in pylint 2.7.0 but this version introduces new lint rules and requires multiple +# file changes. +def _get_messaging_service(app: Optional[App]) -> "_MessagingServiceAsync": return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingServiceAsync) -async def send(message, dry_run=False, app=None): +async def send(message: Message, dry_run: bool = False, app: Optional[App] = None) -> str: """Sends the given message via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -72,7 +87,10 @@ async def send(message, dry_run=False, app=None): """ return await _get_messaging_service(app).send(message, dry_run) -async def subscribe_to_topic(tokens, topic, app=None): +async def subscribe_to_topic( + tokens: List[str], + topic: str, app: Optional[App] = None + ) -> TopicManagementResponse: """Subscribes a list of registration tokens to an FCM topic. Args: @@ -91,7 +109,11 @@ async def subscribe_to_topic(tokens, topic, app=None): return await _get_messaging_service(app).make_topic_management_request( tokens, topic, 'iid/v1:batchAdd') -async def unsubscribe_from_topic(tokens, topic, app=None): +async def unsubscribe_from_topic( + tokens: List[str], + topic: str, + app: Optional[App] = None + ) -> TopicManagementResponse: """Unsubscribes a list of registration tokens from an FCM topic. Args: @@ -114,13 +136,13 @@ async def unsubscribe_from_topic(tokens, topic, app=None): class _MessagingServiceAsync: """Service class that implements Firebase Cloud Messaging (FCM) functionality asynchronously.""" - FCM_URL = 'https://fcm.googleapis.com/v1/projects/{0}/messages:send' - FCM_BATCH_URL = 'https://fcm.googleapis.com/batch' - IID_URL = 'https://iid.googleapis.com' - IID_HEADERS = {'access_token_auth': 'true'} - JSON_ENCODER = MessageEncoder() + FCM_URL: str = 'https://fcm.googleapis.com/v1/projects/{0}/messages:send' + FCM_BATCH_URL: str = 'https://fcm.googleapis.com/batch' + IID_URL: str = 'https://iid.googleapis.com' + IID_HEADERS: Dict[str, str] = {'access_token_auth': 'true'} + JSON_ENCODER: MessageEncoder = MessageEncoder() - FCM_ERROR_TYPES = { + FCM_ERROR_TYPES: Dict[str, Type[FirebaseError]] = { 'APNS_AUTH_ERROR': ThirdPartyAuthError, 'QUOTA_EXCEEDED': QuotaExceededError, 'SENDER_ID_MISMATCH': SenderIdMismatchError, @@ -128,7 +150,7 @@ class _MessagingServiceAsync: 'UNREGISTERED': UnregisteredError, } - def __init__(self, app): + def __init__(self, app: App) -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -145,18 +167,18 @@ def __init__(self, app): self._client = JsonHttpClientAsync(credential=self._credential, timeout=timeout) self._loop = asyncio.get_event_loop() - def close(self): + def close(self) -> None: if self._client is not None: self._loop.run_until_complete(self._client.close()) - self._client = None + self._client = None # type: ignore[assignment] @classmethod - def encode_message(cls, message): + def encode_message(cls, message: Message) -> Dict[str, Any]: if not isinstance(message, Message): raise ValueError('Message must be an instance of messaging.Message class.') return cls.JSON_ENCODER.default(message) - async def send(self, message, dry_run=False): + async def send(self, message: Message, dry_run: bool = False) -> str: """Sends the given message to FCM via the FCM v1 API.""" data = self._message_data(message, dry_run) try: @@ -166,7 +188,7 @@ async def send(self, message, dry_run=False): headers=self._fcm_headers, json=data ) - except ClientResponseError as error: + except ClientResponseWithBodyError as error: raise await self._handle_fcm_error(error) else: return resp['name'] @@ -197,23 +219,23 @@ async def make_topic_management_request(self, tokens, topic, operation): json=data, headers=_MessagingServiceAsync.IID_HEADERS ) - except ClientResponseError as error: + except ClientResponseWithBodyError as error: raise self._handle_iid_error(error) else: return TopicManagementResponse(resp) - def _message_data(self, message, dry_run): + def _message_data(self, message: Message, dry_run: bool) -> Dict[str, Any]: data = {'message': _MessagingServiceAsync.encode_message(message)} if dry_run: - data['validate_only'] = True + data['validate_only'] = True # type: ignore[assignment] return data - async def _handle_fcm_error(self, error): + async def _handle_fcm_error(self, error: ClientResponseWithBodyError) -> FirebaseError: """Handles errors received from the FCM API.""" return await _utils.handle_platform_error_from_aiohttp( error, _MessagingServiceAsync._build_fcm_error_aiohttp) - def _handle_iid_error(self, error): + def _handle_iid_error(self, error: ClientResponseWithBodyError) -> FirebaseError: """Handles errors received from the Instance ID API.""" if error.response is None: raise _utils.handle_requests_error(error) @@ -238,22 +260,28 @@ def _handle_iid_error(self, error): return _utils.handle_requests_error(error, msg) @classmethod - def _build_fcm_error_aiohttp(cls, error, message, error_dict): + def _build_fcm_error_aiohttp( + cls, + error: ClientResponseWithBodyError, + message: Message, + error_dict: Dict[Any, Any] + ) -> Optional[FirebaseError]: """Parses an aiohttp error response from the FCM API and creates a FCM-specific exception if appropriate.""" - exc_type = cls._build_fcm_error(error_dict) - return exc_type( - message, cause=error, + exc_type: Optional[Type[FirebaseError]] = cls._build_fcm_error(error_dict) + return exc_type( # type: ignore[call-arg] + message, + cause=error, http_response=error.request_info ) if exc_type else None @classmethod - def _build_fcm_error(cls, error_dict): + def _build_fcm_error(cls, error_dict: Dict[str, Any]) -> Optional[Type[FirebaseError]]: if not error_dict: return None - fcm_code = None + fcm_code: Optional[str] = None for detail in error_dict.get('details', []): if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': fcm_code = detail.get('errorCode') break - return _MessagingServiceAsync.FCM_ERROR_TYPES.get(fcm_code) + return _MessagingServiceAsync.FCM_ERROR_TYPES.get(fcm_code) # type: ignore[arg-type] From c7bdc63673736b51843304990cd3d3214f9df156 Mon Sep 17 00:00:00 2001 From: jkyle109 Date: Fri, 19 Aug 2022 12:22:12 -0400 Subject: [PATCH 8/9] Lint changes to allow some files to be complient with newer pylint versions --- firebase_admin/_http_client_async.py | 12 ++++++------ firebase_admin/credentials.py | 24 ++++++++++++------------ firebase_admin/messaging_async.py | 18 +++++++++--------- tests/test_credentials.py | 4 ++-- 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/firebase_admin/_http_client_async.py b/firebase_admin/_http_client_async.py index b10190b01..178a6aac4 100644 --- a/firebase_admin/_http_client_async.py +++ b/firebase_admin/_http_client_async.py @@ -95,7 +95,7 @@ def base_url(self): def timeout(self): return self._timeout - def parse_body(self, resp): + async def parse_body(self, resp): raise NotImplementedError async def request(self, method, url, **kwargs): @@ -124,13 +124,13 @@ class call this method to send async HTTP requests out. Refer to resp = await self._session.request(method, self.base_url + url, **kwargs) wrapped_resp = _CombinedResponse(resp) - try: - # Get response content from StreamReader before it is closed by error throw. - resp_content = await wrapped_resp.content() - resp.raise_for_status() + # Get response content from StreamReader before it is closed by error throw. + resp_content = await wrapped_resp.content() # Catch response error and re-release it after appending response body needed to # determine the underlying reason for the error. + try: + resp.raise_for_status() except ClientResponseError as err: raise ClientResponseWithBodyError( err.request_info, @@ -178,6 +178,6 @@ class ClientResponseWithBodyError(aiohttp.ClientResponseError): aiohttp request. """ def __init__(self, request_info, history, response, response_content): - super(ClientResponseWithBodyError, self).__init__(request_info, history) + super().__init__(request_info, history) self.response = response self.response_content = response_content diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index b25f4e377..8af671900 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -103,20 +103,20 @@ def __init__(self, cert: str) -> None: IOError: If the specified certificate file doesn't exist or cannot be read. ValueError: If the specified certificate is invalid. """ - super(Certificate, self).__init__() + super().__init__() if _is_file_path(cert): - with open(cert) as json_file: + with open(cert, encoding="utf-8") as json_file: json_data = json.load(json_file) elif isinstance(cert, dict): json_data = cert else: raise ValueError( - 'Invalid certificate argument: "{0}". Certificate argument must be a file path, ' - 'or a dict containing the parsed file contents.'.format(cert)) + f'Invalid certificate argument: "{cert}". Certificate argument must be a file ' + 'path, or a dict containing the parsed file contents.') if json_data.get('type') != self._CREDENTIAL_TYPE: raise ValueError('Invalid service account certificate. Certificate must contain a ' - '"type" field set to "{0}".'.format(self._CREDENTIAL_TYPE)) + f'"type" field set to "{self._CREDENTIAL_TYPE}".') try: self._g_credential = service_account.Credentials.from_service_account_info( json_data, scopes=_scopes) @@ -124,7 +124,7 @@ def __init__(self, cert: str) -> None: json_data, scopes=_scopes) except ValueError as error: raise ValueError('Failed to initialize a certificate credential. ' - 'Caused by: "{0}"'.format(error)) + f'Caused by: "{error}"') from error @property def project_id(self) -> str: @@ -162,7 +162,7 @@ def __init__(self) -> None: The credentials will be lazily initialized when get_credential(), get_credential_async() or project_id() is called. See those methods for possible errors raised. """ - super(ApplicationDefault, self).__init__() + super().__init__() self._g_credential = None # Will be lazily-loaded via _load_credential(). self._g_credential_async = None # Will be lazily-loaded via _load_credential_async(). @@ -229,20 +229,20 @@ def __init__(self, refresh_token: str) -> None: IOError: If the specified file doesn't exist or cannot be read. ValueError: If the refresh token configuration is invalid. """ - super(RefreshToken, self).__init__() + super().__init__() if _is_file_path(refresh_token): - with open(refresh_token) as json_file: + with open(refresh_token, encoding="utf-8") as json_file: json_data = json.load(json_file) elif isinstance(refresh_token, dict): json_data = refresh_token else: raise ValueError( - 'Invalid refresh token argument: "{0}". Refresh token argument must be a file ' - 'path, or a dict containing the parsed file contents.'.format(refresh_token)) + f'Invalid refresh token argument: "{refresh_token}". Refresh token argument must be' + ' a file path, or a dict containing the parsed file contents.') if json_data.get('type') != self._CREDENTIAL_TYPE: raise ValueError('Invalid refresh token configuration. JSON must contain a ' - '"type" field set to "{0}".'.format(self._CREDENTIAL_TYPE)) + f'"type" field set to "{self._CREDENTIAL_TYPE}".') self._g_credential = credentials.Credentials.from_authorized_user_info(json_data, _scopes) self._g_credential_async = credentials_async.Credentials.from_authorized_user_info( json_data, diff --git a/firebase_admin/messaging_async.py b/firebase_admin/messaging_async.py index 716f15c21..73917f9fc 100644 --- a/firebase_admin/messaging_async.py +++ b/firebase_admin/messaging_async.py @@ -61,9 +61,9 @@ ] # pylint: disable=unsubscriptable-object -# TODO:(/b)Remove false positive unsubscriptable-object lint warnings caused by type hints Optional type. -# This is fixed in pylint 2.7.0 but this version introduces new lint rules and requires multiple -# file changes. +# TODO:(/b)Remove false positive unsubscriptable-object lint warnings caused by type hints Optional +# type. This is fixed in pylint 2.7.0 but this version introduces new lint rules and requires +# multiple file changes. def _get_messaging_service(app: Optional[App]) -> "_MessagingServiceAsync": return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingServiceAsync) @@ -160,7 +160,7 @@ def __init__(self, app: App) -> None: self._fcm_url = _MessagingServiceAsync.FCM_URL.format(project_id) self._fcm_headers = { 'X-GOOG-API-FORMAT-VERSION': '2', - 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + 'X-FIREBASE-CLIENT': f'fire-admin-python/{firebase_admin.__version__}' } timeout = app.options.get('httpTimeout', DEFAULT_TIMEOUT_SECONDS) self._credential = app.credential.get_credential_async() @@ -206,12 +206,12 @@ async def make_topic_management_request(self, tokens, topic, operation): if not isinstance(topic, str) or not topic: raise ValueError('Topic must be a non-empty string.') if not topic.startswith('/topics/'): - topic = '/topics/{0}'.format(topic) + topic = f'/topics/{topic}' data = { 'to': topic, 'registration_tokens': tokens, } - url = '{0}/{1}'.format(_MessagingServiceAsync.IID_URL, operation) + url = f'{_MessagingServiceAsync.IID_URL}/{operation}' try: resp = await self._client.body( 'post', @@ -252,10 +252,10 @@ def _handle_iid_error(self, error: ClientResponseWithBodyError) -> FirebaseError code = data.get('error') msg = None if code: - msg = 'Error while calling the IID service: {0}'.format(code) + msg = f'Error while calling the IID service: {code}' else: - msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( - error.response.status_code, error.response.content.decode()) + msg = (f'Unexpected HTTP response with status: {error.response.status_code}; ' + f'body: {error.response.content.decode()}') return _utils.handle_requests_error(error, msg) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index cceb6b6f9..1e1db6460 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -64,7 +64,7 @@ def test_init_from_invalid_certificate(self, file_name, error): with pytest.raises(error): credentials.Certificate(testutils.resource_filename(file_name)) - @pytest.mark.parametrize('arg', [None, 0, 1, True, False, list(), tuple(), dict()]) + @pytest.mark.parametrize('arg', [None, 0, 1, True, False, [], tuple(), {}]) def test_invalid_args(self, arg): with pytest.raises(ValueError): credentials.Certificate(arg) @@ -156,7 +156,7 @@ def test_init_from_invalid_file(self): credentials.RefreshToken( testutils.resource_filename('service_account.json')) - @pytest.mark.parametrize('arg', [None, 0, 1, True, False, list(), tuple(), dict()]) + @pytest.mark.parametrize('arg', [None, 0, 1, True, False, [], tuple(), {}]) def test_invalid_args(self, arg): with pytest.raises(ValueError): credentials.RefreshToken(arg) From 2c733eb16d01190e89670ebea3666ae545d93af7 Mon Sep 17 00:00:00 2001 From: jkyle109 Date: Wed, 24 Aug 2022 13:56:46 -0400 Subject: [PATCH 9/9] Unit tests for messaging async --- firebase_admin/credentials.py | 9 +- firebase_admin/messaging_async.py | 2 + tests/test_http_client_async.py | 9 +- tests/test_messaging_async.py | 444 ++++++++++++++++++++++++++++++ tests/testutils.py | 22 +- 5 files changed, 475 insertions(+), 11 deletions(-) create mode 100644 tests/test_messaging_async.py diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 8af671900..a556972a6 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -159,8 +159,9 @@ class ApplicationDefault(Base): def __init__(self) -> None: """Creates an instance that will use Application Default credentials. - The credentials will be lazily initialized when get_credential(), get_credential_async() - or project_id() is called. See those methods for possible errors raised. + The credentials will be lazily initialized when ``get_credential()``, + ``get_credential_async()`` or ``project_id()`` is called. See those methods for possible + errors raised. """ super().__init__() self._g_credential = None # Will be lazily-loaded via _load_credential(). @@ -216,8 +217,8 @@ class RefreshToken(Base): def __init__(self, refresh_token: str) -> None: """Initializes credentials from a refresh token JSON file. - The JSON must consist of client_id, client_secret and refresh_token fields. Refresh - token files are typically created and managed by the gcloud SDK. To instantiate + The JSON must consist of ``client_id``, ``client_secret`` and ``refresh_token`` fields. + Refresh token files are typically created and managed by the gcloud SDK. To instantiate credentials from a refresh token file, either specify the file path or a dict representing the parsed contents of the file. diff --git a/firebase_admin/messaging_async.py b/firebase_admin/messaging_async.py index 73917f9fc..2bd399746 100644 --- a/firebase_admin/messaging_async.py +++ b/firebase_admin/messaging_async.py @@ -169,6 +169,8 @@ def __init__(self, app: App) -> None: def close(self) -> None: if self._client is not None: + if self._loop.is_closed(): + self._loop = asyncio.get_event_loop() self._loop.run_until_complete(self._client.close()) self._client = None # type: ignore[assignment] diff --git a/tests/test_http_client_async.py b/tests/test_http_client_async.py index 7c164fb7c..8719f4909 100644 --- a/tests/test_http_client_async.py +++ b/tests/test_http_client_async.py @@ -59,7 +59,8 @@ async def test_http_client_custom_session(self): assert self.client.base_url == '' resp = await self.client.request('GET', _TEST_URL) assert resp.status == 200 - assert await resp.content() == 'body' + content = await resp.content() + assert content.decode() == 'body' assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == _TEST_URL @@ -72,7 +73,8 @@ async def test_base_url(self): assert self.client.base_url == _TEST_URL resp = await self.client.request('GET', 'foo') assert resp.status == 200 - assert await resp.content() == 'body' + content = await resp.content() + assert content.decode() == 'body' assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == _TEST_URL + 'foo' @@ -87,7 +89,8 @@ async def test_credential_async(self): self.client._session = session resp = await self.client.request('GET', _TEST_URL) assert resp.status == 200 - assert await resp.content() == 'body' + content = await resp.content() + assert content.decode() == 'body' assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == _TEST_URL diff --git a/tests/test_messaging_async.py b/tests/test_messaging_async.py new file mode 100644 index 000000000..480f6aa31 --- /dev/null +++ b/tests/test_messaging_async.py @@ -0,0 +1,444 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin.messaging module.""" +import datetime +import json +import numbers + +import asyncio + +from googleapiclient import http +from googleapiclient import _helpers +import pytest + +import firebase_admin +from firebase_admin import exceptions +from firebase_admin import messaging +from firebase_admin import messaging_async +from firebase_admin import _http_client_async + +from google.auth.transport._aiohttp_requests import AuthorizedSession +from tests import testutils + + +NON_STRING_ARGS = [list(), tuple(), dict(), True, False, 1, 0] +NON_DICT_ARGS = ['', list(), tuple(), True, False, 1, 0, {1: 'foo'}, {'foo': 1}] +NON_OBJECT_ARGS = [list(), tuple(), dict(), 'foo', 0, 1, True, False] +NON_LIST_ARGS = ['', tuple(), dict(), True, False, 1, 0, [1], ['foo', 1]] +NON_UINT_ARGS = ['1.23s', list(), tuple(), dict(), -1.23] +HTTP_ERROR_CODES = { + 400: exceptions.InvalidArgumentError, + 403: exceptions.PermissionDeniedError, + 404: exceptions.NotFoundError, + 500: exceptions.InternalError, + 503: exceptions.UnavailableError, +} +FCM_ERROR_CODES = { + 'APNS_AUTH_ERROR': messaging.ThirdPartyAuthError, + 'QUOTA_EXCEEDED': messaging.QuotaExceededError, + 'SENDER_ID_MISMATCH': messaging.SenderIdMismatchError, + 'THIRD_PARTY_AUTH_ERROR': messaging.ThirdPartyAuthError, + 'UNREGISTERED': messaging.UnregisteredError, +} + + +def check_exception(exception, message, status): + assert isinstance(exception, exceptions.FirebaseError) + assert str(exception) == message + assert exception.cause is not None + assert exception.http_response is not None + assert exception.http_response.status_code == status + + +class TestTimeoutAsync: + + def teardown(self): + testutils.cleanup_apps() + + def _instrument_service(self, response): + app = firebase_admin.get_app() + fcm_service_async = messaging_async._get_messaging_service(app) + recorder = [] + credentials = fcm_service_async._client.session.credentials + session = testutils.MockAuthorizedSession(json.dumps(response), 200, recorder, credentials) + fcm_service_async._client._session = session + return recorder + + def _check_timeout(self, recorder, timeout): + assert len(recorder) == 1 + if timeout is None: + assert recorder[0].extra_kwargs['timeout'] is None + else: + assert recorder[0].extra_kwargs['timeout'] == pytest.approx(timeout, 0.001) + + @pytest.mark.parametrize('options, timeout', [ + ({'httpTimeout': 4}, 4), + ({'httpTimeout': None}, None), + ({}, _http_client_async.DEFAULT_TIMEOUT_SECONDS), + ]) + @pytest.mark.asyncio + async def test_send_async(self, options, timeout): + cred = testutils.MockCredentialAsync() + all_options = {'projectId': 'explicit-project-id'} + all_options.update(options) + firebase_admin.initialize_app(cred, all_options) + recorder = self._instrument_service({'name': 'message-id'}) + msg = messaging.Message(topic='foo') + await messaging_async.send(msg) + self._check_timeout(recorder, timeout) + + @pytest.mark.parametrize('options, timeout', [ + ({'httpTimeout': 4}, 4), + ({'httpTimeout': None}, None), + ({}, _http_client_async.DEFAULT_TIMEOUT_SECONDS), + ]) + @pytest.mark.asyncio + async def test_topic_management_custom_timeout(self, options, timeout): + cred = testutils.MockCredentialAsync() + all_options = {'projectId': 'explicit-project-id'} + all_options.update(options) + firebase_admin.initialize_app(cred, all_options) + recorder = self._instrument_service({'results': [{}, {'error': 'error_reason'}]}) + await messaging_async.subscribe_to_topic(['1'], 'a') + self._check_timeout(recorder, timeout) + + +class TestSendAsync: + + _DEFAULT_RESPONSE = json.dumps({'name': 'message-id'}) + _CLIENT_VERSION = 'fire-admin-python/{0}'.format(firebase_admin.__version__) + + def setup(self): + cred = testutils.MockCredentialAsync() + firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) + + def teardown(self): + testutils.cleanup_apps() + + def _instrument_messaging_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE): + if not app: + app = firebase_admin.get_app() + fcm_service_async = messaging_async._get_messaging_service(app) + recorder = [] + + credentials = fcm_service_async._client.session.credentials + session = testutils.MockAuthorizedSession(payload, status, recorder, credentials) + fcm_service_async._client._session = session + + return fcm_service_async, recorder + + def _get_url(self, project_id): + return messaging_async._MessagingServiceAsync.FCM_URL.format(project_id) + + @pytest.mark.asyncio + async def test_no_project_id(self): + async def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredentialAsync(), name='no_project_id') + with pytest.raises(ValueError): + await messaging_async.send(messaging.Message(topic='foo'), app=app) + await testutils.run_without_project_id_async(evaluate) + + @pytest.mark.parametrize('msg', NON_OBJECT_ARGS + [None]) + @pytest.mark.asyncio + async def test_invalid_send(self, msg): + with pytest.raises(ValueError) as excinfo: + await messaging_async.send(msg) + assert str(excinfo.value) == 'Message must be an instance of messaging.Message class.' + + @pytest.mark.asyncio + async def test_send_dry_run(self): + _, recorder = self._instrument_messaging_service() + msg = messaging.Message(topic='foo') + msg_id = await messaging_async.send(msg, dry_run=True) + assert msg_id == 'message-id' + assert len(recorder) == 1 + assert recorder[0].method == 'post' + assert recorder[0].url == self._get_url('explicit-project-id') + assert recorder[0].extra_kwargs['headers']['X-GOOG-API-FORMAT-VERSION'] == '2' + assert recorder[0].extra_kwargs['headers']['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION + body = { + 'message': messaging_async._MessagingServiceAsync.encode_message(msg), + 'validate_only': True, + } + assert recorder[0].extra_kwargs['json'] == body + + @pytest.mark.asyncio + async def test_send(self): + _, recorder = self._instrument_messaging_service() + msg = messaging.Message(topic='foo') + msg_id = await messaging_async.send(msg) + assert msg_id == 'message-id' + assert len(recorder) == 1 + assert recorder[0].method == 'post' + assert recorder[0].url == self._get_url('explicit-project-id') + assert recorder[0].extra_kwargs['headers']['X-GOOG-API-FORMAT-VERSION'] == '2' + assert recorder[0].extra_kwargs['headers']['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION + body = {'message': messaging_async._MessagingServiceAsync.encode_message(msg)} + assert recorder[0].extra_kwargs['json'] == body + + # @pytest.mark.parametrize('status,exc_type', HTTP_ERROR_CODES.items()) + # @pytest.mark.asyncio + # async def test_send_error(self, status, exc_type): + # _, recorder = self._instrument_messaging_service(status=status, payload='{}') + # msg = messaging.Message(topic='foo') + # with pytest.raises(exc_type) as excinfo: + # await messaging_async.send(msg) + # expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) + # check_exception(excinfo.value, expected, status) + # assert len(recorder) == 1 + # assert recorder[0].method == 'POST' + # assert recorder[0].url == self._get_url('explicit-project-id') + # assert recorder[0].extra_kwargs['headers']['X-GOOG-API-FORMAT-VERSION'] == '2' + # assert recorder[0].extra_kwargs['headers']['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION + # body = {'message': messaging_async._MessagingServiceAsync.JSON_ENCODER.default(msg)} + # assert recorder[0].extra_kwargs['json'] == body + + # @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + # @pytest.mark.asyncio + # async def test_send_detailed_error(self, status): + # payload = json.dumps({ + # 'error': { + # 'status': 'INVALID_ARGUMENT', + # 'message': 'test error' + # } + # }) + # _, recorder = self._instrument_messaging_service(status=status, payload=payload) + # msg = messaging.Message(topic='foo') + # with pytest.raises(exceptions.InvalidArgumentError) as excinfo: + # await messaging_async.send(msg) + # check_exception(excinfo.value, 'test error', status) + # assert len(recorder) == 1 + # assert recorder[0].method == 'post' + # assert recorder[0].url == self._get_url('explicit-project-id') + # body = {'message': messaging_async._MessagingServiceAsync.JSON_ENCODER.default(msg)} + # assert recorder[0].extra_kwargs['json'] == body + + # @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + # @pytest.mark.asyncio + # async def test_send_canonical_error_code(self, status): + # payload = json.dumps({ + # 'error': { + # 'status': 'NOT_FOUND', + # 'message': 'test error' + # } + # }) + # _, recorder = self._instrument_messaging_service(status=status, payload=payload) + # msg = messaging.Message(topic='foo') + # with pytest.raises(exceptions.NotFoundError) as excinfo: + # await messaging_async.send(msg) + # check_exception(excinfo.value, 'test error', status) + # assert len(recorder) == 1 + # assert recorder[0].method == 'post' + # assert recorder[0].url == self._get_url('explicit-project-id') + # body = {'message': messaging_async._MessagingServiceAsync.JSON_ENCODER.default(msg)} + # assert recorder[0].extra_kwargs['json'] == body + + # @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + # @pytest.mark.parametrize('fcm_error_code, exc_type', FCM_ERROR_CODES.items()) + # @pytest.mark.asyncio + # async def test_send_fcm_error_code(self, status, fcm_error_code, exc_type): + # payload = json.dumps({ + # 'error': { + # 'status': 'INVALID_ARGUMENT', + # 'message': 'test error', + # 'details': [ + # { + # '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + # 'errorCode': fcm_error_code, + # }, + # ], + # } + # }) + # _, recorder = self._instrument_messaging_service(status=status, payload=payload) + # msg = messaging.Message(topic='foo') + # with pytest.raises(exc_type) as excinfo: + # await messaging_async.send(msg) + # check_exception(excinfo.value, 'test error', status) + # assert len(recorder) == 1 + # assert recorder[0].method == 'post' + # assert recorder[0].url == self._get_url('explicit-project-id') + # body = {'message': messaging_async._MessagingServiceAsync.JSON_ENCODER.default(msg)} + # assert recorder[0].extra_kwargs['json'] == body + + # @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + # @pytest.mark.asyncio + # async def test_send_unknown_fcm_error_code(self, status): + # payload = json.dumps({ + # 'error': { + # 'status': 'INVALID_ARGUMENT', + # 'message': 'test error', + # 'details': [ + # { + # '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + # 'errorCode': 'SOME_UNKNOWN_CODE', + # }, + # ], + # } + # }) + # _, recorder = self._instrument_messaging_service(status=status, payload=payload) + # msg = messaging.Message(topic='foo') + # with pytest.raises(exceptions.InvalidArgumentError) as excinfo: + # await messaging_async.send(msg) + # check_exception(excinfo.value, 'test error', status) + # assert len(recorder) == 1 + # assert recorder[0].method == 'post' + # assert recorder[0].url == self._get_url('explicit-project-id') + # body = {'message': messaging_async._MessagingServiceAsync.JSON_ENCODER.default(msg)} + # assert recorder[0].extra_kwargs['json'] == body + + +class TestTopicManagementAsync: + + _DEFAULT_RESPONSE = json.dumps({'results': [{}, {'error': 'error_reason'}]}) + _DEFAULT_ERROR_RESPONSE = json.dumps({'error': 'error_reason'}) + _VALID_ARGS = [ + # (tokens, topic, expected) + ( + ['foo', 'bar'], + 'test-topic', + {'to': '/topics/test-topic', 'registration_tokens': ['foo', 'bar']} + ), + ( + 'foo', + '/topics/test-topic', + {'to': '/topics/test-topic', 'registration_tokens': ['foo']} + ), + ] + + def setup(self): + cred = testutils.MockCredentialAsync() + firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) + + def teardown(self): + testutils.cleanup_apps() + + def _instrument_iid_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE): + if not app: + app = firebase_admin.get_app() + fcm_service_async = messaging_async._get_messaging_service(app) + recorder = [] + + credentials = fcm_service_async._client.session.credentials + session = testutils.MockAuthorizedSession(payload, status, recorder, credentials) + fcm_service_async._client._session = session + + return fcm_service_async, recorder + + def _get_url(self, path): + return '{0}/{1}'.format(messaging_async._MessagingServiceAsync.IID_URL, path) + + @pytest.mark.parametrize('tokens', [None, '', [], {}, tuple()]) + @pytest.mark.asyncio + async def test_invalid_tokens(self, tokens): + expected = 'Tokens must be a string or a non-empty list of strings.' + if isinstance(tokens, str): + expected = 'Tokens must be non-empty strings.' + + with pytest.raises(ValueError) as excinfo: + await messaging_async.subscribe_to_topic(tokens, 'test-topic') + assert str(excinfo.value) == expected + + with pytest.raises(ValueError) as excinfo: + await messaging_async.unsubscribe_from_topic(tokens, 'test-topic') + assert str(excinfo.value) == expected + + @pytest.mark.parametrize('topic', NON_STRING_ARGS + [None, '']) + @pytest.mark.asyncio + async def test_invalid_topic(self, topic): + expected = 'Topic must be a non-empty string.' + with pytest.raises(ValueError) as excinfo: + await messaging_async.subscribe_to_topic('test-token', topic) + assert str(excinfo.value) == expected + + with pytest.raises(ValueError) as excinfo: + await messaging_async.unsubscribe_from_topic('test-tokens', topic) + assert str(excinfo.value) == expected + + @pytest.mark.parametrize('args', _VALID_ARGS) + @pytest.mark.asyncio + async def test_subscribe_to_topic(self, args): + _, recorder = self._instrument_iid_service() + resp = await messaging_async.subscribe_to_topic(args[0], args[1]) + self._check_response(resp) + assert len(recorder) == 1 + assert recorder[0].method == 'post' + assert recorder[0].url == self._get_url('iid/v1:batchAdd') + assert recorder[0].extra_kwargs['json'] == args[2] + + # @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + # @pytest.mark.asyncio + # async def test_subscribe_to_topic_error(self, status, exc_type): + # _, recorder = self._instrument_iid_service( + # status=status, payload=self._DEFAULT_ERROR_RESPONSE) + # with pytest.raises(exc_type) as excinfo: + # await messaging_async.subscribe_to_topic('foo', 'test-topic') + # assert str(excinfo.value) == 'Error while calling the IID service: error_reason' + # assert len(recorder) == 1 + # assert recorder[0].method == 'POST' + # assert recorder[0].url == self._get_url('iid/v1:batchAdd') + + # @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + # @pytest.mark.asyncio + # async def test_subscribe_to_topic_non_json_error(self, status, exc_type): + # _, recorder = self._instrument_iid_service(status=status, payload='not json') + # with pytest.raises(exc_type) as excinfo: + # await messaging_async.subscribe_to_topic('foo', 'test-topic') + # reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) + # assert str(excinfo.value) == reason + # assert len(recorder) == 1 + # assert recorder[0].method == 'POST' + # assert recorder[0].url == self._get_url('iid/v1:batchAdd') + + @pytest.mark.parametrize('args', _VALID_ARGS) + @pytest.mark.asyncio + async def test_unsubscribe_from_topic(self, args): + _, recorder = self._instrument_iid_service() + resp = await messaging_async.unsubscribe_from_topic(args[0], args[1]) + self._check_response(resp) + assert len(recorder) == 1 + assert recorder[0].method == 'post' + assert recorder[0].url == self._get_url('iid/v1:batchRemove') + assert recorder[0].extra_kwargs['json'] == args[2] + + # @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + # @pytest.mark.asyncio + # async def test_unsubscribe_from_topic_error(self, status, exc_type): + # _, recorder = self._instrument_iid_service( + # status=status, payload=self._DEFAULT_ERROR_RESPONSE) + # with pytest.raises(exc_type) as excinfo: + # await messaging_async.unsubscribe_from_topic('foo', 'test-topic') + # assert str(excinfo.value) == 'Error while calling the IID service: error_reason' + # assert len(recorder) == 1 + # assert recorder[0].method == 'POST' + # assert recorder[0].url == self._get_url('iid/v1:batchRemove') + + # @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + # @pytest.mark.asyncio + # async def test_unsubscribe_from_topic_non_json_error(self, status, exc_type): + # _, recorder = self._instrument_iid_service(status=status, payload='not json') + # with pytest.raises(exc_type) as excinfo: + # await messaging_async.unsubscribe_from_topic('foo', 'test-topic') + # reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) + # assert str(excinfo.value) == reason + # assert len(recorder) == 1 + # assert recorder[0].method == 'POST' + # assert recorder[0].url == self._get_url('iid/v1:batchRemove') + + def _check_response(self, resp): + assert resp.success_count == 1 + assert resp.failure_count == 1 + assert len(resp.errors) == 1 + assert resp.errors[0].index == 1 + assert resp.errors[0].reason == 'error_reason' diff --git a/tests/testutils.py b/tests/testutils.py index f35a2585a..9fe788d4f 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -67,6 +67,22 @@ def run_without_project_id(func): if gcloud_project: os.environ[env_var] = gcloud_project +async def run_without_project_id_async(func): + env_vars = ['GCLOUD_PROJECT', 'GOOGLE_CLOUD_PROJECT'] + env_values = [] + for env_var in env_vars: + gcloud_project = os.environ.get(env_var) + if gcloud_project: + del os.environ[env_var] + env_values.append(gcloud_project) + try: + await func() + finally: + for idx, env_var in enumerate(env_vars): + gcloud_project = env_values[idx] + if gcloud_project: + os.environ[env_var] = gcloud_project + def new_monkeypatch(): return pytest.MonkeyPatch() @@ -130,8 +146,6 @@ def refresh(self, request): class MockGoogleCredentialAsync(_credentials_async.Credentials): """A mock Google authentication credential.""" async def refresh(self, request): - # filename = inspect.stack() - # print("refresh async") self.token = 'mock-token' @@ -201,7 +215,7 @@ def __init__(self, method, url, payload, status, recorder): # pylint: disable=su self._url = url mock_reader = AsyncMock(spec=streams.StreamReader) - mock_reader.read.return_value = payload + mock_reader.read.return_value = str.encode(payload) self.content = mock_reader self.status = status self.recorder = recorder @@ -233,7 +247,7 @@ def __init__(self, payload, status, recorder, credentials): super(MockAuthorizedSession, self).__init__(payload, status, recorder) self.credentials = credentials -# Custom async mock class since unuttest.mock.AsyncMock is only avaible in python 3.8+ +# Custom async mock class since unittest.mock.AsyncMock is only avaible in python 3.8+ class AsyncMock(MagicMock): async def __call__(self, *args, **kwargs): # pylint: disable=useless-super-delegation return super(AsyncMock, self).__call__(*args, **kwargs)