@@ -13,12 +13,15 @@ use super::os::convert_io_error;
13
13
#[ cfg( unix) ]
14
14
use super :: os:: convert_nix_error;
15
15
use crate :: function:: { OptionalArg , PyFuncArgs } ;
16
+ use crate :: obj:: objbytearray:: PyByteArrayRef ;
16
17
use crate :: obj:: objbyteinner:: PyBytesLike ;
17
18
use crate :: obj:: objbytes:: PyBytesRef ;
18
- use crate :: obj:: objstr:: PyStringRef ;
19
+ use crate :: obj:: objstr:: { PyString , PyStringRef } ;
19
20
use crate :: obj:: objtuple:: PyTupleRef ;
20
21
use crate :: obj:: objtype:: PyClassRef ;
21
- use crate :: pyobject:: { PyClassImpl , PyObjectRef , PyRef , PyResult , PyValue , TryFromObject } ;
22
+ use crate :: pyobject:: {
23
+ Either , IntoPyObject , PyClassImpl , PyObjectRef , PyRef , PyResult , PyValue , TryFromObject ,
24
+ } ;
22
25
use crate :: vm:: VirtualMachine ;
23
26
24
27
#[ cfg( unix) ]
@@ -172,14 +175,22 @@ impl PySocket {
172
175
}
173
176
174
177
#[ pymethod]
175
- fn recv ( & self , bufsize : usize , vm : & VirtualMachine ) -> PyResult {
178
+ fn recv ( & self , bufsize : usize , vm : & VirtualMachine ) -> PyResult < Vec < u8 > > {
176
179
let mut buffer = vec ! [ 0u8 ; bufsize] ;
177
180
match self . sock . borrow_mut ( ) . read_exact ( & mut buffer) {
178
- Ok ( ( ) ) => Ok ( vm . ctx . new_bytes ( buffer) ) ,
181
+ Ok ( ( ) ) => Ok ( buffer) ,
179
182
Err ( err) => Err ( convert_sock_error ( vm, err) ) ,
180
183
}
181
184
}
182
185
186
+ #[ pymethod]
187
+ fn recv_into ( & self , buf : PyByteArrayRef , vm : & VirtualMachine ) -> PyResult < usize > {
188
+ let mut buffer = buf. inner . borrow_mut ( ) ;
189
+ self . sock ( )
190
+ . recv ( & mut buffer. elements )
191
+ . map_err ( |err| convert_sock_error ( vm, err) )
192
+ }
193
+
183
194
#[ pymethod]
184
195
fn recvfrom ( & self , bufsize : usize , vm : & VirtualMachine ) -> PyResult < ( Vec < u8 > , AddrTuple ) > {
185
196
let mut buffer = vec ! [ 0u8 ; bufsize] ;
@@ -191,11 +202,20 @@ impl PySocket {
191
202
192
203
#[ pymethod]
193
204
fn send ( & self , bytes : PyBytesLike , vm : & VirtualMachine ) -> PyResult < usize > {
205
+ // TODO: use PyBytesLike.with_ref() instead of to_cow()
194
206
self . sock ( )
195
207
. send ( bytes. to_cow ( ) . as_ref ( ) )
196
208
. map_err ( |err| convert_sock_error ( vm, err) )
197
209
}
198
210
211
+ #[ pymethod]
212
+ fn sendall ( & self , bytes : PyBytesLike , vm : & VirtualMachine ) -> PyResult < ( ) > {
213
+ self . sock
214
+ . borrow_mut ( )
215
+ . write_all ( bytes. to_cow ( ) . as_ref ( ) )
216
+ . map_err ( |err| convert_sock_error ( vm, err) )
217
+ }
218
+
199
219
#[ pymethod]
200
220
fn sendto ( & self , bytes : PyBytesLike , address : Address , vm : & VirtualMachine ) -> PyResult < ( ) > {
201
221
let addr = get_addr ( vm, address) ?;
@@ -315,10 +335,14 @@ impl TryFromObject for Address {
315
335
if tuple. elements . len ( ) != 2 {
316
336
Err ( vm. new_type_error ( "Address tuple should have only 2 values" . to_string ( ) ) )
317
337
} else {
318
- Ok ( Address {
319
- host : PyStringRef :: try_from_object ( vm, tuple. elements [ 0 ] . clone ( ) ) ?,
320
- port : u16:: try_from_object ( vm, tuple. elements [ 1 ] . clone ( ) ) ?,
321
- } )
338
+ let host = PyStringRef :: try_from_object ( vm, tuple. elements [ 0 ] . clone ( ) ) ?;
339
+ let host = if host. as_str ( ) . is_empty ( ) {
340
+ PyString :: from ( "0.0.0.0" ) . into_ref ( vm)
341
+ } else {
342
+ host
343
+ } ;
344
+ let port = u16:: try_from_object ( vm, tuple. elements [ 1 ] . clone ( ) ) ?;
345
+ Ok ( Address { host, port } )
322
346
}
323
347
}
324
348
}
@@ -368,6 +392,85 @@ fn socket_htonl(host: u32, vm: &VirtualMachine) -> PyResult {
368
392
Ok ( vm. new_int ( host. to_be ( ) ) )
369
393
}
370
394
395
+ #[ derive( FromArgs ) ]
396
+ struct GAIOptions {
397
+ #[ pyarg( positional_only) ]
398
+ host : Option < PyStringRef > ,
399
+ #[ pyarg( positional_only) ]
400
+ port : Option < Either < PyStringRef , i32 > > ,
401
+
402
+ #[ pyarg( positional_only, default = "0" ) ]
403
+ family : i32 ,
404
+ #[ pyarg( positional_only, default = "0" ) ]
405
+ ty : i32 ,
406
+ #[ pyarg( positional_only, default = "0" ) ]
407
+ proto : i32 ,
408
+ #[ pyarg( positional_only, default = "0" ) ]
409
+ flags : i32 ,
410
+ }
411
+
412
+ fn socket_getaddrinfo ( opts : GAIOptions , vm : & VirtualMachine ) -> PyResult {
413
+ let hints = dns_lookup:: AddrInfoHints {
414
+ socktype : opts. ty ,
415
+ protocol : opts. proto ,
416
+ address : opts. family ,
417
+ flags : opts. flags ,
418
+ } ;
419
+
420
+ let host = opts. host . as_ref ( ) . map ( |s| s. as_str ( ) ) ;
421
+ let port = opts. port . as_ref ( ) . map ( |p| -> std:: borrow:: Cow < str > {
422
+ match p {
423
+ Either :: A ( ref s) => s. as_str ( ) . into ( ) ,
424
+ Either :: B ( i) => i. to_string ( ) . into ( ) ,
425
+ }
426
+ } ) ;
427
+ let port = port. as_ref ( ) . map ( |p| p. as_ref ( ) ) ;
428
+
429
+ let addrs = dns_lookup:: getaddrinfo ( host, port, Some ( hints) ) . map_err ( |err| {
430
+ let error_type = vm. class ( "_socket" , "gaierror" ) ;
431
+ vm. new_exception ( error_type, io:: Error :: from ( err) . to_string ( ) )
432
+ } ) ?;
433
+
434
+ let list = addrs
435
+ . map ( |ai| {
436
+ ai. map ( |ai| {
437
+ vm. ctx . new_tuple ( vec ! [
438
+ vm. new_int( ai. address) ,
439
+ vm. new_int( ai. socktype) ,
440
+ vm. new_int( ai. protocol) ,
441
+ match ai. canonname {
442
+ Some ( s) => vm. new_str( s) ,
443
+ None => vm. get_none( ) ,
444
+ } ,
445
+ get_addr_tuple( ai. sockaddr) . into_pyobject( vm) . unwrap( ) ,
446
+ ] )
447
+ } )
448
+ } )
449
+ . collect :: < io:: Result < Vec < _ > > > ( )
450
+ . map_err ( |e| convert_sock_error ( vm, e) ) ?;
451
+ Ok ( vm. ctx . new_list ( list) )
452
+ }
453
+
454
+ fn socket_gethostbyaddr (
455
+ addr : PyStringRef ,
456
+ vm : & VirtualMachine ,
457
+ ) -> PyResult < ( String , PyObjectRef , PyObjectRef ) > {
458
+ // TODO: figure out how to do this properly
459
+ let ai = dns_lookup:: getaddrinfo ( Some ( addr. as_str ( ) ) , None , None )
460
+ . map_err ( |e| convert_sock_error ( vm, e. into ( ) ) ) ?
461
+ . next ( )
462
+ . unwrap ( )
463
+ . map_err ( |e| convert_sock_error ( vm, e) ) ?;
464
+ let ( hostname, _) =
465
+ dns_lookup:: getnameinfo ( & ai. sockaddr , 0 ) . map_err ( |e| convert_sock_error ( vm, e. into ( ) ) ) ?;
466
+ Ok ( (
467
+ hostname,
468
+ vm. ctx . new_list ( vec ! [ ] ) ,
469
+ vm. ctx
470
+ . new_list ( vec ! [ vm. new_str( ai. sockaddr. ip( ) . to_string( ) ) ] ) ,
471
+ ) )
472
+ }
473
+
371
474
fn get_addr < T , I > ( vm : & VirtualMachine , addr : T ) -> PyResult < socket2:: SockAddr >
372
475
where
373
476
T : ToSocketAddrs < Iter = I > ,
@@ -467,6 +570,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
467
570
"gethostname" => ctx. new_rustfunc( socket_gethostname) ,
468
571
"htonl" => ctx. new_rustfunc( socket_htonl) ,
469
572
"getdefaulttimeout" => ctx. new_rustfunc( |vm: & VirtualMachine | vm. get_none( ) ) ,
573
+ "getaddrinfo" => ctx. new_rustfunc( socket_getaddrinfo) ,
574
+ "gethostbyaddr" => ctx. new_rustfunc( socket_gethostbyaddr) ,
470
575
} ) ;
471
576
472
577
extend_module_platform_specific ( vm, & module) ;
0 commit comments