Skip to content

Commit

Permalink
Improve padding of coalesced datagrams containing INITIAL
Browse files Browse the repository at this point in the history
Our previous padding algorithm padded all client-sent or ack-eliciting
INITIAL packets to a full datagram size. While this satisfies the
specification requirements, the downside is that it made it impossible
to coalesce any other packets after the INITIAL.

We now mostly defer the padding decision until the datagram is finalised
and perform padding by appending zeroes at the end of the datagram. As
an exception to this rule, in the presence of short-header packets we
insert the padding inside the packet.
  • Loading branch information
jlaine committed Jul 1, 2024
1 parent afe5525 commit 6987588
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 63 deletions.
35 changes: 27 additions & 8 deletions src/aioquic/quic/packet_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
self._datagrams: List[bytes] = []
self._datagram_flight_bytes = 0
self._datagram_init = True
self._datagram_needs_padding = False
self._packets: List[QuicSentPacket] = []
self._flight_bytes = 0
self._total_bytes = 0
Expand Down Expand Up @@ -217,6 +218,7 @@ def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None:
self._flight_capacity = remaining_flight_bytes
self._datagram_flight_bytes = 0
self._datagram_init = False
self._datagram_needs_padding = False

# calculate header size
if packet_type != QuicPacketType.ONE_RTT:
Expand Down Expand Up @@ -270,15 +272,23 @@ def _end_packet(self) -> None:
- packet_size
)

# Padding for initial packets; see RFC 9000 section
# 14.1.
# Padding for datagrams containing initial packets; see RFC 9000
# section 14.1.
if (
(self._is_client or self._packet.is_ack_eliciting)
and self._packet_type == QuicPacketType.INITIAL
and self.remaining_flight_space
and self.remaining_flight_space > padding_size
self._is_client or self._packet.is_ack_eliciting
) and self._packet_type == QuicPacketType.INITIAL:
self._datagram_needs_padding = True

# For datagrams containing 1-RTT data, we *must* apply the padding
# inside the packet, we cannot tack bytes onto the end of the
# datagram.
if (
self._datagram_needs_padding
and self._packet_type == QuicPacketType.ONE_RTT
):
padding_size = self.remaining_flight_space
if self.remaining_flight_space > padding_size:
padding_size = self.remaining_flight_space
self._datagram_needs_padding = False

# write padding
if padding_size > 0:
Expand Down Expand Up @@ -343,7 +353,7 @@ def _end_packet(self) -> None:
if self._packet.in_flight:
self._datagram_flight_bytes += self._packet.sent_bytes

# short header packets cannot be coalesced, we need a new datagram
# Short header packets cannot be coalesced, we need a new datagram.
if self._packet_type == QuicPacketType.ONE_RTT:
self._flush_current_datagram()

Expand All @@ -358,6 +368,15 @@ def _end_packet(self) -> None:
def _flush_current_datagram(self) -> None:
datagram_bytes = self._buffer.tell()
if datagram_bytes:
# Padding for datagrams containing initial packets; see RFC 9000
# section 14.1.
if self._datagram_needs_padding:
extra_bytes = self._flight_capacity - self._buffer.tell()
if extra_bytes > 0:
self._buffer.push_bytes(bytes(extra_bytes))
self._datagram_flight_bytes += extra_bytes
datagram_bytes += extra_bytes

self._datagrams.append(self._buffer.data)
self._flight_bytes += self._datagram_flight_bytes
self._total_bytes += datagram_bytes
Expand Down
64 changes: 28 additions & 36 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@
)

CLIENT_ADDR = ("1.2.3.4", 1234)
CLIENT_HANDSHAKE_DATAGRAM_SIZES = [1200]

SERVER_ADDR = ("2.3.4.5", 4433)
SERVER_INITIAL_DATAGRAM_SIZES = [1200, 1200, 986]
SERVER_INITIAL_DATAGRAM_SIZES = [1200, 1162]

