Skip to content

Commit

Permalink
[asyncio] pass QuicConfiguration to connect()
Browse files Browse the repository at this point in the history
This avoids the need to add keyword arguments whenever QuicConfiguration
changes.
  • Loading branch information
jlaine committed Aug 12, 2019
1 parent 6df7695 commit 57a9e52
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 84 deletions.
30 changes: 6 additions & 24 deletions aioquic/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@ async def connect(
host: str,
port: int,
*,
alpn_protocols: Optional[List[str]] = None,
idle_timeout: Optional[float] = None,
quic_logger: Optional[QuicLogger] = None,
secrets_log_file: Optional[TextIO] = None,
session_ticket: Optional[SessionTicket] = None,
configuration: Optional[QuicConfiguration] = None,
session_ticket_handler: Optional[SessionTicketHandler] = None,
stream_handler: Optional[QuicStreamHandler] = None,
supported_versions: Optional[List[int]] = None,
) -> AsyncGenerator[QuicConnectionProtocol, None]:
"""
Connect to a QUIC server at the given `host` and `port`.
Expand All @@ -36,12 +31,7 @@ async def connect(
:func:`connect` also accepts the following optional arguments:
* ``alpn_protocols`` is a list of ALPN protocols to offer in the
ClientHello.
* ``secrets_log_file`` is a file-like object in which to log traffic
secrets. This is useful to analyze traffic captures with Wireshark.
* ``session_ticket`` is a TLS session ticket which should be used for
resumption.
* ``configuration`` is a QUIC configuration object.
* ``session_ticket_handler`` is a callback which is invoked by the TLS
engine when a new session ticket is received.
* ``stream_handler`` is a callback which is invoked whenever a stream is
Expand All @@ -63,18 +53,10 @@ async def connect(
if len(addr) == 2:
addr = ("::ffff:" + addr[0], addr[1], 0, 0)

configuration = QuicConfiguration(
alpn_protocols=alpn_protocols,
is_client=True,
quic_logger=quic_logger,
secrets_log_file=secrets_log_file,
server_name=server_name,
session_ticket=session_ticket,
)
if idle_timeout is not None:
configuration.idle_timeout = idle_timeout
if supported_versions is not None:
configuration.supported_versions = supported_versions
if configuration is None:
configuration = QuicConfiguration(is_client=True)
if server_name is not None:
configuration.server_name = server_name

connection = QuicConnection(
configuration=configuration, session_ticket_handler=session_ticket_handler
Expand Down
93 changes: 48 additions & 45 deletions examples/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from aioquic.asyncio import connect
from aioquic.asyncio.protocol import QuicConnectionProtocol
from aioquic.h3.connection import H3Connection
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.logger import QuicLogger
from aioquic.quic.packet import QuicProtocolVersion

Expand Down Expand Up @@ -112,21 +113,18 @@ async def http3_request(connection: QuicConnectionProtocol, authority: str, path
return await reader.read()


async def test_version_negotiation(server: Server, **kwargs):
quic_logger = QuicLogger()
async def test_version_negotiation(server: Server, configuration: QuicConfiguration):
configuration.supported_versions = [0x1A2A3A4A, QuicProtocolVersion.DRAFT_22]

async with connect(
server.host,
server.port,
quic_logger=quic_logger,
supported_versions=[0x1A2A3A4A, QuicProtocolVersion.DRAFT_22],
**kwargs
server.host, server.port, configuration=configuration
) as connection:
await connection.ping()

# check log
for stamp, category, event, data in quic_logger.to_dict()["traces"][0][
"events"
]:
for stamp, category, event, data in configuration.quic_logger.to_dict()[
"traces"
][0]["events"]:
if (
category == "TRANSPORT"
and event == "PACKET_RECEIVED"
Expand All @@ -135,24 +133,25 @@ async def test_version_negotiation(server: Server, **kwargs):
server.result |= Result.V


async def test_handshake_and_close(server: Server, **kwargs):
async with connect(server.host, server.port, **kwargs) as connection:
async def test_handshake_and_close(server: Server, configuration: QuicConfiguration):
async with connect(
server.host, server.port, configuration=configuration
) as connection:
await connection.ping()
server.result |= Result.H
server.result |= Result.C


async def test_stateless_retry(server: Server, **kwargs):
quic_logger = QuicLogger()
async def test_stateless_retry(server: Server, configuration: QuicConfiguration):
async with connect(
server.host, server.retry_port, quic_logger=quic_logger, **kwargs
server.host, server.retry_port, configuration=configuration
) as connection:
await connection.ping()

# check log
for stamp, category, event, data in quic_logger.to_dict()["traces"][0][
"events"
]:
for stamp, category, event, data in configuration.quic_logger.to_dict()[
"traces"
][0]["events"]:
if (
category == "TRANSPORT"
and event == "PACKET_RECEIVED"
Expand All @@ -161,30 +160,34 @@ async def test_stateless_retry(server: Server, **kwargs):
server.result |= Result.S


async def test_http_0(server: Server, **kwargs):
async def test_http_0(server: Server, configuration: QuicConfiguration):
if server.path is None:
return

kwargs["alpn_protocols"] = ["hq-22"]
async with connect(server.host, server.port, **kwargs) as connection:
configuration.alpn_protocols = ["hq-22"]
async with connect(
server.host, server.port, configuration=configuration
) as connection:
response = await http_request(connection, server.path)
if response:
server.result |= Result.D


async def test_http_3(server: Server, **kwargs):
async def test_http_3(server: Server, configuration: QuicConfiguration):
if server.path is None:
return

kwargs["alpn_protocols"] = ["h3-22"]
async with connect(server.host, server.port, **kwargs) as connection:
configuration.alpn_protocols = ["h3-22"]
async with connect(
server.host, server.port, configuration=configuration
) as connection:
response = await http3_request(connection, server.host, server.path)
if response:
server.result |= Result.D
server.result |= Result.three


async def test_session_resumption(server: Server, **kwargs):
async def test_session_resumption(server: Server, configuration: QuicConfiguration):
saved_ticket = None

def session_ticket_handler(ticket):
Expand All @@ -195,15 +198,16 @@ def session_ticket_handler(ticket):
async with connect(
server.host,
server.port,
configuration=configuration,
session_ticket_handler=session_ticket_handler,
**kwargs
) as connection:
await connection.ping()

# connect a second time, with the ticket
if saved_ticket is not None:
configuration.session_ticket = saved_ticket
async with connect(
server.host, server.port, session_ticket=saved_ticket, **kwargs
server.host, server.port, configuration=configuration
) as connection:
await connection.ping()

Expand All @@ -216,8 +220,10 @@ def session_ticket_handler(ticket):
server.result |= Result.Z


async def test_key_update(server: Server, **kwargs):
async with connect(server.host, server.port, **kwargs) as connection:
async def test_key_update(server: Server, configuration: QuicConfiguration):
async with connect(
server.host, server.port, configuration=configuration
) as connection:
# cause some traffic
await connection.ping()

Expand All @@ -230,19 +236,18 @@ async def test_key_update(server: Server, **kwargs):
server.result |= Result.U


async def test_spin_bit(server: Server, **kwargs):
quic_logger = QuicLogger()
async def test_spin_bit(server: Server, configuration: QuicConfiguration):
async with connect(
server.host, server.port, quic_logger=quic_logger, **kwargs
server.host, server.port, configuration=configuration
) as connection:
for i in range(5):
await connection.ping()

# check log
spin_bits = set()
for stamp, category, event, data in quic_logger.to_dict()["traces"][0][
"events"
]:
for stamp, category, event, data in configuration.quic_logger.to_dict()[
"traces"
][0]["events"]:
if category == "CONNECTIVITY" and event == "SPIN_BIT_UPDATE":
spin_bits.add(data["state"])
if len(spin_bits) == 2:
Expand All @@ -255,12 +260,17 @@ def print_result(server: Server) -> None:
print("%s%s%s" % (server.name, " " * (20 - len(server.name)), result))


async def run(servers, tests, **kwargs) -> None:
async def run(servers, tests) -> None:
for server in servers:
for test_name, test_func in tests:
print("\n=== %s %s ===\n" % (server.name, test_name))
configuration = QuicConfiguration(
alpn_protocols=["hq-22", "h3-22"],
is_client=True,
quic_logger=QuicLogger(),
)
try:
await asyncio.wait_for(test_func(server, **kwargs), timeout=5)
await asyncio.wait_for(test_func(server, configuration), timeout=5)
except Exception as exc:
print(exc)
print("")
Expand Down Expand Up @@ -305,11 +315,4 @@ async def run(servers, tests, **kwargs) -> None:
tests = list(filter(lambda x: x[0] == args.test, tests))

loop = asyncio.get_event_loop()
loop.run_until_complete(
run(
alpn_protocols=["hq-22", "h3-22"],
servers=servers,
tests=tests,
secrets_log_file=secrets_log_file,
)
)
loop.run_until_complete(run(servers=servers, tests=tests))
51 changes: 36 additions & 15 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from aioquic.asyncio.client import connect
from aioquic.asyncio.server import serve
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.logger import QuicLogger
from aioquic.quic.packet import QuicProtocolVersion

Expand Down Expand Up @@ -83,8 +84,8 @@ def test_connect_and_serve_large(self):
server.close()

def test_connect_and_serve_writelines(self):
async def run_client_writelines(host, port=4433, **kwargs):
async with connect(host, port, **kwargs) as client:
async def run_client_writelines(host, port=4433):
async with connect(host, port) as client:
reader, writer = await client.create_stream()
assert writer.can_write_eof() is True

Expand Down Expand Up @@ -113,8 +114,9 @@ def test_connect_and_serve_with_packet_loss(self, mock_sendto):
),
run_client(
"127.0.0.1",
idle_timeout=300.0,
quic_logger=QuicLogger(),
configuration=QuicConfiguration(
is_client=True, idle_timeout=300.0, quic_logger=QuicLogger()
),
request=data,
),
)
Expand Down Expand Up @@ -146,7 +148,12 @@ def save_ticket(t):
server, response = run(
asyncio.gather(
run_server(session_ticket_fetcher=store.pop),
run_client("127.0.0.1", session_ticket=client_ticket),
run_client(
"127.0.0.1",
configuration=QuicConfiguration(
is_client=True, session_ticket=client_ticket
),
),
)
)
self.assertEqual(response, b"gnip")
Expand All @@ -170,7 +177,12 @@ def test_connect_and_serve_with_stateless_retry_bad(self, mock_validate):

server = run(run_server(stateless_retry=True))
with self.assertRaises(ConnectionError):
run(run_client("127.0.0.1", idle_timeout=4.0))
run(
run_client(
"127.0.0.1",
configuration=QuicConfiguration(is_client=True, idle_timeout=4.0),
)
)
server.close()

def test_connect_and_serve_with_version_negotiation(self):
Expand All @@ -179,8 +191,11 @@ def test_connect_and_serve_with_version_negotiation(self):
run_server(),
run_client(
"127.0.0.1",
quic_logger=QuicLogger(),
supported_versions=[0x1A2A3A4A, QuicProtocolVersion.DRAFT_22],
configuration=QuicConfiguration(
is_client=True,
quic_logger=QuicLogger(),
supported_versions=[0x1A2A3A4A, QuicProtocolVersion.DRAFT_22],
),
),
)
)
Expand All @@ -189,11 +204,17 @@ def test_connect_and_serve_with_version_negotiation(self):

def test_connect_timeout(self):
with self.assertRaises(ConnectionError):
run(run_client("127.0.0.1", port=4400, idle_timeout=5))
run(
run_client(
"127.0.0.1",
port=4400,
configuration=QuicConfiguration(is_client=True, idle_timeout=5),
)
)

def test_change_connection_id(self):
async def run_client_key_update(host, **kwargs):
async with connect(host, 4433, **kwargs) as client:
async def run_client_key_update(host, port=4433):
async with connect(host, port) as client:
await client.ping()
client.change_connection_id()
await client.ping()
Expand All @@ -206,8 +227,8 @@ async def run_client_key_update(host, **kwargs):
server.close()

def test_key_update(self):
async def run_client_key_update(host, **kwargs):
async with connect(host, 4433, **kwargs) as client:
async def run_client_key_update(host, port=4433):
async with connect(host, port) as client:
await client.ping()
client.request_key_update()
await client.ping()
Expand All @@ -220,8 +241,8 @@ async def run_client_key_update(host, **kwargs):
server.close()

def test_ping(self):
async def run_client_ping(host, **kwargs):
async with connect(host, 4433, **kwargs) as client:
async def run_client_ping(host, port=4433):
async with connect(host, port) as client:
await client.ping()
await client.ping()

Expand Down

0 comments on commit 57a9e52

Please sign in to comment.