@@ -11,10 +11,10 @@ use crate::pyobject::{
11
11
use crate :: types:: create_type;
12
12
use crate :: VirtualMachine ;
13
13
14
- use std:: cell:: { Ref , RefCell , RefMut } ;
15
14
use std:: convert:: TryFrom ;
16
15
use std:: ffi:: { CStr , CString } ;
17
16
use std:: fmt;
17
+ use std:: sync:: { RwLock , RwLockWriteGuard } ;
18
18
19
19
use foreign_types_shared:: { ForeignType , ForeignTypeRef } ;
20
20
use openssl:: {
@@ -230,7 +230,7 @@ fn ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec<u8>, bool
230
230
231
231
#[ pyclass( name = "_SSLContext" ) ]
232
232
struct PySslContext {
233
- ctx : RefCell < SslContextBuilder > ,
233
+ ctx : RwLock < SslContextBuilder > ,
234
234
check_hostname : bool ,
235
235
}
236
236
@@ -248,16 +248,18 @@ impl PyValue for PySslContext {
248
248
249
249
#[ pyimpl( flags( BASETYPE ) ) ]
250
250
impl PySslContext {
251
- fn builder ( & self ) -> RefMut < SslContextBuilder > {
252
- self . ctx . borrow_mut ( )
251
+ fn builder ( & self ) -> RwLockWriteGuard < ' _ , SslContextBuilder > {
252
+ self . ctx . write ( ) . unwrap ( )
253
253
}
254
- fn ctx ( & self ) -> Ref < ssl:: SslContextRef > {
255
- Ref :: map ( self . ctx . borrow ( ) , |ctx| unsafe {
256
- & * * ( ctx as * const SslContextBuilder as * const ssl:: SslContext )
257
- } )
254
+ fn exec_ctx < F , R > ( & self , func : F ) -> R
255
+ where
256
+ F : Fn ( & ssl:: SslContextRef ) -> R ,
257
+ {
258
+ let c = self . ctx . read ( ) . unwrap ( ) ;
259
+ func ( unsafe { & * * ( & * c as * const SslContextBuilder as * const ssl:: SslContext ) } )
258
260
}
259
261
fn ptr ( & self ) -> * mut sys:: SSL_CTX {
260
- self . ctx . borrow ( ) . as_ptr ( )
262
+ ( * self . ctx . write ( ) . unwrap ( ) ) . as_ptr ( )
261
263
}
262
264
263
265
#[ pyslot]
@@ -306,7 +308,7 @@ impl PySslContext {
306
308
. map_err ( |e| convert_openssl_error ( vm, e) ) ?;
307
309
308
310
PySslContext {
309
- ctx : RefCell :: new ( builder) ,
311
+ ctx : RwLock :: new ( builder) ,
310
312
check_hostname,
311
313
}
312
314
. into_ref_with_type ( vm, cls)
@@ -325,7 +327,7 @@ impl PySslContext {
325
327
326
328
#[ pyproperty]
327
329
fn verify_mode ( & self ) -> i32 {
328
- let mode = self . ctx ( ) . verify_mode ( ) ;
330
+ let mode = self . exec_ctx ( |ctx| ctx . verify_mode ( ) ) ;
329
331
if mode == SslVerifyMode :: NONE {
330
332
CertRequirements :: None . into ( )
331
333
} else if mode == SslVerifyMode :: PEER {
@@ -385,9 +387,10 @@ impl PySslContext {
385
387
Either :: B ( b) => b. with_ref ( X509 :: from_der) ,
386
388
} ;
387
389
let cert = cert. map_err ( |e| convert_openssl_error ( vm, e) ) ?;
388
- let ctx = self . ctx ( ) ;
389
- let store = ctx. cert_store ( ) ;
390
- let ret = unsafe { sys:: X509_STORE_add_cert ( store. as_ptr ( ) , cert. as_ptr ( ) ) } ;
390
+ let ret = self . exec_ctx ( |ctx| {
391
+ let store = ctx. cert_store ( ) ;
392
+ unsafe { sys:: X509_STORE_add_cert ( store. as_ptr ( ) , cert. as_ptr ( ) ) }
393
+ } ) ;
391
394
if ret <= 0 {
392
395
return Err ( convert_openssl_error ( vm, ErrorStack :: get ( ) ) ) ;
393
396
}
@@ -424,7 +427,8 @@ impl PySslContext {
424
427
use openssl:: stack:: StackRef ;
425
428
let binary_form = binary_form. unwrap_or ( false ) ;
426
429
let certs = unsafe {
427
- let stack = sys:: X509_STORE_get0_objects ( self . ctx ( ) . cert_store ( ) . as_ptr ( ) ) ;
430
+ let stack =
431
+ sys:: X509_STORE_get0_objects ( self . exec_ctx ( |ctx| ctx. cert_store ( ) . as_ptr ( ) ) ) ;
428
432
assert ! ( !stack. is_null( ) ) ;
429
433
StackRef :: < X509Object > :: from_ptr ( stack)
430
434
} ;
@@ -467,10 +471,10 @@ impl PySslContext {
467
471
468
472
Ok ( PySslSocket {
469
473
ctx : zelf,
470
- stream : RefCell :: new ( Some ( stream) ) ,
474
+ stream : RwLock :: new ( Some ( stream) ) ,
471
475
socket_type,
472
476
server_hostname : args. server_hostname ,
473
- owner : RefCell :: new ( args. owner . as_ref ( ) . map ( PyWeak :: downgrade) ) ,
477
+ owner : RwLock :: new ( args. owner . as_ref ( ) . map ( PyWeak :: downgrade) ) ,
474
478
} )
475
479
}
476
480
}
@@ -503,10 +507,10 @@ struct LoadVerifyLocationsArgs {
503
507
#[ pyclass( name = "_SSLSocket" ) ]
504
508
struct PySslSocket {
505
509
ctx : PyRef < PySslContext > ,
506
- stream : RefCell < Option < ssl:: SslStreamBuilder < PySocketRef > > > ,
510
+ stream : RwLock < Option < ssl:: SslStreamBuilder < PySocketRef > > > ,
507
511
socket_type : SslServerOrClient ,
508
512
server_hostname : Option < PyStringRef > ,
509
- owner : RefCell < Option < PyWeak > > ,
513
+ owner : RwLock < Option < PyWeak > > ,
510
514
}
511
515
512
516
impl fmt:: Debug for PySslSocket {
@@ -524,28 +528,32 @@ impl PyValue for PySslSocket {
524
528
#[ pyimpl]
525
529
impl PySslSocket {
526
530
fn stream_builder ( & self ) -> ssl:: SslStreamBuilder < PySocketRef > {
527
- self . stream . replace ( None ) . unwrap ( )
528
- }
529
- fn stream ( & self ) -> RefMut < ssl:: SslStream < PySocketRef > > {
530
- RefMut :: map ( self . stream . borrow_mut ( ) , |b| {
531
- let b = b. as_mut ( ) . unwrap ( ) ;
532
- unsafe { & mut * ( b as * mut ssl:: SslStreamBuilder < _ > as * mut ssl:: SslStream < _ > ) }
531
+ std:: mem:: replace ( & mut * self . stream . write ( ) . unwrap ( ) , None ) . unwrap ( )
532
+ }
533
+ fn exec_stream < F , R > ( & self , func : F ) -> R
534
+ where
535
+ F : Fn ( & mut ssl:: SslStream < PySocketRef > ) -> R ,
536
+ {
537
+ let mut b = self . stream . write ( ) . unwrap ( ) ;
538
+ func ( unsafe {
539
+ & mut * ( b. as_mut ( ) . unwrap ( ) as * mut ssl:: SslStreamBuilder < _ > as * mut ssl:: SslStream < _ > )
533
540
} )
534
541
}
535
542
fn set_stream ( & self , stream : ssl:: SslStream < PySocketRef > ) {
536
- let prev = self
537
- . stream
538
- . replace ( Some ( unsafe { std:: mem:: transmute ( stream) } ) ) ;
539
- debug_assert ! ( prev. is_none( ) ) ;
543
+ * self . stream . write ( ) . unwrap ( ) = Some ( unsafe { std:: mem:: transmute ( stream) } ) ;
540
544
}
541
545
542
546
#[ pyproperty]
543
547
fn owner ( & self ) -> Option < PyObjectRef > {
544
- self . owner . borrow ( ) . as_ref ( ) . and_then ( PyWeak :: upgrade)
548
+ self . owner
549
+ . read ( )
550
+ . unwrap ( )
551
+ . as_ref ( )
552
+ . and_then ( PyWeak :: upgrade)
545
553
}
546
554
#[ pyproperty( setter) ]
547
555
fn set_owner ( & self , owner : PyObjectRef ) {
548
- * self . owner . borrow_mut ( ) = Some ( PyWeak :: downgrade ( & owner) )
556
+ * self . owner . write ( ) . unwrap ( ) = Some ( PyWeak :: downgrade ( & owner) )
549
557
}
550
558
#[ pyproperty]
551
559
fn server_side ( & self ) -> bool {
@@ -567,12 +575,10 @@ impl PySslSocket {
567
575
vm : & VirtualMachine ,
568
576
) -> PyResult < Option < PyObjectRef > > {
569
577
let binary = binary. unwrap_or ( false ) ;
570
- if !self . stream ( ) . ssl ( ) . is_init_finished ( ) {
578
+ if !self . exec_stream ( |stream| stream . ssl ( ) . is_init_finished ( ) ) {
571
579
return Err ( vm. new_value_error ( "handshake not done yet" . to_owned ( ) ) ) ;
572
580
}
573
- self . stream ( )
574
- . ssl ( )
575
- . peer_certificate ( )
581
+ self . exec_stream ( |stream| stream. ssl ( ) . peer_certificate ( ) )
576
582
. map ( |cert| cert_to_py ( vm, & cert, binary) )
577
583
. transpose ( )
578
584
}
@@ -605,17 +611,18 @@ impl PySslSocket {
605
611
606
612
#[ pymethod]
607
613
fn write ( & self , data : PyBytesLike , vm : & VirtualMachine ) -> PyResult < usize > {
608
- data. with_ref ( |b| self . stream ( ) . ssl_write ( b) )
614
+ data. with_ref ( |b| self . exec_stream ( |stream| stream . ssl_write ( b) ) )
609
615
. map_err ( |e| convert_ssl_error ( vm, e) )
610
616
}
611
617
612
618
#[ pymethod]
613
619
fn read ( & self , n : usize , buffer : OptionalArg < PyByteArrayRef > , vm : & VirtualMachine ) -> PyResult {
614
620
if let OptionalArg :: Present ( buffer) = buffer {
615
- let mut buf = buffer. borrow_value_mut ( ) ;
616
621
let n = self
617
- . stream ( )
618
- . ssl_read ( & mut buf. elements )
622
+ . exec_stream ( |stream| {
623
+ let mut buf = buffer. borrow_value_mut ( ) ;
624
+ stream. ssl_read ( & mut buf. elements )
625
+ } )
619
626
. map_err ( |e| convert_ssl_error ( vm, e) ) ?;
620
627
Ok ( vm. new_int ( n) )
621
628
} else {
0 commit comments