Skip to content

Commit

Permalink
Refactor retry / version negotiation handling
Browse files Browse the repository at this point in the history
Move the code for handling retry and version negotiation to their own
methods to simplify `receive_datagram`.
  • Loading branch information
jlaine committed Jun 29, 2024
1 parent dd029b4 commit 70dd040
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 119 deletions.
253 changes: 143 additions & 110 deletions src/aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
STATELESS_RESET_TOKEN_SIZE,
QuicErrorCode,
QuicFrameType,
QuicHeader,
QuicPacketType,
QuicStreamFrame,
QuicTransportParameters,
Expand Down Expand Up @@ -791,7 +792,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
)
return

# check destination CID matches
# Check destination CID matches.
destination_cid_seq: Optional[int] = None
for connection_id in self._host_cids:
if header.destination_cid == connection_id.cid:
Expand All @@ -811,71 +812,16 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
)
return

# check protocol version
if (
self._is_client
and self._state == QuicConnectionState.FIRSTFLIGHT
and header.packet_type == QuicPacketType.VERSION_NEGOTIATION
and not self._version_negotiation_count
):
# version negotiation
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_received",
data={
"frames": [],
"header": {
"packet_type": self._quic_logger.packet_type(
header.packet_type
),
"scid": dump_cid(header.source_cid),
"dcid": dump_cid(header.destination_cid),
},
"raw": {"length": header.packet_length},
},
)
if self._version in header.supported_versions:
self._logger.warning(
"Version negotiation packet contains %s" % self._version
)
return
common = [
x
for x in self._configuration.supported_versions
if x in header.supported_versions
]
chosen_version = common[0] if common else None
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="version_information",
data={
"server_versions": header.supported_versions,
"client_versions": self._configuration.supported_versions,
"chosen_version": chosen_version,
},
)
if chosen_version is None:
self._logger.error("Could not find a common protocol version")
self._close_event = events.ConnectionTerminated(
error_code=QuicErrorCode.INTERNAL_ERROR,
frame_type=QuicFrameType.PADDING,
reason_phrase="Could not find a common protocol version",
)
self._close_end()
return
self._packet_number = 0
self._version = chosen_version
self._version_negotiation_count += 1
self._logger.info("Retrying with %s", self._version)
self._connect(now=now)
# Handle version negotiation packet.
if header.packet_type == QuicPacketType.VERSION_NEGOTIATION:
self._receive_version_negotiation_packet(header=header, now=now)
return
elif (

# Check long header packet protocol version.
if (
header.version is not None
and header.version not in self._configuration.supported_versions
):
# unsupported version
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
Expand All @@ -887,55 +833,15 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
)
return

# handle retry packet
# Handle retry packet.
if header.packet_type == QuicPacketType.RETRY:
if (
self._is_client
and not self._retry_count
and header.destination_cid == self.host_cid
and header.integrity_tag
== get_retry_integrity_tag(
buf.data_slice(
start_off, buf.tell() - RETRY_INTEGRITY_TAG_SIZE
),
self._peer_cid.cid,
version=header.version,
)
):
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_received",
data={
"frames": [],
"header": {
"packet_type": "retry",
"scid": dump_cid(header.source_cid),
"dcid": dump_cid(header.destination_cid),
},
"raw": {"length": header.packet_length},
},
)

self._peer_cid.cid = header.source_cid
self._peer_token = header.token
self._retry_count += 1
self._retry_source_connection_id = header.source_cid
self._logger.info(
"Retrying with token (%d bytes)" % len(header.token)
)
self._connect(now=now)
else:
# unexpected or invalid retry packet
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={
"trigger": "unexpected_packet",
"raw": {"length": header.packet_length},
},
)
self._receive_retry_packet(
header=header,
packet_without_tag=buf.data_slice(
start_off, buf.tell() - RETRY_INTEGRITY_TAG_SIZE
),
now=now,
)
return

crypto_frame_required = False
Expand Down Expand Up @@ -2500,6 +2406,133 @@ def _payload_received(

return is_ack_eliciting, bool(is_probing)

def _receive_retry_packet(
self, header: QuicHeader, packet_without_tag: bytes, now: float
) -> None:
"""
Handle a retry packet.
"""
if (
self._is_client
and not self._retry_count
and header.destination_cid == self.host_cid
and header.integrity_tag
== get_retry_integrity_tag(
packet_without_tag,
self._peer_cid.cid,
version=header.version,
)
):
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_received",
data={
"frames": [],
"header": {
"packet_type": "retry",
"scid": dump_cid(header.source_cid),
"dcid": dump_cid(header.destination_cid),
},
"raw": {"length": header.packet_length},
},
)

self._peer_cid.cid = header.source_cid
self._peer_token = header.token
self._retry_count += 1
self._retry_source_connection_id = header.source_cid
self._logger.info("Retrying with token (%d bytes)" % len(header.token))
self._connect(now=now)
else:
# Unexpected or invalid retry packet.
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={
"trigger": "unexpected_packet",
"raw": {"length": header.packet_length},
},
)

