Skip to content

Commit 621c309

Browse files
committed
Make PySslContext, PySslSocket ThreadSafe
1 parent fb5862d commit 621c309

File tree

1 file changed

+46
-39
lines changed

1 file changed

+46
-39
lines changed

vm/src/stdlib/ssl.rs

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ use crate::pyobject::{
1111
use crate::types::create_type;
1212
use crate::VirtualMachine;
1313

14-
use std::cell::{Ref, RefCell, RefMut};
1514
use std::convert::TryFrom;
1615
use std::ffi::{CStr, CString};
1716
use std::fmt;
17+
use std::sync::{RwLock, RwLockWriteGuard};
1818

1919
use foreign_types_shared::{ForeignType, ForeignTypeRef};
2020
use openssl::{
@@ -230,7 +230,7 @@ fn ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec<u8>, bool
230230

231231
#[pyclass(name = "_SSLContext")]
232232
struct PySslContext {
233-
ctx: RefCell<SslContextBuilder>,
233+
ctx: RwLock<SslContextBuilder>,
234234
check_hostname: bool,
235235
}
236236

@@ -248,16 +248,18 @@ impl PyValue for PySslContext {
248248

249249
#[pyimpl(flags(BASETYPE))]
250250
impl PySslContext {
251-
fn builder(&self) -> RefMut<SslContextBuilder> {
252-
self.ctx.borrow_mut()
251+
fn builder(&self) -> RwLockWriteGuard<'_, SslContextBuilder> {
252+
self.ctx.write().unwrap()
253253
}
254-
fn ctx(&self) -> Ref<ssl::SslContextRef> {
255-
Ref::map(self.ctx.borrow(), |ctx| unsafe {
256-
&**(ctx as *const SslContextBuilder as *const ssl::SslContext)
257-
})
254+
fn exec_ctx<F, R>(&self, func: F) -> R
255+
where
256+
F: Fn(&ssl::SslContextRef) -> R,
257+
{
258+
let c = self.ctx.read().unwrap();
259+
func(unsafe { &**(&*c as *const SslContextBuilder as *const ssl::SslContext) })
258260
}
259261
fn ptr(&self) -> *mut sys::SSL_CTX {
260-
self.ctx.borrow().as_ptr()
262+
(*self.ctx.write().unwrap()).as_ptr()
261263
}
262264

263265
#[pyslot]
@@ -306,7 +308,7 @@ impl PySslContext {
306308
.map_err(|e| convert_openssl_error(vm, e))?;
307309

308310
PySslContext {
309-
ctx: RefCell::new(builder),
311+
ctx: RwLock::new(builder),
310312
check_hostname,
311313
}
312314
.into_ref_with_type(vm, cls)
@@ -325,7 +327,7 @@ impl PySslContext {
325327

326328
#[pyproperty]
327329
fn verify_mode(&self) -> i32 {
328-
let mode = self.ctx().verify_mode();
330+
let mode = self.exec_ctx(|ctx| ctx.verify_mode());
329331
if mode == SslVerifyMode::NONE {
330332
CertRequirements::None.into()
331333
} else if mode == SslVerifyMode::PEER {
@@ -385,9 +387,10 @@ impl PySslContext {
385387
Either::B(b) => b.with_ref(X509::from_der),
386388
};
387389
let cert = cert.map_err(|e| convert_openssl_error(vm, e))?;
388-
let ctx = self.ctx();
389-
let store = ctx.cert_store();
390-
let ret = unsafe { sys::X509_STORE_add_cert(store.as_ptr(), cert.as_ptr()) };
390+
let ret = self.exec_ctx(|ctx| {
391+
let store = ctx.cert_store();
392+
unsafe { sys::X509_STORE_add_cert(store.as_ptr(), cert.as_ptr()) }
393+
});
391394
if ret <= 0 {
392395
return Err(convert_openssl_error(vm, ErrorStack::get()));
393396
}
@@ -424,7 +427,8 @@ impl PySslContext {
424427
use openssl::stack::StackRef;
425428
let binary_form = binary_form.unwrap_or(false);
426429
let certs = unsafe {
427-
let stack = sys::X509_STORE_get0_objects(self.ctx().cert_store().as_ptr());
430+
let stack =
431+
sys::X509_STORE_get0_objects(self.exec_ctx(|ctx| ctx.cert_store().as_ptr()));
428432
assert!(!stack.is_null());
429433
StackRef::<X509Object>::from_ptr(stack)
430434
};
@@ -467,10 +471,10 @@ impl PySslContext {
467471

468472
Ok(PySslSocket {
469473
ctx: zelf,
470-
stream: RefCell::new(Some(stream)),
474+
stream: RwLock::new(Some(stream)),
471475
socket_type,
472476
server_hostname: args.server_hostname,
473-
owner: RefCell::new(args.owner.as_ref().map(PyWeak::downgrade)),
477+
owner: RwLock::new(args.owner.as_ref().map(PyWeak::downgrade)),
474478
})
475479
}
476480
}
@@ -503,10 +507,10 @@ struct LoadVerifyLocationsArgs {
503507
#[pyclass(name = "_SSLSocket")]
504508
struct PySslSocket {
505509
ctx: PyRef<PySslContext>,
506-
stream: RefCell<Option<ssl::SslStreamBuilder<PySocketRef>>>,
510+
stream: RwLock<Option<ssl::SslStreamBuilder<PySocketRef>>>,
507511
socket_type: SslServerOrClient,
508512
server_hostname: Option<PyStringRef>,
509-
owner: RefCell<Option<PyWeak>>,
513+
owner: RwLock<Option<PyWeak>>,
510514
}
511515

512516
impl fmt::Debug for PySslSocket {
@@ -524,28 +528,32 @@ impl PyValue for PySslSocket {
524528
#[pyimpl]
525529
impl PySslSocket {
526530
fn stream_builder(&self) -> ssl::SslStreamBuilder<PySocketRef> {
527-
self.stream.replace(None).unwrap()
528-
}
529-
fn stream(&self) -> RefMut<ssl::SslStream<PySocketRef>> {
530-
RefMut::map(self.stream.borrow_mut(), |b| {
531-
let b = b.as_mut().unwrap();
532-
unsafe { &mut *(b as *mut ssl::SslStreamBuilder<_> as *mut ssl::SslStream<_>) }
531+
std::mem::replace(&mut *self.stream.write().unwrap(), None).unwrap()
532+
}
533+
fn exec_stream<F, R>(&self, func: F) -> R
534+
where
535+
F: Fn(&mut ssl::SslStream<PySocketRef>) -> R,
536+
{
537+
let mut b = self.stream.write().unwrap();
538+
func(unsafe {
539+
&mut *(b.as_mut().unwrap() as *mut ssl::SslStreamBuilder<_> as *mut ssl::SslStream<_>)
533540
})
534541
}
535542
fn set_stream(&self, stream: ssl::SslStream<PySocketRef>) {
536-
let prev = self
537-
.stream
538-
.replace(Some(unsafe { std::mem::transmute(stream) }));
539-
debug_assert!(prev.is_none());
543+
*self.stream.write().unwrap() = Some(unsafe { std::mem::transmute(stream) });
540544
}
541545

542546
#[pyproperty]
543547
fn owner(&self) -> Option<PyObjectRef> {
544-
self.owner.borrow().as_ref().and_then(PyWeak::upgrade)
548+
self.owner
549+
.read()
550+
.unwrap()
551+
.as_ref()
552+
.and_then(PyWeak::upgrade)
545553
}
546554
#[pyproperty(setter)]
547555
fn set_owner(&self, owner: PyObjectRef) {
548-
*self.owner.borrow_mut() = Some(PyWeak::downgrade(&owner))
556+
*self.owner.write().unwrap() = Some(PyWeak::downgrade(&owner))
549557
}
550558
#[pyproperty]
551559
fn server_side(&self) -> bool {
@@ -567,12 +575,10 @@ impl PySslSocket {
567575
vm: &VirtualMachine,
568576
) -> PyResult<Option<PyObjectRef>> {
569577
let binary = binary.unwrap_or(false);
570-
if !self.stream().ssl().is_init_finished() {
578+
if !self.exec_stream(|stream| stream.ssl().is_init_finished()) {
571579
return Err(vm.new_value_error("handshake not done yet".to_owned()));
572580
}
573-
self.stream()
574-
.ssl()
575-
.peer_certificate()
581+
self.exec_stream(|stream| stream.ssl().peer_certificate())
576582
.map(|cert| cert_to_py(vm, &cert, binary))
577583
.transpose()
578584
}
@@ -605,17 +611,18 @@ impl PySslSocket {
605611

606612
#[pymethod]
607613
fn write(&self, data: PyBytesLike, vm: &VirtualMachine) -> PyResult<usize> {
608-
data.with_ref(|b| self.stream().ssl_write(b))
614+
data.with_ref(|b| self.exec_stream(|stream| stream.ssl_write(b)))
609615
.map_err(|e| convert_ssl_error(vm, e))
610616
}
611617

612618
#[pymethod]
613619
fn read(&self, n: usize, buffer: OptionalArg<PyByteArrayRef>, vm: &VirtualMachine) -> PyResult {
614620
if let OptionalArg::Present(buffer) = buffer {
615-
let mut buf = buffer.borrow_value_mut();
616621
let n = self
617-
.stream()
618-
.ssl_read(&mut buf.elements)
622+
.exec_stream(|stream| {
623+
let mut buf = buffer.borrow_value_mut();
624+
stream.ssl_read(&mut buf.elements)
625+
})
619626
.map_err(|e| convert_ssl_error(vm, e))?;
620627
Ok(vm.new_int(n))
621628
} else {

0 commit comments

Comments
 (0)