From 4c472f67d0688bfde93c6c0d6a229b05d6594f65 Mon Sep 17 00:00:00 2001 From: Chris Woodward Date: Wed, 29 Dec 2021 16:20:30 -0500 Subject: [PATCH] wip,adding ca file support --- arango/ca_certificate.py | 33 +++++++++++++++++++++++++++++++++ arango/client.py | 9 +++++++-- arango/connection.py | 23 ++++++++++++++++++++--- arango/http.py | 13 ++++++++++--- arango/version.py | 4 ++-- 5 files changed, 72 insertions(+), 10 deletions(-) create mode 100644 arango/ca_certificate.py diff --git a/arango/ca_certificate.py b/arango/ca_certificate.py new file mode 100644 index 00000000..ed040d9c --- /dev/null +++ b/arango/ca_certificate.py @@ -0,0 +1,33 @@ +import tempfile +import os +import base64 +import typing + +class CA_Certificate(object): + """A CA certificate. If encoded is True the certificate will be automatically base64 decoded""" + def __init__( + self, + certificate: typing.Union[str, bytes], + encoded: bool = True + ): + super(CA_Certificate, self).__init__() + self.certificate = certificate + if encoded: + self.certificate = base64.b64decode(self.certificate) + self.tmp_file = None + + def get_file_path(self): + """saves the cetificate into a tmp file and returns the file path""" + if self.tmp_file is not None: + return self.tmp_file + _ , self.tmp_file = tempfile.mkstemp(text=True) + f = open(self.tmp_file, "wb") + f.write(self.certificate) + f.close() + return self.tmp_file + + def clean(self): + """erases the tmp_file containing the certificate""" + if self.tmp_file is not None: + os.remove(self.tmp_file) + self.tmp_file = None \ No newline at end of file diff --git a/arango/client.py b/arango/client.py index 71b4eb51..f2c8d8fc 100644 --- a/arango/client.py +++ b/arango/client.py @@ -1,7 +1,8 @@ __all__ = ["ArangoClient"] from json import dumps, loads -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Text, Union +import typing from pkg_resources import get_distribution @@ -56,6 +57,7 @@ def __init__( http_client: Optional[HTTPClient] = None, serializer: Callable[..., str] = lambda x: dumps(x), deserializer: Callable[[str], Any] = lambda x: loads(x), + cert: Optional[Union[Text, typing.Tuple[Text, Text]]] = None ) -> None: if isinstance(hosts, str): self._hosts = [host.strip("/") for host in hosts.split(",")] @@ -75,7 +77,7 @@ def __init__( self._http = http_client or DefaultHTTPClient() self._serializer = serializer self._deserializer = deserializer - self._sessions = [self._http.create_session(h) for h in self._hosts] + self._sessions = [self._http.create_session(h, cert) for h in self._hosts] def __repr__(self) -> str: return f"" @@ -111,6 +113,7 @@ def db( verify: bool = False, auth_method: str = "basic", superuser_token: Optional[str] = None, + cert: Optional[Union[Text, typing.Tuple[Text, Text]]] = None ) -> StandardDatabase: """Connect to an ArangoDB database and return the database API wrapper. @@ -160,6 +163,7 @@ def db( http_client=self._http, serializer=self._serializer, deserializer=self._deserializer, + cert=cert ) elif auth_method.lower() == "jwt": connection = JwtConnection( @@ -172,6 +176,7 @@ def db( http_client=self._http, serializer=self._serializer, deserializer=self._deserializer, + cert=cert ) else: raise ValueError(f"invalid auth_method: {auth_method}") diff --git a/arango/connection.py b/arango/connection.py index d49edfe3..e037cb46 100644 --- a/arango/connection.py +++ b/arango/connection.py @@ -10,7 +10,8 @@ import sys import time from abc import abstractmethod -from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Set, Text, Tuple, Union +import typing import jwt from requests import ConnectionError, Session @@ -23,6 +24,8 @@ from arango.response import Response from arango.typings import Fields, Json +from .ca_certificate import CA_Certificate + Connection = Union["BasicConnection", "JwtConnection", "JwtSuperuserConnection"] @@ -38,6 +41,8 @@ def __init__( http_client: HTTPClient, serializer: Callable[..., str], deserializer: Callable[[str], Any], + cert: Optional[Union[Text, typing.Tuple[Text, Text]]] = None + ) -> None: self._url_prefixes = [f"{host}/_db/{db_name}" for host in hosts] self._host_resolver = host_resolver @@ -47,6 +52,8 @@ def __init__( self._serializer = serializer self._deserializer = deserializer self._username: Optional[str] = None + self.cert: Optional[Union[Text, typing.Tuple[Text, Text]]] = None + @property def db_name(self) -> str: @@ -112,7 +119,7 @@ def prep_response(self, resp: Response, deserialize: bool = True) -> Response: return resp def process_request( - self, host_index: int, request: Request, auth: Optional[Tuple[str, str]] = None + self, host_index: int, request: Request, auth: Optional[Tuple[str, str]] = None, cert: Optional[Union[Text, typing.Tuple[Text, Text]]] = None ) -> Response: """Execute a request until a valid response has been returned. @@ -124,6 +131,9 @@ def process_request( :rtype: arango.response.Response """ tries = 0 + if cert is not None: + cert=cert.get_file_path() + indexes_to_filter: Set[int] = set() while tries < self._host_resolver.max_tries: try: @@ -135,6 +145,7 @@ def process_request( data=self.normalize_data(request.data), headers=request.headers, auth=auth, + cert=cert ) return self.prep_response(resp, request.deserialize) @@ -248,6 +259,7 @@ def __init__( http_client: HTTPClient, serializer: Callable[..., str], deserializer: Callable[[str], Any], + cert: Optional[Union[Text, typing.Tuple[Text, Text]]] ) -> None: super().__init__( hosts, @@ -257,9 +269,11 @@ def __init__( http_client, serializer, deserializer, + cert ) self._username = username self._auth = (username, password) + self._cert = cert def send_request(self, request: Request) -> Response: """Send an HTTP request to ArangoDB server. @@ -270,7 +284,7 @@ def send_request(self, request: Request) -> Response: :rtype: arango.response.Response """ host_index = self._host_resolver.get_host_index() - return self.process_request(host_index, request, auth=self._auth) + return self.process_request(host_index, request, auth=self._auth, cert=self._cert) class JwtConnection(BaseConnection): @@ -303,6 +317,8 @@ def __init__( http_client: HTTPClient, serializer: Callable[..., str], deserializer: Callable[[str], Any], + cert: Optional[Union[Text, typing.Tuple[Text, Text]]] + ) -> None: super().__init__( hosts, @@ -312,6 +328,7 @@ def __init__( http_client, serializer, deserializer, + cert ) self._username = username self._password = password diff --git a/arango/http.py b/arango/http.py index 195eecf6..3d7da83e 100644 --- a/arango/http.py +++ b/arango/http.py @@ -1,7 +1,8 @@ __all__ = ["HTTPClient", "DefaultHTTPClient"] from abc import ABC, abstractmethod -from typing import MutableMapping, Optional, Tuple, Union +from typing import MutableMapping, Optional, Text, Tuple, Union +import typing from requests import Session from requests.adapters import HTTPAdapter @@ -16,7 +17,8 @@ class HTTPClient(ABC): # pragma: no cover """Abstract base class for HTTP clients.""" @abstractmethod - def create_session(self, host: str) -> Session: + def create_session(self, host: str, cert: Optional[Union[Text, typing.Tuple[Text, Text]]] +) -> Session: """Return a new requests session given the host URL. This method must be overridden by the user. @@ -24,6 +26,8 @@ def create_session(self, host: str) -> Session: :param host: ArangoDB host URL. :type host: str :returns: Requests session object. + :param cert: Cert file location + :type Optional[Union[Text, tuple[Text, Text]]] :rtype: requests.Session """ raise NotImplementedError @@ -38,6 +42,7 @@ def send_request( params: Optional[MutableMapping[str, str]] = None, data: Union[str, MultipartEncoder, None] = None, auth: Optional[Tuple[str, str]] = None, + cert: Optional[Union[Text, typing.Tuple[Text, Text]]] = None ) -> Response: """Send an HTTP request. @@ -70,7 +75,7 @@ class DefaultHTTPClient(HTTPClient): RETRY_ATTEMPTS = 3 BACKOFF_FACTOR = 1 - def create_session(self, host: str) -> Session: + def create_session(self, host: str, cert: Optional[Union[Text, typing.Tuple[Text, Text]]]) -> Session: """Create and return a new session/connection. :param host: ArangoDB host URL. @@ -101,6 +106,7 @@ def send_request( params: Optional[MutableMapping[str, str]] = None, data: Union[str, MultipartEncoder, None] = None, auth: Optional[Tuple[str, str]] = None, + cert: Optional[Union[Text, typing.Tuple[Text, Text]]] = None ) -> Response: """Send an HTTP request. @@ -129,6 +135,7 @@ def send_request( headers=headers, auth=auth, timeout=self.REQUEST_TIMEOUT, + cert=cert ) return Response( method=method, diff --git a/arango/version.py b/arango/version.py index 0895d736..44ca331e 100644 --- a/arango/version.py +++ b/arango/version.py @@ -1,5 +1,5 @@ # coding: utf-8 # file generated by setuptools_scm # don't change, don't track in version control -version = '6.1.1.dev0+g0e82788.d20210214' -version_tuple = (6, 1, 1, 'dev0+g0e82788', 'd20210214') +version = '7.3.1.dev1+g403d74a.d20211229' +version_tuple = (7, 3, 1, 'dev1', 'g403d74a.d20211229')