Skip to content

Commit

Permalink
try to catch up this PR with the new, stricter requirements treq has …
Browse files Browse the repository at this point in the history
…inherited in the meanwhile
  • Loading branch information
glyph committed Jun 18, 2024
1 parent 1c35956 commit 9198250
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions src/treq/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

import binascii
from enum import Enum
from typing import Union, Optional
from typing import Union, Optional, TypedDict
from urllib.parse import urlparse

from twisted.python.randbytes import secureRandom
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IBodyProducer, IResponse
from twisted.internet.defer import Deferred

from zope.interface import implementer
from requests.utils import parse_dict_header

Expand Down Expand Up @@ -52,6 +54,22 @@ def _sha512_utf_digest(x: str) -> str:
return hashlib.sha512(x.encode("utf-8")).hexdigest()


class _DigestAuthCacheParams(TypedDict):
path: bytes
method: bytes
cached: bool
nonce: str
realm: str
qop: str | None
algorithm: _DIGEST_ALGO
opaque: str | None


class _DigestAuthCacheEntry(TypedDict):
c: int
p: _DigestAuthCacheParams


class HTTPDigestAuth(object):
"""
The container for HTTP Digest authentication credentials.
Expand All @@ -61,17 +79,15 @@ class HTTPDigestAuth(object):
"""

def __init__(self, username: Union[str, bytes], password: Union[str, bytes]):
if isinstance(username, bytes):
self._username: str = username.decode("utf-8")
else:
self._username: str = username
if isinstance(password, bytes):
self._password: str = password.decode("utf-8")
else:
self._password: str = password
self._username: str = (
username.decode("utf-8") if isinstance(username, bytes) else username
)
self._password: str = (
password.decode("utf-8") if isinstance(password, bytes) else password
)

# (method,uri) --> digest auth cache
self._digest_auth_cache = {}
self._digest_auth_cache: dict[tuple[bytes, bytes], _DigestAuthCacheEntry] = {}

def _build_authentication_header(
self,
Expand Down Expand Up @@ -129,7 +145,7 @@ def _build_authentication_header(
digest_hash_func = _sha512_utf_digest
else:
raise ValueError(
f"Unsupported Digest Auth algorithm identifier " f"passed: {algo.name}"
f"Unsupported Digest Auth algorithm identifier passed: {algo}"
)

KD = lambda s, d: digest_hash_func(f"{s}:{d}") # noqa:E731
Expand Down Expand Up @@ -169,7 +185,7 @@ def _build_authentication_header(
base += f', qop="auth", nc={ncvalue}, cnonce="{cnonce}"'

if not cached:
cache_params = {
cache_params: _DigestAuthCacheParams = {
"path": url,
"method": method,
"cached": cached,
Expand All @@ -183,7 +199,9 @@ def _build_authentication_header(

return f"Digest {base}"

def _cached_metadata_for(self, method: bytes, uri: bytes) -> Optional[dict]:
def _cached_metadata_for(
self, method: bytes, uri: bytes
) -> Optional[_DigestAuthCacheEntry]:
return self._digest_auth_cache.get((method, uri))


Expand Down Expand Up @@ -245,7 +263,7 @@ def _on_401_response(
uri: bytes,
headers: Optional[Headers],
bodyProducer: Optional[IBodyProducer],
):
) -> Deferred[IResponse]:
"""
Handle the server`s 401 response, that is capable with authentication
headers, build the Authorization header
Expand Down Expand Up @@ -299,7 +317,7 @@ def _perform_request(
uri: bytes,
headers: Optional[Headers],
bodyProducer: Optional[IBodyProducer],
):
) -> Deferred[IResponse]:
"""
Add Authorization header and perform the request with
actual credentials
Expand All @@ -316,7 +334,7 @@ def _perform_request(
"""
if not headers:
headers = Headers(
{b"Authorization": digest_authentication_header.encode("utf-8")}
{b"Authorization": [digest_authentication_header.encode("utf-8")]}
)
else:
headers.addRawHeader(
Expand All @@ -332,7 +350,7 @@ def request(
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
):
) -> Deferred[IResponse]:
"""
Wrap the agent with HTTP Digest authentication.
Expand Down Expand Up @@ -409,7 +427,7 @@ def add_digest_auth(agent: IAgent, http_digest_auth: HTTPDigestAuth) -> IAgent:
return _RequestDigestAuthenticationAgent(agent, http_digest_auth)


def add_auth(agent: IAgent, auth_config: Union[tuple, HTTPDigestAuth]):
def add_auth(agent: IAgent, auth_config: Union[tuple, HTTPDigestAuth]) -> IAgent:
"""
Wrap an agent to perform authentication
Expand Down

0 comments on commit 9198250

Please sign in to comment.