@@ -325,12 +325,59 @@ mod decl {
325
325
pub ( super ) mod poll {
326
326
use super :: * ;
327
327
use crate :: vm:: {
328
- builtins:: PyFloat , common:: lock:: PyMutex , convert:: ToPyObject , function:: OptionalArg ,
329
- stdlib:: io:: Fildes , AsObject , PyPayload ,
328
+ builtins:: PyFloat ,
329
+ common:: lock:: PyMutex ,
330
+ convert:: { IntoPyException , ToPyObject } ,
331
+ function:: OptionalArg ,
332
+ stdlib:: io:: Fildes ,
333
+ AsObject , PyPayload ,
330
334
} ;
331
335
use libc:: pollfd;
332
- use num_traits:: ToPrimitive ;
333
- use std:: time;
336
+ use num_traits:: { Signed , ToPrimitive } ;
337
+ use std:: time:: { Duration , Instant } ;
338
+
339
+ #[ derive( Default ) ]
340
+ pub ( super ) struct TimeoutArg < const MILLIS : bool > ( pub Option < Duration > ) ;
341
+
342
+ impl < const MILLIS : bool > TryFromObject for TimeoutArg < MILLIS > {
343
+ fn try_from_object ( vm : & VirtualMachine , obj : PyObjectRef ) -> PyResult < Self > {
344
+ let timeout = if vm. is_none ( & obj) {
345
+ None
346
+ } else if let Some ( float) = obj. payload :: < PyFloat > ( ) {
347
+ let float = float. to_f64 ( ) ;
348
+ if float. is_nan ( ) {
349
+ return Err (
350
+ vm. new_value_error ( "Invalid value NaN (not a number)" . to_owned ( ) )
351
+ ) ;
352
+ }
353
+ if float. is_sign_negative ( ) {
354
+ None
355
+ } else {
356
+ let secs = if MILLIS { float * 1000.0 } else { float } ;
357
+ Some ( Duration :: from_secs_f64 ( secs) )
358
+ }
359
+ } else if let Some ( int) = obj. try_index_opt ( vm) . transpose ( ) ? {
360
+ if int. as_bigint ( ) . is_negative ( ) {
361
+ None
362
+ } else {
363
+ let n = int. as_bigint ( ) . to_u64 ( ) . ok_or_else ( || {
364
+ vm. new_overflow_error ( "value out of range" . to_owned ( ) )
365
+ } ) ?;
366
+ Some ( if MILLIS {
367
+ Duration :: from_millis ( n)
368
+ } else {
369
+ Duration :: from_secs ( n)
370
+ } )
371
+ }
372
+ } else {
373
+ return Err ( vm. new_type_error ( format ! (
374
+ "expected an int or float for duration, got {}" ,
375
+ obj. class( )
376
+ ) ) ) ;
377
+ } ;
378
+ Ok ( Self ( timeout) )
379
+ }
380
+ }
334
381
335
382
#[ pyclass( module = "select" , name = "poll" ) ]
336
383
#[ derive( Default , Debug , PyPayload ) ]
@@ -399,50 +446,31 @@ mod decl {
399
446
#[ pymethod]
400
447
fn poll (
401
448
& self ,
402
- timeout : OptionalOption ,
449
+ timeout : OptionalArg < TimeoutArg < true > > ,
403
450
vm : & VirtualMachine ,
404
451
) -> PyResult < Vec < PyObjectRef > > {
405
452
let mut fds = self . fds . lock ( ) ;
406
- let timeout_ms = match timeout. flatten ( ) {
407
- Some ( ms) => {
408
- let ms = if let Some ( float) = ms. payload :: < PyFloat > ( ) {
409
- float. to_f64 ( ) . to_i32 ( )
410
- } else if let Some ( int) = ms. try_index_opt ( vm) {
411
- int?. as_bigint ( ) . to_i32 ( )
412
- } else {
413
- return Err ( vm. new_type_error ( format ! (
414
- "expected an int or float for duration, got {}" ,
415
- ms. class( )
416
- ) ) ) ;
417
- } ;
418
- ms. ok_or_else ( || vm. new_value_error ( "value out of range" . to_owned ( ) ) ) ?
419
- }
420
- None => -1 ,
453
+ let TimeoutArg ( timeout) = timeout. unwrap_or_default ( ) ;
454
+ let timeout_ms = match timeout {
455
+ Some ( d) => i32:: try_from ( d. as_millis ( ) )
456
+ . map_err ( |_| vm. new_overflow_error ( "value out of range" . to_owned ( ) ) ) ?,
457
+ None => -1i32 ,
421
458
} ;
422
- let timeout_ms = if timeout_ms < 0 { -1 } else { timeout_ms } ;
423
- let deadline = ( timeout_ms >= 0 )
424
- . then ( || time:: Instant :: now ( ) + time:: Duration :: from_millis ( timeout_ms as u64 ) ) ;
459
+ let deadline = timeout. map ( |d| Instant :: now ( ) + d) ;
425
460
let mut poll_timeout = timeout_ms;
426
461
loop {
427
462
let res = unsafe { libc:: poll ( fds. as_mut_ptr ( ) , fds. len ( ) as _ , poll_timeout) } ;
428
- let res = if res < 0 {
429
- Err ( io:: Error :: last_os_error ( ) )
430
- } else {
431
- Ok ( ( ) )
432
- } ;
433
- match res {
434
- Ok ( ( ) ) => break ,
435
- Err ( e) if e. kind ( ) == io:: ErrorKind :: Interrupted => {
436
- vm. check_signals ( ) ?;
437
- if let Some ( d) = deadline {
438
- match d. checked_duration_since ( time:: Instant :: now ( ) ) {
439
- Some ( remaining) => poll_timeout = remaining. as_millis ( ) as i32 ,
440
- // we've timed out
441
- None => break ,
442
- }
443
- }
463
+ match nix:: Error :: result ( res) {
464
+ Ok ( _) => break ,
465
+ Err ( nix:: Error :: EINTR ) => vm. check_signals ( ) ?,
466
+ Err ( e) => return Err ( e. into_pyexception ( vm) ) ,
467
+ }
468
+ if let Some ( d) = deadline {
469
+ if let Some ( remaining) = d. checked_duration_since ( Instant :: now ( ) ) {
470
+ poll_timeout = remaining. as_millis ( ) as i32 ;
471
+ } else {
472
+ break ;
444
473
}
445
- Err ( e) => return Err ( e. to_pyexception ( vm) ) ,
446
474
}
447
475
}
448
476
Ok ( fds
@@ -453,4 +481,216 @@ mod decl {
453
481
}
454
482
}
455
483
}
484
+
485
+ #[ cfg( any( target_os = "linux" , target_os = "android" , target_os = "redox" ) ) ]
486
+ #[ pyattr( name = "epoll" , once) ]
487
+ fn epoll ( vm : & VirtualMachine ) -> PyTypeRef {
488
+ use crate :: vm:: class:: PyClassImpl ;
489
+ epoll:: PyEpoll :: make_class ( & vm. ctx )
490
+ }
491
+
492
+ #[ cfg( any( target_os = "linux" , target_os = "android" , target_os = "redox" ) ) ]
493
+ #[ pyattr]
494
+ use libc:: {
495
+ EPOLLERR , EPOLLEXCLUSIVE , EPOLLHUP , EPOLLIN , EPOLLMSG , EPOLLONESHOT , EPOLLOUT , EPOLLPRI ,
496
+ EPOLLRDBAND , EPOLLRDHUP , EPOLLRDNORM , EPOLLWAKEUP , EPOLLWRBAND , EPOLLWRNORM , EPOLL_CLOEXEC ,
497
+ } ;
498
+ #[ cfg( any( target_os = "linux" , target_os = "android" , target_os = "redox" ) ) ]
499
+ #[ pyattr]
500
+ const EPOLLET : u32 = libc:: EPOLLET as u32 ;
501
+
502
+ #[ cfg( any( target_os = "linux" , target_os = "android" , target_os = "redox" ) ) ]
503
+ pub ( super ) mod epoll {
504
+ use super :: * ;
505
+ use crate :: vm:: {
506
+ builtins:: PyTypeRef ,
507
+ common:: lock:: { PyRwLock , PyRwLockReadGuard } ,
508
+ convert:: { IntoPyException , ToPyObject } ,
509
+ function:: OptionalArg ,
510
+ stdlib:: io:: Fildes ,
511
+ types:: Constructor ,
512
+ PyPayload ,
513
+ } ;
514
+ use rustix:: event:: epoll:: { self , EventData , EventFlags } ;
515
+ use std:: ops:: Deref ;
516
+ use std:: os:: fd:: { AsRawFd , IntoRawFd , OwnedFd } ;
517
+ use std:: time:: { Duration , Instant } ;
518
+
519
+ #[ pyclass( module = "select" , name = "epoll" ) ]
520
+ #[ derive( Debug , rustpython_vm:: PyPayload ) ]
521
+ pub struct PyEpoll {
522
+ epoll_fd : PyRwLock < Option < OwnedFd > > ,
523
+ }
524
+
525
+ #[ derive( FromArgs ) ]
526
+ pub struct EpollNewArgs {
527
+ #[ pyarg( any, default = "-1" ) ]
528
+ sizehint : i32 ,
529
+ #[ pyarg( any, default = "0" ) ]
530
+ flags : i32 ,
531
+ }
532
+
533
+ impl Constructor for PyEpoll {
534
+ type Args = EpollNewArgs ;
535
+ fn py_new ( cls : PyTypeRef , args : EpollNewArgs , vm : & VirtualMachine ) -> PyResult {
536
+ if let ..=-2 | 0 = args. sizehint {
537
+ return Err ( vm. new_value_error ( "negative sizehint" . to_owned ( ) ) ) ;
538
+ }
539
+ if !matches ! ( args. flags, 0 | libc:: EPOLL_CLOEXEC ) {
540
+ return Err ( vm. new_os_error ( "invalid flags" . to_owned ( ) ) ) ;
541
+ }
542
+ Self :: new ( )
543
+ . map_err ( |e| e. into_pyexception ( vm) ) ?
544
+ . into_ref_with_type ( vm, cls)
545
+ . map ( Into :: into)
546
+ }
547
+ }
548
+
549
+ #[ derive( FromArgs ) ]
550
+ struct EpollPollArgs {
551
+ #[ pyarg( any, default ) ]
552
+ timeout : poll:: TimeoutArg < false > ,
553
+ #[ pyarg( any, default = "-1" ) ]
554
+ maxevents : i32 ,
555
+ }
556
+
557
+ #[ pyclass( with( Constructor ) ) ]
558
+ impl PyEpoll {
559
+ fn new ( ) -> std:: io:: Result < Self > {
560
+ let epoll_fd = epoll:: create ( epoll:: CreateFlags :: CLOEXEC ) ?;
561
+ let epoll_fd = Some ( epoll_fd) . into ( ) ;
562
+ Ok ( PyEpoll { epoll_fd } )
563
+ }
564
+
565
+ #[ pymethod]
566
+ fn close ( & self ) -> std:: io:: Result < ( ) > {
567
+ let fd = self . epoll_fd . write ( ) . take ( ) ;
568
+ if let Some ( fd) = fd {
569
+ nix:: unistd:: close ( fd. into_raw_fd ( ) ) ?;
570
+ }
571
+ Ok ( ( ) )
572
+ }
573
+
574
+ #[ pygetset]
575
+ fn closed ( & self ) -> bool {
576
+ self . epoll_fd . read ( ) . is_none ( )
577
+ }
578
+
579
+ fn get_epoll (
580
+ & self ,
581
+ vm : & VirtualMachine ,
582
+ ) -> PyResult < impl Deref < Target = OwnedFd > + ' _ > {
583
+ PyRwLockReadGuard :: try_map ( self . epoll_fd . read ( ) , |x| x. as_ref ( ) ) . map_err ( |_| {
584
+ vm. new_value_error ( "I/O operation on closed epoll object" . to_owned ( ) )
585
+ } )
586
+ }
587
+
588
+ #[ pymethod]
589
+ fn fileno ( & self , vm : & VirtualMachine ) -> PyResult < i32 > {
590
+ self . get_epoll ( vm) . map ( |epoll_fd| epoll_fd. as_raw_fd ( ) )
591
+ }
592
+
593
+ #[ pyclassmethod]
594
+ fn fromfd ( cls : PyTypeRef , fd : OwnedFd , vm : & VirtualMachine ) -> PyResult < PyRef < Self > > {
595
+ let epoll_fd = Some ( fd) . into ( ) ;
596
+ Self { epoll_fd } . into_ref_with_type ( vm, cls)
597
+ }
598
+
599
+ #[ pymethod]
600
+ fn register (
601
+ & self ,
602
+ fd : Fildes ,
603
+ eventmask : OptionalArg < u32 > ,
604
+ vm : & VirtualMachine ,
605
+ ) -> PyResult < ( ) > {
606
+ let events = match eventmask {
607
+ OptionalArg :: Present ( mask) => EventFlags :: from_bits_retain ( mask) ,
608
+ OptionalArg :: Missing => EventFlags :: IN | EventFlags :: PRI | EventFlags :: OUT ,
609
+ } ;
610
+ let epoll_fd = & * self . get_epoll ( vm) ?;
611
+ let data = EventData :: new_u64 ( fd. as_raw_fd ( ) as u64 ) ;
612
+ epoll:: add ( epoll_fd, fd, data, events) . map_err ( |e| e. into_pyexception ( vm) )
613
+ }
614
+
615
+ #[ pymethod]
616
+ fn modify ( & self , fd : Fildes , eventmask : u32 , vm : & VirtualMachine ) -> PyResult < ( ) > {
617
+ let events = EventFlags :: from_bits_retain ( eventmask) ;
618
+ let epoll_fd = & * self . get_epoll ( vm) ?;
619
+ let data = EventData :: new_u64 ( fd. as_raw_fd ( ) as u64 ) ;
620
+ epoll:: modify ( epoll_fd, fd, data, events) . map_err ( |e| e. into_pyexception ( vm) )
621
+ }
622
+
623
+ #[ pymethod]
624
+ fn unregister ( & self , fd : Fildes , vm : & VirtualMachine ) -> PyResult < ( ) > {
625
+ let epoll_fd = & * self . get_epoll ( vm) ?;
626
+ epoll:: delete ( epoll_fd, fd) . map_err ( |e| e. into_pyexception ( vm) )
627
+ }
628
+
629
+ #[ pymethod]
630
+ fn poll ( & self , args : EpollPollArgs , vm : & VirtualMachine ) -> PyResult < PyListRef > {
631
+ let poll:: TimeoutArg ( timeout) = args. timeout ;
632
+ let maxevents = args. maxevents ;
633
+
634
+ let make_poll_timeout = |d : Duration | i32:: try_from ( d. as_millis ( ) ) ;
635
+ let mut poll_timeout = match timeout {
636
+ Some ( d) => make_poll_timeout ( d)
637
+ . map_err ( |_| vm. new_overflow_error ( "timeout is too large" . to_owned ( ) ) ) ?,
638
+ None => -1 ,
639
+ } ;
640
+
641
+ let deadline = timeout. map ( |d| Instant :: now ( ) + d) ;
642
+ let maxevents = match maxevents {
643
+ ..-1 => {
644
+ return Err ( vm. new_value_error ( format ! (
645
+ "maxevents must be greater than 0, got {maxevents}"
646
+ ) ) )
647
+ }
648
+ -1 => libc:: FD_SETSIZE - 1 ,
649
+ _ => maxevents as usize ,
650
+ } ;
651
+
652
+ let mut events = epoll:: EventVec :: with_capacity ( maxevents) ;
653
+
654
+ let epoll = & * self . get_epoll ( vm) ?;
655
+
656
+ loop {
657
+ match epoll:: wait ( epoll, & mut events, poll_timeout) {
658
+ Ok ( ( ) ) => break ,
659
+ Err ( rustix:: io:: Errno :: INTR ) => vm. check_signals ( ) ?,
660
+ Err ( e) => return Err ( e. into_pyexception ( vm) ) ,
661
+ }
662
+ if let Some ( deadline) = deadline {
663
+ if let Some ( new_timeout) = deadline. checked_duration_since ( Instant :: now ( ) ) {
664
+ poll_timeout = make_poll_timeout ( new_timeout) . unwrap ( ) ;
665
+ } else {
666
+ break ;
667
+ }
668
+ }
669
+ }
670
+
671
+ let ret = events
672
+ . iter ( )
673
+ . map ( |ev| ( ev. data . u64 ( ) as i32 , { ev. flags } . bits ( ) ) . to_pyobject ( vm) )
674
+ . collect ( ) ;
675
+
676
+ Ok ( vm. ctx . new_list ( ret) )
677
+ }
678
+
679
+ #[ pymethod( magic) ]
680
+ fn enter ( zelf : PyRef < Self > , vm : & VirtualMachine ) -> PyResult < PyRef < Self > > {
681
+ zelf. get_epoll ( vm) ?;
682
+ Ok ( zelf)
683
+ }
684
+
685
+ #[ pymethod( magic) ]
686
+ fn exit (
687
+ & self ,
688
+ _exc_type : OptionalArg ,
689
+ _exc_value : OptionalArg ,
690
+ _exc_tb : OptionalArg ,
691
+ ) -> std:: io:: Result < ( ) > {
692
+ self . close ( )
693
+ }
694
+ }
695
+ }
456
696
}
0 commit comments