1
+ use super :: os:: PyPathLike ;
1
2
use super :: socket:: PySocketRef ;
2
3
use crate :: builtins:: bytearray:: PyByteArrayRef ;
3
4
use crate :: builtins:: pystr:: PyStrRef ;
4
5
use crate :: builtins:: { pytype:: PyTypeRef , weakref:: PyWeak } ;
5
6
use crate :: byteslike:: PyBytesLike ;
6
- use crate :: common:: lock:: { PyRwLock , PyRwLockWriteGuard } ;
7
+ use crate :: common:: lock:: { PyRwLock , PyRwLockReadGuard , PyRwLockWriteGuard } ;
7
8
use crate :: exceptions:: { IntoPyException , PyBaseExceptionRef } ;
8
- use crate :: function:: OptionalArg ;
9
+ use crate :: function:: { OptionalArg , OptionalOption } ;
9
10
use crate :: pyobject:: {
10
- BorrowValue , Either , IntoPyObject , ItemProtocol , PyClassImpl , PyObjectRef , PyRef , PyResult ,
11
- PyValue , StaticType ,
11
+ BorrowValue , Either , IntoPyObject , ItemProtocol , PyCallable , PyClassImpl , PyObjectRef , PyRef ,
12
+ PyResult , PyValue , StaticType ,
12
13
} ;
13
14
use crate :: types:: create_simple_type;
14
15
use crate :: VirtualMachine ;
15
16
17
+ use crossbeam_utils:: atomic:: AtomicCell ;
16
18
use foreign_types_shared:: { ForeignType , ForeignTypeRef } ;
17
19
use openssl:: {
18
20
asn1:: { Asn1Object , Asn1ObjectRef } ,
@@ -24,6 +26,7 @@ use openssl::{
24
26
use std:: convert:: TryFrom ;
25
27
use std:: ffi:: { CStr , CString } ;
26
28
use std:: fmt;
29
+ use std:: time:: Instant ;
27
30
28
31
mod sys {
29
32
#![ allow( non_camel_case_types, unused) ]
@@ -231,7 +234,7 @@ fn _ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec<u8>, boo
231
234
#[ pyclass( module = "ssl" , name = "_SSLContext" ) ]
232
235
struct PySslContext {
233
236
ctx : PyRwLock < SslContextBuilder > ,
234
- check_hostname : bool ,
237
+ check_hostname : AtomicCell < bool > ,
235
238
}
236
239
237
240
impl fmt:: Debug for PySslContext {
@@ -246,6 +249,10 @@ impl PyValue for PySslContext {
246
249
}
247
250
}
248
251
252
+ fn builder_as_ctx ( x : & SslContextBuilder ) -> & ssl:: SslContextRef {
253
+ unsafe { ssl:: SslContextRef :: from_ptr ( x. as_ptr ( ) ) }
254
+ }
255
+
249
256
#[ pyimpl( flags( BASETYPE ) ) ]
250
257
impl PySslContext {
251
258
fn builder ( & self ) -> PyRwLockWriteGuard < ' _ , SslContextBuilder > {
@@ -256,7 +263,7 @@ impl PySslContext {
256
263
F : Fn ( & ssl:: SslContextRef ) -> R ,
257
264
{
258
265
let c = self . ctx . read ( ) ;
259
- func ( unsafe { & * * ( & * c as * const SslContextBuilder as * const ssl :: SslContext ) } )
266
+ func ( builder_as_ctx ( & c ) )
260
267
}
261
268
fn ptr ( & self ) -> * mut sys:: SSL_CTX {
262
269
( * self . ctx . write ( ) ) . as_ptr ( )
@@ -309,7 +316,7 @@ impl PySslContext {
309
316
310
317
PySslContext {
311
318
ctx : PyRwLock :: new ( builder) ,
312
- check_hostname,
319
+ check_hostname : AtomicCell :: new ( check_hostname ) ,
313
320
}
314
321
. into_ref_with_type ( vm, cls)
315
322
}
@@ -340,10 +347,11 @@ impl PySslContext {
340
347
}
341
348
#[ pyproperty( setter) ]
342
349
fn set_verify_mode ( & self , cert : i32 , vm : & VirtualMachine ) -> PyResult < ( ) > {
350
+ let mut ctx = self . builder ( ) ;
343
351
let cert_req = CertRequirements :: try_from ( cert)
344
352
. map_err ( |_| vm. new_value_error ( "invalid value for verify_mode" . to_owned ( ) ) ) ?;
345
353
let mode = match cert_req {
346
- CertRequirements :: None if self . check_hostname => {
354
+ CertRequirements :: None if self . check_hostname . load ( ) => {
347
355
return Err ( vm. new_value_error (
348
356
"Cannot set verify_mode to CERT_NONE when check_hostname is enabled."
349
357
. to_owned ( ) ,
@@ -353,9 +361,21 @@ impl PySslContext {
353
361
CertRequirements :: Optional => SslVerifyMode :: PEER ,
354
362
CertRequirements :: Required => SslVerifyMode :: PEER | SslVerifyMode :: FAIL_IF_NO_PEER_CERT ,
355
363
} ;
356
- self . builder ( ) . set_verify ( mode) ;
364
+ ctx . set_verify ( mode) ;
357
365
Ok ( ( ) )
358
366
}
367
+ #[ pyproperty]
368
+ fn check_hostname ( & self ) -> bool {
369
+ self . check_hostname . load ( )
370
+ }
371
+ #[ pyproperty( setter) ]
372
+ fn set_check_hostname ( & self , ch : bool ) {
373
+ let mut ctx = self . builder ( ) ;
374
+ if ch && builder_as_ctx ( & ctx) . verify_mode ( ) == SslVerifyMode :: NONE {
375
+ ctx. set_verify ( SslVerifyMode :: PEER | SslVerifyMode :: FAIL_IF_NO_PEER_CERT ) ;
376
+ }
377
+ self . check_hostname . store ( ch) ;
378
+ }
359
379
360
380
#[ pymethod]
361
381
fn set_default_verify_paths ( & self , vm : & VirtualMachine ) -> PyResult < ( ) > {
@@ -442,6 +462,30 @@ impl PySslContext {
442
462
Ok ( vm. ctx . new_list ( certs) )
443
463
}
444
464
465
+ #[ pymethod]
466
+ fn load_cert_chain (
467
+ & self ,
468
+ certfile : PyPathLike ,
469
+ keyfile : OptionalArg < PyPathLike > ,
470
+ password : OptionalOption < Either < PyStrRef , PyCallable > > ,
471
+ vm : & VirtualMachine ,
472
+ ) -> PyResult < ( ) > {
473
+ // TODO: requires passing a callback to C
474
+ if password. flatten ( ) . is_some ( ) {
475
+ return Err ( vm. new_not_implemented_error ( "password arg not yet supported" . to_owned ( ) ) ) ;
476
+ }
477
+ let mut ctx = self . builder ( ) ;
478
+ ctx. set_certificate_chain_file ( & certfile)
479
+ . and_then ( |( ) | {
480
+ ctx. set_private_key_file (
481
+ keyfile. as_ref ( ) . unwrap_or ( & certfile) ,
482
+ ssl:: SslFiletype :: PEM ,
483
+ )
484
+ } )
485
+ . and_then ( |( ) | ctx. check_private_key ( ) )
486
+ . map_err ( |e| convert_openssl_error ( vm, e) )
487
+ }
488
+
445
489
#[ pymethod]
446
490
fn _wrap_socket (
447
491
zelf : PyRef < Self > ,
@@ -471,7 +515,7 @@ impl PySslContext {
471
515
472
516
Ok ( PySslSocket {
473
517
ctx : zelf,
474
- stream : PyRwLock :: new ( Some ( stream) ) ,
518
+ stream : PyRwLock :: new ( stream) ,
475
519
socket_type,
476
520
server_hostname : args. server_hostname ,
477
521
owner : PyRwLock :: new ( args. owner . as_ref ( ) . map ( PyWeak :: downgrade) ) ,
@@ -507,7 +551,7 @@ struct LoadVerifyLocationsArgs {
507
551
#[ pyclass( module = "ssl" , name = "_SSLSocket" ) ]
508
552
struct PySslSocket {
509
553
ctx : PyRef < PySslContext > ,
510
- stream : PyRwLock < Option < ssl:: SslStreamBuilder < PySocketRef > > > ,
554
+ stream : PyRwLock < ssl:: SslStreamBuilder < PySocketRef > > ,
511
555
socket_type : SslServerOrClient ,
512
556
server_hostname : Option < PyStrRef > ,
513
557
owner : PyRwLock < Option < PyWeak > > ,
@@ -527,20 +571,19 @@ impl PyValue for PySslSocket {
527
571
528
572
#[ pyimpl]
529
573
impl PySslSocket {
530
- fn stream_builder ( & self ) -> ssl:: SslStreamBuilder < PySocketRef > {
531
- std:: mem:: replace ( & mut * self . stream . write ( ) , None ) . unwrap ( )
532
- }
533
- fn exec_stream < F , R > ( & self , func : F ) -> R
534
- where
535
- F : Fn ( & mut ssl:: SslStream < PySocketRef > ) -> R ,
536
- {
537
- let mut b = self . stream . write ( ) ;
538
- func ( unsafe {
539
- & mut * ( b. as_mut ( ) . unwrap ( ) as * mut ssl:: SslStreamBuilder < _ > as * mut ssl:: SslStream < _ > )
574
+ fn stream ( & self ) -> impl std:: ops:: Deref < Target = ssl:: SslStream < PySocketRef > > + ' _ {
575
+ let s = self . stream . read ( ) ;
576
+ // SAFETY: SslStreamBuilder is just a wrapper around SslStream
577
+ PyRwLockReadGuard :: map ( s, |s| unsafe {
578
+ & * ( s as * const _ as * const ssl:: SslStream < _ > )
540
579
} )
541
580
}
542
- fn set_stream ( & self , stream : ssl:: SslStream < PySocketRef > ) {
543
- * self . stream . write ( ) = Some ( unsafe { std:: mem:: transmute ( stream) } ) ;
581
+ fn stream_mut ( & self ) -> impl std:: ops:: DerefMut < Target = ssl:: SslStream < PySocketRef > > + ' _ {
582
+ let s = self . stream . write ( ) ;
583
+ // SAFETY: SslStreamBuilder is just a wrapper around SslStream
584
+ PyRwLockWriteGuard :: map ( s, |s| unsafe {
585
+ & mut * ( s as * mut _ as * mut ssl:: SslStream < _ > )
586
+ } )
544
587
}
545
588
546
589
#[ pyproperty]
@@ -571,61 +614,101 @@ impl PySslSocket {
571
614
vm : & VirtualMachine ,
572
615
) -> PyResult < Option < PyObjectRef > > {
573
616
let binary = binary. unwrap_or ( false ) ;
574
- if !self . exec_stream ( |stream| stream. ssl ( ) . is_init_finished ( ) ) {
617
+ let stream = self . stream ( ) ;
618
+ if !stream. ssl ( ) . is_init_finished ( ) {
575
619
return Err ( vm. new_value_error ( "handshake not done yet" . to_owned ( ) ) ) ;
576
620
}
577
- self . exec_stream ( |stream| stream. ssl ( ) . peer_certificate ( ) )
621
+ stream
622
+ . ssl ( )
623
+ . peer_certificate ( )
578
624
. map ( |cert| cert_to_py ( vm, & cert, binary) )
579
625
. transpose ( )
580
626
}
581
627
582
628
#[ pymethod]
583
629
fn do_handshake ( & self , vm : & VirtualMachine ) -> PyResult < ( ) > {
584
- // Either a stream builder or a mid-handshake stream from WANT_READ or WANT_WRITE
585
- let mut handshaker: Either < _ , ssl:: MidHandshakeSslStream < _ > > =
586
- Either :: A ( self . stream_builder ( ) ) ;
630
+ let stream_builder = self . stream . write ( ) ;
631
+ let timeout = stream_builder
632
+ . get_ref ( )
633
+ . sock ( )
634
+ . read_timeout ( )
635
+ . ok ( )
636
+ . flatten ( )
637
+ . map ( |dur| ( dur, Instant :: now ( ) ) ) ;
638
+ let mut stream = unsafe { std:: ptr:: read ( & * stream_builder) } ;
587
639
loop {
588
- let handshake_result = match handshaker {
589
- Either :: A ( s) => s. handshake ( ) ,
590
- Either :: B ( s) => s. handshake ( ) ,
591
- } ;
592
- match handshake_result {
593
- Ok ( stream) => {
594
- self . set_stream ( stream) ;
640
+ match stream. handshake ( ) {
641
+ Ok ( s) => {
642
+ // s and stream_builder are the same thing
643
+ std:: mem:: forget ( s) ;
595
644
return Ok ( ( ) ) ;
596
645
}
597
- Err ( ssl:: HandshakeError :: SetupFailure ( e) ) => {
598
- return Err ( convert_openssl_error ( vm, e) )
646
+ Err ( ssl:: HandshakeError :: SetupFailure ( _e) ) => {
647
+ // handshake() error handling code never constructs this
648
+ unreachable ! ( ) ;
649
+ // return Err(convert_openssl_error(vm, e))
650
+ }
651
+ Err ( ssl:: HandshakeError :: WouldBlock ( s) ) => {
652
+ std:: mem:: forget ( s) ;
653
+ stream = unsafe { std:: ptr:: read ( & * stream_builder) } ;
599
654
}
600
- Err ( ssl:: HandshakeError :: WouldBlock ( s) ) => handshaker = Either :: B ( s) ,
601
655
Err ( ssl:: HandshakeError :: Failure ( s) ) => {
602
- return Err ( convert_ssl_error ( vm, s. into_error ( ) ) )
656
+ let err = convert_ssl_error ( vm, s. error ( ) ) ;
657
+ std:: mem:: forget ( s) ;
658
+ return Err ( err) ;
659
+ }
660
+ }
661
+ if let Some ( ( timeout, ref start) ) = timeout {
662
+ if start. elapsed ( ) >= timeout {
663
+ std:: mem:: forget ( stream) ;
664
+ let socket_timeout = vm. class ( "_socket" , "timeout" ) ;
665
+ return Err ( vm. new_exception_msg (
666
+ socket_timeout,
667
+ "The handshake operation timed out" . to_owned ( ) ,
668
+ ) ) ;
603
669
}
604
670
}
605
671
}
606
672
}
607
673
608
674
#[ pymethod]
609
675
fn write ( & self , data : PyBytesLike , vm : & VirtualMachine ) -> PyResult < usize > {
610
- data. with_ref ( |b| self . exec_stream ( |stream| stream. ssl_write ( b) ) )
676
+ let mut stream = self . stream_mut ( ) ;
677
+ data. with_ref ( |b| stream. ssl_write ( b) )
611
678
. map_err ( |e| convert_ssl_error ( vm, e) )
612
679
}
613
680
614
681
#[ pymethod]
615
682
fn read ( & self , n : usize , buffer : OptionalArg < PyByteArrayRef > , vm : & VirtualMachine ) -> PyResult {
616
- if let OptionalArg :: Present ( buffer) = buffer {
617
- let n = self
618
- . exec_stream ( |stream| {
619
- let mut buf = buffer. borrow_value_mut ( ) ;
620
- stream. ssl_read ( & mut buf. elements )
621
- } )
622
- . map_err ( |e| convert_ssl_error ( vm, e) ) ?;
623
- Ok ( vm. ctx . new_int ( n) )
683
+ let mut stream = self . stream_mut ( ) ;
684
+ let ret_nread = buffer. is_present ( ) ;
685
+ let ssl_res = if let OptionalArg :: Present ( buffer) = buffer {
686
+ let mut buf = buffer. borrow_value_mut ( ) ;
687
+ stream
688
+ . ssl_read ( & mut buf. elements )
689
+ . map ( |n| vm. ctx . new_int ( n) )
624
690
} else {
625
691
let mut buf = vec ! [ 0u8 ; n] ;
626
- buf. truncate ( n) ;
627
- Ok ( vm. ctx . new_bytes ( buf) )
628
- }
692
+ stream. ssl_read ( & mut buf) . map ( |n| {
693
+ buf. truncate ( n) ;
694
+ vm. ctx . new_bytes ( buf)
695
+ } )
696
+ } ;
697
+ ssl_res. or_else ( |e| {
698
+ if e. code ( ) == ssl:: ErrorCode :: ZERO_RETURN
699
+ && stream. get_shutdown ( ) == ssl:: ShutdownState :: RECEIVED
700
+ {
701
+ Ok ( if ret_nread {
702
+ vm. ctx . new_int ( 0 )
703
+ } else {
704
+ vm. ctx . new_bytes ( vec ! [ ] )
705
+ } )
706
+ } else {
707
+ Err ( convert_ssl_error ( vm, e) )
708
+ }
709
+ } )
710
+
711
+ // .map_err(|e| convert_ssl_error(vm, e))?;
629
712
}
630
713
}
631
714
@@ -645,15 +728,28 @@ fn convert_openssl_error(vm: &VirtualMachine, err: ErrorStack) -> PyBaseExceptio
645
728
// );
646
729
// TODO: map the error codes to code names, e.g. "CERTIFICATE_VERIFY_FAILED", just requires a big hashmap/dict
647
730
let msg = e. to_string ( ) ;
648
- vm. new_exception_msg ( cls, msg)
731
+ vm. new_exception ( cls, vec ! [ vm . ctx . new_int ( e . code ( ) ) , vm . ctx . new_str ( msg) ] )
649
732
}
650
733
None => vm. new_exception_empty ( cls) ,
651
734
}
652
735
}
653
- fn convert_ssl_error ( vm : & VirtualMachine , e : ssl:: Error ) -> PyBaseExceptionRef {
654
- match e. into_io_error ( ) {
655
- Ok ( io_err) => io_err. into_pyexception ( vm) ,
656
- Err ( e) => convert_openssl_error ( vm, e. ssl_error ( ) . unwrap ( ) . clone ( ) ) ,
736
+ fn convert_ssl_error (
737
+ vm : & VirtualMachine ,
738
+ e : impl std:: borrow:: Borrow < ssl:: Error > ,
739
+ ) -> PyBaseExceptionRef {
740
+ let e = e. borrow ( ) ;
741
+ match e. io_error ( ) {
742
+ Some ( io_err) => io_err. into_pyexception ( vm) ,
743
+ None => match e. ssl_error ( ) {
744
+ Some ( e) => convert_openssl_error ( vm, e. clone ( ) ) ,
745
+ None => vm. new_exception (
746
+ ssl_error ( vm) ,
747
+ vec ! [
748
+ vm. ctx. new_int( e. code( ) . as_raw( ) ) ,
749
+ vm. ctx. new_str( e. to_string( ) ) ,
750
+ ] ,
751
+ ) ,
752
+ } ,
657
753
}
658
754
}
659
755
@@ -805,7 +901,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
805
901
"SSL_ERROR_SYSCALL" => ctx. new_int( sys:: SSL_ERROR_SYSCALL ) ,
806
902
"SSL_ERROR_SSL" => ctx. new_int( sys:: SSL_ERROR_SSL ) ,
807
903
"SSL_ERROR_WANT_CONNECT" => ctx. new_int( sys:: SSL_ERROR_WANT_CONNECT ) ,
808
- // "SSL_ERROR_EOF" => ctx.new_int(sys::SSL_ERROR_EOF),
904
+ "SSL_ERROR_EOF" => ctx. new_int( 8 ) , // custom for python
809
905
// "SSL_ERROR_INVALID_ERROR_CODE" => ctx.new_int(sys::SSL_ERROR_INVALID_ERROR_CODE),
810
906
// TODO: so many more of these
811
907
"ALERT_DESCRIPTION_DECODE_ERROR" => ctx. new_int( sys:: SSL_AD_DECODE_ERROR ) ,
0 commit comments