@@ -3,6 +3,7 @@ use std::io;
3
3
use std:: io:: Read ;
4
4
use std:: io:: Write ;
5
5
use std:: net:: { Ipv4Addr , SocketAddr , TcpListener , TcpStream , ToSocketAddrs , UdpSocket } ;
6
+ use std:: time:: Duration ;
6
7
7
8
#[ cfg( all( unix, not( target_os = "redox" ) ) ) ]
8
9
use nix:: unistd:: sethostname;
@@ -122,6 +123,27 @@ impl Connection {
122
123
fn fileno ( & self ) -> i64 {
123
124
unimplemented ! ( ) ;
124
125
}
126
+
127
+ fn setblocking ( & mut self , value : bool ) -> io:: Result < ( ) > {
128
+ match self {
129
+ Connection :: TcpListener ( con) => con. set_nonblocking ( !value) ,
130
+ Connection :: UdpSocket ( con) => con. set_nonblocking ( !value) ,
131
+ Connection :: TcpStream ( con) => con. set_nonblocking ( !value) ,
132
+ }
133
+ }
134
+
135
+ fn settimeout ( & mut self , duration : Duration ) -> io:: Result < ( ) > {
136
+ match self {
137
+ // net
138
+ Connection :: TcpListener ( _con) => Ok ( ( ) ) ,
139
+ Connection :: UdpSocket ( con) => con
140
+ . set_read_timeout ( Some ( duration) )
141
+ . and_then ( |_| con. set_write_timeout ( Some ( duration) ) ) ,
142
+ Connection :: TcpStream ( con) => con
143
+ . set_read_timeout ( Some ( duration) )
144
+ . and_then ( |_| con. set_write_timeout ( Some ( duration) ) ) ,
145
+ }
146
+ }
125
147
}
126
148
127
149
impl Read for Connection {
@@ -152,6 +174,7 @@ pub struct Socket {
152
174
address_family : AddressFamily ,
153
175
socket_kind : SocketKind ,
154
176
con : RefCell < Option < Connection > > ,
177
+ timeout : RefCell < Option < Duration > > ,
155
178
}
156
179
157
180
impl PyValue for Socket {
@@ -166,6 +189,7 @@ impl Socket {
166
189
address_family,
167
190
socket_kind,
168
191
con : RefCell :: new ( None ) ,
192
+ timeout : RefCell :: new ( None ) ,
169
193
}
170
194
}
171
195
}
@@ -194,13 +218,41 @@ impl SocketRef {
194
218
let address_string = address. get_address_string ( ) ;
195
219
196
220
match self . socket_kind {
197
- SocketKind :: Stream => match TcpStream :: connect ( address_string) {
198
- Ok ( stream) => {
199
- self . con . borrow_mut ( ) . replace ( Connection :: TcpStream ( stream) ) ;
200
- Ok ( ( ) )
221
+ SocketKind :: Stream => {
222
+ let con = if let Some ( duration) = self . timeout . borrow ( ) . as_ref ( ) {
223
+ let sock_addr = match address_string. to_socket_addrs ( ) {
224
+ Ok ( mut sock_addrs) => {
225
+ if sock_addrs. len ( ) == 0 {
226
+ let error_type = vm. class ( "socket" , "gaierror" ) ;
227
+ return Err ( vm. new_exception (
228
+ error_type,
229
+ "nodename nor servname provided, or not known" . to_string ( ) ,
230
+ ) ) ;
231
+ } else {
232
+ sock_addrs. next ( ) . unwrap ( )
233
+ }
234
+ }
235
+ Err ( e) => {
236
+ let error_type = vm. class ( "socket" , "gaierror" ) ;
237
+ return Err ( vm. new_exception ( error_type, e. to_string ( ) ) ) ;
238
+ }
239
+ } ;
240
+ TcpStream :: connect_timeout ( & sock_addr, * duration)
241
+ } else {
242
+ TcpStream :: connect ( address_string)
243
+ } ;
244
+ match con {
245
+ Ok ( stream) => {
246
+ self . con . borrow_mut ( ) . replace ( Connection :: TcpStream ( stream) ) ;
247
+ Ok ( ( ) )
248
+ }
249
+ Err ( ref e) if e. kind ( ) == io:: ErrorKind :: TimedOut => {
250
+ let socket_timeout = vm. class ( "socket" , "timeout" ) ;
251
+ Err ( vm. new_exception ( socket_timeout, "Timed out" . to_string ( ) ) )
252
+ }
253
+ Err ( s) => Err ( vm. new_os_error ( s. to_string ( ) ) ) ,
201
254
}
202
- Err ( s) => Err ( vm. new_os_error ( s. to_string ( ) ) ) ,
203
- } ,
255
+ }
204
256
SocketKind :: Dgram => {
205
257
if let Some ( Connection :: UdpSocket ( con) ) = self . con . borrow ( ) . as_ref ( ) {
206
258
match con. connect ( address_string) {
@@ -254,6 +306,7 @@ impl SocketRef {
254
306
address_family : self . address_family ,
255
307
socket_kind : self . socket_kind ,
256
308
con : RefCell :: new ( Some ( Connection :: TcpStream ( tcp_stream) ) ) ,
309
+ timeout : RefCell :: new ( None ) ,
257
310
}
258
311
. into_ref ( vm) ;
259
312
@@ -267,6 +320,10 @@ impl SocketRef {
267
320
match self . con . borrow_mut ( ) . as_mut ( ) {
268
321
Some ( v) => match v. read_exact ( & mut buffer) {
269
322
Ok ( _) => ( ) ,
323
+ Err ( ref e) if e. kind ( ) == io:: ErrorKind :: TimedOut => {
324
+ let socket_timeout = vm. class ( "socket" , "timeout" ) ;
325
+ return Err ( vm. new_exception ( socket_timeout, "Timed out" . to_string ( ) ) ) ;
326
+ }
270
327
Err ( s) => return Err ( vm. new_os_error ( s. to_string ( ) ) ) ,
271
328
} ,
272
329
None => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
@@ -295,9 +352,13 @@ impl SocketRef {
295
352
match self . con . borrow_mut ( ) . as_mut ( ) {
296
353
Some ( v) => match v. write ( & bytes) {
297
354
Ok ( _) => ( ) ,
355
+ Err ( ref e) if e. kind ( ) == io:: ErrorKind :: TimedOut => {
356
+ let socket_timeout = vm. class ( "socket" , "timeout" ) ;
357
+ return Err ( vm. new_exception ( socket_timeout, "Timed out" . to_string ( ) ) ) ;
358
+ }
298
359
Err ( s) => return Err ( vm. new_os_error ( s. to_string ( ) ) ) ,
299
360
} ,
300
- None => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
361
+ None => return Err ( vm. new_type_error ( "Socket is not connected " . to_string ( ) ) ) ,
301
362
} ;
302
363
Ok ( ( ) )
303
364
}
@@ -352,6 +413,75 @@ impl SocketRef {
352
413
Err ( s) => Err ( vm. new_os_error ( s. to_string ( ) ) ) ,
353
414
}
354
415
}
416
+
417
+ fn gettimeout ( self , _vm : & VirtualMachine ) -> PyResult < Option < f64 > > {
418
+ match self . timeout . borrow ( ) . as_ref ( ) {
419
+ Some ( duration) => Ok ( Some ( duration. as_secs ( ) as f64 ) ) ,
420
+ None => Ok ( None ) ,
421
+ }
422
+ }
423
+
424
+ fn setblocking ( self , block : Option < bool > , vm : & VirtualMachine ) -> PyResult < ( ) > {
425
+ match block {
426
+ Some ( value) => {
427
+ if value {
428
+ self . timeout . replace ( None ) ;
429
+ } else {
430
+ self . timeout . borrow_mut ( ) . replace ( Duration :: from_secs ( 0 ) ) ;
431
+ }
432
+ if let Some ( conn) = self . con . borrow_mut ( ) . as_mut ( ) {
433
+ return match conn. setblocking ( value) {
434
+ Ok ( _) => Ok ( ( ) ) ,
435
+ Err ( err) => Err ( vm. new_os_error ( err. to_string ( ) ) ) ,
436
+ } ;
437
+ } else {
438
+ Ok ( ( ) )
439
+ }
440
+ }
441
+ None => {
442
+ // Avoid converting None to bool
443
+ Err ( vm. new_type_error ( "an bool is required" . to_string ( ) ) )
444
+ }
445
+ }
446
+ }
447
+
448
+ fn getblocking ( self , _vm : & VirtualMachine ) -> PyResult < Option < bool > > {
449
+ match self . timeout . borrow ( ) . as_ref ( ) {
450
+ Some ( duration) => {
451
+ if duration. as_secs ( ) != 0 {
452
+ Ok ( Some ( true ) )
453
+ } else {
454
+ Ok ( Some ( false ) )
455
+ }
456
+ }
457
+ None => Ok ( Some ( true ) ) ,
458
+ }
459
+ }
460
+
461
+ fn settimeout ( self , timeout : Option < f64 > , vm : & VirtualMachine ) -> PyResult < ( ) > {
462
+ match timeout {
463
+ Some ( timeout) => {
464
+ self . timeout
465
+ . borrow_mut ( )
466
+ . replace ( Duration :: from_secs ( timeout as u64 ) ) ;
467
+
468
+ let block = timeout > 0.0 ;
469
+
470
+ if let Some ( conn) = self . con . borrow_mut ( ) . as_mut ( ) {
471
+ conn. setblocking ( block)
472
+ . and_then ( |_| conn. settimeout ( Duration :: from_secs ( timeout as u64 ) ) )
473
+ . map_err ( |err| vm. new_os_error ( err. to_string ( ) ) )
474
+ . map ( |_| ( ) )
475
+ } else {
476
+ Ok ( ( ) )
477
+ }
478
+ }
479
+ None => {
480
+ self . timeout . replace ( None ) ;
481
+ Ok ( ( ) )
482
+ }
483
+ }
484
+ }
355
485
}
356
486
357
487
struct Address {
@@ -432,6 +562,8 @@ fn socket_htonl(host: PyIntRef, vm: &VirtualMachine) -> PyResult {
432
562
433
563
pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
434
564
let ctx = & vm. ctx ;
565
+ let socket_timeout = ctx. new_class ( "socket.timeout" , vm. ctx . exceptions . os_error . clone ( ) ) ;
566
+ let socket_gaierror = ctx. new_class ( "socket.gaierror" , vm. ctx . exceptions . os_error . clone ( ) ) ;
435
567
436
568
let socket = py_class ! ( ctx, "socket" , ctx. object( ) , {
437
569
"__new__" => ctx. new_rustfunc( SocketRef :: new) ,
@@ -448,9 +580,16 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
448
580
"sendto" => ctx. new_rustfunc( SocketRef :: sendto) ,
449
581
"recvfrom" => ctx. new_rustfunc( SocketRef :: recvfrom) ,
450
582
"fileno" => ctx. new_rustfunc( SocketRef :: fileno) ,
583
+ "getblocking" => ctx. new_rustfunc( SocketRef :: getblocking) ,
584
+ "setblocking" => ctx. new_rustfunc( SocketRef :: setblocking) ,
585
+ "gettimeout" => ctx. new_rustfunc( SocketRef :: gettimeout) ,
586
+ "settimeout" => ctx. new_rustfunc( SocketRef :: settimeout) ,
451
587
} ) ;
452
588
453
589
let module = py_module ! ( vm, "socket" , {
590
+ "error" => ctx. exceptions. os_error. clone( ) ,
591
+ "timeout" => socket_timeout,
592
+ "gaierror" => socket_gaierror,
454
593
"AF_INET" => ctx. new_int( AddressFamily :: Inet as i32 ) ,
455
594
"SOCK_STREAM" => ctx. new_int( SocketKind :: Stream as i32 ) ,
456
595
"SOCK_DGRAM" => ctx. new_int( SocketKind :: Dgram as i32 ) ,
0 commit comments