Skip to content

Commit 2999d70

Browse files
authored
Merge pull request RustPython#1309 from Lynskylate/extend-socket
Add settimeout and setblocking for socket module
2 parents 124f164 + 936af5b commit 2999d70

File tree

2 files changed

+156
-7
lines changed

2 files changed

+156
-7
lines changed

tests/snippets/stdlib_socket.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,13 @@
137137

138138
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
139139
pass
140+
141+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as listener:
142+
listener.bind(("127.0.0.1", 0))
143+
listener.listen(1)
144+
connector = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
145+
connector.connect(("127.0.0.1", listener.getsockname()[1]))
146+
(connection, addr) = listener.accept()
147+
connection.settimeout(1.0)
148+
with assertRaises(OSError):
149+
connection.recv(len(MESSAGE_A))

vm/src/stdlib/socket.rs

Lines changed: 146 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::io;
33
use std::io::Read;
44
use std::io::Write;
55
use std::net::{Ipv4Addr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket};
6+
use std::time::Duration;
67

78
#[cfg(all(unix, not(target_os = "redox")))]
89
use nix::unistd::sethostname;
@@ -122,6 +123,27 @@ impl Connection {
122123
fn fileno(&self) -> i64 {
123124
unimplemented!();
124125
}
126+
127+
fn setblocking(&mut self, value: bool) -> io::Result<()> {
128+
match self {
129+
Connection::TcpListener(con) => con.set_nonblocking(!value),
130+
Connection::UdpSocket(con) => con.set_nonblocking(!value),
131+
Connection::TcpStream(con) => con.set_nonblocking(!value),
132+
}
133+
}
134+
135+
fn settimeout(&mut self, duration: Duration) -> io::Result<()> {
136+
match self {
137+
// net
138+
Connection::TcpListener(_con) => Ok(()),
139+
Connection::UdpSocket(con) => con
140+
.set_read_timeout(Some(duration))
141+
.and_then(|_| con.set_write_timeout(Some(duration))),
142+
Connection::TcpStream(con) => con
143+
.set_read_timeout(Some(duration))
144+
.and_then(|_| con.set_write_timeout(Some(duration))),
145+
}
146+
}
125147
}
126148

