Skip to content

gh-134698: Hold a lock when the thread state is detached in ssl #134724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,25 @@ def getpass(self):
# Make sure the password function isn't called if it isn't needed
ctx.load_cert_chain(CERTFILE, password=getpass_exception)

@threading_helper.requires_working_threading()
def test_load_cert_chain_thread_safety(self):
# gh-134698: _ssl detaches the thread state (and as such,
# releases the GIL and critical sections) around expensive
# OpenSSL calls. Unfortunately, OpenSSL structures aren't
# thread-safe, so executing these calls concurrently led
# to crashes.
ctx = ssl.create_default_context()

def race():
ctx.load_cert_chain(CERTFILE)

threads = [threading.Thread(target=race) for _ in range(8)]
with threading_helper.catch_threading_exception() as cm:
with threading_helper.start_threads(threads):
pass

self.assertIsNone(cm.exc_value)

def test_load_verify_locations(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ctx.load_verify_locations(CERTFILE)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix a crash when calling methods of :class:`ssl.SSLContext` or
:class:`ssl.SSLSocket` across multiple threads.
99 changes: 57 additions & 42 deletions Modules/_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@
/* Redefined below for Windows debug builds after important #includes */
#define _PySSL_FIX_ERRNO

#define PySSL_BEGIN_ALLOW_THREADS_S(save) \
do { (save) = PyEval_SaveThread(); } while(0)
#define PySSL_END_ALLOW_THREADS_S(save) \
do { PyEval_RestoreThread(save); _PySSL_FIX_ERRNO; } while(0)
#define PySSL_BEGIN_ALLOW_THREADS { \
#define PySSL_BEGIN_ALLOW_THREADS_S(save, mutex) \
do { (save) = PyEval_SaveThread(); PyMutex_Lock(mutex); } while(0)
#define PySSL_END_ALLOW_THREADS_S(save, mutex) \
do { PyMutex_Unlock(mutex); PyEval_RestoreThread(save); _PySSL_FIX_ERRNO; } while(0)
#define PySSL_BEGIN_ALLOW_THREADS(self) { \
PyThreadState *_save = NULL; \
PySSL_BEGIN_ALLOW_THREADS_S(_save);
#define PySSL_END_ALLOW_THREADS PySSL_END_ALLOW_THREADS_S(_save); }
PySSL_BEGIN_ALLOW_THREADS_S(_save, &self->tstate_mutex);
#define PySSL_END_ALLOW_THREADS(self) PySSL_END_ALLOW_THREADS_S(_save, &self->tstate_mutex); }

#if defined(HAVE_POLL_H)
#include <poll.h>
Expand Down Expand Up @@ -309,6 +309,9 @@ typedef struct {
PyObject *psk_client_callback;
PyObject *psk_server_callback;
#endif
/* Lock to synchronize calls when the thread state is detached.
See also gh-134698. */
PyMutex tstate_mutex;
} PySSLContext;

#define PySSLContext_CAST(op) ((PySSLContext *)(op))
Expand Down Expand Up @@ -336,6 +339,9 @@ typedef struct {
* and shutdown methods check for chained exceptions.
*/
PyObject *exc;
/* Lock to synchronize calls when the thread state is detached.
See also gh-134698. */
PyMutex tstate_mutex;
} PySSLSocket;

#define PySSLSocket_CAST(op) ((PySSLSocket *)(op))
Expand Down Expand Up @@ -885,13 +891,14 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
self->server_hostname = NULL;
self->err = err;
self->exc = NULL;
self->tstate_mutex = (PyMutex){0};

/* Make sure the SSL error state is initialized */
ERR_clear_error();

PySSL_BEGIN_ALLOW_THREADS
PySSL_BEGIN_ALLOW_THREADS(sslctx)
self->ssl = SSL_new(ctx);
PySSL_END_ALLOW_THREADS
PySSL_END_ALLOW_THREADS(sslctx)
if (self->ssl == NULL) {
Py_DECREF(self);
_setSSLError(get_state_ctx(self), NULL, 0, __FILE__, __LINE__);
Expand Down Expand Up @@ -960,12 +967,12 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
BIO_set_nbio(SSL_get_wbio(self->ssl), 1);
}

PySSL_BEGIN_ALLOW_THREADS
PySSL_BEGIN_ALLOW_THREADS(self)
if (socket_type == PY_SSL_CLIENT)
SSL_set_connect_state(self->ssl);
else
SSL_set_accept_state(self->ssl);
PySSL_END_ALLOW_THREADS
PySSL_END_ALLOW_THREADS(self)

self->socket_type = socket_type;
if (sock != NULL) {
Expand Down Expand Up @@ -1034,10 +1041,10 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self)
/* Actually negotiate SSL connection */
/* XXX If SSL_do_handshake() returns 0, it's also a failure. */
do {
PySSL_BEGIN_ALLOW_THREADS
PySSL_BEGIN_ALLOW_THREADS(self)
ret = SSL_do_handshake(self->ssl);
err = _PySSL_errno(ret < 1, self->ssl, ret);
PySSL_END_ALLOW_THREADS
PySSL_END_ALLOW_THREADS(self)
self->err = err;

if (PyErr_CheckSignals())
Expand Down Expand Up @@ -2414,9 +2421,10 @@ PySSL_select(PySocketSockObject *s, int writing, PyTime_t timeout)
ms = (int)_PyTime_AsMilliseconds(timeout, _PyTime_ROUND_CEILING);
assert(ms <= INT_MAX);

PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS
rc = poll(&pollfd, 1, (int)ms);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;
#else
/* Guard against socket too large for select*/
if (!_PyIsSelectable_fd(s->sock_fd))
Expand All @@ -2428,13 +2436,14 @@ PySSL_select(PySocketSockObject *s, int writing, PyTime_t timeout)
FD_SET(s->sock_fd, &fds);

/* Wait until the socket becomes ready */
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS
nfds = Py_SAFE_DOWNCAST(s->sock_fd+1, SOCKET_T, int);
if (writing)
rc = select(nfds, NULL, &fds, NULL, &tv);
else
rc = select(nfds, &fds, NULL, NULL, &tv);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;
#endif

/* Return SOCKET_TIMED_OUT on timeout, SOCKET_OPERATION_OK otherwise
Expand Down Expand Up @@ -2505,10 +2514,10 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b)
}

do {
PySSL_BEGIN_ALLOW_THREADS
PySSL_BEGIN_ALLOW_THREADS(self)
retval = SSL_write_ex(self->ssl, b->buf, (size_t)b->len, &count);
err = _PySSL_errno(retval == 0, self->ssl, retval);
PySSL_END_ALLOW_THREADS
PySSL_END_ALLOW_THREADS(self)
self->err = err;

if (PyErr_CheckSignals())
Expand Down Expand Up @@ -2566,10 +2575,10 @@ _ssl__SSLSocket_pending_impl(PySSLSocket *self)
int count = 0;
_PySSLError err;

PySSL_BEGIN_ALLOW_THREADS
PySSL_BEGIN_ALLOW_THREADS(self)
count = SSL_pending(self->ssl);
err = _PySSL_errno(count < 0, self->ssl, count);
PySSL_END_ALLOW_THREADS
PySSL_END_ALLOW_THREADS(self)
self->err = err;

if (count < 0)
Expand Down Expand Up @@ -2660,10 +2669,10 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
deadline = _PyDeadline_Init(timeout);

do {
PySSL_BEGIN_ALLOW_THREADS
PySSL_BEGIN_ALLOW_THREADS(self)
retval = SSL_read_ex(self->ssl, mem, (size_t)len, &count);
err = _PySSL_errno(retval == 0, self->ssl, retval);
PySSL_END_ALLOW_THREADS
PySSL_END_ALLOW_THREADS(self)
self->err = err;

if (PyErr_CheckSignals())
Expand Down Expand Up @@ -2762,7 +2771,7 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)
}

while (1) {
PySSL_BEGIN_ALLOW_THREADS
PySSL_BEGIN_ALLOW_THREADS(self)
/* Disable read-ahead so that unwrap can work correctly.
* Otherwise OpenSSL might read in too much data,
* eating clear text data that happens to be
Expand All @@ -2775,7 +2784,7 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)
SSL_set_read_ahead(self->ssl, 0);
ret = SSL_shutdown(self->ssl);
err = _PySSL_errno(ret < 0, self->ssl, ret);
PySSL_END_ALLOW_THREADS
PySSL_END_ALLOW_THREADS(self)
self->err = err;

/* If err == 1, a secure shutdown with SSL_shutdown() is complete */
Expand Down Expand Up @@ -3167,9 +3176,10 @@ _ssl__SSLContext_impl(PyTypeObject *type, int proto_version)
// no other thread can be touching this object yet.
// (Technically, we can't even lock if we wanted to, as the
// lock hasn't been initialized yet.)
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS
ctx = SSL_CTX_new(method);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;

if (ctx == NULL) {
_setSSLError(get_ssl_state(module), NULL, 0, __FILE__, __LINE__);
Expand All @@ -3194,6 +3204,7 @@ _ssl__SSLContext_impl(PyTypeObject *type, int proto_version)
self->psk_client_callback = NULL;
self->psk_server_callback = NULL;
#endif
self->tstate_mutex = (PyMutex){0};

/* Don't check host name by default */
if (proto_version == PY_SSL_VERSION_TLS_CLIENT) {
Expand Down Expand Up @@ -3312,9 +3323,10 @@ context_clear(PyObject *op)
Py_CLEAR(self->psk_server_callback);
#endif
if (self->keylog_bio != NULL) {
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS
BIO_free_all(self->keylog_bio);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;
self->keylog_bio = NULL;
}
return 0;
Expand Down Expand Up @@ -4037,7 +4049,8 @@ _password_callback(char *buf, int size, int rwflag, void *userdata)
_PySSLPasswordInfo *pw_info = (_PySSLPasswordInfo*) userdata;
PyObject *fn_ret = NULL;

PySSL_END_ALLOW_THREADS_S(pw_info->thread_state);
pw_info->thread_state = PyThreadState_Swap(pw_info->thread_state);
_PySSL_FIX_ERRNO;

if (pw_info->error) {
/* already failed previously. OpenSSL 3.0.0-alpha14 invokes the
Expand Down Expand Up @@ -4067,13 +4080,13 @@ _password_callback(char *buf, int size, int rwflag, void *userdata)
goto error;
}

PySSL_BEGIN_ALLOW_THREADS_S(pw_info->thread_state);
pw_info->thread_state = PyThreadState_Swap(pw_info->thread_state);
memcpy(buf, pw_info->password, pw_info->size);
return pw_info->size;

error:
Py_XDECREF(fn_ret);
PySSL_BEGIN_ALLOW_THREADS_S(pw_info->thread_state);
pw_info->thread_state = PyThreadState_Swap(pw_info->thread_state);
pw_info->error = 1;
return -1;
}
Expand Down Expand Up @@ -4126,10 +4139,10 @@ _ssl__SSLContext_load_cert_chain_impl(PySSLContext *self, PyObject *certfile,
SSL_CTX_set_default_passwd_cb(self->ctx, _password_callback);
SSL_CTX_set_default_passwd_cb_userdata(self->ctx, &pw_info);
}
PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state);
PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state, &self->tstate_mutex);
r = SSL_CTX_use_certificate_chain_file(self->ctx,
PyBytes_AS_STRING(certfile_bytes));
PySSL_END_ALLOW_THREADS_S(pw_info.thread_state);
PySSL_END_ALLOW_THREADS_S(pw_info.thread_state, &self->tstate_mutex);
if (r != 1) {
if (pw_info.error) {
ERR_clear_error();
Expand All @@ -4144,11 +4157,11 @@ _ssl__SSLContext_load_cert_chain_impl(PySSLContext *self, PyObject *certfile,
}
goto error;
}
PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state);
PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state, &self->tstate_mutex);
r = SSL_CTX_use_PrivateKey_file(self->ctx,
PyBytes_AS_STRING(keyfile ? keyfile_bytes : certfile_bytes),
SSL_FILETYPE_PEM);
PySSL_END_ALLOW_THREADS_S(pw_info.thread_state);
PySSL_END_ALLOW_THREADS_S(pw_info.thread_state, &self->tstate_mutex);
Py_CLEAR(keyfile_bytes);
Py_CLEAR(certfile_bytes);
if (r != 1) {
Expand All @@ -4165,9 +4178,9 @@ _ssl__SSLContext_load_cert_chain_impl(PySSLContext *self, PyObject *certfile,
}
goto error;
}
PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state);
PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state, &self->tstate_mutex);
r = SSL_CTX_check_private_key(self->ctx);
PySSL_END_ALLOW_THREADS_S(pw_info.thread_state);
PySSL_END_ALLOW_THREADS_S(pw_info.thread_state, &self->tstate_mutex);
if (r != 1) {
_setSSLError(get_state_ctx(self), NULL, 0, __FILE__, __LINE__);
goto error;
Expand Down Expand Up @@ -4384,9 +4397,9 @@ _ssl__SSLContext_load_verify_locations_impl(PySSLContext *self,
cafile_buf = PyBytes_AS_STRING(cafile_bytes);
if (capath)
capath_buf = PyBytes_AS_STRING(capath_bytes);
PySSL_BEGIN_ALLOW_THREADS
PySSL_BEGIN_ALLOW_THREADS(self)
r = SSL_CTX_load_verify_locations(self->ctx, cafile_buf, capath_buf);
PySSL_END_ALLOW_THREADS
PySSL_END_ALLOW_THREADS(self)
if (r != 1) {
if (errno != 0) {
PyErr_SetFromErrno(PyExc_OSError);
Expand Down Expand Up @@ -4438,10 +4451,11 @@ _ssl__SSLContext_load_dh_params_impl(PySSLContext *self, PyObject *filepath)
return NULL;

errno = 0;
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS
dh = PEM_read_DHparams(f, NULL, NULL, NULL);
fclose(f);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;
if (dh == NULL) {
if (errno != 0) {
PyErr_SetFromErrnoWithFilenameObject(PyExc_OSError, filepath);
Expand Down Expand Up @@ -4593,6 +4607,7 @@ _ssl__SSLContext_set_default_verify_paths_impl(PySSLContext *self)
Py_BEGIN_ALLOW_THREADS
rc = SSL_CTX_set_default_verify_paths(self->ctx);
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;
if (!rc) {
_setSSLError(get_state_ctx(self), NULL, 0, __FILE__, __LINE__);
return NULL;
Expand Down
15 changes: 9 additions & 6 deletions Modules/_ssl/debughelpers.c
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,15 @@ _PySSL_keylog_callback(const SSL *ssl, const char *line)
* critical debug helper.
*/

PySSL_BEGIN_ALLOW_THREADS
assert(PyMutex_IsLocked(&ssl_obj->tstate_mutex));
Py_BEGIN_ALLOW_THREADS
PyThread_acquire_lock(lock, 1);
res = BIO_printf(ssl_obj->ctx->keylog_bio, "%s\n", line);
e = errno;
(void)BIO_flush(ssl_obj->ctx->keylog_bio);
PyThread_release_lock(lock);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;

if (res == -1) {
errno = e;
Expand Down Expand Up @@ -187,9 +189,10 @@ _PySSLContext_set_keylog_filename(PyObject *op, PyObject *arg,
if (self->keylog_bio != NULL) {
BIO *bio = self->keylog_bio;
self->keylog_bio = NULL;
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS
BIO_free_all(bio);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're missing a call to _PySSL_FIX_ERRNO here

_PySSL_FIX_ERRNO;
}

if (arg == Py_None) {
Expand All @@ -211,13 +214,13 @@ _PySSLContext_set_keylog_filename(PyObject *op, PyObject *arg,
self->keylog_filename = Py_NewRef(arg);

/* Write a header for seekable, empty files (this excludes pipes). */
PySSL_BEGIN_ALLOW_THREADS
PySSL_BEGIN_ALLOW_THREADS(self)
if (BIO_tell(self->keylog_bio) == 0) {
BIO_puts(self->keylog_bio,
"# TLS secrets log file, generated by OpenSSL / Python\n");
(void)BIO_flush(self->keylog_bio);
}
PySSL_END_ALLOW_THREADS
PySSL_END_ALLOW_THREADS(self)
SSL_CTX_set_keylog_callback(self->ctx, _PySSL_keylog_callback);
return 0;
}
Loading