Skip to content

Commit 6f2bbd2

Browse files
Merge pull request RustPython#545 from palaviv/socket1
Socket improvments
2 parents cca19cf + 8087ccf commit 6f2bbd2

File tree

2 files changed

+99
-26
lines changed

2 files changed

+99
-26
lines changed

tests/snippets/stdlib_socket.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,37 @@
11
import socket
2+
from testutils import assertRaises
3+
24

35
listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
4-
listener.bind(("127.0.0.1", 8080))
6+
listener.bind(("127.0.0.1", 0))
57
listener.listen(1)
68

79
connector = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
8-
connector.connect(("127.0.0.1", 8080))
10+
connector.connect(("127.0.0.1", listener.getsockname()[1]))
911
connection = listener.accept()[0]
1012

1113
message_a = b'aaaa'
1214
message_b = b'bbbbb'
1315

1416
connector.send(message_a)
15-
connector.close()
16-
recv_a = connection.recv(10)
17+
connection.send(message_b)
18+
recv_a = connection.recv(len(message_a))
19+
recv_b = connector.recv(len(message_b))
20+
assert recv_a == message_a
21+
assert recv_b == message_b
1722

1823
connection.close()
24+
connector.close()
1925
listener.close()
2026

27+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
28+
with assertRaises(TypeError):
29+
s.connect(("127.0.0.1", 8888, 8888))
30+
31+
with assertRaises(TypeError):
32+
s.bind(("127.0.0.1", 8888, 8888))
33+
34+
with assertRaises(TypeError):
35+
s.bind((888, 8888))
36+
37+
s.close()

vm/src/stdlib/socket.rs

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::cell::RefCell;
22
use std::io;
33
use std::io::Read;
44
use std::io::Write;
5-
use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket};
5+
use std::net::{SocketAddr, TcpListener, TcpStream};
66

77
use crate::obj::objbytes;
88
use crate::obj::objint;
@@ -52,7 +52,7 @@ impl SocketKind {
5252
enum Connection {
5353
TcpListener(TcpListener),
5454
TcpStream(TcpStream),
55-
UdpSocket(UdpSocket),
55+
// UdpSocket(UdpSocket),
5656
}
5757

5858
impl Connection {
@@ -62,6 +62,13 @@ impl Connection {
6262
_ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")),
6363
}
6464
}
65+
66+
fn local_addr(&self) -> io::Result<SocketAddr> {
67+
match self {
68+
Connection::TcpListener(con) => con.local_addr(),
69+
_ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")),
70+
}
71+
}
6572
}
6673

6774
impl Read for Connection {
@@ -87,15 +94,15 @@ impl Write for Connection {
8794

8895
pub struct Socket {
8996
address_family: AddressFamily,
90-
sk: SocketKind,
97+
socket_kind: SocketKind,
9198
con: Option<Connection>,
9299
}
93100

94101
impl Socket {
95-
fn new(address_family: AddressFamily, sk: SocketKind) -> Socket {
102+
fn new(address_family: AddressFamily, socket_kind: SocketKind) -> Socket {
96103
Socket {
97104
address_family,
98-
sk: sk,
105+
socket_kind: socket_kind,
99106
con: None,
100107
}
101108
}
@@ -130,11 +137,7 @@ fn socket_connect(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
130137
required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))]
131138
);
132139

133-
let elements = get_elements(address);
134-
let host = objstr::get_value(&elements[0]);
135-
let port = objint::get_value(&elements[1]);
136-
137-
let address_string = format!("{}:{}", host, port.to_string());
140+
let address_string = get_address_string(vm, address)?;
138141

139142
match zelf.payload {
140143
PyObjectPayload::Socket { ref socket } => {
@@ -157,11 +160,7 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
157160
required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))]
158161
);
159162

160-
let elements = get_elements(address);
161-
let host = objstr::get_value(&elements[0]);
162-
let port = objint::get_value(&elements[1]);
163-
164-
let address_string = format!("{}:{}", host, port.to_string());
163+
let address_string = get_address_string(vm, address)?;
165164

166165
match zelf.payload {
167166
PyObjectPayload::Socket { ref socket } => {
@@ -177,7 +176,36 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
177176
}
178177
}
179178

