Skip to content

Commit

Permalink
Add support for SECP384R1 key exchange
Browse files Browse the repository at this point in the history
  • Loading branch information
jlaine committed Jun 18, 2024
1 parent 7dc7214 commit d914a46
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 33 deletions.
40 changes: 19 additions & 21 deletions src/aioquic/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ def __init__(
self._signature_algorithms.append(SignatureAlgorithm.ED25519)
if default_backend().ed448_supported():
self._signature_algorithms.append(SignatureAlgorithm.ED448)
self._supported_groups = [Group.SECP256R1]
self._supported_groups = [Group.SECP256R1, Group.SECP384R1]
if default_backend().x25519_supported():
self._supported_groups.append(Group.X25519)
if default_backend().x448_supported():
Expand All @@ -1337,7 +1337,7 @@ def __init__(
self._dec_key: Optional[bytes] = None
self.__logger = logger

self._ec_private_key: Optional[ec.EllipticCurvePrivateKey] = None
self._ec_private_keys: List[ec.EllipticCurvePrivateKey] = []
self._x25519_private_key: Optional[x25519.X25519PrivateKey] = None
self._x448_private_key: Optional[x448.X448PrivateKey] = None

Expand Down Expand Up @@ -1525,13 +1525,7 @@ def _client_send_hello(self, output_buf: Buffer) -> None:
supported_groups: List[int] = []

for group in self._supported_groups:
if group == Group.SECP256R1:
self._ec_private_key = ec.generate_private_key(
GROUP_TO_CURVE[Group.SECP256R1]()
)
key_share.append(encode_public_key(self._ec_private_key.public_key()))
supported_groups.append(Group.SECP256R1)
elif group == Group.X25519:
if group == Group.X25519:
self._x25519_private_key = x25519.X25519PrivateKey.generate()
key_share.append(
encode_public_key(self._x25519_private_key.public_key())
Expand All @@ -1544,6 +1538,11 @@ def _client_send_hello(self, output_buf: Buffer) -> None:
elif group == Group.GREASE:
key_share.append((Group.GREASE, b"\x00"))
supported_groups.append(Group.GREASE)
elif group in GROUP_TO_CURVE:
ec_private_key = ec.generate_private_key(GROUP_TO_CURVE[group]())
self._ec_private_keys.append(ec_private_key)
key_share.append(encode_public_key(ec_private_key.public_key()))
supported_groups.append(group)

assert len(key_share), "no key share entries"

Expand Down Expand Up @@ -1668,13 +1667,13 @@ def _client_handle_hello(self, input_buf: Buffer, output_buf: Buffer) -> None:
and self._x448_private_key is not None
):
shared_key = self._x448_private_key.exchange(peer_public_key)
elif (
isinstance(peer_public_key, ec.EllipticCurvePublicKey)
and self._ec_private_key is not None
and self._ec_private_key.public_key().curve.__class__
== peer_public_key.curve.__class__
):
shared_key = self._ec_private_key.exchange(ec.ECDH(), peer_public_key)
elif isinstance(peer_public_key, ec.EllipticCurvePublicKey):
for ec_private_key in self._ec_private_keys:
if (
ec_private_key.public_key().curve.__class__
== peer_public_key.curve.__class__
):
shared_key = ec_private_key.exchange(ec.ECDH(), peer_public_key)
assert shared_key is not None

self.key_schedule.update_hash(input_buf.data)
Expand Down Expand Up @@ -1989,11 +1988,10 @@ def _server_handle_hello(
shared_key = self._x448_private_key.exchange(peer_public_key)
break
elif isinstance(peer_public_key, ec.EllipticCurvePublicKey):
self._ec_private_key = ec.generate_private_key(
GROUP_TO_CURVE[key_share[0]]()
)
public_key = self._ec_private_key.public_key()
shared_key = self._ec_private_key.exchange(ec.ECDH(), peer_public_key)
ec_private_key = ec.generate_private_key(GROUP_TO_CURVE[key_share[0]]())
self._ec_private_keys.append(ec_private_key)
public_key = ec_private_key.public_key()
shared_key = ec_private_key.exchange(ec.ECDH(), peer_public_key)
break
assert shared_key is not None

Expand Down
37 changes: 25 additions & 12 deletions tests/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def reset_buffers(buffers):


class ContextTest(TestCase):
def assertClientHello(self, data: bytes):
self.assertEqual(data[0], tls.HandshakeType.CLIENT_HELLO)
self.assertGreaterEqual(len(data), 191)
self.assertLessEqual(len(data), 564)

def create_client(
self, alpn_protocols=None, cadata=None, cafile=SERVER_CACERTFILE, **kwargs
):
Expand Down Expand Up @@ -379,8 +384,7 @@ def _handshake(self, client, server):
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
self.assertGreaterEqual(len(server_input), 181)
self.assertLessEqual(len(server_input), 358)
self.assertClientHello(server_input)
reset_buffers(client_buf)

# Handle client hello.
Expand Down Expand Up @@ -445,8 +449,7 @@ def test_handshake_with_certificate_request_no_certificate(self):
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
self.assertGreaterEqual(len(server_input), 181)
self.assertLessEqual(len(server_input), 358)
self.assertClientHello(server_input)
reset_buffers(client_buf)

# Handle client hello.
Expand Down Expand Up @@ -504,8 +507,7 @@ def test_handshake_with_certificate_request_with_certificate(self):
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
self.assertGreaterEqual(len(server_input), 181)
self.assertLessEqual(len(server_input), 358)
self.assertClientHello(server_input)
reset_buffers(client_buf)

# Handle client hello.
Expand Down Expand Up @@ -660,6 +662,20 @@ def test_handshake_with_grease_group(self):

self._handshake(client, server)

def test_handshake_with_secp256r1_group(self):
client = self.create_client()
client._supported_groups = [tls.Group.SECP256R1]
server = self.create_server()

self._handshake(client, server)

def test_handshake_with_secp384r1_group(self):
client = self.create_client()
client._supported_groups = [tls.Group.SECP384R1]
server = self.create_server()

self._handshake(client, server)

def test_handshake_with_x25519(self):
client = self.create_client()
client._supported_groups = [tls.Group.X25519]
Expand Down Expand Up @@ -729,8 +745,7 @@ def second_handshake():
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
self.assertGreaterEqual(len(server_input), 383)
self.assertLessEqual(len(server_input), 483)
self.assertClientHello(server_input)
reset_buffers(client_buf)

# Handle client hello.
Expand Down Expand Up @@ -782,8 +797,7 @@ def second_handshake_bad_binder():
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
self.assertGreaterEqual(len(server_input), 383)
self.assertLessEqual(len(server_input), 483)
self.assertClientHello(server_input)
reset_buffers(client_buf)

# tamper with binder
Expand All @@ -808,8 +822,7 @@ def second_handshake_bad_pre_shared_key():
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
self.assertGreaterEqual(len(server_input), 383)
self.assertLessEqual(len(server_input), 483)
self.assertClientHello(server_input)
reset_buffers(client_buf)

# handle client hello
Expand Down

0 comments on commit d914a46

Please sign in to comment.