Skip to content

Commit a8964f4

Browse files
coolreader18youknowone
authored andcommitted
Add select.epoll
1 parent 740aeed commit a8964f4

File tree

9 files changed

+322
-44
lines changed

9 files changed

+322
-44
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ once_cell = "1.19.0"
171171
parking_lot = "0.12.1"
172172
paste = "1.0.7"
173173
rand = "0.8.5"
174+
rustix = { version = "0.38", features = ["event"] }
174175
rustyline = "14.0.0"
175176
serde = { version = "1.0.133", default-features = false }
176177
schannel = "0.1.22"

Lib/test/test_poll.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,6 @@ def test_threaded_poll(self):
211211
os.write(w, b'spam')
212212
t.join()
213213

214-
# TODO: RUSTPYTHON add support for negative timeout
215-
@unittest.expectedFailure
216214
@unittest.skipUnless(threading, 'Threading required for this test.')
217215
@threading_helper.reap_threads
218216
def test_poll_blocks_with_negative_ms(self):

stdlib/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ page_size = "0.4"
9898
[target.'cfg(all(unix, not(target_os = "redox"), not(target_os = "ios")))'.dependencies]
9999
termios = "0.3.3"
100100

101+
[target.'cfg(unix)'.dependencies]
102+
rustix = { workspace = true }
103+
101104
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
102105
gethostname = "0.2.3"
103106
socket2 = { version = "0.5.6", features = ["all"] }

stdlib/src/select.rs

