Skip to content

Commit 8847081

Browse files
Merge pull request RustPython#582 from palaviv/socket-any
Use PyObjectPayload::AnyRustValue for socket
2 parents 7bb6f8f + af2f790 commit 8847081

File tree

2 files changed

+91
-109
lines changed

2 files changed

+91
-109
lines changed

vm/src/pyobject.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ use crate::obj::objsuper;
2929
use crate::obj::objtuple;
3030
use crate::obj::objtype;
3131
use crate::obj::objzip;
32-
use crate::stdlib::socket::Socket;
3332
use crate::vm::VirtualMachine;
3433
use num_bigint::BigInt;
3534
use num_bigint::ToBigInt;
@@ -1236,9 +1235,6 @@ pub enum PyObjectPayload {
12361235
RustFunction {
12371236
function: Box<Fn(&mut VirtualMachine, PyFuncArgs) -> PyResult>,
12381237
},
1239-
Socket {
1240-
socket: RefCell<Socket>,
1241-
},
12421238
AnyRustValue {
12431239
value: Box<dyn std::any::Any>,
12441240
},
@@ -1278,7 +1274,6 @@ impl fmt::Debug for PyObjectPayload {
12781274
PyObjectPayload::Instance { .. } => write!(f, "instance"),
12791275
PyObjectPayload::RustFunction { .. } => write!(f, "rust function"),
12801276
PyObjectPayload::Frame { .. } => write!(f, "frame"),
1281-
PyObjectPayload::Socket { .. } => write!(f, "socket"),
12821277
PyObjectPayload::AnyRustValue { .. } => write!(f, "some rust value"),
12831278
}
12841279
}

vm/src/stdlib/socket.rs

Lines changed: 91 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::io;
33
use std::io::Read;
44
use std::io::Write;
55
use std::net::{SocketAddr, TcpListener, TcpStream};
6+
use std::ops::DerefMut;
67

78
use crate::obj::objbytes;
89
use crate::obj::objint;
@@ -108,6 +109,15 @@ impl Socket {
108109
}
109110
}
110111

