Skip to content

Commit 2b4621d

Browse files
committed
Switch to QUIC draft-25
1 parent b68f60f commit 2b4621d

16 files changed

+303
-74
lines changed

docs/http_client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
async def http_client(host, port):
11-
configuration = QuicConfiguration(alpn_protocols=["hq-23"])
11+
configuration = QuicConfiguration(alpn_protocols=["hq-25"])
1212

1313
async with connect(host, port, configuration=configuration) as connection:
1414
reader, writer = await connection.create_stream()

src/aioquic/h0/connection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from aioquic.quic.connection import QuicConnection
55
from aioquic.quic.events import QuicEvent, StreamDataReceived
66

7-
H0_ALPN = ["hq-24", "hq-23"]
7+
H0_ALPN = ["hq-25"]
88

99

1010
class H0Connection:

src/aioquic/h3/connection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
logger = logging.getLogger("http3")
2121

22-
H3_ALPN = ["h3-24", "h3-23"]
22+
H3_ALPN = ["h3-25"]
2323

2424

2525
class ErrorCode(IntEnum):

src/aioquic/quic/configuration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class QuicConfiguration:
6868
private_key: Any = None
6969
quantum_readiness_test: bool = False
7070
supported_versions: List[int] = field(
71-
default_factory=lambda: [QuicProtocolVersion.DRAFT_24]
71+
default_factory=lambda: [QuicProtocolVersion.DRAFT_25]
7272
)
7373
verify_mode: Optional[int] = None
7474

src/aioquic/quic/connection.py