def _receive_version_negotiation_packet(
self, header: QuicHeader, now: float
) -> None:
"""
Handle a version negotiation packet.
This is used in "Incompatible Version Negotiation", see:
https://datatracker.ietf.org/doc/html/rfc9368#section-2.2
"""
if (
self._is_client
and self._state == QuicConnectionState.FIRSTFLIGHT
and not self._version_negotiation_count
):
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_received",
data={
"frames": [],
"header": {
"packet_type": self._quic_logger.packet_type(
header.packet_type
),
"scid": dump_cid(header.source_cid),
"dcid": dump_cid(header.destination_cid),
},
"raw": {"length": header.packet_length},
},
)
if self._version in header.supported_versions:
self._logger.warning(
"Version negotiation packet contains %s" % self._version
)
return
common = [
x
for x in self._configuration.supported_versions
if x in header.supported_versions
]
chosen_version = common[0] if common else None
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="version_information",
data={
"server_versions": header.supported_versions,
"client_versions": self._configuration.supported_versions,
"chosen_version": chosen_version,
},
)
if chosen_version is None:
self._logger.error("Could not find a common protocol version")
self._close_event = events.ConnectionTerminated(
error_code=QuicErrorCode.INTERNAL_ERROR,
frame_type=QuicFrameType.PADDING,
reason_phrase="Could not find a common protocol version",
)
self._close_end()
return
self._packet_number = 0
self._version = chosen_version
self._version_negotiation_count += 1
self._logger.info("Retrying with %s", self._version)
self._connect(now=now)
else:
# Unexpected version negotiation packet.
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={
"trigger": "unexpected_packet",
"raw": {"length": header.packet_length},
},
)

def _replenish_connection_ids(self) -> None:
"""
Generate new connection IDs.
Expand Down
47 changes: 38 additions & 9 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ def transfer(sender, receiver):


class QuicConnectionTest(TestCase):
def assertPacketDropped(self, connection: QuicConnection, trigger: str):
log = connection.configuration.quic_logger.to_dict()
found_trigger = None
for event in log["traces"][0]["events"]:
if event["name"] == "transport:packet_dropped":
found_trigger = event["data"]["trigger"]
break
self.assertEqual(found_trigger, trigger)

def check_handshake(self, client, server, alpn_protocol=None):
"""
Check handshake completed.
Expand Down Expand Up @@ -472,7 +481,9 @@ def test_connect_with_loss_2(self):
and decides to retransmit its own CRYPTO to speedup handshake completion.
"""

client_configuration = QuicConfiguration(is_client=True)
client_configuration = QuicConfiguration(
is_client=True, quic_logger=QuicLogger()
)
client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE)

client = QuicConnection(configuration=client_configuration)
Expand Down Expand Up @@ -513,6 +524,8 @@ def test_connect_with_loss_2(self):
self.assertAlmostEqual(client.get_timer(), 0.3)
self.assertIsNone(client.next_event())

self.assertPacketDropped(client, "key_unavailable")

# server receives duplicate INITIAL, retransmits INITIAL + HANDSHAKE
now += TICK
server.receive_datagram(items[0][0], CLIENT_ADDR, now=now)
Expand Down Expand Up @@ -856,14 +869,9 @@ def test_initial_that_is_too_small(self):

for datagram in builder.flush()[0]:
server.receive_datagram(datagram, SERVER_ADDR, now=time.time())
# look for the drop event
log = server_configuration.quic_logger.to_dict()
trigger = None
for event in log["traces"][0]["events"]:
trigger = event["data"].get("trigger")
if trigger is not None:
break
self.assertEqual(trigger, "initial_packet_datagram_too_small")

# Look for the drop event.
self.assertPacketDropped(server, "initial_packet_datagram_too_small")

def test_connect_with_no_crypto_frame(self):
def patch(client):
Expand Down Expand Up @@ -1240,6 +1248,8 @@ def test_receive_datagram_wrong_version(self):
client.receive_datagram(datagram, SERVER_ADDR, now=time.time())
self.assertEqual(drop(client), 0)

self.assertPacketDropped(client, "unsupported_version")

def test_receive_datagram_retry(self):
client = create_standalone_client(self)

Expand All @@ -1259,6 +1269,7 @@ def test_receive_datagram_retry(self):
def test_receive_datagram_retry_wrong_destination_cid(self):
client = create_standalone_client(self)

# The client does not reply to a retry packet with a wrong destination CID.
client.receive_datagram(
encode_quic_retry(
version=client._version,
Expand All @@ -1272,6 +1283,8 @@ def test_receive_datagram_retry_wrong_destination_cid(self):
)
self.assertEqual(drop(client), 0)

self.assertPacketDropped(client, "unknown_connection_id")

def test_receive_datagram_retry_wrong_integrity_tag(self):
client = create_standalone_client(self)

Expand Down Expand Up @@ -2982,6 +2995,22 @@ def test_version_negotiation_ignore(self):
)
self.assertEqual(drop(client), 0)

def test_version_negotiation_ignore_server(self):
with client_and_server() as (client, server):
# The server does not reply to the version negotiation packet.
server.receive_datagram(
encode_quic_version_negotiation(
source_cid=server._peer_cid.cid,
destination_cid=server.host_cid,
supported_versions=[QuicProtocolVersion.VERSION_1],
),
CLIENT_ADDR,
now=time.time(),
)
self.assertEqual(drop(client), 0)

self.assertPacketDropped(server, "unexpected_packet")

def test_version_negotiation_ok(self):
client = create_standalone_client(
self,
Expand Down

0 comments on commit 70dd040

Please sign in to comment.