127149
impl Read for Connection {
@@ -152,6 +174,7 @@ pub struct Socket {
152174
address_family: AddressFamily,
153175
socket_kind: SocketKind,
154176
con: RefCell<Option<Connection>>,
177+
timeout: RefCell<Option<Duration>>,
155178
}
156179

157180
impl PyValue for Socket {
@@ -166,6 +189,7 @@ impl Socket {
166189
address_family,
167190
socket_kind,
168191
con: RefCell::new(None),
192+
timeout: RefCell::new(None),
169193
}
170194
}
171195
}
@@ -194,13 +218,41 @@ impl SocketRef {
194218
let address_string = address.get_address_string();
195219

196220
match self.socket_kind {
197-
SocketKind::Stream => match TcpStream::connect(address_string) {
198-
Ok(stream) => {
199-
self.con.borrow_mut().replace(Connection::TcpStream(stream));
200-
Ok(())
221+
SocketKind::Stream => {
222+
let con = if let Some(duration) = self.timeout.borrow().as_ref() {
223+
let sock_addr = match address_string.to_socket_addrs() {
224+
Ok(mut sock_addrs) => {
225+
if sock_addrs.len() == 0 {
226+
let error_type = vm.class("socket", "gaierror");
227+
return Err(vm.new_exception(
228+
error_type,
229+
"nodename nor servname provided, or not known".to_string(),
230+
));
231+
} else {
232+
sock_addrs.next().unwrap()
233+
}
234+
}
235+
Err(e) => {
236+
let error_type = vm.class("socket", "gaierror");
237+
return Err(vm.new_exception(error_type, e.to_string()));
238+
}
239+
};
240+
TcpStream::connect_timeout(&sock_addr, *duration)
241+
} else {
242+
TcpStream::connect(address_string)
243+
};
244+
match con {
245+
Ok(stream) => {
246+
self.con.borrow_mut().replace(Connection::TcpStream(stream));
247+
Ok(())
248+
}
249+
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
250+
let socket_timeout = vm.class("socket", "timeout");
251+
Err(vm.new_exception(socket_timeout, "Timed out".to_string()))
252+
}
253+
Err(s) => Err(vm.new_os_error(s.to_string())),
201254
}
202-
Err(s) => Err(vm.new_os_error(s.to_string())),
203-
},
255+
}
204256
SocketKind::Dgram => {
205257
if let Some(Connection::UdpSocket(con)) = self.con.borrow().as_ref() {
206258
match con.connect(address_string) {
@@ -254,6 +306,7 @@ impl SocketRef {
254306
address_family: self.address_family,
255307
socket_kind: self.socket_kind,
256308
con: RefCell::new(Some(Connection::TcpStream(tcp_stream))),
309+
timeout: RefCell::new(None),
257310
}
258311
.into_ref(vm);
259312

@@ -267,6 +320,10 @@ impl SocketRef {
267320
match self.con.borrow_mut().as_mut() {
268321
Some(v) => match v.read_exact(&mut buffer) {
269322
Ok(_) => (),
323+
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
324+
let socket_timeout = vm.class("socket", "timeout");
325+
return Err(vm.new_exception(socket_timeout, "Timed out".to_string()));
326+
}
270327
Err(s) => return Err(vm.new_os_error(s.to_string())),
271328
},
272329
None => return Err(vm.new_type_error("".to_string())),
@@ -295,9 +352,13 @@ impl SocketRef {
295352
match self.con.borrow_mut().as_mut() {
296353
Some(v) => match v.write(&bytes) {
297354
Ok(_) => (),
355+
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
356+
let socket_timeout = vm.class("socket", "timeout");
357+
return Err(vm.new_exception(socket_timeout, "Timed out".to_string()));
358+
}
298359
Err(s) => return Err(vm.new_os_error(s.to_string())),
299360
},
300-
None => return Err(vm.new_type_error("".to_string())),
361+
None => return Err(vm.new_type_error("Socket is not connected".to_string())),
301362
};
302363
Ok(())
303364
}
@@ -352,6 +413,75 @@ impl SocketRef {
352413
Err(s) => Err(vm.new_os_error(s.to_string())),
353414
}
354415
}
416+
417+
fn gettimeout(self, _vm: &VirtualMachine) -> PyResult<Option<f64>> {
418+
match self.timeout.borrow().as_ref() {
419+
Some(duration) => Ok(Some(duration.as_secs() as f64)),
420+
None => Ok(None),
421+
}
422+
}
423+
424+
fn setblocking(self, block: Option<bool>, vm: &VirtualMachine) -> PyResult<()> {
425+
match block {
426+
Some(value) => {
427+
if value {
428+
self.timeout.replace(None);
429+
} else {
430+
self.timeout.borrow_mut().replace(Duration::from_secs(0));
431+
}
432+
if let Some(conn) = self.con.borrow_mut().as_mut() {
433+
return match conn.setblocking(value) {
434+
Ok(_) => Ok(()),
435+
Err(err) => Err(vm.new_os_error(err.to_string())),
436+
};
437+
} else {
438+
Ok(())
439+
}
440+
}
441+
None => {
442+
// Avoid converting None to bool
443+
Err(vm.new_type_error("an bool is required".to_string()))
444+
}
445+
}
446+
}
447+
448+
fn getblocking(self, _vm: &VirtualMachine) -> PyResult<Option<bool>> {
449+
match self.timeout.borrow().as_ref() {
450+
Some(duration) => {
451+
if duration.as_secs() != 0 {
452+
Ok(Some(true))
453+
} else {
454+
Ok(Some(false))
455+
}
456+
}
457+
None => Ok(Some(true)),
458+
}
459+
}
460+
461+
fn settimeout(self, timeout: Option<f64>, vm: &VirtualMachine) -> PyResult<()> {
462+
match timeout {
463+
Some(timeout) => {
464+
self.timeout
465+
.borrow_mut()
466+
.replace(Duration::from_secs(timeout as u64));
467+
468+
let block = timeout > 0.0;
469+
470+
if let Some(conn) = self.con.borrow_mut().as_mut() {
471+
conn.setblocking(block)
472+
.and_then(|_| conn.settimeout(Duration::from_secs(timeout as u64)))
473+
.map_err(|err| vm.new_os_error(err.to_string()))
474+
.map(|_| ())
475+
} else {
476+
Ok(())
477+
}
478+
}
479+
None => {
480+
self.timeout.replace(None);
481+
Ok(())
482+
}
483+
}
484+
}
355485
}
356486

357487
struct Address {
@@ -432,6 +562,8 @@ fn socket_htonl(host: PyIntRef, vm: &VirtualMachine) -> PyResult {
432562

433563
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
434564
let ctx = &vm.ctx;
565+
let socket_timeout = ctx.new_class("socket.timeout", vm.ctx.exceptions.os_error.clone());
566+
let socket_gaierror = ctx.new_class("socket.gaierror", vm.ctx.exceptions.os_error.clone());
435567

436568
let socket = py_class!(ctx, "socket", ctx.object(), {
437569
"__new__" => ctx.new_rustfunc(SocketRef::new),
@@ -448,9 +580,16 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
448580
"sendto" => ctx.new_rustfunc(SocketRef::sendto),
449581
"recvfrom" => ctx.new_rustfunc(SocketRef::recvfrom),
450582
"fileno" => ctx.new_rustfunc(SocketRef::fileno),
583+
"getblocking" => ctx.new_rustfunc(SocketRef::getblocking),
584+
"setblocking" => ctx.new_rustfunc(SocketRef::setblocking),
585+
"gettimeout" => ctx.new_rustfunc(SocketRef::gettimeout),
586+
"settimeout" => ctx.new_rustfunc(SocketRef::settimeout),
451587
});
452588

453589
let module = py_module!(vm, "socket", {
590+
"error" => ctx.exceptions.os_error.clone(),
591+
"timeout" => socket_timeout,
592+
"gaierror" => socket_gaierror,
454593
"AF_INET" => ctx.new_int(AddressFamily::Inet as i32),
455594
"SOCK_STREAM" => ctx.new_int(SocketKind::Stream as i32),
456595
"SOCK_DGRAM" => ctx.new_int(SocketKind::Dgram as i32),

0 commit comments

Comments
 (0)