Lines changed: 280 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -325,12 +325,59 @@ mod decl {
325325
pub(super) mod poll {
326326
use super::*;
327327
use crate::vm::{
328-
builtins::PyFloat, common::lock::PyMutex, convert::ToPyObject, function::OptionalArg,
329-
stdlib::io::Fildes, AsObject, PyPayload,
328+
builtins::PyFloat,
329+
common::lock::PyMutex,
330+
convert::{IntoPyException, ToPyObject},
331+
function::OptionalArg,
332+
stdlib::io::Fildes,
333+
AsObject, PyPayload,
330334
};
331335
use libc::pollfd;
332-
use num_traits::ToPrimitive;
333-
use std::time;
336+
use num_traits::{Signed, ToPrimitive};
337+
use std::time::{Duration, Instant};
338+
339+
#[derive(Default)]
340+
pub(super) struct TimeoutArg<const MILLIS: bool>(pub Option<Duration>);
341+
342+
impl<const MILLIS: bool> TryFromObject for TimeoutArg<MILLIS> {
343+
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
344+
let timeout = if vm.is_none(&obj) {
345+
None
346+
} else if let Some(float) = obj.payload::<PyFloat>() {
347+
let float = float.to_f64();
348+
if float.is_nan() {
349+
return Err(
350+
vm.new_value_error("Invalid value NaN (not a number)".to_owned())
351+
);
352+
}
353+
if float.is_sign_negative() {
354+
None
355+
} else {
356+
let secs = if MILLIS { float * 1000.0 } else { float };
357+
Some(Duration::from_secs_f64(secs))
358+
}
359+
} else if let Some(int) = obj.try_index_opt(vm).transpose()? {
360+
if int.as_bigint().is_negative() {
361+
None
362+
} else {
363+
let n = int.as_bigint().to_u64().ok_or_else(|| {
364+
vm.new_overflow_error("value out of range".to_owned())
365+
})?;
366+
Some(if MILLIS {
367+
Duration::from_millis(n)
368+
} else {
369+
Duration::from_secs(n)
370+
})
371+
}
372+
} else {
373+
return Err(vm.new_type_error(format!(
374+
"expected an int or float for duration, got {}",
375+
obj.class()
376+
)));
377+
};
378+
Ok(Self(timeout))
379+
}
380+
}
334381

335382
#[pyclass(module = "select", name = "poll")]
336383
#[derive(Default, Debug, PyPayload)]
@@ -399,50 +446,31 @@ mod decl {
399446
#[pymethod]
400447
fn poll(
401448
&self,
402-
timeout: OptionalOption,
449+
timeout: OptionalArg<TimeoutArg<true>>,
403450
vm: &VirtualMachine,
404451
) -> PyResult<Vec<PyObjectRef>> {
405452
let mut fds = self.fds.lock();
406-
let timeout_ms = match timeout.flatten() {
407-
Some(ms) => {
408-
let ms = if let Some(float) = ms.payload::<PyFloat>() {
409-
float.to_f64().to_i32()
410-
} else if let Some(int) = ms.try_index_opt(vm) {
411-
int?.as_bigint().to_i32()
412-
} else {
413-
return Err(vm.new_type_error(format!(
414-
"expected an int or float for duration, got {}",
415-
ms.class()
416-
)));
417-
};
418-
ms.ok_or_else(|| vm.new_value_error("value out of range".to_owned()))?
419-
}
420-
None => -1,
453+
let TimeoutArg(timeout) = timeout.unwrap_or_default();
454+
let timeout_ms = match timeout {
455+
Some(d) => i32::try_from(d.as_millis())
456+
.map_err(|_| vm.new_overflow_error("value out of range".to_owned()))?,
457+
None => -1i32,
421458
};
422-
let timeout_ms = if timeout_ms < 0 { -1 } else { timeout_ms };
423-
let deadline = (timeout_ms >= 0)
424-
.then(|| time::Instant::now() + time::Duration::from_millis(timeout_ms as u64));
459+
let deadline = timeout.map(|d| Instant::now() + d);
425460
let mut poll_timeout = timeout_ms;
426461
loop {
427462
let res = unsafe { libc::poll(fds.as_mut_ptr(), fds.len() as _, poll_timeout) };
428-
let res = if res < 0 {
429-
Err(io::Error::last_os_error())
430-
} else {
431-
Ok(())
432-
};
433-
match res {
434-
Ok(()) => break,
435-
Err(e) if e.kind() == io::ErrorKind::Interrupted => {
436-
vm.check_signals()?;
437-
if let Some(d) = deadline {
438-
match d.checked_duration_since(time::Instant::now()) {
439-
Some(remaining) => poll_timeout = remaining.as_millis() as i32,
440-
// we've timed out
441-
None => break,
442-
}
443-
}
463+
match nix::Error::result(res) {
464+
Ok(_) => break,
465+
Err(nix::Error::EINTR) => vm.check_signals()?,
466+
Err(e) => return Err(e.into_pyexception(vm)),
467+
}
468+
if let Some(d) = deadline {
469+
if let Some(remaining) = d.checked_duration_since(Instant::now()) {
470+
poll_timeout = remaining.as_millis() as i32;
471+
} else {
472+
break;
444473
}
445-
Err(e) => return Err(e.to_pyexception(vm)),
446474
}
447475
}
448476
Ok(fds
@@ -453,4 +481,216 @@ mod decl {
453481
}
454482
}
455483
}
484+
485+
#[cfg(any(target_os = "linux", target_os = "android", target_os = "redox"))]
486+
#[pyattr(name = "epoll", once)]
487+
fn epoll(vm: &VirtualMachine) -> PyTypeRef {
488+
use crate::vm::class::PyClassImpl;
489+
epoll::PyEpoll::make_class(&vm.ctx)
490+
}
491+
492+
#[cfg(any(target_os = "linux", target_os = "android", target_os = "redox"))]
493+
#[pyattr]
494+
use libc::{
495+
EPOLLERR, EPOLLEXCLUSIVE, EPOLLHUP, EPOLLIN, EPOLLMSG, EPOLLONESHOT, EPOLLOUT, EPOLLPRI,
496+
EPOLLRDBAND, EPOLLRDHUP, EPOLLRDNORM, EPOLLWAKEUP, EPOLLWRBAND, EPOLLWRNORM, EPOLL_CLOEXEC,
497+
};
498+
#[cfg(any(target_os = "linux", target_os = "android", target_os = "redox"))]
499+
#[pyattr]
500+
const EPOLLET: u32 = libc::EPOLLET as u32;
501+
502+
#[cfg(any(target_os = "linux", target_os = "android", target_os = "redox"))]
503+
pub(super) mod epoll {
504+
use super::*;
505+
use crate::vm::{
506+
builtins::PyTypeRef,
507+
common::lock::{PyRwLock, PyRwLockReadGuard},
508+
convert::{IntoPyException, ToPyObject},
509+
function::OptionalArg,
510+
stdlib::io::Fildes,
511+
types::Constructor,
512+
PyPayload,
513+
};
514+
use rustix::event::epoll::{self, EventData, EventFlags};
515+
use std::ops::Deref;
516+
use std::os::fd::{AsRawFd, IntoRawFd, OwnedFd};
517+
use std::time::{Duration, Instant};
518+
519+
#[pyclass(module = "select", name = "epoll")]
520+
#[derive(Debug, rustpython_vm::PyPayload)]
521+
pub struct PyEpoll {
522+
epoll_fd: PyRwLock<Option<OwnedFd>>,
523+
}
524+
525+
#[derive(FromArgs)]
526+
pub struct EpollNewArgs {
527+
#[pyarg(any, default = "-1")]
528+
sizehint: i32,
529+
#[pyarg(any, default = "0")]
530+
flags: i32,
531+
}
532+
533+
impl Constructor for PyEpoll {
534+
type Args = EpollNewArgs;
535+
fn py_new(cls: PyTypeRef, args: EpollNewArgs, vm: &VirtualMachine) -> PyResult {
536+
if let ..=-2 | 0 = args.sizehint {
537+
return Err(vm.new_value_error("negative sizehint".to_owned()));
538+
}
539+
if !matches!(args.flags, 0 | libc::EPOLL_CLOEXEC) {
540+
return Err(vm.new_os_error("invalid flags".to_owned()));
541+
}
542+
Self::new()
543+
.map_err(|e| e.into_pyexception(vm))?
544+
.into_ref_with_type(vm, cls)
545+
.map(Into::into)
546+
}
547+
}
548+
549+
#[derive(FromArgs)]
550+
struct EpollPollArgs {
551+
#[pyarg(any, default)]
552+
timeout: poll::TimeoutArg<false>,
553+
#[pyarg(any, default = "-1")]
554+
maxevents: i32,
555+
}
556+
557+
#[pyclass(with(Constructor))]
558+
impl PyEpoll {
559+
fn new() -> std::io::Result<Self> {
560+
let epoll_fd = epoll::create(epoll::CreateFlags::CLOEXEC)?;
561+
let epoll_fd = Some(epoll_fd).into();
562+
Ok(PyEpoll { epoll_fd })
563+
}
564+
565+
#[pymethod]
566+
fn close(&self) -> std::io::Result<()> {
567+
let fd = self.epoll_fd.write().take();
568+
if let Some(fd) = fd {
569+
nix::unistd::close(fd.into_raw_fd())?;
570+
}
571+
Ok(())
572+
}
573+
574+
#[pygetset]
575+
fn closed(&self) -> bool {
576+
self.epoll_fd.read().is_none()
577+
}
578+
579+
fn get_epoll(
580+
&self,
581+
vm: &VirtualMachine,
582+
) -> PyResult<impl Deref<Target = OwnedFd> + '_> {
583+
PyRwLockReadGuard::try_map(self.epoll_fd.read(), |x| x.as_ref()).map_err(|_| {
584+
vm.new_value_error("I/O operation on closed epoll object".to_owned())
585+
})
586+
}
587+
588+
#[pymethod]
589+
fn fileno(&self, vm: &VirtualMachine) -> PyResult<i32> {
590+
self.get_epoll(vm).map(|epoll_fd| epoll_fd.as_raw_fd())
591+
}
592+
593+
#[pyclassmethod]
594+
fn fromfd(cls: PyTypeRef, fd: OwnedFd, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
595+
let epoll_fd = Some(fd).into();
596+
Self { epoll_fd }.into_ref_with_type(vm, cls)
597+
}
598+
599+
#[pymethod]
600+
fn register(
601+
&self,
602+
fd: Fildes,
603+
eventmask: OptionalArg<u32>,
604+
vm: &VirtualMachine,
605+
) -> PyResult<()> {
606+
let events = match eventmask {
607+
OptionalArg::Present(mask) => EventFlags::from_bits_retain(mask),
608+
OptionalArg::Missing => EventFlags::IN | EventFlags::PRI | EventFlags::OUT,
609+
};
610+
let epoll_fd = &*self.get_epoll(vm)?;
611+
let data = EventData::new_u64(fd.as_raw_fd() as u64);
612+
epoll::add(epoll_fd, fd, data, events).map_err(|e| e.into_pyexception(vm))
613+
}
614+
615+
#[pymethod]
616+
fn modify(&self, fd: Fildes, eventmask: u32, vm: &VirtualMachine) -> PyResult<()> {
617+
let events = EventFlags::from_bits_retain(eventmask);
618+
let epoll_fd = &*self.get_epoll(vm)?;
619+
let data = EventData::new_u64(fd.as_raw_fd() as u64);
620+
epoll::modify(epoll_fd, fd, data, events).map_err(|e| e.into_pyexception(vm))
621+
}
622+
623+
#[pymethod]
624+
fn unregister(&self, fd: Fildes, vm: &VirtualMachine) -> PyResult<()> {
625+
let epoll_fd = &*self.get_epoll(vm)?;
626+
epoll::delete(epoll_fd, fd).map_err(|e| e.into_pyexception(vm))
627+
}
628+
629+
#[pymethod]
630+
fn poll(&self, args: EpollPollArgs, vm: &VirtualMachine) -> PyResult<PyListRef> {
631+
let poll::TimeoutArg(timeout) = args.timeout;
632+
let maxevents = args.maxevents;
633+
634+
let make_poll_timeout = |d: Duration| i32::try_from(d.as_millis());
635+
let mut poll_timeout = match timeout {
636+
Some(d) => make_poll_timeout(d)
637+
.map_err(|_| vm.new_overflow_error("timeout is too large".to_owned()))?,
638+
None => -1,
639+
};
640+
641+
let deadline = timeout.map(|d| Instant::now() + d);
642+
let maxevents = match maxevents {
643+
..-1 => {
644+
return Err(vm.new_value_error(format!(
645+
"maxevents must be greater than 0, got {maxevents}"
646+
)))
647+
}
648+
-1 => libc::FD_SETSIZE - 1,
649+
_ => maxevents as usize,
650+
};
651+
652+
let mut events = epoll::EventVec::with_capacity(maxevents);
653+
654+
let epoll = &*self.get_epoll(vm)?;
655+
656+
loop {
657+
match epoll::wait(epoll, &mut events, poll_timeout) {
658+
Ok(()) => break,
659+
Err(rustix::io::Errno::INTR) => vm.check_signals()?,
660+
Err(e) => return Err(e.into_pyexception(vm)),
661+
}
662+
if let Some(deadline) = deadline {
663+
if let Some(new_timeout) = deadline.checked_duration_since(Instant::now()) {
664+
poll_timeout = make_poll_timeout(new_timeout).unwrap();
665+
} else {
666+
break;
667+
}
668+
}
669+
}
670+
671+
let ret = events
672+
.iter()
673+
.map(|ev| (ev.data.u64() as i32, { ev.flags }.bits()).to_pyobject(vm))
674+
.collect();
675+
676+
Ok(vm.ctx.new_list(ret))
677+
}
678+
679+
#[pymethod(magic)]
680+
fn enter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
681+
zelf.get_epoll(vm)?;
682+
Ok(zelf)
683+
}
684+
685+
#[pymethod(magic)]
686+
fn exit(
687+
&self,
688+
_exc_type: OptionalArg,
689+
_exc_value: OptionalArg,
690+
_exc_tb: OptionalArg,
691+
) -> std::io::Result<()> {
692+
self.close()
693+
}
694+
}
695+
}
456696
}

vm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ unic-ucd-category = "0.9.0"
9191
unic-ucd-ident = "0.9.0"
9292

9393
[target.'cfg(unix)'.dependencies]
94+
rustix = { workspace = true }
9495
exitcode = "1.1.2"
9596
uname = "0.1.1"
9697
strum = "0.24.0"

0 commit comments

Comments
 (0)