Skip to content

Commit

Permalink
Limit the number of pending connection IDs marked for retirement.
Browse files Browse the repository at this point in the history
  • Loading branch information
rthalley authored and jlaine committed Mar 12, 2024
1 parent c32862a commit 4f73f18
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 48 deletions.
50 changes: 37 additions & 13 deletions src/aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
STREAM_FLAGS = 0x07
STREAM_COUNT_MAX = 0x1000000000000000
UDP_HEADER_SIZE = 8
MAX_PENDING_RETIRES = 100

NetworkAddress = Any

Expand Down Expand Up @@ -338,6 +339,7 @@ def __init__(
)
self._peer_cid_available: List[QuicConnectionId] = []
self._peer_cid_sequence_numbers: Set[int] = set([0])
self._peer_retire_prior_to = 0
self._peer_token = configuration.token
self._quic_logger: Optional[QuicLoggerTrace] = None
self._remote_ack_delay_exponent = 3
Expand Down Expand Up @@ -1914,24 +1916,30 @@ def _handle_new_connection_id_frame(
reason_phrase="Retire Prior To is greater than Sequence Number",
)

# only accept retire_prior_to if it is bigger than the one we know
self._peer_retire_prior_to = max(retire_prior_to, self._peer_retire_prior_to)

# determine which CIDs to retire
change_cid = False
retire = list(
filter(
lambda c: c.sequence_number < retire_prior_to, self._peer_cid_available
)
)
if self._peer_cid.sequence_number < retire_prior_to:
retire = [
cid
for cid in self._peer_cid_available
if cid.sequence_number < self._peer_retire_prior_to
]
if self._peer_cid.sequence_number < self._peer_retire_prior_to:
change_cid = True
retire.insert(0, self._peer_cid)