112+
fn get_socket<'a>(obj: &'a PyObjectRef) -> impl DerefMut<Target = Socket> + 'a {
113+
if let PyObjectPayload::AnyRustValue { ref value } = obj.payload {
114+
if let Some(socket) = value.downcast_ref::<RefCell<Socket>>() {
115+
return socket.borrow_mut();
116+
}
117+
}
118+
panic!("Inner error getting socket {:?}", obj);
119+
}
120+
111121
fn socket_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
112122
arg_check!(
113123
vm,
@@ -125,7 +135,9 @@ fn socket_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
125135
let socket = RefCell::new(Socket::new(address_family, kind));
126136

127137
Ok(PyObject::new(
128-
PyObjectPayload::Socket { socket },
138+
PyObjectPayload::AnyRustValue {
139+
value: Box::new(socket),
140+
},
129141
cls.clone(),
130142
))
131143
}
@@ -139,17 +151,14 @@ fn socket_connect(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
139151

140152
let address_string = get_address_string(vm, address)?;
141153

142-
match zelf.payload {
143-
PyObjectPayload::Socket { ref socket } => {
144-
if let Ok(stream) = TcpStream::connect(address_string) {
145-
socket.borrow_mut().con = Some(Connection::TcpStream(stream));
146-
Ok(vm.get_none())
147-
} else {
148-
// TODO: Socket error
149-
Err(vm.new_type_error("socket failed".to_string()))
150-
}
151-
}
152-
_ => Err(vm.new_type_error("".to_string())),
154+
let mut socket = get_socket(zelf);
155+
156+
if let Ok(stream) = TcpStream::connect(address_string) {
157+
socket.con = Some(Connection::TcpStream(stream));
158+
Ok(vm.get_none())
159+
} else {
160+
// TODO: Socket error
161+
Err(vm.new_type_error("socket failed".to_string()))
153162
}
154163
}
155164

@@ -162,17 +171,14 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
162171

163172
let address_string = get_address_string(vm, address)?;
164173

165-
match zelf.payload {
166-
PyObjectPayload::Socket { ref socket } => {
167-
if let Ok(stream) = TcpListener::bind(address_string) {
168-
socket.borrow_mut().con = Some(Connection::TcpListener(stream));
169-
Ok(vm.get_none())
170-
} else {
171-
// TODO: Socket error
172-
Err(vm.new_type_error("socket failed".to_string()))
173-
}
174-
}
175-
_ => Err(vm.new_type_error("".to_string())),
174+
let mut socket = get_socket(zelf);
175+
176+
if let Ok(stream) = TcpListener::bind(address_string) {
177+
socket.con = Some(Connection::TcpListener(stream));
178+
Ok(vm.get_none())
179+
} else {
180+
// TODO: Socket error
181+
Err(vm.new_type_error("socket failed".to_string()))
176182
}
177183
}
178184

@@ -212,35 +218,37 @@ fn socket_listen(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
212218
fn socket_accept(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
213219
arg_check!(vm, args, required = [(zelf, None)]);
214220

215-
match zelf.payload {
216-
PyObjectPayload::Socket { ref socket } => {
217-
let ret = match socket.borrow_mut().con {
218-
Some(ref mut v) => v.accept(),
219-
None => return Err(vm.new_type_error("".to_string())),
220-
};
221+
let mut socket = get_socket(zelf);
221222

222-
let tcp_stream = match ret {
223-
Ok((socket, _addr)) => socket,
224-
_ => return Err(vm.new_type_error("".to_string())),
225-
};
223+
let ret = match socket.con {
224+
Some(ref mut v) => v.accept(),
225+
None => return Err(vm.new_type_error("".to_string())),
226+
};
226227

227-
let socket = RefCell::new(Socket {
228-
address_family: socket.borrow().address_family,
229-
socket_kind: socket.borrow().socket_kind,
230-
con: Some(Connection::TcpStream(tcp_stream)),
231-
});
228+
let tcp_stream = match ret {
229+
Ok((socket, _addr)) => socket,
230+
_ => return Err(vm.new_type_error("".to_string())),
231+
};
232232

233-
let sock_obj = PyObject::new(PyObjectPayload::Socket { socket }, zelf.typ());
233+
let socket = RefCell::new(Socket {
234+
address_family: socket.address_family,
235+
socket_kind: socket.socket_kind,
236+
con: Some(Connection::TcpStream(tcp_stream)),
237+
});
238+
239+
let sock_obj = PyObject::new(
240+
PyObjectPayload::AnyRustValue {
241+
value: Box::new(socket),
242+
},
243+
zelf.typ(),
244+
);
234245

235-
let elements = RefCell::new(vec![sock_obj, vm.get_none()]);
246+
let elements = RefCell::new(vec![sock_obj, vm.get_none()]);
236247

237-
Ok(PyObject::new(
238-
PyObjectPayload::Sequence { elements },
239-
vm.ctx.tuple_type(),
240-
))
241-
}
242-
_ => Err(vm.new_type_error("".to_string())),
243-
}
248+
Ok(PyObject::new(
249+
PyObjectPayload::Sequence { elements },
250+
vm.ctx.tuple_type(),
251+
))
244252
}
245253

246254
fn socket_recv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
@@ -249,17 +257,14 @@ fn socket_recv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
249257
args,
250258
required = [(zelf, None), (bufsize, Some(vm.ctx.int_type()))]
251259
);
252-
match zelf.payload {
253-
PyObjectPayload::Socket { ref socket } => {
254-
let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()];
255-
match socket.borrow_mut().con {
256-
Some(ref mut v) => v.read_exact(&mut buffer).unwrap(),
257-
None => return Err(vm.new_type_error("".to_string())),
258-
};
259-
Ok(vm.ctx.new_bytes(buffer))
260-
}
261-
_ => Err(vm.new_type_error("".to_string())),
262-
}
260+
let mut socket = get_socket(zelf);
261+
262+
let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()];
263+
match socket.con {
264+
Some(ref mut v) => v.read_exact(&mut buffer).unwrap(),
265+
None => return Err(vm.new_type_error("".to_string())),
266+
};
267+
Ok(vm.ctx.new_bytes(buffer))
263268
}
264269