179+
fn get_address_string(
180+
vm: &mut VirtualMachine,
181+
address: &PyObjectRef,
182+
) -> Result<String, PyObjectRef> {
183+
let args = PyFuncArgs {
184+
args: get_elements(address).to_vec(),
185+
kwargs: vec![],
186+
};
187+
arg_check!(
188+
vm,
189+
args,
190+
required = [
191+
(host, Some(vm.ctx.str_type())),
192+
(port, Some(vm.ctx.int_type()))
193+
]
194+
);
195+
196+
Ok(format!(
197+
"{}:{}",
198+
objstr::get_value(host),
199+
objint::get_value(port).to_string()
200+
))
201+
}
202+
180203
fn socket_listen(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
204+
arg_check!(
205+
vm,
206+
args,
207+
required = [(_zelf, None), (_num, Some(vm.ctx.int_type()))]
208+
);
181209
Ok(vm.get_none())
182210
}
183211

@@ -197,8 +225,8 @@ fn socket_accept(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
197225
};
198226

199227
let socket = RefCell::new(Socket {
200-
address_family: socket.borrow().address_family.clone(),
201-
sk: socket.borrow().sk.clone(),
228+
address_family: socket.borrow().address_family,
229+
socket_kind: socket.borrow().socket_kind,
202230
con: Some(Connection::TcpStream(tcp_stream)),
203231
});
204232

@@ -223,10 +251,10 @@ fn socket_recv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
223251
);
224252
match zelf.payload {
225253
PyObjectPayload::Socket { ref socket } => {
226-
let mut buffer = Vec::new();
227-
let _temp = match socket.borrow_mut().con {
228-
Some(ref mut v) => v.read_to_end(&mut buffer).unwrap(),
229-
None => 0,
254+
let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()];
255+
match socket.borrow_mut().con {
256+
Some(ref mut v) => v.read_exact(&mut buffer).unwrap(),
257+
None => return Err(vm.new_type_error("".to_string())),
230258
};
231259
Ok(vm.ctx.new_bytes(buffer))
232260
}
@@ -258,7 +286,7 @@ fn socket_close(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
258286
PyObjectPayload::Socket { ref socket } => {
259287
let mut socket = socket.borrow_mut();
260288
match socket.address_family {
261-
AddressFamily::AfInet => match socket.sk {
289+
AddressFamily::AfInet => match socket.socket_kind {
262290
SocketKind::SockStream => {
263291
socket.con = None;
264292
Ok(vm.get_none())
@@ -272,6 +300,33 @@ fn socket_close(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
272300
}
273301
}
274302

303+
fn socket_getsockname(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
304+
arg_check!(vm, args, required = [(zelf, None)]);
305+
match zelf.payload {
306+
PyObjectPayload::Socket { ref socket } => {
307+
let addr = match socket.borrow_mut().con {
308+
Some(ref mut v) => v.local_addr(),
309+
None => return Err(vm.new_type_error("".to_string())),
310+
};
311+
312+
match addr {
313+
Ok(addr) => {
314+
let port = vm.ctx.new_int(addr.port());
315+
let ip = vm.ctx.new_str(addr.ip().to_string());
316+
let elements = RefCell::new(vec![ip, port]);
317+
318+
Ok(PyObject::new(
319+
PyObjectPayload::Sequence { elements },
320+
vm.ctx.tuple_type(),
321+
))
322+
}
323+
_ => return Err(vm.new_type_error("".to_string())),
324+
}
325+
}
326+
_ => Err(vm.new_type_error("".to_string())),
327+
}
328+
}
329+
275330
pub fn mk_module(ctx: &PyContext) -> PyObjectRef {
276331
let py_mod = ctx.new_module(&"socket".to_string(), ctx.new_scope(None));
277332

@@ -297,6 +352,7 @@ pub fn mk_module(ctx: &PyContext) -> PyObjectRef {
297352
ctx.set_attr(&socket, "accept", ctx.new_rustfunc(socket_accept));
298353
ctx.set_attr(&socket, "listen", ctx.new_rustfunc(socket_listen));
299354
ctx.set_attr(&socket, "close", ctx.new_rustfunc(socket_close));
355+
ctx.set_attr(&socket, "getsockname", ctx.new_rustfunc(socket_getsockname));
300356
socket
301357
};
302358
ctx.set_attr(&py_mod, "socket", socket.clone());

0 commit comments

Comments
 (0)