# update available CIDs
self._peer_cid_available = list(
filter(
lambda c: c.sequence_number >= retire_prior_to, self._peer_cid_available
)
)
if sequence_number not in self._peer_cid_sequence_numbers:
self._peer_cid_available = [
cid
for cid in self._peer_cid_available
if cid.sequence_number >= self._peer_retire_prior_to
]
if (
sequence_number >= self._peer_retire_prior_to
and sequence_number not in self._peer_cid_sequence_numbers
):
self._peer_cid_available.append(
QuicConnectionId(
cid=connection_id,
Expand All @@ -1957,6 +1965,21 @@ def _handle_new_connection_id_frame(
reason_phrase="Too many active connection IDs",
)

# Check the number of retired connection IDs pending, though with a safer limit
# than the 2x recommended in section 5.1.2 of the RFC. Note that we are doing
# the check here and not in _retire_peer_cid() because we know the frame type to
# use here, and because it is the new connection id path that is potentially
# dangerous. We may transiently go a bit over the limit due to unacked frames
# getting added back to the list, but that's ok as it is bounded.
if len(self._retire_connection_ids) > min(
self._local_active_connection_id_limit * 4, MAX_PENDING_RETIRES
):
raise QuicConnectionError(
error_code=QuicErrorCode.CONNECTION_ID_LIMIT_ERROR,
frame_type=frame_type,
reason_phrase="Too many pending retired connection IDs",
)

def _handle_new_token_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
Expand Down Expand Up @@ -2484,9 +2507,10 @@ def _retire_peer_cid(self, connection_id: QuicConnectionId) -> None:
Retire a destination connection ID.
"""
self._logger.debug(
"Retiring CID %s (%d)",
"Retiring CID %s (%d) [%d]",
dump_cid(connection_id.cid),
connection_id.sequence_number,
len(self._retire_connection_ids) + 1,
)
self._retire_connection_ids.append(connection_id.sequence_number)

Expand Down
123 changes: 88 additions & 35 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ def datagram_sizes(items: List[Tuple[bytes, NetworkAddress]]) -> List[int]:
return [len(x[0]) for x in items]


def new_connection_id(
*,
sequence_number: int,
retire_prior_to: int = 0,
connection_id: bytes = bytes(8),
capacity: int = 100,
):
buf = Buffer(capacity=capacity)
buf.push_uint_var(sequence_number)
buf.push_uint_var(retire_prior_to)
buf.push_uint_var(len(connection_id))
buf.push_bytes(connection_id)
buf.push_bytes(bytes(16)) # stateless reset token
buf.seek(0)
return buf


@contextlib.contextmanager
def client_and_server(
client_kwargs={},
Expand Down Expand Up @@ -1574,13 +1591,7 @@ def test_handle_max_streams_uni_frame(self):

def test_handle_new_connection_id_duplicate(self):
with client_and_server() as (client, server):
buf = Buffer(capacity=100)
buf.push_uint_var(7) # sequence_number
buf.push_uint_var(0) # retire_prior_to
buf.push_uint_var(8)
buf.push_bytes(bytes(8))
buf.push_bytes(bytes(16))
buf.seek(0)
buf = new_connection_id(sequence_number=7)

# client receives NEW_CONNECTION_ID
client._handle_new_connection_id_frame(
Expand All @@ -1596,13 +1607,7 @@ def test_handle_new_connection_id_duplicate(self):

def test_handle_new_connection_id_over_limit(self):
with client_and_server() as (client, server):
buf = Buffer(capacity=100)
buf.push_uint_var(8) # sequence_number
buf.push_uint_var(0) # retire_prior_to
buf.push_uint_var(8)
buf.push_bytes(bytes(8))
buf.push_bytes(bytes(16))
buf.seek(0)
buf = new_connection_id(sequence_number=8)

# client receives NEW_CONNECTION_ID
with self.assertRaises(QuicConnectionError) as cm:
Expand All @@ -1621,13 +1626,7 @@ def test_handle_new_connection_id_over_limit(self):

def test_handle_new_connection_id_with_retire_prior_to(self):
with client_and_server() as (client, server):
buf = Buffer(capacity=42)
buf.push_uint_var(8) # sequence_number
buf.push_uint_var(2) # retire_prior_to
buf.push_uint_var(8)
buf.push_bytes(bytes(8))
buf.push_bytes(bytes(16))
buf.seek(0)
buf = new_connection_id(sequence_number=8, retire_prior_to=2, capacity=42)

# client receives NEW_CONNECTION_ID
client._handle_new_connection_id_frame(
Expand All @@ -1641,15 +1640,75 @@ def test_handle_new_connection_id_with_retire_prior_to(self):
sequence_numbers(client._peer_cid_available), [3, 4, 5, 6, 7, 8]
)

def test_handle_new_connection_id_with_retire_prior_to_lower(self):
with client_and_server() as (client, server):
buf = new_connection_id(sequence_number=80, retire_prior_to=80)

# client receives NEW_CONNECTION_ID
client._handle_new_connection_id_frame(
client_receive_context(client),
QuicFrameType.NEW_CONNECTION_ID,
buf,
)

self.assertEqual(client._peer_cid.sequence_number, 80)
self.assertEqual(sequence_numbers(client._peer_cid_available), [])

buf = new_connection_id(sequence_number=30, retire_prior_to=30)

# client receives NEW_CONNECTION_ID
client._handle_new_connection_id_frame(
client_receive_context(client),
QuicFrameType.NEW_CONNECTION_ID,
buf,
)

self.assertEqual(client._peer_cid.sequence_number, 80)
self.assertEqual(sequence_numbers(client._peer_cid_available), [])

def test_handle_excessive_new_connection_id_retires(self):
with client_and_server() as (client, server):
for i in range(25):
sequence_number = 8 + i
buf = new_connection_id(
sequence_number=sequence_number, retire_prior_to=sequence_number
)

# client receives NEW_CONNECTION_ID
client._handle_new_connection_id_frame(
client_receive_context(client),
QuicFrameType.NEW_CONNECTION_ID,
buf,
)

# So far, so good! We should be at the (default) limit of 4*8 pending
# retirements.
self.assertEqual(len(client._retire_connection_ids), 32)

# Now we will go one too many!
sequence_number = 8 + 25
buf = new_connection_id(
sequence_number=sequence_number, retire_prior_to=sequence_number
)
with self.assertRaises(QuicConnectionError) as cm:
client._handle_new_connection_id_frame(
client_receive_context(client),
QuicFrameType.NEW_CONNECTION_ID,
buf,
)
self.assertEqual(
cm.exception.error_code, QuicErrorCode.CONNECTION_ID_LIMIT_ERROR
)
self.assertEqual(cm.exception.frame_type, QuicFrameType.NEW_CONNECTION_ID)
self.assertEqual(
cm.exception.reason_phrase, "Too many pending retired connection IDs"
)

def test_handle_new_connection_id_with_connection_id_invalid(self):
with client_and_server() as (client, server):
buf = Buffer(capacity=100)
buf.push_uint_var(8) # sequence_number
buf.push_uint_var(2) # retire_prior_to
buf.push_uint_var(21)
buf.push_bytes(bytes(21))
buf.push_bytes(bytes(16))
buf.seek(0)
buf = new_connection_id(
sequence_number=8, retire_prior_to=2, connection_id=bytes(21)
)

# client receives NEW_CONNECTION_ID
with self.assertRaises(QuicConnectionError) as cm:
Expand All @@ -1670,13 +1729,7 @@ def test_handle_new_connection_id_with_connection_id_invalid(self):

def test_handle_new_connection_id_with_retire_prior_to_invalid(self):
with client_and_server() as (client, server):
buf = Buffer(capacity=100)
buf.push_uint_var(8) # sequence_number
buf.push_uint_var(9) # retire_prior_to
buf.push_uint_var(8)
buf.push_bytes(bytes(8))
buf.push_bytes(bytes(16))
buf.seek(0)
buf = new_connection_id(sequence_number=8, retire_prior_to=9)

# client receives NEW_CONNECTION_ID
with self.assertRaises(QuicConnectionError) as cm:
Expand Down

0 comments on commit 4f73f18

Please sign in to comment.