HANDSHAKE_COMPLETED_EVENTS = [
events.HandshakeCompleted,
Expand Down Expand Up @@ -464,9 +465,8 @@ def test_connect_without_loss(self):
now += TICK
client.receive_datagram(items[0][0], SERVER_ADDR, now=now)
client.receive_datagram(items[1][0], SERVER_ADDR, now=now)
client.receive_datagram(items[2][0], SERVER_ADDR, now=now)
items = client.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [1200, 327])
self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES)
self.assertAlmostEqual(client.get_timer(), 0.425)
self.assertSentPackets(client, [0, 1, 1])
self.assertEvents(
Expand All @@ -475,7 +475,6 @@ def test_connect_without_loss(self):

now += TICK
server.receive_datagram(items[0][0], CLIENT_ADDR, now=now)
server.receive_datagram(items[1][0], CLIENT_ADDR, now=now)
items = server.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [229])
self.assertAlmostEqual(server.get_timer(), 0.425)
Expand Down Expand Up @@ -529,9 +528,8 @@ def test_connect_with_loss_1(self):
now += TICK
client.receive_datagram(items[0][0], SERVER_ADDR, now=now)
client.receive_datagram(items[1][0], SERVER_ADDR, now=now)
client.receive_datagram(items[2][0], SERVER_ADDR, now=now)
items = client.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [1200, 327])
self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES)
self.assertAlmostEqual(client.get_timer(), 0.625)
self.assertSentPackets(client, [0, 1, 1])
self.assertEvents(
Expand All @@ -540,7 +538,6 @@ def test_connect_with_loss_1(self):

now += TICK
server.receive_datagram(items[0][0], CLIENT_ADDR, now=now)
server.receive_datagram(items[1][0], CLIENT_ADDR, now=now)
items = server.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [229])
self.assertAlmostEqual(server.get_timer(), 0.625)
Expand Down Expand Up @@ -607,9 +604,8 @@ def test_connect_with_loss_2(self):
now += TICK
client.receive_datagram(items[0][0], SERVER_ADDR, now=now)
client.receive_datagram(items[1][0], SERVER_ADDR, now=now)
client.receive_datagram(items[2][0], SERVER_ADDR, now=now)
items = client.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [1200, 327])
self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES)
self.assertAlmostEqual(client.get_timer(), 0.525)
self.assertSentPackets(client, [0, 1, 1])
self.assertEvents(
Expand All @@ -618,7 +614,6 @@ def test_connect_with_loss_2(self):

now += TICK
server.receive_datagram(items[0][0], CLIENT_ADDR, now=now)
server.receive_datagram(items[1][0], CLIENT_ADDR, now=now)
items = server.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [229])
self.assertAlmostEqual(server.get_timer(), 0.525)
Expand Down Expand Up @@ -683,9 +678,8 @@ def test_connect_with_loss_3(self):
now += TICK
client.receive_datagram(items[0][0], SERVER_ADDR, now=now)
client.receive_datagram(items[1][0], SERVER_ADDR, now=now)
client.receive_datagram(items[2][0], SERVER_ADDR, now=now)
items = client.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [1200, 327])
self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES)
self.assertAlmostEqual(client.get_timer(), 0.625)
self.assertSentPackets(client, [0, 1, 1])
self.assertEvents(
Expand All @@ -694,7 +688,6 @@ def test_connect_with_loss_3(self):

now += TICK
server.receive_datagram(items[0][0], CLIENT_ADDR, now=now)
server.receive_datagram(items[1][0], CLIENT_ADDR, now=now)
items = server.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [229])
self.assertAlmostEqual(server.get_timer(), 0.625)
Expand Down Expand Up @@ -733,12 +726,11 @@ def test_connect_with_loss_4(self):
self.assertSentPackets(server, [1, 2, 0])
self.assertEvents(server, [events.ProtocolNegotiated])

