Skip to content

Commit

Permalink
Fix tls.py assertion issues. (aiortc#435)
Browse files Browse the repository at this point in the history
1) Some assertions had side effects and would cause a loss of
   framing if python was run with -O or -OO

2) Some assertions should have been error checks.

Co-authored-by: Jeremy Lainé <[email protected]>
  • Loading branch information
rthalley and jlaine authored Dec 28, 2023
1 parent bc7c480 commit 5772246
Show file tree
Hide file tree
Showing 2 changed files with 272 additions and 95 deletions.
146 changes: 103 additions & 43 deletions src/aioquic/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ class AlertCertificateExpired(Alert):
description = AlertDescription.certificate_expired


class AlertDecodeError(Alert):
description = AlertDescription.decode_error


class AlertDecryptError(Alert):
description = AlertDescription.decrypt_error

Expand Down Expand Up @@ -330,6 +334,10 @@ class HandshakeType(IntEnum):
MESSAGE_HASH = 254


class NameType(IntEnum):
HOST_NAME = 0


class PskKeyExchangeMode(IntEnum):
PSK_KE = 0
PSK_DHE_KE = 1
Expand Down Expand Up @@ -365,7 +373,9 @@ def pull_block(buf: Buffer, capacity: int) -> Generator:
length = int.from_bytes(buf.pull_bytes(capacity), byteorder="big")
end = buf.tell() + length
yield length
assert buf.tell() == end
if buf.tell() != end:
# There was trailing garbage or our parsing was bad.
raise AlertDecodeError("extra bytes at the end of a block")


@contextmanager
Expand Down Expand Up @@ -433,6 +443,26 @@ def push_extension(buf: Buffer, extension_type: int) -> Generator:
yield


# ServerName


def pull_server_name(buf: Buffer) -> str:
with pull_block(buf, 2):
name_type = buf.pull_uint8()
if name_type != NameType.HOST_NAME:
# We don't know this name_type.
raise AlertIllegalParameter(
f"ServerName has an unknown name type {name_type}"
)
return pull_opaque(buf, 2).decode("ascii")


def push_server_name(buf: Buffer, server_name: str) -> None:
with push_block(buf, 2):
buf.push_uint8(NameType.HOST_NAME)
push_opaque(buf, 2, server_name.encode("ascii"))


# KeyShareEntry


Expand Down Expand Up @@ -466,6 +496,12 @@ def push_alpn_protocol(buf: Buffer, protocol: str) -> None:
PskIdentity = Tuple[bytes, int]


@dataclass
class OfferedPsks:
identities: List[PskIdentity]
binders: List[bytes]


def pull_psk_identity(buf: Buffer) -> PskIdentity:
identity = pull_opaque(buf, 2)
obfuscated_ticket_age = buf.pull_uint32()
Expand All @@ -485,15 +521,31 @@ def push_psk_binder(buf: Buffer, binder: bytes) -> None:
push_opaque(buf, 1, binder)


# MESSAGES
def pull_offered_psks(buf: Buffer) -> OfferedPsks:
return OfferedPsks(
identities=pull_list(buf, 2, partial(pull_psk_identity, buf)),
binders=pull_list(buf, 2, partial(pull_psk_binder, buf)),
)

Extension = Tuple[int, bytes]

def push_offered_psks(buf: Buffer, pre_shared_key: OfferedPsks) -> None:
push_list(
buf,
2,
partial(push_psk_identity, buf),
pre_shared_key.identities,
)
push_list(
buf,
2,
partial(push_psk_binder, buf),
pre_shared_key.binders,
)

@dataclass
class OfferedPsks:
identities: List[PskIdentity]
binders: List[bytes]

# MESSAGES

Extension = Tuple[int, bytes]


@dataclass
Expand All @@ -517,10 +569,21 @@ class ClientHello:
other_extensions: List[Extension] = field(default_factory=list)


def pull_handshake_type(buf: Buffer, expected_type: HandshakeType) -> None:
"""
Pull the message type and assert it is the expected one.
If it is not, we have a programming error.
"""
message_type = buf.pull_uint8()
assert message_type == expected_type


