Skip to content

Commit

Permalink
[http3-server] use QuicServer
Browse files Browse the repository at this point in the history
This removes a whole raft of duplicated code
  • Loading branch information
jlaine committed Aug 12, 2019
1 parent fa59c31 commit 478bbfc
Showing 1 changed file with 14 additions and 127 deletions.
141 changes: 14 additions & 127 deletions examples/http3-server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,24 @@
import importlib
import json
import logging
import os
import time
from email.utils import formatdate
from typing import Callable, Dict, Optional, Text, Union, cast
from typing import Callable, Dict, Optional, Union

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

import aioquic.quic.events
from aioquic.asyncio.protocol import QuicConnectionProtocol
from aioquic.buffer import Buffer
from aioquic.asyncio.server import QuicServer
from aioquic.h0.connection import H0Connection
from aioquic.h3.connection import H3Connection
from aioquic.h3.events import DataReceived, HttpEvent, RequestReceived
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import NetworkAddress, QuicConnection
from aioquic.quic.events import QuicEvent
from aioquic.quic.logger import QuicLogger
from aioquic.quic.packet import (
PACKET_TYPE_INITIAL,
encode_quic_retry,
encode_quic_version_negotiation,
pull_quic_header,
)
from aioquic.quic.retry import QuicRetryTokenHandler
from aioquic.tls import SessionTicket, SessionTicketFetcher, SessionTicketHandler
from aioquic.tls import SessionTicket

try:
import uvloop
Expand Down Expand Up @@ -82,128 +73,24 @@ async def send(self, message: Dict):
self.send_pending()


class HttpServer(asyncio.DatagramProtocol):
def __init__(
self,
*,
application: AsgiApplication,
configuration: QuicConfiguration,
session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
session_ticket_handler: Optional[SessionTicketHandler] = None,
stateless_retry: bool = False,
) -> None:
self._application = application
self._configuration = configuration
self._loop = asyncio.get_event_loop()
self._protocols: Dict[bytes, QuicConnectionProtocol] = {}
self._session_ticket_fetcher = session_ticket_fetcher
self._session_ticket_handler = session_ticket_handler
self._transport: Optional[asyncio.DatagramTransport] = None

if stateless_retry:
self._retry = QuicRetryTokenHandler()
else:
self._retry = None

def close(self):
for protocol in set(self._protocols.values()):
protocol.close()
self._protocols.clear()
self._transport.close()

def connection_made(self, transport: asyncio.BaseTransport) -> None:
self._transport = cast(asyncio.DatagramTransport, transport)

def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
data = cast(bytes, data)
buf = Buffer(data=data)
header = pull_quic_header(buf, host_cid_length=8)

# version negotiation
if (
header.version is not None
and header.version not in self._configuration.supported_versions
):
self._transport.sendto(
encode_quic_version_negotiation(
source_cid=header.destination_cid,
destination_cid=header.source_cid,
supported_versions=self._configuration.supported_versions,
),
addr,
)
return

protocol = self._protocols.get(header.destination_cid, None)
original_connection_id: Optional[bytes] = None
if protocol is None and header.packet_type == PACKET_TYPE_INITIAL:
# stateless retry
if self._retry is not None:
if not header.token:
# create a retry token
self._transport.sendto(
encode_quic_retry(
version=header.version,
source_cid=os.urandom(8),
destination_cid=header.source_cid,
original_destination_cid=header.destination_cid,
retry_token=self._retry.create_token(
addr, header.destination_cid
),
),
addr,
)
return
else:
# validate retry token
try:
original_connection_id = self._retry.validate_token(
addr, header.token
)
except ValueError:
return

# create new connection
connection = QuicConnection(
configuration=self._configuration,
original_connection_id=original_connection_id,
session_ticket_fetcher=self._session_ticket_fetcher,
session_ticket_handler=self._session_ticket_handler,
)
protocol = HttpServerProtocol(connection, self)
protocol.connection_made(self._transport)

self._protocols[header.destination_cid] = protocol
self._protocols[connection.host_cid] = protocol

if protocol is not None:
protocol.datagram_received(data, addr)


class HttpServerProtocol(QuicConnectionProtocol):
def __init__(self, quic: QuicConnection, server: HttpServer):
super().__init__(quic)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._handlers: Dict[int, HttpRequestHandler] = {}
self._http: Optional[HttpConnection] = None
self._server = server

def quic_event_received(self, event: QuicEvent):
if isinstance(event, aioquic.quic.events.ConnectionTerminated):
# remove the connection
for cid, protocol in list(self._server._protocols.items()):
if protocol == self:
del self._server._protocols[cid]
return
if isinstance(event, aioquic.quic.events.ConnectionIdIssued):
self._connection_id_issued_handler(event.connection_id)
elif isinstance(event, aioquic.quic.events.ConnectionIdRetired):
self._connection_id_retired_handler(event.connection_id)
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
self._connection_terminated_handler()
elif isinstance(event, aioquic.quic.events.ProtocolNegotiated):
if event.alpn_protocol == "h3-22":
self._http = H3Connection(self._quic)
elif event.alpn_protocol == "hq-22":
self._http = H0Connection(self._quic)
elif isinstance(event, aioquic.quic.events.ConnectionIdIssued):
self._server._protocols[event.connection_id] = self
elif isinstance(event, aioquic.quic.events.ConnectionIdRetired):
assert self._server._protocols[event.connection_id] == self
del self._server._protocols[event.connection_id]

#  pass event to the HTTP layer
if self._http is not None:
Expand Down Expand Up @@ -249,7 +136,7 @@ def http_event_received(self, event: HttpEvent) -> None:
stream_id=event.stream_id,
)
self._handlers[event.stream_id] = handler
asyncio.ensure_future(handler.run_asgi(self._server._application))
asyncio.ensure_future(handler.run_asgi(application))
elif isinstance(event, DataReceived):
handler = self._handlers[event.stream_id]
handler.queue.put_nowait(
Expand Down Expand Up @@ -378,9 +265,9 @@ def pop(self, label: bytes) -> Optional[SessionTicket]:
loop = asyncio.get_event_loop()
loop.run_until_complete(
loop.create_datagram_endpoint(
lambda: HttpServer(
application=application,
lambda: QuicServer(
configuration=configuration,
create_protocol=HttpServerProtocol,
session_ticket_fetcher=ticket_store.pop,
session_ticket_handler=ticket_store.add,
stateless_retry=args.stateless_retry,
Expand Down

0 comments on commit 478bbfc

Please sign in to comment.