# client only receives first two datagrams and sends ACKS
# client only receives the first datagram and sends ACKS
now += TICK
client.receive_datagram(items[0][0], SERVER_ADDR, now=now)
client.receive_datagram(items[1][0], SERVER_ADDR, now=now)
items = client.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [1200, 48])
self.assertEqual(datagram_sizes(items), [1200])
self.assertAlmostEqual(client.get_timer(), 0.325)
self.assertSentPackets(client, [0, 1, 0])
self.assertEvents(client, [events.ProtocolNegotiated])
Expand Down Expand Up @@ -821,9 +813,8 @@ def test_connect_with_loss_5(self):
now += TICK
client.receive_datagram(items[0][0], SERVER_ADDR, now=now)
client.receive_datagram(items[1][0], SERVER_ADDR, now=now)
client.receive_datagram(items[2][0], SERVER_ADDR, now=now)
items = client.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [1200, 327])
self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES)
self.assertAlmostEqual(client.get_timer(), 0.425)
self.assertSentPackets(client, [0, 1, 1])
self.assertEvents(
Expand All @@ -833,7 +824,6 @@ def test_connect_with_loss_5(self):
# server completes handshake, but HANDSHAKE_DONE is lost
now += TICK
server.receive_datagram(items[0][0], CLIENT_ADDR, now=now)
server.receive_datagram(items[1][0], CLIENT_ADDR, now=now)
items = server.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [229])
self.assertAlmostEqual(server.get_timer(), 0.425)
Expand Down Expand Up @@ -1066,7 +1056,7 @@ def save_session_ticket(ticket):
stream_id = client.get_next_available_stream_id()
client.send_stream_data(stream_id, b"hello")

self.assertEqual(roundtrip(client, server), (2, 2))
self.assertEqual(roundtrip(client, server), (1, 1))

event = server.next_event()
self.assertEqual(type(event), events.ProtocolNegotiated)
Expand Down Expand Up @@ -2785,16 +2775,19 @@ def test_payload_received_malformed_frame(self):

def test_send_max_data_blocked_by_cc(self):
with client_and_server() as (client, server):
# check congestion control
# Check congestion control. We do not check the congestion
# window too strictly as its exact value depends on the size
# of our ACKs, which depends on the execution time.
self.assertEqual(client._loss.bytes_in_flight, 0)
self.assertEqual(client._loss.congestion_window, 13423)
self.assertGreaterEqual(client._loss.congestion_window, 13530)
self.assertLessEqual(client._loss.congestion_window, 13540)

# artificially raise received data counter
client._local_max_data_used = client._local_max_data
self.assertEqual(server._remote_max_data, 1048576)

# artificially raise bytes in flight
client._loss._cc.bytes_in_flight = 13423
client._loss._cc.bytes_in_flight = client._loss.congestion_window

# MAX_DATA is not sent due to congestion control
self.assertEqual(drop(client), 0)
Expand Down Expand Up @@ -3153,20 +3146,19 @@ def test_version_negotiation_ignore(self):
self.assertEqual(drop(client), 0)

def test_version_negotiation_ignore_server(self):
with client_and_server() as (client, server):
# The server does not reply to the version negotiation packet.
server.receive_datagram(
encode_quic_version_negotiation(
source_cid=server._peer_cid.cid,
destination_cid=server.host_cid,
supported_versions=[QuicProtocolVersion.VERSION_1],
),
CLIENT_ADDR,
now=time.time(),
)
self.assertEqual(drop(client), 0)
server = create_standalone_server(self)

self.assertPacketDropped(server, "unexpected_packet")
# Servers do not expect version negotiation packets.
server.receive_datagram(
encode_quic_version_negotiation(
source_cid=server._peer_cid.cid,
destination_cid=server.host_cid,
supported_versions=[QuicProtocolVersion.VERSION_1],
),
CLIENT_ADDR,
now=time.time(),
)
self.assertPacketDropped(server, "unexpected_packet")

def test_version_negotiation_ok(self):
client = create_standalone_client(
Expand Down
Loading

0 comments on commit 6987588

Please sign in to comment.