+75-17
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
QuicProtocolVersion,
2626
QuicStreamFrame,
2727
QuicTransportParameters,
28+
get_retry_integrity_tag,
2829
get_spin_bit,
2930
is_long_header,
3031
pull_ack_frame,
@@ -283,6 +284,7 @@ def __init__(
283284
# things to send
284285
self._close_pending = False
285286
self._datagrams_pending: Deque[bytes] = deque()
287+
self._handshake_done_pending = False
286288
self._ping_pending: List[int] = []
287289
self._probe_pending = False
288290
self._retire_connection_ids: List[int] = []
@@ -324,6 +326,7 @@ def __init__(
324326
0x1B: (self._handle_path_response_frame, EPOCHS("01")),
325327
0x1C: (self._handle_connection_close_frame, EPOCHS("IH1")),
326328
0x1D: (self._handle_connection_close_frame, EPOCHS("1")),
329+
0x1E: (self._handle_handshake_done_frame, EPOCHS("1")),
327330
0x30: (self._handle_datagram_frame, EPOCHS("01")),
328331
0x31: (self._handle_datagram_frame, EPOCHS("01")),
329332
}
@@ -666,10 +669,18 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
666669
return
667670

668671
if self._is_client and header.packet_type == PACKET_TYPE_RETRY:
669-
# stateless retry
672+
# calculate stateless retry integrity tag
673+
integrity_tag = get_retry_integrity_tag(
674+
version=header.version,
675+
source_cid=header.source_cid,
676+
destination_cid=header.destination_cid,
677+
original_destination_cid=self._peer_cid,
678+
retry_token=header.token,
679+
)
680+
670681
if (
671682
header.destination_cid == self.host_cid
672-
and header.original_destination_cid == self._peer_cid
683+
and header.integrity_tag == integrity_tag
673684
and not self._stateless_retry_count
674685
):
675686
if self._quic_logger is not None:
@@ -862,12 +873,13 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
862873
self._network_paths.insert(0, network_path)
863874

864875
# record packet as received
865-
if packet_number > space.largest_received_packet:
866-
space.largest_received_packet = packet_number
867-
space.largest_received_time = now
868-
space.ack_queue.add(packet_number)
869-
if is_ack_eliciting and space.ack_at is None:
870-
space.ack_at = now + self._ack_delay
876+
if not space.discarded:
877+
if packet_number > space.largest_received_packet:
878+
space.largest_received_packet = packet_number
879+
space.largest_received_time = now
880+
space.ack_queue.add(packet_number)
881+
if is_ack_eliciting and space.ack_at is None:
882+
space.ack_at = now + self._ack_delay
871883

872884
def request_key_update(self) -> None:
873885
"""
@@ -1019,6 +1031,7 @@ def _discard_epoch(self, epoch: tls.Epoch) -> None:
10191031
self._logger.debug("Discarding epoch %s", epoch)
10201032
self._cryptos[epoch].teardown()
10211033
self._loss.discard_space(self._spaces[epoch])
1034+
self._spaces[epoch].discarded = True
10221035

10231036
def _find_network_path(self, addr: NetworkAddress) -> QuicNetworkPath:
10241037
# check existing network paths
@@ -1191,15 +1204,6 @@ def _handle_ack_frame(
11911204
now=context.time,
11921205
)
11931206

1194-
# check if we can discard handshake keys
1195-
if (
1196-
not self._handshake_confirmed
1197-
and self._handshake_complete
1198-
and context.epoch == tls.Epoch.ONE_RTT
1199-
):
1200-
self._discard_epoch(tls.Epoch.HANDSHAKE)
1201-
self._handshake_confirmed = True
1202-
12031207
def _handle_connection_close_frame(
12041208
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
12051209
) -> None:
@@ -1291,6 +1295,13 @@ def _handle_crypto_frame(
12911295
tls.State.SERVER_POST_HANDSHAKE,
12921296
]:
12931297
self._handshake_complete = True
1298+
1299+
# for servers, the handshake is now confirmed
1300+
if not self._is_client:
1301+
self._discard_epoch(tls.Epoch.HANDSHAKE)
1302+
self._handshake_confirmed = True
1303+
self._handshake_done_pending = True
1304+
12941305
self._loss.is_client_without_1rtt = False
12951306
self._replenish_connection_ids()
12961307
self._events.append(
@@ -1352,6 +1363,30 @@ def _handle_datagram_frame(
13521363

13531364
self._events.append(events.DatagramFrameReceived(data=data))
13541365

1366+
def _handle_handshake_done_frame(
1367+
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
1368+
) -> None:
1369+
"""
1370+
Handle a HANDSHAKE_DONE frame.
1371+
"""
1372+
# log frame
1373+
if self._quic_logger is not None:
1374+
context.quic_logger_frames.append(
1375+
self._quic_logger.encode_handshake_done_frame()
1376+
)
1377+
1378+
if not self._is_client:
1379+
raise QuicConnectionError(
1380+
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
1381+
frame_type=frame_type,
1382+
reason_phrase="Clients must not send HANDSHAKE_DONE frames",
1383+
)
1384+
1385+
#  for clients, the handshake is now confirmed
1386+
if not self._handshake_confirmed:
1387+
self._discard_epoch(tls.Epoch.HANDSHAKE)
1388+
self._handshake_confirmed = True
1389+
13551390
def _handle_max_data_frame(
13561391
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
13571392
) -> None:
@@ -1765,6 +1800,13 @@ def _on_ack_delivery(
17651800
if delivery == QuicDeliveryState.ACKED:
17661801
space.ack_queue.subtract(0, highest_acked + 1)
17671802

1803+
def _on_handshake_done_delivery(self, delivery: QuicDeliveryState) -> None:
1804+
"""
1805+
Callback when a HANDSHAKE_DONE frame is acknowledged or lost.
1806+
"""
1807+
if delivery != QuicDeliveryState.ACKED:
1808+
self._handshake_done_pending = True
1809+
17681810
def _on_max_data_delivery(self, delivery: QuicDeliveryState) -> None:
17691811
"""
17701812
Callback when a MAX_DATA frame is acknowledged or lost.
@@ -2063,6 +2105,11 @@ def _write_application(
20632105
if space.ack_at is not None and space.ack_at <= now:
20642106
self._write_ack_frame(builder=builder, space=space, now=now)
20652107

2108+
# HANDSHAKE_DONE
2109+
if self._handshake_done_pending:
2110+
self._write_handshake_done_frame(builder=builder)
2111+
self._handshake_done_pending = False
2112+
20662113
# PATH CHALLENGE
20672114
if (
20682115
not network_path.is_validated
@@ -2329,6 +2376,17 @@ def _write_datagram_frame(
23292376

23302377
return True
23312378

2379+
def _write_handshake_done_frame(self, builder: QuicPacketBuilder) -> None:
2380+
builder.start_frame(
2381+
QuicFrameType.HANDSHAKE_DONE, self._on_handshake_done_delivery,
2382+
)
2383+
2384+
# log frame
2385+
if self._quic_logger is not None:
2386+
builder.quic_logger_frames.append(
2387+
self._quic_logger.encode_handshake_done_frame()
2388+
)
2389+
23322390
def _write_new_connection_id_frame(
23332391
self, builder: QuicPacketBuilder, connection_id: QuicConnectionId
23342392
) -> None:

src/aioquic/quic/logger.py

+3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def encode_datagram_frame(self, length: int) -> Dict:
8383
"length": length,
8484
}
8585

86+
def encode_handshake_done_frame(self) -> Dict:
87+
return {"frame_type": "handshake_done"}
88+
8689
def encode_max_data_frame(self, maximum: int) -> Dict:
8790
return {"frame_type": "max_data", "maximum": str(maximum)}
8891

src/aioquic/quic/packet.py

+65-13
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import binascii
12
import ipaddress
23
import os
34
from dataclasses import dataclass
45
from enum import IntEnum
56
from typing import List, Optional, Tuple
67

8+
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
9+
710
from ..buffer import Buffer
811
from ..tls import pull_block, push_block
912
from .rangeset import RangeSet
@@ -21,6 +24,9 @@
2124

2225
CONNECTION_ID_MAX_SIZE = 20
2326
PACKET_NUMBER_MAX_SIZE = 4
27+
RETRY_AEAD_KEY = binascii.unhexlify("4d32ecdb2a2133c841e4043df27d4430")
28+
RETRY_AEAD_NONCE = binascii.unhexlify("4d1611d05513a552c587d575")
29+
RETRY_INTEGRITY_SIZE = 16
2430

2531

2632
class QuicErrorCode(IntEnum):
@@ -33,14 +39,16 @@ class QuicErrorCode(IntEnum):
3339
FINAL_SIZE_ERROR = 0x6
3440
FRAME_ENCODING_ERROR = 0x7
3541
TRANSPORT_PARAMETER_ERROR = 0x8
42+
CONNECTION_ID_LIMIT_ERROR = 0x9
3643
PROTOCOL_VIOLATION = 0xA
44+
INVALID_TOKEN = 0xB
3745
CRYPTO_BUFFER_EXCEEDED = 0xD
3846
CRYPTO_ERROR = 0x100
3947

4048

4149
class QuicProtocolVersion(IntEnum):
4250
NEGOTIATION = 0
43-
DRAFT_24 = 0xFF000018
51+
DRAFT_25 = 0xFF000019
4452

4553

4654
@dataclass
@@ -50,8 +58,8 @@ class QuicHeader:
5058
packet_type: int
5159
destination_cid: bytes
5260
source_cid: bytes
53-
original_destination_cid: bytes = b""
5461
token: bytes = b""
62+
integrity_tag: bytes = b""
5563
rest_length: int = 0
5664

5765

@@ -72,6 +80,42 @@ def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int:
7280
return candidate
7381

7482

83+
def get_retry_integrity_tag(
84+
version: int,
85+
source_cid: bytes,
86+
destination_cid: bytes,
87+
original_destination_cid: bytes,
88+
retry_token: bytes,
89+
) -> bytes:
90+
"""
91+
Calculate the integrity tag for a RETRY packet.
92+
"""
93+
# build Retry pseudo packet
94+
buf = Buffer(
95+
capacity=8
96+
+ len(destination_cid)
97+
+ len(source_cid)
98+
+ len(original_destination_cid)
99+
+ len(retry_token)
100+
)
101+
buf.push_uint8(len(original_destination_cid))
102+
buf.push_bytes(original_destination_cid)
103+
buf.push_uint8(PACKET_TYPE_RETRY)
104+
buf.push_uint32(version)
105+
buf.push_uint8(len(destination_cid))
106+
buf.push_bytes(destination_cid)
107+
buf.push_uint8(len(source_cid))
108+
buf.push_bytes(source_cid)
109+
buf.push_bytes(retry_token)
110+
assert buf.eof()
111+
112+
# run AES-128-GCM
113+
aead = AESGCM(RETRY_AEAD_KEY)
114+
integrity_tag = aead.encrypt(RETRY_AEAD_NONCE, b"", buf.data)
115+
assert len(integrity_tag) == RETRY_INTEGRITY_SIZE
116+
return integrity_tag
117+
118+
75119
def get_spin_bit(first_byte: int) -> bool:
76120
return bool(first_byte & PACKET_SPIN_BIT)
77121

@@ -83,7 +127,7 @@ def is_long_header(first_byte: int) -> bool:
83127
def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> QuicHeader:
84128
first_byte = buf.pull_uint8()
85129

86-
original_destination_cid = b""
130+
integrity_tag = b""
87131
token = b""
88132
if is_long_header(first_byte):
89133
# long header packet
@@ -115,11 +159,9 @@ def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> Quic
115159
token = buf.pull_bytes(token_length)
116160
rest_length = buf.pull_uint_var()
117161
elif packet_type == PACKET_TYPE_RETRY:
118-
original_destination_cid_length = buf.pull_uint8()
119-
original_destination_cid = buf.pull_bytes(
120-
original_destination_cid_length
121-
)
122-
token = buf.pull_bytes(buf.capacity - buf.tell())
162+
token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_SIZE
163+
token = buf.pull_bytes(token_length)
164+
integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_SIZE)
123165
rest_length = 0
124166
else:
125167
rest_length = buf.pull_uint_var()
@@ -130,8 +172,8 @@ def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> Quic
130172
packet_type=packet_type,
131173
destination_cid=destination_cid,
132174
source_cid=source_cid,
133-
original_destination_cid=original_destination_cid,
134175
token=token,
176+
integrity_tag=integrity_tag,
135177
rest_length=rest_length,
136178
)
137179
else:
@@ -159,11 +201,20 @@ def encode_quic_retry(
159201
original_destination_cid: bytes,
160202
retry_token: bytes,
161203
) -> bytes:
204+
# calculate integrity tag
205+
integrity_tag = get_retry_integrity_tag(
206+
version=version,
207+
source_cid=source_cid,
208+
destination_cid=destination_cid,
209+
original_destination_cid=original_destination_cid,
210+
retry_token=retry_token,
211+
)
212+
162213
buf = Buffer(
163-
capacity=8
214+
capacity=7
164215
+ len(destination_cid)
165216
+ len(source_cid)
166-
+ len(original_destination_cid)
217+
+ len(integrity_tag)
167218
+ len(retry_token)
168219
)
169220
buf.push_uint8(PACKET_TYPE_RETRY)
@@ -172,9 +223,9 @@ def encode_quic_retry(
172223
buf.push_bytes(destination_cid)
173224
buf.push_uint8(len(source_cid))
174225
buf.push_bytes(source_cid)
175-
buf.push_uint8(len(original_destination_cid))
176-
buf.push_bytes(original_destination_cid)
177226
buf.push_bytes(retry_token)
227+
buf.push_bytes(integrity_tag)
228+
assert buf.eof()
178229
return buf.data
179230

180231

@@ -368,6 +419,7 @@ class QuicFrameType(IntEnum):
368419
PATH_RESPONSE = 0x1B
369420
TRANSPORT_CLOSE = 0x1C
370421
APPLICATION_CLOSE = 0x1D
422+
HANDSHAKE_DONE = 0x1E
371423
DATAGRAM = 0x30
372424
DATAGRAM_WITH_LENGTH = 0x31
373425

0 commit comments

Comments
 (0)