Skip to content

Commit 10997ae

Browse files
Merge pull request RustPython#1620 from RustPython/coolreader18/socket-fixes
Add some dns functions to socket
2 parents 711c222 + f24cdf1 commit 10997ae

File tree

3 files changed

+126
-8
lines changed

3 files changed

+126
-8
lines changed

Cargo.lock

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ gethostname = "0.2.0"
8686
subprocess = "0.1.18"
8787
num_cpus = "1"
8888
socket2 = { version = "0.3", features = ["unix"] }
89+
dns-lookup = "1.0"
8990

9091
[target."cfg(windows)".dependencies.winapi]
9192
version = "0.3"

vm/src/stdlib/socket.rs

Lines changed: 113 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@ use super::os::convert_io_error;
1313
#[cfg(unix)]
1414
use super::os::convert_nix_error;
1515
use crate::function::{OptionalArg, PyFuncArgs};
16+
use crate::obj::objbytearray::PyByteArrayRef;
1617
use crate::obj::objbyteinner::PyBytesLike;
1718
use crate::obj::objbytes::PyBytesRef;
18-
use crate::obj::objstr::PyStringRef;
19+
use crate::obj::objstr::{PyString, PyStringRef};
1920
use crate::obj::objtuple::PyTupleRef;
2021
use crate::obj::objtype::PyClassRef;
21-
use crate::pyobject::{PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject};
22+
use crate::pyobject::{
23+
Either, IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
24+
};
2225
use crate::vm::VirtualMachine;
2326

2427
#[cfg(unix)]
@@ -172,14 +175,22 @@ impl PySocket {
172175
}
173176

