Skip to content

Commit

Permalink
During address validation, count the entire received datagram
Browse files Browse the repository at this point in the history
For anti-amplification purposes, servers need to keep track of the
amount of data received on unvalidated network paths. We must count the
entire datagram size regardless of whether packets are processed or
dropped.

This is particularly important when talking to clients who pad
datagrams containing INITIAL packets by appending bytes after the
long-header packets, which is legitimate behaviour.
  • Loading branch information
jlaine committed Jul 1, 2024
1 parent 79a8caf commit afe5525
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions src/aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,13 +753,14 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
:param addr: The network address from which the datagram was received.
:param now: The current time.
"""
payload_length = len(data)

# stop handling packets when closing
if self._state in END_STATES:
return

# log datagram
if self._quic_logger is not None:
payload_length = len(data)
self._quic_logger.log_event(
category="transport",
event="datagrams_received",
Expand All @@ -774,6 +775,20 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
},
)

# For anti-amplification purposes, servers need to keep track of the
# amount of data received on unvalidated network paths. We must count the
# entire datagram size regardless of whether packets are processed or
# dropped.
#
# This is particularly important when talking to clients who pad
# datagrams containing INITIAL packets by appending bytes after the
# long-header packets, which is legitimate behaviour.
#
# https://datatracker.ietf.org/doc/html/rfc9000#section-8.1
network_path = self._find_network_path(addr)
if not network_path.is_validated:
network_path.bytes_received += payload_length

# for servers, arm the idle timeout on the first datagram
if self._close_at is None:
self._close_at = now + self._idle_timeout()
Expand Down Expand Up @@ -802,7 +817,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
if (
not self._is_client
and header.packet_type == QuicPacketType.INITIAL
and len(data) < SMALLEST_MAX_DATAGRAM_SIZE
and payload_length < SMALLEST_MAX_DATAGRAM_SIZE
):
if self._quic_logger is not None:
self._quic_logger.log_event(
Expand Down Expand Up @@ -868,9 +883,8 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
return

crypto_frame_required = False
network_path = self._find_network_path(addr)

# server initialization
# Server initialization.
if not self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT:
assert (
header.packet_type == QuicPacketType.INITIAL
Expand All @@ -880,7 +894,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
self._version = header.version
self._initialize(header.destination_cid)

# determine crypto and packet space
# Determine crypto and packet space.
epoch = get_epoch(header.packet_type)
if epoch == tls.Epoch.INITIAL:
crypto = self._cryptos_initial[header.version]
Expand Down Expand Up @@ -1051,7 +1065,6 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
"Network path %s validated by handshake", network_path.addr
)
network_path.is_validated = True
network_path.bytes_received += end_off - start_off
if network_path not in self._network_paths:
self._network_paths.append(network_path)
idx = self._network_paths.index(network_path)
Expand Down

0 comments on commit afe5525

Please sign in to comment.