265270
fn socket_send(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
@@ -268,60 +273,42 @@ fn socket_send(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
268273
args,
269274
required = [(zelf, None), (bytes, Some(vm.ctx.bytes_type()))]
270275
);
271-
match zelf.payload {
272-
PyObjectPayload::Socket { ref socket } => {
273-
match socket.borrow_mut().con {
274-
Some(ref mut v) => v.write(&objbytes::get_value(&bytes)).unwrap(),
275-
None => return Err(vm.new_type_error("".to_string())),
276-
};
277-
Ok(vm.get_none())
278-
}
279-
_ => Err(vm.new_type_error("".to_string())),
280-
}
276+
let mut socket = get_socket(zelf);
277+
278+
match socket.con {
279+
Some(ref mut v) => v.write(&objbytes::get_value(&bytes)).unwrap(),
280+
None => return Err(vm.new_type_error("".to_string())),
281+
};
282+
Ok(vm.get_none())
281283
}
282284

283285
fn socket_close(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
284286
arg_check!(vm, args, required = [(zelf, None)]);
285-
match zelf.payload {
286-
PyObjectPayload::Socket { ref socket } => {
287-
let mut socket = socket.borrow_mut();
288-
match socket.address_family {
289-
AddressFamily::Inet => match socket.socket_kind {
290-
SocketKind::Stream => {
291-
socket.con = None;
292-
Ok(vm.get_none())
293-
}
294-
_ => Err(vm.new_type_error("".to_string())),
295-
},
296-
_ => Err(vm.new_type_error("".to_string())),
297-
}
298-
}
299-
_ => Err(vm.new_type_error("".to_string())),
300-
}
287+
288+
let mut socket = get_socket(zelf);
289+
socket.con = None;
290+
Ok(vm.get_none())
301291
}
302292

303293
fn socket_getsockname(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
304294
arg_check!(vm, args, required = [(zelf, None)]);
305-
match zelf.payload {
306-
PyObjectPayload::Socket { ref socket } => {
307-
let addr = match socket.borrow_mut().con {
308-
Some(ref mut v) => v.local_addr(),
309-
None => return Err(vm.new_type_error("".to_string())),
310-
};
311-
312-
match addr {
313-
Ok(addr) => {
314-
let port = vm.ctx.new_int(addr.port());
315-
let ip = vm.ctx.new_str(addr.ip().to_string());
316-
let elements = RefCell::new(vec![ip, port]);
317-
318-
Ok(PyObject::new(
319-
PyObjectPayload::Sequence { elements },
320-
vm.ctx.tuple_type(),
321-
))
322-
}
323-
_ => Err(vm.new_type_error("".to_string())),
324-
}
295+
let mut socket = get_socket(zelf);
296+
297+
let addr = match socket.con {
298+
Some(ref mut v) => v.local_addr(),
299+
None => return Err(vm.new_type_error("".to_string())),
300+
};
301+
302+
match addr {
303+
Ok(addr) => {
304+
let port = vm.ctx.new_int(addr.port());
305+
let ip = vm.ctx.new_str(addr.ip().to_string());
306+
let elements = RefCell::new(vec![ip, port]);
307+
308+
Ok(PyObject::new(
309+
PyObjectPayload::Sequence { elements },
310+
vm.ctx.tuple_type(),
311+
))
325312
}
326313
_ => Err(vm.new_type_error("".to_string())),
327314
}

0 commit comments

Comments
 (0)