174177
#[pymethod]
175-
fn recv(&self, bufsize: usize, vm: &VirtualMachine) -> PyResult {
178+
fn recv(&self, bufsize: usize, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
176179
let mut buffer = vec![0u8; bufsize];
177180
match self.sock.borrow_mut().read_exact(&mut buffer) {
178-
Ok(()) => Ok(vm.ctx.new_bytes(buffer)),
181+
Ok(()) => Ok(buffer),
179182
Err(err) => Err(convert_sock_error(vm, err)),
180183
}
181184
}
182185

186+
#[pymethod]
187+
fn recv_into(&self, buf: PyByteArrayRef, vm: &VirtualMachine) -> PyResult<usize> {
188+
let mut buffer = buf.inner.borrow_mut();
189+
self.sock()
190+
.recv(&mut buffer.elements)
191+
.map_err(|err| convert_sock_error(vm, err))
192+
}
193+
183194
#[pymethod]
184195
fn recvfrom(&self, bufsize: usize, vm: &VirtualMachine) -> PyResult<(Vec<u8>, AddrTuple)> {
185196
let mut buffer = vec![0u8; bufsize];
@@ -191,11 +202,20 @@ impl PySocket {
191202

192203
#[pymethod]
193204
fn send(&self, bytes: PyBytesLike, vm: &VirtualMachine) -> PyResult<usize> {
205+
// TODO: use PyBytesLike.with_ref() instead of to_cow()
194206
self.sock()
195207
.send(bytes.to_cow().as_ref())
196208
.map_err(|err| convert_sock_error(vm, err))
197209
}
198210

211+
#[pymethod]
212+
fn sendall(&self, bytes: PyBytesLike, vm: &VirtualMachine) -> PyResult<()> {
213+
self.sock
214+
.borrow_mut()
215+
.write_all(bytes.to_cow().as_ref())
216+
.map_err(|err| convert_sock_error(vm, err))
217+
}
218+
199219
#[pymethod]
200220
fn sendto(&self, bytes: PyBytesLike, address: Address, vm: &VirtualMachine) -> PyResult<()> {
201221
let addr = get_addr(vm, address)?;
@@ -315,10 +335,14 @@ impl TryFromObject for Address {
315335
if tuple.elements.len() != 2 {
316336
Err(vm.new_type_error("Address tuple should have only 2 values".to_string()))
317337
} else {
318-
Ok(Address {
319-
host: PyStringRef::try_from_object(vm, tuple.elements[0].clone())?,
320-
port: u16::try_from_object(vm, tuple.elements[1].clone())?,
321-
})
338+
let host = PyStringRef::try_from_object(vm, tuple.elements[0].clone())?;
339+
let host = if host.as_str().is_empty() {
340+
PyString::from("0.0.0.0").into_ref(vm)
341+
} else {
342+
host
343+
};
344+
let port = u16::try_from_object(vm, tuple.elements[1].clone())?;
345+
Ok(Address { host, port })
322346
}
323347
}
324348
}
@@ -368,6 +392,85 @@ fn socket_htonl(host: u32, vm: &VirtualMachine) -> PyResult {
368392
Ok(vm.new_int(host.to_be()))
369393
}
370394

395+
#[derive(FromArgs)]
396+
struct GAIOptions {
397+
#[pyarg(positional_only)]
398+
host: Option<PyStringRef>,
399+
#[pyarg(positional_only)]
400+
port: Option<Either<PyStringRef, i32>>,
401+
402+
#[pyarg(positional_only, default = "0")]
403+
family: i32,
404+
#[pyarg(positional_only, default = "0")]
405+
ty: i32,
406+
#[pyarg(positional_only, default = "0")]
407+
proto: i32,
408+
#[pyarg(positional_only, default = "0")]
409+
flags: i32,
410+
}
411+
412+
fn socket_getaddrinfo(opts: GAIOptions, vm: &VirtualMachine) -> PyResult {
413+
let hints = dns_lookup::AddrInfoHints {
414+
socktype: opts.ty,
415+
protocol: opts.proto,
416+
address: opts.family,
417+
flags: opts.flags,
418+
};
419+
420+
let host = opts.host.as_ref().map(|s| s.as_str());
421+
let port = opts.port.as_ref().map(|p| -> std::borrow::Cow<str> {
422+
match p {
423+
Either::A(ref s) => s.as_str().into(),
424+
Either::B(i) => i.to_string().into(),
425+
}
426+
});
427+
let port = port.as_ref().map(|p| p.as_ref());
428+
429+
let addrs = dns_lookup::getaddrinfo(host, port, Some(hints)).map_err(|err| {
430+
let error_type = vm.class("_socket", "gaierror");
431+
vm.new_exception(error_type, io::Error::from(err).to_string())
432+
})?;
433+
434+
let list = addrs
435+
.map(|ai| {
436+
ai.map(|ai| {
437+
vm.ctx.new_tuple(vec![
438+
vm.new_int(ai.address),
439+
vm.new_int(ai.socktype),
440+
vm.new_int(ai.protocol),
441+
match ai.canonname {
442+
Some(s) => vm.new_str(s),
443+
None => vm.get_none(),
444+
},
445+
get_addr_tuple(ai.sockaddr).into_pyobject(vm).unwrap(),
446+
])
447+
})
448+
})
449+
.collect::<io::Result<Vec<_>>>()
450+
.map_err(|e| convert_sock_error(vm, e))?;
451+
Ok(vm.ctx.new_list(list))
452+
}
453+
454+
fn socket_gethostbyaddr(
455+
addr: PyStringRef,
456+
vm: &VirtualMachine,
457+
) -> PyResult<(String, PyObjectRef, PyObjectRef)> {
458+
// TODO: figure out how to do this properly
459+
let ai = dns_lookup::getaddrinfo(Some(addr.as_str()), None, None)
460+
.map_err(|e| convert_sock_error(vm, e.into()))?
461+
.next()
462+
.unwrap()
463+
.map_err(|e| convert_sock_error(vm, e))?;
464+
let (hostname, _) =
465+
dns_lookup::getnameinfo(&ai.sockaddr, 0).map_err(|e| convert_sock_error(vm, e.into()))?;
466+
Ok((
467+
hostname,
468+
vm.ctx.new_list(vec![]),
469+
vm.ctx
470+
.new_list(vec![vm.new_str(ai.sockaddr.ip().to_string())]),
471+
))
472+
}
473+
371474
fn get_addr<T, I>(vm: &VirtualMachine, addr: T) -> PyResult<socket2::SockAddr>
372475
where
373476
T: ToSocketAddrs<Iter = I>,
@@ -467,6 +570,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
467570
"gethostname" => ctx.new_rustfunc(socket_gethostname),
468571
"htonl" => ctx.new_rustfunc(socket_htonl),
469572
"getdefaulttimeout" => ctx.new_rustfunc(|vm: &VirtualMachine| vm.get_none()),
573+
"getaddrinfo" => ctx.new_rustfunc(socket_getaddrinfo),
574+
"gethostbyaddr" => ctx.new_rustfunc(socket_gethostbyaddr),
470575
});
471576

472577
extend_module_platform_specific(vm, &module);

0 commit comments

Comments
 (0)