Skip to content

Commit 10a7f5d

Browse files
committed
Fix SslSocket methods, sorta
1 parent 9647fce commit 10a7f5d

File tree

2 files changed

+158
-57
lines changed

2 files changed

+158
-57
lines changed

vm/src/stdlib/os.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ fn make_path<'a>(
170170
}
171171

172172
impl IntoPyException for io::Error {
173+
fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef {
174+
(&self).into_pyexception(vm)
175+
}
176+
}
177+
impl IntoPyException for &'_ io::Error {
173178
fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef {
174179
#[allow(unreachable_patterns)] // some errors are just aliases of each other
175180
let exc_type = match self.kind() {

vm/src/stdlib/ssl.rs

Lines changed: 153 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1+
use super::os::PyPathLike;
12
use super::socket::PySocketRef;
23
use crate::builtins::bytearray::PyByteArrayRef;
34
use crate::builtins::pystr::PyStrRef;
45
use crate::builtins::{pytype::PyTypeRef, weakref::PyWeak};
56
use crate::byteslike::PyBytesLike;
6-
use crate::common::lock::{PyRwLock, PyRwLockWriteGuard};
7+
use crate::common::lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard};
78
use crate::exceptions::{IntoPyException, PyBaseExceptionRef};
8-
use crate::function::OptionalArg;
9+
use crate::function::{OptionalArg, OptionalOption};
910
use crate::pyobject::{
10-
BorrowValue, Either, IntoPyObject, ItemProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult,
11-
PyValue, StaticType,
11+
BorrowValue, Either, IntoPyObject, ItemProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef,
12+
PyResult, PyValue, StaticType,
1213
};
1314
use crate::types::create_simple_type;
1415
use crate::VirtualMachine;
1516

17+
use crossbeam_utils::atomic::AtomicCell;
1618
use foreign_types_shared::{ForeignType, ForeignTypeRef};
1719
use openssl::{
1820
asn1::{Asn1Object, Asn1ObjectRef},
@@ -24,6 +26,7 @@ use openssl::{
2426
use std::convert::TryFrom;
2527
use std::ffi::{CStr, CString};
2628
use std::fmt;
29+
use std::time::Instant;
2730

2831
mod sys {
2932
#![allow(non_camel_case_types, unused)]
@@ -231,7 +234,7 @@ fn _ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec<u8>, boo
231234
#[pyclass(module = "ssl", name = "_SSLContext")]
232235
struct PySslContext {
233236
ctx: PyRwLock<SslContextBuilder>,
234-
check_hostname: bool,
237+
check_hostname: AtomicCell<bool>,
235238
}
236239

237240
impl fmt::Debug for PySslContext {
@@ -246,6 +249,10 @@ impl PyValue for PySslContext {
246249
}
247250
}
248251

252+
fn builder_as_ctx(x: &SslContextBuilder) -> &ssl::SslContextRef {
253+
unsafe { ssl::SslContextRef::from_ptr(x.as_ptr()) }
254+
}
255+
249256
#[pyimpl(flags(BASETYPE))]
250257
impl PySslContext {
251258
fn builder(&self) -> PyRwLockWriteGuard<'_, SslContextBuilder> {
@@ -256,7 +263,7 @@ impl PySslContext {
256263
F: Fn(&ssl::SslContextRef) -> R,
257264
{
258265
let c = self.ctx.read();
259-
func(unsafe { &**(&*c as *const SslContextBuilder as *const ssl::SslContext) })
266+
func(builder_as_ctx(&c))
260267
}
261268
fn ptr(&self) -> *mut sys::SSL_CTX {
262269
(*self.ctx.write()).as_ptr()
@@ -309,7 +316,7 @@ impl PySslContext {
309316

310317
PySslContext {
311318
ctx: PyRwLock::new(builder),
312-
check_hostname,
319+
check_hostname: AtomicCell::new(check_hostname),
313320
}
314321
.into_ref_with_type(vm, cls)
315322
}
@@ -340,10 +347,11 @@ impl PySslContext {
340347
}
341348
#[pyproperty(setter)]
342349
fn set_verify_mode(&self, cert: i32, vm: &VirtualMachine) -> PyResult<()> {
350+
let mut ctx = self.builder();
343351
let cert_req = CertRequirements::try_from(cert)
344352
.map_err(|_| vm.new_value_error("invalid value for verify_mode".to_owned()))?;
345353
let mode = match cert_req {
346-
CertRequirements::None if self.check_hostname => {
354+
CertRequirements::None if self.check_hostname.load() => {
347355
return Err(vm.new_value_error(
348356
"Cannot set verify_mode to CERT_NONE when check_hostname is enabled."
349357
.to_owned(),
@@ -353,9 +361,21 @@ impl PySslContext {
353361
CertRequirements::Optional => SslVerifyMode::PEER,
354362
CertRequirements::Required => SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT,
355363
};
356-
self.builder().set_verify(mode);
364+
ctx.set_verify(mode);
357365
Ok(())
358366
}
367+
#[pyproperty]
368+
fn check_hostname(&self) -> bool {
369+
self.check_hostname.load()
370+
}
371+
#[pyproperty(setter)]
372+
fn set_check_hostname(&self, ch: bool) {
373+
let mut ctx = self.builder();
374+
if ch && builder_as_ctx(&ctx).verify_mode() == SslVerifyMode::NONE {
375+
ctx.set_verify(SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT);
376+
}
377+
self.check_hostname.store(ch);
378+
}
359379

360380
#[pymethod]
361381
fn set_default_verify_paths(&self, vm: &VirtualMachine) -> PyResult<()> {
@@ -442,6 +462,30 @@ impl PySslContext {
442462
Ok(vm.ctx.new_list(certs))
443463
}
444464

465+
#[pymethod]
466+
fn load_cert_chain(
467+
&self,
468+
certfile: PyPathLike,
469+
keyfile: OptionalArg<PyPathLike>,
470+
password: OptionalOption<Either<PyStrRef, PyCallable>>,
471+
vm: &VirtualMachine,
472+
) -> PyResult<()> {
473+
// TODO: requires passing a callback to C
474+
if password.flatten().is_some() {
475+
return Err(vm.new_not_implemented_error("password arg not yet supported".to_owned()));
476+
}
477+
let mut ctx = self.builder();
478+
ctx.set_certificate_chain_file(&certfile)
479+
.and_then(|()| {
480+
ctx.set_private_key_file(
481+
keyfile.as_ref().unwrap_or(&certfile),
482+
ssl::SslFiletype::PEM,
483+
)
484+
})
485+
.and_then(|()| ctx.check_private_key())
486+
.map_err(|e| convert_openssl_error(vm, e))
487+
}
488+
445489
#[pymethod]
446490
fn _wrap_socket(
447491
zelf: PyRef<Self>,
@@ -471,7 +515,7 @@ impl PySslContext {
471515

472516
Ok(PySslSocket {
473517
ctx: zelf,
474-
stream: PyRwLock::new(Some(stream)),
518+
stream: PyRwLock::new(stream),
475519
socket_type,
476520
server_hostname: args.server_hostname,
477521
owner: PyRwLock::new(args.owner.as_ref().map(PyWeak::downgrade)),
@@ -507,7 +551,7 @@ struct LoadVerifyLocationsArgs {
507551
#[pyclass(module = "ssl", name = "_SSLSocket")]
508552
struct PySslSocket {
509553
ctx: PyRef<PySslContext>,
510-
stream: PyRwLock<Option<ssl::SslStreamBuilder<PySocketRef>>>,
554+
stream: PyRwLock<ssl::SslStreamBuilder<PySocketRef>>,
511555
socket_type: SslServerOrClient,
512556
server_hostname: Option<PyStrRef>,
513557
owner: PyRwLock<Option<PyWeak>>,
@@ -527,20 +571,19 @@ impl PyValue for PySslSocket {
527571

528572
#[pyimpl]
529573
impl PySslSocket {
530-
fn stream_builder(&self) -> ssl::SslStreamBuilder<PySocketRef> {
531-
std::mem::replace(&mut *self.stream.write(), None).unwrap()
532-
}
533-
fn exec_stream<F, R>(&self, func: F) -> R
534-
where
535-
F: Fn(&mut ssl::SslStream<PySocketRef>) -> R,
536-
{
537-
let mut b = self.stream.write();
538-
func(unsafe {
539-
&mut *(b.as_mut().unwrap() as *mut ssl::SslStreamBuilder<_> as *mut ssl::SslStream<_>)
574+
fn stream(&self) -> impl std::ops::Deref<Target = ssl::SslStream<PySocketRef>> + '_ {
575+
let s = self.stream.read();
576+
// SAFETY: SslStreamBuilder is just a wrapper around SslStream
577+
PyRwLockReadGuard::map(s, |s| unsafe {
578+
&*(s as *const _ as *const ssl::SslStream<_>)
540579
})
541580
}
542-
fn set_stream(&self, stream: ssl::SslStream<PySocketRef>) {
543-
*self.stream.write() = Some(unsafe { std::mem::transmute(stream) });
581+
fn stream_mut(&self) -> impl std::ops::DerefMut<Target = ssl::SslStream<PySocketRef>> + '_ {
582+
let s = self.stream.write();
583+
// SAFETY: SslStreamBuilder is just a wrapper around SslStream
584+
PyRwLockWriteGuard::map(s, |s| unsafe {
585+
&mut *(s as *mut _ as *mut ssl::SslStream<_>)
586+
})
544587
}
545588

546589
#[pyproperty]
@@ -571,61 +614,101 @@ impl PySslSocket {
571614
vm: &VirtualMachine,
572615
) -> PyResult<Option<PyObjectRef>> {
573616
let binary = binary.unwrap_or(false);
574-
if !self.exec_stream(|stream| stream.ssl().is_init_finished()) {
617+
let stream = self.stream();
618+
if !stream.ssl().is_init_finished() {
575619
return Err(vm.new_value_error("handshake not done yet".to_owned()));
576620
}
577-
self.exec_stream(|stream| stream.ssl().peer_certificate())
621+
stream
622+
.ssl()
623+
.peer_certificate()
578624
.map(|cert| cert_to_py(vm, &cert, binary))
579625
.transpose()
580626
}
581627

582628
#[pymethod]
583629
fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
584-
// Either a stream builder or a mid-handshake stream from WANT_READ or WANT_WRITE
585-
let mut handshaker: Either<_, ssl::MidHandshakeSslStream<_>> =
586-
Either::A(self.stream_builder());
630+
let stream_builder = self.stream.write();
631+
let timeout = stream_builder
632+
.get_ref()
633+
.sock()
634+
.read_timeout()
635+
.ok()
636+
.flatten()
637+
.map(|dur| (dur, Instant::now()));
638+
let mut stream = unsafe { std::ptr::read(&*stream_builder) };
587639
loop {
588-
let handshake_result = match handshaker {
589-
Either::A(s) => s.handshake(),
590-
Either::B(s) => s.handshake(),
591-
};
592-
match handshake_result {
593-
Ok(stream) => {
594-
self.set_stream(stream);
640+
match stream.handshake() {
641+
Ok(s) => {
642+
// s and stream_builder are the same thing
643+
std::mem::forget(s);
595644
return Ok(());
596645
}
597-
Err(ssl::HandshakeError::SetupFailure(e)) => {
598-
return Err(convert_openssl_error(vm, e))
646+
Err(ssl::HandshakeError::SetupFailure(_e)) => {
647+
// handshake() error handling code never constructs this
648+
unreachable!();
649+
// return Err(convert_openssl_error(vm, e))
650+
}
651+
Err(ssl::HandshakeError::WouldBlock(s)) => {
652+
std::mem::forget(s);
653+
stream = unsafe { std::ptr::read(&*stream_builder) };
599654
}
600-
Err(ssl::HandshakeError::WouldBlock(s)) => handshaker = Either::B(s),
601655
Err(ssl::HandshakeError::Failure(s)) => {
602-
return Err(convert_ssl_error(vm, s.into_error()))
656+
let err = convert_ssl_error(vm, s.error());
657+
std::mem::forget(s);
658+
return Err(err);
659+
}
660+
}
661+
if let Some((timeout, ref start)) = timeout {
662+
if start.elapsed() >= timeout {
663+
std::mem::forget(stream);
664+
let socket_timeout = vm.class("_socket", "timeout");
665+
return Err(vm.new_exception_msg(
666+
socket_timeout,
667+
"The handshake operation timed out".to_owned(),
668+
));
603669
}
604670
}
605671
}
606672
}
607673

608674
#[pymethod]
609675
fn write(&self, data: PyBytesLike, vm: &VirtualMachine) -> PyResult<usize> {
610-
data.with_ref(|b| self.exec_stream(|stream| stream.ssl_write(b)))
676+
let mut stream = self.stream_mut();
677+
data.with_ref(|b| stream.ssl_write(b))
611678
.map_err(|e| convert_ssl_error(vm, e))
612679
}
613680

614681
#[pymethod]
615682
fn read(&self, n: usize, buffer: OptionalArg<PyByteArrayRef>, vm: &VirtualMachine) -> PyResult {
616-
if let OptionalArg::Present(buffer) = buffer {
617-
let n = self
618-
.exec_stream(|stream| {
619-
let mut buf = buffer.borrow_value_mut();
620-
stream.ssl_read(&mut buf.elements)
621-
})
622-
.map_err(|e| convert_ssl_error(vm, e))?;
623-
Ok(vm.ctx.new_int(n))
683+
let mut stream = self.stream_mut();
684+
let ret_nread = buffer.is_present();
685+
let ssl_res = if let OptionalArg::Present(buffer) = buffer {
686+
let mut buf = buffer.borrow_value_mut();
687+
stream
688+
.ssl_read(&mut buf.elements)
689+
.map(|n| vm.ctx.new_int(n))
624690
} else {
625691
let mut buf = vec![0u8; n];
626-
buf.truncate(n);
627-
Ok(vm.ctx.new_bytes(buf))
628-
}
692+
stream.ssl_read(&mut buf).map(|n| {
693+
buf.truncate(n);
694+
vm.ctx.new_bytes(buf)
695+
})
696+
};
697+
ssl_res.or_else(|e| {
698+
if e.code() == ssl::ErrorCode::ZERO_RETURN
699+
&& stream.get_shutdown() == ssl::ShutdownState::RECEIVED
700+
{
701+
Ok(if ret_nread {
702+
vm.ctx.new_int(0)
703+
} else {
704+
vm.ctx.new_bytes(vec![])
705+
})
706+
} else {
707+
Err(convert_ssl_error(vm, e))
708+
}
709+
})
710+
711+
// .map_err(|e| convert_ssl_error(vm, e))?;
629712
}
630713
}
631714

@@ -645,15 +728,28 @@ fn convert_openssl_error(vm: &VirtualMachine, err: ErrorStack) -> PyBaseExceptio
645728
// );
646729
// TODO: map the error codes to code names, e.g. "CERTIFICATE_VERIFY_FAILED", just requires a big hashmap/dict
647730
let msg = e.to_string();
648-
vm.new_exception_msg(cls, msg)
731+
vm.new_exception(cls, vec![vm.ctx.new_int(e.code()), vm.ctx.new_str(msg)])
649732
}
650733
None => vm.new_exception_empty(cls),
651734
}
652735
}
653-
fn convert_ssl_error(vm: &VirtualMachine, e: ssl::Error) -> PyBaseExceptionRef {
654-
match e.into_io_error() {
655-
Ok(io_err) => io_err.into_pyexception(vm),
656-
Err(e) => convert_openssl_error(vm, e.ssl_error().unwrap().clone()),
736+
fn convert_ssl_error(
737+
vm: &VirtualMachine,
738+
e: impl std::borrow::Borrow<ssl::Error>,
739+
) -> PyBaseExceptionRef {
740+
let e = e.borrow();
741+
match e.io_error() {
742+
Some(io_err) => io_err.into_pyexception(vm),
743+
None => match e.ssl_error() {
744+
Some(e) => convert_openssl_error(vm, e.clone()),
745+
None => vm.new_exception(
746+
ssl_error(vm),
747+
vec![
748+
vm.ctx.new_int(e.code().as_raw()),
749+
vm.ctx.new_str(e.to_string()),
750+
],
751+
),
752+
},
657753
}
658754
}
659755

@@ -805,7 +901,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
805901
"SSL_ERROR_SYSCALL" => ctx.new_int(sys::SSL_ERROR_SYSCALL),
806902
"SSL_ERROR_SSL" => ctx.new_int(sys::SSL_ERROR_SSL),
807903
"SSL_ERROR_WANT_CONNECT" => ctx.new_int(sys::SSL_ERROR_WANT_CONNECT),
808-
// "SSL_ERROR_EOF" => ctx.new_int(sys::SSL_ERROR_EOF),
904+
"SSL_ERROR_EOF" => ctx.new_int(8), // custom for python
809905
// "SSL_ERROR_INVALID_ERROR_CODE" => ctx.new_int(sys::SSL_ERROR_INVALID_ERROR_CODE),
810906
// TODO: so many more of these
811907
"ALERT_DESCRIPTION_DECODE_ERROR" => ctx.new_int(sys::SSL_AD_DECODE_ERROR),

0 commit comments

Comments
 (0)