|
25 | 25 | QuicProtocolVersion,
|
26 | 26 | QuicStreamFrame,
|
27 | 27 | QuicTransportParameters,
|
| 28 | + get_retry_integrity_tag, |
28 | 29 | get_spin_bit,
|
29 | 30 | is_long_header,
|
30 | 31 | pull_ack_frame,
|
@@ -283,6 +284,7 @@ def __init__(
|
283 | 284 | # things to send
|
284 | 285 | self._close_pending = False
|
285 | 286 | self._datagrams_pending: Deque[bytes] = deque()
|
| 287 | + self._handshake_done_pending = False |
286 | 288 | self._ping_pending: List[int] = []
|
287 | 289 | self._probe_pending = False
|
288 | 290 | self._retire_connection_ids: List[int] = []
|
@@ -324,6 +326,7 @@ def __init__(
|
324 | 326 | 0x1B: (self._handle_path_response_frame, EPOCHS("01")),
|
325 | 327 | 0x1C: (self._handle_connection_close_frame, EPOCHS("IH1")),
|
326 | 328 | 0x1D: (self._handle_connection_close_frame, EPOCHS("1")),
|
| 329 | + 0x1E: (self._handle_handshake_done_frame, EPOCHS("1")), |
327 | 330 | 0x30: (self._handle_datagram_frame, EPOCHS("01")),
|
328 | 331 | 0x31: (self._handle_datagram_frame, EPOCHS("01")),
|
329 | 332 | }
|
@@ -666,10 +669,18 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
|
666 | 669 | return
|
667 | 670 |
|
668 | 671 | if self._is_client and header.packet_type == PACKET_TYPE_RETRY:
|
669 |
| - # stateless retry |
| 672 | + # calculate stateless retry integrity tag |
| 673 | + integrity_tag = get_retry_integrity_tag( |
| 674 | + version=header.version, |
| 675 | + source_cid=header.source_cid, |
| 676 | + destination_cid=header.destination_cid, |
| 677 | + original_destination_cid=self._peer_cid, |
| 678 | + retry_token=header.token, |
| 679 | + ) |
| 680 | + |
670 | 681 | if (
|
671 | 682 | header.destination_cid == self.host_cid
|
672 |
| - and header.original_destination_cid == self._peer_cid |
| 683 | + and header.integrity_tag == integrity_tag |
673 | 684 | and not self._stateless_retry_count
|
674 | 685 | ):
|
675 | 686 | if self._quic_logger is not None:
|
@@ -862,12 +873,13 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
|
862 | 873 | self._network_paths.insert(0, network_path)
|
863 | 874 |
|
864 | 875 | # record packet as received
|
865 |
| - if packet_number > space.largest_received_packet: |
866 |
| - space.largest_received_packet = packet_number |
867 |
| - space.largest_received_time = now |
868 |
| - space.ack_queue.add(packet_number) |
869 |
| - if is_ack_eliciting and space.ack_at is None: |
870 |
| - space.ack_at = now + self._ack_delay |
| 876 | + if not space.discarded: |
| 877 | + if packet_number > space.largest_received_packet: |
| 878 | + space.largest_received_packet = packet_number |
| 879 | + space.largest_received_time = now |
| 880 | + space.ack_queue.add(packet_number) |
| 881 | + if is_ack_eliciting and space.ack_at is None: |
| 882 | + space.ack_at = now + self._ack_delay |
871 | 883 |
|
872 | 884 | def request_key_update(self) -> None:
|
873 | 885 | """
|
@@ -1019,6 +1031,7 @@ def _discard_epoch(self, epoch: tls.Epoch) -> None:
|
1019 | 1031 | self._logger.debug("Discarding epoch %s", epoch)
|
1020 | 1032 | self._cryptos[epoch].teardown()
|
1021 | 1033 | self._loss.discard_space(self._spaces[epoch])
|
| 1034 | + self._spaces[epoch].discarded = True |
1022 | 1035 |
|
1023 | 1036 | def _find_network_path(self, addr: NetworkAddress) -> QuicNetworkPath:
|
1024 | 1037 | # check existing network paths
|
@@ -1191,15 +1204,6 @@ def _handle_ack_frame(
|
1191 | 1204 | now=context.time,
|
1192 | 1205 | )
|
1193 | 1206 |
|
1194 |
| - # check if we can discard handshake keys |
1195 |
| - if ( |
1196 |
| - not self._handshake_confirmed |
1197 |
| - and self._handshake_complete |
1198 |
| - and context.epoch == tls.Epoch.ONE_RTT |
1199 |
| - ): |
1200 |
| - self._discard_epoch(tls.Epoch.HANDSHAKE) |
1201 |
| - self._handshake_confirmed = True |
1202 |
| - |
1203 | 1207 | def _handle_connection_close_frame(
|
1204 | 1208 | self, context: QuicReceiveContext, frame_type: int, buf: Buffer
|
1205 | 1209 | ) -> None:
|
@@ -1291,6 +1295,13 @@ def _handle_crypto_frame(
|
1291 | 1295 | tls.State.SERVER_POST_HANDSHAKE,
|
1292 | 1296 | ]:
|
1293 | 1297 | self._handshake_complete = True
|
| 1298 | + |
| 1299 | + # for servers, the handshake is now confirmed |
| 1300 | + if not self._is_client: |
| 1301 | + self._discard_epoch(tls.Epoch.HANDSHAKE) |
| 1302 | + self._handshake_confirmed = True |
| 1303 | + self._handshake_done_pending = True |
| 1304 | + |
1294 | 1305 | self._loss.is_client_without_1rtt = False
|
1295 | 1306 | self._replenish_connection_ids()
|
1296 | 1307 | self._events.append(
|
@@ -1352,6 +1363,30 @@ def _handle_datagram_frame(
|
1352 | 1363 |
|
1353 | 1364 | self._events.append(events.DatagramFrameReceived(data=data))
|
1354 | 1365 |
|
| 1366 | + def _handle_handshake_done_frame( |
| 1367 | + self, context: QuicReceiveContext, frame_type: int, buf: Buffer |
| 1368 | + ) -> None: |
| 1369 | + """ |
| 1370 | + Handle a HANDSHAKE_DONE frame. |
| 1371 | + """ |
| 1372 | + # log frame |
| 1373 | + if self._quic_logger is not None: |
| 1374 | + context.quic_logger_frames.append( |
| 1375 | + self._quic_logger.encode_handshake_done_frame() |
| 1376 | + ) |
| 1377 | + |
| 1378 | + if not self._is_client: |
| 1379 | + raise QuicConnectionError( |
| 1380 | + error_code=QuicErrorCode.PROTOCOL_VIOLATION, |
| 1381 | + frame_type=frame_type, |
| 1382 | + reason_phrase="Clients must not send HANDSHAKE_DONE frames", |
| 1383 | + ) |
| 1384 | + |
| 1385 | + # for clients, the handshake is now confirmed |
| 1386 | + if not self._handshake_confirmed: |
| 1387 | + self._discard_epoch(tls.Epoch.HANDSHAKE) |
| 1388 | + self._handshake_confirmed = True |
| 1389 | + |
1355 | 1390 | def _handle_max_data_frame(
|
1356 | 1391 | self, context: QuicReceiveContext, frame_type: int, buf: Buffer
|
1357 | 1392 | ) -> None:
|
@@ -1765,6 +1800,13 @@ def _on_ack_delivery(
|
1765 | 1800 | if delivery == QuicDeliveryState.ACKED:
|
1766 | 1801 | space.ack_queue.subtract(0, highest_acked + 1)
|
1767 | 1802 |
|
| 1803 | + def _on_handshake_done_delivery(self, delivery: QuicDeliveryState) -> None: |
| 1804 | + """ |
| 1805 | + Callback when a HANDSHAKE_DONE frame is acknowledged or lost. |
| 1806 | + """ |
| 1807 | + if delivery != QuicDeliveryState.ACKED: |
| 1808 | + self._handshake_done_pending = True |
| 1809 | + |
1768 | 1810 | def _on_max_data_delivery(self, delivery: QuicDeliveryState) -> None:
|
1769 | 1811 | """
|
1770 | 1812 | Callback when a MAX_DATA frame is acknowledged or lost.
|
@@ -2063,6 +2105,11 @@ def _write_application(
|
2063 | 2105 | if space.ack_at is not None and space.ack_at <= now:
|
2064 | 2106 | self._write_ack_frame(builder=builder, space=space, now=now)
|
2065 | 2107 |
|
| 2108 | + # HANDSHAKE_DONE |
| 2109 | + if self._handshake_done_pending: |
| 2110 | + self._write_handshake_done_frame(builder=builder) |
| 2111 | + self._handshake_done_pending = False |
| 2112 | + |
2066 | 2113 | # PATH CHALLENGE
|
2067 | 2114 | if (
|
2068 | 2115 | not network_path.is_validated
|
@@ -2329,6 +2376,17 @@ def _write_datagram_frame(
|
2329 | 2376 |
|
2330 | 2377 | return True
|
2331 | 2378 |
|
| 2379 | + def _write_handshake_done_frame(self, builder: QuicPacketBuilder) -> None: |
| 2380 | + builder.start_frame( |
| 2381 | + QuicFrameType.HANDSHAKE_DONE, self._on_handshake_done_delivery, |
| 2382 | + ) |
| 2383 | + |
| 2384 | + # log frame |
| 2385 | + if self._quic_logger is not None: |
| 2386 | + builder.quic_logger_frames.append( |
| 2387 | + self._quic_logger.encode_handshake_done_frame() |
| 2388 | + ) |
| 2389 | + |
2332 | 2390 | def _write_new_connection_id_frame(
|
2333 | 2391 | self, builder: QuicPacketBuilder, connection_id: QuicConnectionId
|
2334 | 2392 | ) -> None:
|
|
0 commit comments