def pull_client_hello(buf: Buffer) -> ClientHello:
assert buf.pull_uint8() == HandshakeType.CLIENT_HELLO
pull_handshake_type(buf, HandshakeType.CLIENT_HELLO)
with pull_block(buf, 3):
assert buf.pull_uint16() == TLS_VERSION_1_2
if buf.pull_uint16() != TLS_VERSION_1_2:
raise AlertDecodeError("ClientHello version is not 1.2")

hello = ClientHello(
random=buf.pull_bytes(32),
Expand All @@ -535,7 +598,9 @@ def pull_client_hello(buf: Buffer) -> ClientHello:
def pull_extension() -> None:
# pre_shared_key MUST be last
nonlocal after_psk
assert not after_psk
if after_psk:
# the alert is Illegal Parameter per RFC 8446 section 4.2.11.
raise AlertIllegalParameter("PreSharedKey is not the last extension")

extension_type = buf.pull_uint16()
extension_length = buf.pull_uint16()
Expand All @@ -550,20 +615,15 @@ def pull_extension() -> None:
elif extension_type == ExtensionType.PSK_KEY_EXCHANGE_MODES:
hello.psk_key_exchange_modes = pull_list(buf, 1, buf.pull_uint8)
elif extension_type == ExtensionType.SERVER_NAME:
with pull_block(buf, 2):
assert buf.pull_uint8() == 0
hello.server_name = pull_opaque(buf, 2).decode("ascii")
hello.server_name = pull_server_name(buf)
elif extension_type == ExtensionType.ALPN:
hello.alpn_protocols = pull_list(
buf, 2, partial(pull_alpn_protocol, buf)
)
elif extension_type == ExtensionType.EARLY_DATA:
hello.early_data = True
elif extension_type == ExtensionType.PRE_SHARED_KEY:
hello.pre_shared_key = OfferedPsks(
identities=pull_list(buf, 2, partial(pull_psk_identity, buf)),
binders=pull_list(buf, 2, partial(pull_psk_binder, buf)),
)
hello.pre_shared_key = pull_offered_psks(buf)
after_psk = True
else:
hello.other_extensions.append(
Expand Down Expand Up @@ -604,9 +664,7 @@ def push_client_hello(buf: Buffer, hello: ClientHello) -> None:

if hello.server_name is not None:
with push_extension(buf, ExtensionType.SERVER_NAME):
with push_block(buf, 2):
buf.push_uint8(0)
push_opaque(buf, 2, hello.server_name.encode("ascii"))
push_server_name(buf, hello.server_name)

if hello.alpn_protocols is not None:
with push_extension(buf, ExtensionType.ALPN):
Expand All @@ -625,18 +683,7 @@ def push_client_hello(buf: Buffer, hello: ClientHello) -> None:
# pre_shared_key MUST be last
if hello.pre_shared_key is not None:
with push_extension(buf, ExtensionType.PRE_SHARED_KEY):
push_list(
buf,
2,
partial(push_psk_identity, buf),
hello.pre_shared_key.identities,
)
push_list(
buf,
2,
partial(push_psk_binder, buf),
hello.pre_shared_key.binders,
)
push_offered_psks(buf, hello.pre_shared_key)


@dataclass
Expand All @@ -654,9 +701,10 @@ class ServerHello:


def pull_server_hello(buf: Buffer) -> ServerHello:
assert buf.pull_uint8() == HandshakeType.SERVER_HELLO
pull_handshake_type(buf, HandshakeType.SERVER_HELLO)
with pull_block(buf, 3):
assert buf.pull_uint16() == TLS_VERSION_1_2
if buf.pull_uint16() != TLS_VERSION_1_2:
raise AlertDecodeError("ServerHello version is not 1.2")

hello = ServerHello(
random=buf.pull_bytes(32),
Expand Down Expand Up @@ -729,7 +777,7 @@ class NewSessionTicket:
def pull_new_session_ticket(buf: Buffer) -> NewSessionTicket:
new_session_ticket = NewSessionTicket()

assert buf.pull_uint8() == HandshakeType.NEW_SESSION_TICKET
pull_handshake_type(buf, HandshakeType.NEW_SESSION_TICKET)
with pull_block(buf, 3):
new_session_ticket.ticket_lifetime = buf.pull_uint32()
new_session_ticket.ticket_age_add = buf.pull_uint32()
Expand Down Expand Up @@ -780,7 +828,7 @@ class EncryptedExtensions:
def pull_encrypted_extensions(buf: Buffer) -> EncryptedExtensions:
extensions = EncryptedExtensions()

assert buf.pull_uint8() == HandshakeType.ENCRYPTED_EXTENSIONS
pull_handshake_type(buf, HandshakeType.ENCRYPTED_EXTENSIONS)
with pull_block(buf, 3):

def pull_extension() -> None:
Expand Down Expand Up @@ -836,7 +884,7 @@ class Certificate:
def pull_certificate(buf: Buffer) -> Certificate:
certificate = Certificate()

assert buf.pull_uint8() == HandshakeType.CERTIFICATE
pull_handshake_type(buf, HandshakeType.CERTIFICATE)
with pull_block(buf, 3):
certificate.request_context = pull_opaque(buf, 1)

Expand Down Expand Up @@ -876,7 +924,7 @@ class CertificateRequest:
def pull_certificate_request(buf: Buffer) -> CertificateRequest:
certificate_request = CertificateRequest()

assert buf.pull_uint8() == HandshakeType.CERTIFICATE_REQUEST
pull_handshake_type(buf, HandshakeType.CERTIFICATE_REQUEST)
with pull_block(buf, 3):
certificate_request.request_context = pull_opaque(buf, 1)

Expand Down Expand Up @@ -922,7 +970,7 @@ class CertificateVerify:


def pull_certificate_verify(buf: Buffer) -> CertificateVerify:
assert buf.pull_uint8() == HandshakeType.CERTIFICATE_VERIFY
pull_handshake_type(buf, HandshakeType.CERTIFICATE_VERIFY)
with pull_block(buf, 3):
algorithm = buf.pull_uint16()
signature = pull_opaque(buf, 2)
Expand All @@ -945,7 +993,7 @@ class Finished:
def pull_finished(buf: Buffer) -> Finished:
finished = Finished()

assert buf.pull_uint8() == HandshakeType.FINISHED
pull_handshake_type(buf, HandshakeType.FINISHED)
finished.verify_data = pull_opaque(buf, 3)

return finished
Expand Down Expand Up @@ -1373,6 +1421,9 @@ def handle_message(
elif self.state == State.SERVER_POST_HANDSHAKE:
raise AlertUnexpectedMessage

# This condition should never be reached, because if the message
# contains any extra bytes, the `pull_block` inside the message
# parser will raise `AlertDecodeError`.
assert input_buf.eof()

def _build_session_ticket(
Expand Down Expand Up @@ -1402,7 +1453,10 @@ def _build_session_ticket(
)

def _check_certificate_verify_signature(self, verify: CertificateVerify) -> None:
assert verify.algorithm in self._signature_algorithms
if verify.algorithm not in self._signature_algorithms:
raise AlertDecryptError(
"CertificateVerify has a signature algorithm we did not advertise"
)

try:
self._peer_certificate.public_key().verify(
Expand Down Expand Up @@ -1524,8 +1578,14 @@ def _client_handle_hello(self, input_buf: Buffer, output_buf: Buffer) -> None:
[peer_hello.cipher_suite],
AlertHandshakeFailure("Unsupported cipher suite"),
)
assert peer_hello.compression_method in self._legacy_compression_methods
assert peer_hello.supported_version in self._supported_versions
if peer_hello.compression_method not in self._legacy_compression_methods:
raise AlertIllegalParameter(
"ServerHello has a compression method we did not advertise"
)
if peer_hello.supported_version not in self._supported_versions:
raise AlertIllegalParameter(
"ServerHello has a version we did not advertise"
)

# select key schedule
if peer_hello.pre_shared_key is not None:
Expand Down
Loading

0 comments on commit 5772246

Please sign in to comment.